diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index e2127b22..17b62bfa 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -659,7 +659,7 @@ def _query_resource( use_document_model: if None, will defer to the self.use_document_model attribute num_chunks: Maximum number of chunks of data to yield. None will yield all possible. chunk_size: Number of data entries per chunk. - timeout : Time in seconds to wait until a request timeout error is thrown + timeout (float or None): Time in seconds to wait until a request timeout error is thrown Returns: A Resource, a dict with two keys, "data" containing a list of documents, and @@ -805,14 +805,16 @@ def _query_resource( except RequestException as ex: raise MPRestError(str(ex)) - def _submit_requests( # noqa + def _submit_requests( self, - url, - criteria, - use_document_model, - chunk_size, - num_chunks=None, - timeout=None, + url: str, + criteria: dict[str, Any], + use_document_model: bool, + chunk_size: int | None, + num_chunks: int | None = None, + timeout: int | None = None, + max_batch_size: int = 100, + norecur: bool = False, ) -> dict: """Handle submitting requests sequentially with pagination. @@ -820,12 +822,15 @@ def _submit_requests( # noqa split them into multiple sequential requests and combine results. Arguments: - criteria: dictionary of criteria to filter down - url: url used to make request - use_document_model: if None, will defer to the self.use_document_model attribute - num_chunks: Maximum number of chunks of data to yield. None will yield all possible. - chunk_size: Number of data entries per chunk. - timeout: Time in seconds to wait until a request timeout error is thrown + url (str): url used to make request + criteria (dict of str): dictionary of criteria to filter down + use_document_model (bool): whether to use the document model + num_chunks (int or None): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int or None): Number of data entries per chunk. + timeout (int or None): Time in seconds to wait until a request timeout error is thrown + max_batch_size (int) : Maximum size of a batch when retrieving batches in parallel + norecur (bool) : Whether to forbid recursive splitting of a query field + when a direct query fails Returns: Dictionary containing data and metadata @@ -884,14 +889,6 @@ def _submit_requests( # noqa timeout=timeout, ) - # Check if we got 0 results - some parameters are silently ignored by the API - # when passed as comma-separated values, so we need to split them anyway - if total_num_docs == 0 and len(split_values) > 1: - # Treat this the same as a 422 error - split into batches - raise MPRestError( - "Got 0 results for comma-separated parameter, will try splitting" - ) - # If successful, continue with normal pagination data_chunks = [data["data"]] total_data: dict[str, Any] = {"data": []} @@ -903,18 +900,26 @@ def _submit_requests( # noqa # Continue with pagination if needed (handled below) except MPRestError as e: - # If we get 422 or 414 error, or 0 results for comma-separated params, split into batches - if any(trace in str(e) for trace in ("422", "414", "Got 0 results")): + # If we get 422 or 414 error, split into batches + if not norecur and any( + trace in str(e) + for trace in ( + "422", + "414", + ) + ): total_data = {"data": []} total_num_docs = 0 data_chunks = [] # Batch the split values to reduce number of requests # Use batches of up to 100 values to balance URL length and request count - batch_size = min(100, max(1, len(split_values) // 10)) + num_batches = min( + max_batch_size, max(1, len(split_values) // max_batch_size) + ) + batch_size = min(len(split_values), max_batch_size) # Setup progress bar for split parameter requests - num_batches = ceil(len(split_values) / batch_size) pbar_message = f"Retrieving {len(split_values)} {split_param} values in {num_batches} batches" pbar = ( tqdm( @@ -938,6 +943,7 @@ def _submit_requests( # noqa chunk_size=chunk_size, num_chunks=num_chunks, timeout=timeout, + norecur=len(batch) <= max_batch_size, ) data_chunks.append(result["data"]) @@ -979,6 +985,12 @@ def _submit_requests( # noqa if "meta" in data: total_data["meta"] = data["meta"] + # otherwise, paginate sequentially + if chunk_size is None or chunk_size < 1: + raise ValueError( + "A positive chunk size must be provided to enable pagination" + ) + # Get max number of response pages max_pages = ( num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size) @@ -998,7 +1010,7 @@ def _submit_requests( # noqa desc=pbar_message, total=num_docs_needed, ) - if not self.mute_progress_bars + if not self.mute_progress_bars and total_num_docs > 0 else None ) @@ -1018,10 +1030,6 @@ def _submit_requests( # noqa pbar.close() return new_total_data - # otherwise, paginate sequentially - if chunk_size is None: - raise ValueError("A chunk size must be provided to enable pagination") - # Warning to select specific fields only for many results if criteria.get("_all_fields", False) and (total_num_docs / chunk_size > 10): warnings.warn( diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index ece9d7bc..c2dfefde 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -1359,23 +1359,17 @@ def get_download_info( return meta, urls def _check_get_download_info_url_by_task_id(self, prefix, task_ids) -> list[str]: - nomad_exist_task_ids: list[str] = [] prefix = prefix.replace("/raw/query", "/repo/") - for task_id in task_ids: - url = prefix + task_id - if self._check_nomad_exist(url): - nomad_exist_task_ids.append(task_id) - return nomad_exist_task_ids + return [ + task_id for task_id in task_ids if self._check_nomad_exist(prefix + task_id) + ] @staticmethod def _check_nomad_exist(url) -> bool: response = get(url=url) if response.status_code != 200: return False - content = load_json(response.text) - if content["pagination"]["total"] == 0: - return False - return True + return load_json(response.text)["pagination"]["total"] != 0 @staticmethod def _print_help_message(nomad_exist_task_ids, task_ids, file_patterns, calc_types): diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 1769781e..7cb2f2db 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -2,6 +2,7 @@ import warnings from collections import defaultdict +from itertools import chain, product from emmet.core.summary import HasProps, SummaryDoc from emmet.core.symmetry import CrystalSystem @@ -200,8 +201,9 @@ def search( # noqa: D417 mmnd_inv = {v: k for k, v in min_max_name_dict.items() if k != v} # Set user query params from `locals` + _locals = locals() user_settings = { - k: v for k, v in locals().items() if k in min_max_name_dict and v + k: v for k, v in _locals.items() if k in min_max_name_dict and v } # Check to see if user specified _search fields using **kwargs, @@ -328,10 +330,11 @@ def _csrc(x): "spacegroup_number": 230, "spacegroup_symbol": 230, } + batched_symm_query = {} for k, cardinality in symm_cardinality.items(): - if isinstance(symm_vals := locals().get(k), list | tuple | set): + if isinstance(symm_vals := _locals.get(k), list | tuple | set): if len(symm_vals) < cardinality // 2: - query_params.update({k: ",".join(str(v) for v in symm_vals)}) + batched_symm_query[k] = symm_vals else: raise MPRestError( f"Querying `{k}` by a list of values is only " @@ -378,6 +381,24 @@ def _csrc(x): if query_params[entry] is not None } + if batched_symm_query: + ordered_symm_key = sorted(batched_symm_query) + return list( + chain.from_iterable( + self._search( # type: ignore[return-value] + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, + **{sk: symm_params[i] for i, sk in enumerate(ordered_symm_key)}, + ) + for symm_params in product( + *[batched_symm_query[k] for k in ordered_symm_key] + ) + ) + ) + return super()._search( # type: ignore[return-value] num_chunks=num_chunks, chunk_size=chunk_size, diff --git a/tests/client/test_core_client.py b/tests/client/test_core_client.py index 89c83c67..2fffae46 100644 --- a/tests/client/test_core_client.py +++ b/tests/client/test_core_client.py @@ -94,3 +94,21 @@ def test_warnings_exceptions(): with pytest.raises(MPRestError, match="Number of chunks must be greater than zero"): MaterialsRester()._get_all_documents({}, num_chunks=-1) + + +def test_regression_batching(): + # See https://github.com/materialsproject/api/pull/1077 + # This test ensures that queries with batched input + # that return no results do not infinitely recur + + # This test should simply run if there is no regression + + num_idxs = 100 + assert ( + len( + MaterialsRester().search( + material_ids=[f"mp-{idx}" for idx in range(num_idxs)], deprecated=True + ) + ) + <= num_idxs + )