-
Notifications
You must be signed in to change notification settings - Fork 109
feat: add identity client passthrough and tests #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,11 @@ | |
| import boto3 | ||
| from pydantic import BaseModel | ||
|
|
||
| from bedrock_agentcore._utils.endpoints import CP_ENDPOINT_OVERRIDE, DP_ENDPOINT_OVERRIDE | ||
| from bedrock_agentcore._utils.endpoints import ( | ||
| CP_ENDPOINT_OVERRIDE, | ||
| DP_ENDPOINT_OVERRIDE, | ||
| ) | ||
| from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs | ||
|
|
||
|
|
||
| class TokenPoller(ABC): | ||
|
|
@@ -85,15 +89,49 @@ def __init__(self, region: str): | |
| self.dp_client = boto3.client("bedrock-agentcore", **dp_kwargs) | ||
| self.logger = logging.getLogger("bedrock_agentcore.identity_client") | ||
|
|
||
| def create_oauth2_credential_provider(self, req): | ||
| """Create an OAuth2 credential provider.""" | ||
| self.logger.info("Creating OAuth2 credential provider...") | ||
| return self.cp_client.create_oauth2_credential_provider(**req) | ||
|
|
||
| def create_api_key_credential_provider(self, req): | ||
| """Create an API key credential provider.""" | ||
| self.logger.info("Creating API key credential provider...") | ||
| return self.cp_client.create_api_key_credential_provider(**req) | ||
| # Pass-through | ||
| # ------------------------------------------------------------------------- | ||
| _ALLOWED_CP_METHODS = { | ||
| # OAuth2 credential provider CRUD | ||
| "create_oauth2_credential_provider", | ||
| "get_oauth2_credential_provider", | ||
| "list_oauth2_credential_providers", | ||
| "update_oauth2_credential_provider", | ||
| "delete_oauth2_credential_provider", | ||
| # API key credential provider CRUD | ||
| "create_api_key_credential_provider", | ||
| "get_api_key_credential_provider", | ||
| "list_api_key_credential_providers", | ||
| "delete_api_key_credential_provider", | ||
| # Workload identity | ||
| "get_workload_identity", | ||
| "update_workload_identity", | ||
| } | ||
|
|
||
| _ALLOWED_DP_METHODS = { | ||
| "get_resource_oauth2_token", | ||
| "get_resource_api_key", | ||
| "get_workload_access_token", | ||
| "get_workload_access_token_for_jwt", | ||
| "get_workload_access_token_for_user_id", | ||
| } | ||
|
|
||
| def __getattr__(self, name: str): | ||
| """Dynamically forward allowlisted method calls to the appropriate boto3 client.""" | ||
| if name in self._ALLOWED_DP_METHODS and hasattr(self.dp_client, name): | ||
| method = getattr(self.dp_client, name) | ||
| return accept_snake_case_kwargs(method) | ||
|
|
||
| if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name): | ||
| method = getattr(self.cp_client, name) | ||
| return accept_snake_case_kwargs(method) | ||
|
|
||
| raise AttributeError( | ||
| f"'{self.__class__.__name__}' object has no attribute '{name}'. " | ||
| f"Method not found on data plane or control plane client. " | ||
| f"Available methods can be found in the boto3 documentation for " | ||
| f"'bedrock-agentcore' and 'bedrock-agentcore-control' services." | ||
| ) | ||
|
|
||
| def get_workload_access_token( | ||
| self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None | ||
|
|
@@ -125,20 +163,6 @@ def create_workload_identity( | |
| name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls or [] | ||
| ) | ||
|
|
||
| def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> Dict: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that we are proxying it through boto core, does the signature change at all? Would this be a breaking change for customers that already rely on this signature? Same question for the other ones where we switched to passthrough. |
||
| """Update an existing workload identity with allowed resource OAuth2 callback urls.""" | ||
| self.logger.info( | ||
| "Updating workload identity '%s' with callback urls: %s", name, allowed_resource_oauth_2_return_urls | ||
| ) | ||
| return self.cp_client.update_workload_identity( | ||
| name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls | ||
| ) | ||
|
|
||
| def get_workload_identity(self, name: str) -> Dict: | ||
| """Retrieves information about a workload identity.""" | ||
| self.logger.info("Fetching workload identity '%s'", name) | ||
| return self.cp_client.get_workload_identity(name=name) | ||
|
|
||
| def complete_resource_token_auth( | ||
| self, session_uri: str, user_identifier: Union[UserTokenIdentifier, UserIdIdentifier] | ||
| ): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| """Integration tests for IdentityClient passthrough and __getattr__ methods.""" | ||
|
|
||
| import os | ||
| import time | ||
|
|
||
| import pytest | ||
| from botocore.exceptions import ClientError | ||
|
|
||
| from bedrock_agentcore.services.identity import IdentityClient | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| class TestIdentityClientPassthrough: | ||
| """Integration tests for IdentityClient passthrough via __getattr__. | ||
|
|
||
| Tests read-only operations that don't require pre-existing resources. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def setup_class(cls): | ||
| cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") | ||
| cls.client = IdentityClient(region=cls.region) | ||
|
|
||
| @pytest.mark.order(1) | ||
| def test_list_oauth2_credential_providers_passthrough(self): | ||
| response = self.client.list_oauth2_credential_providers() | ||
| assert "credentialProviders" in response | ||
|
|
||
| @pytest.mark.order(2) | ||
| def test_list_api_key_credential_providers_passthrough(self): | ||
| response = self.client.list_api_key_credential_providers() | ||
| assert "credentialProviders" in response | ||
|
|
||
| @pytest.mark.order(3) | ||
| def test_list_oauth2_snake_case(self): | ||
| response = self.client.list_oauth2_credential_providers( | ||
| max_results=10, | ||
| ) | ||
| assert "credentialProviders" in response | ||
|
|
||
| @pytest.mark.order(4) | ||
| def test_get_nonexistent_oauth2_provider(self): | ||
| with pytest.raises(ClientError) as exc_info: | ||
| self.client.get_oauth2_credential_provider( | ||
| name="nonexistent-provider", | ||
| ) | ||
| assert exc_info.value.response["Error"]["Code"] in ( | ||
| "ResourceNotFoundException", | ||
| "AccessDeniedException", | ||
| ) | ||
|
|
||
| @pytest.mark.order(5) | ||
| def test_get_nonexistent_api_key_provider(self): | ||
| with pytest.raises(ClientError) as exc_info: | ||
| self.client.get_api_key_credential_provider( | ||
| name="nonexistent-provider", | ||
| ) | ||
| assert exc_info.value.response["Error"]["Code"] in ( | ||
| "ResourceNotFoundException", | ||
| "AccessDeniedException", | ||
| ) | ||
|
|
||
| @pytest.mark.order(6) | ||
| def test_non_allowlisted_method_raises(self): | ||
| with pytest.raises(AttributeError): | ||
| self.client.not_a_real_method() | ||
|
|
||
|
|
||
| @pytest.mark.integration | ||
| class TestIdentityClientOauth2Crud: | ||
| """Integration tests for OAuth2 credential provider CRUD via passthrough. | ||
|
|
||
| Requires COGNITO_POOL_ID, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def setup_class(cls): | ||
| cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2") | ||
| cls.pool_id = os.environ.get("COGNITO_POOL_ID") | ||
| cls.client_id = os.environ.get("COGNITO_CLIENT_ID") | ||
| cls.client_secret = os.environ.get("COGNITO_CLIENT_SECRET") | ||
| if not all([cls.pool_id, cls.client_id, cls.client_secret]): | ||
| pytest.skip("COGNITO_POOL_ID, COGNITO_CLIENT_ID, and COGNITO_CLIENT_SECRET must all be set") | ||
| cls.client = IdentityClient(region=cls.region) | ||
| cls.discovery_url = ( | ||
| f"https://cognito-idp.{cls.region}.amazonaws.com/{cls.pool_id}/.well-known/openid-configuration" | ||
| ) | ||
| cls.provider_name = f"sdk-integ-{int(time.time())}" | ||
|
|
||
| @classmethod | ||
| def teardown_class(cls): | ||
| try: | ||
| cls.client.delete_oauth2_credential_provider( | ||
| name=cls.provider_name, | ||
| ) | ||
| except Exception as e: | ||
| print(f"Teardown: {e}") | ||
|
|
||
| @pytest.mark.order(10) | ||
| def test_create_oauth2_credential_provider(self): | ||
| self.client.create_oauth2_credential_provider( | ||
| name=self.provider_name, | ||
| credentialProviderVendor="CustomOauth2", | ||
| oauth2ProviderConfigInput={ | ||
| "customOauth2ProviderConfig": { | ||
| "oauthDiscovery": { | ||
| "discoveryUrl": self.discovery_url, | ||
| }, | ||
| "clientId": self.client_id, | ||
| "clientSecret": self.client_secret, | ||
| } | ||
| }, | ||
| ) | ||
| provider = self.client.get_oauth2_credential_provider( | ||
| name=self.provider_name, | ||
| ) | ||
| assert provider["name"] == self.provider_name | ||
|
|
||
| @pytest.mark.order(11) | ||
| def test_get_oauth2_provider_passthrough(self): | ||
| provider = self.client.get_oauth2_credential_provider( | ||
| name=self.provider_name, | ||
| ) | ||
| assert provider["name"] == self.provider_name | ||
|
|
||
| @pytest.mark.order(12) | ||
| def test_delete_oauth2_credential_provider(self): | ||
| self.client.delete_oauth2_credential_provider( | ||
| name=self.provider_name, | ||
| ) | ||
| # Provider may take a moment to delete | ||
| import time | ||
|
|
||
| time.sleep(5) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we poll here? |
||
| with pytest.raises(ClientError): | ||
| self.client.get_oauth2_credential_provider( | ||
| name=self.provider_name, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is defined as an explicit method and a passthrough. I think the passthrough takes precedence? If so, should we delete the old one?