diff --git a/.mypy.ini b/.mypy.ini index 8ff36fda08..8868a875f1 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -77,6 +77,7 @@ modules = azul.service.user_controller, azul.service.user_service, scripts.pull_request, + urllib3_mock, packages = diff --git a/scripts/request_flooder.py b/scripts/request_flooder.py index e64a6a4eb6..5f72da2d4d 100755 --- a/scripts/request_flooder.py +++ b/scripts/request_flooder.py @@ -11,11 +11,12 @@ import sys import time -import requests - from azul.args import ( AzulArgumentHelpFormatter, ) +from azul.http import ( + http_client, +) from azul.lib import ( R, ) @@ -57,10 +58,6 @@ def parse_args(argv): type=int, default='300', help='Total duration of the test in seconds.') - parser.add_argument('--log-headers', - default=False, - action='store_true', - help='Include response headers in log output') args = parser.parse_args(argv) args.method = args.method.upper() assert args.method in ['HEAD', 'GET', 'PUT'], R( @@ -75,16 +72,12 @@ def parse_args(argv): return args -def request_url(method: str, url: str, log_headers: bool) -> int: - log.info('Making %s request to %r', method, url) - start_time = time.time() - response = requests.request(method=method, url=url) - duration = time.time() - start_time - if log_headers: - log.info('… with response headers %r', response.headers) - log.info('Got %i response after %.3fs from %s to %s', - response.status_code, duration, method, url) - return response.status_code +http = http_client(log=log) + + +def request_url(method: str, url: str) -> int: + response = http.request(method=method, url=url) + return response.status def main(argv): @@ -98,7 +91,7 @@ def main(argv): end_time = start_time + args.duration while time.time() < end_time: time.sleep(sleep_delay) - futures.append(tpe.submit(request_url, args.method, args.url, args.log_headers)) + futures.append(tpe.submit(request_url, args.method, args.url)) for f in as_completed(futures): assert f.result() in [200, 429] diff --git a/src/azul/health.py b/src/azul/health.py index 2d82159f13..0bdc613dba 100644 --- a/src/azul/health.py +++ b/src/azul/health.py @@ -26,7 +26,6 @@ from furl import ( furl, ) -import requests from azul import ( CatalogName, @@ -40,6 +39,11 @@ from azul.deployment import ( aws, ) +from azul.http import ( + HTTPStatusError, + HasCachedHttpClient, + raise_on_status, +) from azul.lib import ( R, cache, @@ -176,7 +180,7 @@ def _make_response(self, body: JSON) -> Response: @attr.s(frozen=True, kw_only=True, auto_attribs=True) -class Health: +class Health(HasCachedHttpClient): """ Encapsulates information about the health status of an Azul deployment. All aspects of health are exposed as lazily loaded properties. Instantiating the @@ -262,14 +266,10 @@ def progress(self) -> JSON: def _api_endpoint(self, entity_type: str) -> JSON: relative_url = furl(path=('index', entity_type), args={'size': '1'}) url = str(config.service_endpoint.join(relative_url)) - log.info('Making HEAD request to %s', url) - start = time.time() - response = requests.api.head(url) - log.info('Got %s response after %.3fs from HEAD request to %s', - response.status_code, time.time() - start, url) + response = self._http_client.request('HEAD', url) try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: + raise_on_status(response) + except HTTPStatusError as e: return {'up': False, 'error': repr(e)} else: return {'up': True} @@ -300,9 +300,8 @@ def _lambda(self, lambda_name) -> JSON: try: url = config.lambda_endpoint(lambda_name).set(path='/health/basic', args={'catalog': self.catalog}) - log.info('Requesting %r', url) - response = requests.api.get(str(url)) - response.raise_for_status() + response = self._http_client.request('GET', str(url)) + raise_on_status(response) up = response.json()['up'] except Exception as e: return { diff --git a/src/azul/http.py b/src/azul/http.py index 940acae5b4..9e61db4de8 100644 --- a/src/azul/http.py +++ b/src/azul/http.py @@ -174,6 +174,18 @@ def http_client(log: logging.Logger | None = None) -> HttpClient: return StatusRetryHttpClient(client) +class HTTPStatusError(Exception): + + def __init__(self, url: str | None, status: int, reason: str | None = None): + # URL is intentionally passed as the last arg, as they tend to be long. + super().__init__('Unexpected response status', status, reason, url) + + +def raise_on_status(response: urllib3.BaseHTTPResponse) -> None: + if not 200 <= response.status <= 399: + raise HTTPStatusError(response.url, response.status, response.reason) + + class LimitedTimeoutException(Exception): def __init__(self, url: furl, timeout: float): diff --git a/src/azul/plugins/repository/dss/__init__.py b/src/azul/plugins/repository/dss/__init__.py index 4486fc143f..abd11c04f5 100644 --- a/src/azul/plugins/repository/dss/__init__.py +++ b/src/azul/plugins/repository/dss/__init__.py @@ -18,7 +18,6 @@ from more_itertools import ( one, ) -import requests from azul import ( config, @@ -31,6 +30,7 @@ ) from azul.http import ( HasCachedHttpClient, + raise_on_status, ) from azul.indexer import ( SourcedBundleFQID, @@ -208,7 +208,7 @@ def validate_version(self, version: str) -> None: parse_dcp2_version(version) -class DSSFileDownload(RepositoryFileDownload): +class DSSFileDownload(RepositoryFileDownload, HasCachedHttpClient): _location: str | None = None _retry_after: int | None = None @@ -222,8 +222,8 @@ def update(self, authentication: Authentication | None) -> None: file_version=self.file.version, replica=self.replica, token=self.token) - dss_response = requests.get(dss_url, allow_redirects=False) - if dss_response.status_code == 301: + dss_response = self._http_client.request('GET', dss_url, redirect=False) + if dss_response.status == 301: retry_after = int(dss_response.headers.get('Retry-After')) location = dss_response.headers['Location'] @@ -233,7 +233,7 @@ def update(self, authentication: Authentication | None) -> None: self.replica = one(query['replica']) self.file = attrs.evolve(self.file, version=one(query['version'])) self._retry_after = retry_after - elif dss_response.status_code == 302: + elif dss_response.status == 302: location = dss_response.headers['Location'] # Remove once https://github.com/HumanCellAtlas/data-store/issues/1837 is resolved if True: @@ -256,7 +256,7 @@ def update(self, authentication: Authentication | None) -> None: Params=params) self._location = location else: - dss_response.raise_for_status() + raise_on_status(dss_response) assert False @property diff --git a/src/azul/service/drs_controller.py b/src/azul/service/drs_controller.py index 8ff29c4742..04ffd6ee2d 100644 --- a/src/azul/service/drs_controller.py +++ b/src/azul/service/drs_controller.py @@ -12,6 +12,7 @@ from datetime import ( datetime, ) +import logging from typing import ( Any, ) @@ -27,7 +28,7 @@ from more_itertools import ( one, ) -import requests +import urllib3 from azul import ( config, @@ -38,6 +39,9 @@ drs_object_uri, drs_object_url_path, ) +from azul.http import ( + HasCachedHttpClient, +) from azul.lib import ( cached_property, mutable_furl, @@ -62,8 +66,10 @@ IndexService, ) +log = logging.getLogger(__name__) + -class DRSController(ServiceController): +class DRSController(ServiceController, HasCachedHttpClient): @cached_property def _service(self) -> IndexService: @@ -207,7 +213,7 @@ def get_object(self, file_uuid, query_params): # We only want direct URLs for Google extra_params = dict(query_params, directurl=access_method.replica == 'gcp') response = self._dss_get_file(file_uuid, access_method.replica, **extra_params) - if response.status_code == 301: + if response.status == 301: retry_url = response.headers['location'] query = urllib.parse.urlparse(retry_url).query query = urllib.parse.parse_qs(query, strict_parsing=True) @@ -215,14 +221,14 @@ def get_object(self, file_uuid, query_params): # We use the encoded token string as the key for our access ID. access_id = encode_access_id(token, access_method.replica) drs_object.add_access_method(access_method, access_id=access_id) - elif response.status_code == 302: + elif response.status == 302: retry_url = response.headers['location'] if access_method.replica == 'gcp': assert retry_url.startswith('gs:') drs_object.add_access_method(access_method, url=retry_url) else: # For errors, just proxy DSS response - return Response(response.text, status_code=response.status_code) + return Response(response.data, status_code=response.status) return Response(drs_object.to_json()) def get_object_access(self, access_id, file_uuid, query_params): @@ -240,24 +246,32 @@ def get_object_access(self, access_id, file_uuid, query_params): 'directurl': replica == 'gcp', 'token': token }) - if response.status_code == 301: - headers = {'retry-after': response.headers['retry-after']} + if response.status == 301: + header_name = 'retry-after' + retry_after = response.headers[header_name] # DRS says no body for 202 responses - return Response(body='', status_code=202, headers=headers) - elif response.status_code == 302: + return Response(body='', status_code=202, headers={header_name: retry_after}) + elif response.status == 302: retry_url = response.headers['location'] return Response(self._access_url(retry_url)) else: # For errors, just proxy DSS response - return Response(response.text, status_code=response.status_code) + return Response(response.data, status_code=response.status) - def _dss_get_file(self, file_uuid, replica, **kwargs): + def _dss_get_file(self, + file_uuid, + replica, + **kwargs + ) -> urllib3.BaseHTTPResponse: dss_params = { 'replica': replica, **kwargs } url = self.dss_file_url(file_uuid) - return requests.api.get(str(url), params=dss_params, allow_redirects=False) + return self._http_client.request('GET', + str(url), + fields=dss_params, + redirect=False) @classmethod def dss_file_url(cls, file_uuid: str) -> mutable_furl: @@ -269,7 +283,7 @@ class GatewayTimeoutError(ChaliceViewError): @dataclass -class DRSObject: +class DRSObject(HasCachedHttpClient): """" Used to build up a https://ga4gh.github.io/data-repository-service-schemas/docs/#_drsobject """ @@ -295,7 +309,7 @@ def add_access_method(self, def to_json(self) -> JSON: args = _url_query(replica='aws', version=self.version) url = DRSController.dss_file_url(self.uuid).add(args=args) - headers = requests.api.head(str(url)).headers + headers = self._http_client.request('HEAD', str(url)).headers version = headers['x-dss-version'] if self.version is not None: assert version == self.version diff --git a/src/humancellatlas/data/metadata/helpers/schema_validation.py b/src/humancellatlas/data/metadata/helpers/schema_validation.py index 4c8d9937fc..f0f693f03a 100644 --- a/src/humancellatlas/data/metadata/helpers/schema_validation.py +++ b/src/humancellatlas/data/metadata/helpers/schema_validation.py @@ -15,8 +15,11 @@ Registry, Resource, ) -import requests +from azul.http import ( + HasCachedHttpClient, + raise_on_status, +) from azul.lib import ( R, cached_property, @@ -28,7 +31,7 @@ log = logging.getLogger(__name__) -class SchemaValidator: +class SchemaValidator(HasCachedHttpClient): def validate_json(self, file_json: JSON, file_name: str): try: @@ -45,8 +48,8 @@ def validate_json(self, file_json: JSON, file_name: str): @lru_cache(maxsize=None) def _download_json_file(self, file_url: str) -> JSON: - response = requests.get(file_url, allow_redirects=False) - response.raise_for_status() + response = self._http_client.request('GET', file_url, redirect=False) + raise_on_status(response) return response.json() def _retrieve_resource(self, resource_url: str) -> Resource: diff --git a/test/app_test_case.py b/test/app_test_case.py index 2252f87335..e489e1fcc6 100644 --- a/test/app_test_case.py +++ b/test/app_test_case.py @@ -18,14 +18,16 @@ from furl import ( furl, ) -import requests - +import urllib3 from azul import ( config, ) from azul.chalice import ( AzulChaliceApp, ) +from azul.http import ( + raise_on_status, +) from azul.lib import ( mutable_furl, ) @@ -129,7 +131,7 @@ def setUp(self): while True: try: response = self._ping() - response.raise_for_status() + raise_on_status(response) except Exception: if time.time() > deadline: raise @@ -138,8 +140,14 @@ def setUp(self): else: break - def _ping(self): - return requests.get(str(self.base_url.set(path='/health/basic'))) + def _ping(self) -> urllib3.BaseHTTPResponse: + return self._http_client.urlopen('GET', + str(self.base_url.set(path='/health/basic')), + retries=urllib3.Retry(connect=2, + read=2, + status=0, + redirect=0, + status_forcelist={500})) def chalice_config(self): return ChaliceConfig.create(lambda_timeout=config.api_gateway_lambda_timeout) diff --git a/test/azul_test_case.py b/test/azul_test_case.py index 20d482a261..124d402c82 100644 --- a/test/azul_test_case.py +++ b/test/azul_test_case.py @@ -45,6 +45,7 @@ from opensearchpy.exceptions import ( OpenSearchWarning, ) +import urllib3 from azul import ( CatalogName, @@ -54,6 +55,9 @@ from azul.deployment import ( aws, ) +from azul.http import ( + HasCachedHttpClient, +) from azul.logging import ( configure_test_logging, get_test_logger, @@ -240,8 +244,13 @@ def stacked_patches(self, patches: Iterable[Patch]): context.enter_context(cm) yield + # Disable HTTP retries for tests intentionally returning other status codes + no_retries: urllib3.Retry = urllib3.Retry(status=0, + status_forcelist={500}, + raise_on_status=False) + -class AzulUnitTestCase(AzulTestCase): +class AzulUnitTestCase(AzulTestCase, HasCachedHttpClient): @classmethod def setUpClass(cls) -> None: diff --git a/test/health_check_test_case.py b/test/health_check_test_case.py index 1049b35b33..b84f78789b 100644 --- a/test/health_check_test_case.py +++ b/test/health_check_test_case.py @@ -18,15 +18,15 @@ MagicMock, patch, ) - from furl import ( furl, ) from moto import ( mock_aws, ) -import requests -import responses +from urllib3 import ( + BaseHTTPResponse, +) from app_test_case import ( LocalAppTestCase, @@ -42,6 +42,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.modules import ( load_app_module, @@ -55,6 +56,9 @@ from sqs_test_case import ( SqsTestCase, ) +from urllib3_mock import ( + Urllib3Mock, +) # FIXME: This is inelegant: https://github.com/DataBiosphere/azul/issues/652 @@ -69,6 +73,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class HealthCheckTestCase(LocalAppTestCase, OpenSearchTestCase, StorageServiceTestCase, @@ -76,21 +83,25 @@ class HealthCheckTestCase(LocalAppTestCase, metaclass=ABCMeta): def test_basic(self): - response = requests.get(str(self.base_url.set(path='/health/basic'))) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen('GET', + str(self.base_url.set(path='/health/basic')), + retries=self.no_retries) + self.assertEqual(200, response.status) self.assertEqual({'up': True}, response.json()) def test_validation(self): for path in ['foo', 'opensearch,', ',opensearch', ',', '1']: - response = requests.get(str(self.base_url.set(path=('health', path)))) - self.assertEqual(400, response.status_code) + response = self._http_client.urlopen('GET', + str(self.base_url.set(path=('health', path))), + retries=self.no_retries) + self.assertEqual(400, response.status) @mock_aws def test_health_all_ok(self): self._create_mock_queues() with self._mock(): response = self._test('/health/') - self.assertEqual(200, response.status_code) + self.assertEqual(200, response.status) self.assertEqual({ 'up': True, **self._expected_opensearch(up=True), @@ -120,7 +131,7 @@ def test_health_endpoint_keys(self): with self.subTest(keys=keys): with self._mock(): response = self._test(f'health/{keys}') - self.assertEqual(200, response.status_code) + self.assertEqual(200, response.status) self.assertEqual(expected_response, response.json()) @mock_aws @@ -128,7 +139,7 @@ def test_cached_health(self): # No health object is available in S3 bucket, yielding an error with self._mock(): response = self._test('/health/cached') - self.assertEqual(404, response.status_code) + self.assertEqual(404, response.status) expected_response = { 'Code': 'NotFoundError', 'Message': 'Cached health object does not exist' @@ -141,7 +152,7 @@ def test_cached_health(self): with self._mock(): app.update_health_cache(MagicMock(), MagicMock()) response = self._test('/health/cached') - self.assertEqual(200, response.status_code) + self.assertEqual(200, response.status) # Another failure is observed when the cache health object is older than # 2 minutes @@ -149,7 +160,7 @@ def test_cached_health(self): with patch('time.time', new=lambda: future_time): with self._mock(): response = self._test('/health/cached') - self.assertEqual(500, response.status_code) + self.assertEqual(500, response.status) expected_response = { 'Code': 'ChaliceViewError', 'Message': 'Cached health object is stale' @@ -164,7 +175,7 @@ def test_laziness(self): # The use of subTests ensures that we see the result of both # assertions. In the case of the health endpoint, the body of a 503 # may carry a body with additional information. - self.assertEqual(200, response.status_code) + self.assertEqual(200, response.status) expected_response = {'up': True, **self._expected_other_lambdas(up=True)} self.assertEqual(expected_response, response.json()) @@ -185,7 +196,7 @@ def test_opensearch_down(self): with patch.dict(os.environ, **mock_env): with self._mock(): response = self._test('/health/fast') - self.assertEqual(503, response.status_code) + self.assertEqual(503, response.status) self.assertEqual(self._expected_health(opensearch_up=False), response.json()) def _expected_queues(self, *, up: bool) -> MutableJSON: @@ -214,9 +225,9 @@ def _expected_api_endpoints(self, *, up: bool) -> MutableJSON: } if up else { 'up': up, 'error': ( - "HTTPError('503 Server Error: " - "Service Unavailable for url: " - f"{self._endpoint('/index/bundles?size=1')}')" + "" + "HTTPStatusError('Unexpected response status', 503, 'Service Unavailable', " + f"'{self._endpoint('/index/bundles?size=1')}')" ) } } @@ -267,38 +278,34 @@ def _mock(self, *, endpoints_up: bool = True, lambdas_up: bool = True): with self._mock_service_endpoints(helper, up=endpoints_up): yield - def _test(self, path: str) -> requests.Response: - return requests.get(str(self.base_url.set(path=path))) + def _test(self, path: str) -> BaseHTTPResponse: + return self._http_client.urlopen('GET', + str(self.base_url.set(path=path)), + retries=self.no_retries) def helper(self): - helper = responses.RequestsMock() - helper.add_passthru(str(self.base_url)) - # We originally shared the Requests mock with Moto which had this set - # to False. Because of that, and without noticing, we ended up mocking - # more responses than necessary for some of the tests. Instead of - # rewriting the tests to only mock what is actually used, we simply - # disable the assertion, just like Moto did. - helper.assert_all_requests_are_fired = False - return helper + return Urllib3Mock(Health) def _mock_service_endpoints(self, - helper: responses.RequestsMock, + helper: Urllib3Mock, *, up: bool ) -> ContextManager: - helper.add(responses.Response(method='HEAD', - url=self._endpoint('/index/bundles?size=1'), - status=200 if up else 503, - body='')) + helper.add(method='HEAD', + url=self._endpoint('/index/bundles?size=1'), + status=200 if up else 503, + reason='OK' if up else 'Service Unavailable') # Patching the Health class to use a random generator with a pinned # seed allows us to predict the service endpoint that will be picked # to check the health of the service REST API. return patch.object(Health, '_random', random.Random(x=42)) - def _mock_other_lambdas(self, helper: responses.RequestsMock, *, up: bool): + def _mock_other_lambdas(self, helper: Urllib3Mock, *, up: bool): for lambda_name in self._other_lambda_names(): - url = config.lambda_endpoint(lambda_name).set(path='/health/basic') - helper.add(responses.Response(method='GET', - url=str(url), - status=200 if up else 500, - json={'up': up})) + url = config.lambda_endpoint(lambda_name).set(path='/health/basic', + args={'catalog': self.catalog}) + helper.add(method='GET', + url=str(url), + status=200 if up else 500, + reason='OK' if up else 'Internal Server Error', + body={'up': up}) diff --git a/test/indexer/test_anvil.py b/test/indexer/test_anvil.py index 03e04c4675..c87acbcc5d 100644 --- a/test/indexer/test_anvil.py +++ b/test/indexer/test_anvil.py @@ -42,6 +42,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.plugins.repository import ( tdr_anvil, @@ -69,6 +70,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class DUOSTestCase(TDRTestCase, ABC): def _mock_normal_duos(self): diff --git a/test/indexer/test_health_check.py b/test/indexer/test_health_check.py index a6322facf3..f898a71437 100644 --- a/test/indexer/test_health_check.py +++ b/test/indexer/test_health_check.py @@ -7,6 +7,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul_test_case import ( DCP1TestCase, @@ -21,6 +22,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestIndexerHealthCheck(DCP1TestCase, HealthCheckTestCase): @classmethod @@ -42,7 +46,7 @@ def _expected_health(self, def test_queues_down(self): with self._mock(): response = self._test('/health/fast') - self.assertEqual(503, response.status_code) + self.assertEqual(503, response.status) self.assertEqual(self._expected_health(), response.json()) diff --git a/test/integration_test.py b/test/integration_test.py index 0cbcf8b82d..5488b27699 100644 --- a/test/integration_test.py +++ b/test/integration_test.py @@ -74,7 +74,6 @@ validate, ) import opensearchpy -import requests import urllib3 from azul import ( @@ -102,8 +101,9 @@ AccessMethod, ) from azul.http import ( + HasCachedHttpClient, HttpClient, - http_client, + raise_on_status, ) from azul.indexer import ( SourcedBundleFQID, @@ -371,7 +371,7 @@ def _filter(source: tuple[SourceSpec, SourceConfig]) -> bool: return Source(ref=plugin.resolve_source(source), config=config) -class IndexingIntegrationTest(IntegrationTestCase): +class IndexingIntegrationTest(IntegrationTestCase, HasCachedHttpClient): """ An integration test case that tests indexing of public and managed-access metadata from a random selection of bundles, and the expected effects on the @@ -395,7 +395,7 @@ class IndexingIntegrationTest(IntegrationTestCase): def setUp(self) -> None: super().setUp() - self._plain_http = http_client(log) + self._plain_http = self._http_client self._http = self._plain_http @contextmanager @@ -1851,7 +1851,7 @@ def test_azul_client_error_handling(self): self.assertEqual({expected}, cm.exception.args[1]) -class OpenAPIIntegrationTest(AzulTestCase): +class OpenAPIIntegrationTest(AzulTestCase, HasCachedHttpClient): def test_openapi(self): for component, url in [ @@ -1860,19 +1860,22 @@ def test_openapi(self): ]: with self.subTest(component=component): url.set(path='/') - response = requests.get(str(url)) - self.assertEqual(response.status_code, 200) + response = self._http_client.urlopen(GET, + str(url), + redirect=True, + retries=self.no_retries) + self.assertEqual(response.status, 200) self.assertEqual(response.headers['content-type'], 'text/html') - self.assertGreater(len(response.content), 0) + self.assertGreater(len(response.data), 0) # validate OpenAPI spec url.set(path='/openapi.json') - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen(GET, str(url)) + raise_on_status(response) spec = response.json() validate(spec) -class AzulChaliceLocalIntegrationTest(AzulTestCase): +class AzulChaliceLocalIntegrationTest(AzulTestCase, HasCachedHttpClient): url = furl(scheme='http', host='127.0.0.1', port=8000) server = None server_thread = None @@ -1898,21 +1901,24 @@ def tearDownClass(cls) -> None: super().tearDownClass() def test_local_chalice(self): - response = requests.get(str(self.url)) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen(GET, + str(self.url), + redirect=True, + retries=self.no_retries) + self.assertEqual(200, response.status) def test_local_chalice_health_endpoint(self): url = str(self.url.copy().set(path='health')) - response = requests.get(url) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen(GET, url) + self.assertEqual(200, response.status) catalog = first(config.integration_test_catalogs) def test_local_chalice_index_endpoints(self): url = str(self.url.copy().set(path='repository/sources', query=dict(catalog=self.catalog))) - response = requests.get(url) - self.assertEqual(200, response.status_code, response.content) + response = self._http_client.urlopen(GET, url) + self.assertEqual(200, response.status, response.data) def test_local_filtered_index_endpoints(self): if config.is_hca_enabled(self.catalog): @@ -1925,11 +1931,11 @@ def test_local_filtered_index_endpoints(self): url = str(self.url.copy().set(path='index/files', query=dict(filters=json.dumps(filters), catalog=self.catalog))) - response = requests.get(url) - self.assertEqual(200, response.status_code, response.content) + response = self._http_client.urlopen(GET, url) + self.assertEqual(200, response.status, response.data) -class CanBundleScriptIntegrationTest(IntegrationTestCase): +class CanBundleScriptIntegrationTest(IntegrationTestCase, HasCachedHttpClient): def _test_catalog(self, catalog: config.Catalog): fqid = self.bundle_fqid(catalog.name) @@ -2036,10 +2042,9 @@ def _can_bundle_main(self) -> Callable[[Sequence[str]], None]: return can_bundle.main -class SwaggerResourceIntegrationTest(AzulTestCase): +class SwaggerResourceIntegrationTest(AzulTestCase, HasCachedHttpClient): def test(self): - http = http_client(log) for component, base_url in [ ('service', config.service_endpoint), ('indexer', config.indexer_endpoint) @@ -2055,11 +2060,11 @@ def test(self): ('..%2Fdoes-not-exist', 403), ]: with self.subTest(component=component, file=file): - response = http.request(GET, str(base_url / 'swagger' / file)) + response = self._http_client.request(GET, str(base_url / 'swagger' / file)) self.assertEqual(expected_status, response.status) -class DeployedVersionIntegrationTest(AzulTestCase): +class DeployedVersionIntegrationTest(AzulTestCase, HasCachedHttpClient): def test_version(self): local_status = config.git_status @@ -2068,8 +2073,8 @@ def test_version(self): ('indexer', config.indexer_endpoint) ]: endpoint.set(path='/version') - response = requests.get(str(endpoint)) - self.assertEqual(response.status_code, 200) + response = self._http_client.urlopen(GET, str(endpoint)) + self.assertEqual(response.status, 200) lambda_status = response.json()['git'] self.assertEqual(local_status, lambda_status) @@ -2089,7 +2094,7 @@ def test(self): opensearch.indices.delete(index=[index_name]) -class ResponseHeadersTest(AzulTestCase): +class ResponseHeadersTest(AzulTestCase, HasCachedHttpClient): def test_response_security_headers(self): no_cache = 'no-store' @@ -2105,8 +2110,8 @@ def test_response_security_headers(self): for endpoint in (config.service_endpoint, config.indexer_endpoint): for path, cache_control in test_cases.items(): with self.subTest(endpoint=endpoint, path=path): - response = requests.get(str(endpoint / path)) - response.raise_for_status() + response = self._http_client.urlopen(GET, str(endpoint / path)) + raise_on_status(response) actual_csp = response.headers['Content-Security-Policy'] parsed_csp = CSP.parse(actual_csp) parsed_csp.validate() @@ -2130,12 +2135,13 @@ def test_response_security_headers(self): # expected value. 'Content-Security-Policy': str(parsed_csp) } - self.assertIsSubset(expected_headers.items(), response.headers.items()) + self.assertIsSubset(expected_headers.items(), + set(list(response.headers.items()))) def test_default_4xx_response_headers(self): for endpoint in (config.service_endpoint, config.indexer_endpoint): with self.subTest(endpoint=endpoint): - response = requests.get(str(endpoint / 'does-not-exist')) - self.assertEqual(403, response.status_code) + response = self._http_client.urlopen(GET, str(endpoint / 'does-not-exist')) + self.assertEqual(403, response.status) self.assertIsSubset(AzulChaliceApp.security_headers().items(), - response.headers.items()) + set(list(response.headers.items()))) diff --git a/test/service/test_app_logging.py b/test/service/test_app_logging.py index 15c2aab84d..ca4f2fc44c 100644 --- a/test/service/test_app_logging.py +++ b/test/service/test_app_logging.py @@ -11,20 +11,19 @@ patch, ) -import requests - from azul import ( Config, ) from azul.chalice import ( AzulChaliceApp, - log, + log as chalice_log, ) from azul.lib.types import ( JSON, ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from indexer import ( DCP1CannedBundleTestCase, @@ -39,6 +38,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestServiceAppLogging(DCP1CannedBundleTestCase, WebServiceTestCase): @classmethod @@ -57,6 +59,7 @@ def app_name(cls) -> str: def test_request_logs(self): prefix_len = 1024 + http = self._http_client def filter_body(organ: str) -> JSON: return {'filters': json.dumps({'organ': {'is': [organ]}})} @@ -77,7 +80,7 @@ def filter_body(organ: str) -> JSON: url = self.base_url.set(path='/index/projects') request_headers = {'authorization': 'Bearer foo_token'} if authenticated else {} level = [INFO, DEBUG, DEBUG][debug] - with self.assertLogs(logger=log, level=level) as logs: + with self.assertLogs(logger=chalice_log, level=level) as logs: with patch.object(Config, 'debug', new=PropertyMock(return_value=debug)): if body: request_headers = { @@ -85,17 +88,15 @@ def filter_body(organ: str) -> JSON: 'content-type': 'application/json', **request_headers } - response = requests.get(str(url), + response = http.request('GET', str(url), headers=request_headers, - json=body_json) + body=body or None) logs = [(r.levelno, r.getMessage()) for r in logs.records] body_log_level, body_log_message = logs.pop() # asserted separately request_headers = { 'host': url.netloc, - 'user-agent': 'python-requests/2.33.1', - 'accept-encoding': 'gzip, deflate, zstd', - 'accept': '*/*', - 'connection': 'keep-alive', + 'accept-encoding': 'identity', + 'user-agent': 'python-urllib3/2.6.3', **request_headers, } response_headers = { @@ -146,7 +147,7 @@ def filter_body(organ: str) -> JSON: ], logs ) - body = json.dumps(response.json()) + body = json.dumps(json.loads(response.data)) self.assertGreater(len(body), prefix_len) if debug == 0: expected_log = "… with a response body of type ()" @@ -158,4 +159,4 @@ def filter_body(organ: str) -> JSON: assert False self.assertEqual(expected_log, body_log_message) self.assertEqual(INFO, body_log_level) - self.assertEqual(200, response.status_code) + self.assertEqual(200, response.status) diff --git a/test/service/test_cache_poisoning.py b/test/service/test_cache_poisoning.py index 60f8e8a722..335b63afd4 100644 --- a/test/service/test_cache_poisoning.py +++ b/test/service/test_cache_poisoning.py @@ -5,13 +5,15 @@ patch, ) -import requests - from app_test_case import ( LocalAppTestCase, ) +from azul.http import ( + raise_on_status, +) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.terra import ( TDRClient, @@ -27,6 +29,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class CachePoisoningTestCase(LocalAppTestCase, metaclass=ABCMeta): snapshot_mock = None @@ -50,8 +55,8 @@ def app_name(cls) -> str: def _test(self): url = self.base_url.set(path='/repository/sources') - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) # Note that the test cases are named intentionally to force the order in which diff --git a/test/service/test_drs.py b/test/service/test_drs.py index 7201c36aac..dd8832e1e2 100644 --- a/test/service/test_drs.py +++ b/test/service/test_drs.py @@ -5,9 +5,6 @@ ) import urllib.parse -import requests -import responses - from app_test_case import ( LocalAppTestCase, ) @@ -17,14 +14,19 @@ from azul.drs import ( AccessMethod, ) +from azul.http import ( + raise_on_status, +) from azul.lib.types import ( MutableJSON, ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.service.drs_controller import ( DRSController, + DRSObject, dss_drs_object_uri, dss_drs_object_url, ) @@ -34,6 +36,9 @@ from indexer import ( DCP1CannedBundleTestCase, ) +from urllib3_mock import ( + Urllib3Mock, +) # noinspection PyPep8Naming @@ -41,6 +46,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestDRSEndpoint(DCP1CannedBundleTestCase, LocalAppTestCase): maxDiff = None @@ -67,15 +75,14 @@ def test_drs(self): file_version = '2018-11-02T113344.698028Z' for redirects in (0, 1, 2, 6): with self.subTest(redirects=redirects): - with responses.RequestsMock() as helper: - helper.add_passthru(str(self.base_url)) + with Urllib3Mock(DRSController, DRSObject) as helper: self._mock_responses(helper, redirects, file_uuid, file_version=file_version) # Make first client request url = dss_drs_object_url(file_uuid=file_uuid, file_version=file_version, base_url=self.base_url) - drs_response = requests.get(str(url)) - drs_response.raise_for_status() + drs_response = self._http_client.urlopen('GET', str(url)) + raise_on_status(drs_response) drs_object = drs_response.json() uri = dss_drs_object_uri(file_uuid=file_uuid, file_version='2018-11-02T113344.698028Z') @@ -133,16 +140,16 @@ def test_drs(self): file_version=file_version, base_url=self.base_url, access_id=access_id) - drs_response = requests.get(str(drs_access_url)) - self.assertEqual(drs_response.status_code, 202) - self.assertEqual(drs_response.text, '') + drs_response = self._http_client.urlopen('GET', str(drs_access_url)) + self.assertEqual(drs_response.status, 202) + self.assertEqual(drs_response.data.decode(), '') # The final request should give us just the access URL drs_access_url = dss_drs_object_url(file_uuid=file_uuid, file_version=file_version, base_url=self.base_url, access_id=access_id) - drs_response = requests.get(str(drs_access_url)) - self.assertEqual(drs_response.status_code, 200) + drs_response = self._http_client.urlopen('GET', str(drs_access_url)) + self.assertEqual(drs_response.status, 200) if method['type'] == AccessMethod.https.scheme: self.assertEqual(drs_response.json(), {'url': self.signed_url}) elif method['type'] == AccessMethod.gs.scheme: @@ -151,13 +158,14 @@ def test_drs(self): assert False, f'Access type {method["type"]} is not supported' def _dss_response(self, + helper: Urllib3Mock, file_uuid, file_version, replica, head=False, initial=True, _301=False - ): + ) -> None: request_query = { 'replica': replica, **({'version': file_version} if file_version else {}), @@ -181,45 +189,91 @@ def _dss_response(self, 'retry-after': '1' } if head: - return responses.Response(method=responses.HEAD, url=initial_url, status=200, headers=self.dss_headers) + helper.add(method='HEAD', + url=initial_url, + status=200, + headers=self.dss_headers) else: - return responses.Response(method=responses.GET, - url=initial_url if initial else retry_url, - status=301 if _301 else 302, - headers=headers_301 if _301 else headers_302) + helper.add(method='GET', + url=initial_url if initial else retry_url, + status=301 if _301 else 302, + headers=headers_301 if _301 else headers_302) - def _mock_responses(self, helper, redirects, file_uuid, file_version=None): + def _mock_responses(self, + helper: Urllib3Mock, + redirects, + file_uuid, + file_version=None): assert redirects >= 0 - helper.add_passthru(str(self.base_url)) if redirects == 0: - helper.add(self._dss_response(file_uuid, file_version, 'aws', initial=True, _301=False)) - helper.add(self._dss_response(file_uuid, file_version, 'gcp', initial=True, _301=False)) - helper.add(self._dss_response(file_uuid, file_version, 'aws', head=True)) + self._dss_response(helper, + file_uuid, + file_version, + 'aws', + initial=True, + _301=False) + self._dss_response(helper, + file_uuid, + file_version, + 'gcp', + initial=True, + _301=False) + self._dss_response(helper, file_uuid, file_version, 'aws', head=True) else: - helper.add(self._dss_response(file_uuid, file_version, 'aws', initial=True, _301=True)) - helper.add(self._dss_response(file_uuid, file_version, 'gcp', initial=True, _301=True)) - helper.add(self._dss_response(file_uuid, file_version, 'aws', head=True)) + self._dss_response(helper, + file_uuid, + file_version, + 'aws', + initial=True, + _301=True) + self._dss_response(helper, + file_uuid, + file_version, + 'gcp', + initial=True, + _301=True) + self._dss_response(helper, file_uuid, file_version, 'aws', head=True) redirects -= 1 for _ in range(redirects): - helper.add(self._dss_response(file_uuid, file_version, 'aws', initial=False, _301=True)) - helper.add(self._dss_response(file_uuid, file_version, 'gcp', initial=False, _301=True)) - helper.add(self._dss_response(file_uuid, file_version, 'aws', initial=False, _301=False)) - helper.add(self._dss_response(file_uuid, file_version, 'gcp', initial=False, _301=False)) + self._dss_response(helper, + file_uuid, + file_version, + 'aws', + initial=False, + _301=True) + self._dss_response(helper, + file_uuid, + file_version, + 'gcp', + initial=False, + _301=True) + self._dss_response(helper, + file_uuid, + file_version, + 'aws', + initial=False, + _301=False) + self._dss_response(helper, + file_uuid, + file_version, + 'gcp', + initial=False, + _301=False) def test_data_object_not_found(self): file_uuid = 'NOT_A_GOOD_IDEA' error_body = 'DRS should just proxy the DSS for error responses' - with responses.RequestsMock() as helper: - helper.add_passthru(str(self.base_url)) - url = f'{config.dss_endpoint}/files/{file_uuid}' - helper.add(responses.Response(method=responses.GET, - body=error_body, - url=url, - status=404)) + with Urllib3Mock(DRSController, DRSObject) as helper: + # The controller calls dss_get_file which uses request() with + # fields={'replica': 'aws', 'directurl': False}. RequestMethods + # encodes these fields into the URL for GET requests. + dss_url = f'{config.dss_endpoint}/files/{file_uuid}' + query = urllib.parse.urlencode({'replica': 'aws', 'directurl': False}) + helper.add(method='GET', url=f'{dss_url}?{query}', status=404, body=error_body) url = dss_drs_object_url(file_uuid=file_uuid, base_url=self.base_url) - drs_response = requests.get(str(url)) - self.assertEqual(drs_response.status_code, 404) - self.assertEqual(drs_response.text, error_body) + drs_response = self._http_client.urlopen('GET', str(url)) + self.assertEqual(404, drs_response.status) + self.assertEqual(error_body, drs_response.data.decode()) class TestDRSController(AzulUnitTestCase): diff --git a/test/service/test_health_check.py b/test/service/test_health_check.py index 770123fe05..282b9fa6f6 100644 --- a/test/service/test_health_check.py +++ b/test/service/test_health_check.py @@ -7,6 +7,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul_test_case import ( DCP1TestCase, @@ -21,6 +22,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestServiceHealthCheck(DCP1TestCase, HealthCheckTestCase): @classmethod @@ -42,7 +46,7 @@ def test_all_api_endpoints_down(self): self._create_mock_queues() with self._mock(endpoints_up=False): response = self._test('/health/fast') - self.assertEqual(503, response.status_code) + self.assertEqual(503, response.status) self.assertEqual(self._expected_health(endpoints_up=False), response.json()) diff --git a/test/service/test_index_projects.py b/test/service/test_index_projects.py index 4b3e28e7d3..bee26fe8c8 100644 --- a/test/service/test_index_projects.py +++ b/test/service/test_index_projects.py @@ -1,10 +1,13 @@ from more_itertools import ( one, ) -import requests +from azul.http import ( + raise_on_status, +) from azul.logging import ( configure_test_logging, + get_test_logger, ) from indexer import ( DCP1CannedBundleTestCase, @@ -19,6 +22,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestIndexProjectsEndpoint(DCP1CannedBundleTestCase, WebServiceTestCase): # Set a seed so that we can test the detail response with a stable project ID seed = 123 @@ -43,8 +49,8 @@ def test_projects_response(self): def get_response_json(uuid=None): url = self.base_url.set(path=('index', 'projects', uuid or ''), args=dict(catalog=self.catalog)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) return response.json() def assert_file_type_summaries(hit): diff --git a/test/service/test_index_samples.py b/test/service/test_index_samples.py index b743daa663..1726996a36 100644 --- a/test/service/test_index_samples.py +++ b/test/service/test_index_samples.py @@ -1,7 +1,9 @@ -import requests - +from azul.http import ( + raise_on_status, +) from azul.logging import ( configure_test_logging, + get_test_logger, ) from indexer import ( DCP1CannedBundleTestCase, @@ -16,6 +18,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestIndexSamplesEndpoint(DCP1CannedBundleTestCase, WebServiceTestCase): @classmethod @@ -31,8 +36,8 @@ def tearDownClass(cls): def test_basic_response(self): url = self.base_url.set(path='/index/samples', args=dict(catalog=self.catalog)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() def assert_file_type_summaries(hit): diff --git a/test/service/test_manifest_async.py b/test/service/test_manifest_async.py index 9b492e2b06..3b00f8cdf9 100644 --- a/test/service/test_manifest_async.py +++ b/test/service/test_manifest_async.py @@ -45,6 +45,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.plugins import ( ManifestFormat, @@ -79,6 +80,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + @patch.object(AsyncManifestService, '_sfn') class TestAsyncManifestService(AzulUnitTestCase): generation_id = UUID('1ea94a54-a64d-54f1-8b41-15455fb958db') diff --git a/test/service/test_pagination.py b/test/service/test_pagination.py index 71c7edb746..f9b2c1a50b 100644 --- a/test/service/test_pagination.py +++ b/test/service/test_pagination.py @@ -14,8 +14,10 @@ from more_itertools import ( unzip, ) -import requests +from azul.http import ( + raise_on_status, +) from azul.logging import ( configure_test_logging, get_test_logger, @@ -124,8 +126,8 @@ def sort_field_value(doc): return value def fetch(url): - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response = response.json() values = tuple(map(sort_field_value, response['hits'])) self.assertEqual(values, tuple(sorted(unique(values), reverse=reverse))) diff --git a/test/service/test_repository_files.py b/test/service/test_repository_files.py index bff73a37ae..0d25b558f5 100644 --- a/test/service/test_repository_files.py +++ b/test/service/test_repository_files.py @@ -25,8 +25,6 @@ from google.auth.transport.urllib3 import ( AuthorizedHttp, ) -import requests -import responses import urllib3 from app_test_case import ( @@ -45,6 +43,7 @@ ) from azul.http import ( http_client, + raise_on_status, ) from azul.indexer.mirror_service import ( MirrorService, @@ -57,6 +56,9 @@ from azul.plugins.metadata.hca import ( HCAFile, ) +from azul.plugins.repository.dss import ( + DSSFileDownload, +) from azul.service.index_service import ( IndexService, ) @@ -71,6 +73,9 @@ MirrorTestCase, S3TestCase, ) +from urllib3_mock import ( + Urllib3Mock, +) log = get_test_logger(__name__) @@ -216,8 +221,7 @@ def test(self): ('foo bar.txt', 'grbM6udwp0n/QE/L/RYfjtQCS/U='), ('foo&bar.txt', 'r4C8YxpJ4nXTZh+agBsfhZ2e7fI=')]: with self.subTest(fetch=fetch, file_name=file_name, wait=wait): - with responses.RequestsMock() as helper: - helper.add_passthru(str(self.base_url)) + with Urllib3Mock(DSSFileDownload) as helper: fixed_time = 1547691253.07010 expires = str(round(fixed_time + 3600)) s3_url = furl(url=f'https://{bucket_name}.s3.amazonaws.com', @@ -228,13 +232,13 @@ def test(self): 'x-amz-security-token': 'SOMETOKEN', 'Expires': expires }) - helper.add(responses.Response(method='GET', - url=str(dss_url), - status=301, - headers={ - 'Location': str(dss_url_with_token), - 'Retry-After': '10' - })) + helper.add(method='GET', + url=str(dss_url), + status=301, + headers={ + 'Location': str(dss_url_with_token), + 'Retry-After': '10' + }) azul_url = self.base_url.set(path=['repository', 'files', file_uuid], args=dict(catalog=self.catalog, version=file_version)) if fetch: @@ -250,16 +254,16 @@ def request_azul(url, expect_status): before = time.monotonic() with patch.object(type(aws), 'dss_checkout_bucket', return_value=bucket_name): with patch('time.time', new=lambda: 1547691253.07010): - response = requests.get(url, allow_redirects=False) + response = self._http_client.urlopen('GET', url, redirect=False) if wait and expect_status == 301: self.assertLess(retry_after, time.monotonic() - before) if fetch: - self.assertEqual(200, response.status_code) + self.assertEqual(200, response.status) response = response.json() self.assertEqual(expect_status, response['Status']) else: - if response.status_code != expect_status: - response.raise_for_status() + if response.status != expect_status: + raise_on_status(response) response = dict(response.headers) if expect_retry_after is None: self.assertNotIn('Retry-After', response) @@ -283,10 +287,10 @@ def request_azul(url, expect_status): azul_url.args['sha256'] = file.sha256 self.assertUrlEqual(azul_url, location) - helper.add(responses.Response(method='GET', - url=str(dss_url_with_token), - status=302, - headers={'Location': str(s3_url)})) + helper.add(method='GET', + url=str(dss_url_with_token), + status=302, + headers={'Location': str(s3_url)}) location = request_azul(url=location, expect_status=302) diff --git a/test/service/test_request_builder.py b/test/service/test_request_builder.py index c0c8940b2f..5d2c2ccf8c 100644 --- a/test/service/test_request_builder.py +++ b/test/service/test_request_builder.py @@ -18,6 +18,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.plugins import ( FieldPath, @@ -48,6 +49,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestRequestBuilder(DCP1CannedBundleTestCase, WebServiceTestCase): # Subclass the class under test so we can inject a mock plugin @attr.s(frozen=True, auto_attribs=True) diff --git a/test/service/test_request_validation.py b/test/service/test_request_validation.py index 1bba52f3a8..8952364d31 100644 --- a/test/service/test_request_validation.py +++ b/test/service/test_request_validation.py @@ -3,13 +3,11 @@ from furl import ( furl, ) -import requests -from requests import ( - Response, -) +import urllib3 from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.plugins import ( MetadataPlugin, @@ -30,6 +28,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class RequestParameterValidationTest(DCP1CannedBundleTestCase, WebServiceTestCase): maxDiff = None @@ -57,13 +58,16 @@ def _metadata_plugin(self) -> MetadataPlugin: assert isinstance(plugin, MetadataPlugin) return plugin - def assertResponseStatus(self, url: furl, status: int) -> Response: + def assertResponseStatus(self, + url: furl, + status: int + ) -> urllib3.BaseHTTPResponse: if str(url.path) in {'/manifest/files', '/fetch/manifest/files'}: method = 'PUT' else: method = 'GET' - response = requests.request(method, str(url)) - self.assertEqual(status, response.status_code, response.content) + response = self._http_client.urlopen(method, str(url)) + self.assertEqual(status, response.status, response.data) return response def assertErrorMessage(self, url: furl, status: int, code: str, message: str): diff --git a/test/service/test_response.py b/test/service/test_response.py index 0486d313c9..13f82ff50f 100644 --- a/test/service/test_response.py +++ b/test/service/test_response.py @@ -33,8 +33,6 @@ from more_itertools import ( one, ) -import requests - from app_test_case import ( LocalAppTestCase, ) @@ -47,6 +45,9 @@ from azul.field_type import ( null_str, ) +from azul.http import ( + raise_on_status, +) from azul.indexer import ( BundleFQID, SourcedBundleFQID, @@ -68,6 +69,7 @@ ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.plugins import ( FieldPath, @@ -112,6 +114,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + def parse_url_qs(url) -> dict[str, str]: url_parts = urlparse(url) query_dict = dict(parse_qsl(url_parts.query, keep_blank_values=True)) @@ -1039,9 +1044,10 @@ def test_response_stage_files_file(self): def test_sorting_details(self): for entity_type in 'files', 'samples', 'projects', 'bundles': with self.subTest(entity_type=entity_type): - response = requests.get(str(self.base_url.set(path=('index', entity_type), - args=self._params()))) - response.raise_for_status() + response = self._http_client.urlopen('GET', + str(self.base_url.set(path=('index', entity_type), + args=self._params()))) + raise_on_status(response) response_json = response.json() # Verify default sort field is set correctly self.assertEqual(response_json['pagination']['sort'], @@ -1054,8 +1060,8 @@ def test_transform_request_with_file_url(self): for entity_type in ('files', 'bundles'): with self.subTest(entity_type=entity_type): url = self.base_url.set(path=('index', entity_type), args=self._params()) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() for hit in response_json['hits']: if entity_type == 'files': @@ -1081,8 +1087,8 @@ def test_filter_with_none(self): with self.subTest(test_data=test_data): params = self._params(size=10, filters={'specimenDisease': {'is': test_data}}) url = self.base_url.set(path='/index/samples', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() diseases = { disease @@ -1111,8 +1117,8 @@ def test_filter_by_projectId(self): with self.subTest(entity_type=entity_type): params = self._params(size=2, filters={'projectId': {'is': [test_data['id']]}}) url = self.base_url.set(path=('index', entity_type), args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() for hit in response_json['hits']: for project in hit['projects']: @@ -1131,8 +1137,8 @@ def test_filter_by_contentDescription(self): filters={'contentDescription': {'is': ['RNA sequence']}}, sort='fileName', order='asc') - response = requests.get(str(url), params=params) - response.raise_for_status() + response = self._http_client.request('GET', str(url), fields=params) + raise_on_status(response) response_json = response.json() expected = [ 'Cortex2.CCJ15ANXX.SM2_052318p4_D8.unmapped.1.fastq.gz', @@ -1150,8 +1156,8 @@ def test_translated_facets(self): """ url = self.base_url.set(path='/index/samples', args=(self._params(size=10))) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() facets = response_json['termFacets'] @@ -1177,8 +1183,8 @@ def test_sample(self): for entity_type in 'projects', 'samples', 'files', 'bundles': with self.subTest(entity_type=entity_type): url = self.base_url.set(path=('index', entity_type), args=self._params()) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() if entity_type == 'samples': for hit in response_json['hits']: @@ -1199,8 +1205,8 @@ def test_sample(self): def test_bundles_outer_entity(self): entity_type = 'bundles' url = self.base_url.set(path=('index', entity_type), args=self._params()) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response = response.json() indexed_bundles = set(self.bundles()) self.assertEqual(len(self.bundles()), len(indexed_bundles)) @@ -1299,8 +1305,8 @@ def test_ranged_values(self): order='desc', sort='entryId') url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) actual_hits = [hit['donorOrganisms'] for hit in response.json()['hits']] self.assertElasticEqual(expected_hits, actual_hits) @@ -1312,13 +1318,15 @@ def test_ordering(self): for sort_field, accessor in sort_fields: responses = { - order: requests.get(str(self.base_url.set(path='/index/projects', - args=self._params(order=order, sort=sort_field)))) + order: self._http_client.urlopen('GET', + str(self.base_url.set(path='/index/projects', + args=self._params(order=order, + sort=sort_field)))) for order in ['asc', 'desc'] } hit_sort_values = {} for order, response in responses.items(): - response.raise_for_status() + raise_on_status(response) hit_sort_values[order] = [accessor(hit) for hit in response.json()['hits']] self.assertEqual(hit_sort_values['asc'], @@ -1356,8 +1364,8 @@ def extract_cell_line_types(response_json): sort='cellLineType', order='asc' if ascending else 'desc') url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() actual_values = list(extract_cell_line_types(response_json)) expected = ascending_values if ascending else list(reversed(ascending_values)) @@ -1374,8 +1382,8 @@ def test_multivalued_field_sorting(self): with self.subTest(order=order, reverse=reverse): params = self._params(size=15, sort='laboratory', order=order) url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() laboratories = [] for hit in response_json['hits']: @@ -1692,8 +1700,8 @@ def test_aggregate_date_sort(self): sorted(expected, key=lambda x: (x[0] is None, x[0]))) params = self._params(size=50, sort=field, order=direction) url = self.base_url.set(path=('index', entity_type), args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() actual = [ (dates[field], hit['entryId']) @@ -1866,8 +1874,8 @@ def test_aggregate_date_filter(self): } params = self._params(filters=filters, size=15, sort=field, order='asc') url = self.base_url.set(path=('index', entity_type), args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() actual = [ (dates[field], hit['entryId']) @@ -1898,8 +1906,8 @@ def test_contributors_order(self): # Next assert the order of contributors in the service response url = self.base_url.set(path=('index', 'projects', project_id)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() project = one(response_json['projects']) actual = [r['email'] for r in project['contributors']] @@ -1945,8 +1953,8 @@ def _assert_term_facets(self, project_term_facets: JSON, url: str) -> None: for project_id, term_facets in project_term_facets.items(): with self.subTest(project_id=project_id): params = self._params(filters={'projectId': {'is': [project_id]}}) - response = requests.get(url, params=params) - response.raise_for_status() + response = self._http_client.request('GET', url, fields=params) + raise_on_status(response) response_json = response.json() actual_term_facets = response_json['termFacets'] for facet, terms in term_facets.items(): @@ -2028,11 +2036,11 @@ def test_organism_age_facet_search(self): url = self.base_url.set(path='/index/projects', args=dict(catalog=self.catalog, filters=json.dumps({'organismAge': filters}))) - response = requests.get(str(url)) + response = self._http_client.urlopen('GET', str(url)) if project_id is None: - self.assertTrue(response.status_code, 400) + self.assertTrue(response.status, 400) else: - response.raise_for_status() + raise_on_status(response) response = response.json() hit = one(response['hits']) self.assertEqual(hit['entryId'], project_id) @@ -2048,8 +2056,8 @@ def test_pagination_search_after_search_before(self): """ params = self._params(size=3, sort='workflow', order='asc') url = self.base_url.set(path='/index/samples', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() first_page_next = parse_url_qs(response_json['pagination']['next']) @@ -2071,8 +2079,8 @@ def test_pagination_search_after_search_before(self): self.assertEqual([None, '2d8282f0-6cbb-4d5a-822c-4b01718b4d0d'], json.loads(first_page_next['search_after'])) - response = requests.get(response_json['pagination']['next']) - response.raise_for_status() + response = self._http_client.urlopen('GET', response_json['pagination']['next']) + raise_on_status(response) response_json = response.json() second_page_next = parse_url_qs(response_json['pagination']['next']) second_page_previous = parse_url_qs(response_json['pagination']['previous']) @@ -2096,12 +2104,12 @@ def test_bad_search_after_search_before(self): query_params = self._params(size=1, sort='sampleId', order='asc') url = self.base_url.set(path='/index/samples', args=query_params) # Get page 1 - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() # Get page 2 - response = requests.get(response_json['pagination']['next']) - response.raise_for_status() + response = self._http_client.urlopen('GET', response_json['pagination']['next']) + raise_on_status(response) response_json = response.json() test_cases = { 'search_before': response_json['pagination']['previous'], @@ -2110,17 +2118,17 @@ def test_bad_search_after_search_before(self): for pagination_key, good_url in test_cases.items(): with self.subTest(pagination_key=pagination_key): # Verify URL works before modifying - response = requests.get(good_url) - response.raise_for_status() + response = self._http_client.urlopen('GET', good_url) + raise_on_status(response) # Modify search_… param in URL and verify expected error occurs bad_url = furl(good_url) self.assertIn('"', bad_url.args[pagination_key]) bad_url.args[pagination_key] = bad_url.args[pagination_key].replace('"', '') - response = requests.get(str(bad_url)) + response = self._http_client.urlopen('GET', str(bad_url)) error_msg = f'The {pagination_key!r} parameter is not valid JSON' expected_text = f'{{"Code":"BadRequestError","Message":"{error_msg}"}}' - self.assertEqual(400, response.status_code) - self.assertEqual(expected_text, response.text) + self.assertEqual(400, response.status) + self.assertEqual(expected_text, response.data.decode()) def test_filter_by_publication_title(self): cases = [ @@ -2169,8 +2177,8 @@ def test_filter_by_publication_title(self): } url = self.base_url.set(path='/index/files', args=dict(filters=json.dumps(filters))) - response = requests.get(str(url)) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen('GET', str(url)) + self.assertEqual(200, response.status) self.assertEqual(expected_terms, response.json()['termFacets']['publicationTitle']) files = { @@ -2192,8 +2200,8 @@ def _test(entity_type: str, expect_empty: bool, expect_accessible: bool): with self.subTest(entity_type=entity_type, expect_accessible=expect_accessible): url = str(self.base_url.set(path=('index', entity_type))) - response = requests.get(url) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen('GET', url) + self.assertEqual(200, response.status) hits = response.json()['hits'] if expect_empty: self.assertEqual([], hits) @@ -2218,8 +2226,8 @@ def request_accessions(nested_properties): } }) url = self.base_url.set(path='/index/projects') - response = requests.get(str(url), params=params) - self.assertEqual(200, response.status_code) + response = self._http_client.request('GET', str(url), fields=params) + self.assertEqual(200, response.status) return response.json() cases = [ @@ -2268,8 +2276,8 @@ def test_version(self): azul_git_commit=commit, azul_git_dirty=str(int(dirty))): url = self.base_url.set(path='/version') - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) expected_json = { 'commit': commit, 'dirty': dirty @@ -2303,8 +2311,8 @@ def test_grouping(self): 'catalog': self.catalog, 'filters': json.dumps(filters) } - response = requests.get(str(url), params=params) - response.raise_for_status() + response = self._http_client.request('GET', str(url), fields=params) + raise_on_status(response) response_json = response.json() file_type_summaries = one(response_json['hits'])['fileTypeSummaries'] expected = [ @@ -2477,8 +2485,8 @@ def test_inner_entity_samples(self): 'order': 'asc', } url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() hits = response_json['hits'] self.assertEqual(expected_hits, [hit['samples'] for hit in hits]) @@ -2532,8 +2540,8 @@ def test_project_cell_count(self): for entity_type in expected_cell_counts.keys(): with self.subTest(entity_type=entity_type): url = self.base_url.set(path=('index', entity_type), args=params) - response = requests.get(url) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() actual_cell_counts = [] for hit in response_json['hits']: @@ -2545,8 +2553,8 @@ def test_project_cell_count(self): def test_summary_cell_counts(self): url = self.base_url.set(path='/index/summary', args=dict(catalog=self.catalog)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) summary = response.json() self.assertEqual(1, summary['projectCount']) self.assertEqual(10, summary['fileCount']) @@ -2574,8 +2582,8 @@ def test_protocols(self): """ params = {'catalog': self.catalog} url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(url) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() hit = one(response_json['hits']) expected_protocols = [ @@ -2710,8 +2718,8 @@ def test_sorting_by_cell_count(self): 'order': 'asc' if ascending else 'desc' } url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response = response.json() actual = list(map(CellCounts.from_response, response['hits'])) if not ascending: @@ -2777,8 +2785,8 @@ def test_filter_by_cell_count(self): 'filters': json.dumps(filters) } url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response = response.json() actual = list(map(CellCounts.from_response, response['hits'])) self.assertEqual(actual, expected) @@ -2848,8 +2856,8 @@ def test_file_source_facet(self): """ params = self.params(project_id='8185730f-4113-40d3-9cc3-929271784c2b') url = self.base_url.set(path='/index/files', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() facets = response_json['termFacets'] expected_counts = { @@ -2875,8 +2883,8 @@ def test_is_intermediate_facet(self): """ params = self.params(project_id='8185730f-4113-40d3-9cc3-929271784c2b') url = self.base_url.set(path='/index/files', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() facets = response_json['termFacets'] expected = [ @@ -2940,8 +2948,8 @@ def test_contributor_matrix_files(self): facet=facet, value=value) url = self.base_url.set(path='/index/files', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() actual_files = [one(hit['files'])['name'] for hit in response_json['hits']] self.assertEqual(sorted(expected_files), sorted(actual_files)) @@ -2955,8 +2963,8 @@ def test_matrices_tree(self): url = self.base_url.set(path='/index/projects', args=params) drs_uri = furl(scheme='drs', netloc=config.drs_domain or config.api_lambda_domain('service')) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() hit = one(response_json['hits']) self.assertEqual('8185730f-4113-40d3-9cc3-929271784c2b', hit['entryId']) @@ -3325,8 +3333,8 @@ def test_matrix_cell_count(self): for endpoint in ('projects', 'samples'): with self.subTest(endpoint=endpoint): url = self.base_url.set(path=('index', endpoint), args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() for hit in response_json['hits']: actual_counts = { @@ -3341,8 +3349,8 @@ def test_matrix_cell_count(self): } actual_counts = Counter() url = self.base_url.set(path='/index/files', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() for hit in response_json['hits']: file = one(hit['files']) @@ -3390,8 +3398,8 @@ def tearDownClass(cls): def test_summary_response(self): url = self.base_url.set(path='/index/summary', args=dict(catalog=self.catalog)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) summary = response.json() self.assertEqual(1 + 1 + 1 + 1, summary['projectCount']) self.assertEqual(1 + 1 + 3 + 1, summary['specimenCount']) @@ -3483,8 +3491,8 @@ def test_filtered_summary_cell_counts(self): url = self.base_url.set(path='/index/summary', args=dict(catalog=self.catalog, filters=json.dumps(filters))) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) summary = response.json() self.assertElasticEqual(expected_projects, summary['projects']) @@ -3495,8 +3503,8 @@ def test_summary_filter_none(self): if use_filter: params['filters'] = json.dumps({"organPart": {"is": [None]}}) url = self.base_url.set(path='/index/summary', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) summary_object = response.json() self.assertEqual(summary_object['labCount'], labCount) @@ -3514,8 +3522,8 @@ def test_projects_response(self): }) } url = self.base_url.set(path='/index/projects', args=params) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() project = one(one(response_json['hits'])['projects']) expected_contributors = [ @@ -3595,8 +3603,8 @@ def test_data_use_and_duos_id(self): plugin = self.index_service.metadata_plugin(self.catalog) for entity_type in plugin.exposed_indices: url = self.base_url.set(path=('index', entity_type), args=params) - response = requests.get(url) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response = response.json() if field != 'duosId': facets = response['termFacets'] @@ -3658,8 +3666,8 @@ def test_empty_response(self): with self.subTest(entity_type=entity_type): url = self.base_url.set(path=('index', entity_type), args=dict(order='asc')) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) sort_field = self._metadata_plugin.exposed_indices[entity_type].field_name expected_response = { 'hits': [], @@ -3693,8 +3701,8 @@ def test_sorted_responses(self): with self.subTest(entity=entity_type, field=field): url = self.base_url.set(path=('index', entity_type), args=dict(sort=field)) - response = requests.get(str(url)) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen('GET', str(url)) + self.assertEqual(200, response.status) class TestListCatalogsResponse(DCP1CannedBundleTestCase, LocalAppTestCase): @@ -3705,8 +3713,9 @@ def app_name(cls) -> str: return 'service' def test(self): - response = requests.get(str(self.base_url.set(path='/index/catalogs'))) - self.assertEqual(200, response.status_code) + response = self._http_client.urlopen('GET', + str(self.base_url.set(path='/index/catalogs'))) + self.assertEqual(200, response.status) self.assertEqual({ 'default_catalog': 'test', 'catalogs': { @@ -3773,8 +3782,8 @@ def bundles(cls) -> list[SourcedBundleFQID]: def test_tdr_sources(self): url = self.base_url.set(path='/index/projects') - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() plugin = self.index_service.metadata_plugin(self.catalog) special_fields = plugin.special_fields @@ -3790,8 +3799,8 @@ def test_tdr_sources(self): def get_file(self, entry_id: str) -> JSON: url = self.base_url.set(path=('index', 'files', entry_id)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) return one(response.json()['files']) def test_file_urls(self): @@ -3816,8 +3825,8 @@ def test_contributed_analyses_matrix(self): self.maxDiff = None project_id = '9b876d31-0739-4e96-9846-f76e6a427279' url = self.base_url.set(path=('index', 'projects', project_id)) - response = requests.get(str(url)) - response.raise_for_status() + response = self._http_client.urlopen('GET', str(url)) + raise_on_status(response) response_json = response.json() project = one(response_json['projects']) file_url = str(self.base_url.set( diff --git a/test/service/test_response_anvil.py b/test/service/test_response_anvil.py index a4899e12f3..f87eaebd38 100644 --- a/test/service/test_response_anvil.py +++ b/test/service/test_response_anvil.py @@ -1,13 +1,15 @@ -import requests - from azul.deployment import ( aws, ) +from azul.http import ( + raise_on_status, +) from azul.lib.types import ( JSON, ) from azul.logging import ( configure_test_logging, + get_test_logger, ) from azul.plugins.repository.tdr_anvil import ( TDRAnvilBundleFQID, @@ -25,6 +27,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestAnvilResponse(AnvilIndexerTestCase, WebServiceTestCase): @classmethod @@ -1540,7 +1545,7 @@ def test_summary(self): self._assertResponse(url, expected_response) def _assertResponse(self, url: str, expected_response: JSON): - response = requests.get(url) - response.raise_for_status() + response = self._http_client.urlopen('GET', url) + raise_on_status(response) response = response.json() self.assertEqual(expected_response, response) diff --git a/test/test_app_logging.py b/test/test_app_logging.py index 55429dc09c..58bf6887ed 100644 --- a/test/test_app_logging.py +++ b/test/test_app_logging.py @@ -18,7 +18,6 @@ from more_itertools import ( one, ) -import requests from app_test_case import ( ChaliceServerThread, @@ -30,6 +29,7 @@ from azul.logging import ( azul_log_level, configure_test_logging, + get_test_logger, ) from azul_test_case import ( AzulUnitTestCase, @@ -41,6 +41,9 @@ def setUpModule(): configure_test_logging() +log = get_test_logger(__name__) + + class TestAppLogging(AzulUnitTestCase): def test(self): @@ -66,23 +69,23 @@ def fail(): host, port = server_thread.address with self.assertLogs(app.log, level=log_level) as app_log: with self.assertLogs(azul.log, level=log_level) as azul_log: - response = requests.get(f'http://{host}:{port}{path}') + response = self._http_client.urlopen('GET', + f'http://{host}:{port}{path}', + retries=self.no_retries) finally: server_thread.kill_thread() server_thread.join(timeout=10) if server_thread.is_alive(): self.fail('Thread is still alive after joining') - self.assertEqual(500, response.status_code) + self.assertEqual(500, response.status) # The request is always logged self.assertEqual(5, len(azul_log.output)) info = { 'host': f'{host}:{port}', - 'user-agent': 'python-requests/2.33.1', - 'accept-encoding': 'gzip, deflate, zstd', - 'accept': '*/*', - 'connection': 'keep-alive' + 'accept-encoding': 'identity', + 'user-agent': 'python-urllib3/2.6.3', } self.assertEqual(f'INFO:azul.chalice:Received GET request for {path!r}, ' f"with {json.dumps({'query': None, 'headers': info})}.", @@ -99,7 +102,7 @@ def fail(): self.assertIn(magic_message, app_log.output[0]) self.assertIn(traceback_header, app_log.output[0]) - body = response.content.decode() + body = response.data.decode() if debug < 2: # We don't allow stacktraces in error responses … self.assertNotIn(traceback_header, body) diff --git a/test/urllib3_mock.py b/test/urllib3_mock.py new file mode 100644 index 0000000000..e3b69975b4 --- /dev/null +++ b/test/urllib3_mock.py @@ -0,0 +1,124 @@ +from collections import ( + defaultdict, + deque, +) +import json +from unittest.mock import ( + PropertyMock, + patch, +) + +from urllib3 import ( + BaseHTTPResponse, + HTTPResponse, +) + +from azul.http import ( + HttpClient, +) +from azul.lib import ( + mutable_furl, +) +from azul.lib.types import ( + JSON, +) + +type _QueuedResponses = dict[tuple[str, str], deque[BaseHTTPResponse]] + + +class Urllib3Mock: + """ + Context manager that patches the ``_http_client`` property of one or more + target classes with a mock client that returns previously queued mock + responses. Each distinct combination of HTTP method and URL has a separate + queue. When the patched target's mock client makes a request matching one of + those combinations, the mock responses are returned in the order they were + queued in. + """ + + def __init__(self, *targets: type) -> None: + """ + Create the context manager instance that patches the given targets on + entry, and restores them on exit. + """ + self._responses: _QueuedResponses = defaultdict(deque) + self._client = _MockHttpClient(self._responses) + self._patches = deque( + patch.object(target, + '_http_client', + new=PropertyMock(return_value=self._client)) + for target in targets + ) + + def add(self, + *, + method: str, + url: str, + status: int, + headers: dict[str, str] | None = None, + body: bytes | str | JSON = b'', + reason: str | None = None, + ) -> None: + """ + Queue a mock response for the given combination of HTTP method and URL. + If multiple responses are queued for the same combination, they are + returned in the order they were queued in. + + :param method: the request method, e.g. 'GET' + + :param url: the request URL + + :param status: the status of the returned response + + :param headers: the headers of the returned response + + :param body: the body of the returned response + + :param reason: optional text to follow the numeric response status + """ + if headers is None: + headers = {} + if isinstance(body, dict): + body = json.dumps(body) + headers['Content-Type'] = 'application/json' + if isinstance(body, str): + body = body.encode() + assert isinstance(body, bytes), type(body) + response = HTTPResponse(body=body, + headers=headers, + status=status, + reason=reason, + request_method=method, + request_url=url, + preload_content=True) + key = (method, _normalize_url(url)) + self._responses[key].append(response) + + def __enter__(self): + for p in self._patches: + p.start() + return self + + def __exit__(self, *_args): + for p in reversed(self._patches): + p.stop() + self._responses.clear() + + +class _MockHttpClient(HttpClient): + + def __init__(self, responses: _QueuedResponses) -> None: + super().__init__() + self._responses = responses + + def urlopen(self, method: str, url: str, *args, **kwargs) -> BaseHTTPResponse: + key = (method, _normalize_url(url)) + responses = self._responses[key] + assert responses, f'No responses queued for {key!r}' + return responses.popleft() + + +def _normalize_url(url: str) -> str: + url = mutable_furl(url) + url.set(args=dict(sorted(url.args.items()))) + return str(url)