Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions src/bedrock_agentcore/services/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import boto3
from pydantic import BaseModel

from bedrock_agentcore._utils.endpoints import CP_ENDPOINT_OVERRIDE, DP_ENDPOINT_OVERRIDE
from bedrock_agentcore._utils.endpoints import (
CP_ENDPOINT_OVERRIDE,
DP_ENDPOINT_OVERRIDE,
)
from bedrock_agentcore._utils.snake_case import accept_snake_case_kwargs


class TokenPoller(ABC):
Expand Down Expand Up @@ -85,15 +89,49 @@ def __init__(self, region: str):
self.dp_client = boto3.client("bedrock-agentcore", **dp_kwargs)
self.logger = logging.getLogger("bedrock_agentcore.identity_client")

def create_oauth2_credential_provider(self, req):
"""Create an OAuth2 credential provider."""
self.logger.info("Creating OAuth2 credential provider...")
return self.cp_client.create_oauth2_credential_provider(**req)

def create_api_key_credential_provider(self, req):
"""Create an API key credential provider."""
self.logger.info("Creating API key credential provider...")
return self.cp_client.create_api_key_credential_provider(**req)
# Pass-through
# -------------------------------------------------------------------------
_ALLOWED_CP_METHODS = {
# OAuth2 credential provider CRUD
"create_oauth2_credential_provider",
"get_oauth2_credential_provider",
"list_oauth2_credential_providers",
"update_oauth2_credential_provider",
"delete_oauth2_credential_provider",
# API key credential provider CRUD
"create_api_key_credential_provider",
"get_api_key_credential_provider",
"list_api_key_credential_providers",
"delete_api_key_credential_provider",
# Workload identity
"get_workload_identity",
"update_workload_identity",
}

_ALLOWED_DP_METHODS = {
"get_resource_oauth2_token",
"get_resource_api_key",
"get_workload_access_token",
"get_workload_access_token_for_jwt",
"get_workload_access_token_for_user_id",
}

def __getattr__(self, name: str):
"""Dynamically forward allowlisted method calls to the appropriate boto3 client."""
if name in self._ALLOWED_DP_METHODS and hasattr(self.dp_client, name):
method = getattr(self.dp_client, name)
return accept_snake_case_kwargs(method)

if name in self._ALLOWED_CP_METHODS and hasattr(self.cp_client, name):
method = getattr(self.cp_client, name)
return accept_snake_case_kwargs(method)

raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'. "
f"Method not found on data plane or control plane client. "
f"Available methods can be found in the boto3 documentation for "
f"'bedrock-agentcore' and 'bedrock-agentcore-control' services."
)

def get_workload_access_token(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is defined as an explicit method and a passthrough. I think the passthrough takes precedence? If so, should we delete the old one?

self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None
Expand Down Expand Up @@ -125,20 +163,6 @@ def create_workload_identity(
name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls or []
)

def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> Dict:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we are proxying it through boto core, does the signature change at all? Would this be a breaking change for customers that already rely on this signature?

Same question for the other ones where we switched to passthrough.

"""Update an existing workload identity with allowed resource OAuth2 callback urls."""
self.logger.info(
"Updating workload identity '%s' with callback urls: %s", name, allowed_resource_oauth_2_return_urls
)
return self.cp_client.update_workload_identity(
name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls
)

def get_workload_identity(self, name: str) -> Dict:
"""Retrieves information about a workload identity."""
self.logger.info("Fetching workload identity '%s'", name)
return self.cp_client.get_workload_identity(name=name)

def complete_resource_token_auth(
self, session_uri: str, user_identifier: Union[UserTokenIdentifier, UserIdIdentifier]
):
Expand Down
59 changes: 37 additions & 22 deletions tests/bedrock_agentcore/services/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,44 +32,54 @@ def test_initialization(self):
)

def test_create_oauth2_credential_provider(self):
"""Test OAuth2 credential provider creation."""
"""Test OAuth2 credential provider creation via passthrough."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client
mock_cp_client = Mock()
mock_dp_client = Mock()
mock_boto_client.side_effect = [mock_cp_client, mock_dp_client]

identity_client = IdentityClient(region)

# Test data
req = {"name": "test-provider", "clientId": "test-client"}
expected_response = {"providerId": "test-provider-id"}
mock_client.create_oauth2_credential_provider.return_value = expected_response
mock_cp_client.create_oauth2_credential_provider.return_value = expected_response

result = identity_client.create_oauth2_credential_provider(req)
result = identity_client.create_oauth2_credential_provider(
name="test-provider",
clientId="test-client",
)

assert result == expected_response
mock_client.create_oauth2_credential_provider.assert_called_once_with(**req)
mock_cp_client.create_oauth2_credential_provider.assert_called_once_with(
name="test-provider",
clientId="test-client",
)

def test_create_api_key_credential_provider(self):
"""Test API key credential provider creation."""
"""Test API key credential provider creation via passthrough."""
region = "us-west-2"

with patch("boto3.client") as mock_boto_client:
mock_client = Mock()
mock_boto_client.return_value = mock_client
mock_cp_client = Mock()
mock_dp_client = Mock()
mock_boto_client.side_effect = [mock_cp_client, mock_dp_client]

identity_client = IdentityClient(region)

# Test data
req = {"name": "test-api-provider", "apiKeyName": "test-key"}
expected_response = {"providerId": "test-api-provider-id"}
mock_client.create_api_key_credential_provider.return_value = expected_response
mock_cp_client.create_api_key_credential_provider.return_value = expected_response

result = identity_client.create_api_key_credential_provider(req)
result = identity_client.create_api_key_credential_provider(
name="test-api-provider",
apiKeyName="test-key",
)

assert result == expected_response
mock_client.create_api_key_credential_provider.assert_called_once_with(**req)
mock_cp_client.create_api_key_credential_provider.assert_called_once_with(
name="test-api-provider",
apiKeyName="test-key",
)

@pytest.mark.asyncio
async def test_get_token_direct_response(self):
Expand Down Expand Up @@ -505,11 +515,15 @@ def test_update_workload_identity(self):

mock_cp_client.update_workload_identity.return_value = expected_response

result = identity_client.update_workload_identity(workload_name, allowed_urls)
result = identity_client.update_workload_identity(
name=workload_name,
allowedResourceOauth2ReturnUrls=allowed_urls,
)

assert result == expected_response
mock_cp_client.update_workload_identity.assert_called_once_with(
name=workload_name, allowedResourceOauth2ReturnUrls=allowed_urls
name=workload_name,
allowedResourceOauth2ReturnUrls=allowed_urls,
)

def test_get_workload_identity(self):
Expand All @@ -523,15 +537,16 @@ def test_get_workload_identity(self):
identity_client = IdentityClient(region)

workload_name = "test-workload"
allowed_urls = ["https://unit-test.com/callback", "https://test.com/oauth"]
expected_response = {"name": workload_name, "allowedResourceOauth2ReturnUrls": allowed_urls}
expected_response = {"name": workload_name}

mock_cp_client.get_workload_identity.return_value = expected_response

result = identity_client.get_workload_identity(workload_name)
result = identity_client.get_workload_identity(name=workload_name)

assert result == expected_response
mock_cp_client.get_workload_identity.assert_called_once_with(name=workload_name)
mock_cp_client.get_workload_identity.assert_called_once_with(
name=workload_name,
)

def test_complete_resource_token_auth_with_user_id(self):
region = "us-west-2"
Expand Down
138 changes: 138 additions & 0 deletions tests_integ/identity/test_identity_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Integration tests for IdentityClient passthrough and __getattr__ methods."""

import os
import time

import pytest
from botocore.exceptions import ClientError

from bedrock_agentcore.services.identity import IdentityClient


@pytest.mark.integration
class TestIdentityClientPassthrough:
"""Integration tests for IdentityClient passthrough via __getattr__.

Tests read-only operations that don't require pre-existing resources.
"""

@classmethod
def setup_class(cls):
cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2")
cls.client = IdentityClient(region=cls.region)

@pytest.mark.order(1)
def test_list_oauth2_credential_providers_passthrough(self):
response = self.client.list_oauth2_credential_providers()
assert "credentialProviders" in response

@pytest.mark.order(2)
def test_list_api_key_credential_providers_passthrough(self):
response = self.client.list_api_key_credential_providers()
assert "credentialProviders" in response

@pytest.mark.order(3)
def test_list_oauth2_snake_case(self):
response = self.client.list_oauth2_credential_providers(
max_results=10,
)
assert "credentialProviders" in response

@pytest.mark.order(4)
def test_get_nonexistent_oauth2_provider(self):
with pytest.raises(ClientError) as exc_info:
self.client.get_oauth2_credential_provider(
name="nonexistent-provider",
)
assert exc_info.value.response["Error"]["Code"] in (
"ResourceNotFoundException",
"AccessDeniedException",
)

@pytest.mark.order(5)
def test_get_nonexistent_api_key_provider(self):
with pytest.raises(ClientError) as exc_info:
self.client.get_api_key_credential_provider(
name="nonexistent-provider",
)
assert exc_info.value.response["Error"]["Code"] in (
"ResourceNotFoundException",
"AccessDeniedException",
)

@pytest.mark.order(6)
def test_non_allowlisted_method_raises(self):
with pytest.raises(AttributeError):
self.client.not_a_real_method()


@pytest.mark.integration
class TestIdentityClientOauth2Crud:
"""Integration tests for OAuth2 credential provider CRUD via passthrough.

Requires COGNITO_POOL_ID, COGNITO_CLIENT_ID, COGNITO_CLIENT_SECRET.
"""

@classmethod
def setup_class(cls):
cls.region = os.environ.get("BEDROCK_TEST_REGION", "us-west-2")
cls.pool_id = os.environ.get("COGNITO_POOL_ID")
cls.client_id = os.environ.get("COGNITO_CLIENT_ID")
cls.client_secret = os.environ.get("COGNITO_CLIENT_SECRET")
if not all([cls.pool_id, cls.client_id, cls.client_secret]):
pytest.skip("COGNITO_POOL_ID, COGNITO_CLIENT_ID, and COGNITO_CLIENT_SECRET must all be set")
cls.client = IdentityClient(region=cls.region)
cls.discovery_url = (
f"https://cognito-idp.{cls.region}.amazonaws.com/{cls.pool_id}/.well-known/openid-configuration"
)
cls.provider_name = f"sdk-integ-{int(time.time())}"

@classmethod
def teardown_class(cls):
try:
cls.client.delete_oauth2_credential_provider(
name=cls.provider_name,
)
except Exception as e:
print(f"Teardown: {e}")

@pytest.mark.order(10)
def test_create_oauth2_credential_provider(self):
self.client.create_oauth2_credential_provider(
name=self.provider_name,
credentialProviderVendor="CustomOauth2",
oauth2ProviderConfigInput={
"customOauth2ProviderConfig": {
"oauthDiscovery": {
"discoveryUrl": self.discovery_url,
},
"clientId": self.client_id,
"clientSecret": self.client_secret,
}
},
)
provider = self.client.get_oauth2_credential_provider(
name=self.provider_name,
)
assert provider["name"] == self.provider_name

@pytest.mark.order(11)
def test_get_oauth2_provider_passthrough(self):
provider = self.client.get_oauth2_credential_provider(
name=self.provider_name,
)
assert provider["name"] == self.provider_name

@pytest.mark.order(12)
def test_delete_oauth2_credential_provider(self):
self.client.delete_oauth2_credential_provider(
name=self.provider_name,
)
# Provider may take a moment to delete
import time

time.sleep(5)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we poll here?

with pytest.raises(ClientError):
self.client.get_oauth2_credential_provider(
name=self.provider_name,
)
Loading