diff --git a/benchmark/mcoplib_mxbenchmark_ops.py b/benchmark/mcoplib_mxbenchmark_ops.py index 93b6932..5d80b69 100644 --- a/benchmark/mcoplib_mxbenchmark_ops.py +++ b/benchmark/mcoplib_mxbenchmark_ops.py @@ -12,8 +12,7 @@ try: import cuda.bench._nvbench as bench except ImportError: - print("[ERROR] Runtime environment missing 'nvbench'. Please check configuration.") - sys.exit(1) + bench = None # Import base class for type checking from mcoplib_mxbenchmark_op_wrapper import OpBenchmarkBase @@ -124,6 +123,10 @@ def list_supported_operators(): print(f" * {op}") print("="*40 + "\n") +def find_literal_matches(query, candidates): + pattern = re.compile(re.escape(query), re.IGNORECASE) + return [item for item in candidates if pattern.search(item)] + def load_operator_runner(op_name): current_dir = get_base_dir() if current_dir not in sys.path: @@ -144,9 +147,8 @@ def load_operator_runner(op_name): else: print(f"[INFO] Exact config match not found for '{op_name}', trying fuzzy search...") try: - pattern = re.compile(op_name, re.IGNORECASE) json_files = [f for f in os.listdir(config_dir) if f.endswith(".json")] - matched_jsons = [f for f in json_files if pattern.search(f)] + matched_jsons = find_literal_matches(op_name, json_files) if len(matched_jsons) == 0: print(f"[ERROR] Config file not found, and fuzzy search for '{op_name}' yielded no results.") @@ -195,14 +197,12 @@ def load_operator_runner(op_name): # Strategy A: Config Name if canonical_name: - pattern_canon = re.compile(canonical_name, re.IGNORECASE) - matched_runners = [f for f in py_files if pattern_canon.search(f)] + matched_runners = find_literal_matches(canonical_name, py_files) # Strategy B: Input Name Fallback if not matched_runners and op_name and op_name != canonical_name: print(f"[INFO] Canonical name match failed, falling back to input name '{op_name}'...") - pattern_op = re.compile(op_name, re.IGNORECASE) - matched_runners = [f for f in py_files if pattern_op.search(f)] + matched_runners = find_literal_matches(op_name, py_files) except re.error as e: print(f"[ERROR] Regex error: {e}") @@ -601,6 +601,10 @@ def perform_comparison(cur_raw, hist_raw): print("-"*80 + "\n") sys.exit(1) + if bench is None: + print("[ERROR] Runtime environment missing 'nvbench'. Please check configuration.") + sys.exit(1) + # 3. Load Operator op_name = args.op op_instance = load_operator_runner(op_name) diff --git a/unit_test/test_benchmark_fuzzy_match.py b/unit_test/test_benchmark_fuzzy_match.py new file mode 100644 index 0000000..882faba --- /dev/null +++ b/unit_test/test_benchmark_fuzzy_match.py @@ -0,0 +1,22 @@ +import sys +import unittest +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "benchmark")) + +from mcoplib_mxbenchmark_ops import find_literal_matches + + +class BenchmarkFuzzyMatchTest(unittest.TestCase): + def test_matches_literal_special_characters(self): + candidates = ["op+a.json", "opxa.json", "other.json"] + self.assertEqual(find_literal_matches("op+a", candidates), ["op+a.json"]) + + def test_unbalanced_regex_character_is_safe(self): + candidates = ["scaled_mm[fp8].json", "scaled_mm_fp8.json"] + self.assertEqual(find_literal_matches("mm[fp8]", candidates), ["scaled_mm[fp8].json"]) + + +if __name__ == "__main__": + unittest.main()