diff --git a/src/runpod_flash/runtime/context.py b/src/runpod_flash/runtime/context.py index 0e3a7a6f..541c9bd5 100644 --- a/src/runpod_flash/runtime/context.py +++ b/src/runpod_flash/runtime/context.py @@ -9,10 +9,19 @@ def is_deployed_container() -> bool: A deployed container is identified by: - RUNPOD_ENDPOINT_ID is set (RunPod sets this for serverless endpoints) - OR RUNPOD_POD_ID is set (RunPod sets this for pods) + - BUT NOT when FLASH_IS_LIVE_PROVISIONING is true (explicit local dev mode) + + The FLASH_IS_LIVE_PROVISIONING flag allows local development with on-demand + provisioning even when RunPod environment variables are present (e.g., from + testing or previous deployments). Returns: True if running in deployed container, False for local dev """ + # Explicit local development mode - overrides container detection + if os.getenv("FLASH_IS_LIVE_PROVISIONING", "").lower() == "true": + return False + return bool(os.getenv("RUNPOD_ENDPOINT_ID") or os.getenv("RUNPOD_POD_ID")) diff --git a/tests/unit/runtime/test_context.py b/tests/unit/runtime/test_context.py index ce708a9c..a601ebfb 100644 --- a/tests/unit/runtime/test_context.py +++ b/tests/unit/runtime/test_context.py @@ -42,6 +42,77 @@ def test_local_development_empty_env_vars(self): ): assert is_deployed_container() is False + def test_live_provisioning_overrides_endpoint_id(self): + """FLASH_IS_LIVE_PROVISIONING=true should override RUNPOD_ENDPOINT_ID.""" + with patch.dict( + os.environ, + { + "RUNPOD_ENDPOINT_ID": "test-endpoint-123", + "FLASH_IS_LIVE_PROVISIONING": "true", + }, + ): + assert is_deployed_container() is False + + def test_live_provisioning_overrides_pod_id(self): + """FLASH_IS_LIVE_PROVISIONING=true should override RUNPOD_POD_ID.""" + with patch.dict( + os.environ, + { + "RUNPOD_POD_ID": "test-pod-456", + "FLASH_IS_LIVE_PROVISIONING": "true", + }, + clear=True, + ): + assert is_deployed_container() is False + + def test_live_provisioning_overrides_both_ids(self): + """FLASH_IS_LIVE_PROVISIONING=true should override both RunPod IDs.""" + with patch.dict( + os.environ, + { + "RUNPOD_ENDPOINT_ID": "test-endpoint-123", + "RUNPOD_POD_ID": "test-pod-456", + "FLASH_IS_LIVE_PROVISIONING": "true", + }, + ): + assert is_deployed_container() is False + + def test_live_provisioning_case_insensitive(self): + """FLASH_IS_LIVE_PROVISIONING should be case-insensitive.""" + test_values = ["true", "True", "TRUE", "TrUe"] + + for value in test_values: + with patch.dict( + os.environ, + { + "RUNPOD_ENDPOINT_ID": "test-endpoint-123", + "FLASH_IS_LIVE_PROVISIONING": value, + }, + ): + assert is_deployed_container() is False + + def test_live_provisioning_false_does_not_override(self): + """FLASH_IS_LIVE_PROVISIONING=false should not override deployment detection.""" + with patch.dict( + os.environ, + { + "RUNPOD_ENDPOINT_ID": "test-endpoint-123", + "FLASH_IS_LIVE_PROVISIONING": "false", + }, + ): + assert is_deployed_container() is True + + def test_live_provisioning_empty_does_not_override(self): + """Empty FLASH_IS_LIVE_PROVISIONING should not override deployment detection.""" + with patch.dict( + os.environ, + { + "RUNPOD_ENDPOINT_ID": "test-endpoint-123", + "FLASH_IS_LIVE_PROVISIONING": "", + }, + ): + assert is_deployed_container() is True + class TestIsLocalDevelopment: """Tests for is_local_development function.""" @@ -73,3 +144,14 @@ def test_inverse_of_is_deployed(self): for env_vars in test_cases: with patch.dict(os.environ, env_vars, clear=True): assert is_local_development() == (not is_deployed_container()) + + def test_local_with_live_provisioning(self): + """Should return True when FLASH_IS_LIVE_PROVISIONING=true even with RunPod IDs.""" + with patch.dict( + os.environ, + { + "RUNPOD_ENDPOINT_ID": "test-endpoint-123", + "FLASH_IS_LIVE_PROVISIONING": "true", + }, + ): + assert is_local_development() is True