diff --git a/ICtest.py b/ICtest.py index b1080bb..527200e 100644 --- a/ICtest.py +++ b/ICtest.py @@ -1,14 +1,12 @@ import pandas as pd import numpy as np -import sys -import importlib -from pathlib import Path from typing import List, Dict, Any, Optional from alpha191.utils import ( load_benchmark_csv, get_benchmark_members, format_alpha_name, - parallel_load_stocks_with_alpha + parallel_load_stocks_with_alpha, + get_alpha_func ) from assessment import get_clean_factor_and_forward_returns, compute_performance_metrics, compute_stability_metrics from datetime import datetime @@ -19,15 +17,10 @@ def assess_alpha(alpha_name: str, benchmark: str = "zz800", horizons: List[int] # Import alpha function try: - alpha_module = importlib.import_module(f"alpha191.{alpha_name}") - func_name = alpha_name[:5] + "_" + alpha_name[5:] - alpha_func = getattr(alpha_module, func_name) - except (ImportError, AttributeError) as e: + alpha_func = get_alpha_func(alpha_name, use_df=True) + except ValueError as e: print(f"Error importing {alpha_name}: {e}") - try: - alpha_func = getattr(alpha_module, alpha_name) - except AttributeError: - return + return codes = get_benchmark_members(benchmark) diff --git a/alpha191/utils.py b/alpha191/utils.py index 71444d6..dca7292 100644 --- a/alpha191/utils.py +++ b/alpha191/utils.py @@ -6,7 +6,7 @@ import pandas as pd from pathlib import Path from functools import lru_cache -from typing import List, Optional +from typing import List, Optional, Union, Any # Cache for benchmark data _benchmark_cache = {} @@ -191,6 +191,61 @@ def format_alpha_name(alpha_name: str) -> str: raise ValueError(f"Invalid alpha name: {alpha_name}. Expected format: '1' or 'alpha001'") +def get_alpha_func(alpha_id: Union[int, str], use_df: bool = False, ignore_errors: bool = False) -> Optional[Any]: + """ + Get the alpha function by number or name. + + Args: + alpha_id: Alpha number (e.g., 17) or name (e.g., "alpha017") + use_df: If True, return the function that takes a DataFrame (alpha_XXX). + If False, return the function that takes code/benchmark (alphaXXX). + ignore_errors: If True, return None on failure instead of raising ValueError. + + Returns: + The alpha function or None if not found and ignore_errors is True. + """ + import importlib + try: + alpha_name = format_alpha_name(str(alpha_id)) + module = importlib.import_module(f"alpha191.{alpha_name}") + + if use_df: + # Try alpha_XXX first (preferred for DataFrame input) + func_name = f"alpha_{int(alpha_name[5:]):03d}" + if hasattr(module, func_name): + return getattr(module, func_name) + # Fallback to alphaXXX + if hasattr(module, alpha_name): + return getattr(module, alpha_name) + else: + # Try alphaXXX first (preferred for code/benchmark input) + if hasattr(module, alpha_name): + return getattr(module, alpha_name) + # Fallback to alpha_XXX + func_name = f"alpha_{int(alpha_name[5:]):03d}" + if hasattr(module, func_name): + return getattr(module, func_name) + + except (ImportError, ModuleNotFoundError, ValueError): + pass + + if ignore_errors: + return None + raise ValueError(f"Alpha function for '{alpha_id}' not found") + + +def get_stock_codes(benchmark: str) -> List[str]: + """Get list of stock codes available in the benchmark directory.""" + benchmark_dir = PROJECT_ROOT / 'bao' / benchmark + if not benchmark_dir.exists(): + raise FileNotFoundError(f"Benchmark directory not found: {benchmark_dir}") + + # Get all CSV files and extract stock codes (without .csv extension) + csv_files = sorted(benchmark_dir.glob('*.csv')) + stock_codes = [f.stem for f in csv_files] + return stock_codes + + def _load_single_stock_with_alpha(args): """ Helper function for parallel loading of stock data and alpha computation. diff --git a/calculate_covariance.py b/calculate_covariance.py index 7aab3d1..0d541e8 100644 --- a/calculate_covariance.py +++ b/calculate_covariance.py @@ -1,21 +1,9 @@ -import os -import sys import pandas as pd import numpy as np import time -from pathlib import Path from alpha191 import * -from alpha191.utils import get_benchmark_members, load_stock_csv, load_benchmark_csv - -def get_alpha_func(alpha_num: int): - """Get the alpha function (the one that takes a DataFrame) by number.""" - func_name = f"alpha_{alpha_num:03d}" - # Check in individual modules first or from alpha191 package - import alpha191 - if hasattr(alpha191, func_name): - return getattr(alpha191, func_name) - return None +from alpha191.utils import get_benchmark_members, load_stock_csv, load_benchmark_csv, get_alpha_func, get_stock_codes def main(): benchmark = "hs300" @@ -25,8 +13,7 @@ def main(): except Exception as e: print(f"Error loading benchmark members: {e}") # Fallback to listing directory if needed - benchmark_dir = Path('bao') / benchmark - stock_codes = [f.stem for f in benchmark_dir.glob('*.csv')] + stock_codes = get_stock_codes(benchmark) print(f"Using all {len(stock_codes)} stocks for covariance calculation.") @@ -60,7 +47,7 @@ def main(): if i in missing_alphas: continue - alpha_func = get_alpha_func(i) + alpha_func = get_alpha_func(i, use_df=True, ignore_errors=True) if alpha_func: try: res = alpha_func(df) diff --git a/fulltest.py b/fulltest.py index 99cad48..7f66d6d 100644 --- a/fulltest.py +++ b/fulltest.py @@ -18,28 +18,8 @@ import sys import time -from pathlib import Path from alpha191 import * - - -def get_alpha_func(alpha_num: int): - """Get the alpha function by number (e.g., 17 -> alpha017).""" - func_name = f"alpha{alpha_num:03d}" - if hasattr(sys.modules['alpha191'], func_name): - return getattr(sys.modules['alpha191'], func_name) - raise ValueError(f"Alpha function '{func_name}' not found in alpha191 module") - - -def get_stock_codes(benchmark: str) -> list: - """Get list of stock codes from the benchmark directory.""" - benchmark_dir = Path('bao') / benchmark - if not benchmark_dir.exists(): - raise FileNotFoundError(f"Benchmark directory not found: {benchmark_dir}") - - # Get all CSV files and extract stock codes (without .csv extension) - csv_files = sorted(benchmark_dir.glob('*.csv')) - stock_codes = [f.stem for f in csv_files] - return stock_codes +from alpha191.utils import get_alpha_func, get_stock_codes def test_alpha(alpha_num: int, benchmark: str, stock_codes: list) -> dict: diff --git a/grouptest.py b/grouptest.py index beb2cf1..1b30b5b 100644 --- a/grouptest.py +++ b/grouptest.py @@ -1,15 +1,13 @@ import pandas as pd import numpy as np -import sys -import importlib -from pathlib import Path from typing import List, Dict, Any, Union from alpha191.utils import ( load_benchmark_csv, get_benchmark_members, format_alpha_name, - parallel_load_stocks_with_alpha + parallel_load_stocks_with_alpha, + get_alpha_func ) from assessment import ( get_clean_factor_and_forward_returns, @@ -32,15 +30,10 @@ def run_group_test(alpha_name: str, horizons: List[int] = [20], benchmark: str = # Import alpha function try: - alpha_module = importlib.import_module(f"alpha191.{alpha_name}") - func_name = alpha_name[:5] + "_" + alpha_name[5:] - alpha_func = getattr(alpha_module, func_name) - except (ImportError, AttributeError) as e: - try: - alpha_func = getattr(alpha_module, alpha_name) - except (AttributeError, NameError): - print(f"Error importing {alpha_name}: {e}") - return + alpha_func = get_alpha_func(alpha_name, use_df=True) + except ValueError as e: + print(f"Error importing {alpha_name}: {e}") + return codes = get_benchmark_members(benchmark) diff --git a/run_alpha_tests.py b/run_alpha_tests.py index 091984c..6de62ff 100644 --- a/run_alpha_tests.py +++ b/run_alpha_tests.py @@ -12,7 +12,6 @@ import pandas as pd import numpy as np -import importlib # Add the current directory to path for imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) @@ -21,7 +20,8 @@ load_benchmark_csv, get_benchmark_members, format_alpha_name, - parallel_load_stocks_with_alpha + parallel_load_stocks_with_alpha, + get_alpha_func ) from assessment import ( get_clean_factor_and_forward_returns, @@ -49,15 +49,9 @@ def run_ic_test_for_alpha(alpha_name: str, benchmark: str = "zz800", horizons: l with contextlib.redirect_stdout(output_buffer): try: - alpha_module = importlib.import_module(f"alpha191.{alpha_name}") - func_name = alpha_name[:5] + "_" + alpha_name[5:] - alpha_func = getattr(alpha_module, func_name) - except (ImportError, AttributeError) as e: - print(f"Error importing {alpha_name}: {e}") - try: - alpha_func = getattr(alpha_module, alpha_name) - except AttributeError: - return f"ERROR: Could not load alpha {alpha_name}\n{str(e)}" + alpha_func = get_alpha_func(alpha_name, use_df=True) + except ValueError as e: + return f"ERROR: Could not load alpha {alpha_name}\n{str(e)}" codes = get_benchmark_members(benchmark) benchmark_df = load_benchmark_csv(benchmark) @@ -195,15 +189,10 @@ def run_group_test_for_alpha(alpha_name: str, benchmark: str = "zz800", horizons with contextlib.redirect_stdout(output_buffer): try: - alpha_module = importlib.import_module(f"alpha191.{alpha_name}") - func_name = alpha_name[:5] + "_" + alpha_name[5:] - alpha_func = getattr(alpha_module, func_name) - except (ImportError, AttributeError) as e: - try: - alpha_func = getattr(alpha_module, alpha_name) - except (AttributeError, NameError): - print(f"Error importing {alpha_name}: {e}") - return + alpha_func = get_alpha_func(alpha_name, use_df=True) + except ValueError as e: + print(f"Error importing {alpha_name}: {e}") + return codes = get_benchmark_members(benchmark) benchmark_df = load_benchmark_csv(benchmark) diff --git a/select_alphas.py b/select_alphas.py index 252f79f..1e55e5c 100644 --- a/select_alphas.py +++ b/select_alphas.py @@ -12,7 +12,6 @@ import pandas as pd import numpy as np -import importlib import sys import os from pathlib import Path @@ -26,7 +25,8 @@ from alpha191.utils import ( load_benchmark_csv, get_benchmark_members, - load_stock_csv + load_stock_csv, + get_alpha_func ) from assessment import get_clean_factor_and_forward_returns, compute_performance_metrics_light @@ -38,22 +38,6 @@ warnings.filterwarnings("ignore") -def get_alpha_function(alpha_name): - """Dynamically import alpha function.""" - try: - module = importlib.import_module(f"alpha191.{alpha_name}") - func_name1 = alpha_name[:5] + "_" + alpha_name[5:] # alpha_001 - func_name2 = alpha_name # alpha001 - - if hasattr(module, func_name1): - return getattr(module, func_name1) - elif hasattr(module, func_name2): - return getattr(module, func_name2) - except (ImportError, ModuleNotFoundError): - pass - return None - - def preload_data(benchmark): """Load all stock data into memory using float32.""" benchmark_df = load_benchmark_csv(benchmark) @@ -106,7 +90,7 @@ def compute_alpha_for_all_stocks(alpha_func, stock_cache): def process_one_alpha(alpha_name, stock_cache, timeline): """Compute metrics for one alpha. Returns dict or None.""" - alpha_func = get_alpha_function(alpha_name) + alpha_func = get_alpha_func(alpha_name, use_df=True, ignore_errors=True) if not alpha_func: return None diff --git a/speedtest.py b/speedtest.py index 41d279f..65d27a9 100644 --- a/speedtest.py +++ b/speedtest.py @@ -15,28 +15,8 @@ import sys import time -from pathlib import Path from alpha191 import * - - -def get_alpha_func(alpha_num: int): - """Get the alpha function by number (e.g., 17 -> alpha017).""" - func_name = f"alpha{alpha_num:03d}" - if hasattr(sys.modules['alpha191'], func_name): - return getattr(sys.modules['alpha191'], func_name) - raise ValueError(f"Alpha function '{func_name}' not found in alpha191 module") - - -def get_stock_codes(benchmark: str) -> list: - """Get list of stock codes from the benchmark directory.""" - benchmark_dir = Path('bao') / benchmark - if not benchmark_dir.exists(): - raise FileNotFoundError(f"Benchmark directory not found: {benchmark_dir}") - - # Get all CSV files and extract stock codes (without .csv extension) - csv_files = sorted(benchmark_dir.glob('*.csv')) - stock_codes = [f.stem for f in csv_files] - return stock_codes +from alpha191.utils import get_alpha_func, get_stock_codes def main(): diff --git a/tests/test_alphas.py b/tests/test_alphas.py index b39cba8..75a5ea6 100644 --- a/tests/test_alphas.py +++ b/tests/test_alphas.py @@ -992,6 +992,46 @@ def test_format_alpha_name_invalid(self): with self.assertRaises(ValueError): format_alpha_name("a123") + def test_get_alpha_func(self): + """Test get_alpha_func retrieves correct functions.""" + from alpha191.utils import get_alpha_func + # Test getting code version (alphaXXX) + func = get_alpha_func(1, use_df=False) + self.assertTrue(callable(func)) + self.assertEqual(func.__name__, "alpha001") + + # Test getting df version (alpha_XXX) + func_df = get_alpha_func(1, use_df=True) + self.assertTrue(callable(func_df)) + self.assertEqual(func_df.__name__, "alpha_001") + + # Test with string name + func2 = get_alpha_func("alpha002") + self.assertEqual(func2.__name__, "alpha002") + + # Test ignore_errors + func_none = get_alpha_func(999, ignore_errors=True) + self.assertIsNone(func_none) + + # Test raising error + with self.assertRaises(ValueError): + get_alpha_func(999, ignore_errors=False) + + def test_get_stock_codes(self): + """Test get_stock_codes lists files correctly.""" + from alpha191.utils import get_stock_codes + from pathlib import Path + + # This test depends on the existence of bao/hs300 directory + if Path("bao/hs300").exists(): + codes = get_stock_codes("hs300") + self.assertIsInstance(codes, list) + if len(codes) > 0: + self.assertTrue(isinstance(codes[0], str)) + else: + with self.assertRaises(FileNotFoundError): + get_stock_codes("nonexistent_benchmark") + if __name__ == '__main__': unittest.main()