Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions benchmark/mcoplib_mxbenchmark_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
try:
import cuda.bench._nvbench as bench
except ImportError:
print("[ERROR] Runtime environment missing 'nvbench'. Please check configuration.")
sys.exit(1)
bench = None

# Import base class for type checking
from mcoplib_mxbenchmark_op_wrapper import OpBenchmarkBase
Expand Down Expand Up @@ -124,6 +123,10 @@ def list_supported_operators():
print(f" * {op}")
print("="*40 + "\n")

def find_literal_matches(query, candidates):
pattern = re.compile(re.escape(query), re.IGNORECASE)
return [item for item in candidates if pattern.search(item)]

def load_operator_runner(op_name):
current_dir = get_base_dir()
if current_dir not in sys.path:
Expand All @@ -144,9 +147,8 @@ def load_operator_runner(op_name):
else:
print(f"[INFO] Exact config match not found for '{op_name}', trying fuzzy search...")
try:
pattern = re.compile(op_name, re.IGNORECASE)
json_files = [f for f in os.listdir(config_dir) if f.endswith(".json")]
matched_jsons = [f for f in json_files if pattern.search(f)]
matched_jsons = find_literal_matches(op_name, json_files)

if len(matched_jsons) == 0:
print(f"[ERROR] Config file not found, and fuzzy search for '{op_name}' yielded no results.")
Expand Down Expand Up @@ -195,14 +197,12 @@ def load_operator_runner(op_name):

# Strategy A: Config Name
if canonical_name:
pattern_canon = re.compile(canonical_name, re.IGNORECASE)
matched_runners = [f for f in py_files if pattern_canon.search(f)]
matched_runners = find_literal_matches(canonical_name, py_files)

# Strategy B: Input Name Fallback
if not matched_runners and op_name and op_name != canonical_name:
print(f"[INFO] Canonical name match failed, falling back to input name '{op_name}'...")
pattern_op = re.compile(op_name, re.IGNORECASE)
matched_runners = [f for f in py_files if pattern_op.search(f)]
matched_runners = find_literal_matches(op_name, py_files)

except re.error as e:
print(f"[ERROR] Regex error: {e}")
Expand Down Expand Up @@ -601,6 +601,10 @@ def perform_comparison(cur_raw, hist_raw):
print("-"*80 + "\n")
sys.exit(1)

if bench is None:
print("[ERROR] Runtime environment missing 'nvbench'. Please check configuration.")
sys.exit(1)

# 3. Load Operator
op_name = args.op
op_instance = load_operator_runner(op_name)
Expand Down
22 changes: 22 additions & 0 deletions unit_test/test_benchmark_fuzzy_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import sys
import unittest
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "benchmark"))

from mcoplib_mxbenchmark_ops import find_literal_matches


class BenchmarkFuzzyMatchTest(unittest.TestCase):
def test_matches_literal_special_characters(self):
candidates = ["op+a.json", "opxa.json", "other.json"]
self.assertEqual(find_literal_matches("op+a", candidates), ["op+a.json"])

def test_unbalanced_regex_character_is_safe(self):
candidates = ["scaled_mm[fp8].json", "scaled_mm_fp8.json"]
self.assertEqual(find_literal_matches("mm[fp8]", candidates), ["scaled_mm[fp8].json"])


if __name__ == "__main__":
unittest.main()