From b6cb2d2be859bd84551f92b0e55aafff8da439ba Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 19 Jun 2025 17:28:26 +0100 Subject: [PATCH] Make calls to numtracker async --- pyproject.toml | 2 + src/blueapi/client/numtracker.py | 15 +-- src/blueapi/core/context.py | 4 +- tests/conftest.py | 139 +-------------------- tests/unit_tests/client/test_numtracker.py | 87 +++++++++---- tests/unit_tests/service/test_interface.py | 50 +++++--- uv.lock | 17 +++ 7 files changed, 130 insertions(+), 184 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2a29b112..e22570152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "pyjwt[crypto]", "tomlkit", "graypy>=2.1.0", + "httpx>=0.28.1", ] dynamic = ["version"] license.file = "LICENSE" @@ -53,6 +54,7 @@ dev = [ "pyright<1.1.407", # https://github.com/bluesky/scanspec/issues/190 "pytest-cov", "pytest-asyncio", + "pytest-httpx>=0.35.0", "responses", "ruff", "semver", diff --git a/src/blueapi/client/numtracker.py b/src/blueapi/client/numtracker.py index 3ec7a9ab8..54b9aa4d7 100644 --- a/src/blueapi/client/numtracker.py +++ b/src/blueapi/client/numtracker.py @@ -3,7 +3,7 @@ from pathlib import Path from textwrap import dedent -import requests +import httpx from pydantic import Field, HttpUrl from blueapi.utils import BlueapiBaseModel @@ -60,7 +60,7 @@ def set_headers(self, headers: Mapping[str, str]) -> None: self._headers = headers - def create_scan( + async def create_scan( self, instrument_session: str, instrument: str ) -> NumtrackerScanMutationResponse: """ @@ -92,11 +92,12 @@ def create_scan( """) } - response = requests.post( - self._url.unicode_string(), - headers=self._headers, - json=query, - ) + async with httpx.AsyncClient() as client: + response = await client.post( + self._url.unicode_string(), + headers=self._headers, + json=query, + ) response.raise_for_status() json = response.json() diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 681660076..78728c5f1 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -155,8 +155,8 @@ def __post_init__(self, configuration: ApplicationConfig | None): # local reference so it's available in _update_scan_num numtracker = self.numtracker - def _update_scan_num(md: dict[str, Any]) -> int: - scan = numtracker.create_scan( + async def _update_scan_num(md: dict[str, Any]) -> int: + scan = await numtracker.create_scan( md["instrument_session"], md["instrument"] ) md["data_session_directory"] = str(scan.scan.directory.path) diff --git a/tests/conftest.py b/tests/conftest.py index b2f7cb947..81b79af61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import asyncio import base64 import time -from collections.abc import Iterable from pathlib import Path from textwrap import dedent from typing import Any, cast @@ -20,7 +19,6 @@ from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.trace import get_tracer_provider -from responses.matchers import json_params_matcher from blueapi.config import ApplicationConfig, OIDCConfig from blueapi.service.model import Cache @@ -335,12 +333,9 @@ def mock_jwks_fetch(json_web_keyset: JWK): return patch("jwt.PyJWKClient.fetch_data", mock) -NOT_CONFIGURED_INSTRUMENT = "p100" - - -@pytest.fixture(scope="module") -def mock_numtracker_server() -> Iterable[responses.RequestsMock]: - query_working = { +@pytest.fixture +def nt_query() -> dict[str, str]: + return { "query": dedent(""" mutation{ scan( @@ -358,94 +353,11 @@ def mock_numtracker_server() -> Iterable[responses.RequestsMock]: } """) } - query_400 = { - "query": dedent(""" - mutation{ - scan( - instrument: "p47", - instrumentSession: "ab123" - ) { - directory{ - instrumentSession - instrument - path - } - scanFile - scanNumber - } - } - """) - } - query_500 = { - "query": dedent(""" - mutation{ - scan( - instrument: "p48", - instrumentSession: "ab123" - ) { - directory{ - instrumentSession - instrument - path - } - scanFile - scanNumber - } - } - """) - } - query_key_error = { - "query": dedent(""" - mutation{ - scan( - instrument: "p49", - instrumentSession: "ab123" - ) { - directory{ - instrumentSession - instrument - path - } - scanFile - scanNumber - } - } - """) - } - query_200_with_errors = { - "query": dedent(f""" - mutation{{ - scan( - instrument: "{NOT_CONFIGURED_INSTRUMENT}", - instrumentSession: "ab123" - ) {{ - directory{{ - instrumentSession - instrument - path - }} - scanFile - scanNumber - }} - }} - """) - } - response_with_errors = { - "data": None, - "errors": [ - { - "message": ( - "No configuration available for instrument " - f'"{NOT_CONFIGURED_INSTRUMENT}"' - ), - "locations": [{"line": 3, "column": 5}], - "path": ["scan"], - } - ], - } - working_response = { +@pytest.fixture +def nt_response() -> dict[str, Any]: + return { "data": { "scan": { "scanFile": "p46-11", @@ -458,42 +370,3 @@ def mock_numtracker_server() -> Iterable[responses.RequestsMock]: } } } - empty_response = {} - - with responses.RequestsMock(assert_all_requests_are_fired=False) as requests_mock: - requests_mock.add( - responses.POST, - url="https://numtracker-example.com/graphql", - match=[json_params_matcher(query_working)], - status=200, - json=working_response, - ) - requests_mock.add( - responses.POST, - url="https://numtracker-example.com/graphql", - match=[json_params_matcher(query_400)], - status=400, - json=empty_response, - ) - requests_mock.add( - responses.POST, - url="https://numtracker-example.com/graphql", - match=[json_params_matcher(query_500)], - status=500, - json=empty_response, - ) - requests_mock.add( - responses.POST, - url="https://numtracker-example.com/graphql", - match=[json_params_matcher(query_key_error)], - status=200, - json=empty_response, - ) - requests_mock.add( - responses.POST, - "https://numtracker-example.com/graphql", - match=[json_params_matcher(query_200_with_errors)], - status=200, - json=response_with_errors, - ) - yield requests_mock diff --git a/tests/unit_tests/client/test_numtracker.py b/tests/unit_tests/client/test_numtracker.py index 067653a1f..21373eee3 100644 --- a/tests/unit_tests/client/test_numtracker.py +++ b/tests/unit_tests/client/test_numtracker.py @@ -1,10 +1,9 @@ from pathlib import Path +import httpx import pytest -import responses from pydantic import HttpUrl -from requests import HTTPError -from tests.conftest import NOT_CONFIGURED_INSTRUMENT +from pytest_httpx import HTTPXMock from blueapi.client.numtracker import ( DirectoryPath, @@ -19,11 +18,33 @@ def numtracker() -> NumtrackerClient: return NumtrackerClient(HttpUrl("https://numtracker-example.com/graphql")) -def test_create_scan( - numtracker: NumtrackerClient, - mock_numtracker_server: responses.RequestsMock, +URL = "https://numtracker-example.com/graphql" + +EMPTY = {} + +ERRORS = { + "data": None, + "errors": [ + { + "message": "No configuration available for instrument p46", + "locations": [{"line": 3, "column": 5}], + "path": ["scan"], + } + ], +} + + +async def test_create_scan( + numtracker: NumtrackerClient, httpx_mock: HTTPXMock, nt_query, nt_response ): - scan = numtracker.create_scan("ab123", "p46") + httpx_mock.add_response( + method="POST", + url=URL, + match_json=nt_query, + status_code=200, + json=nt_response, + ) + scan = await numtracker.create_scan("ab123", "p46") assert scan == NumtrackerScanMutationResponse( scan=ScanPaths( scanFile="p46-11", @@ -37,42 +58,54 @@ def test_create_scan( ) -def test_create_scan_raises_400_error( - numtracker: NumtrackerClient, - mock_numtracker_server: responses.RequestsMock, +async def test_create_scan_raises_400_error( + numtracker: NumtrackerClient, httpx_mock: HTTPXMock, nt_query ): + httpx_mock.add_response( + method="POST", url=URL, match_json=nt_query, status_code=400, json=EMPTY + ) with pytest.raises( - HTTPError, - match="400 Client Error: Bad Request for url: https://numtracker-example.com/graphql", + httpx.HTTPStatusError, + match="Client error '400 Bad Request' for url 'https://numtracker-example.com/graphql'", ): - numtracker.create_scan("ab123", "p47") + await numtracker.create_scan("ab123", "p46") -def test_create_scan_raises_500_error( - numtracker: NumtrackerClient, - mock_numtracker_server: responses.RequestsMock, +async def test_create_scan_raises_500_error( + numtracker: NumtrackerClient, httpx_mock: HTTPXMock, nt_query ): + httpx_mock.add_response( + method="POST", url=URL, match_json=nt_query, status_code=500, json=EMPTY + ) with pytest.raises( - HTTPError, - match="500 Server Error: Internal Server Error for url: https://numtracker-example.com/graphql", + httpx.HTTPStatusError, + match="Server error '500 Internal Server Error' for url 'https://numtracker-example.com/graphql'", ): - numtracker.create_scan("ab123", "p48") + await numtracker.create_scan("ab123", "p46") -def test_create_scan_raises_key_error_on_incorrectly_formatted_responses( - numtracker: NumtrackerClient, - mock_numtracker_server: responses.RequestsMock, +async def test_create_scan_raises_key_error_on_incorrectly_formatted_responses( + numtracker: NumtrackerClient, httpx_mock: HTTPXMock, nt_query ): + httpx_mock.add_response( + method="POST", url=URL, match_json=nt_query, status_code=200, json=EMPTY + ) with pytest.raises( KeyError, match="data", ): - numtracker.create_scan("ab123", "p49") + await numtracker.create_scan("ab123", "p46") -def test_create_scan_raises_runtime_error_on_graphql_error( - numtracker: NumtrackerClient, - mock_numtracker_server: responses.RequestsMock, +async def test_create_scan_raises_runtime_error_on_graphql_error( + numtracker: NumtrackerClient, httpx_mock: HTTPXMock, nt_query ): + httpx_mock.add_response( + method="POST", + url=URL, + match_json=nt_query, + status_code=200, + json=ERRORS, + ) with pytest.raises(RuntimeError, match="Numtracker error:"): - numtracker.create_scan("ab123", NOT_CONFIGURED_INSTRUMENT) + await numtracker.create_scan("ab123", "p46") diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index 565ad1c39..4a7017e94 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -1,6 +1,7 @@ import json import uuid from dataclasses import dataclass +from inspect import isawaitable from typing import Any from unittest.mock import ANY, MagicMock, Mock, patch @@ -15,6 +16,7 @@ ) from ophyd_async.epics.motor import Motor from pydantic import HttpUrl +from pytest_httpx import HTTPXMock from stomp.connect import StompConnection11 as Connection from blueapi.client.numtracker import NumtrackerClient @@ -434,8 +436,8 @@ def test_configure_numtracker(): assert nt._url.unicode_string() == "https://numtracker-example.com/graphql" -@patch("blueapi.client.numtracker.requests.post") -def test_headers_are_cleared(mock_post): +@patch("blueapi.client.numtracker.httpx.AsyncClient.post") +async def test_headers_are_cleared(mock_post): mock_response = Mock() mock_post.return_value = mock_response mock_response.raise_for_status.side_effect = None @@ -465,16 +467,18 @@ def test_headers_are_cleared(mock_post): interface.begin_task(task=WorkerTask(task_id=None), pass_through_headers=headers) ctx = interface.context() assert ctx.run_engine.scan_id_source is not None - ctx.run_engine.scan_id_source( + scan_id = ctx.run_engine.scan_id_source( {"instrument_session": "cm12345-1", "instrument": "p46"} ) + assert isawaitable(scan_id) and await scan_id mock_post.assert_called_once() assert mock_post.call_args.kwargs["headers"] == headers interface.begin_task(task=WorkerTask(task_id=None)) - ctx.run_engine.scan_id_source( + scan_id = ctx.run_engine.scan_id_source( {"instrument_session": "cm12345-1", "instrument": "p46"} ) + assert isawaitable(scan_id) and await scan_id assert mock_post.call_count == 2 assert mock_post.call_args.kwargs["headers"] == {} @@ -553,7 +557,9 @@ def test_setup_with_numtracker_raises_if_provider_is_defined_in_device_module(): @patch("blueapi.client.numtracker.NumtrackerClient.create_scan") -def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_scan): +async def test_numtracker_create_scan_called_with_arguments_from_metadata( + mock_create_scan, +): conf = ApplicationConfig( numtracker=NumtrackerConfig( url=HttpUrl("https://numtracker-example.com/graphql") @@ -570,14 +576,24 @@ def test_numtracker_create_scan_called_with_arguments_from_metadata(mock_create_ ctx.numtracker.set_headers(headers) ctx.run_engine.md["instrument_session"] = "ab123" - ctx.run_engine.scan_id_source(ctx.run_engine.md) + scan_id = ctx.run_engine.scan_id_source(ctx.run_engine.md) + assert isawaitable(scan_id) and await scan_id mock_create_scan.assert_called_once_with("ab123", "p46") -def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md( - mock_numtracker_server, +async def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md( + httpx_mock, + nt_query, + nt_response, ): + httpx_mock.add_response( + method="POST", + url="https://numtracker-example.com/graphql", + match_json=nt_query, + status_code=200, + json=nt_response, + ) conf = ApplicationConfig( env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")), numtracker=NumtrackerConfig( @@ -590,21 +606,24 @@ def test_update_scan_num_side_effect_sets_data_session_directory_in_re_md( assert ctx.run_engine.scan_id_source is not None ctx.run_engine.md["instrument_session"] = "ab123" - ctx.run_engine.scan_id_source(ctx.run_engine.md) + scan_id = ctx.run_engine.scan_id_source(ctx.run_engine.md) + assert isawaitable(scan_id) and await scan_id assert ( ctx.run_engine.md["data_session_directory"] == "/exports/mybeamline/data/2025" ) -def test_update_scan_num_side_effect_sets_scan_file_in_re_md( - mock_numtracker_server, +async def test_update_scan_num_side_effect_sets_scan_file_in_re_md( + httpx_mock: HTTPXMock, nt_query, nt_response ): + nt_url = "https://numtracker-example.com/graphql" + httpx_mock.add_response( + method="POST", url=nt_url, match_json=nt_query, json=nt_response + ) conf = ApplicationConfig( env=EnvironmentConfig(metadata=MetadataConfig(instrument="p46")), - numtracker=NumtrackerConfig( - url=HttpUrl("https://numtracker-example.com/graphql") - ), + numtracker=NumtrackerConfig(url=HttpUrl(nt_url)), ) interface.setup(conf) ctx = interface.context() @@ -612,6 +631,7 @@ def test_update_scan_num_side_effect_sets_scan_file_in_re_md( assert ctx.run_engine.scan_id_source is not None ctx.run_engine.md["instrument_session"] = "ab123" - ctx.run_engine.scan_id_source(ctx.run_engine.md) + scan_id = ctx.run_engine.scan_id_source(ctx.run_engine.md) + assert isawaitable(scan_id) and await scan_id assert ctx.run_engine.md["scan_file"] == "p46-11" diff --git a/uv.lock b/uv.lock index 08aaa59f9..1a00e6a3b 100644 --- a/uv.lock +++ b/uv.lock @@ -402,6 +402,7 @@ dependencies = [ { name = "fastapi" }, { name = "gitpython" }, { name = "graypy" }, + { name = "httpx" }, { name = "observability-utils" }, { name = "opentelemetry-distro" }, { name = "opentelemetry-instrumentation-fastapi" }, @@ -432,6 +433,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-httpx" }, { name = "responses" }, { name = "ruff" }, { name = "semver" }, @@ -458,6 +460,7 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.112.0" }, { name = "gitpython" }, { name = "graypy", specifier = ">=2.1.0" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "observability-utils", specifier = ">=0.1.4" }, { name = "opentelemetry-distro", specifier = ">=0.48b0" }, { name = "opentelemetry-instrumentation-fastapi", specifier = ">=0.48b0" }, @@ -488,6 +491,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-httpx", specifier = ">=0.35.0" }, { name = "responses" }, { name = "ruff" }, { name = "semver" }, @@ -3983,6 +3987,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "pytest-httpx" +version = "0.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/5574834da9499066fa1a5ea9c336f94dba2eae02298d36dab192fcf95c86/pytest_httpx-0.36.0.tar.gz", hash = "sha256:9edb66a5fd4388ce3c343189bc67e7e1cb50b07c2e3fc83b97d511975e8a831b", size = 56793, upload-time = "2025-12-02T16:34:57.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/d2/1eb1ea9c84f0d2033eb0b49675afdc71aa4ea801b74615f00f3c33b725e3/pytest_httpx-0.36.0-py3-none-any.whl", hash = "sha256:bd4c120bb80e142df856e825ec9f17981effb84d159f9fa29ed97e2357c3a9c8", size = 20229, upload-time = "2025-12-02T16:34:56.45Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"