Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/9653.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Validate cloud detection IMDS responses and harden metadata JSON parsing to prevent false positives on non-major cloud providers
12 changes: 12 additions & 0 deletions src/ai/backend/common/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,3 +1192,15 @@ def error_code(self) -> ErrorCode:
operation=ErrorOperation.EXECUTE,
error_detail=ErrorDetail.INVALID_PARAMETERS,
)


class CloudDetectionError(BackendAIError, web.HTTPInternalServerError):
error_type = "https://api.backend.ai/probs/cloud-detection-failed"
error_title = "Cloud Provider Detection Failed"

def error_code(self) -> ErrorCode:
return ErrorCode(
domain=ErrorDomain.EXTERNAL_SYSTEM,
operation=ErrorOperation.READ,
error_detail=ErrorDetail.BAD_REQUEST,
)
114 changes: 90 additions & 24 deletions src/ai/backend/common/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ifaddr
import psutil

from .exception import CloudDetectionError
from .networking import curl

__all__ = (
Expand Down Expand Up @@ -58,7 +59,17 @@ def is_containerized() -> bool:
async def _detect_aws(session: aiohttp.ClientSession) -> CloudProvider:
async with session.get(
"http://169.254.169.254/latest/meta-data/",
):
) as resp:
body = await resp.text()
if resp.status != 200:
raise CloudDetectionError(
f"AWS detection failed with status {resp.status}", extra_data=f"{body[:200]!r}"
)
# ref: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html
if "instance-id" not in body:
raise CloudDetectionError(
f"AWS detection failed with status {resp.status}", extra_data=f"{body[:200]!r}"
)
return CloudProvider.AWS


Expand All @@ -67,15 +78,47 @@ async def _detect_azure(session: aiohttp.ClientSession) -> CloudProvider:
"http://169.254.169.254/metadata/instance/compute",
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
):
) as resp:
body = await resp.text()
if resp.status != 200:
raise CloudDetectionError(
f"Azure detection failed with status {resp.status}", extra_data=f"{body[:200]!r})"
)
try:
data = await resp.json()
except (json.JSONDecodeError, ValueError):
body = await resp.text()
raise CloudDetectionError(
f"Azure detection failed with status {resp.status}", extra_data=f"{body[:200]!r}"
) from None
# ref: https://learn.microsoft.com/azure/virtual-machines/instance-metadata-service?tabs=linux
if not isinstance(data, dict) or "vmId" not in data:
body = await resp.text()
raise CloudDetectionError(
f"Azure detection failed with status {resp.status}",
extra_data=f"missing vmId key (body={body[:200]!r}) ",
) from None
return CloudProvider.AZURE


async def _detect_gcp(session: aiohttp.ClientSession) -> CloudProvider:
async with session.get(
"http://169.254.169.254/computeMetadata/v1/instance/id",
headers={"Metadata-Flavor": "Google"},
):
) as resp:
body = await resp.text()
if resp.status != 200:
raise CloudDetectionError(
f"GCP detection failed with status {resp.status}", extra_data=f"{body[:64]!r})"
)
try:
# ref: https://docs.cloud.google.com/compute/docs/metadata/predefined-metadata-keys
int(body.strip())
except ValueError:
raise CloudDetectionError(
f"GCP detection failed with status {resp.status}",
extra_data=f"non-numeric body (body={body[:64]!r})",
) from None
return CloudProvider.GCP


Expand All @@ -85,8 +128,7 @@ async def detect_cloud() -> CloudProvider | None:
to get the fastest returning result from multiple metadata URL detectors.
"""
async with aiohttp.ClientSession(
raise_for_status=True,
timeout=aiohttp.ClientTimeout(connect=0.3),
timeout=aiohttp.ClientTimeout(total=1.0, connect=0.3, sock_read=0.5),
) as session:
detection_tasks = [
functools.partial(_detect_aws, session),
Expand All @@ -100,6 +142,9 @@ async def detect_cloud() -> CloudProvider | None:
if winner_value is not None:
result: CloudProvider | None = winner_value
return result
for exc in exceptions:
if exc is not None:
log.debug(f"Cloud detection failed: {exc}")
return None


Expand Down Expand Up @@ -203,9 +248,12 @@ async def _get_instance_type() -> str:

async def _get_instance_region() -> str:
doc = await curl(_dynamic_prefix + "instance-identity/document", None)
if doc is None:
if not doc:
return "amazon/unknown"
try:
region = json.loads(doc)["region"]
except (json.JSONDecodeError, KeyError):
return "amazon/unknown"
region = json.loads(doc)["region"]
return f"amazon/{region}"

case CloudProvider.AZURE:
Expand All @@ -219,11 +267,16 @@ async def _get_instance_id() -> str:
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
)
if data is None:
if not data:
return f"i-{socket.gethostname()}"
try:
o = json.loads(data)
vm_name = o["compute"]["name"] # unique within the resource group
vm_id = uuid.UUID(
o["compute"]["vmId"]
) # prevent conflicts across resource group
except (json.JSONDecodeError, KeyError, ValueError):
return f"i-{socket.gethostname()}"
o = json.loads(data)
vm_name = o["compute"]["name"] # unique within the resource group
vm_id = uuid.UUID(o["compute"]["vmId"]) # prevent conflicts across resource group
vm_id_hash = base64.b32encode(vm_id.bytes[-5:]).decode().lower()
return f"i-{vm_name}-{vm_id_hash}"

Expand All @@ -234,12 +287,15 @@ async def _get_instance_ip(_subnet_hint: BaseIPNetwork[Any] | None = None) -> st
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
)
if data is None:
if not data:
return "127.0.0.1"
try:
o = json.loads(data)
result: str = o["network"]["interface"][0]["ipv4"]["ipAddress"][0][
"privateIpAddress"
]
except (json.JSONDecodeError, KeyError, IndexError):
return "127.0.0.1"
o = json.loads(data)
result: str = o["network"]["interface"][0]["ipv4"]["ipAddress"][0][
"privateIpAddress"
]
return result

async def _get_instance_type() -> str:
Expand All @@ -249,10 +305,13 @@ async def _get_instance_type() -> str:
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
)
if data is None:
if not data:
return "unknown"
try:
o = json.loads(data)
result: str = o["compute"]["vmSize"]
except (json.JSONDecodeError, KeyError):
return "unknown"
o = json.loads(data)
result: str = o["compute"]["vmSize"]
return result

async def _get_instance_region() -> str:
Expand All @@ -262,10 +321,13 @@ async def _get_instance_region() -> str:
params={"api-version": "2021-02-01"},
headers={"Metadata": "true"},
)
if data is None:
if not data:
return "azure/unknown"
try:
o = json.loads(data)
region = o["compute"]["location"]
except (json.JSONDecodeError, KeyError):
return "azure/unknown"
o = json.loads(data)
region = o["compute"]["location"]
return f"azure/{region}"

case CloudProvider.GCP:
Expand All @@ -278,14 +340,18 @@ async def _get_instance_id() -> str:
None,
headers={"Metadata-Flavor": "Google"},
)
if vm_id is None:
if not vm_id:
return f"i-{socket.gethostname()}"
try:
vm_id_int = int(vm_id)
except ValueError:
return f"i-{socket.gethostname()}"
vm_name = await curl(
_metadata_prefix + "instance/name",
None,
headers={"Metadata-Flavor": "Google"},
)
vm_id_hash = base64.b32encode(int(vm_id).to_bytes(8, "big")[-5:]).decode().lower()
vm_id_hash = base64.b32encode(vm_id_int.to_bytes(8, "big")[-5:]).decode().lower()
return f"i-{vm_name}-{vm_id_hash}"

async def _get_instance_ip(_subnet_hint: BaseIPNetwork[Any] | None = None) -> str:
Expand Down
Loading
Loading