Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
03de4db
Consolidate FunctionRanker: merge rank/rerank/filter methods into sin…
KRRT7 Dec 12, 2025
902a982
calculate in own file time
KRRT7 Dec 14, 2025
9d005b1
implement suggestions
KRRT7 Dec 14, 2025
6b7c435
cleanup code
KRRT7 Dec 14, 2025
713f135
let's make it clear it's an sqlite3 db
KRRT7 Dec 14, 2025
3c8533b
forgot this one
KRRT7 Dec 14, 2025
267030c
cleanup
KRRT7 Dec 14, 2025
afdb0f4
tessl add
KRRT7 Dec 14, 2025
3dde686
improve filtering
KRRT7 Dec 14, 2025
a1eee7d
cleanup
KRRT7 Dec 14, 2025
e0d8900
Optimize FunctionRanker.get_function_stats_summary (#971)
codeflash-ai[bot] Dec 14, 2025
f276474
Revert "let's make it clear it's an sqlite3 db"
KRRT7 Dec 16, 2025
6c93082
cleanup trace file
KRRT7 Dec 16, 2025
53d5e3e
cleanup
KRRT7 Dec 16, 2025
4ab0682
addressable time
KRRT7 Dec 16, 2025
0d44424
Optimize TestResults.add
codeflash-ai[bot] Dec 16, 2025
9e15667
bugfix
KRRT7 Dec 17, 2025
8b91de1
Merge pull request #972 from codeflash-ai/codeflash/optimize-TestResu…
misrasaurabh1 Dec 17, 2025
813922e
Merge branch 'ranking-changes' of https://github.com/codeflash-ai/cod…
KRRT7 Dec 17, 2025
9d95745
cleanup
KRRT7 Dec 17, 2025
2e82259
type checks
KRRT7 Dec 17, 2025
fe2a5a2
pre-commit
KRRT7 Dec 17, 2025
e8fba39
⚡️ Speed up function `get_cached_gh_event_data` by 13% (#975)
codeflash-ai[bot] Dec 17, 2025
29bffe9
⚡️ Speed up function `function_is_a_property` by 60% (#974)
codeflash-ai[bot] Dec 17, 2025
675abb2
Optimize function_is_a_property (#976)
codeflash-ai[bot] Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,9 @@ fabric.properties

# Mac
.DS_Store
WARP.MD
WARP.MD

.mcp.json
.tessl/
CLAUDE.md
tessl.json
6 changes: 5 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,4 +315,8 @@ Language Server Protocol support in `codeflash/lsp/` enables IDE integration dur
### Performance Optimization
- Profile before and after changes
- Use benchmarks to validate improvements
- Generate detailed performance reports
- Generate detailed performance reports

# Agent Rules <!-- tessl-managed -->

@.tessl/RULES.md follow the [instructions](.tessl/RULES.md)
Binary file not shown.
197 changes: 132 additions & 65 deletions codeflash/benchmarking/function_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.tracing.profile_stats import ProfileStats
Expand All @@ -12,29 +12,63 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize

pytest_patterns = {
"<frozen", # Frozen modules like runpy
"<string>", # Dynamically evaluated code
"_pytest/", # Pytest internals
"pytest", # Pytest files
"pluggy/", # Plugin system
"_pydev", # PyDev debugger
"runpy.py", # Python module runner
}
pytest_func_patterns = {"pytest_", "_pytest", "runtest"}


def is_pytest_infrastructure(filename: str, function_name: str) -> bool:
"""Check if a function is part of pytest infrastructure that should be excluded from ranking.

This filters out pytest internal functions, hooks, and test framework code that
would otherwise dominate the ranking but aren't candidates for optimization.
"""
# Check filename patterns
for pattern in pytest_patterns:
if pattern in filename:
return True

return any(pattern in function_name.lower() for pattern in pytest_func_patterns)


class FunctionRanker:
"""Ranks and filters functions based on a ttX score derived from profiling data.
"""Ranks and filters functions based on % of addressable time derived from profiling data.

The ttX score is calculated as:
ttX = own_time + (time_spent_in_callees / call_count)
The % of addressable time is calculated as:
addressable_time = own_time + (time_spent_in_callees / call_count)

This score prioritizes functions that are computationally heavy themselves (high `own_time`)
or that make expensive calls to other functions (high average `time_spent_in_callees`).
This represents the runtime of a function plus the runtime of its immediate dependent functions,
as a fraction of overall runtime. It prioritizes functions that are computationally heavy themselves
(high `own_time`) or that make expensive calls to other functions (high average `time_spent_in_callees`).

Functions are first filtered by an importance threshold based on their `own_time` as a
fraction of the total runtime. The remaining functions are then ranked by their ttX score
fraction of the total runtime. The remaining functions are then ranked by their % of addressable time
to identify the best candidates for optimization.
"""

def __init__(self, trace_file_path: Path) -> None:
self.trace_file_path = trace_file_path
self._profile_stats = ProfileStats(trace_file_path.as_posix())
self._function_stats: dict[str, dict] = {}
self._function_stats_by_name: dict[str, list[tuple[str, dict]]] = {}
self.load_function_stats()

# Build index for faster lookups: map function_name to list of (key, stats)
for key, stats in self._function_stats.items():
func_name = stats.get("function_name")
if func_name:
self._function_stats_by_name.setdefault(func_name, []).append((key, stats))

def load_function_stats(self) -> None:
try:
pytest_filtered_count = 0
for (filename, line_number, func_name), (
call_count,
_num_callers,
Expand All @@ -45,6 +79,10 @@ def load_function_stats(self) -> None:
if call_count <= 0:
continue

if is_pytest_infrastructure(filename, func_name):
pytest_filtered_count += 1
continue

# Parse function name to handle methods within classes
class_name, qualified_name, base_function_name = (None, func_name, func_name)
if "." in func_name and not func_name.startswith("<"):
Expand All @@ -56,8 +94,8 @@ def load_function_stats(self) -> None:
own_time_ns = total_time_ns
time_in_callees_ns = cumulative_time_ns - total_time_ns

# Calculate ttX score
ttx_score = own_time_ns + (time_in_callees_ns / call_count)
# Calculate addressable time (own time + avg time in immediate callees)
addressable_time_ns = own_time_ns + (time_in_callees_ns / call_count)

function_key = f"{filename}:{qualified_name}"
self._function_stats[function_key] = {
Expand All @@ -70,89 +108,118 @@ def load_function_stats(self) -> None:
"own_time_ns": own_time_ns,
"cumulative_time_ns": cumulative_time_ns,
"time_in_callees_ns": time_in_callees_ns,
"ttx_score": ttx_score,
"addressable_time_ns": addressable_time_ns,
}

logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")
logger.debug(
f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats "
f"(filtered {pytest_filtered_count} pytest infrastructure functions)"
)

except Exception as e:
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
self._function_stats = {}

def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
target_filename = function_to_optimize.file_path.name
for key, stats in self._function_stats.items():
if stats.get("function_name") == function_to_optimize.function_name and (
key.endswith(f"/{target_filename}") or target_filename in key
):
candidates = self._function_stats_by_name.get(function_to_optimize.function_name)
if not candidates:
logger.debug(
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
)
return None

for key, stats in candidates:
# The check preserves exact logic: "key.endswith(f"/{target_filename}") or target_filename in key"
if key.endswith(f"/{target_filename}") or target_filename in key:
return stats

logger.debug(
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
)
return None

def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
stats = self._get_function_stats(function_to_optimize)
return stats["ttx_score"] if stats else 0.0
def get_function_addressable_time(self, function_to_optimize: FunctionToOptimize) -> float:
"""Get the addressable time in nanoseconds for a function.

def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
logger.debug(
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
)
return ranked
Addressable time = own_time + (time_in_callees / call_count)
This represents the runtime of the function plus runtime of immediate dependent functions.
"""
stats = self.get_function_stats_summary(function_to_optimize)
return stats["addressable_time_ns"] if stats else 0.0

def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
return self._get_function_stats(function_to_optimize)
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
"""Ranks and filters functions based on their % of addressable time and importance.

def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
"""Ranks functions based on their ttX score.
Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD
of file-relative runtime, then ranks the remaining functions by addressable time.

This method calculates the ttX score for each function and returns
the functions sorted in descending order of their ttX score.
"""
if not self._function_stats:
logger.warning("No function stats available to rank functions.")
return []
Importance is calculated relative to functions in the same file(s) rather than
total program time. This avoids filtering out functions due to test infrastructure
overhead.

return self.rank_functions(functions_to_optimize)
The addressable time metric (own_time + avg time in immediate callees) prioritizes
functions that are computationally heavy themselves or that make expensive calls
to other functions.

def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
"""Reranks and filters functions based on their impact on total runtime.
Args:
functions_to_optimize: List of functions to rank.

This method first calculates the total runtime of all profiled functions.
It then filters out functions whose own_time is less than a specified
percentage of the total runtime (importance_threshold).
Returns:
Important functions sorted in descending order of their addressable time.

The remaining 'important' functions are then ranked by their ttX score.
"""
stats_map = self._function_stats
if not stats_map:
if not self._function_stats:
logger.warning("No function stats available to rank functions.")
return []

total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)
# Calculate total time from functions in the same file(s) as functions to optimize
if functions_to_optimize:
# Get unique files from functions to optimize
target_files = {func.file_path.name for func in functions_to_optimize}
# Calculate total time only from functions in these files
total_program_time = sum(
s["own_time_ns"]
for s in self._function_stats.values()
if s.get("own_time_ns", 0) > 0
and any(
str(s.get("filename", "")).endswith("/" + target_file) or s.get("filename") == target_file
for target_file in target_files
)
)
logger.debug(
f"Using file-relative importance for {len(target_files)} file(s): {target_files}. "
f"Total file time: {total_program_time:,} ns"
)
else:
total_program_time = sum(
s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0
)

if total_program_time == 0:
logger.warning("Total program time is zero, cannot determine function importance.")
return self.rank_functions(functions_to_optimize)

important_functions = []
for func in functions_to_optimize:
func_stats = self._get_function_stats(func)
if func_stats and func_stats.get("own_time_ns", 0) > 0:
importance = func_stats["own_time_ns"] / total_program_time
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
important_functions.append(func)
else:
logger.debug(
f"Filtering out function {func.qualified_name} with importance "
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
)

logger.info(
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
functions_to_rank = functions_to_optimize
else:
functions_to_rank = []
for func in functions_to_optimize:
func_stats = self.get_function_stats_summary(func)
if func_stats and func_stats.get("own_time_ns", 0) > 0:
importance = func_stats["own_time_ns"] / total_program_time
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
functions_to_rank.append(func)
else:
logger.debug(
f"Filtering out function {func.qualified_name} with importance "
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
)

logger.info(
f"Filtered down to {len(functions_to_rank)} important functions "
f"from {len(functions_to_optimize)} total functions"
)

ranked = sorted(functions_to_rank, key=self.get_function_addressable_time, reverse=True)
logger.debug(
f"Function ranking order: {[f'{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)' for func in ranked]}"
)
console.rule()

return self.rank_functions(important_functions)
return ranked
36 changes: 8 additions & 28 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,23 @@ def get_unique_test_name(module: str, function_name: str, benchmark_name: str, c


def create_trace_replay_test_code(
trace_file: str,
functions_data: list[dict[str, Any]],
test_framework: str = "pytest",
max_run_count=256, # noqa: ANN001
trace_file: str, functions_data: list[dict[str, Any]], max_run_count: int = 256
) -> str:
"""Create a replay test for functions based on trace data.

Args:
----
trace_file: Path to the SQLite database file
functions_data: List of dictionaries with function info extracted from DB
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include in the test

Returns:
-------
A string containing the test code

"""
assert test_framework in ["pytest", "unittest"]

# Create Imports
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
{"import unittest" if test_framework == "unittest" else ""}
imports = """from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
from codeflash.benchmarking.replay_test import get_next_arg_and_return
"""

Expand Down Expand Up @@ -158,13 +151,7 @@ def create_trace_replay_test_code(
)

# Create main body

if test_framework == "unittest":
self = "self"
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
else:
test_template = ""
self = ""
test_template = ""

for func in functions_data:
module_name = func.get("module_name")
Expand Down Expand Up @@ -223,30 +210,26 @@ def create_trace_replay_test_code(
filter_variables=filter_variables,
)

formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
formatted_test_body = textwrap.indent(test_body, " ")

test_template += " " if test_framework == "unittest" else ""
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
test_template += f"def test_{unique_test_name}():\n{formatted_test_body}\n"

return imports + "\n" + metadata + "\n" + test_template


def generate_replay_test(
trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100
) -> int:
def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count: int = 100) -> int:
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.

Args:
----
trace_file_path: Path to the SQLite database file
output_dir: Directory to write the generated tests (if None, only returns the code)
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include per function

Returns:
-------
Dictionary mapping benchmark names to generated test code
The number of replay tests generated

"""
count = 0
Expand Down Expand Up @@ -293,10 +276,7 @@ def generate_replay_test(
continue
# Generate the test code for this benchmark
test_code = create_trace_replay_test_code(
trace_file=trace_file_path.as_posix(),
functions_data=functions_data,
test_framework=test_framework,
max_run_count=max_run_count,
trace_file=trace_file_path.as_posix(), functions_data=functions_data, max_run_count=max_run_count
)
test_code = sort_imports(code=test_code)
output_file = get_test_file_path(
Expand Down
Loading
Loading