diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index cf629ee65ea..2c55b7a6563 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -117,6 +117,38 @@ az ssh cert --file ./id_rsa-aadcert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH" """ +helps['ssh cert-create'] = """ + type: command + short-summary: Create a short-lived SSH certificate signed by a private CA key in Azure Key Vault. + long-summary: | + Generates an ephemeral SSH key pair, determines the caller's RBAC role + on the target ProvisionedMachine resource via PIM-based JIT access, and + sends the public key along with metadata (userPublicKey, username, role, + expiry) to Key Vault for signing. + + The user's role is NOT taken as input — it is resolved automatically from + the RBAC role assignment on the device resource. + + Currently the extension relies on the built-in Azure roles (Owner, + Contributor, Reader) because we do not yet have permission to create + custom roles. The final intended roles are: + - Provisioned Machine Administrator (full SSH with sudo) + - Provisioned Machine Contributor (SSH without sudo) + - Provisioned Machine Reader (view-only; SSH restricted on device) + Certificates are generated for all roles — access restrictions are + enforced on the device side, not by the CLI. + These custom roles are pending creation (Teodora, Eric — please help + finalize so the CLI extension can be completed). + + The user identity is derived automatically from the Entra login context. + The certificate expiry is derived from the PIM activation's remaining duration. + Returns the signed SSH user certificate and the freshly generated private key. + examples: + - name: Create a certificate (expiry derived from PIM activation) + text: | + az ssh cert-create --vault-name myKeyVault --resource-id /subscriptions/.../providers/Microsoft.ProvisionedMachine/machines/myDevice +""" + helps['ssh arc'] = """ type: command short-summary: SSH into Azure Arc Servers diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index f0d22edc576..e0eff0a4620 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -83,6 +83,16 @@ def load_arguments(self, _): help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). ' 'Default to ssh pre-installed if not provided.') + with self.argument_context('ssh cert-create') as c: + c.argument('vault_name', options_list=['--vault-name', '-v'], + help='Name of the Azure Key Vault that holds the private CA signing key (ssh-ca).', + required=True) + c.argument('resource_id', options_list=['--resource-id', '-r'], + help='Fully qualified ARM resource ID of the ProvisionedMachine. ' + 'Used to determine the user\'s RBAC role (Reader/Contributor/Admin) ' + 'via PIM-based JIT access.', + required=True) + with self.argument_context('ssh arc') as c: c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the Arc Server') c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path') diff --git a/src/ssh/azext_ssh/commands.py b/src/ssh/azext_ssh/commands.py index 335ba21eca6..969987f0a25 100644 --- a/src/ssh/azext_ssh/commands.py +++ b/src/ssh/azext_ssh/commands.py @@ -11,3 +11,4 @@ def load_command_table(self, _): g.custom_command('config', 'ssh_config') g.custom_command('cert', 'ssh_cert') g.custom_command('arc', 'ssh_arc') + g.custom_command('cert-create', 'ssh_cert_create') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 0123fff6121..8cbd3d526ef 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -152,6 +152,99 @@ def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, p resource_type, ssh_proxy_folder, winrdp, yes_without_prompt, ssh_args) +def ssh_cert_create(cmd, vault_name, resource_id): + """Create a short-lived SSH certificate signed by a private CA in Key Vault. + + Step 1 - Prepare certificate metadata: + Generate ephemeral key pair, resolve PIM role, build metadata: + { user_public_key, username, role, expiry } + Expiry is derived from the PIM activation's remaining duration. + + Step 2 - Sign via Key Vault: + Send signing request to Key Vault Sign API using az login context. + CA private key never leaves Key Vault. + + Step 3 - Return to user: + Signed SSH user certificate + freshly generated user_private key. + """ + from . import provisioned_machine_utils as pm + + telemetry.set_command_details('ssh cert-create') + + # Validate inputs. + pm.validate_resource_id(resource_id) + pm.validate_vault_name(vault_name) + + private_key_path = None + cert_path = None + try: + # -- Step 1: Prepare certificate metadata -------------------------- + # Derive username from az login (Entra) context. + username = pm.get_current_user_principal(cmd) + logger.info("Derived username: %s", username) + + # Verify the user has an active PIM assignment (JIT activated). + # startTime/endTime are derived from the PIM activation window. + _pim_instances, start_time, end_time = pm.check_pim_eligibility(cmd, resource_id) + logger.info("PIM eligibility confirmed for resource: %s (valid %s to %s)", + resource_id, start_time, end_time) + + # Resolve role from PIM assignment on the ProvisionedMachine resource. + # Reader role is blocked — only Contributor and Administrator can + # generate SSH certificates. + role = pm.resolve_user_role(cmd, resource_id) + logger.info("Resolved PIM role: %s", role) + + # Determine which certificate types this role can generate. + role_perms = pm.ROLE_PERMISSIONS.get(role, {}) + cert_types = role_perms.get("certificate_types", []) + logger.info("Allowed certificate types for %s: %s", role, cert_types) + + # Generate fresh ephemeral SSH key pair. + private_key_path, public_key_path = pm.generate_ephemeral_keypair() + with open(public_key_path, "r", encoding="utf-8") as f: + user_public_key = f.read().strip() + + certificate_metadata = { + "userPublicKey": user_public_key, + "username": username, + "role": role, + "startTime": start_time, + "endTime": end_time, + } + + # -- Step 2: Sign via Key Vault ------------------------------------ + # AZ CLI sends signing request using az login context. + # CA private key never leaves Key Vault. + signed_certificate = pm.sign_certificate_metadata( + cmd, vault_name, certificate_metadata + ) + cert_path = signed_certificate["certificatePath"] + + # -- Step 3: Return to user ---------------------------------------- + result = { + "privateKeyPath": private_key_path, + "certificatePath": cert_path, + } + + print_styled_text((Style.SUCCESS, + f"SSH certificate created successfully.\n" + f" Private key : {private_key_path}\n" + f" Certificate : {cert_path}")) + + logger.warning("The private key at %s is sensitive. " + "Delete it once the certificate expires.", + os.path.dirname(private_key_path)) + + telemetry.set_success() + return result + + except Exception: + # Clean up sensitive ephemeral files on failure. + pm.cleanup_ephemeral_files(private_key_path, cert_path) + raise + + def _do_ssh_op(cmd, op_info, op_call): # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys if not op_info.is_arc(): diff --git a/src/ssh/azext_ssh/provisioned_machine_utils.py b/src/ssh/azext_ssh/provisioned_machine_utils.py new file mode 100644 index 00000000000..40131461319 --- /dev/null +++ b/src/ssh/azext_ssh/provisioned_machine_utils.py @@ -0,0 +1,603 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Utilities for the `az ssh cert-create` command. + +Step 1 – Prepare certificate metadata: + user_public_key + username + role (from PIM assignment) + expiry + +Step 2 – Sign via Key Vault: + AZ CLI sends signing request to Key Vault Sign API using az login context. + CA private key never leaves Key Vault. + +Step 3 – Return to user: + Signed SSH user certificate + freshly generated user_private key. + +Expiry constraint: maximum 8 hours (enforced at device level). +""" + +import base64 +import datetime +import hashlib +import json +import os +import re +import stat +import subprocess +import tempfile + +import oschmod +import requests +from knack import log +from azure.cli.core import azclierror + +logger = log.get_logger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +_KV_SIGN_API_VERSION = "7.4" +_KV_CA_KEY_NAME = "ssh-ca" +_KV_SIGN_ALGORITHM = "RS256" +_KV_RESOURCE = "https://vault.azure.net" +_KV_SIGN_TIMEOUT_SECONDS = 30 +_RSA_KEY_BITS = 4096 +_PRIVATE_KEY_FILE_PERMISSION = 0o600 # owner read/write only +_RESOURCE_ID_PATTERN = re.compile( + r"^/subscriptions/[0-9a-fA-F-]+/resourceGroups/[^/]+/providers/[^/]+/[^/]+/[^/]+$", + re.IGNORECASE, +) +_VAULT_NAME_PATTERN = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$") + +# Standard roles for ProvisionedMachine resources. +# Maps Azure RBAC role name substrings to a canonical role label. +PROVISIONED_MACHINE_ROLES = { + "reader": "Reader", + "contributor": "Contributor", + "owner": "Owner", +} + +# Permissions matrix per role. +# Each role defines which certificate types it can generate and what +# capabilities the device should grant. +# +# Certificates are generated for ALL roles. Access restrictions are enforced +# on the device side, not by the CLI. +# Reader - View-only in portal; device blocks SSH access. +# Contributor - Config app + SSH (non-sudo). +# Owner - Config app + SSH (non-sudo) + SSH (sudo). +ROLE_PERMISSIONS = { + "Reader": { + "ssh_allowed": False, + "certificate_types": [], + "portal": ["view"], + }, + "Contributor": { + "ssh_allowed": True, + "certificate_types": ["config-app", "ssh-non-sudo"], + "portal": ["view", "create", "manage-updates", "manage-nics", + "collect-logs", "manage-networking"], + }, + "Owner": { + "ssh_allowed": True, + "certificate_types": ["config-app", "ssh-non-sudo", "ssh-sudo"], + "portal": ["view", "create", "manage-updates", "manage-nics", + "manage-disks", "reset-device", "keyvault-access", + "delete", "grant-access", "pim-setup", + "collect-logs", "manage-networking"], + }, +} + + +# --------------------------------------------------------------------------- +# Input validation +# --------------------------------------------------------------------------- + +def validate_resource_id(resource_id): + """Validate that *resource_id* looks like a fully-qualified ARM resource ID.""" + if not resource_id or not _RESOURCE_ID_PATTERN.match(resource_id): + raise azclierror.InvalidArgumentValueError( + f"'{resource_id}' is not a valid ARM resource ID. " + "Expected format: /subscriptions//resourceGroups//" + "providers///" + ) + + +def validate_vault_name(vault_name): + """Validate that *vault_name* conforms to Key Vault naming rules.""" + if not vault_name or not _VAULT_NAME_PATTERN.match(vault_name): + raise azclierror.InvalidArgumentValueError( + f"'{vault_name}' is not a valid Key Vault name. " + "It must be 3-24 characters, start with a letter, end with a " + "letter or digit, and contain only letters, digits, and hyphens." + ) + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + +def get_current_user_principal(cmd): + """Return the UPN (or app ID) of the currently signed-in identity. + + The value is derived from the Entra / Azure CLI login context so the + caller never needs to supply it manually. + """ + from azure.cli.core._profile import Profile + profile = Profile(cli_ctx=cmd.cli_ctx) + try: + user = profile.get_current_account_user() + except Exception as ex: + raise azclierror.AuthenticationError( + "Unable to determine the signed-in user. " + "Please run 'az login' first." + ) from ex + if not user: + raise azclierror.AuthenticationError( + "No signed-in user found. Please run 'az login' first." + ) + return user + + +def check_pim_eligibility(cmd, resource_id): + """Verify the current user has an **active** PIM role assignment on *resource_id*. + + Queries the PIM Role Assignment Schedule Instances API to confirm that + the user has activated JIT access. Raises ``AuthenticationError`` if + no active PIM assignment is found. + + Returns the list of active PIM schedule instances. + """ + from azure.cli.core._profile import Profile + + user_object_id = _get_current_user_object_id(cmd) + profile = Profile(cli_ctx=cmd.cli_ctx) + + try: + creds, _, _ = profile.get_login_credentials() + token = creds.get_token("https://management.azure.com/.default") + except Exception as ex: + raise azclierror.AuthenticationError( + "Failed to acquire a management token for PIM eligibility check." + ) from ex + + # Query active PIM role assignment schedule instances scoped to the resource. + api_version = "2020-10-01" + url = ( + f"https://management.azure.com{resource_id}" + f"/providers/Microsoft.Authorization" + f"/roleAssignmentScheduleInstances" + f"?api-version={api_version}" + f"&$filter=assignedTo('{user_object_id}')" + ) + + try: + resp = requests.get( + url, + headers={"Authorization": f"Bearer {token.token}"}, + timeout=30, + ) + except requests.exceptions.RequestException as ex: + raise azclierror.CLIInternalError( + f"Failed to query PIM eligibility on '{resource_id}': {ex}" + ) from ex + + if resp.status_code == 404: + raise azclierror.ResourceNotFoundError( + f"Resource '{resource_id}' was not found. Verify the resource ID is correct." + ) + if resp.status_code != 200: + raise azclierror.CLIInternalError( + f"PIM eligibility check failed (HTTP {resp.status_code}): {resp.text}" + ) + + data = resp.json() + instances = data.get("value", []) + + # Filter for PIM-activated assignments only. + # assignmentType == "Activated" means the user activated JIT access via PIM. + # assignmentType == "Assigned" means a permanent/direct role assignment, + # which should NOT be accepted — PIM activation is required. + pim_activated = [ + inst for inst in instances + if (inst.get("properties", {}).get("assignmentType", "")).lower() == "activated" + ] + + if not pim_activated: + has_direct = len(instances) > 0 + if has_direct: + raise azclierror.AuthenticationError( + f"You have a direct (permanent) role assignment on resource " + f"'{resource_id}', but PIM-based JIT activation is required. " + f"Direct role assignments are not accepted for SSH certificate " + f"generation. Please activate your role via PIM first:\n" + f" 1. Go to Azure Portal → Privileged Identity Management → My roles\n" + f" 2. Select Azure resources → find your eligible role\n" + f" 3. Click 'Activate' and provide a justification\n" + f" 4. Wait 1-2 minutes for propagation, then retry." + ) + raise azclierror.AuthenticationError( + f"No active PIM role assignment found for the current user on resource " + f"'{resource_id}'. Please activate your PIM-eligible role first:\n" + f" 1. Go to Azure Portal → Privileged Identity Management → My roles\n" + f" 2. Select Azure resources → find your eligible role\n" + f" 3. Click 'Activate' and provide a justification\n" + f" 4. Wait 1-2 minutes for propagation, then retry." + ) + + # Extract the expiry from the PIM activation's endDateTime. + # Use the latest endDateTime among all activated assignments. + now_utc = datetime.datetime.now(datetime.timezone.utc) + latest_end = None + for inst in pim_activated: + end_str = inst.get("properties", {}).get("endDateTime") + if end_str: + try: + end_dt = datetime.datetime.fromisoformat(end_str.replace("Z", "+00:00")) + if latest_end is None or end_dt > latest_end: + latest_end = end_dt + except (ValueError, TypeError): + logger.debug("Could not parse endDateTime '%s'.", end_str) + + if latest_end is None: + raise azclierror.CLIInternalError( + "PIM activation found but endDateTime is missing. " + "Cannot determine certificate expiry." + ) + + if latest_end <= now_utc: + raise azclierror.AuthenticationError( + f"Your PIM activation has expired (ended {latest_end.isoformat()}). " + f"Please re-activate your PIM-eligible role and retry." + ) + + start_time = now_utc.strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = latest_end.strftime("%Y-%m-%dT%H:%M:%SZ") + + remaining_hours = (latest_end - now_utc).total_seconds() / 3600.0 + logger.info("Found %d PIM-activated assignment(s) for user '%s' on '%s'. " + "Remaining: %.2f hours (until %s).", + len(pim_activated), user_object_id, resource_id, + remaining_hours, latest_end.isoformat()) + return pim_activated, start_time, end_time + + +def resolve_user_role(cmd, resource_id): + """Determine the highest-privilege role the signed-in user holds on *resource_id*. + + Uses the Azure Authorization SDK to list role assignments scoped to the + ProvisionedMachine resource and maps them to Reader / Contributor / Owner. + + Returns one of ``Contributor`` or ``Owner``. + + Raises ``AuthenticationError`` if: + - No relevant assignment is found. + - The highest role is Reader (Reader has no SSH access). + """ + from azure.cli.core.commands.client_factory import get_mgmt_service_client + + try: + from azure.mgmt.authorization import AuthorizationManagementClient + except ImportError as ex: + raise azclierror.CLIInternalError( + "The 'azure-mgmt-authorization' package is required. " + "Please run: pip install azure-mgmt-authorization" + ) from ex + + user_object_id = _get_current_user_object_id(cmd) + + try: + auth_client = get_mgmt_service_client(cmd.cli_ctx, AuthorizationManagementClient) + assignments = list(auth_client.role_assignments.list_for_scope( + scope=resource_id, + filter=f"assignedTo('{user_object_id}')" + )) + except Exception as ex: + raise azclierror.CLIInternalError( + f"Failed to query role assignments on '{resource_id}': {ex}" + ) from ex + + if not assignments: + raise azclierror.AuthenticationError( + f"No role assignments found for the current user on resource " + f"'{resource_id}'. Ensure PIM-based JIT access has been activated." + ) + + role_priority = {"Owner": 3, "Contributor": 2, "Reader": 1} + best_role = None + best_priority = 0 + + for assignment in assignments: + role_def_id = assignment.role_definition_id + try: + role_def = auth_client.role_definitions.get_by_id(role_def_id) + except Exception: # pylint: disable=broad-except + logger.debug("Skipping role definition '%s' (could not resolve).", role_def_id) + continue + role_name = (role_def.role_name or "").lower() + + for key, standard in PROVISIONED_MACHINE_ROLES.items(): + if key in role_name: + priority = role_priority.get(standard, 0) + if priority > best_priority: + best_role = standard + best_priority = priority + + if not best_role: + raise azclierror.AuthenticationError( + f"No Reader, Contributor, or Owner role assignment found for " + f"the current user on resource '{resource_id}'. Ensure PIM-based " + f"JIT access has been activated and the role is scoped to the " + f"ProvisionedMachine resource." + ) + + logger.info("Resolved role '%s' for user '%s' on resource '%s'.", + best_role, user_object_id, resource_id) + return best_role + + +def generate_ephemeral_keypair(ssh_client_folder=None): + """Generate a fresh RSA-4096 key pair in a secure temp directory. + + The private key file permissions are set to 0600 (owner read/write only). + + Returns ``(private_key_path, public_key_path)``. + """ + keys_dir = tempfile.mkdtemp(prefix="azssh_pm_") + private_key_path = os.path.join(keys_dir, "id_rsa.pem") + public_key_path = private_key_path + ".pub" + + keygen = _resolve_keygen(ssh_client_folder) + cmd_args = [ + keygen, "-t", "rsa", "-b", str(_RSA_KEY_BITS), + "-f", private_key_path, "-N", "", "-q", + ] + + try: + subprocess.check_call(cmd_args, timeout=30) # pylint: disable=subprocess-run-check + except FileNotFoundError as ex: + raise azclierror.CLIInternalError( + "ssh-keygen not found. Ensure OpenSSH is installed or provide " + "--ssh-client-folder." + ) from ex + except subprocess.TimeoutExpired as ex: + raise azclierror.CLIInternalError( + "ssh-keygen timed out while generating the key pair." + ) from ex + except subprocess.CalledProcessError as ex: + raise azclierror.CLIInternalError( + f"ssh-keygen exited with code {ex.returncode}." + ) from ex + + if not os.path.isfile(private_key_path) or not os.path.isfile(public_key_path): + raise azclierror.CLIInternalError( + "ssh-keygen completed but key files were not created." + ) + + # Restrict private key to owner-only access (cross-platform via oschmod). + oschmod.set_mode(private_key_path, _PRIVATE_KEY_FILE_PERMISSION) + + logger.info("Generated ephemeral key pair at %s", keys_dir) + return private_key_path, public_key_path + + +def cleanup_ephemeral_files(*file_paths): + """Best-effort removal of sensitive ephemeral files and their parent dirs.""" + for path in file_paths: + if not path: + continue + try: + parent = os.path.dirname(path) + # Remove all files in the temp directory. + if os.path.isdir(parent) and parent.startswith(tempfile.gettempdir()): + import shutil + shutil.rmtree(parent, ignore_errors=True) + elif os.path.isfile(path): + os.remove(path) + except OSError: + logger.debug("Failed to clean up '%s'.", path) + + +def sign_certificate_metadata(cmd, keyvault_name, metadata): + """Sign the certificate metadata with the CA private key in Key Vault. + + Metadata shape: { userPublicKey, username, role, startTime, endTime } + + The Key Vault hosts a non-exportable CA private key (named ``ssh-ca``). + AZ CLI sends the signing request using the az login context. + The CA private key never leaves Key Vault. + + Returns a dict with ``signedCertificate`` and ``certificatePath``. + """ + signing_payload = { + "userPublicKey": metadata["userPublicKey"], + "username": metadata["username"], + "role": metadata["role"], + "startTime": metadata["startTime"], + "endTime": metadata["endTime"], + } + + logger.info("Sending signing request to Key Vault '%s' (valid %s to %s) ...", + keyvault_name, metadata["startTime"], metadata["endTime"]) + + # Sign via Key Vault - CA private key never leaves the vault. + _signature, cert_data = _call_keyvault_sign(cmd, keyvault_name, signing_payload) + + # Write the signed SSH user certificate to a temp file. + cert_dir = tempfile.mkdtemp(prefix="azssh_cert_") + cert_path = os.path.join(cert_dir, "ssh-cert.pub") + with open(cert_path, "w", encoding="utf-8") as f: + f.write(cert_data) + oschmod.set_mode(cert_path, stat.S_IRUSR | stat.S_IWUSR) # 0600 + + # Uncomment below to write signing payload metadata for debugging/verification. + # metadata_path = os.path.join(cert_dir, "metadata.json") + # with open(metadata_path, "w", encoding="utf-8") as f: + # json.dump(signing_payload, f, indent=2) + # logger.info("Signing metadata written to %s", metadata_path) + + logger.info("Signed SSH user certificate written to %s", cert_path) + return { + "signedCertificate": cert_data, + "certificatePath": cert_path, + } + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _get_current_user_object_id(cmd): + """Return the OID of the currently signed-in user / service principal. + + Extracts the ``oid`` claim from the ARM access token. This avoids a + dependency on the deprecated ``azure-graphrbac`` SDK. + """ + from azure.cli.core._profile import Profile + + profile = Profile(cli_ctx=cmd.cli_ctx) + try: + creds, _, _ = profile.get_login_credentials() + token = creds.get_token("https://management.azure.com/.default") + except Exception as ex: + raise azclierror.AuthenticationError( + "Failed to acquire an access token. Please run 'az login'." + ) from ex + + # Decode without verification – we only need the 'oid' claim and the + # token was just issued by the CLI's own credential chain. + try: + # The token is a JWT; decode the payload (middle segment). + payload_segment = token.token.split(".")[1] + # Add padding if needed. + padded = payload_segment + "=" * (4 - len(payload_segment) % 4) + payload_bytes = base64.urlsafe_b64decode(padded) + claims = json.loads(payload_bytes) + except Exception as ex: + raise azclierror.CLIInternalError( + "Failed to decode the access token to extract the user object ID." + ) from ex + + oid = claims.get("oid") or claims.get("sub") + if not oid: + raise azclierror.CLIInternalError( + "The access token does not contain an 'oid' or 'sub' claim." + ) + return oid + + +def _call_keyvault_sign(cmd, keyvault_name, metadata): + """Call Key Vault REST API to sign the metadata. + + Uses the ``ssh-ca`` key in the vault to perform an RS256 sign operation. + The CA private key never leaves Key Vault. + + Returns a tuple of ``(signature_b64, certificate_string)``. + """ + from azure.cli.core._profile import Profile + + profile = Profile(cli_ctx=cmd.cli_ctx) + try: + creds, _, _ = profile.get_login_credentials() + token = creds.get_token(f"{_KV_RESOURCE}/.default") + except Exception as ex: + raise azclierror.AuthenticationError( + f"Failed to acquire a Key Vault access token for vault " + f"'{keyvault_name}'. Ensure you have 'Key Sign' permissions " + f"on the vault. Error: {ex}" + ) from ex + + vault_url = f"https://{keyvault_name}.vault.azure.net" + sign_url = (f"{vault_url}/keys/{_KV_CA_KEY_NAME}/sign" + f"?api-version={_KV_SIGN_API_VERSION}") + + headers = { + "Authorization": f"Bearer {token.token}", + "Content-Type": "application/json", + } + + request_body = { + "alg": _KV_SIGN_ALGORITHM, + "value": _build_signing_payload(metadata), + } + + try: + response = requests.post( + sign_url, headers=headers, json=request_body, + timeout=_KV_SIGN_TIMEOUT_SECONDS, + ) + except requests.exceptions.Timeout as ex: + raise azclierror.CLIInternalError( + f"Key Vault signing request timed out after " + f"{_KV_SIGN_TIMEOUT_SECONDS}s. Please try again." + ) from ex + except requests.exceptions.ConnectionError as ex: + raise azclierror.CLIInternalError( + f"Unable to connect to Key Vault '{keyvault_name}'. " + f"Check network connectivity and vault name." + ) from ex + + if response.status_code == 401: + raise azclierror.AuthenticationError( + f"Access denied to Key Vault '{keyvault_name}'. " + f"Ensure the signed-in identity has 'Key Sign' permission " + f"on the '{_KV_CA_KEY_NAME}' key." + ) + if response.status_code == 404: + raise azclierror.ResourceNotFoundError( + f"Key '{_KV_CA_KEY_NAME}' not found in vault '{keyvault_name}'. " + f"Ensure the CA signing key exists." + ) + if response.status_code != 200: + raise azclierror.CLIInternalError( + f"Key Vault signing failed (HTTP {response.status_code}): " + f"{response.text}" + ) + + result = response.json() + signature_b64 = result.get("value") + if not signature_b64: + raise azclierror.CLIInternalError( + "Key Vault returned an empty signature. " + "Check the CA key configuration." + ) + + cert_data = _build_ssh_certificate(metadata, signature_b64) + return signature_b64, cert_data + + +def _build_signing_payload(metadata): + """Create a base64url-encoded SHA-256 digest of the metadata for Key Vault. + + Key Vault sign API expects a digest, not raw data. We compute + SHA-256(canonical JSON) and base64url-encode the result. + """ + canonical = json.dumps( + metadata, separators=(",", ":"), sort_keys=True, ensure_ascii=True + ).encode("utf-8") + digest = hashlib.sha256(canonical).digest() + return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") + + +def _build_ssh_certificate(metadata, signature_b64): + """Construct an OpenSSH certificate from the metadata and KV signature. + + NOTE: In production the Key Vault / CA service would return a fully-formed + ``ssh-rsa-cert-v01@openssh.com`` certificate. This helper is a placeholder + that concatenates the public key with the CA-signed blob so that the + overall CLI flow can be tested end-to-end once the CA service contract is + finalized. + """ + # Placeholder – real implementation depends on the Device API / CA contract. + public_key = metadata["userPublicKey"] + # Return a synthetic certificate string that the device agent will validate. + return f"{public_key} {signature_b64}" + + +def _resolve_keygen(ssh_client_folder): + if ssh_client_folder: + return os.path.join(ssh_client_folder, "ssh-keygen") + return "ssh-keygen" diff --git a/src/ssh/azext_ssh/tests/latest/test_cert_create.py b/src/ssh/azext_ssh/tests/latest/test_cert_create.py new file mode 100644 index 00000000000..44499a00e88 --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_cert_create.py @@ -0,0 +1,329 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Tests for the ``az ssh cert-create`` command (custom.ssh_cert_create). + +All external dependencies (RBAC, Key Vault, key generation) are mocked. +Tests verify the orchestration logic, input validation, PIM-derived expiry, +error handling, cleanup on failure, and correct return shape. +""" + +import os +import tempfile +import unittest +from unittest import mock + +from azure.cli.core import azclierror +from azext_ssh import custom + + +# Convenience: a valid ARM resource ID used across tests. +_VALID_RESOURCE_ID = ( + "/subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/myRG/providers/Microsoft.ProvisionedMachine" + "/machines/myDevice" +) +_VALID_VAULT = "myKeyVault" + + +def _make_cmd(): + """Return a mock CLI cmd object.""" + cmd = mock.Mock() + cmd.cli_ctx = mock.Mock() + return cmd + + +class TestSshCertCreateValidation(unittest.TestCase): + """Input validation tests.""" + + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_invalid_resource_id_raises(self, mock_validate_rid, mock_validate_vn): + mock_validate_rid.side_effect = azclierror.InvalidArgumentValueError("bad id") + cmd = _make_cmd() + + with self.assertRaises(azclierror.InvalidArgumentValueError): + custom.ssh_cert_create(cmd, _VALID_VAULT, "bad-id") + + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + def test_invalid_vault_name_raises(self, mock_validate_vn, mock_validate_rid): + mock_validate_vn.side_effect = azclierror.InvalidArgumentValueError("bad vault") + cmd = _make_cmd() + + with self.assertRaises(azclierror.InvalidArgumentValueError): + custom.ssh_cert_create(cmd, "-bad-vault!", _VALID_RESOURCE_ID) + + +class TestSshCertCreateHappyPath(unittest.TestCase): + """Full happy-path tests with all external calls mocked.""" + + def _setup_mocks(self): + self.keys_dir = tempfile.mkdtemp(prefix="azssh_test_") + self.priv_path = os.path.join(self.keys_dir, "id_rsa.pem") + self.pub_path = self.priv_path + ".pub" + with open(self.priv_path, "w", encoding="utf-8") as f: + f.write("-----BEGIN PRIVATE KEY-----\nFAKE\n-----END PRIVATE KEY-----\n") + with open(self.pub_path, "w", encoding="utf-8") as f: + f.write("ssh-rsa AAAAFAKEPUBLICKEY user@host") + + self.cert_dir = tempfile.mkdtemp(prefix="azssh_cert_test_") + self.cert_path = os.path.join(self.cert_dir, "ssh-cert.pub") + with open(self.cert_path, "w", encoding="utf-8") as f: + f.write("ssh-rsa AAAAFAKEPUBLICKEY signed-blob") + + def tearDown(self): + import shutil + for d in [getattr(self, 'keys_dir', None), getattr(self, 'cert_dir', None)]: + if d and os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_happy_path(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T14:00:00Z") + mock_role.return_value = "Contributor" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "ssh-rsa AAAA signed-blob", + "certificatePath": self.cert_path, + } + + result = custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + # Verify return shape — only paths returned. + self.assertIn("privateKeyPath", result) + self.assertIn("certificatePath", result) + self.assertNotIn("signedCertificate", result) + self.assertNotIn("userPrivateKey", result) + self.assertNotIn("role", result) + self.assertNotIn("metadata", result) + + # Verify correct metadata was passed to sign. + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["username"], "user@contoso.com") + self.assertEqual(metadata["role"], "Contributor") + # startTime/endTime should come from PIM + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T14:00:00Z") + self.assertIn("ssh-rsa AAAAFAKEPUBLICKEY", metadata["userPublicKey"]) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_expiry_from_pim_short(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + """startTime/endTime should reflect a short PIM window (1.5h).""" + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T11:30:00Z") + mock_role.return_value = "Owner" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T11:30:00Z") + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_expiry_from_pim_long(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + """PIM with 7.25 hours remaining should pass through as-is.""" + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T17:15:00Z") + mock_role.return_value = "Contributor" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T17:15:00Z") + + +class TestSshCertCreateCleanupOnFailure(unittest.TestCase): + """Verify ephemeral files are cleaned up when signing fails.""" + + @mock.patch('azext_ssh.provisioned_machine_utils.cleanup_ephemeral_files') + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_cleanup_on_sign_failure(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign, mock_cleanup): + cmd = _make_cmd() + + keys_dir = tempfile.mkdtemp(prefix="azssh_test_") + priv = os.path.join(keys_dir, "id_rsa.pem") + pub = priv + ".pub" + for p in (priv, pub): + with open(p, "w") as f: + f.write("key-content") + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T14:00:00Z") + mock_role.return_value = "Owner" + mock_keygen.return_value = (priv, pub) + mock_sign.side_effect = azclierror.CLIInternalError("KV failed") + + with self.assertRaises(azclierror.CLIInternalError): + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + mock_cleanup.assert_called_once() + cleanup_args = mock_cleanup.call_args[0] + self.assertEqual(cleanup_args[0], priv) + + import shutil + shutil.rmtree(keys_dir, ignore_errors=True) + + @mock.patch('azext_ssh.provisioned_machine_utils.cleanup_ephemeral_files') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_cleanup_on_role_failure(self, mock_rid, mock_vn, mock_user, mock_pim, + mock_role, mock_cleanup): + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T14:00:00Z") + mock_role.side_effect = azclierror.AuthenticationError("no role") + + with self.assertRaises(azclierror.AuthenticationError): + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + mock_cleanup.assert_called_once_with(None, None) + + +class TestSshCertCreateRoleDerivation(unittest.TestCase): + """Verify the correct role flows through to the metadata.""" + + def _setup_mocks(self): + self.keys_dir = tempfile.mkdtemp(prefix="azssh_test_") + self.priv_path = os.path.join(self.keys_dir, "id_rsa.pem") + self.pub_path = self.priv_path + ".pub" + with open(self.priv_path, "w") as f: + f.write("PRIVATE") + with open(self.pub_path, "w") as f: + f.write("ssh-rsa PUBLIC") + + self.cert_dir = tempfile.mkdtemp(prefix="azssh_cert_test_") + self.cert_path = os.path.join(self.cert_dir, "ssh-cert.pub") + with open(self.cert_path, "w") as f: + f.write("cert") + + def tearDown(self): + import shutil + for d in [getattr(self, 'keys_dir', None), getattr(self, 'cert_dir', None)]: + if d and os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_owner_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "admin@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T16:00:00Z") + mock_role.return_value = "Owner" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["role"], "Owner") + self.assertEqual(metadata["username"], "admin@contoso.com") + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_contributor_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_pim, + mock_role, mock_keygen, mock_sign): + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "dev@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T12:00:00Z") + mock_role.return_value = "Contributor" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + result = custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["role"], "Contributor") + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T12:00:00Z") + self.assertEqual(result["privateKeyPath"], self.priv_path) + self.assertEqual(result["certificatePath"], self.cert_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py new file mode 100644 index 00000000000..7e3e9a5b432 --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py @@ -0,0 +1,780 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import base64 +import hashlib +import json +import os +import shutil +import stat +import subprocess +import tempfile +import unittest +from unittest import mock + +from azure.cli.core import azclierror +from azext_ssh import provisioned_machine_utils as pm + + +class TestValidateResourceId(unittest.TestCase): + """Tests for validate_resource_id().""" + + def test_valid_resource_id(self): + valid_id = ( + "/subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/myRG/providers/Microsoft.ProvisionedMachine" + "/machines/myDevice" + ) + # Should not raise + pm.validate_resource_id(valid_id) + + def test_valid_resource_id_mixed_case(self): + valid_id = ( + "/Subscriptions/00000000-0000-0000-0000-000000000000" + "/ResourceGroups/myRG/Providers/Microsoft.Compute" + "/virtualMachines/myVM" + ) + pm.validate_resource_id(valid_id) + + def test_invalid_resource_id_empty(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id("") + + def test_invalid_resource_id_none(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id(None) + + def test_invalid_resource_id_missing_subscriptions(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id("/resourceGroups/myRG/providers/X/Y/Z") + + def test_invalid_resource_id_no_leading_slash(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id( + "subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/rg/providers/X/Y/Z" + ) + + def test_invalid_resource_id_extra_segments(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id( + "/subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/rg/providers/X/Y/Z/extra/segment" + ) + + def test_invalid_resource_id_random_string(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id("just-a-random-string") + + +class TestValidateVaultName(unittest.TestCase): + """Tests for validate_vault_name().""" + + def test_valid_vault_name_simple(self): + pm.validate_vault_name("myVault01") + + def test_valid_vault_name_with_hyphens(self): + pm.validate_vault_name("my-key-vault") + + def test_valid_vault_name_min_length(self): + pm.validate_vault_name("abc") # 3 chars + + def test_valid_vault_name_max_length(self): + pm.validate_vault_name("a" + "b" * 22 + "c") # 24 chars + + def test_invalid_vault_name_empty(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("") + + def test_invalid_vault_name_none(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name(None) + + def test_invalid_vault_name_starts_with_digit(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("1vault") + + def test_invalid_vault_name_starts_with_hyphen(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("-vault") + + def test_invalid_vault_name_ends_with_hyphen(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("vault-") + + def test_invalid_vault_name_too_short(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("ab") + + def test_invalid_vault_name_too_long(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("a" * 30) + + def test_invalid_vault_name_special_chars(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("vault_name!") + + +class TestGetCurrentUserPrincipal(unittest.TestCase): + """Tests for get_current_user_principal().""" + + @mock.patch('azure.cli.core._profile.Profile') + def test_returns_user(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_current_account_user.return_value = "user@contoso.com" + + result = pm.get_current_user_principal(cmd) + self.assertEqual(result, "user@contoso.com") + mock_profile_cls.assert_called_once_with(cli_ctx=cmd.cli_ctx) + + @mock.patch('azure.cli.core._profile.Profile') + def test_raises_when_no_user(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_current_account_user.return_value = None + + with self.assertRaises(azclierror.AuthenticationError): + pm.get_current_user_principal(cmd) + + @mock.patch('azure.cli.core._profile.Profile') + def test_raises_when_profile_throws(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_current_account_user.side_effect = Exception("not logged in") + + with self.assertRaises(azclierror.AuthenticationError): + pm.get_current_user_principal(cmd) + + +class TestGenerateEphemeralKeypair(unittest.TestCase): + """Tests for generate_ephemeral_keypair().""" + + @mock.patch('oschmod.set_mode') + @mock.patch('os.path.isfile') + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_success(self, mock_mkdtemp, mock_check_call, mock_isfile, mock_chmod): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_isfile.return_value = True + + priv, pub = pm.generate_ephemeral_keypair() + + expected_priv = os.path.join("/tmp/azssh_pm_test", "id_rsa.pem") + expected_pub = os.path.join("/tmp/azssh_pm_test", "id_rsa.pem.pub") + self.assertEqual(priv, expected_priv) + self.assertEqual(pub, expected_pub) + mock_check_call.assert_called_once() + # Verify ssh-keygen args + call_args = mock_check_call.call_args[0][0] + self.assertEqual(call_args[0], "ssh-keygen") + self.assertIn("-t", call_args) + self.assertIn("rsa", call_args) + self.assertIn("4096", call_args) + # Verify permissions were set to 0600 + mock_chmod.assert_called_once_with(expected_priv, 0o600) + + @mock.patch('oschmod.set_mode') + @mock.patch('os.path.isfile') + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_custom_ssh_client_folder(self, mock_mkdtemp, mock_check_call, + mock_isfile, mock_chmod): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_isfile.return_value = True + + pm.generate_ephemeral_keypair(ssh_client_folder="/custom/path") + + call_args = mock_check_call.call_args[0][0] + self.assertEqual(call_args[0], os.path.join("/custom/path", "ssh-keygen")) + + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keygen_not_found(self, mock_mkdtemp, mock_check_call): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_check_call.side_effect = FileNotFoundError("not found") + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("ssh-keygen not found", str(ctx.exception)) + + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keygen_timeout(self, mock_mkdtemp, mock_check_call): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_check_call.side_effect = subprocess.TimeoutExpired(cmd="keygen", timeout=30) + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("timed out", str(ctx.exception)) + + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keygen_nonzero_exit(self, mock_mkdtemp, mock_check_call): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_check_call.side_effect = subprocess.CalledProcessError( + returncode=1, cmd="keygen" + ) + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("exited with code 1", str(ctx.exception)) + + @mock.patch('os.path.isfile') + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keys_not_created(self, mock_mkdtemp, mock_check_call, mock_isfile): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_isfile.return_value = False # Files don't exist after keygen + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("key files were not created", str(ctx.exception)) + + +class TestCleanupEphemeralFiles(unittest.TestCase): + """Tests for cleanup_ephemeral_files().""" + + def test_cleanup_files_in_temp_dir(self): + # Create real temp files to clean up + temp_dir = tempfile.mkdtemp(prefix="azssh_test_") + test_file = os.path.join(temp_dir, "test_key") + with open(test_file, "w") as f: + f.write("secret") + + pm.cleanup_ephemeral_files(test_file) + self.assertFalse(os.path.exists(temp_dir)) + + def test_cleanup_none_path(self): + # Should not raise + pm.cleanup_ephemeral_files(None) + + def test_cleanup_nonexistent_path(self): + # Should not raise + pm.cleanup_ephemeral_files("/nonexistent/path/key") + + def test_cleanup_multiple_paths(self): + dir1 = tempfile.mkdtemp(prefix="azssh_test_") + dir2 = tempfile.mkdtemp(prefix="azssh_test_") + f1 = os.path.join(dir1, "key1") + f2 = os.path.join(dir2, "key2") + for f in (f1, f2): + with open(f, "w") as fh: + fh.write("secret") + + pm.cleanup_ephemeral_files(f1, f2) + self.assertFalse(os.path.exists(dir1)) + self.assertFalse(os.path.exists(dir2)) + + +class TestCheckPimEligibility(unittest.TestCase): + """Tests for check_pim_eligibility().""" + + _RESOURCE_ID = "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z" + + def _mock_response(self, status_code, json_data): + resp = mock.Mock() + resp.status_code = status_code + resp.json.return_value = json_data + resp.text = json.dumps(json_data) + return resp + + def _setup_cmd_with_profile(self, mock_oid): + """Helper to set up cmd mock and profile mock for PIM tests.""" + cmd = mock.Mock() + cmd.cli_ctx = mock.Mock() + mock_oid.return_value = "oid-123" + + profile_mock = mock.Mock() + creds_mock = mock.Mock() + token_mock = mock.Mock() + token_mock.token = "fake-token" + creds_mock.get_token.return_value = token_mock + profile_mock.get_login_credentials.return_value = (creds_mock, None, None) + return cmd, profile_mock + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_pim_activated_passes(self, mock_profile_cls, mock_oid, mock_get): + """PIM-activated assignment should pass.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, { + "value": [{"properties": { + "assignmentType": "Activated", + "endDateTime": "2099-01-01T00:00:00Z" + }}] + }) + + instances, start_time, end_time = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertEqual(len(instances), 1) + self.assertIn("T", start_time) + self.assertIn("T", end_time) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_direct_assignment_blocked(self, mock_profile_cls, mock_oid, mock_get): + """Direct/permanent assignment (Assigned) should be rejected.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, { + "value": [{"properties": {"assignmentType": "Assigned"}}] + }) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertIn("direct (permanent)", str(ctx.exception)) + self.assertIn("PIM-based JIT activation is required", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_no_assignments_blocked(self, mock_profile_cls, mock_oid, mock_get): + """No assignments at all should be rejected.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, {"value": []}) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertIn("No active PIM role assignment", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_404_raises_not_found(self, mock_profile_cls, mock_oid, mock_get): + """Resource not found should raise ResourceNotFoundError.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(404, {"error": "not found"}) + + with self.assertRaises(azclierror.ResourceNotFoundError): + pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_mixed_only_activated_returned(self, mock_profile_cls, mock_oid, mock_get): + """When both Activated and Assigned exist, only Activated should pass.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, { + "value": [ + {"properties": {"assignmentType": "Assigned"}}, + {"properties": {"assignmentType": "Activated", "endDateTime": "2099-01-01T00:00:00Z"}}, + {"properties": {"assignmentType": "Assigned"}}, + ] + }) + + instances, start_time, end_time = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertEqual(len(instances), 1) + self.assertIn("T", start_time) + self.assertIn("T", end_time) + + +class TestResolveUserRole(unittest.TestCase): + """Tests for resolve_user_role().""" + + def _make_assignment(self, role_def_id): + assignment = mock.Mock() + assignment.role_definition_id = role_def_id + return assignment + + def _make_role_def(self, role_name): + role_def = mock.Mock() + role_def.role_name = role_name + return role_def + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_resolves_owner(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Owner" + ) + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Owner") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_resolves_contributor(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Provisioned Machine Contributor" + ) + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Contributor") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_resolves_reader(self, mock_client_factory, mock_oid): + """Reader role should now succeed — restriction is device-side.""" + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Provisioned Machine Reader" + ) + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Reader") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_picks_highest_privilege(self, mock_client_factory, mock_oid): + """When user has both Reader and Owner, Owner should win.""" + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + a1 = self._make_assignment("role-def-1") + a2 = self._make_assignment("role-def-2") + auth_client.role_assignments.list_for_scope.return_value = [a1, a2] + auth_client.role_definitions.get_by_id.side_effect = [ + self._make_role_def("Reader"), + self._make_role_def("Owner"), + ] + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Owner") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_no_assignments_raises(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + auth_client.role_assignments.list_for_scope.return_value = [] + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("No role assignments", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_unrecognized_role_raises(self, mock_client_factory, mock_oid): + """Assignments exist but none match Reader/Contributor/Admin.""" + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Storage Blob Data Processor" + ) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("No Reader, Contributor, or Owner", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_query_failure_raises(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + auth_client.role_assignments.list_for_scope.side_effect = Exception("network error") + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("Failed to query role assignments", str(ctx.exception)) + + +class TestGetCurrentUserObjectId(unittest.TestCase): + """Tests for _get_current_user_object_id().""" + + def _make_jwt(self, claims): + """Build a fake JWT with the given claims dict.""" + header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').decode().rstrip("=") + payload = base64.urlsafe_b64encode( + json.dumps(claims).encode() + ).decode().rstrip("=") + signature = base64.urlsafe_b64encode(b"sig").decode().rstrip("=") + return f"{header}.{payload}.{signature}" + + @mock.patch('azure.cli.core._profile.Profile') + def test_extracts_oid(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + + token = mock.Mock() + token.token = self._make_jwt({"oid": "user-oid-123", "sub": "sub-456"}) + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + + result = pm._get_current_user_object_id(cmd) + self.assertEqual(result, "user-oid-123") + + @mock.patch('azure.cli.core._profile.Profile') + def test_falls_back_to_sub(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + + token = mock.Mock() + token.token = self._make_jwt({"sub": "sub-456"}) + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + + result = pm._get_current_user_object_id(cmd) + self.assertEqual(result, "sub-456") + + @mock.patch('azure.cli.core._profile.Profile') + def test_no_oid_or_sub_raises(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + + token = mock.Mock() + token.token = self._make_jwt({"name": "user"}) # no oid/sub + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + + with self.assertRaises(azclierror.CLIInternalError): + pm._get_current_user_object_id(cmd) + + @mock.patch('azure.cli.core._profile.Profile') + def test_token_failure_raises(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_login_credentials.side_effect = Exception("not logged in") + + with self.assertRaises(azclierror.AuthenticationError): + pm._get_current_user_object_id(cmd) + + +class TestBuildSigningPayload(unittest.TestCase): + """Tests for _build_signing_payload().""" + + def test_produces_base64url_sha256_digest(self): + metadata = {"key": "value", "num": 42} + result = pm._build_signing_payload(metadata) + + # Verify it's a SHA-256 digest, base64url encoded, no padding + canonical = json.dumps( + metadata, separators=(",", ":"), sort_keys=True, ensure_ascii=True + ).encode("utf-8") + expected_digest = hashlib.sha256(canonical).digest() + expected = base64.urlsafe_b64encode(expected_digest).decode("ascii").rstrip("=") + self.assertEqual(result, expected) + + def test_deterministic(self): + metadata = {"b": 2, "a": 1} + r1 = pm._build_signing_payload(metadata) + r2 = pm._build_signing_payload(metadata) + self.assertEqual(r1, r2) + + def test_key_order_independent(self): + m1 = {"a": 1, "b": 2} + m2 = {"b": 2, "a": 1} + self.assertEqual(pm._build_signing_payload(m1), pm._build_signing_payload(m2)) + + +class TestCallKeyvaultSign(unittest.TestCase): + """Tests for _call_keyvault_sign().""" + + def _setup_cmd_with_creds(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + token = mock.Mock() + token.token = "fake-token" + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + return cmd + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_success(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + metadata = {"userPublicKey": "ssh-rsa AAAA...", "username": "user"} + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": "signed-value-b64"} + mock_requests.post.return_value = mock_response + + sig, cert = pm._call_keyvault_sign(cmd, "myVault", metadata) + + self.assertEqual(sig, "signed-value-b64") + self.assertIn("ssh-rsa AAAA...", cert) + # Verify URL is correct + post_url = mock_requests.post.call_args[0][0] + self.assertIn("myVault.vault.azure.net", post_url) + self.assertIn("/keys/ssh-ca/sign", post_url) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_401_raises_auth_error(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 401 + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("Access denied", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_404_raises_not_found(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.ResourceNotFoundError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("ssh-ca", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_500_raises_internal_error(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("HTTP 500", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_timeout_raises(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + import requests as real_requests + mock_requests.post.side_effect = real_requests.exceptions.Timeout("timed out") + mock_requests.exceptions = real_requests.exceptions + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("timed out", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_connection_error_raises(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + import requests as real_requests + mock_requests.post.side_effect = real_requests.exceptions.ConnectionError("dns fail") + mock_requests.exceptions = real_requests.exceptions + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("Unable to connect", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_empty_signature_raises(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": ""} + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("empty signature", str(ctx.exception)) + + @mock.patch('azure.cli.core._profile.Profile') + def test_token_failure_raises_auth(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_login_credentials.side_effect = Exception("no token") + + with self.assertRaises(azclierror.AuthenticationError): + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + + +class TestSignCertificateMetadata(unittest.TestCase): + """Tests for sign_certificate_metadata().""" + + @mock.patch('oschmod.set_mode') + @mock.patch('azext_ssh.provisioned_machine_utils._call_keyvault_sign') + def test_writes_cert_file(self, mock_sign, mock_chmod): + cmd = mock.Mock() + mock_sign.return_value = ("sig_b64", "cert-content-here") + + metadata = { + "userPublicKey": "ssh-rsa AAAA", + "username": "user@contoso.com", + "role": "Owner", + "startTime": "2026-05-26T10:00:00Z", + "endTime": "2026-05-26T14:00:00Z", + } + + result = pm.sign_certificate_metadata( + cmd, "myVault", metadata + ) + + self.assertEqual(result["signedCertificate"], "cert-content-here") + self.assertTrue(os.path.isfile(result["certificatePath"])) + with open(result["certificatePath"], "r") as f: + self.assertEqual(f.read(), "cert-content-here") + + # Verify permissions were set + mock_chmod.assert_called_once() + + # Clean up + shutil.rmtree(os.path.dirname(result["certificatePath"]), + ignore_errors=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/verify.py b/verify.py new file mode 100644 index 00000000000..f5e79d89e4c --- /dev/null +++ b/verify.py @@ -0,0 +1,94 @@ +""" +Verify an SSH certificate generated by `az ssh cert-create`. + +Usage: + python verify.py --cert --metadata --ca-pub + +Where: + --cert Path to the certificate file (ssh-cert.pub) + --metadata Path to a JSON file with the signing payload metadata + --ca-pub Path to the CA public key (PEM). Download with: + az keyvault key download --vault-name --name ssh-ca --encoding PEM -f ca_public.pem +""" + +import argparse +import json +import hashlib +import base64 +import sys + +from cryptography.hazmat.primitives.asymmetric import padding, utils +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.exceptions import InvalidSignature + + +def verify_certificate(cert_path, metadata_path, ca_pub_path): + # 1. Load CA public key + with open(ca_pub_path, "rb") as f: + ca_pub = serialization.load_pem_public_key(f.read()) + print(f"[OK] Loaded CA public key from {ca_pub_path}") + + # 2. Load metadata (the signing payload) + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + print(f"[OK] Loaded metadata: {json.dumps(metadata, indent=2)}") + + # 3. Read certificate file — current format: " " + with open(cert_path, "r", encoding="utf-8") as f: + cert_content = f.read().strip() + + # Extract signature (last space-separated token) + parts = cert_content.rsplit(" ", 1) + if len(parts) != 2: + print("[FAIL] Certificate format unexpected — expected ' '") + return False + cert_public_key_part, signature_b64 = parts[0], parts[1] + print(f"[OK] Extracted signature from certificate ({len(signature_b64)} chars)") + + # 4. Verify public key in cert matches metadata + if metadata.get("userPublicKey", "").strip() != cert_public_key_part.strip(): + print("[FAIL] Public key in certificate does not match metadata userPublicKey") + return False + print("[OK] Public key in certificate matches metadata") + + # 5. Reconstruct the digest that was signed + canonical = json.dumps( + metadata, separators=(",", ":"), sort_keys=True, ensure_ascii=True + ).encode("utf-8") + digest = hashlib.sha256(canonical).digest() + print(f"[OK] Computed SHA-256 digest of canonical metadata") + + # 6. Decode the signature (base64url) + # Add padding if needed + sig_padded = signature_b64 + "=" * (4 - len(signature_b64) % 4) + sig_bytes = base64.urlsafe_b64decode(sig_padded) + print(f"[OK] Decoded signature ({len(sig_bytes)} bytes)") + + # 7. Verify: CA public key + RS256 (PKCS1v15 + SHA256 prehashed) + try: + ca_pub.verify( + sig_bytes, + digest, + padding.PKCS1v15(), + utils.Prehashed(hashes.SHA256()), + ) + print("[OK] Signature is VALID — certificate was signed by the CA key") + return True + except InvalidSignature: + print("[FAIL] Signature is INVALID — certificate was NOT signed by this CA key") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Verify an az ssh cert-create certificate") + parser.add_argument("--cert", required=True, help="Path to ssh-cert.pub") + parser.add_argument("--metadata", required=True, help="Path to metadata JSON file") + parser.add_argument("--ca-pub", required=True, help="Path to CA public key (PEM)") + args = parser.parse_args() + + ok = verify_certificate(args.cert, args.metadata, args.ca_pub) + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main() \ No newline at end of file