diff --git a/.mypy.ini b/.mypy.ini index dbc569f195..898a5a6287 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -71,6 +71,7 @@ modules = azul.service.manifest_service, azul.field_type, azul.source, + azul.service.query_service, packages = diff --git a/src/azul/lib/types.py b/src/azul/lib/types.py index 4d7dcf8b7c..7eba3e5990 100644 --- a/src/azul/lib/types.py +++ b/src/azul/lib/types.py @@ -194,6 +194,11 @@ def json_sorted(vs: Iterable[PrimitiveJSON]) -> MutableJSONArray: return sorted(vs, key=none_safe_key(none_last=True)) +def json_primitive(v: AnyJSON) -> PrimitiveJSON: + assert v is None or isinstance(v, (str, int, float, bool)), type(v) + return v + + def json_str(v: AnyMutableJSON | AnyJSON) -> str: return any_str(v) diff --git a/src/azul/service/query_service.py b/src/azul/service/query_service.py index 0a3c739cb9..feb4a8fbf0 100644 --- a/src/azul/service/query_service.py +++ b/src/azul/service/query_service.py @@ -8,10 +8,6 @@ from collections.abc import ( Iterable, Mapping, - Sequence, -) -from functools import ( - partial, ) import json import logging @@ -53,6 +49,7 @@ ) from azul.indexer.document import ( DocumentType, + FieldPath, IndexName, ) from azul.indexer.document_service import ( @@ -65,11 +62,22 @@ from azul.lib.types import ( AnyJSON, JSON, + JSONArray, JSONTypedDict, JSONs, MutableJSON, PrimitiveJSON, - json_list, + json_dict, + json_dict_of_dicts, + json_element_dicts, + json_element_strings, + json_int, + json_item_sequences, + json_list_of_dicts, + json_mapping, + json_primitive, + json_sequence, + json_sequence_of_mappings, json_str, ) from azul.opensearch import ( @@ -77,11 +85,11 @@ ) from azul.plugins import ( DocumentSlice, - FieldPath, MetadataPlugin, dotted, ) from azul.service import ( + FilterJSON, Filters, FiltersJSON, ) @@ -168,7 +176,7 @@ def wrap[R0](self, other: OpenSearchStage[R0, R1]) -> OpenSearchChain[R0, R1, R2 return OpenSearchChain(inner=other, outer=self) -TranslatedFilters = Mapping[FieldPath, Mapping[str, Sequence[PrimitiveJSON]]] +TranslatedFilters = Mapping[FieldPath, Mapping[str, JSONArray]] @attr.s(frozen=True, auto_attribs=True, kw_only=True) @@ -216,30 +224,38 @@ def _translate_filters(self, filters: FiltersJSON) -> TranslatedFilters: """ catalog = self.catalog field_mapping = self.plugin.field_mapping - translated_filters = {} - for field, filter in filters.items(): - field = field_mapping[field] - operator, values = one(filter.items()) - field_type = self.service.field_type(catalog, field) - values = field_type.filter(operator, values) - translated_filters[field] = {operator: list(values)} - return translated_filters - def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query: + def translate_filter(field_name: str, + filter: FilterJSON + ) -> tuple[FieldPath, Mapping[str, JSONArray]]: + field_path = field_mapping[field_name] + operator, values = one(filter.items()) + field_type = self.service.field_type(catalog, field_path) + # FIXME: remove `type: ignore` + # https://github.com/DataBiosphere/azul/issues/6821 + values: JSONArray = list(field_type.filter(operator, values)) # type: ignore + return field_path, {operator: values} + + return dict( + translate_filter(field, filter) + for field, filter in filters.items() + ) + + def prepare_query(self, skip_field_paths: tuple[FieldPath, ...] = ()) -> Query: """ Converts the given filters into an OpenSearch DSL Query object. """ filter_list = [] for field_path, filter in self.prepared_filters.items(): if field_path not in skip_field_paths: - operator, values = one(filter.items()) + operator, values = one(json_item_sequences(filter)) # Note that `is_not` is only used internally (for filtering by # inaccessible sources) if operator in ('is', 'is_not'): field_type = self.service.field_type(self.catalog, field_path) if isinstance(field_type, Nested): term_queries = [] - for nested_field, nested_value in one(values).items(): + for nested_field, nested_value in json_mapping(one(values)).items(): nested_body = {dotted(field_path, nested_field, 'keyword'): nested_value} term_queries.append(Q('term', **nested_body)) query = Q('nested', path=dotted(field_path), query=Q('bool', must=term_queries)) @@ -258,7 +274,7 @@ def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query: filter_list.append(query) elif operator in ('contains', 'within', 'intersects'): for value in values: - value = value | {'relation': operator} + value = {**json_mapping(value), 'relation': operator} filter_list.append(Q('range', **{dotted(field_path): value})) else: assert False @@ -307,7 +323,7 @@ def prepare_request(self, request: Search) -> Search: def process_response(self, response: MutableJSON) -> MutableJSON: try: - aggs = response['aggregations'] + aggs = json_dict(response['aggregations']) except KeyError: pass else: @@ -330,7 +346,7 @@ def _prepare_aggregation(self, *, facet: str, facet_path: FieldPath) -> Agg: nested_agg = agg.bucket(name='nested', agg_type='nested', path=dotted(facet_path)) - facet_path = dotted(facet_path, field_type.agg_property) + facet_path = (*facet_path, field_type.agg_property) else: nested_agg = agg # Make an inner agg that will contain the terms in question @@ -368,9 +384,9 @@ def annotate(agg: Agg): annotate(request.aggs[agg_name]) def _flatten_nested_aggs(self, aggs: MutableJSON): - for facet, agg in aggs.items(): + for facet, agg in json_dict_of_dicts(aggs).items(): try: - nested_agg = agg.pop('nested') + nested_agg = json_dict(agg.pop('nested')) except KeyError: pass else: @@ -382,26 +398,27 @@ def _translate_response_aggs(self, aggs: MutableJSON): OpenSearch response. """ - def translate(k, v: MutableJSON): + def translate(k: str, v: MutableJSON): try: buckets = v['buckets'] except KeyError: - for k, v in v.items(): - if isinstance(v, dict): - translate(k, v) + for ki, vi in v.items(): + if isinstance(vi, dict): + translate(ki, vi) else: try: - path = v['meta']['path'] + path = json_dict(v['meta'])['path'] except KeyError: pass else: - field_type = self.service.field_type(self.catalog, tuple(path)) - for bucket in buckets: + field_type = self.service.field_type(self.catalog, + tuple(json_element_strings(path))) + for bucket in json_element_dicts(buckets): bucket['key'] = field_type.from_index(bucket['key']) translate(k, bucket) for k, v in aggs.items(): - translate(k, v) + translate(k, json_dict(v)) def _populate_accessible(self, aggs: MutableJSON) -> None: # Because the value of the `accessible` field depends on the provided @@ -410,12 +427,14 @@ def _populate_accessible(self, aggs: MutableJSON) -> None: source_ids = self.filter_stage.filters.source_ids plugin = self.service.metadata_plugin(self.catalog) special_fields = plugin.special_fields - agg = aggs.pop(special_fields.source_id.name) + agg = json_dict(aggs.pop(special_fields.source_id.name)) counts_by_accessibility: dict[bool, int] = defaultdict(int) - for bucket in agg['myTerms']['buckets']: + terms = json_dict(agg['myTerms']) + buckets = json_list_of_dicts(terms['buckets']) + for bucket in buckets: accessible = bucket['key'] in source_ids - counts_by_accessibility[accessible] += bucket['doc_count'] - agg['myTerms']['buckets'] = [ + counts_by_accessibility[accessible] += json_int(bucket['doc_count']) + terms['buckets'] = [ {'key': accessible, 'doc_count': count} for accessible, count in counts_by_accessibility.items() ] @@ -465,8 +484,8 @@ def process_response(self, response: Response) -> MutableJSON: def sort_key_from_json(s: AnyJSON) -> SortKey: - a, b = json_list(s) - return a, json_str(b) + a, b = json_sequence(s) + return json_primitive(a), json_str(b) def sort_key_to_json(s: SortKey) -> AnyJSON: @@ -532,12 +551,12 @@ class PaginationStage(_OpenSearchStage[JSON, ResponseTriple]): def prepare_request(self, request: Search) -> Search: sort_order = self.pagination.order - sort_field = self.plugin.field_mapping[self.pagination.sort] - field_type = self.service.field_type(self.catalog, sort_field) + sort_field_path = self.plugin.field_mapping[self.pagination.sort] + field_type = self.service.field_type(self.catalog, sort_field_path) sort_mode = field_type.es_sort_mode - sort_field = dotted(sort_field, 'keyword') + sort_field = dotted(sort_field_path, 'keyword') - def sort(order): + def sort(order: str) -> tuple[JSON, JSON]: assert order in ('asc', 'desc'), order return ( { @@ -590,47 +609,48 @@ def process_response(self, response: JSON) -> ResponseTriple: """ Returns hits and pagination as dict """ - # The slice is necessary because we may have fetched an extra entry to - # determine if there is a previous or next page. - hits = self._extract_hits(response) + hits, total = self._extract_hits(response) + pagination = self._process_pagination(hits, total) hits = self._translate_hits(hits) - pagination = self._process_pagination(response) - aggregations = response.get('aggregations', {}) + aggregations = json_mapping(response.get('aggregations', {})) return hits, pagination, aggregations - def _extract_hits(self, response): - hits = response['hits']['hits'][0:self.pagination.size] - if self.pagination.search_before is not None: - hits = reversed(hits) - hits = [hit['_source'] for hit in hits] - return hits - - def _translate_hits(self, hits): - f = partial(self.service.translate_fields, self.catalog, forward=False) - hits = list(map(f, hits)) - return hits - - def _process_pagination(self, response: JSON) -> MutableJSON: - total = response['hits']['total'] + def _extract_hits(self, response: JSON) -> tuple[JSONs, int]: + hits = json_mapping(response['hits']) + total = json_mapping(hits['total']) # FIXME: Handle other relations # https://github.com/DataBiosphere/azul/issues/3770 assert total['relation'] == 'eq' - pages = -(-total['value'] // self.pagination.size) + return json_sequence_of_mappings(hits['hits']), json_int(total['value']) + + def _translate_hits(self, hits: JSONs) -> JSONs: + # The slice is necessary because we may have fetched an extra entry to + # determine if there is a previous or next page. + hits = hits[0:self.pagination.size] + hits = iter(hits) if self.pagination.search_before is None else reversed(hits) + return [ + self.service.translate_fields(self.catalog, + json_mapping(hit['_source']), + forward=False) + for hit in hits + ] + + def _process_pagination(self, hits: JSONs, total: int) -> ResponsePagination: + pages = -(-total // self.pagination.size) # ... else use search_after/search_before pagination - hits: JSONs = response['hits']['hits'] count = len(hits) if self.pagination.search_before is None: # hits are normal sorted if count > self.pagination.size: # There is an extra hit, indicating a next page. count -= 1 - search_after = tuple(hits[count - 1]['sort']) + search_after = sort_key_from_json(hits[count - 1]['sort']) else: # No next page search_after = None if self.pagination.search_after is not None: - search_before = tuple(hits[0]['sort']) + search_before = sort_key_from_json(hits[0]['sort']) else: search_before = None else: @@ -638,23 +658,23 @@ def _process_pagination(self, response: JSON) -> MutableJSON: if count > self.pagination.size: # There is an extra hit, indicating a previous page. count -= 1 - search_before = tuple(hits[count - 1]['sort']) + search_before = sort_key_from_json(hits[count - 1]['sort']) else: # No previous page search_before = None - search_after = tuple(hits[0]['sort']) + search_after = sort_key_from_json(hits[0]['sort']) pagination = self.pagination.advance(search_before=search_before, search_after=search_after) - def page_link(*, previous): + def page_link(*, previous: bool) -> str | None: url = pagination.link(previous=previous, catalog=self.catalog, filters=json.dumps(self.filters.explicit)) return None if url is None else str(url) return ResponsePagination(count=count, - total=total['value'], + total=total, size=pagination.size, next=page_link(previous=False), previous=page_link(previous=True),