Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions ICtest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import pandas as pd
import numpy as np
import sys
import importlib
from pathlib import Path
from typing import List, Dict, Any, Optional
from alpha191.utils import (
load_benchmark_csv,
get_benchmark_members,
format_alpha_name,
parallel_load_stocks_with_alpha
parallel_load_stocks_with_alpha,
get_alpha_func
)
from assessment import get_clean_factor_and_forward_returns, compute_performance_metrics, compute_stability_metrics
from datetime import datetime
Expand All @@ -19,15 +17,10 @@ def assess_alpha(alpha_name: str, benchmark: str = "zz800", horizons: List[int]

# Import alpha function
try:
alpha_module = importlib.import_module(f"alpha191.{alpha_name}")
func_name = alpha_name[:5] + "_" + alpha_name[5:]
alpha_func = getattr(alpha_module, func_name)
except (ImportError, AttributeError) as e:
alpha_func = get_alpha_func(alpha_name, use_df=True)
except ValueError as e:
print(f"Error importing {alpha_name}: {e}")
try:
alpha_func = getattr(alpha_module, alpha_name)
except AttributeError:
return
return

codes = get_benchmark_members(benchmark)

Expand Down
57 changes: 56 additions & 1 deletion alpha191/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from pathlib import Path
from functools import lru_cache
from typing import List, Optional
from typing import List, Optional, Union, Any

# Cache for benchmark data
_benchmark_cache = {}
Expand Down Expand Up @@ -191,6 +191,61 @@ def format_alpha_name(alpha_name: str) -> str:
raise ValueError(f"Invalid alpha name: {alpha_name}. Expected format: '1' or 'alpha001'")


def get_alpha_func(alpha_id: Union[int, str], use_df: bool = False, ignore_errors: bool = False) -> Optional[Any]:
"""
Get the alpha function by number or name.

Args:
alpha_id: Alpha number (e.g., 17) or name (e.g., "alpha017")
use_df: If True, return the function that takes a DataFrame (alpha_XXX).
If False, return the function that takes code/benchmark (alphaXXX).
ignore_errors: If True, return None on failure instead of raising ValueError.

Returns:
The alpha function or None if not found and ignore_errors is True.
"""
import importlib
try:
alpha_name = format_alpha_name(str(alpha_id))
module = importlib.import_module(f"alpha191.{alpha_name}")

if use_df:
# Try alpha_XXX first (preferred for DataFrame input)
func_name = f"alpha_{int(alpha_name[5:]):03d}"
if hasattr(module, func_name):
return getattr(module, func_name)
# Fallback to alphaXXX
if hasattr(module, alpha_name):
return getattr(module, alpha_name)
else:
# Try alphaXXX first (preferred for code/benchmark input)
if hasattr(module, alpha_name):
return getattr(module, alpha_name)
# Fallback to alpha_XXX
func_name = f"alpha_{int(alpha_name[5:]):03d}"
if hasattr(module, func_name):
return getattr(module, func_name)

except (ImportError, ModuleNotFoundError, ValueError):
pass

if ignore_errors:
return None
raise ValueError(f"Alpha function for '{alpha_id}' not found")
Comment on lines +194 to +234

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function can be improved for clarity, robustness, and to follow best practices.

  1. Local Import: import importlib is inside the function. According to PEP 8, imports should be at the top of the file unless there's a specific reason like avoiding circular dependencies. Please move import importlib to the top of alpha191/utils.py.
  2. DRY Principle: The logic for finding the function when use_df is True or False is very similar and can be refactored to avoid repetition. You can define the search order based on the use_df flag and then iterate through it.
  3. Error Handling: The except (ImportError, ModuleNotFoundError, ValueError): pass block is too broad. It catches ValueError from format_alpha_name but then raises a generic "not found" error, hiding the original, more specific error. It's better to handle exceptions more granularly to provide clearer error messages.

Here is a suggested refactoring that addresses points 2 and 3 (please also address point 1 separately by moving the import):

def get_alpha_func(alpha_id: Union[int, str], use_df: bool = False, ignore_errors: bool = False) -> Optional[Any]:
    """
    Get the alpha function by number or name.

    Args:
        alpha_id: Alpha number (e.g., 17) or name (e.g., "alpha017")
        use_df: If True, return the function that takes a DataFrame (alpha_XXX).
               If False, return the function that takes code/benchmark (alphaXXX).
        ignore_errors: If True, return None on failure instead of raising ValueError.

    Returns:
        The alpha function or None if not found and ignore_errors is True.
    """
    try:
        alpha_name = format_alpha_name(str(alpha_id))
    except ValueError as e:
        if ignore_errors:
            return None
        raise ValueError(f"Invalid alpha name format for '{alpha_id}'") from e

    try:
        module = importlib.import_module(f"alpha191.{alpha_name}")
    except (ImportError, ModuleNotFoundError) as e:
        if ignore_errors:
            return None
        raise ValueError(f"Could not import module for alpha '{alpha_id}'") from e

    df_func_name = f"alpha_{int(alpha_name[5:]):03d}"
    code_func_name = alpha_name

    # Determine search order based on use_df flag
    search_order = [df_func_name, code_func_name] if use_df else [code_func_name, df_func_name]

    for func_name in search_order:
        if hasattr(module, func_name):
            return getattr(module, func_name)

    if ignore_errors:
        return None
    raise ValueError(f"Alpha function for '{alpha_id}' not found in module '{alpha_name}'")



def get_stock_codes(benchmark: str) -> List[str]:
"""Get list of stock codes available in the benchmark directory."""
benchmark_dir = PROJECT_ROOT / 'bao' / benchmark
if not benchmark_dir.exists():
raise FileNotFoundError(f"Benchmark directory not found: {benchmark_dir}")

# Get all CSV files and extract stock codes (without .csv extension)
csv_files = sorted(benchmark_dir.glob('*.csv'))
stock_codes = [f.stem for f in csv_files]
return stock_codes


def _load_single_stock_with_alpha(args):
"""
Helper function for parallel loading of stock data and alpha computation.
Expand Down
19 changes: 3 additions & 16 deletions calculate_covariance.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@

import os
import sys
import pandas as pd
import numpy as np
import time
from pathlib import Path
from alpha191 import *
from alpha191.utils import get_benchmark_members, load_stock_csv, load_benchmark_csv

def get_alpha_func(alpha_num: int):
"""Get the alpha function (the one that takes a DataFrame) by number."""
func_name = f"alpha_{alpha_num:03d}"
# Check in individual modules first or from alpha191 package
import alpha191
if hasattr(alpha191, func_name):
return getattr(alpha191, func_name)
return None
from alpha191.utils import get_benchmark_members, load_stock_csv, load_benchmark_csv, get_alpha_func, get_stock_codes

def main():
benchmark = "hs300"
Expand All @@ -25,8 +13,7 @@ def main():
except Exception as e:
print(f"Error loading benchmark members: {e}")
# Fallback to listing directory if needed
benchmark_dir = Path('bao') / benchmark
stock_codes = [f.stem for f in benchmark_dir.glob('*.csv')]
stock_codes = get_stock_codes(benchmark)

print(f"Using all {len(stock_codes)} stocks for covariance calculation.")

Expand Down Expand Up @@ -60,7 +47,7 @@ def main():
if i in missing_alphas:
continue

alpha_func = get_alpha_func(i)
alpha_func = get_alpha_func(i, use_df=True, ignore_errors=True)
if alpha_func:
try:
res = alpha_func(df)
Expand Down
22 changes: 1 addition & 21 deletions fulltest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,8 @@

import sys
import time
from pathlib import Path
from alpha191 import *


def get_alpha_func(alpha_num: int):
"""Get the alpha function by number (e.g., 17 -> alpha017)."""
func_name = f"alpha{alpha_num:03d}"
if hasattr(sys.modules['alpha191'], func_name):
return getattr(sys.modules['alpha191'], func_name)
raise ValueError(f"Alpha function '{func_name}' not found in alpha191 module")


def get_stock_codes(benchmark: str) -> list:
"""Get list of stock codes from the benchmark directory."""
benchmark_dir = Path('bao') / benchmark
if not benchmark_dir.exists():
raise FileNotFoundError(f"Benchmark directory not found: {benchmark_dir}")

# Get all CSV files and extract stock codes (without .csv extension)
csv_files = sorted(benchmark_dir.glob('*.csv'))
stock_codes = [f.stem for f in csv_files]
return stock_codes
from alpha191.utils import get_alpha_func, get_stock_codes


def test_alpha(alpha_num: int, benchmark: str, stock_codes: list) -> dict:
Expand Down
19 changes: 6 additions & 13 deletions grouptest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import pandas as pd
import numpy as np
import sys
import importlib
from pathlib import Path
from typing import List, Dict, Any, Union

from alpha191.utils import (
load_benchmark_csv,
get_benchmark_members,
format_alpha_name,
parallel_load_stocks_with_alpha
parallel_load_stocks_with_alpha,
get_alpha_func
)
from assessment import (
get_clean_factor_and_forward_returns,
Expand All @@ -32,15 +30,10 @@ def run_group_test(alpha_name: str, horizons: List[int] = [20], benchmark: str =

# Import alpha function
try:
alpha_module = importlib.import_module(f"alpha191.{alpha_name}")
func_name = alpha_name[:5] + "_" + alpha_name[5:]
alpha_func = getattr(alpha_module, func_name)
except (ImportError, AttributeError) as e:
try:
alpha_func = getattr(alpha_module, alpha_name)
except (AttributeError, NameError):
print(f"Error importing {alpha_name}: {e}")
return
alpha_func = get_alpha_func(alpha_name, use_df=True)
except ValueError as e:
print(f"Error importing {alpha_name}: {e}")
return

codes = get_benchmark_members(benchmark)

Expand Down
29 changes: 9 additions & 20 deletions run_alpha_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import pandas as pd
import numpy as np
import importlib

# Add the current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
Expand All @@ -21,7 +20,8 @@
load_benchmark_csv,
get_benchmark_members,
format_alpha_name,
parallel_load_stocks_with_alpha
parallel_load_stocks_with_alpha,
get_alpha_func
)
from assessment import (
get_clean_factor_and_forward_returns,
Expand Down Expand Up @@ -49,15 +49,9 @@ def run_ic_test_for_alpha(alpha_name: str, benchmark: str = "zz800", horizons: l

with contextlib.redirect_stdout(output_buffer):
try:
alpha_module = importlib.import_module(f"alpha191.{alpha_name}")
func_name = alpha_name[:5] + "_" + alpha_name[5:]
alpha_func = getattr(alpha_module, func_name)
except (ImportError, AttributeError) as e:
print(f"Error importing {alpha_name}: {e}")
try:
alpha_func = getattr(alpha_module, alpha_name)
except AttributeError:
return f"ERROR: Could not load alpha {alpha_name}\n{str(e)}"
alpha_func = get_alpha_func(alpha_name, use_df=True)
except ValueError as e:
return f"ERROR: Could not load alpha {alpha_name}\n{str(e)}"

codes = get_benchmark_members(benchmark)
benchmark_df = load_benchmark_csv(benchmark)
Expand Down Expand Up @@ -195,15 +189,10 @@ def run_group_test_for_alpha(alpha_name: str, benchmark: str = "zz800", horizons

with contextlib.redirect_stdout(output_buffer):
try:
alpha_module = importlib.import_module(f"alpha191.{alpha_name}")
func_name = alpha_name[:5] + "_" + alpha_name[5:]
alpha_func = getattr(alpha_module, func_name)
except (ImportError, AttributeError) as e:
try:
alpha_func = getattr(alpha_module, alpha_name)
except (AttributeError, NameError):
print(f"Error importing {alpha_name}: {e}")
return
alpha_func = get_alpha_func(alpha_name, use_df=True)
except ValueError as e:
print(f"Error importing {alpha_name}: {e}")
return

codes = get_benchmark_members(benchmark)
benchmark_df = load_benchmark_csv(benchmark)
Expand Down
22 changes: 3 additions & 19 deletions select_alphas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import pandas as pd
import numpy as np
import importlib
import sys
import os
from pathlib import Path
Expand All @@ -26,7 +25,8 @@
from alpha191.utils import (
load_benchmark_csv,
get_benchmark_members,
load_stock_csv
load_stock_csv,
get_alpha_func
)
from assessment import get_clean_factor_and_forward_returns, compute_performance_metrics_light

Expand All @@ -38,22 +38,6 @@
warnings.filterwarnings("ignore")


def get_alpha_function(alpha_name):
"""Dynamically import alpha function."""
try:
module = importlib.import_module(f"alpha191.{alpha_name}")
func_name1 = alpha_name[:5] + "_" + alpha_name[5:] # alpha_001
func_name2 = alpha_name # alpha001

if hasattr(module, func_name1):
return getattr(module, func_name1)
elif hasattr(module, func_name2):
return getattr(module, func_name2)
except (ImportError, ModuleNotFoundError):
pass
return None


def preload_data(benchmark):
"""Load all stock data into memory using float32."""
benchmark_df = load_benchmark_csv(benchmark)
Expand Down Expand Up @@ -106,7 +90,7 @@ def compute_alpha_for_all_stocks(alpha_func, stock_cache):

def process_one_alpha(alpha_name, stock_cache, timeline):
"""Compute metrics for one alpha. Returns dict or None."""
alpha_func = get_alpha_function(alpha_name)
alpha_func = get_alpha_func(alpha_name, use_df=True, ignore_errors=True)
if not alpha_func:
return None

Expand Down
22 changes: 1 addition & 21 deletions speedtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,8 @@

import sys
import time
from pathlib import Path
from alpha191 import *


def get_alpha_func(alpha_num: int):
"""Get the alpha function by number (e.g., 17 -> alpha017)."""
func_name = f"alpha{alpha_num:03d}"
if hasattr(sys.modules['alpha191'], func_name):
return getattr(sys.modules['alpha191'], func_name)
raise ValueError(f"Alpha function '{func_name}' not found in alpha191 module")


def get_stock_codes(benchmark: str) -> list:
"""Get list of stock codes from the benchmark directory."""
benchmark_dir = Path('bao') / benchmark
if not benchmark_dir.exists():
raise FileNotFoundError(f"Benchmark directory not found: {benchmark_dir}")

# Get all CSV files and extract stock codes (without .csv extension)
csv_files = sorted(benchmark_dir.glob('*.csv'))
stock_codes = [f.stem for f in csv_files]
return stock_codes
from alpha191.utils import get_alpha_func, get_stock_codes


def main():
Expand Down
Loading