|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import warnings |
3 | 4 | from collections import defaultdict |
4 | 5 |
|
5 | 6 | from emmet.core.eos import EOSDoc |
6 | 7 |
|
7 | | -from mp_api.client.core import BaseRester |
| 8 | +from mp_api.client.core import BaseRester, MPRestError, MPRestWarning |
8 | 9 | from mp_api.client.core.utils import validate_ids |
9 | 10 |
|
10 | 11 |
|
11 | 12 | class EOSRester(BaseRester): |
12 | 13 | suffix = "materials/eos" |
13 | 14 | document_model = EOSDoc # type: ignore |
14 | | - primary_key = "material_id" |
| 15 | + primary_key = "task_id" |
15 | 16 |
|
16 | 17 | def search( |
17 | 18 | self, |
18 | | - material_ids: str | list[str] | None = None, |
| 19 | + task_ids: str | list[str] | None = None, |
19 | 20 | energies: tuple[float, float] | None = None, |
20 | 21 | volumes: tuple[float, float] | None = None, |
21 | 22 | num_chunks: int | None = None, |
22 | 23 | chunk_size: int = 1000, |
23 | 24 | all_fields: bool = True, |
24 | 25 | fields: list[str] | None = None, |
| 26 | + **kwargs, |
25 | 27 | ) -> list[EOSDoc] | list[dict]: |
26 | 28 | """Query equations of state docs using a variety of search criteria. |
27 | 29 |
|
28 | 30 | Arguments: |
29 | | - material_ids (str, List[str]): Search for equation of states associated with the specified Material IDs |
| 31 | + task_ids (str, List[str]): Search for equation of states associated with the specified task IDs |
30 | 32 | energies (Tuple[float,float]): Minimum and maximum energy in eV/atom to consider for EOS plot range. |
31 | 33 | volumes (Tuple[float,float]): Minimum and maximum volume in A³/atom to consider for EOS plot range. |
32 | 34 | num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. |
33 | 35 | chunk_size (int): Number of data entries per chunk. |
34 | 36 | all_fields (bool): Whether to return all fields in the document. Defaults to True. |
35 | 37 | fields (List[str]): List of fields in EOSDoc to return data for. |
36 | 38 | Default is material_id only if all_fields is False. |
| 39 | + **kwargs : used for handling deprecated kwargs |
37 | 40 |
|
38 | 41 | Returns: |
39 | 42 | ([EOSDoc], [dict]) List of equations of state docs or dictionaries. |
40 | 43 | """ |
41 | 44 | query_params: dict = defaultdict(dict) |
42 | 45 |
|
43 | | - if material_ids: |
44 | | - if isinstance(material_ids, str): |
45 | | - material_ids = [material_ids] |
| 46 | + if "material_ids" in kwargs: |
| 47 | + if task_ids: |
| 48 | + raise MPRestError( |
| 49 | + "You have specified both `task_ids` and the deprecated `material_ids` tag. " |
| 50 | + "Please specify only `task_ids`." |
| 51 | + ) |
| 52 | + task_ids = kwargs.pop("material_ids") |
| 53 | + warnings.warn( |
| 54 | + "`material_id` has been replaced by `task_id` in the EOS endpoint. " |
| 55 | + "Please migrate to using the newer field name.", |
| 56 | + stacklevel=2, |
| 57 | + category=MPRestWarning, |
| 58 | + ) |
46 | 59 |
|
47 | | - query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) |
| 60 | + if task_ids: |
| 61 | + query_params["material_ids"] = ",".join( |
| 62 | + validate_ids([task_ids] if isinstance(task_ids, str) else task_ids) |
| 63 | + ) |
48 | 64 |
|
49 | 65 | if volumes: |
50 | 66 | query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]}) |
|
0 commit comments