Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions benchmark/mcoplib_mxbenchmark_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
try:
import cuda.bench._nvbench as bench
except ImportError:
print("[ERROR] Runtime environment missing 'nvbench'. Please check configuration.")
sys.exit(1)
bench = None

# Import base class for type checking
from mcoplib_mxbenchmark_op_wrapper import OpBenchmarkBase
Expand Down Expand Up @@ -115,12 +114,17 @@
def get_base_dir():
return os.path.dirname(os.path.abspath(__file__))

def list_supported_operators():
def list_supported_operators(json_output=False):
operators = sorted(SUPPORTED_OPERATORS)
if json_output:
print(json.dumps({"operators": operators, "count": len(operators)}, sort_keys=True))
return

print("\n" + "="*40 + f"\n{' Supported Operators ':=^40}\n" + "="*40)
if not SUPPORTED_OPERATORS:
if not operators:
print(" (No operators defined in SUPPORTED_OPERATORS list)")
else:
for op in sorted(SUPPORTED_OPERATORS):
for op in operators:
print(f" * {op}")
print("="*40 + "\n")

Expand Down Expand Up @@ -574,8 +578,14 @@ def perform_comparison(cur_raw, hist_raw):
# =============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MCOPLIB Operator Performance Benchmark")
parser.add_argument("--op", type=str, default=None, help="Operator name (Required, unless --list is used)")
parser.add_argument(
"--op",
type=str,
default=None,
help="Operator name (Required, unless --list or --list-json is used)",
)
parser.add_argument("--list", action="store_true", help="List all supported operators and exit")
parser.add_argument("--list-json", action="store_true", help="Print supported operators as JSON and exit")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

由于新增了 --list-json 参数,第 581 行中 --op 参数的帮助信息 (Required, unless --list is used) 变得不够准确。建议将其更新为 (Required, unless --list or --list-json is used),以避免误导用户。

parser.add_argument("--csv", type=str, default=None, help="Path to result CSV")

group = parser.add_mutually_exclusive_group()
Expand All @@ -587,8 +597,8 @@ def perform_comparison(cur_raw, hist_raw):
args, unknown = parser.parse_known_args()

# 1. Handle --list
if args.list:
list_supported_operators()
if args.list or args.list_json:
list_supported_operators(json_output=args.list_json)
sys.exit(0)

# 2. Validate Core Argument --op
Expand All @@ -601,6 +611,10 @@ def perform_comparison(cur_raw, hist_raw):
print("-"*80 + "\n")
sys.exit(1)

if bench is None:
print("[ERROR] Runtime environment missing 'nvbench'. Please check configuration.", file=sys.stderr)
sys.exit(1)

# 3. Load Operator
op_name = args.op
op_instance = load_operator_runner(op_name)
Expand Down