diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d643ee..4e3abec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [4.0.7] - 2025-06-19 + +- Support requesting a temporary token (JWT) with region and client ref + ## [4.0.6] - 2025-06-19 - Moved channel_diarization_labels field from realtime transcription config to common class. diff --git a/VERSION b/VERSION index 9eefef7..43beb40 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -4.0.6 \ No newline at end of file +4.0.7 diff --git a/speechmatics/client.py b/speechmatics/client.py index dcc6ff6..bb1c6a7 100644 --- a/speechmatics/client.py +++ b/speechmatics/client.py @@ -587,7 +587,7 @@ async def run( self.connection_settings.generate_temp_token and self.connection_settings.auth_token is not None ): - temp_token = await _get_temp_token(self.connection_settings.auth_token) + temp_token = await _get_temp_token(self.connection_settings) token = f"Bearer {temp_token}" extra_headers["Authorization"] = token @@ -670,18 +670,30 @@ async def send_message(self, message_type: str, data: Optional[Any] = None): raise exc -async def _get_temp_token(api_key): +async def _get_temp_token(connection_settings: ConnectionSettings): """ Used to get a temporary token from management platform api for SaaS users """ version = get_version() - mp_api_url = os.getenv("SM_MANAGEMENT_PLATFORM_URL", "https://mp.speechmatics.com") + mp_api_url = os.getenv("SM_MANAGEMENT_PLATFORM_URL", connection_settings.mp_url) + + assert mp_api_url, "Management platform URL not set" + endpoint = mp_api_url + "/v1/api_keys" params = {"type": "rt", "sm-sdk": f"python-{version}"} - body = {"ttl": 60} - headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + payload: dict[str, Union[str, int]] = {"ttl": 60} + + if connection_settings.region: + payload["region"] = connection_settings.region + if connection_settings.client_ref: + payload["client_ref"] = connection_settings.client_ref + + headers = { + "Authorization": f"Bearer {connection_settings.auth_token}", + "Content-Type": "application/json", + } # pylint: disable=no-member - response = httpx.post(endpoint, json=body, params=params, headers=headers) + response = httpx.post(endpoint, json=payload, params=params, headers=headers) response.raise_for_status() response.read() key_object = response.json() diff --git a/speechmatics/models.py b/speechmatics/models.py index c0539f0..f79899b 100644 --- a/speechmatics/models.py +++ b/speechmatics/models.py @@ -469,6 +469,15 @@ class ConnectionSettings: """Automatically generate a temporary token for authentication. Enterprise customers should set this to False.""" + mp_url: Optional[str] = "https://mp.speechmatics.com" + """Management platform URL for generating temporary tokens.""" + + region: Optional[str] = "eu" + """Region for generating temporary tokens.""" + + client_ref: Optional[str] = None + """Client reference for generating temporary tokens.""" + def set_missing_values_from_config(self, mode: UsageMode): stored_config = read_config_from_home() if self.url is None or self.url == "": diff --git a/tests/test_cli.py b/tests/test_cli.py index 6bc504d..7fb098f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -782,6 +782,7 @@ def test_rt_main_with_multichannel_option(mock_server): ] cli.main(vars(cli.parse_args(args))) + mock_server.wait_for_clean_disconnects() assert mock_server.clients_connected_count == 1 assert mock_server.clients_disconnected_count == 1 diff --git a/tests/test_models.py b/tests/test_models.py index 19d3af4..dafc6f3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -90,6 +90,25 @@ def test_connection_settings_url(url, want): assert got == want +def test_default_jwt_connection_settings(): + connection_settings = models.ConnectionSettings(url="examples.com") + assert connection_settings.generate_temp_token is False + assert connection_settings.region == "eu" + assert connection_settings.client_ref is None + + +def test_custom_jwt_connection_settings(): + connection_settings = models.ConnectionSettings( + url="examples.com", + generate_temp_token=True, + region="usa", + client_ref="test", + ) + assert connection_settings.generate_temp_token is True + assert connection_settings.region == "usa" + assert connection_settings.client_ref == "test" + + @mark.parametrize( "params, want", [