From af2240b35e8fdac05c9f067a408729bbe8e71053 Mon Sep 17 00:00:00 2001 From: Pushkar Srivastava Date: Mon, 25 May 2026 22:35:14 +0530 Subject: [PATCH 1/3] ssh cert-create extension for remote login into edgemachine --- src/ssh/azext_ssh/_help.py | 17 + src/ssh/azext_ssh/_params.py | 10 + src/ssh/azext_ssh/commands.py | 1 + src/ssh/azext_ssh/custom.py | 92 +++ .../azext_ssh/provisioned_machine_utils.py | 609 ++++++++++++++ .../tests/latest/test_cert_create.py | 318 +++++++ .../latest/test_provisioned_machine_utils.py | 779 ++++++++++++++++++ verify.py | 94 +++ 8 files changed, 1920 insertions(+) create mode 100644 src/ssh/azext_ssh/provisioned_machine_utils.py create mode 100644 src/ssh/azext_ssh/tests/latest/test_cert_create.py create mode 100644 src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py create mode 100644 verify.py diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index cf629ee65ea..6155486dd57 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -117,6 +117,23 @@ az ssh cert --file ./id_rsa-aadcert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH" """ +helps['ssh cert-create'] = """ + type: command + short-summary: Create a short-lived SSH certificate signed by a private CA key in Azure Key Vault. + long-summary: | + Generates an ephemeral SSH key pair, determines the caller's RBAC role + (Reader/Contributor/Owner) on the target ProvisionedMachine resource via + PIM-based JIT access, and sends the public key along with metadata + (userPublicKey, username, role, expiry) to Key Vault for signing. + The user identity is derived automatically from the Entra login context. + The certificate expiry is derived from the PIM activation's remaining duration. + Returns the signed SSH user certificate and the freshly generated private key. + examples: + - name: Create a certificate (expiry derived from PIM activation) + text: | + az ssh cert-create --vault-name myKeyVault --resource-id /subscriptions/.../providers/Microsoft.ProvisionedMachine/machines/myDevice +""" + helps['ssh arc'] = """ type: command short-summary: SSH into Azure Arc Servers diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index f0d22edc576..e0eff0a4620 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -83,6 +83,16 @@ def load_arguments(self, _): help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). ' 'Default to ssh pre-installed if not provided.') + with self.argument_context('ssh cert-create') as c: + c.argument('vault_name', options_list=['--vault-name', '-v'], + help='Name of the Azure Key Vault that holds the private CA signing key (ssh-ca).', + required=True) + c.argument('resource_id', options_list=['--resource-id', '-r'], + help='Fully qualified ARM resource ID of the ProvisionedMachine. ' + 'Used to determine the user\'s RBAC role (Reader/Contributor/Admin) ' + 'via PIM-based JIT access.', + required=True) + with self.argument_context('ssh arc') as c: c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the Arc Server') c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path') diff --git a/src/ssh/azext_ssh/commands.py b/src/ssh/azext_ssh/commands.py index 335ba21eca6..969987f0a25 100644 --- a/src/ssh/azext_ssh/commands.py +++ b/src/ssh/azext_ssh/commands.py @@ -11,3 +11,4 @@ def load_command_table(self, _): g.custom_command('config', 'ssh_config') g.custom_command('cert', 'ssh_cert') g.custom_command('arc', 'ssh_arc') + g.custom_command('cert-create', 'ssh_cert_create') diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 0123fff6121..f530c37730f 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -152,6 +152,98 @@ def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, p resource_type, ssh_proxy_folder, winrdp, yes_without_prompt, ssh_args) +def ssh_cert_create(cmd, vault_name, resource_id): + """Create a short-lived SSH certificate signed by a private CA in Key Vault. + + Step 1 - Prepare certificate metadata: + Generate ephemeral key pair, resolve PIM role, build metadata: + { user_public_key, username, role, expiry } + Expiry is derived from the PIM activation's remaining duration. + + Step 2 - Sign via Key Vault: + Send signing request to Key Vault Sign API using az login context. + CA private key never leaves Key Vault. + + Step 3 - Return to user: + Signed SSH user certificate + freshly generated user_private key. + """ + from . import provisioned_machine_utils as pm + + telemetry.set_command_details('ssh cert-create') + + # Validate inputs. + pm.validate_resource_id(resource_id) + pm.validate_vault_name(vault_name) + + private_key_path = None + cert_path = None + try: + # -- Step 1: Prepare certificate metadata -------------------------- + # Derive username from az login (Entra) context. + username = pm.get_current_user_principal(cmd) + logger.info("Derived username: %s", username) + + # Verify the user has an active PIM assignment (JIT activated). + # Expiry is derived from the PIM activation's remaining duration. + _pim_instances, expiry = pm.check_pim_eligibility(cmd, resource_id) + logger.info("PIM eligibility confirmed for resource: %s (%.2f hours remaining)", + resource_id, expiry) + + # Resolve role from PIM assignment on the ProvisionedMachine resource. + # Reader role is blocked — only Contributor and Administrator can + # generate SSH certificates. + role = pm.resolve_user_role(cmd, resource_id) + logger.info("Resolved PIM role: %s", role) + + # Determine which certificate types this role can generate. + role_perms = pm.ROLE_PERMISSIONS.get(role, {}) + cert_types = role_perms.get("certificate_types", []) + logger.info("Allowed certificate types for %s: %s", role, cert_types) + + # Generate fresh ephemeral SSH key pair. + private_key_path, public_key_path = pm.generate_ephemeral_keypair() + with open(public_key_path, "r", encoding="utf-8") as f: + user_public_key = f.read().strip() + + certificate_metadata = { + "userPublicKey": user_public_key, + "username": username, + "role": role, + "expiry": expiry, + } + + # -- Step 2: Sign via Key Vault ------------------------------------ + # AZ CLI sends signing request using az login context. + # CA private key never leaves Key Vault. + signed_certificate = pm.sign_certificate_metadata( + cmd, vault_name, certificate_metadata, resource_id + ) + cert_path = signed_certificate["certificatePath"] + + # -- Step 3: Return to user ---------------------------------------- + result = { + "privateKeyPath": private_key_path, + "certificatePath": cert_path, + } + + print_styled_text((Style.SUCCESS, + f"SSH certificate created successfully.\n" + f" Private key : {private_key_path}\n" + f" Certificate : {cert_path}")) + + logger.warning("The private key at %s is sensitive. " + "Delete it once the certificate expires.", + os.path.dirname(private_key_path)) + + telemetry.set_success() + return result + + except Exception: + # Clean up sensitive ephemeral files on failure. + pm.cleanup_ephemeral_files(private_key_path, cert_path) + raise + + def _do_ssh_op(cmd, op_info, op_call): # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys if not op_info.is_arc(): diff --git a/src/ssh/azext_ssh/provisioned_machine_utils.py b/src/ssh/azext_ssh/provisioned_machine_utils.py new file mode 100644 index 00000000000..09e9b6bcedf --- /dev/null +++ b/src/ssh/azext_ssh/provisioned_machine_utils.py @@ -0,0 +1,609 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Utilities for the `az ssh cert-create` command. + +Step 1 – Prepare certificate metadata: + user_public_key + username + role (from PIM assignment) + expiry + +Step 2 – Sign via Key Vault: + AZ CLI sends signing request to Key Vault Sign API using az login context. + CA private key never leaves Key Vault. + +Step 3 – Return to user: + Signed SSH user certificate + freshly generated user_private key. + +Expiry constraint: maximum 8 hours (enforced at device level). +""" + +import base64 +import datetime +import hashlib +import json +import os +import re +import stat +import subprocess +import tempfile + +import oschmod +import requests +from knack import log +from azure.cli.core import azclierror + +logger = log.get_logger(__name__) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +_KV_SIGN_API_VERSION = "7.4" +_KV_CA_KEY_NAME = "ssh-ca" +_KV_SIGN_ALGORITHM = "RS256" +_KV_RESOURCE = "https://vault.azure.net" +_KV_SIGN_TIMEOUT_SECONDS = 30 +_RSA_KEY_BITS = 4096 +_PRIVATE_KEY_FILE_PERMISSION = 0o600 # owner read/write only +_RESOURCE_ID_PATTERN = re.compile( + r"^/subscriptions/[0-9a-fA-F-]+/resourceGroups/[^/]+/providers/[^/]+/[^/]+/[^/]+$", + re.IGNORECASE, +) +_VAULT_NAME_PATTERN = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,22}[a-zA-Z0-9]$") + +# Standard roles for ProvisionedMachine resources. +# Maps Azure RBAC role name substrings to a canonical role label. +PROVISIONED_MACHINE_ROLES = { + "reader": "Reader", + "contributor": "Contributor", + "owner": "Owner", +} + +# Permissions matrix per role. +# Each role defines which certificate types it can generate and what +# capabilities the device should grant. +# +# Reader - View-only in portal; NO SSH access, NO certificate generation. +# Contributor - Config app + SSH (non-sudo). +# Owner - Config app + SSH (non-sudo) + SSH (sudo). +ROLE_PERMISSIONS = { + "Reader": { + "ssh_allowed": False, + "certificate_types": [], + "portal": ["view"], + }, + "Contributor": { + "ssh_allowed": True, + "certificate_types": ["config-app", "ssh-non-sudo"], + "portal": ["view", "create", "manage-updates", "manage-nics", + "collect-logs", "manage-networking"], + }, + "Owner": { + "ssh_allowed": True, + "certificate_types": ["config-app", "ssh-non-sudo", "ssh-sudo"], + "portal": ["view", "create", "manage-updates", "manage-nics", + "manage-disks", "reset-device", "keyvault-access", + "delete", "grant-access", "pim-setup", + "collect-logs", "manage-networking"], + }, +} + + +# --------------------------------------------------------------------------- +# Input validation +# --------------------------------------------------------------------------- + +def validate_resource_id(resource_id): + """Validate that *resource_id* looks like a fully-qualified ARM resource ID.""" + if not resource_id or not _RESOURCE_ID_PATTERN.match(resource_id): + raise azclierror.InvalidArgumentValueError( + f"'{resource_id}' is not a valid ARM resource ID. " + "Expected format: /subscriptions//resourceGroups//" + "providers///" + ) + + +def validate_vault_name(vault_name): + """Validate that *vault_name* conforms to Key Vault naming rules.""" + if not vault_name or not _VAULT_NAME_PATTERN.match(vault_name): + raise azclierror.InvalidArgumentValueError( + f"'{vault_name}' is not a valid Key Vault name. " + "It must be 3-24 characters, start with a letter, end with a " + "letter or digit, and contain only letters, digits, and hyphens." + ) + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + +def get_current_user_principal(cmd): + """Return the UPN (or app ID) of the currently signed-in identity. + + The value is derived from the Entra / Azure CLI login context so the + caller never needs to supply it manually. + """ + from azure.cli.core._profile import Profile + profile = Profile(cli_ctx=cmd.cli_ctx) + try: + user = profile.get_current_account_user() + except Exception as ex: + raise azclierror.AuthenticationError( + "Unable to determine the signed-in user. " + "Please run 'az login' first." + ) from ex + if not user: + raise azclierror.AuthenticationError( + "No signed-in user found. Please run 'az login' first." + ) + return user + + +def check_pim_eligibility(cmd, resource_id): + """Verify the current user has an **active** PIM role assignment on *resource_id*. + + Queries the PIM Role Assignment Schedule Instances API to confirm that + the user has activated JIT access. Raises ``AuthenticationError`` if + no active PIM assignment is found. + + Returns the list of active PIM schedule instances. + """ + from azure.cli.core._profile import Profile + + user_object_id = _get_current_user_object_id(cmd) + profile = Profile(cli_ctx=cmd.cli_ctx) + + try: + creds, _, _ = profile.get_login_credentials() + token = creds.get_token("https://management.azure.com/.default") + except Exception as ex: + raise azclierror.AuthenticationError( + "Failed to acquire a management token for PIM eligibility check." + ) from ex + + # Query active PIM role assignment schedule instances scoped to the resource. + api_version = "2020-10-01" + url = ( + f"https://management.azure.com{resource_id}" + f"/providers/Microsoft.Authorization" + f"/roleAssignmentScheduleInstances" + f"?api-version={api_version}" + f"&$filter=assignedTo('{user_object_id}')" + ) + + try: + resp = requests.get( + url, + headers={"Authorization": f"Bearer {token.token}"}, + timeout=30, + ) + except requests.exceptions.RequestException as ex: + raise azclierror.CLIInternalError( + f"Failed to query PIM eligibility on '{resource_id}': {ex}" + ) from ex + + if resp.status_code == 404: + raise azclierror.ResourceNotFoundError( + f"Resource '{resource_id}' was not found. Verify the resource ID is correct." + ) + if resp.status_code != 200: + raise azclierror.CLIInternalError( + f"PIM eligibility check failed (HTTP {resp.status_code}): {resp.text}" + ) + + data = resp.json() + instances = data.get("value", []) + + # Filter for PIM-activated assignments only. + # assignmentType == "Activated" means the user activated JIT access via PIM. + # assignmentType == "Assigned" means a permanent/direct role assignment, + # which should NOT be accepted — PIM activation is required. + pim_activated = [ + inst for inst in instances + if (inst.get("properties", {}).get("assignmentType", "")).lower() == "activated" + ] + + if not pim_activated: + has_direct = len(instances) > 0 + if has_direct: + raise azclierror.AuthenticationError( + f"You have a direct (permanent) role assignment on resource " + f"'{resource_id}', but PIM-based JIT activation is required. " + f"Direct role assignments are not accepted for SSH certificate " + f"generation. Please activate your role via PIM first:\n" + f" 1. Go to Azure Portal → Privileged Identity Management → My roles\n" + f" 2. Select Azure resources → find your eligible role\n" + f" 3. Click 'Activate' and provide a justification\n" + f" 4. Wait 1-2 minutes for propagation, then retry." + ) + raise azclierror.AuthenticationError( + f"No active PIM role assignment found for the current user on resource " + f"'{resource_id}'. Please activate your PIM-eligible role first:\n" + f" 1. Go to Azure Portal → Privileged Identity Management → My roles\n" + f" 2. Select Azure resources → find your eligible role\n" + f" 3. Click 'Activate' and provide a justification\n" + f" 4. Wait 1-2 minutes for propagation, then retry." + ) + + # Extract the expiry from the PIM activation's endDateTime. + # Use the latest endDateTime among all activated assignments. + now_utc = datetime.datetime.now(datetime.timezone.utc) + latest_end = None + for inst in pim_activated: + end_str = inst.get("properties", {}).get("endDateTime") + if end_str: + try: + end_dt = datetime.datetime.fromisoformat(end_str.replace("Z", "+00:00")) + if latest_end is None or end_dt > latest_end: + latest_end = end_dt + except (ValueError, TypeError): + logger.debug("Could not parse endDateTime '%s'.", end_str) + + if latest_end is None: + raise azclierror.CLIInternalError( + "PIM activation found but endDateTime is missing. " + "Cannot determine certificate expiry." + ) + + remaining_hours = (latest_end - now_utc).total_seconds() / 3600.0 + if remaining_hours <= 0: + raise azclierror.AuthenticationError( + f"Your PIM activation has expired (ended {latest_end.isoformat()}). " + f"Please re-activate your PIM-eligible role and retry." + ) + + logger.info("Found %d PIM-activated assignment(s) for user '%s' on '%s'. " + "Remaining: %.2f hours (until %s).", + len(pim_activated), user_object_id, resource_id, + remaining_hours, latest_end.isoformat()) + return pim_activated, remaining_hours + + +def resolve_user_role(cmd, resource_id): + """Determine the highest-privilege role the signed-in user holds on *resource_id*. + + Uses the Azure Authorization SDK to list role assignments scoped to the + ProvisionedMachine resource and maps them to Reader / Contributor / Owner. + + Returns one of ``Contributor`` or ``Owner``. + + Raises ``AuthenticationError`` if: + - No relevant assignment is found. + - The highest role is Reader (Reader has no SSH access). + """ + from azure.cli.core.commands.client_factory import get_mgmt_service_client + + try: + from azure.mgmt.authorization import AuthorizationManagementClient + except ImportError as ex: + raise azclierror.CLIInternalError( + "The 'azure-mgmt-authorization' package is required. " + "Please run: pip install azure-mgmt-authorization" + ) from ex + + user_object_id = _get_current_user_object_id(cmd) + + try: + auth_client = get_mgmt_service_client(cmd.cli_ctx, AuthorizationManagementClient) + assignments = list(auth_client.role_assignments.list_for_scope( + scope=resource_id, + filter=f"assignedTo('{user_object_id}')" + )) + except Exception as ex: + raise azclierror.CLIInternalError( + f"Failed to query role assignments on '{resource_id}': {ex}" + ) from ex + + if not assignments: + raise azclierror.AuthenticationError( + f"No role assignments found for the current user on resource " + f"'{resource_id}'. Ensure PIM-based JIT access has been activated." + ) + + role_priority = {"Owner": 3, "Contributor": 2, "Reader": 1} + best_role = None + best_priority = 0 + + for assignment in assignments: + role_def_id = assignment.role_definition_id + try: + role_def = auth_client.role_definitions.get_by_id(role_def_id) + except Exception: # pylint: disable=broad-except + logger.debug("Skipping role definition '%s' (could not resolve).", role_def_id) + continue + role_name = (role_def.role_name or "").lower() + + for key, standard in PROVISIONED_MACHINE_ROLES.items(): + if key in role_name: + priority = role_priority.get(standard, 0) + if priority > best_priority: + best_role = standard + best_priority = priority + + if not best_role: + raise azclierror.AuthenticationError( + f"No Reader, Contributor, or Owner role assignment found for " + f"the current user on resource '{resource_id}'. Ensure PIM-based " + f"JIT access has been activated and the role is scoped to the " + f"ProvisionedMachine resource." + ) + + # Reader role has no SSH access — block certificate generation. + if best_role == "Reader": + raise azclierror.AuthenticationError( + f"Your highest role on '{resource_id}' is Reader. " + f"Reader role does not have SSH access and cannot generate certificates. " + f"You need at least Contributor role for SSH (non-sudo) access, " + f"or Owner role for SSH (sudo) access." + ) + + logger.info("Resolved role '%s' for user '%s' on resource '%s'.", + best_role, user_object_id, resource_id) + return best_role + + +def generate_ephemeral_keypair(ssh_client_folder=None): + """Generate a fresh RSA-4096 key pair in a secure temp directory. + + The private key file permissions are set to 0600 (owner read/write only). + + Returns ``(private_key_path, public_key_path)``. + """ + keys_dir = tempfile.mkdtemp(prefix="azssh_pm_") + private_key_path = os.path.join(keys_dir, "id_rsa.pem") + public_key_path = private_key_path + ".pub" + + keygen = _resolve_keygen(ssh_client_folder) + cmd_args = [ + keygen, "-t", "rsa", "-b", str(_RSA_KEY_BITS), + "-f", private_key_path, "-N", "", "-q", + ] + + try: + subprocess.check_call(cmd_args, timeout=30) # pylint: disable=subprocess-run-check + except FileNotFoundError as ex: + raise azclierror.CLIInternalError( + "ssh-keygen not found. Ensure OpenSSH is installed or provide " + "--ssh-client-folder." + ) from ex + except subprocess.TimeoutExpired as ex: + raise azclierror.CLIInternalError( + "ssh-keygen timed out while generating the key pair." + ) from ex + except subprocess.CalledProcessError as ex: + raise azclierror.CLIInternalError( + f"ssh-keygen exited with code {ex.returncode}." + ) from ex + + if not os.path.isfile(private_key_path) or not os.path.isfile(public_key_path): + raise azclierror.CLIInternalError( + "ssh-keygen completed but key files were not created." + ) + + # Restrict private key to owner-only access (cross-platform via oschmod). + oschmod.set_mode(private_key_path, _PRIVATE_KEY_FILE_PERMISSION) + + logger.info("Generated ephemeral key pair at %s", keys_dir) + return private_key_path, public_key_path + + +def cleanup_ephemeral_files(*file_paths): + """Best-effort removal of sensitive ephemeral files and their parent dirs.""" + for path in file_paths: + if not path: + continue + try: + parent = os.path.dirname(path) + # Remove all files in the temp directory. + if os.path.isdir(parent) and parent.startswith(tempfile.gettempdir()): + import shutil + shutil.rmtree(parent, ignore_errors=True) + elif os.path.isfile(path): + os.remove(path) + except OSError: + logger.debug("Failed to clean up '%s'.", path) + + +def sign_certificate_metadata(cmd, keyvault_name, metadata, resource_id): + """Sign the certificate metadata with the CA private key in Key Vault. + + Metadata shape: { userPublicKey, username, role, expiry } + + The Key Vault hosts a non-exportable CA private key (named ``ssh-ca``). + AZ CLI sends the signing request using the az login context. + The CA private key never leaves Key Vault. + + Returns a dict with ``signedCertificate`` and ``certificatePath``. + """ + expiry_hours = metadata["expiry"] + + signing_payload = { + "userPublicKey": metadata["userPublicKey"], + "username": metadata["username"], + "role": metadata["role"], + "expiry": expiry_hours, + "resourceId": resource_id, + } + + logger.info("Sending signing request to Key Vault '%s' (expiry %.2f hours) ...", + keyvault_name, expiry_hours) + + # Sign via Key Vault - CA private key never leaves the vault. + _signature, cert_data = _call_keyvault_sign(cmd, keyvault_name, signing_payload) + + # Write the signed SSH user certificate to a temp file. + cert_dir = tempfile.mkdtemp(prefix="azssh_cert_") + cert_path = os.path.join(cert_dir, "ssh-cert.pub") + with open(cert_path, "w", encoding="utf-8") as f: + f.write(cert_data) + oschmod.set_mode(cert_path, stat.S_IRUSR | stat.S_IWUSR) # 0600 + + # Write the signing payload metadata alongside the cert for verification. + metadata_path = os.path.join(cert_dir, "metadata.json") + with open(metadata_path, "w", encoding="utf-8") as f: + json.dump(signing_payload, f, indent=2) + logger.info("Signing metadata written to %s", metadata_path) + + logger.info("Signed SSH user certificate written to %s", cert_path) + return { + "signedCertificate": cert_data, + "certificatePath": cert_path, + } + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _get_current_user_object_id(cmd): + """Return the OID of the currently signed-in user / service principal. + + Extracts the ``oid`` claim from the ARM access token. This avoids a + dependency on the deprecated ``azure-graphrbac`` SDK. + """ + from azure.cli.core._profile import Profile + + profile = Profile(cli_ctx=cmd.cli_ctx) + try: + creds, _, _ = profile.get_login_credentials() + token = creds.get_token("https://management.azure.com/.default") + except Exception as ex: + raise azclierror.AuthenticationError( + "Failed to acquire an access token. Please run 'az login'." + ) from ex + + # Decode without verification – we only need the 'oid' claim and the + # token was just issued by the CLI's own credential chain. + try: + # The token is a JWT; decode the payload (middle segment). + payload_segment = token.token.split(".")[1] + # Add padding if needed. + padded = payload_segment + "=" * (4 - len(payload_segment) % 4) + payload_bytes = base64.urlsafe_b64decode(padded) + claims = json.loads(payload_bytes) + except Exception as ex: + raise azclierror.CLIInternalError( + "Failed to decode the access token to extract the user object ID." + ) from ex + + oid = claims.get("oid") or claims.get("sub") + if not oid: + raise azclierror.CLIInternalError( + "The access token does not contain an 'oid' or 'sub' claim." + ) + return oid + + +def _call_keyvault_sign(cmd, keyvault_name, metadata): + """Call Key Vault REST API to sign the metadata. + + Uses the ``ssh-ca`` key in the vault to perform an RS256 sign operation. + The CA private key never leaves Key Vault. + + Returns a tuple of ``(signature_b64, certificate_string)``. + """ + from azure.cli.core._profile import Profile + + profile = Profile(cli_ctx=cmd.cli_ctx) + try: + creds, _, _ = profile.get_login_credentials() + token = creds.get_token(f"{_KV_RESOURCE}/.default") + except Exception as ex: + raise azclierror.AuthenticationError( + f"Failed to acquire a Key Vault access token for vault " + f"'{keyvault_name}'. Ensure you have 'Key Sign' permissions " + f"on the vault. Error: {ex}" + ) from ex + + vault_url = f"https://{keyvault_name}.vault.azure.net" + sign_url = (f"{vault_url}/keys/{_KV_CA_KEY_NAME}/sign" + f"?api-version={_KV_SIGN_API_VERSION}") + + headers = { + "Authorization": f"Bearer {token.token}", + "Content-Type": "application/json", + } + + request_body = { + "alg": _KV_SIGN_ALGORITHM, + "value": _build_signing_payload(metadata), + } + + try: + response = requests.post( + sign_url, headers=headers, json=request_body, + timeout=_KV_SIGN_TIMEOUT_SECONDS, + ) + except requests.exceptions.Timeout as ex: + raise azclierror.CLIInternalError( + f"Key Vault signing request timed out after " + f"{_KV_SIGN_TIMEOUT_SECONDS}s. Please try again." + ) from ex + except requests.exceptions.ConnectionError as ex: + raise azclierror.CLIInternalError( + f"Unable to connect to Key Vault '{keyvault_name}'. " + f"Check network connectivity and vault name." + ) from ex + + if response.status_code == 401: + raise azclierror.AuthenticationError( + f"Access denied to Key Vault '{keyvault_name}'. " + f"Ensure the signed-in identity has 'Key Sign' permission " + f"on the '{_KV_CA_KEY_NAME}' key." + ) + if response.status_code == 404: + raise azclierror.ResourceNotFoundError( + f"Key '{_KV_CA_KEY_NAME}' not found in vault '{keyvault_name}'. " + f"Ensure the CA signing key exists." + ) + if response.status_code != 200: + raise azclierror.CLIInternalError( + f"Key Vault signing failed (HTTP {response.status_code}): " + f"{response.text}" + ) + + result = response.json() + signature_b64 = result.get("value") + if not signature_b64: + raise azclierror.CLIInternalError( + "Key Vault returned an empty signature. " + "Check the CA key configuration." + ) + + cert_data = _build_ssh_certificate(metadata, signature_b64) + return signature_b64, cert_data + + +def _build_signing_payload(metadata): + """Create a base64url-encoded SHA-256 digest of the metadata for Key Vault. + + Key Vault sign API expects a digest, not raw data. We compute + SHA-256(canonical JSON) and base64url-encode the result. + """ + canonical = json.dumps( + metadata, separators=(",", ":"), sort_keys=True, ensure_ascii=True + ).encode("utf-8") + digest = hashlib.sha256(canonical).digest() + return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") + + +def _build_ssh_certificate(metadata, signature_b64): + """Construct an OpenSSH certificate from the metadata and KV signature. + + NOTE: In production the Key Vault / CA service would return a fully-formed + ``ssh-rsa-cert-v01@openssh.com`` certificate. This helper is a placeholder + that concatenates the public key with the CA-signed blob so that the + overall CLI flow can be tested end-to-end once the CA service contract is + finalized. + """ + # Placeholder – real implementation depends on the Device API / CA contract. + public_key = metadata["userPublicKey"] + # Return a synthetic certificate string that the device agent will validate. + return f"{public_key} {signature_b64}" + + +def _resolve_keygen(ssh_client_folder): + if ssh_client_folder: + return os.path.join(ssh_client_folder, "ssh-keygen") + return "ssh-keygen" diff --git a/src/ssh/azext_ssh/tests/latest/test_cert_create.py b/src/ssh/azext_ssh/tests/latest/test_cert_create.py new file mode 100644 index 00000000000..72c847bc80a --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_cert_create.py @@ -0,0 +1,318 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Tests for the ``az ssh cert-create`` command (custom.ssh_cert_create). + +All external dependencies (RBAC, Key Vault, key generation) are mocked. +Tests verify the orchestration logic, input validation, PIM-derived expiry, +error handling, cleanup on failure, and correct return shape. +""" + +import os +import tempfile +import unittest +from unittest import mock + +from azure.cli.core import azclierror +from azext_ssh import custom + + +# Convenience: a valid ARM resource ID used across tests. +_VALID_RESOURCE_ID = ( + "/subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/myRG/providers/Microsoft.ProvisionedMachine" + "/machines/myDevice" +) +_VALID_VAULT = "myKeyVault" + + +def _make_cmd(): + """Return a mock CLI cmd object.""" + cmd = mock.Mock() + cmd.cli_ctx = mock.Mock() + return cmd + + +class TestSshCertCreateValidation(unittest.TestCase): + """Input validation tests.""" + + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_invalid_resource_id_raises(self, mock_validate_rid, mock_validate_vn): + mock_validate_rid.side_effect = azclierror.InvalidArgumentValueError("bad id") + cmd = _make_cmd() + + with self.assertRaises(azclierror.InvalidArgumentValueError): + custom.ssh_cert_create(cmd, _VALID_VAULT, "bad-id") + + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + def test_invalid_vault_name_raises(self, mock_validate_vn, mock_validate_rid): + mock_validate_vn.side_effect = azclierror.InvalidArgumentValueError("bad vault") + cmd = _make_cmd() + + with self.assertRaises(azclierror.InvalidArgumentValueError): + custom.ssh_cert_create(cmd, "-bad-vault!", _VALID_RESOURCE_ID) + + +class TestSshCertCreateHappyPath(unittest.TestCase): + """Full happy-path tests with all external calls mocked.""" + + def _setup_mocks(self): + self.keys_dir = tempfile.mkdtemp(prefix="azssh_test_") + self.priv_path = os.path.join(self.keys_dir, "id_rsa.pem") + self.pub_path = self.priv_path + ".pub" + with open(self.priv_path, "w", encoding="utf-8") as f: + f.write("-----BEGIN PRIVATE KEY-----\nFAKE\n-----END PRIVATE KEY-----\n") + with open(self.pub_path, "w", encoding="utf-8") as f: + f.write("ssh-rsa AAAAFAKEPUBLICKEY user@host") + + self.cert_dir = tempfile.mkdtemp(prefix="azssh_cert_test_") + self.cert_path = os.path.join(self.cert_dir, "ssh-cert.pub") + with open(self.cert_path, "w", encoding="utf-8") as f: + f.write("ssh-rsa AAAAFAKEPUBLICKEY signed-blob") + + def tearDown(self): + import shutil + for d in [getattr(self, 'keys_dir', None), getattr(self, 'cert_dir', None)]: + if d and os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_happy_path(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 4.0) + mock_role.return_value = "Contributor" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "ssh-rsa AAAA signed-blob", + "certificatePath": self.cert_path, + } + + result = custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + # Verify return shape — only paths returned. + self.assertIn("privateKeyPath", result) + self.assertIn("certificatePath", result) + self.assertNotIn("signedCertificate", result) + self.assertNotIn("userPrivateKey", result) + self.assertNotIn("role", result) + self.assertNotIn("metadata", result) + + # Verify correct metadata was passed to sign. + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["username"], "user@contoso.com") + self.assertEqual(metadata["role"], "Contributor") + # Expiry should come from PIM (4.0 hours remaining) + self.assertEqual(metadata["expiry"], 4.0) + self.assertIn("ssh-rsa AAAAFAKEPUBLICKEY", metadata["userPublicKey"]) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_expiry_from_pim_short(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + """Expiry should match the remaining PIM duration (1.5h).""" + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 1.5) + mock_role.return_value = "Owner" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["expiry"], 1.5) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_expiry_from_pim_long(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + """PIM with 7.25 hours remaining should pass through as-is.""" + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 7.25) + mock_role.return_value = "Contributor" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["expiry"], 7.25) + + +class TestSshCertCreateCleanupOnFailure(unittest.TestCase): + """Verify ephemeral files are cleaned up when signing fails.""" + + @mock.patch('azext_ssh.provisioned_machine_utils.cleanup_ephemeral_files') + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_cleanup_on_sign_failure(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign, mock_cleanup): + cmd = _make_cmd() + + keys_dir = tempfile.mkdtemp(prefix="azssh_test_") + priv = os.path.join(keys_dir, "id_rsa.pem") + pub = priv + ".pub" + for p in (priv, pub): + with open(p, "w") as f: + f.write("key-content") + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 4.0) + mock_role.return_value = "Owner" + mock_keygen.return_value = (priv, pub) + mock_sign.side_effect = azclierror.CLIInternalError("KV failed") + + with self.assertRaises(azclierror.CLIInternalError): + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + mock_cleanup.assert_called_once() + cleanup_args = mock_cleanup.call_args[0] + self.assertEqual(cleanup_args[0], priv) + + import shutil + shutil.rmtree(keys_dir, ignore_errors=True) + + @mock.patch('azext_ssh.provisioned_machine_utils.cleanup_ephemeral_files') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_cleanup_on_role_failure(self, mock_rid, mock_vn, mock_user, mock_pim, + mock_role, mock_cleanup): + cmd = _make_cmd() + + mock_user.return_value = "user@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 4.0) + mock_role.side_effect = azclierror.AuthenticationError("no role") + + with self.assertRaises(azclierror.AuthenticationError): + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + mock_cleanup.assert_called_once_with(None, None) + + +class TestSshCertCreateRoleDerivation(unittest.TestCase): + """Verify the correct role flows through to the metadata.""" + + def _setup_mocks(self): + self.keys_dir = tempfile.mkdtemp(prefix="azssh_test_") + self.priv_path = os.path.join(self.keys_dir, "id_rsa.pem") + self.pub_path = self.priv_path + ".pub" + with open(self.priv_path, "w") as f: + f.write("PRIVATE") + with open(self.pub_path, "w") as f: + f.write("ssh-rsa PUBLIC") + + self.cert_dir = tempfile.mkdtemp(prefix="azssh_cert_test_") + self.cert_path = os.path.join(self.cert_dir, "ssh-cert.pub") + with open(self.cert_path, "w") as f: + f.write("cert") + + def tearDown(self): + import shutil + for d in [getattr(self, 'keys_dir', None), getattr(self, 'cert_dir', None)]: + if d and os.path.exists(d): + shutil.rmtree(d, ignore_errors=True) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_owner_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, + mock_keygen, mock_sign): + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "admin@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 6.0) + mock_role.return_value = "Owner" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["role"], "Owner") + self.assertEqual(metadata["username"], "admin@contoso.com") + self.assertEqual(metadata["expiry"], 6.0) + + @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') + @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') + @mock.patch('azext_ssh.provisioned_machine_utils.resolve_user_role') + @mock.patch('azext_ssh.provisioned_machine_utils.check_pim_eligibility') + @mock.patch('azext_ssh.provisioned_machine_utils.get_current_user_principal') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_vault_name') + @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') + def test_contributor_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_pim, + mock_role, mock_keygen, mock_sign): + self._setup_mocks() + cmd = _make_cmd() + + mock_user.return_value = "dev@contoso.com" + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 2.0) + mock_role.return_value = "Contributor" + mock_keygen.return_value = (self.priv_path, self.pub_path) + mock_sign.return_value = { + "signedCertificate": "cert", + "certificatePath": self.cert_path, + } + + result = custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) + + metadata = mock_sign.call_args[0][2] + self.assertEqual(metadata["role"], "Contributor") + self.assertEqual(metadata["expiry"], 2.0) + self.assertEqual(result["privateKeyPath"], self.priv_path) + self.assertEqual(result["certificatePath"], self.cert_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py new file mode 100644 index 00000000000..ace32cef1fb --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py @@ -0,0 +1,779 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import base64 +import hashlib +import json +import os +import shutil +import stat +import subprocess +import tempfile +import unittest +from unittest import mock + +from azure.cli.core import azclierror +from azext_ssh import provisioned_machine_utils as pm + + +class TestValidateResourceId(unittest.TestCase): + """Tests for validate_resource_id().""" + + def test_valid_resource_id(self): + valid_id = ( + "/subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/myRG/providers/Microsoft.ProvisionedMachine" + "/machines/myDevice" + ) + # Should not raise + pm.validate_resource_id(valid_id) + + def test_valid_resource_id_mixed_case(self): + valid_id = ( + "/Subscriptions/00000000-0000-0000-0000-000000000000" + "/ResourceGroups/myRG/Providers/Microsoft.Compute" + "/virtualMachines/myVM" + ) + pm.validate_resource_id(valid_id) + + def test_invalid_resource_id_empty(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id("") + + def test_invalid_resource_id_none(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id(None) + + def test_invalid_resource_id_missing_subscriptions(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id("/resourceGroups/myRG/providers/X/Y/Z") + + def test_invalid_resource_id_no_leading_slash(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id( + "subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/rg/providers/X/Y/Z" + ) + + def test_invalid_resource_id_extra_segments(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id( + "/subscriptions/00000000-0000-0000-0000-000000000000" + "/resourceGroups/rg/providers/X/Y/Z/extra/segment" + ) + + def test_invalid_resource_id_random_string(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_resource_id("just-a-random-string") + + +class TestValidateVaultName(unittest.TestCase): + """Tests for validate_vault_name().""" + + def test_valid_vault_name_simple(self): + pm.validate_vault_name("myVault01") + + def test_valid_vault_name_with_hyphens(self): + pm.validate_vault_name("my-key-vault") + + def test_valid_vault_name_min_length(self): + pm.validate_vault_name("abc") # 3 chars + + def test_valid_vault_name_max_length(self): + pm.validate_vault_name("a" + "b" * 22 + "c") # 24 chars + + def test_invalid_vault_name_empty(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("") + + def test_invalid_vault_name_none(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name(None) + + def test_invalid_vault_name_starts_with_digit(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("1vault") + + def test_invalid_vault_name_starts_with_hyphen(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("-vault") + + def test_invalid_vault_name_ends_with_hyphen(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("vault-") + + def test_invalid_vault_name_too_short(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("ab") + + def test_invalid_vault_name_too_long(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("a" * 30) + + def test_invalid_vault_name_special_chars(self): + with self.assertRaises(azclierror.InvalidArgumentValueError): + pm.validate_vault_name("vault_name!") + + +class TestGetCurrentUserPrincipal(unittest.TestCase): + """Tests for get_current_user_principal().""" + + @mock.patch('azure.cli.core._profile.Profile') + def test_returns_user(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_current_account_user.return_value = "user@contoso.com" + + result = pm.get_current_user_principal(cmd) + self.assertEqual(result, "user@contoso.com") + mock_profile_cls.assert_called_once_with(cli_ctx=cmd.cli_ctx) + + @mock.patch('azure.cli.core._profile.Profile') + def test_raises_when_no_user(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_current_account_user.return_value = None + + with self.assertRaises(azclierror.AuthenticationError): + pm.get_current_user_principal(cmd) + + @mock.patch('azure.cli.core._profile.Profile') + def test_raises_when_profile_throws(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_current_account_user.side_effect = Exception("not logged in") + + with self.assertRaises(azclierror.AuthenticationError): + pm.get_current_user_principal(cmd) + + +class TestGenerateEphemeralKeypair(unittest.TestCase): + """Tests for generate_ephemeral_keypair().""" + + @mock.patch('oschmod.set_mode') + @mock.patch('os.path.isfile') + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_success(self, mock_mkdtemp, mock_check_call, mock_isfile, mock_chmod): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_isfile.return_value = True + + priv, pub = pm.generate_ephemeral_keypair() + + expected_priv = os.path.join("/tmp/azssh_pm_test", "id_rsa.pem") + expected_pub = os.path.join("/tmp/azssh_pm_test", "id_rsa.pem.pub") + self.assertEqual(priv, expected_priv) + self.assertEqual(pub, expected_pub) + mock_check_call.assert_called_once() + # Verify ssh-keygen args + call_args = mock_check_call.call_args[0][0] + self.assertEqual(call_args[0], "ssh-keygen") + self.assertIn("-t", call_args) + self.assertIn("rsa", call_args) + self.assertIn("4096", call_args) + # Verify permissions were set to 0600 + mock_chmod.assert_called_once_with(expected_priv, 0o600) + + @mock.patch('oschmod.set_mode') + @mock.patch('os.path.isfile') + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_custom_ssh_client_folder(self, mock_mkdtemp, mock_check_call, + mock_isfile, mock_chmod): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_isfile.return_value = True + + pm.generate_ephemeral_keypair(ssh_client_folder="/custom/path") + + call_args = mock_check_call.call_args[0][0] + self.assertEqual(call_args[0], os.path.join("/custom/path", "ssh-keygen")) + + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keygen_not_found(self, mock_mkdtemp, mock_check_call): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_check_call.side_effect = FileNotFoundError("not found") + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("ssh-keygen not found", str(ctx.exception)) + + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keygen_timeout(self, mock_mkdtemp, mock_check_call): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_check_call.side_effect = subprocess.TimeoutExpired(cmd="keygen", timeout=30) + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("timed out", str(ctx.exception)) + + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keygen_nonzero_exit(self, mock_mkdtemp, mock_check_call): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_check_call.side_effect = subprocess.CalledProcessError( + returncode=1, cmd="keygen" + ) + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("exited with code 1", str(ctx.exception)) + + @mock.patch('os.path.isfile') + @mock.patch('subprocess.check_call') + @mock.patch('tempfile.mkdtemp') + def test_keys_not_created(self, mock_mkdtemp, mock_check_call, mock_isfile): + mock_mkdtemp.return_value = "/tmp/azssh_pm_test" + mock_isfile.return_value = False # Files don't exist after keygen + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.generate_ephemeral_keypair() + self.assertIn("key files were not created", str(ctx.exception)) + + +class TestCleanupEphemeralFiles(unittest.TestCase): + """Tests for cleanup_ephemeral_files().""" + + def test_cleanup_files_in_temp_dir(self): + # Create real temp files to clean up + temp_dir = tempfile.mkdtemp(prefix="azssh_test_") + test_file = os.path.join(temp_dir, "test_key") + with open(test_file, "w") as f: + f.write("secret") + + pm.cleanup_ephemeral_files(test_file) + self.assertFalse(os.path.exists(temp_dir)) + + def test_cleanup_none_path(self): + # Should not raise + pm.cleanup_ephemeral_files(None) + + def test_cleanup_nonexistent_path(self): + # Should not raise + pm.cleanup_ephemeral_files("/nonexistent/path/key") + + def test_cleanup_multiple_paths(self): + dir1 = tempfile.mkdtemp(prefix="azssh_test_") + dir2 = tempfile.mkdtemp(prefix="azssh_test_") + f1 = os.path.join(dir1, "key1") + f2 = os.path.join(dir2, "key2") + for f in (f1, f2): + with open(f, "w") as fh: + fh.write("secret") + + pm.cleanup_ephemeral_files(f1, f2) + self.assertFalse(os.path.exists(dir1)) + self.assertFalse(os.path.exists(dir2)) + + +class TestCheckPimEligibility(unittest.TestCase): + """Tests for check_pim_eligibility().""" + + _RESOURCE_ID = "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z" + + def _mock_response(self, status_code, json_data): + resp = mock.Mock() + resp.status_code = status_code + resp.json.return_value = json_data + resp.text = json.dumps(json_data) + return resp + + def _setup_cmd_with_profile(self, mock_oid): + """Helper to set up cmd mock and profile mock for PIM tests.""" + cmd = mock.Mock() + cmd.cli_ctx = mock.Mock() + mock_oid.return_value = "oid-123" + + profile_mock = mock.Mock() + creds_mock = mock.Mock() + token_mock = mock.Mock() + token_mock.token = "fake-token" + creds_mock.get_token.return_value = token_mock + profile_mock.get_login_credentials.return_value = (creds_mock, None, None) + return cmd, profile_mock + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_pim_activated_passes(self, mock_profile_cls, mock_oid, mock_get): + """PIM-activated assignment should pass.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, { + "value": [{"properties": { + "assignmentType": "Activated", + "endDateTime": "2099-01-01T00:00:00Z" + }}] + }) + + instances, expiry_hours = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertEqual(len(instances), 1) + self.assertGreater(expiry_hours, 0) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_direct_assignment_blocked(self, mock_profile_cls, mock_oid, mock_get): + """Direct/permanent assignment (Assigned) should be rejected.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, { + "value": [{"properties": {"assignmentType": "Assigned"}}] + }) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertIn("direct (permanent)", str(ctx.exception)) + self.assertIn("PIM-based JIT activation is required", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_no_assignments_blocked(self, mock_profile_cls, mock_oid, mock_get): + """No assignments at all should be rejected.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, {"value": []}) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertIn("No active PIM role assignment", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_404_raises_not_found(self, mock_profile_cls, mock_oid, mock_get): + """Resource not found should raise ResourceNotFoundError.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(404, {"error": "not found"}) + + with self.assertRaises(azclierror.ResourceNotFoundError): + pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core._profile.Profile') + def test_mixed_only_activated_returned(self, mock_profile_cls, mock_oid, mock_get): + """When both Activated and Assigned exist, only Activated should pass.""" + cmd, profile_mock = self._setup_cmd_with_profile(mock_oid) + mock_profile_cls.return_value = profile_mock + + mock_get.return_value = self._mock_response(200, { + "value": [ + {"properties": {"assignmentType": "Assigned"}}, + {"properties": {"assignmentType": "Activated", "endDateTime": "2099-01-01T00:00:00Z"}}, + {"properties": {"assignmentType": "Assigned"}}, + ] + }) + + instances, expiry_hours = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + self.assertEqual(len(instances), 1) + self.assertGreater(expiry_hours, 0) + + +class TestResolveUserRole(unittest.TestCase): + """Tests for resolve_user_role().""" + + def _make_assignment(self, role_def_id): + assignment = mock.Mock() + assignment.role_definition_id = role_def_id + return assignment + + def _make_role_def(self, role_name): + role_def = mock.Mock() + role_def.role_name = role_name + return role_def + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_resolves_owner(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Owner" + ) + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Owner") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_resolves_contributor(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Provisioned Machine Contributor" + ) + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Contributor") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_resolves_reader(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Provisioned Machine Reader" + ) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("Reader", str(ctx.exception)) + self.assertIn("does not have SSH access", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_picks_highest_privilege(self, mock_client_factory, mock_oid): + """When user has both Reader and Owner, Owner should win.""" + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + a1 = self._make_assignment("role-def-1") + a2 = self._make_assignment("role-def-2") + auth_client.role_assignments.list_for_scope.return_value = [a1, a2] + auth_client.role_definitions.get_by_id.side_effect = [ + self._make_role_def("Reader"), + self._make_role_def("Owner"), + ] + + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Owner") + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_no_assignments_raises(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + auth_client.role_assignments.list_for_scope.return_value = [] + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("No role assignments", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_unrecognized_role_raises(self, mock_client_factory, mock_oid): + """Assignments exist but none match Reader/Contributor/Admin.""" + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + + assignment = self._make_assignment("role-def-1") + auth_client.role_assignments.list_for_scope.return_value = [assignment] + auth_client.role_definitions.get_by_id.return_value = self._make_role_def( + "Storage Blob Data Processor" + ) + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("No Reader, Contributor, or Owner", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') + @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') + def test_query_failure_raises(self, mock_client_factory, mock_oid): + cmd = mock.Mock() + mock_oid.return_value = "oid-123" + + auth_client = mock.Mock() + mock_client_factory.return_value = auth_client + auth_client.role_assignments.list_for_scope.side_effect = Exception("network error") + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertIn("Failed to query role assignments", str(ctx.exception)) + + +class TestGetCurrentUserObjectId(unittest.TestCase): + """Tests for _get_current_user_object_id().""" + + def _make_jwt(self, claims): + """Build a fake JWT with the given claims dict.""" + header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').decode().rstrip("=") + payload = base64.urlsafe_b64encode( + json.dumps(claims).encode() + ).decode().rstrip("=") + signature = base64.urlsafe_b64encode(b"sig").decode().rstrip("=") + return f"{header}.{payload}.{signature}" + + @mock.patch('azure.cli.core._profile.Profile') + def test_extracts_oid(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + + token = mock.Mock() + token.token = self._make_jwt({"oid": "user-oid-123", "sub": "sub-456"}) + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + + result = pm._get_current_user_object_id(cmd) + self.assertEqual(result, "user-oid-123") + + @mock.patch('azure.cli.core._profile.Profile') + def test_falls_back_to_sub(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + + token = mock.Mock() + token.token = self._make_jwt({"sub": "sub-456"}) + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + + result = pm._get_current_user_object_id(cmd) + self.assertEqual(result, "sub-456") + + @mock.patch('azure.cli.core._profile.Profile') + def test_no_oid_or_sub_raises(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + + token = mock.Mock() + token.token = self._make_jwt({"name": "user"}) # no oid/sub + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + + with self.assertRaises(azclierror.CLIInternalError): + pm._get_current_user_object_id(cmd) + + @mock.patch('azure.cli.core._profile.Profile') + def test_token_failure_raises(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_login_credentials.side_effect = Exception("not logged in") + + with self.assertRaises(azclierror.AuthenticationError): + pm._get_current_user_object_id(cmd) + + +class TestBuildSigningPayload(unittest.TestCase): + """Tests for _build_signing_payload().""" + + def test_produces_base64url_sha256_digest(self): + metadata = {"key": "value", "num": 42} + result = pm._build_signing_payload(metadata) + + # Verify it's a SHA-256 digest, base64url encoded, no padding + canonical = json.dumps( + metadata, separators=(",", ":"), sort_keys=True, ensure_ascii=True + ).encode("utf-8") + expected_digest = hashlib.sha256(canonical).digest() + expected = base64.urlsafe_b64encode(expected_digest).decode("ascii").rstrip("=") + self.assertEqual(result, expected) + + def test_deterministic(self): + metadata = {"b": 2, "a": 1} + r1 = pm._build_signing_payload(metadata) + r2 = pm._build_signing_payload(metadata) + self.assertEqual(r1, r2) + + def test_key_order_independent(self): + m1 = {"a": 1, "b": 2} + m2 = {"b": 2, "a": 1} + self.assertEqual(pm._build_signing_payload(m1), pm._build_signing_payload(m2)) + + +class TestCallKeyvaultSign(unittest.TestCase): + """Tests for _call_keyvault_sign().""" + + def _setup_cmd_with_creds(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + token = mock.Mock() + token.token = "fake-token" + mock_creds = mock.Mock() + mock_creds.get_token.return_value = token + mock_profile.get_login_credentials.return_value = (mock_creds, None, None) + return cmd + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_success(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + metadata = {"userPublicKey": "ssh-rsa AAAA...", "username": "user"} + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": "signed-value-b64"} + mock_requests.post.return_value = mock_response + + sig, cert = pm._call_keyvault_sign(cmd, "myVault", metadata) + + self.assertEqual(sig, "signed-value-b64") + self.assertIn("ssh-rsa AAAA...", cert) + # Verify URL is correct + post_url = mock_requests.post.call_args[0][0] + self.assertIn("myVault.vault.azure.net", post_url) + self.assertIn("/keys/ssh-ca/sign", post_url) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_401_raises_auth_error(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 401 + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.AuthenticationError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("Access denied", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_404_raises_not_found(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.ResourceNotFoundError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("ssh-ca", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_500_raises_internal_error(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("HTTP 500", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_timeout_raises(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + import requests as real_requests + mock_requests.post.side_effect = real_requests.exceptions.Timeout("timed out") + mock_requests.exceptions = real_requests.exceptions + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("timed out", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_connection_error_raises(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + import requests as real_requests + mock_requests.post.side_effect = real_requests.exceptions.ConnectionError("dns fail") + mock_requests.exceptions = real_requests.exceptions + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("Unable to connect", str(ctx.exception)) + + @mock.patch('azext_ssh.provisioned_machine_utils.requests') + @mock.patch('azure.cli.core._profile.Profile') + def test_empty_signature_raises(self, mock_profile_cls, mock_requests): + cmd = self._setup_cmd_with_creds(mock_profile_cls) + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"value": ""} + mock_requests.post.return_value = mock_response + + with self.assertRaises(azclierror.CLIInternalError) as ctx: + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + self.assertIn("empty signature", str(ctx.exception)) + + @mock.patch('azure.cli.core._profile.Profile') + def test_token_failure_raises_auth(self, mock_profile_cls): + cmd = mock.Mock() + mock_profile = mock.Mock() + mock_profile_cls.return_value = mock_profile + mock_profile.get_login_credentials.side_effect = Exception("no token") + + with self.assertRaises(azclierror.AuthenticationError): + pm._call_keyvault_sign(cmd, "myVault", {"userPublicKey": "k"}) + + +class TestSignCertificateMetadata(unittest.TestCase): + """Tests for sign_certificate_metadata().""" + + @mock.patch('oschmod.set_mode') + @mock.patch('azext_ssh.provisioned_machine_utils._call_keyvault_sign') + def test_writes_cert_file(self, mock_sign, mock_chmod): + cmd = mock.Mock() + mock_sign.return_value = ("sig_b64", "cert-content-here") + + metadata = { + "userPublicKey": "ssh-rsa AAAA", + "username": "user@contoso.com", + "role": "Owner", + "expiry": 4.0, + } + + result = pm.sign_certificate_metadata( + cmd, "myVault", metadata, + "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z" + ) + + self.assertEqual(result["signedCertificate"], "cert-content-here") + self.assertTrue(os.path.isfile(result["certificatePath"])) + with open(result["certificatePath"], "r") as f: + self.assertEqual(f.read(), "cert-content-here") + + # Verify permissions were set + mock_chmod.assert_called_once() + + # Clean up + shutil.rmtree(os.path.dirname(result["certificatePath"]), + ignore_errors=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/verify.py b/verify.py new file mode 100644 index 00000000000..f5e79d89e4c --- /dev/null +++ b/verify.py @@ -0,0 +1,94 @@ +""" +Verify an SSH certificate generated by `az ssh cert-create`. + +Usage: + python verify.py --cert --metadata --ca-pub + +Where: + --cert Path to the certificate file (ssh-cert.pub) + --metadata Path to a JSON file with the signing payload metadata + --ca-pub Path to the CA public key (PEM). Download with: + az keyvault key download --vault-name --name ssh-ca --encoding PEM -f ca_public.pem +""" + +import argparse +import json +import hashlib +import base64 +import sys + +from cryptography.hazmat.primitives.asymmetric import padding, utils +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.exceptions import InvalidSignature + + +def verify_certificate(cert_path, metadata_path, ca_pub_path): + # 1. Load CA public key + with open(ca_pub_path, "rb") as f: + ca_pub = serialization.load_pem_public_key(f.read()) + print(f"[OK] Loaded CA public key from {ca_pub_path}") + + # 2. Load metadata (the signing payload) + with open(metadata_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + print(f"[OK] Loaded metadata: {json.dumps(metadata, indent=2)}") + + # 3. Read certificate file — current format: " " + with open(cert_path, "r", encoding="utf-8") as f: + cert_content = f.read().strip() + + # Extract signature (last space-separated token) + parts = cert_content.rsplit(" ", 1) + if len(parts) != 2: + print("[FAIL] Certificate format unexpected — expected ' '") + return False + cert_public_key_part, signature_b64 = parts[0], parts[1] + print(f"[OK] Extracted signature from certificate ({len(signature_b64)} chars)") + + # 4. Verify public key in cert matches metadata + if metadata.get("userPublicKey", "").strip() != cert_public_key_part.strip(): + print("[FAIL] Public key in certificate does not match metadata userPublicKey") + return False + print("[OK] Public key in certificate matches metadata") + + # 5. Reconstruct the digest that was signed + canonical = json.dumps( + metadata, separators=(",", ":"), sort_keys=True, ensure_ascii=True + ).encode("utf-8") + digest = hashlib.sha256(canonical).digest() + print(f"[OK] Computed SHA-256 digest of canonical metadata") + + # 6. Decode the signature (base64url) + # Add padding if needed + sig_padded = signature_b64 + "=" * (4 - len(signature_b64) % 4) + sig_bytes = base64.urlsafe_b64decode(sig_padded) + print(f"[OK] Decoded signature ({len(sig_bytes)} bytes)") + + # 7. Verify: CA public key + RS256 (PKCS1v15 + SHA256 prehashed) + try: + ca_pub.verify( + sig_bytes, + digest, + padding.PKCS1v15(), + utils.Prehashed(hashes.SHA256()), + ) + print("[OK] Signature is VALID — certificate was signed by the CA key") + return True + except InvalidSignature: + print("[FAIL] Signature is INVALID — certificate was NOT signed by this CA key") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Verify an az ssh cert-create certificate") + parser.add_argument("--cert", required=True, help="Path to ssh-cert.pub") + parser.add_argument("--metadata", required=True, help="Path to metadata JSON file") + parser.add_argument("--ca-pub", required=True, help="Path to CA public key (PEM)") + args = parser.parse_args() + + ok = verify_certificate(args.cert, args.metadata, args.ca_pub) + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main() \ No newline at end of file From a252112315aa0856f174582287e70c4498f1d923 Mon Sep 17 00:00:00 2001 From: Pushkar Srivastava Date: Tue, 26 May 2026 10:32:19 +0530 Subject: [PATCH 2/3] Reader role permission change --- src/ssh/azext_ssh/_help.py | 21 ++++++++++++++++--- .../azext_ssh/provisioned_machine_utils.py | 13 +++--------- .../latest/test_provisioned_machine_utils.py | 7 +++---- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index 6155486dd57..2c55b7a6563 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -122,9 +122,24 @@ short-summary: Create a short-lived SSH certificate signed by a private CA key in Azure Key Vault. long-summary: | Generates an ephemeral SSH key pair, determines the caller's RBAC role - (Reader/Contributor/Owner) on the target ProvisionedMachine resource via - PIM-based JIT access, and sends the public key along with metadata - (userPublicKey, username, role, expiry) to Key Vault for signing. + on the target ProvisionedMachine resource via PIM-based JIT access, and + sends the public key along with metadata (userPublicKey, username, role, + expiry) to Key Vault for signing. + + The user's role is NOT taken as input — it is resolved automatically from + the RBAC role assignment on the device resource. + + Currently the extension relies on the built-in Azure roles (Owner, + Contributor, Reader) because we do not yet have permission to create + custom roles. The final intended roles are: + - Provisioned Machine Administrator (full SSH with sudo) + - Provisioned Machine Contributor (SSH without sudo) + - Provisioned Machine Reader (view-only; SSH restricted on device) + Certificates are generated for all roles — access restrictions are + enforced on the device side, not by the CLI. + These custom roles are pending creation (Teodora, Eric — please help + finalize so the CLI extension can be completed). + The user identity is derived automatically from the Entra login context. The certificate expiry is derived from the PIM activation's remaining duration. Returns the signed SSH user certificate and the freshly generated private key. diff --git a/src/ssh/azext_ssh/provisioned_machine_utils.py b/src/ssh/azext_ssh/provisioned_machine_utils.py index 09e9b6bcedf..971d62aedde 100644 --- a/src/ssh/azext_ssh/provisioned_machine_utils.py +++ b/src/ssh/azext_ssh/provisioned_machine_utils.py @@ -64,7 +64,9 @@ # Each role defines which certificate types it can generate and what # capabilities the device should grant. # -# Reader - View-only in portal; NO SSH access, NO certificate generation. +# Certificates are generated for ALL roles. Access restrictions are enforced +# on the device side, not by the CLI. +# Reader - View-only in portal; device blocks SSH access. # Contributor - Config app + SSH (non-sudo). # Owner - Config app + SSH (non-sudo) + SSH (sudo). ROLE_PERMISSIONS = { @@ -329,15 +331,6 @@ def resolve_user_role(cmd, resource_id): f"ProvisionedMachine resource." ) - # Reader role has no SSH access — block certificate generation. - if best_role == "Reader": - raise azclierror.AuthenticationError( - f"Your highest role on '{resource_id}' is Reader. " - f"Reader role does not have SSH access and cannot generate certificates. " - f"You need at least Contributor role for SSH (non-sudo) access, " - f"or Owner role for SSH (sudo) access." - ) - logger.info("Resolved role '%s' for user '%s' on resource '%s'.", best_role, user_object_id, resource_id) return best_role diff --git a/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py index ace32cef1fb..c0c20f8dd20 100644 --- a/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py @@ -434,6 +434,7 @@ def test_resolves_contributor(self, mock_client_factory, mock_oid): @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') def test_resolves_reader(self, mock_client_factory, mock_oid): + """Reader role should now succeed — restriction is device-side.""" cmd = mock.Mock() mock_oid.return_value = "oid-123" @@ -446,10 +447,8 @@ def test_resolves_reader(self, mock_client_factory, mock_oid): "Provisioned Machine Reader" ) - with self.assertRaises(azclierror.AuthenticationError) as ctx: - pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") - self.assertIn("Reader", str(ctx.exception)) - self.assertIn("does not have SSH access", str(ctx.exception)) + result = pm.resolve_user_role(cmd, "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z") + self.assertEqual(result, "Reader") @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') @mock.patch('azure.cli.core.commands.client_factory.get_mgmt_service_client') From 148e34ce525f70589db68da67ffac37be6a947ca Mon Sep 17 00:00:00 2001 From: Pushkar Srivastava Date: Tue, 26 May 2026 14:08:09 +0530 Subject: [PATCH 3/3] changing expiry in metadat to start and end time --- src/ssh/azext_ssh/custom.py | 13 ++++--- .../azext_ssh/provisioned_machine_utils.py | 33 ++++++++-------- .../tests/latest/test_cert_create.py | 39 ++++++++++++------- .../latest/test_provisioned_machine_utils.py | 16 ++++---- 4 files changed, 58 insertions(+), 43 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index f530c37730f..8cbd3d526ef 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -184,10 +184,10 @@ def ssh_cert_create(cmd, vault_name, resource_id): logger.info("Derived username: %s", username) # Verify the user has an active PIM assignment (JIT activated). - # Expiry is derived from the PIM activation's remaining duration. - _pim_instances, expiry = pm.check_pim_eligibility(cmd, resource_id) - logger.info("PIM eligibility confirmed for resource: %s (%.2f hours remaining)", - resource_id, expiry) + # startTime/endTime are derived from the PIM activation window. + _pim_instances, start_time, end_time = pm.check_pim_eligibility(cmd, resource_id) + logger.info("PIM eligibility confirmed for resource: %s (valid %s to %s)", + resource_id, start_time, end_time) # Resolve role from PIM assignment on the ProvisionedMachine resource. # Reader role is blocked — only Contributor and Administrator can @@ -209,14 +209,15 @@ def ssh_cert_create(cmd, vault_name, resource_id): "userPublicKey": user_public_key, "username": username, "role": role, - "expiry": expiry, + "startTime": start_time, + "endTime": end_time, } # -- Step 2: Sign via Key Vault ------------------------------------ # AZ CLI sends signing request using az login context. # CA private key never leaves Key Vault. signed_certificate = pm.sign_certificate_metadata( - cmd, vault_name, certificate_metadata, resource_id + cmd, vault_name, certificate_metadata ) cert_path = signed_certificate["certificatePath"] diff --git a/src/ssh/azext_ssh/provisioned_machine_utils.py b/src/ssh/azext_ssh/provisioned_machine_utils.py index 971d62aedde..40131461319 100644 --- a/src/ssh/azext_ssh/provisioned_machine_utils.py +++ b/src/ssh/azext_ssh/provisioned_machine_utils.py @@ -248,18 +248,21 @@ def check_pim_eligibility(cmd, resource_id): "Cannot determine certificate expiry." ) - remaining_hours = (latest_end - now_utc).total_seconds() / 3600.0 - if remaining_hours <= 0: + if latest_end <= now_utc: raise azclierror.AuthenticationError( f"Your PIM activation has expired (ended {latest_end.isoformat()}). " f"Please re-activate your PIM-eligible role and retry." ) + start_time = now_utc.strftime("%Y-%m-%dT%H:%M:%SZ") + end_time = latest_end.strftime("%Y-%m-%dT%H:%M:%SZ") + + remaining_hours = (latest_end - now_utc).total_seconds() / 3600.0 logger.info("Found %d PIM-activated assignment(s) for user '%s' on '%s'. " "Remaining: %.2f hours (until %s).", len(pim_activated), user_object_id, resource_id, remaining_hours, latest_end.isoformat()) - return pim_activated, remaining_hours + return pim_activated, start_time, end_time def resolve_user_role(cmd, resource_id): @@ -398,10 +401,10 @@ def cleanup_ephemeral_files(*file_paths): logger.debug("Failed to clean up '%s'.", path) -def sign_certificate_metadata(cmd, keyvault_name, metadata, resource_id): +def sign_certificate_metadata(cmd, keyvault_name, metadata): """Sign the certificate metadata with the CA private key in Key Vault. - Metadata shape: { userPublicKey, username, role, expiry } + Metadata shape: { userPublicKey, username, role, startTime, endTime } The Key Vault hosts a non-exportable CA private key (named ``ssh-ca``). AZ CLI sends the signing request using the az login context. @@ -409,18 +412,16 @@ def sign_certificate_metadata(cmd, keyvault_name, metadata, resource_id): Returns a dict with ``signedCertificate`` and ``certificatePath``. """ - expiry_hours = metadata["expiry"] - signing_payload = { "userPublicKey": metadata["userPublicKey"], "username": metadata["username"], "role": metadata["role"], - "expiry": expiry_hours, - "resourceId": resource_id, + "startTime": metadata["startTime"], + "endTime": metadata["endTime"], } - logger.info("Sending signing request to Key Vault '%s' (expiry %.2f hours) ...", - keyvault_name, expiry_hours) + logger.info("Sending signing request to Key Vault '%s' (valid %s to %s) ...", + keyvault_name, metadata["startTime"], metadata["endTime"]) # Sign via Key Vault - CA private key never leaves the vault. _signature, cert_data = _call_keyvault_sign(cmd, keyvault_name, signing_payload) @@ -432,11 +433,11 @@ def sign_certificate_metadata(cmd, keyvault_name, metadata, resource_id): f.write(cert_data) oschmod.set_mode(cert_path, stat.S_IRUSR | stat.S_IWUSR) # 0600 - # Write the signing payload metadata alongside the cert for verification. - metadata_path = os.path.join(cert_dir, "metadata.json") - with open(metadata_path, "w", encoding="utf-8") as f: - json.dump(signing_payload, f, indent=2) - logger.info("Signing metadata written to %s", metadata_path) + # Uncomment below to write signing payload metadata for debugging/verification. + # metadata_path = os.path.join(cert_dir, "metadata.json") + # with open(metadata_path, "w", encoding="utf-8") as f: + # json.dump(signing_payload, f, indent=2) + # logger.info("Signing metadata written to %s", metadata_path) logger.info("Signed SSH user certificate written to %s", cert_path) return { diff --git a/src/ssh/azext_ssh/tests/latest/test_cert_create.py b/src/ssh/azext_ssh/tests/latest/test_cert_create.py index 72c847bc80a..44499a00e88 100644 --- a/src/ssh/azext_ssh/tests/latest/test_cert_create.py +++ b/src/ssh/azext_ssh/tests/latest/test_cert_create.py @@ -94,7 +94,8 @@ def test_happy_path(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, cmd = _make_cmd() mock_user.return_value = "user@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 4.0) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T14:00:00Z") mock_role.return_value = "Contributor" mock_keygen.return_value = (self.priv_path, self.pub_path) mock_sign.return_value = { @@ -116,8 +117,9 @@ def test_happy_path(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, metadata = mock_sign.call_args[0][2] self.assertEqual(metadata["username"], "user@contoso.com") self.assertEqual(metadata["role"], "Contributor") - # Expiry should come from PIM (4.0 hours remaining) - self.assertEqual(metadata["expiry"], 4.0) + # startTime/endTime should come from PIM + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T14:00:00Z") self.assertIn("ssh-rsa AAAAFAKEPUBLICKEY", metadata["userPublicKey"]) @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') @@ -129,12 +131,13 @@ def test_happy_path(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, @mock.patch('azext_ssh.provisioned_machine_utils.validate_resource_id') def test_expiry_from_pim_short(self, mock_rid, mock_vn, mock_user, mock_pim, mock_role, mock_keygen, mock_sign): - """Expiry should match the remaining PIM duration (1.5h).""" + """startTime/endTime should reflect a short PIM window (1.5h).""" self._setup_mocks() cmd = _make_cmd() mock_user.return_value = "user@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 1.5) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T11:30:00Z") mock_role.return_value = "Owner" mock_keygen.return_value = (self.priv_path, self.pub_path) mock_sign.return_value = { @@ -145,7 +148,8 @@ def test_expiry_from_pim_short(self, mock_rid, mock_vn, mock_user, mock_pim, moc custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) metadata = mock_sign.call_args[0][2] - self.assertEqual(metadata["expiry"], 1.5) + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T11:30:00Z") @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') @@ -161,7 +165,8 @@ def test_expiry_from_pim_long(self, mock_rid, mock_vn, mock_user, mock_pim, mock cmd = _make_cmd() mock_user.return_value = "user@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 7.25) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T17:15:00Z") mock_role.return_value = "Contributor" mock_keygen.return_value = (self.priv_path, self.pub_path) mock_sign.return_value = { @@ -172,7 +177,8 @@ def test_expiry_from_pim_long(self, mock_rid, mock_vn, mock_user, mock_pim, mock custom.ssh_cert_create(cmd, _VALID_VAULT, _VALID_RESOURCE_ID) metadata = mock_sign.call_args[0][2] - self.assertEqual(metadata["expiry"], 7.25) + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T17:15:00Z") class TestSshCertCreateCleanupOnFailure(unittest.TestCase): @@ -198,7 +204,8 @@ def test_cleanup_on_sign_failure(self, mock_rid, mock_vn, mock_user, mock_pim, m f.write("key-content") mock_user.return_value = "user@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 4.0) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T14:00:00Z") mock_role.return_value = "Owner" mock_keygen.return_value = (priv, pub) mock_sign.side_effect = azclierror.CLIInternalError("KV failed") @@ -224,7 +231,8 @@ def test_cleanup_on_role_failure(self, mock_rid, mock_vn, mock_user, mock_pim, cmd = _make_cmd() mock_user.return_value = "user@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 4.0) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T14:00:00Z") mock_role.side_effect = azclierror.AuthenticationError("no role") with self.assertRaises(azclierror.AuthenticationError): @@ -269,7 +277,8 @@ def test_owner_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_pim, mo cmd = _make_cmd() mock_user.return_value = "admin@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 6.0) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T16:00:00Z") mock_role.return_value = "Owner" mock_keygen.return_value = (self.priv_path, self.pub_path) mock_sign.return_value = { @@ -282,7 +291,7 @@ def test_owner_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_pim, mo metadata = mock_sign.call_args[0][2] self.assertEqual(metadata["role"], "Owner") self.assertEqual(metadata["username"], "admin@contoso.com") - self.assertEqual(metadata["expiry"], 6.0) + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") @mock.patch('azext_ssh.provisioned_machine_utils.sign_certificate_metadata') @mock.patch('azext_ssh.provisioned_machine_utils.generate_ephemeral_keypair') @@ -297,7 +306,8 @@ def test_contributor_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_p cmd = _make_cmd() mock_user.return_value = "dev@contoso.com" - mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], 2.0) + mock_pim.return_value = ([{"properties": {"assignmentType": "Activated"}}], + "2026-05-26T10:00:00Z", "2026-05-26T12:00:00Z") mock_role.return_value = "Contributor" mock_keygen.return_value = (self.priv_path, self.pub_path) mock_sign.return_value = { @@ -309,7 +319,8 @@ def test_contributor_role_in_metadata(self, mock_rid, mock_vn, mock_user, mock_p metadata = mock_sign.call_args[0][2] self.assertEqual(metadata["role"], "Contributor") - self.assertEqual(metadata["expiry"], 2.0) + self.assertEqual(metadata["startTime"], "2026-05-26T10:00:00Z") + self.assertEqual(metadata["endTime"], "2026-05-26T12:00:00Z") self.assertEqual(result["privateKeyPath"], self.priv_path) self.assertEqual(result["certificatePath"], self.cert_path) diff --git a/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py index c0c20f8dd20..7e3e9a5b432 100644 --- a/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_provisioned_machine_utils.py @@ -313,9 +313,10 @@ def test_pim_activated_passes(self, mock_profile_cls, mock_oid, mock_get): }}] }) - instances, expiry_hours = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + instances, start_time, end_time = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) self.assertEqual(len(instances), 1) - self.assertGreater(expiry_hours, 0) + self.assertIn("T", start_time) + self.assertIn("T", end_time) @mock.patch('azext_ssh.provisioned_machine_utils.requests.get') @mock.patch('azext_ssh.provisioned_machine_utils._get_current_user_object_id') @@ -377,9 +378,10 @@ def test_mixed_only_activated_returned(self, mock_profile_cls, mock_oid, mock_ge ] }) - instances, expiry_hours = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) + instances, start_time, end_time = pm.check_pim_eligibility(cmd, self._RESOURCE_ID) self.assertEqual(len(instances), 1) - self.assertGreater(expiry_hours, 0) + self.assertIn("T", start_time) + self.assertIn("T", end_time) class TestResolveUserRole(unittest.TestCase): @@ -753,12 +755,12 @@ def test_writes_cert_file(self, mock_sign, mock_chmod): "userPublicKey": "ssh-rsa AAAA", "username": "user@contoso.com", "role": "Owner", - "expiry": 4.0, + "startTime": "2026-05-26T10:00:00Z", + "endTime": "2026-05-26T14:00:00Z", } result = pm.sign_certificate_metadata( - cmd, "myVault", metadata, - "/subscriptions/sub/resourceGroups/rg/providers/X/Y/Z" + cmd, "myVault", metadata ) self.assertEqual(result["signedCertificate"], "cert-content-here")