Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mcoplib/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def _timestamp() -> str:
return datetime.now().strftime("%Y%m%dT%H%M%S")


def _resolve_output_dir(default_output_dir):
return os.getenv("PROFILER_OUTPUT_DIR", default_output_dir)


def _track_handler(prof, output_dir, func_name):
"""
Track handler implementation that matches test.py track_handler format.
Expand Down Expand Up @@ -120,7 +124,9 @@ def wrapper(*args, **kwargs):
active=1,
repeat=repeat
),
on_trace_ready=lambda prof: _track_handler(prof, output_dir, func.__name__),
on_trace_ready=lambda prof: _track_handler(
prof, _resolve_output_dir(output_dir), func.__name__
),
with_modules=True,
record_shapes=True,
profile_memory=True
Expand Down
33 changes: 33 additions & 0 deletions unit_test/test_profiler_output_dir_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import sys
import unittest
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from mcoplib.profiler import _resolve_output_dir


class ProfilerOutputDirEnvTest(unittest.TestCase):
def test_uses_default_without_env(self):
old = os.environ.pop("PROFILER_OUTPUT_DIR", None)
try:
self.assertEqual(_resolve_output_dir("./profiles"), "./profiles")
finally:
if old is not None:
os.environ["PROFILER_OUTPUT_DIR"] = old

def test_env_overrides_default(self):
old = os.environ.get("PROFILER_OUTPUT_DIR")
os.environ["PROFILER_OUTPUT_DIR"] = "/tmp/mcoplib-profiles"
try:
self.assertEqual(_resolve_output_dir("./profiles"), "/tmp/mcoplib-profiles")
finally:
if old is None:
os.environ.pop("PROFILER_OUTPUT_DIR", None)
else:
os.environ["PROFILER_OUTPUT_DIR"] = old


if __name__ == "__main__":
unittest.main()