diff --git a/src/azul/es.py b/src/azul/es.py index 70ab428dac..aa6ee30334 100644 --- a/src/azul/es.py +++ b/src/azul/es.py @@ -1,24 +1,36 @@ +from abc import ( + ABCMeta, + abstractmethod, +) from collections.abc import ( Collection, + Iterator, ) +import json import logging from typing import ( Any, Mapping, cast, ) +import unittest.mock from urllib.parse import ( urlencode, ) +import attrs from aws_requests_auth.boto_utils import ( BotoAWSRequestsAuth, ) from opensearchpy import ( Connection, OpenSearch, + Search, Urllib3HttpConnection, ) +from opensearchpy.connection.connections import ( + get_connection, +) import requests import requests.auth import urllib3 @@ -33,10 +45,17 @@ from azul.http import ( HttpClient, ) +from azul.json import ( + copy_json, +) from azul.logging import ( es_log, http_body_log_message, ) +from azul.types import ( + AnyJSON, + MutableJSON, +) log = logging.getLogger(__name__) @@ -243,3 +262,91 @@ def _create_client(cls, host, port, timeout): else: return OpenSearch(connection_class=AzulUrllib3HttpConnection, **common_params) + + +@attrs.frozen(auto_attribs=True, kw_only=True) +class Template(metaclass=ABCMeta): + param_name: str + value: AnyJSON + + @abstractmethod + def to_source(self) -> AnyJSON: + raise NotImplementedError + + +class RawStr(str): + """ + Instances of this class will not be surrounded by quotes when encoded as + JSON using a :class:`TemplateSearchJSONEncoder`. + """ + + +@attrs.frozen(auto_attribs=True, kw_only=True) +class ToJsonTemplate(Template): + + def to_source(self) -> RawStr: + return RawStr('{{#toJson}}' + self.param_name + '{{/toJson}}') + + +_original = json.encoder.py_encode_basestring_ascii + + +def _encode_basestring_ascii(s: str) -> str: + result = _original(s) + assert result[0] == result[-1] == '"', result + return result[1:-1] if isinstance(s, RawStr) else result + + +class TemplateSearchJSONEncoder(json.JSONEncoder): + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.params: dict[str, AnyJSON] = {} + + def default(self, obj: Any) -> Any: + if isinstance(obj, Template): + try: + old_value = self.params[obj.param_name] + except KeyError: + self.params[obj.param_name] = obj.value + else: + # If two parameters have the same name, they ought to come from + # the same template object. Having two template objects with the + # same name probably indicates a bug even if they happen to use + # the same value. + assert obj.value is old_value, (obj, old_value) + return obj.to_source() + else: + return super().default(obj) + + def iterencode(self, o: AnyJSON, _one_shot: bool = False) -> Iterator[str]: + with unittest.mock.patch('json.encoder.encode_basestring_ascii', + wraps=_encode_basestring_ascii): + return super().iterencode(o, _one_shot=_one_shot) + + +class TemplateSearch(Search): + + def to_dict(self, count: bool = False, **kwargs) -> MutableJSON: + # Sorting ensures consistent output for unit tests + encoder = TemplateSearchJSONEncoder(sort_keys=True) + return { + 'source': encoder.encode(super().to_dict(count=count, **kwargs)), + 'params': copy_json(encoder.params), + } + + def execute(self, ignore_cache: bool = False) -> Any: + # The body of this method is mostly copied from the superclass, with the + # only change being switching `search` for `search_template`. We could + # also monkeypatch that method, but this approach is more robust because + # we retain control over which arguments are passed. Note that `search` + # supports many parameters that `search_template` currently does not. + if ignore_cache or not hasattr(self, '_response'): + opensearch = get_connection(self._using) + self._response = self._response_class( + self, + opensearch.search_template( + index=self._index, body=self.to_dict(), **self._params + ), + ) + return self._response diff --git a/src/azul/service/manifest_service.py b/src/azul/service/manifest_service.py index 84902c6ddd..aa2e3da962 100644 --- a/src/azul/service/manifest_service.py +++ b/src/azul/service/manifest_service.py @@ -1286,6 +1286,53 @@ class ClientSidePagingManifestGenerator(ManifestGenerator, metaclass=ABCMeta): """ page_size = 500 + def _paginate_hits(self, + request_factory: Callable[[SortKey | None], Search] + ) -> Iterable[Hit]: + """ + Yield all hits in every page of Elasticsearch hits in responses to + requests that use client-side paging. + + :param request_factory: A callable that returns a prepared Elasticsearch + request for the given search-after key, with the + appropriate filters and sorting applied. The + returned request should yield one page worth of + hits, starting at the first page (if the argument + is None), or the hit right after the hit with + given search-after key + """ + search_after = None + while True: + request = request_factory(search_after) + response = request.execute() + if response.hits: + hit = None + for hit in response.hits: + yield hit + assert hit is not None + search_after = self._search_after(hit) + else: + break + + def _paginate_hits_sorted(self, + request: Search, + sort: SortKey + ) -> Iterable[Hit]: + """ + Wrapper around :meth:`_paginate_hits` for simple cases where the request + does not require any additional setup between pages + """ + request = request.extra(size=self.page_size) + request = request.sort(*sort) + + def request_factory(search_after: SortKey | None) -> Search: + if search_after is None: + return request + else: + return request.extra(search_after=search_after) + + return self._paginate_hits(request_factory) + def _create_paged_request(self, search_after: SortKey | None) -> Search: pagination = Pagination(sort='entryId', order='asc', @@ -1788,7 +1835,8 @@ def write_page_to(self, Bundles = dict[FQID, Bundle] -class PFBManifestGenerator(FileBasedManifestGenerator): +class PFBManifestGenerator(FileBasedManifestGenerator, + ClientSidePagingManifestGenerator): @classmethod def format(cls) -> ManifestFormat: @@ -1816,10 +1864,10 @@ def included_fields(self) -> list[FieldPath] | None: def _all_docs_sorted(self) -> Iterable[JSON]: request = self._create_request(self.entity_type) - request = request.params(preserve_order=True).sort('entity_id.keyword') - for hit in request.scan(): - doc = self._hit_to_doc(hit) - yield doc + # Need two sort fields to satisfy type constraints + sort = ('entity_id.keyword',) * 2 + hits = self._paginate_hits_sorted(request, sort) + return map(self._hit_to_doc, hits) def create_file(self) -> tuple[str, str | None]: transformers = self.service.transformer_types(self.catalog) @@ -1913,34 +1961,6 @@ class ReplicaKeys: hub_id: str replica_ids: list[str] - def _paginate_hits(self, - request_factory: Callable[[SortKey | None], Search] - ) -> Iterable[Hit]: - """ - Yield all hits in every page of Elasticsearch hits in responses to - requests that use client-side paging. - - :param request_factory: A callable that returns a prepared Elasticsearch - request for the given search-after key, with the - appropriate filters and sorting applied. The - returned request should yield one page worth of - hits, starting at the first page (if the argument - is None), or the hit right after the hit with - given search-after key - """ - search_after = None - while True: - request = request_factory(search_after) - response = request.execute() - if response.hits: - hit = None - for hit in response.hits: - yield hit - assert hit is not None - search_after = self._search_after(hit) - else: - break - def _list_replica_keys(self) -> Iterable[ReplicaKeys]: for hit in self._paginate_hits(self._create_paged_request): document_ids = [ @@ -1986,7 +2006,6 @@ def _join_replicas(self, keys: Iterable[ReplicaKeys]) -> Iterable[Hit]: {'terms': {'hub_ids.keyword': list(hub_ids)}}, {'terms': {'entity_id.keyword': list(replica_ids)}} ])) - request = request.extra(size=self.page_size) # `_id` is currently the only index field that is unique to each replica # document (and thus results in an unambiguous total ordering). However, @@ -1998,15 +2017,8 @@ def _join_replicas(self, keys: Iterable[ReplicaKeys]) -> Iterable[Hit]: # FIXME: ES DeprecationWarning for using _id as sort key # https://github.com/DataBiosphere/azul/issues/7290 # - request = request.sort('entity_id.keyword', '_id') - - def request_factory(search_after: SortKey | None) -> Search: - if search_after is None: - return request - else: - return request.extra(search_after=search_after) - - return self._paginate_hits(request_factory) + sort = ('entity_id.keyword', '_id') + return self._paginate_hits_sorted(request, sort) class JSONLVerbatimManifestGenerator(PagedManifestGenerator, diff --git a/src/azul/service/query_service.py b/src/azul/service/query_service.py index ec9d848cf8..2c9f3b82b4 100644 --- a/src/azul/service/query_service.py +++ b/src/azul/service/query_service.py @@ -51,6 +51,8 @@ ) from azul.es import ( ESClientFactory, + TemplateSearch, + ToJsonTemplate, ) from azul.indexer.document import ( DocumentType, @@ -231,23 +233,30 @@ def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query: Converts the given filters into an Elasticsearch DSL Query object. """ filter_list = [] + plugin = self.plugin + source_id_field_name = plugin.special_fields.source_id.name + source_id_field_path = plugin.field_mapping[source_id_field_name] for field_path, filter in self.prepared_filters.items(): if field_path not in skip_field_paths: + values: Sequence[PrimitiveJSON] | ToJsonTemplate operator, values = one(filter.items()) + original_values = values # Note that `is_not` is only used internally (for filtering by # inaccessible sources) if operator in ('is', 'is_not'): field_type = self.service.field_type(self.catalog, field_path) + if field_path == source_id_field_path: + values = ToJsonTemplate(param_name=source_id_field_name, value=values) if isinstance(field_type, Nested): term_queries = [] - for nested_field, nested_value in one(values).items(): + for nested_field, nested_value in one(original_values).items(): nested_body = {dotted(field_path, nested_field, 'keyword'): nested_value} term_queries.append(Q('term', **nested_body)) query = Q('nested', path=dotted(field_path), query=Q('bool', must=term_queries)) else: query = Q('terms', **{dotted(field_path, 'keyword'): values}) translated_none = field_type.to_index(None) - if translated_none in values: + if translated_none in original_values: # Note that at this point None values in filters have already # been translated e.g. {'is': ['~null']} and if the filter has a # None our query needs to find fields with None values as well @@ -691,12 +700,12 @@ def create_request(self, catalog: CatalogName, entity_type: str, doc_type: DocumentType = DocumentType.aggregate - ) -> Search: + ) -> TemplateSearch: """ Create an Elasticsearch request against the index containing documents of the given entity and document types, in the given catalog. """ - return Search(using=self._es_client, - index=str(IndexName.create(catalog=catalog, - qualifier=entity_type, - doc_type=doc_type))) + return TemplateSearch(using=self._es_client, + index=str(IndexName.create(catalog=catalog, + qualifier=entity_type, + doc_type=doc_type))) diff --git a/test/es_test_case.py b/test/es_test_case.py index 847d0179d1..66764d9baf 100644 --- a/test/es_test_case.py +++ b/test/es_test_case.py @@ -57,17 +57,32 @@ def setUpClass(cls): cls.es_client = ESClientFactory.get() cls._wait_for_es() - # Disable the automatic creation of indexes when documents are - # indexed. We create indexes explicitly before any documents are - # indexed so a missing index would be indicative of some sort of - # bug. We want to fail early in that situation. Automatically - # created indices have a only a default mapping, resulting in - # failure modes that are harder to diagnose. - # cls.es_client.cluster.put_settings(body={ 'persistent': { + # Disable the automatic creation of indexes when documents are + # indexed. We create indexes explicitly before any documents are + # indexed so a missing index would be indicative of some sort of + # bug. We want to fail early in that situation. Automatically + # created indices have a only a default mapping, resulting in + # failure modes that are harder to diagnose. + # 'action.auto_create_index': False, - 'action.destructive_requires_name': False + + # Allow wildcard deletions, making it possible to delete all + # indices in a single request. Speeds up deletion of indices + # between tests. + # + 'action.destructive_requires_name': False, + + # The service uses template queries when reading from the + # index which, during tests, can far exceed the default + # script compilation rate of 75/5m. Rendering a template + # query using mustache is a very cheap operation compared to + # other compilation contexts (e.g. generating bytecode for a + # painless script), so performance shouldn't be + # significantly affected. + # + 'script.context.template.max_compilations_rate': 'unlimited' } }) except BaseException: # no coverage diff --git a/test/service/test_request_builder.py b/test/service/test_request_builder.py index 61c40003cb..6815367857 100644 --- a/test/service/test_request_builder.py +++ b/test/service/test_request_builder.py @@ -5,9 +5,13 @@ import json import attr +from opensearchpy import ( + Search, +) from azul import ( CatalogName, + JSON, ) from azul.indexer.field import ( FieldTypes, @@ -96,7 +100,7 @@ def facets(self) -> Sequence[str]: 'constant_score': { 'filter': { 'terms': { - 'sources.id.keyword': [] + 'sources.id.keyword': '{{#toJson}}sourceId{{/toJson}}' } } } @@ -127,10 +131,6 @@ def test_create_request(self): } } sample_filter = {'entity_id': {'is': ['cbb998ce-ddaf-34fa-e163-d14b399c6b34']}} - # Need to work on a couple cases: - # - The empty case - # - The 1 filter case - # - The complex multiple filters case self._test_create_request(expected_output, sample_filter) def test_create_request_empty(self): @@ -149,38 +149,6 @@ def test_create_request_empty(self): sample_filter = {} self._test_create_request(expected_output, sample_filter, post_filter=False) - def test_create_request_complex(self): - """ - Tests creation of a complex request. - """ - expected_output = { - 'post_filter': { - 'bool': { - 'must': [ - { - 'constant_score': { - 'filter': { - 'terms': { - 'entity_id.keyword': [ - 'cbb998ce-ddaf-34fa-e163-d14b399c6b34' - ] - } - } - } - }, - self.sources_filter - ] - } - } - } - sample_filter = { - 'entity_id': - { - 'is': ['cbb998ce-ddaf-34fa-e163-d14b399c6b34'] - } - } - self._test_create_request(expected_output, sample_filter) - def test_create_request_missing_values(self): """ Tests creation of a request for facets that do not have a value @@ -314,18 +282,40 @@ def test_create_request_terms_and_missing_values(self): self._test_create_request(expected_output, sample_filter) def _test_create_request(self, - expected_output, - sample_filter, - post_filter=True + expected_output: JSON, + sample_filter: JSON, + post_filter: bool = True ): service = self.Service(self.MockPlugin()) + self._test_create_request_with_service(service, + expected_output, + sample_filter, + post_filter) + + def _test_create_request_with_service(self, + service: Service, + expected_output: JSON, + sample_filter: JSON, + post_filter: bool = True + ): filters = Filters(explicit=sample_filter, source_ids=set()) request = self._prepare_request(filters, post_filter, service) - expected_output = json.dumps(expected_output, sort_keys=True) - actual_output = json.dumps(request.to_dict(), sort_keys=True) - self.assertEqual(actual_output, expected_output) + expected_output = { + 'source': ( + json.dumps(expected_output, sort_keys=True) + .replace('"{{#toJson}}', '{{#toJson}}') + .replace('{{/toJson}}"', '{{/toJson}}') + ), + 'params': {'sourceId': []} + } + actual_request_body = request.to_dict() + self.assertEqual(expected_output, actual_request_body) - def _prepare_request(self, filters, post_filter, service): + def _prepare_request(self, + filters: Filters, + post_filter: bool, + service: Service + ) -> Search: entity_type = 'files' pipeline = service.create_chain(catalog=self.catalog, entity_type=entity_type, @@ -344,7 +334,7 @@ def test_create_aggregate(self): Tests creation of an ES aggregate """ expected_output = { - 'filter': { + 'post_filter': { 'bool': { 'must': [ self.sources_filter @@ -352,18 +342,29 @@ def test_create_aggregate(self): } }, 'aggs': { - 'myTerms': { - 'terms': { - 'field': 'path.to.foo.keyword', - 'size': 99999 + 'foo': { + 'filter': { + 'bool': { + 'must': [ + self.sources_filter + ] + } }, - 'meta': { - 'path': ['path', 'to', 'foo'] - } - }, - 'untagged': { - 'missing': { - 'field': 'path.to.foo.keyword' + 'aggs': { + 'myTerms': { + 'terms': { + 'field': 'path.to.foo.keyword', + 'size': 99999 + }, + 'meta': { + 'path': ['path', 'to', 'foo'] + } + }, + 'untagged': { + 'missing': { + 'field': 'path.to.foo.keyword' + } + } } } } @@ -391,11 +392,5 @@ def facets(self) -> Sequence[str]: return ['foo'] service = Service(MockPlugin()) - - filters = Filters(explicit={}, source_ids=set()) - post_filter = True - request = self._prepare_request(filters, post_filter, service) - aggregation = request.aggs['foo'] - expected_output = json.dumps(expected_output, sort_keys=True) - actual_output = json.dumps(aggregation.to_dict(), sort_keys=True) - self.assertEqual(actual_output, expected_output) + sample_filter = {} + self._test_create_request_with_service(service, expected_output, sample_filter)