Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ modules =
azul.service.manifest_service,
azul.field_type,
azul.source,
azul.service.query_service,


packages =
Expand Down
5 changes: 5 additions & 0 deletions src/azul/lib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def json_sorted(vs: Iterable[PrimitiveJSON]) -> MutableJSONArray:
return sorted(vs, key=none_safe_key(none_last=True))


def json_primitive(v: AnyJSON) -> PrimitiveJSON:
assert v is None or isinstance(v, (str, int, float, bool)), type(v)
return v


def json_str(v: AnyMutableJSON | AnyJSON) -> str:
return any_str(v)

Expand Down
156 changes: 88 additions & 68 deletions src/azul/service/query_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
from collections.abc import (
Iterable,
Mapping,
Sequence,
)
from functools import (
partial,
)
import json
import logging
Expand Down Expand Up @@ -53,6 +49,7 @@
)
from azul.indexer.document import (
DocumentType,
FieldPath,
IndexName,
)
from azul.indexer.document_service import (
Expand All @@ -65,23 +62,34 @@
from azul.lib.types import (
AnyJSON,
JSON,
JSONArray,
JSONTypedDict,
JSONs,
MutableJSON,
PrimitiveJSON,
json_list,
json_dict,
json_dict_of_dicts,
json_element_dicts,
json_element_strings,
json_int,
json_item_sequences,
json_list_of_dicts,
json_mapping,
json_primitive,
json_sequence,
json_sequence_of_mappings,
json_str,
)
from azul.opensearch import (
OpenSearchClientFactory,
)
from azul.plugins import (
DocumentSlice,
FieldPath,
MetadataPlugin,
dotted,
)
from azul.service import (
FilterJSON,
Filters,
FiltersJSON,
)
Expand Down Expand Up @@ -168,7 +176,7 @@ def wrap[R0](self, other: OpenSearchStage[R0, R1]) -> OpenSearchChain[R0, R1, R2
return OpenSearchChain(inner=other, outer=self)


TranslatedFilters = Mapping[FieldPath, Mapping[str, Sequence[PrimitiveJSON]]]
TranslatedFilters = Mapping[FieldPath, Mapping[str, JSONArray]]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think JSONArray is overly general so I'm not sure this is the best we can do here. Please just add a FIXME for 6821 pointing out that PrimitiveJSON is incorrect.



@attr.s(frozen=True, auto_attribs=True, kw_only=True)
Expand Down Expand Up @@ -216,30 +224,38 @@ def _translate_filters(self, filters: FiltersJSON) -> TranslatedFilters:
"""
catalog = self.catalog
field_mapping = self.plugin.field_mapping
translated_filters = {}
for field, filter in filters.items():
field = field_mapping[field]
operator, values = one(filter.items())
field_type = self.service.field_type(catalog, field)
values = field_type.filter(operator, values)
translated_filters[field] = {operator: list(values)}
return translated_filters

def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query:
def translate_filter(field_name: str,
filter: FilterJSON
) -> tuple[FieldPath, Mapping[str, JSONArray]]:
field_path = field_mapping[field_name]
operator, values = one(filter.items())
field_type = self.service.field_type(catalog, field_path)
# FIXME: remove `type: ignore`
# https://github.com/DataBiosphere/azul/issues/6821
values: JSONArray = list(field_type.filter(operator, values)) # type: ignore
return field_path, {operator: values}

return dict(
translate_filter(field, filter)
for field, filter in filters.items()
)

def prepare_query(self, skip_field_paths: tuple[FieldPath, ...] = ()) -> Query:
Comment on lines -219 to +244
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we agreed to only touch filter-related code when it's trivial narrowing? This is a non-trivial change, and it still requires a pragma so I don't know what it's buying us.

"""
Converts the given filters into an OpenSearch DSL Query object.
"""
filter_list = []
for field_path, filter in self.prepared_filters.items():
if field_path not in skip_field_paths:
operator, values = one(filter.items())
operator, values = one(json_item_sequences(filter))
# Note that `is_not` is only used internally (for filtering by
# inaccessible sources)
if operator in ('is', 'is_not'):
field_type = self.service.field_type(self.catalog, field_path)
if isinstance(field_type, Nested):
term_queries = []
for nested_field, nested_value in one(values).items():
for nested_field, nested_value in json_mapping(one(values)).items():
nested_body = {dotted(field_path, nested_field, 'keyword'): nested_value}
term_queries.append(Q('term', **nested_body))
query = Q('nested', path=dotted(field_path), query=Q('bool', must=term_queries))
Expand All @@ -258,7 +274,7 @@ def prepare_query(self, skip_field_paths: tuple[FieldPath] = ()) -> Query:
filter_list.append(query)
elif operator in ('contains', 'within', 'intersects'):
for value in values:
value = value | {'relation': operator}
value = {**json_mapping(value), 'relation': operator}
filter_list.append(Q('range', **{dotted(field_path): value}))
else:
assert False
Expand Down Expand Up @@ -307,7 +323,7 @@ def prepare_request(self, request: Search) -> Search:

def process_response(self, response: MutableJSON) -> MutableJSON:
try:
aggs = response['aggregations']
aggs = json_dict(response['aggregations'])
except KeyError:
pass
else:
Expand All @@ -330,7 +346,7 @@ def _prepare_aggregation(self, *, facet: str, facet_path: FieldPath) -> Agg:
nested_agg = agg.bucket(name='nested',
agg_type='nested',
path=dotted(facet_path))
facet_path = dotted(facet_path, field_type.agg_property)
facet_path = (*facet_path, field_type.agg_property)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a semantic change.

else:
nested_agg = agg
# Make an inner agg that will contain the terms in question
Expand Down Expand Up @@ -368,9 +384,9 @@ def annotate(agg: Agg):
annotate(request.aggs[agg_name])

def _flatten_nested_aggs(self, aggs: MutableJSON):
for facet, agg in aggs.items():
for facet, agg in json_dict_of_dicts(aggs).items():
try:
nested_agg = agg.pop('nested')
nested_agg = json_dict(agg.pop('nested'))
except KeyError:
pass
else:
Expand All @@ -382,26 +398,27 @@ def _translate_response_aggs(self, aggs: MutableJSON):
OpenSearch response.
"""

def translate(k, v: MutableJSON):
def translate(k: str, v: MutableJSON):
try:
buckets = v['buckets']
except KeyError:
for k, v in v.items():
if isinstance(v, dict):
translate(k, v)
for ki, vi in v.items():
if isinstance(vi, dict):
translate(ki, vi)
else:
try:
path = v['meta']['path']
path = json_dict(v['meta'])['path']
except KeyError:
pass
else:
field_type = self.service.field_type(self.catalog, tuple(path))
for bucket in buckets:
field_type = self.service.field_type(self.catalog,
tuple(json_element_strings(path)))
for bucket in json_element_dicts(buckets):
bucket['key'] = field_type.from_index(bucket['key'])
translate(k, bucket)

for k, v in aggs.items():
translate(k, v)
translate(k, json_dict(v))

def _populate_accessible(self, aggs: MutableJSON) -> None:
# Because the value of the `accessible` field depends on the provided
Expand All @@ -410,12 +427,14 @@ def _populate_accessible(self, aggs: MutableJSON) -> None:
source_ids = self.filter_stage.filters.source_ids
plugin = self.service.metadata_plugin(self.catalog)
special_fields = plugin.special_fields
agg = aggs.pop(special_fields.source_id.name)
agg = json_dict(aggs.pop(special_fields.source_id.name))
counts_by_accessibility: dict[bool, int] = defaultdict(int)
for bucket in agg['myTerms']['buckets']:
terms = json_dict(agg['myTerms'])
buckets = json_list_of_dicts(terms['buckets'])
for bucket in buckets:
accessible = bucket['key'] in source_ids
counts_by_accessibility[accessible] += bucket['doc_count']
agg['myTerms']['buckets'] = [
counts_by_accessibility[accessible] += json_int(bucket['doc_count'])
terms['buckets'] = [
{'key': accessible, 'doc_count': count}
for accessible, count in counts_by_accessibility.items()
]
Expand Down Expand Up @@ -465,8 +484,8 @@ def process_response(self, response: Response) -> MutableJSON:


def sort_key_from_json(s: AnyJSON) -> SortKey:
a, b = json_list(s)
return a, json_str(b)
a, b = json_sequence(s)
return json_primitive(a), json_str(b)


def sort_key_to_json(s: SortKey) -> AnyJSON:
Expand Down Expand Up @@ -532,12 +551,12 @@ class PaginationStage(_OpenSearchStage[JSON, ResponseTriple]):

def prepare_request(self, request: Search) -> Search:
sort_order = self.pagination.order
sort_field = self.plugin.field_mapping[self.pagination.sort]
field_type = self.service.field_type(self.catalog, sort_field)
sort_field_path = self.plugin.field_mapping[self.pagination.sort]
field_type = self.service.field_type(self.catalog, sort_field_path)
sort_mode = field_type.es_sort_mode
sort_field = dotted(sort_field, 'keyword')
sort_field = dotted(sort_field_path, 'keyword')
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't sort_field also a path, just in a different form? If so, then your variable naming choice is confusing.


def sort(order):
def sort(order: str) -> tuple[JSON, JSON]:
assert order in ('asc', 'desc'), order
return (
{
Expand Down Expand Up @@ -590,71 +609,72 @@ def process_response(self, response: JSON) -> ResponseTriple:
"""
Returns hits and pagination as dict
"""
# The slice is necessary because we may have fetched an extra entry to
# determine if there is a previous or next page.
hits = self._extract_hits(response)
hits, total = self._extract_hits(response)
pagination = self._process_pagination(hits, total)
hits = self._translate_hits(hits)
pagination = self._process_pagination(response)
aggregations = response.get('aggregations', {})
aggregations = json_mapping(response.get('aggregations', {}))
return hits, pagination, aggregations

def _extract_hits(self, response):
hits = response['hits']['hits'][0:self.pagination.size]
if self.pagination.search_before is not None:
hits = reversed(hits)
hits = [hit['_source'] for hit in hits]
return hits

def _translate_hits(self, hits):
f = partial(self.service.translate_fields, self.catalog, forward=False)
hits = list(map(f, hits))
return hits

def _process_pagination(self, response: JSON) -> MutableJSON:
total = response['hits']['total']
def _extract_hits(self, response: JSON) -> tuple[JSONs, int]:
hits = json_mapping(response['hits'])
total = json_mapping(hits['total'])
# FIXME: Handle other relations
# https://github.com/DataBiosphere/azul/issues/3770
assert total['relation'] == 'eq'
pages = -(-total['value'] // self.pagination.size)
return json_sequence_of_mappings(hits['hits']), json_int(total['value'])

def _translate_hits(self, hits: JSONs) -> JSONs:
# The slice is necessary because we may have fetched an extra entry to
# determine if there is a previous or next page.
hits = hits[0:self.pagination.size]
hits = iter(hits) if self.pagination.search_before is None else reversed(hits)
return [
self.service.translate_fields(self.catalog,
json_mapping(hit['_source']),
forward=False)
for hit in hits
]

def _process_pagination(self, hits: JSONs, total: int) -> ResponsePagination:
pages = -(-total // self.pagination.size)

# ... else use search_after/search_before pagination
hits: JSONs = response['hits']['hits']
count = len(hits)
if self.pagination.search_before is None:
# hits are normal sorted
if count > self.pagination.size:
# There is an extra hit, indicating a next page.
count -= 1
search_after = tuple(hits[count - 1]['sort'])
search_after = sort_key_from_json(hits[count - 1]['sort'])
else:
# No next page
search_after = None
if self.pagination.search_after is not None:
search_before = tuple(hits[0]['sort'])
search_before = sort_key_from_json(hits[0]['sort'])
else:
search_before = None
else:
# hits are reverse sorted
if count > self.pagination.size:
# There is an extra hit, indicating a previous page.
count -= 1
search_before = tuple(hits[count - 1]['sort'])
search_before = sort_key_from_json(hits[count - 1]['sort'])
else:
# No previous page
search_before = None
search_after = tuple(hits[0]['sort'])
search_after = sort_key_from_json(hits[0]['sort'])

pagination = self.pagination.advance(search_before=search_before,
search_after=search_after)

def page_link(*, previous):
def page_link(*, previous: bool) -> str | None:
url = pagination.link(previous=previous,
catalog=self.catalog,
filters=json.dumps(self.filters.explicit))
return None if url is None else str(url)

return ResponsePagination(count=count,
total=total['value'],
total=total,
size=pagination.size,
next=page_link(previous=False),
previous=page_link(previous=True),
Expand Down
Loading