From 52bdc0ef5cc916cc413e1c396ba5613e61719029 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Thu, 17 Apr 2025 11:07:47 -0700 Subject: [PATCH] Small update to collect_profile so that it can be added as a script in pyproject.toml PiperOrigin-RevId: 748743061 --- pathwaysutils/collect_profile.py | 55 +++++++++++++++++++------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/pathwaysutils/collect_profile.py b/pathwaysutils/collect_profile.py index 18ecf57..26a01e6 100644 --- a/pathwaysutils/collect_profile.py +++ b/pathwaysutils/collect_profile.py @@ -23,6 +23,7 @@ from pathwaysutils import profiling _logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) _DESCRIPTION = """ @@ -33,29 +34,38 @@ for a provided duration. The trace file will be dumped into a GCS bucket (determined by `--log_dir`). """ -parser = argparse.ArgumentParser(description=_DESCRIPTION) -parser.add_argument( - "--log_dir", - required=True, - help="GCS path to store log files.", - type=str, -) -parser.add_argument("port", help="Port to collect trace", type=int) -parser.add_argument( - "duration_ms", help="Duration to collect trace in milliseconds", type=int -) -parser.add_argument( - "--host", - default="127.0.0.1", - help=( - "Host to collect trace. This host IP/DNS address should be accessible" - " from where this API is being called. Defaults to 127.0.0.1" - ), - type=str, -) -def main(args): +def _get_parser(): + """Returns an argument parser for the collect_profile script.""" + parser = argparse.ArgumentParser(description=_DESCRIPTION) + parser.add_argument( + "--log_dir", + required=True, + help="GCS path to store log files.", + type=str, + ) + parser.add_argument("port", help="Port to collect trace", type=int) + parser.add_argument( + "duration_ms", help="Duration to collect trace in milliseconds", type=int + ) + parser.add_argument( + "--host", + default="127.0.0.1", + help=( + "Host to collect trace. This host IP/DNS address should be accessible" + " from where this API is being called. Defaults to 127.0.0.1" + ), + type=str, + ) + + return parser + + +def main(): + parser = _get_parser() + args = parser.parse_args() + if profiling.collect_profile( args.port, args.duration_ms, args.host, args.log_dir ): @@ -63,5 +73,6 @@ def main(args): else: _logger.error("Failed to collect profiling information.") + if __name__ == "__main__": - main(parser.parse_args()) + main()