-
Notifications
You must be signed in to change notification settings - Fork 4
Cover query service with mypy (#6821) #7916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
dc71dd4
ba3d151
39fecf1
9c7ff6f
3395aa1
de9cd39
510aa4b
03a9859
59295f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,10 +8,6 @@ | |
| from collections.abc import ( | ||
| Iterable, | ||
| Mapping, | ||
| Sequence, | ||
| ) | ||
| from functools import ( | ||
| partial, | ||
| ) | ||
| import json | ||
| import logging | ||
|
|
@@ -53,6 +49,7 @@ | |
| ) | ||
| from azul.indexer.document import ( | ||
| DocumentType, | ||
| FieldPath, | ||
| IndexName, | ||
| ) | ||
| from azul.indexer.document_service import ( | ||
|
|
@@ -65,23 +62,34 @@ | |
| from azul.lib.types import ( | ||
| AnyJSON, | ||
| JSON, | ||
| JSONArray, | ||
| JSONTypedDict, | ||
| JSONs, | ||
| MutableJSON, | ||
| PrimitiveJSON, | ||
| json_list, | ||
| json_dict, | ||
| json_dict_of_dicts, | ||
| json_element_dicts, | ||
| json_element_strings, | ||
| json_int, | ||
| json_item_sequences, | ||
| json_list_of_dicts, | ||
| json_mapping, | ||
| json_primitive, | ||
| json_sequence, | ||
| json_sequence_of_mappings, | ||
| json_str, | ||
| ) | ||
| from azul.opensearch import ( | ||
| OpenSearchClientFactory, | ||
| ) | ||
| from azul.plugins import ( | ||
| DocumentSlice, | ||
| FieldPath, | ||
| MetadataPlugin, | ||
| dotted, | ||
| ) | ||
| from azul.service import ( | ||
| FilterJSON, | ||
| Filters, | ||
| FiltersJSON, | ||
| ) | ||
|
|
@@ -168,7 +176,7 @@ def wrap[R0](self, other: OpenSearchStage[R0, R1]) -> OpenSearchChain[R0, R1, R2 | |
| return OpenSearchChain(inner=other, outer=self) | ||
|
|
||
|
|
||
| TranslatedFilters = Mapping[FieldPath, Mapping[str, Sequence[PrimitiveJSON]]] | ||
| TranslatedFilters = Mapping[FieldPath, Mapping[str, JSONArray]] | ||
|
|
||
|
|
||
| @attr.s(frozen=True, auto_attribs=True, kw_only=True) | ||
|
|
@@ -216,30 +224,38 @@ def _translate_filters(self, filters: FiltersJSON) -> TranslatedFilters: | |
| """ | ||
| catalog = self.catalog | ||
| field_mapping = self.plugin.field_mapping | ||
| translated_filters = {} | ||
| for field, filter in filters.items(): | ||
| field = field_mapping[field] | ||
| operator, values = one(filter.items()) | ||
| field_type = self.service.field_type(catalog, field) | ||
| values = field_type.filter(operator, values) | ||
| translated_filters[field] = {operator: list(values)} | ||
| return translated_filters | ||
|
|
||
| def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query: | ||
| def translate_filter(field_name: str, | ||
| filter: FilterJSON | ||
| ) -> tuple[FieldPath, Mapping[str, JSONArray]]: | ||
| field_path = field_mapping[field_name] | ||
| operator, values = one(filter.items()) | ||
| field_type = self.service.field_type(catalog, field_path) | ||
| # FIXME: remove `type: ignore` | ||
| # https://github.com/DataBiosphere/azul/issues/6821 | ||
| values: JSONArray = list(field_type.filter(operator, values)) # type: ignore | ||
| return field_path, {operator: values} | ||
|
|
||
| return dict( | ||
| translate_filter(field, filter) | ||
| for field, filter in filters.items() | ||
| ) | ||
|
|
||
| def prepare_query(self, skip_field_paths: tuple[FieldPath, ...] = ()) -> Query: | ||
|
Comment on lines
-219
to
+244
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we agreed to only touch filter-related code when it's trivial narrowing? This is a non-trivial change, and it still requires a pragma so I don't know what it's buying us. |
||
| """ | ||
| Converts the given filters into an OpenSearch DSL Query object. | ||
| """ | ||
| filter_list = [] | ||
| for field_path, filter in self.prepared_filters.items(): | ||
| if field_path not in skip_field_paths: | ||
| operator, values = one(filter.items()) | ||
| operator, values = one(json_item_sequences(filter)) | ||
| # Note that `is_not` is only used internally (for filtering by | ||
| # inaccessible sources) | ||
| if operator in ('is', 'is_not'): | ||
| field_type = self.service.field_type(self.catalog, field_path) | ||
| if isinstance(field_type, Nested): | ||
| term_queries = [] | ||
| for nested_field, nested_value in one(values).items(): | ||
| for nested_field, nested_value in json_mapping(one(values)).items(): | ||
| nested_body = {dotted(field_path, nested_field, 'keyword'): nested_value} | ||
| term_queries.append(Q('term', **nested_body)) | ||
| query = Q('nested', path=dotted(field_path), query=Q('bool', must=term_queries)) | ||
|
|
@@ -258,7 +274,7 @@ def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query: | |
| filter_list.append(query) | ||
| elif operator in ('contains', 'within', 'intersects'): | ||
| for value in values: | ||
| value = value | {'relation': operator} | ||
| value = {**json_mapping(value), 'relation': operator} | ||
| filter_list.append(Q('range', **{dotted(field_path): value})) | ||
| else: | ||
| assert False | ||
|
|
@@ -307,7 +323,7 @@ def prepare_request(self, request: Search) -> Search: | |
|
|
||
| def process_response(self, response: MutableJSON) -> MutableJSON: | ||
| try: | ||
| aggs = response['aggregations'] | ||
| aggs = json_dict(response['aggregations']) | ||
| except KeyError: | ||
| pass | ||
| else: | ||
|
|
@@ -330,7 +346,7 @@ def _prepare_aggregation(self, *, facet: str, facet_path: FieldPath) -> Agg: | |
| nested_agg = agg.bucket(name='nested', | ||
| agg_type='nested', | ||
| path=dotted(facet_path)) | ||
| facet_path = dotted(facet_path, field_type.agg_property) | ||
| facet_path = (*facet_path, field_type.agg_property) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like a semantic change. |
||
| else: | ||
| nested_agg = agg | ||
| # Make an inner agg that will contain the terms in question | ||
|
|
@@ -368,9 +384,9 @@ def annotate(agg: Agg): | |
| annotate(request.aggs[agg_name]) | ||
|
|
||
| def _flatten_nested_aggs(self, aggs: MutableJSON): | ||
| for facet, agg in aggs.items(): | ||
| for facet, agg in json_dict_of_dicts(aggs).items(): | ||
| try: | ||
| nested_agg = agg.pop('nested') | ||
| nested_agg = json_dict(agg.pop('nested')) | ||
| except KeyError: | ||
| pass | ||
| else: | ||
|
|
@@ -382,26 +398,27 @@ def _translate_response_aggs(self, aggs: MutableJSON): | |
| OpenSearch response. | ||
| """ | ||
|
|
||
| def translate(k, v: MutableJSON): | ||
| def translate(k: str, v: MutableJSON): | ||
| try: | ||
| buckets = v['buckets'] | ||
| except KeyError: | ||
| for k, v in v.items(): | ||
| if isinstance(v, dict): | ||
| translate(k, v) | ||
| for ki, vi in v.items(): | ||
| if isinstance(vi, dict): | ||
| translate(ki, vi) | ||
| else: | ||
| try: | ||
| path = v['meta']['path'] | ||
| path = json_dict(v['meta'])['path'] | ||
| except KeyError: | ||
| pass | ||
| else: | ||
| field_type = self.service.field_type(self.catalog, tuple(path)) | ||
| for bucket in buckets: | ||
| field_type = self.service.field_type(self.catalog, | ||
| tuple(json_element_strings(path))) | ||
| for bucket in json_element_dicts(buckets): | ||
| bucket['key'] = field_type.from_index(bucket['key']) | ||
| translate(k, bucket) | ||
|
|
||
| for k, v in aggs.items(): | ||
| translate(k, v) | ||
| translate(k, json_dict(v)) | ||
|
|
||
| def _populate_accessible(self, aggs: MutableJSON) -> None: | ||
| # Because the value of the `accessible` field depends on the provided | ||
|
|
@@ -410,12 +427,14 @@ def _populate_accessible(self, aggs: MutableJSON) -> None: | |
| source_ids = self.filter_stage.filters.source_ids | ||
| plugin = self.service.metadata_plugin(self.catalog) | ||
| special_fields = plugin.special_fields | ||
| agg = aggs.pop(special_fields.source_id.name) | ||
| agg = json_dict(aggs.pop(special_fields.source_id.name)) | ||
| counts_by_accessibility: dict[bool, int] = defaultdict(int) | ||
| for bucket in agg['myTerms']['buckets']: | ||
| terms = json_dict(agg['myTerms']) | ||
| buckets = json_list_of_dicts(terms['buckets']) | ||
| for bucket in buckets: | ||
| accessible = bucket['key'] in source_ids | ||
| counts_by_accessibility[accessible] += bucket['doc_count'] | ||
| agg['myTerms']['buckets'] = [ | ||
| counts_by_accessibility[accessible] += json_int(bucket['doc_count']) | ||
| terms['buckets'] = [ | ||
| {'key': accessible, 'doc_count': count} | ||
| for accessible, count in counts_by_accessibility.items() | ||
| ] | ||
|
|
@@ -465,8 +484,8 @@ def process_response(self, response: Response) -> MutableJSON: | |
|
|
||
|
|
||
| def sort_key_from_json(s: AnyJSON) -> SortKey: | ||
| a, b = json_list(s) | ||
| return a, json_str(b) | ||
| a, b = json_sequence(s) | ||
| return json_primitive(a), json_str(b) | ||
|
|
||
|
|
||
| def sort_key_to_json(s: SortKey) -> AnyJSON: | ||
|
|
@@ -532,12 +551,12 @@ class PaginationStage(_OpenSearchStage[JSON, ResponseTriple]): | |
|
|
||
| def prepare_request(self, request: Search) -> Search: | ||
| sort_order = self.pagination.order | ||
| sort_field = self.plugin.field_mapping[self.pagination.sort] | ||
| field_type = self.service.field_type(self.catalog, sort_field) | ||
| sort_field_path = self.plugin.field_mapping[self.pagination.sort] | ||
| field_type = self.service.field_type(self.catalog, sort_field_path) | ||
| sort_mode = field_type.es_sort_mode | ||
| sort_field = dotted(sort_field, 'keyword') | ||
| sort_field = dotted(sort_field_path, 'keyword') | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't |
||
|
|
||
| def sort(order): | ||
| def sort(order: str) -> tuple[JSON, JSON]: | ||
| assert order in ('asc', 'desc'), order | ||
| return ( | ||
| { | ||
|
|
@@ -590,71 +609,72 @@ def process_response(self, response: JSON) -> ResponseTriple: | |
| """ | ||
| Returns hits and pagination as dict | ||
| """ | ||
| # The slice is necessary because we may have fetched an extra entry to | ||
| # determine if there is a previous or next page. | ||
| hits = self._extract_hits(response) | ||
| hits, total = self._extract_hits(response) | ||
| pagination = self._process_pagination(hits, total) | ||
| hits = self._translate_hits(hits) | ||
| pagination = self._process_pagination(response) | ||
| aggregations = response.get('aggregations', {}) | ||
| aggregations = json_mapping(response.get('aggregations', {})) | ||
| return hits, pagination, aggregations | ||
|
|
||
| def _extract_hits(self, response): | ||
| hits = response['hits']['hits'][0:self.pagination.size] | ||
| if self.pagination.search_before is not None: | ||
| hits = reversed(hits) | ||
| hits = [hit['_source'] for hit in hits] | ||
| return hits | ||
|
|
||
| def _translate_hits(self, hits): | ||
| f = partial(self.service.translate_fields, self.catalog, forward=False) | ||
| hits = list(map(f, hits)) | ||
| return hits | ||
|
|
||
| def _process_pagination(self, response: JSON) -> MutableJSON: | ||
| total = response['hits']['total'] | ||
| def _extract_hits(self, response: JSON) -> tuple[JSONs, int]: | ||
| hits = json_mapping(response['hits']) | ||
| total = json_mapping(hits['total']) | ||
| # FIXME: Handle other relations | ||
| # https://github.com/DataBiosphere/azul/issues/3770 | ||
| assert total['relation'] == 'eq' | ||
| pages = -(-total['value'] // self.pagination.size) | ||
| return json_sequence_of_mappings(hits['hits']), json_int(total['value']) | ||
|
|
||
| def _translate_hits(self, hits: JSONs) -> JSONs: | ||
| # The slice is necessary because we may have fetched an extra entry to | ||
| # determine if there is a previous or next page. | ||
| hits = hits[0:self.pagination.size] | ||
| hits = iter(hits) if self.pagination.search_before is None else reversed(hits) | ||
| return [ | ||
| self.service.translate_fields(self.catalog, | ||
| json_mapping(hit['_source']), | ||
| forward=False) | ||
| for hit in hits | ||
| ] | ||
|
|
||
| def _process_pagination(self, hits: JSONs, total: int) -> ResponsePagination: | ||
| pages = -(-total // self.pagination.size) | ||
|
|
||
| # ... else use search_after/search_before pagination | ||
| hits: JSONs = response['hits']['hits'] | ||
| count = len(hits) | ||
| if self.pagination.search_before is None: | ||
| # hits are normal sorted | ||
| if count > self.pagination.size: | ||
| # There is an extra hit, indicating a next page. | ||
| count -= 1 | ||
| search_after = tuple(hits[count - 1]['sort']) | ||
| search_after = sort_key_from_json(hits[count - 1]['sort']) | ||
| else: | ||
| # No next page | ||
| search_after = None | ||
| if self.pagination.search_after is not None: | ||
| search_before = tuple(hits[0]['sort']) | ||
| search_before = sort_key_from_json(hits[0]['sort']) | ||
| else: | ||
| search_before = None | ||
| else: | ||
| # hits are reverse sorted | ||
| if count > self.pagination.size: | ||
| # There is an extra hit, indicating a previous page. | ||
| count -= 1 | ||
| search_before = tuple(hits[count - 1]['sort']) | ||
| search_before = sort_key_from_json(hits[count - 1]['sort']) | ||
| else: | ||
| # No previous page | ||
| search_before = None | ||
| search_after = tuple(hits[0]['sort']) | ||
| search_after = sort_key_from_json(hits[0]['sort']) | ||
|
|
||
| pagination = self.pagination.advance(search_before=search_before, | ||
| search_after=search_after) | ||
|
|
||
| def page_link(*, previous): | ||
| def page_link(*, previous: bool) -> str | None: | ||
| url = pagination.link(previous=previous, | ||
| catalog=self.catalog, | ||
| filters=json.dumps(self.filters.explicit)) | ||
| return None if url is None else str(url) | ||
|
|
||
| return ResponsePagination(count=count, | ||
| total=total['value'], | ||
| total=total, | ||
| size=pagination.size, | ||
| next=page_link(previous=False), | ||
| previous=page_link(previous=True), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think JSONArray is overly general so I'm not sure this is the best we can do here. Please just add a FIXME for 6821 pointing out that PrimitiveJSON is incorrect.