diff --git a/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py b/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py index dd30846daa3..e36b861dac6 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py +++ b/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py @@ -99,13 +99,28 @@ def _handle_challenge_phase(login_server, logger.debug(add_timestamp("Sending a HTTP Get request to {}".format(request_url))) challenge = requests.get(request_url, verify=not should_disable_connection_verify()) - if challenge.status_code != 401 or 'WWW-Authenticate' not in challenge.headers: + if challenge.status_code not in [401, 403]: from ._errors import CONNECTIVITY_CHALLENGE_ERROR if is_diagnostics_context: return CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server) raise CLIError(CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server).get_error_message()) - authenticate = challenge.headers['WWW-Authenticate'] + authenticate = challenge.headers.get('WWW-Authenticate') + if not authenticate: + if is_aad_token and challenge.status_code == 403: + logger.warning( + "Received 403 challenge response without WWW-Authenticate from '%s'. " + "Falling back to default ACR token endpoints.", + login_server, + ) + return { + 'realm': 'https://{}/oauth2/token'.format(login_server), + 'service': login_server + } + from ._errors import CONNECTIVITY_CHALLENGE_ERROR + if is_diagnostics_context: + return CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server) + raise CLIError(CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server).get_error_message()) tokens = authenticate.split(' ', 2) diff --git a/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py b/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py index 95c8e3c34db..9e52a4d4285 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py +++ b/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py @@ -9,6 +9,8 @@ from unittest import mock import sys +from knack.util import CLIError + from azure.cli.command_modules.acr.repository import ( acr_repository_list, acr_repository_show_tags, @@ -44,6 +46,7 @@ get_access_credentials, get_authorization_header, get_manifest_authorization_header, + _handle_challenge_phase, _resolve_acr_scope, RepoAccessTokenPermission, HelmAccessTokenPermission, @@ -1219,6 +1222,71 @@ def _core_token_scenarios(self, mock_get_raw_token, mock_requests_get, mock_requ get_access_credentials(cmd, registry_name, tenant_suffix=tenant_suffix, artifact_repository=TEST_REPOSITORY, permission=HelmAccessTokenPermission.PULL.value) self._validate_access_token_request(mock_requests_get, mock_requests_post, login_server, 'artifact-repository:{}:{}'.format(TEST_REPOSITORY, HelmAccessTokenPermission.PULL.value)) + @mock.patch('azure.cli.core._profile.Profile.get_subscription_id', autospec=True) + @mock.patch('azure.cli.command_modules.acr._docker_utils.get_registry_by_name') + @mock.patch('requests.post', autospec=True) + @mock.patch('requests.get', autospec=True) + @mock.patch('azure.cli.core._profile.Profile.get_raw_token') + def test_get_access_credentials_fallback_on_403_without_www_authenticate( + self, mock_get_raw_token, mock_requests_get, mock_requests_post, mock_get_registry_by_name, + mock_get_subscription): + from azure.mgmt.containerregistry.models import Registry, Sku + + registry = Registry(location='westus', sku=Sku(name='Standard')) + login_server = 'testregistry.azurecr.io' + registry.login_server = login_server + mock_get_registry_by_name.return_value = registry, None + + cmd = self._setup_cmd() + mock_get_subscription.return_value = TEST_SUBSCRIPTION + mock_get_raw_token.return_value = ('Bearer', TEST_AAD_ACCESS_TOKEN, {}), TEST_SUBSCRIPTION, TEST_TENANT + + initial_connectivity_response = mock.MagicMock() + initial_connectivity_response.status_code = 200 + initial_connectivity_response.headers = {} + challenge_response = mock.MagicMock() + challenge_response.status_code = 403 + challenge_response.headers = {} + mock_requests_get.side_effect = [initial_connectivity_response, challenge_response] + + token_response = mock.MagicMock() + token_response.status_code = 200 + token_response.headers = {} + token_response.content = json.dumps({ + 'refresh_token': TEST_ACR_REFRESH_TOKEN, + 'access_token': TEST_ACR_ACCESS_TOKEN + }).encode() + mock_requests_post.return_value = token_response + + login_server, username, password = get_access_credentials( + cmd, + 'testregistry', + artifact_repository=TEST_REPOSITORY, + permission=HelmAccessTokenPermission.PULL.value + ) + self.assertEqual((login_server, username, password), ('testregistry.azurecr.io', EMPTY_GUID, TEST_ACR_ACCESS_TOKEN)) + + mock_requests_post.assert_any_call( + 'https://{}/oauth2/exchange'.format(login_server), + urlencode({ + 'grant_type': 'access_token', + 'service': login_server, + 'tenant': TEST_TENANT, + 'access_token': TEST_AAD_ACCESS_TOKEN + }), + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + verify=mock.ANY) + mock_requests_post.assert_any_call( + 'https://{}/oauth2/token'.format(login_server), + urlencode({ + 'grant_type': 'refresh_token', + 'service': login_server, + 'scope': 'artifact-repository:{}:{}'.format(TEST_REPOSITORY, HelmAccessTokenPermission.PULL.value), + 'refresh_token': TEST_ACR_REFRESH_TOKEN + }), + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + verify=mock.ANY) + def _setup_mock_token_requests(self, mock_get_aad_token, mock_requests_get, mock_requests_post, login_server): # Set up AAD token with only access token mock_get_aad_token.return_value = ('Bearer', TEST_AAD_ACCESS_TOKEN, {}), TEST_SUBSCRIPTION, TEST_TENANT @@ -1240,6 +1308,42 @@ def _setup_mock_token_requests(self, mock_get_aad_token, mock_requests_get, mock 'access_token': TEST_ACR_ACCESS_TOKEN}).encode() mock_requests_post.return_value = token_response + @mock.patch('requests.get', autospec=True) + def test_handle_challenge_phase_allows_403_with_www_authenticate(self, mock_requests_get): + challenge_response = mock.MagicMock() + challenge_response.status_code = 403 + challenge_response.headers = { + 'WWW-Authenticate': 'Bearer realm="https://testregistry.azurecr.io/oauth2/token",service="testregistry.azurecr.io"' + } + mock_requests_get.return_value = challenge_response + + token_params = _handle_challenge_phase( + login_server='testregistry.azurecr.io', + repository=TEST_REPOSITORY, + artifact_repository=None, + permission=RepoAccessTokenPermission.METADATA_READ.value + ) + self.assertEqual( + token_params, + {'realm': 'https://testregistry.azurecr.io/oauth2/token', 'service': 'testregistry.azurecr.io'} + ) + + @mock.patch('requests.get', autospec=True) + def test_handle_challenge_phase_rejects_403_without_www_authenticate_for_non_aad_auth(self, mock_requests_get): + challenge_response = mock.MagicMock() + challenge_response.status_code = 403 + challenge_response.headers = {} + mock_requests_get.return_value = challenge_response + + with self.assertRaises(CLIError): + _handle_challenge_phase( + login_server='testregistry.azurecr.io', + repository=TEST_REPOSITORY, + artifact_repository=None, + permission=RepoAccessTokenPermission.METADATA_READ.value, + is_aad_token=False + ) + def _validate_raw_token_request(self, mock_get_raw_token): mock_get_raw_token.assert_called_with(mock.ANY, resource="https://containerregistry.azure.net", subscription=mock.ANY)