增加集体通信总线带宽系数结构化输出#6
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a Python script tools/busbw_factor.py along with its unit tests to compute MCCL bus-bandwidth correction factors for various collective operations. The reviewer provided valuable feedback suggesting to make the collective name parsing more robust by ignoring underscores and hyphens, and to add a validation check ensuring that the algorithm bandwidth (algbw_gbps) is non-negative.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| collective = collective.lower() | ||
| if collective == "all_reduce": | ||
| return 2 * (ranks - 1) / ranks | ||
| if collective in {"all_gather", "reduce_scatter"}: | ||
| return (ranks - 1) / ranks | ||
| if collective in {"broadcast", "reduce"}: | ||
| return 1.0 | ||
| raise ValueError(f"unsupported collective: {collective}") |
There was a problem hiding this comment.
为了提高脚本的鲁棒性,建议在处理 collective 参数时忽略下划线和连字符(例如支持 allreduce、all_reduce 和 all-reduce 等常见写法)。这样可以避免用户因格式不一致而导致运行失败。
| collective = collective.lower() | |
| if collective == "all_reduce": | |
| return 2 * (ranks - 1) / ranks | |
| if collective in {"all_gather", "reduce_scatter"}: | |
| return (ranks - 1) / ranks | |
| if collective in {"broadcast", "reduce"}: | |
| return 1.0 | |
| raise ValueError(f"unsupported collective: {collective}") | |
| collective = collective.lower().replace("_", "").replace("-", "") | |
| if collective == "allreduce": | |
| return 2 * (ranks - 1) / ranks | |
| if collective in {"allgather", "reducescatter"}: | |
| return (ranks - 1) / ranks | |
| if collective in {"broadcast", "reduce"}: | |
| return 1.0 | |
| raise ValueError(f"unsupported collective: {collective}") |
| def summarize(collective: str, ranks: int, algbw_gbps: float | None = None) -> dict[str, object]: | ||
| factor = correction_factor(collective, ranks) | ||
| payload: dict[str, object] = {"collective": collective, "ranks": ranks, "factor": factor} | ||
| if algbw_gbps is not None: | ||
| payload["algbw_gbps"] = algbw_gbps | ||
| payload["busbw_gbps"] = algbw_gbps * factor | ||
| return payload |
There was a problem hiding this comment.
为了防止传入无效的带宽数据,建议对 algbw_gbps 进行非负性校验,确保数据的合理性。
| def summarize(collective: str, ranks: int, algbw_gbps: float | None = None) -> dict[str, object]: | |
| factor = correction_factor(collective, ranks) | |
| payload: dict[str, object] = {"collective": collective, "ranks": ranks, "factor": factor} | |
| if algbw_gbps is not None: | |
| payload["algbw_gbps"] = algbw_gbps | |
| payload["busbw_gbps"] = algbw_gbps * factor | |
| return payload | |
| def summarize(collective: str, ranks: int, algbw_gbps: float | None = None) -> dict[str, object]: | |
| if algbw_gbps is not None and algbw_gbps < 0: | |
| raise ValueError("algbw_gbps must be non-negative") | |
| factor = correction_factor(collective, ranks) | |
| payload: dict[str, object] = {"collective": collective, "ranks": ranks, "factor": factor} | |
| if algbw_gbps is not None: | |
| payload["algbw_gbps"] = algbw_gbps | |
| payload["busbw_gbps"] = algbw_gbps * factor | |
| return payload |
这次改动补上了集体通信总线带宽系数结构化输出,主要是为了解决集体通信测试与结果整理流程里相关信息不够集中、人工整理成本较高的问题,让日常排查、验证和结果归档更直接。
实现上补充了对应工具或脚本逻辑,补上了对应测试,同时尽量保持现有用法不变,避免影响已有流程。
这一分支已经在沐曦算力环境完成实际验证,相关检查均已通过,现提交合入。