diff --git a/changes/9653.fix.md b/changes/9653.fix.md new file mode 100644 index 00000000000..c88cf022c86 --- /dev/null +++ b/changes/9653.fix.md @@ -0,0 +1 @@ +Validate cloud detection IMDS responses and harden metadata JSON parsing to prevent false positives on non-major cloud providers diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index 0df9de5ea7a..048d5650b8d 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -1192,3 +1192,15 @@ def error_code(self) -> ErrorCode: operation=ErrorOperation.EXECUTE, error_detail=ErrorDetail.INVALID_PARAMETERS, ) + + +class CloudDetectionError(BackendAIError, web.HTTPInternalServerError): + error_type = "https://api.backend.ai/probs/cloud-detection-failed" + error_title = "Cloud Provider Detection Failed" + + def error_code(self) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.EXTERNAL_SYSTEM, + operation=ErrorOperation.READ, + error_detail=ErrorDetail.BAD_REQUEST, + ) diff --git a/src/ai/backend/common/identity.py b/src/ai/backend/common/identity.py index 41e82105210..d28d13d0197 100644 --- a/src/ai/backend/common/identity.py +++ b/src/ai/backend/common/identity.py @@ -22,6 +22,7 @@ import ifaddr import psutil +from .exception import CloudDetectionError from .networking import curl __all__ = ( @@ -58,7 +59,17 @@ def is_containerized() -> bool: async def _detect_aws(session: aiohttp.ClientSession) -> CloudProvider: async with session.get( "http://169.254.169.254/latest/meta-data/", - ): + ) as resp: + body = await resp.text() + if resp.status != 200: + raise CloudDetectionError( + f"AWS detection failed with status {resp.status}", extra_data=f"{body[:200]!r}" + ) + # ref: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html + if "instance-id" not in body: + raise CloudDetectionError( + f"AWS detection failed with status {resp.status}", extra_data=f"{body[:200]!r}" + ) return CloudProvider.AWS @@ -67,7 +78,26 @@ async def _detect_azure(session: aiohttp.ClientSession) -> CloudProvider: "http://169.254.169.254/metadata/instance/compute", params={"api-version": "2021-02-01"}, headers={"Metadata": "true"}, - ): + ) as resp: + body = await resp.text() + if resp.status != 200: + raise CloudDetectionError( + f"Azure detection failed with status {resp.status}", extra_data=f"{body[:200]!r})" + ) + try: + data = await resp.json() + except (json.JSONDecodeError, ValueError): + body = await resp.text() + raise CloudDetectionError( + f"Azure detection failed with status {resp.status}", extra_data=f"{body[:200]!r}" + ) from None + # ref: https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service?tabs=linux + if not isinstance(data, dict) or "vmId" not in data: + body = await resp.text() + raise CloudDetectionError( + f"Azure detection failed with status {resp.status}", + extra_data=f"missing vmId key (body={body[:200]!r}) ", + ) from None return CloudProvider.AZURE @@ -75,7 +105,20 @@ async def _detect_gcp(session: aiohttp.ClientSession) -> CloudProvider: async with session.get( "http://169.254.169.254/computeMetadata/v1/instance/id", headers={"Metadata-Flavor": "Google"}, - ): + ) as resp: + body = await resp.text() + if resp.status != 200: + raise CloudDetectionError( + f"GCP detection failed with status {resp.status}", extra_data=f"{body[:64]!r})" + ) + try: + # ref: https://docs.cloud.google.com/compute/docs/metadata/predefined-metadata-keys + int(body.strip()) + except ValueError: + raise CloudDetectionError( + f"GCP detection failed with status {resp.status}", + extra_data=f"non-numeric body (body={body[:64]!r})", + ) from None return CloudProvider.GCP @@ -85,8 +128,7 @@ async def detect_cloud() -> CloudProvider | None: to get the fastest returning result from multiple metadata URL detectors. """ async with aiohttp.ClientSession( - raise_for_status=True, - timeout=aiohttp.ClientTimeout(connect=0.3), + timeout=aiohttp.ClientTimeout(total=1.0, connect=0.3, sock_read=0.5), ) as session: detection_tasks = [ functools.partial(_detect_aws, session), @@ -100,6 +142,9 @@ async def detect_cloud() -> CloudProvider | None: if winner_value is not None: result: CloudProvider | None = winner_value return result + for exc in exceptions: + if exc is not None: + log.debug(f"Cloud detection failed: {exc}") return None @@ -203,9 +248,12 @@ async def _get_instance_type() -> str: async def _get_instance_region() -> str: doc = await curl(_dynamic_prefix + "instance-identity/document", None) - if doc is None: + if not doc: + return "amazon/unknown" + try: + region = json.loads(doc)["region"] + except (json.JSONDecodeError, KeyError): return "amazon/unknown" - region = json.loads(doc)["region"] return f"amazon/{region}" case CloudProvider.AZURE: @@ -219,11 +267,16 @@ async def _get_instance_id() -> str: params={"api-version": "2021-02-01"}, headers={"Metadata": "true"}, ) - if data is None: + if not data: + return f"i-{socket.gethostname()}" + try: + o = json.loads(data) + vm_name = o["compute"]["name"] # unique within the resource group + vm_id = uuid.UUID( + o["compute"]["vmId"] + ) # prevent conflicts across resource group + except (json.JSONDecodeError, KeyError, ValueError): return f"i-{socket.gethostname()}" - o = json.loads(data) - vm_name = o["compute"]["name"] # unique within the resource group - vm_id = uuid.UUID(o["compute"]["vmId"]) # prevent conflicts across resource group vm_id_hash = base64.b32encode(vm_id.bytes[-5:]).decode().lower() return f"i-{vm_name}-{vm_id_hash}" @@ -234,12 +287,15 @@ async def _get_instance_ip(_subnet_hint: BaseIPNetwork[Any] | None = None) -> st params={"api-version": "2021-02-01"}, headers={"Metadata": "true"}, ) - if data is None: + if not data: + return "127.0.0.1" + try: + o = json.loads(data) + result: str = o["network"]["interface"][0]["ipv4"]["ipAddress"][0][ + "privateIpAddress" + ] + except (json.JSONDecodeError, KeyError, IndexError): return "127.0.0.1" - o = json.loads(data) - result: str = o["network"]["interface"][0]["ipv4"]["ipAddress"][0][ - "privateIpAddress" - ] return result async def _get_instance_type() -> str: @@ -249,10 +305,13 @@ async def _get_instance_type() -> str: params={"api-version": "2021-02-01"}, headers={"Metadata": "true"}, ) - if data is None: + if not data: + return "unknown" + try: + o = json.loads(data) + result: str = o["compute"]["vmSize"] + except (json.JSONDecodeError, KeyError): return "unknown" - o = json.loads(data) - result: str = o["compute"]["vmSize"] return result async def _get_instance_region() -> str: @@ -262,10 +321,13 @@ async def _get_instance_region() -> str: params={"api-version": "2021-02-01"}, headers={"Metadata": "true"}, ) - if data is None: + if not data: + return "azure/unknown" + try: + o = json.loads(data) + region = o["compute"]["location"] + except (json.JSONDecodeError, KeyError): return "azure/unknown" - o = json.loads(data) - region = o["compute"]["location"] return f"azure/{region}" case CloudProvider.GCP: @@ -278,14 +340,18 @@ async def _get_instance_id() -> str: None, headers={"Metadata-Flavor": "Google"}, ) - if vm_id is None: + if not vm_id: + return f"i-{socket.gethostname()}" + try: + vm_id_int = int(vm_id) + except ValueError: return f"i-{socket.gethostname()}" vm_name = await curl( _metadata_prefix + "instance/name", None, headers={"Metadata-Flavor": "Google"}, ) - vm_id_hash = base64.b32encode(int(vm_id).to_bytes(8, "big")[-5:]).decode().lower() + vm_id_hash = base64.b32encode(vm_id_int.to_bytes(8, "big")[-5:]).decode().lower() return f"i-{vm_name}-{vm_id_hash}" async def _get_instance_ip(_subnet_hint: BaseIPNetwork[Any] | None = None) -> str: diff --git a/tests/unit/common/test_identity.py b/tests/unit/common/test_identity.py index 588951fcec0..7b1dcab5801 100644 --- a/tests/unit/common/test_identity.py +++ b/tests/unit/common/test_identity.py @@ -1,16 +1,28 @@ from __future__ import annotations +import json import random import secrets import socket +from collections.abc import AsyncGenerator, Generator +from dataclasses import dataclass from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import aiodns +import aiohttp import pytest from aioresponses import aioresponses import ai.backend.common.identity +from ai.backend.common.exception import CloudDetectionError +from ai.backend.common.identity import ( + CloudProvider, + _detect_aws, + _detect_azure, + _detect_gcp, + detect_cloud, +) def test_is_containerized() -> None: @@ -211,3 +223,364 @@ async def test_get_instance_type(provider: str | None) -> None: elif provider is None: ret = await ai.backend.common.identity.get_instance_type() assert ret == "default" + + +_AWS_URL = "http://169.254.169.254/latest/meta-data/" +_AZURE_URL = "http://169.254.169.254/metadata/instance/compute?api-version=2021-02-01" +_GCP_URL = "http://169.254.169.254/computeMetadata/v1/instance/id" + + +class TestDetectCloudServices: + @pytest.fixture + def mock_responses(self) -> Generator[aioresponses, None, None]: + with aioresponses() as m: + yield m + + @pytest.fixture + async def client_session( + self, mock_responses: aioresponses + ) -> AsyncGenerator[aiohttp.ClientSession, None]: + async with aiohttp.ClientSession() as session: + yield session + + @pytest.fixture + def aws_metadata_url(self) -> str: + return _AWS_URL + + async def test_valid_aws_metadata( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + aws_metadata_url: str, + ) -> None: + mock_responses.get( + aws_metadata_url, + body="ami-id\nami-launch-index\ninstance-id\ninstance-type\nlocal-hostname", + ) + result = await _detect_aws(client_session) + assert result == CloudProvider.AWS + + @pytest.mark.parametrize( + "body", + ["Cloud metadata", ""], + ids=["non_aws_body", "empty_body"], + ) + async def test_rejects_non_aws_response( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + aws_metadata_url: str, + body: str, + ) -> None: + mock_responses.get(aws_metadata_url, body=body) + with pytest.raises(CloudDetectionError, match="AWS detection failed"): + await _detect_aws(client_session) + + @pytest.mark.parametrize("status", [404, 500, 503], ids=["404", "500", "503"]) + async def test_rejects_non_200_aws_response( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + aws_metadata_url: str, + status: int, + ) -> None: + mock_responses.get(aws_metadata_url, status=status, body="error") + with pytest.raises(CloudDetectionError, match=f"AWS detection failed with status {status}"): + await _detect_aws(client_session) + + @pytest.fixture + def azure_metadata_url(self) -> str: + return _AZURE_URL + + async def test_valid_azure_metadata( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + azure_metadata_url: str, + ) -> None: + mock_responses.get( + azure_metadata_url, + body=json.dumps({"vmId": "abc-123", "name": "myvm", "vmSize": "Standard_D2s_v3"}), + ) + result = await _detect_azure(client_session) + assert result == CloudProvider.AZURE + + @pytest.mark.parametrize( + "body", + [ + "not json at all", + json.dumps({"someOtherKey": "value"}), + "", + ], + ids=["non_json", "json_without_vmid", "empty_body"], + ) + async def test_rejects_non_azure_response( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + azure_metadata_url: str, + body: str, + ) -> None: + mock_responses.get(azure_metadata_url, body=body) + with pytest.raises(CloudDetectionError, match="Azure detection failed"): + await _detect_azure(client_session) + + @pytest.mark.parametrize("status", [404, 500, 503], ids=["404", "500", "503"]) + async def test_rejects_non_200_azure_response( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + azure_metadata_url: str, + status: int, + ) -> None: + mock_responses.get(azure_metadata_url, status=status, body="error") + with pytest.raises( + CloudDetectionError, match=f"Azure detection failed with status {status}" + ): + await _detect_azure(client_session) + + @pytest.fixture + def gcp_metadata_url(self) -> str: + return _GCP_URL + + async def test_valid_gcp_metadata( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + gcp_metadata_url: str, + ) -> None: + mock_responses.get(gcp_metadata_url, body="1234567890123456") + result = await _detect_gcp(client_session) + assert result == CloudProvider.GCP + + @pytest.mark.parametrize( + "body", + ["not-a-number", ""], + ids=["non_numeric", "empty_body"], + ) + async def test_rejects_non_gcp_response( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + gcp_metadata_url: str, + body: str, + ) -> None: + mock_responses.get(gcp_metadata_url, body=body) + with pytest.raises(CloudDetectionError, match="GCP detection failed"): + await _detect_gcp(client_session) + + @pytest.mark.parametrize("status", [404, 500, 503], ids=["404", "500", "503"]) + async def test_rejects_non_200_gcp_response( + self, + mock_responses: aioresponses, + client_session: aiohttp.ClientSession, + gcp_metadata_url: str, + status: int, + ) -> None: + mock_responses.get(gcp_metadata_url, status=status, body="error") + with pytest.raises(CloudDetectionError, match=f"GCP detection failed with status {status}"): + await _detect_gcp(client_session) + + +@dataclass(frozen=True) +class IMDSMock: + """Mocked IMDS endpoint response specification.""" + + body: str = "" + status: int = 200 + + +@dataclass(frozen=True) +class DetectCloudScenario: + """Bundled scenario for detect_cloud() parametrized tests.""" + + aws: IMDSMock + azure: IMDSMock + gcp: IMDSMock + expected: CloudProvider | None + + +class TestDetectCloud: + @pytest.fixture + def mock_responses(self) -> Generator[aioresponses, None, None]: + with aioresponses() as m: + yield m + + @pytest.mark.parametrize( + "scenario", + [ + pytest.param( + DetectCloudScenario( + aws=IMDSMock(body="ami-id\ninstance-id\ninstance-type"), + azure=IMDSMock(status=404), + gcp=IMDSMock(status=404), + expected=CloudProvider.AWS, + ), + id="aws_wins", + ), + pytest.param( + DetectCloudScenario( + aws=IMDSMock(status=404), + azure=IMDSMock(body=json.dumps({"vmId": "abc-123"})), + gcp=IMDSMock(status=404), + expected=CloudProvider.AZURE, + ), + id="azure_wins", + ), + pytest.param( + DetectCloudScenario( + aws=IMDSMock(status=404), + azure=IMDSMock(status=404), + gcp=IMDSMock(body="1234567890123456"), + expected=CloudProvider.GCP, + ), + id="gcp_wins", + ), + pytest.param( + DetectCloudScenario( + aws=IMDSMock(status=404), + azure=IMDSMock(status=404), + gcp=IMDSMock(status=404), + expected=None, + ), + id="all_non_200", + ), + ], + ) + async def test_detect_cloud( + self, + mock_responses: aioresponses, + scenario: DetectCloudScenario, + ) -> None: + mock_responses.get(_AWS_URL, status=scenario.aws.status, body=scenario.aws.body) + mock_responses.get(_AZURE_URL, status=scenario.azure.status, body=scenario.azure.body) + mock_responses.get(_GCP_URL, status=scenario.gcp.status, body=scenario.gcp.body) + result = await detect_cloud() + assert result == scenario.expected + + async def test_detect_cloud_returns_none_on_network_errors( + self, + mock_responses: aioresponses, + ) -> None: + mock_responses.get(_AWS_URL, exception=aiohttp.ClientConnectionError()) + mock_responses.get(_AZURE_URL, exception=aiohttp.ClientConnectionError()) + mock_responses.get(_GCP_URL, exception=aiohttp.ClientConnectionError()) + result = await detect_cloud() + assert result is None + + async def test_detect_cloud_picks_valid_when_others_fail( + self, + mock_responses: aioresponses, + ) -> None: + mock_responses.get(_AWS_URL, body="not aws") + mock_responses.get(_AZURE_URL, exception=aiohttp.ClientConnectionError()) + mock_responses.get(_GCP_URL, body="1234567890123456") + result = await detect_cloud() + assert result == CloudProvider.GCP + + +class TestIdentityFunctions: + @pytest.fixture + def mock_curl(self) -> Generator[AsyncMock, None, None]: + mock = AsyncMock() + with patch("ai.backend.common.identity.curl", mock): + yield mock + + @pytest.fixture + def mock_hostname(self) -> Generator[None, None, None]: + with patch("socket.gethostname", return_value="testhost"): + yield + + @pytest.fixture + def aws_provider(self) -> None: + ai.backend.common.identity.current_provider = CloudProvider.AWS + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + return + + @pytest.mark.parametrize( + ("curl_return", "expected"), + [ + (json.dumps({"region": "us-east-1"}), "amazon/us-east-1"), + ("not json", "amazon/unknown"), + (json.dumps({"otherKey": "value"}), "amazon/unknown"), + ("", "amazon/unknown"), + ], + ids=["valid_json", "invalid_json", "missing_key", "empty_response"], + ) + async def test_get_instance_region( + self, mock_curl: AsyncMock, aws_provider: None, curl_return: str, expected: str + ) -> None: + mock_curl.return_value = curl_return + result = await ai.backend.common.identity.get_instance_region() + assert result == expected + + @pytest.fixture + def azure_provider(self) -> None: + ai.backend.common.identity.current_provider = CloudProvider.AZURE + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + return + + async def test_get_instance_id_with_invalid_json( + self, mock_curl: AsyncMock, mock_hostname: None, azure_provider: None + ) -> None: + mock_curl.return_value = "not json" + result = await ai.backend.common.identity.get_instance_id() + assert result == "i-testhost" + + @pytest.mark.parametrize( + ("curl_return", "expected"), + [ + ("not json", "127.0.0.1"), + ("", "127.0.0.1"), + ], + ids=["invalid_json", "empty_response"], + ) + async def test_get_instance_ip_fallback( + self, mock_curl: AsyncMock, azure_provider: None, curl_return: str, expected: str + ) -> None: + mock_curl.return_value = curl_return + result = await ai.backend.common.identity.get_instance_ip(None) + assert result == expected + + async def test_get_instance_type_with_invalid_json( + self, mock_curl: AsyncMock, azure_provider: None + ) -> None: + mock_curl.return_value = "not json" + result = await ai.backend.common.identity.get_instance_type() + assert result == "unknown" + + @pytest.mark.parametrize( + ("curl_return", "expected"), + [ + ("not json", "azure/unknown"), + (json.dumps({"compute": {"otherKey": "val"}}), "azure/unknown"), + ], + ids=["invalid_json", "missing_key"], + ) + async def test_get_instance_region_fallback( + self, mock_curl: AsyncMock, azure_provider: None, curl_return: str, expected: str + ) -> None: + mock_curl.return_value = curl_return + result = await ai.backend.common.identity.get_instance_region() + assert result == expected + + @pytest.fixture + def gcp_provider(self) -> None: + ai.backend.common.identity.current_provider = CloudProvider.GCP + ai.backend.common.identity._defined = False + ai.backend.common.identity._define_functions() + return + + @pytest.mark.parametrize( + "curl_return", + ["not-a-number", ""], + ids=["non_numeric", "empty"], + ) + async def test_get_instance_id_fallback( + self, mock_curl: AsyncMock, mock_hostname: None, gcp_provider: None, curl_return: str + ) -> None: + mock_curl.return_value = curl_return + result = await ai.backend.common.identity.get_instance_id() + assert result == "i-testhost"