Skip to content

Commit 2655407

Browse files
fix eos, task traj, add __dir__ to lazy import + fix attr access
1 parent 3f9bec4 commit 2655407

8 files changed

Lines changed: 73 additions & 29 deletions

File tree

mp_api/_test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def client_search_testing(
7878
doc = docs[0].model_dump()
7979
else:
8080
raise ValueError("No documents returned")
81+
print(doc)
8182

8283
for sub_field in sub_doc_fields:
8384
if sub_field in doc:

mp_api/client/core/client.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from itertools import chain, islice
2424
from json import JSONDecodeError
2525
from math import ceil
26+
from pathlib import Path
2627
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
2728
from urllib.parse import urljoin
2829

@@ -184,7 +185,7 @@ def __init__(
184185

185186
self.use_document_model = use_document_model
186187
self.mute_progress_bars = mute_progress_bars
187-
self.local_dataset_cache = local_dataset_cache
188+
self.local_dataset_cache = Path(local_dataset_cache)
188189
self.force_renew = force_renew
189190
self._query_builder = query_builder
190191

@@ -1436,12 +1437,7 @@ def _convert_to_model(
14361437
)
14371438

14381439
return [
1439-
data_model(
1440-
**{
1441-
field: raw_doc[field]
1442-
for field in set_fields.intersection(raw_doc)
1443-
}
1444-
)
1440+
data_model(**raw_doc)
14451441
for raw_doc in (data if is_list else chain([first_doc], data))
14461442
]
14471443

@@ -1464,7 +1460,14 @@ def _generate_returned_model(
14641460
set of str: set_fields, fields_not_requested)
14651461
"""
14661462
model_fields = self.document_model.model_fields
1467-
set_fields = set(doc).intersection(model_fields)
1463+
aliases = {
1464+
anno.alias: field for field, anno in model_fields.items() if anno.alias
1465+
}
1466+
set_fields = (
1467+
set(doc)
1468+
.intersection(model_fields)
1469+
.union({aliases[k] for k in set(doc).intersection(aliases)})
1470+
)
14681471
unset_fields = set(model_fields).difference(set_fields)
14691472
user_requested_fields: list[str] = requested_fields or []
14701473
fields_not_requested = unset_fields.difference(user_requested_fields)

mp_api/client/core/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def validate_ids(id_list: list[str]) -> list[str]:
116116
" data for all IDs and filter locally."
117117
)
118118

119-
[validate_identifier(idx, serialize=False) for idx in id_list]
120-
return [getattr(idx, "string", str(idx)) for idx in id_list]
119+
validated = [validate_identifier(idx, serialize=False) for idx in id_list]
120+
return [getattr(idx, "string", str(idx)) for idx in validated]
121121

122122

123123
def validate_endpoint(endpoint: str | None, suffix: str | None = None) -> str:
@@ -241,6 +241,14 @@ def __getattr__(self, v: str) -> Any:
241241
if hasattr(self._imported, v):
242242
return getattr(self._imported, v)
243243

244+
raise AttributeError(
245+
f"{self._module_name}{'.' + self._class_name if self._class_name else ''} "
246+
f"has no attribute {v}"
247+
)
248+
249+
def __dir__(self) -> list[str]:
250+
return self._obj.__dir__()
251+
244252

245253
class MPDataset:
246254
"""Convenience wrapper for pyarrow datasets stored on disk."""

mp_api/client/routes/materials/electronic_structure.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ def get_bandstructure_from_material_id(
326326
material_ids=material_id, fields=["bandstructure"]
327327
)
328328
if not bs_doc:
329-
raise MPRestError("No electronic structure data found.")
329+
raise MPRestError(
330+
f"No electronic structure data found for material ID {material_id}."
331+
)
330332

331333
if (_bs_data := bs_doc[0]["bandstructure"]) is None:
332334
raise MPRestError(
@@ -349,7 +351,9 @@ def get_bandstructure_from_material_id(
349351
material_ids=material_id, fields=["dos"]
350352
)
351353
):
352-
raise MPRestError("No electronic structure data found.")
354+
raise MPRestError(
355+
f"No electronic structure data found for material ID {material_id}."
356+
)
353357

354358
if (_bs_data := bs_doc[0]["dos"]) is None:
355359
raise MPRestError(
@@ -538,7 +542,9 @@ def get_dos_from_material_id(self, material_id: str) -> Dos:
538542
if not (
539543
dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"])
540544
):
541-
return None
545+
raise MPRestError(
546+
f"No electronic structure data found for material ID {material_id}."
547+
)
542548

543549
if not (dos_data := dos_doc[0].get("dos")):
544550
raise MPRestError(f"No density of states data found for {material_id}")

mp_api/client/routes/materials/eos.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,66 @@
11
from __future__ import annotations
22

3+
import warnings
34
from collections import defaultdict
45

56
from emmet.core.eos import EOSDoc
67

7-
from mp_api.client.core import BaseRester
8+
from mp_api.client.core import BaseRester, MPRestError, MPRestWarning
89
from mp_api.client.core.utils import validate_ids
910

1011

1112
class EOSRester(BaseRester):
1213
suffix = "materials/eos"
1314
document_model = EOSDoc # type: ignore
14-
primary_key = "material_id"
15+
primary_key = "task_id"
1516

1617
def search(
1718
self,
18-
material_ids: str | list[str] | None = None,
19+
task_ids: str | list[str] | None = None,
1920
energies: tuple[float, float] | None = None,
2021
volumes: tuple[float, float] | None = None,
2122
num_chunks: int | None = None,
2223
chunk_size: int = 1000,
2324
all_fields: bool = True,
2425
fields: list[str] | None = None,
26+
**kwargs,
2527
) -> list[EOSDoc] | list[dict]:
2628
"""Query equations of state docs using a variety of search criteria.
2729
2830
Arguments:
29-
material_ids (str, List[str]): Search for equation of states associated with the specified Material IDs
31+
task_ids (str, List[str]): Search for equation of states associated with the specified task IDs
3032
energies (Tuple[float,float]): Minimum and maximum energy in eV/atom to consider for EOS plot range.
3133
volumes (Tuple[float,float]): Minimum and maximum volume in A³/atom to consider for EOS plot range.
3234
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
3335
chunk_size (int): Number of data entries per chunk.
3436
all_fields (bool): Whether to return all fields in the document. Defaults to True.
3537
fields (List[str]): List of fields in EOSDoc to return data for.
3638
Default is material_id only if all_fields is False.
39+
**kwargs : used for handling deprecated kwargs
3740
3841
Returns:
3942
([EOSDoc], [dict]) List of equations of state docs or dictionaries.
4043
"""
4144
query_params: dict = defaultdict(dict)
4245

43-
if material_ids:
44-
if isinstance(material_ids, str):
45-
material_ids = [material_ids]
46+
if "material_ids" in kwargs:
47+
if task_ids:
48+
raise MPRestError(
49+
"You have specified both `task_ids` and the deprecated `material_ids` tag. "
50+
"Please specify only `task_ids`."
51+
)
52+
task_ids = kwargs.pop("material_ids")
53+
warnings.warn(
54+
"`material_id` has been replaced by `task_id` in the EOS endpoint. "
55+
"Please migrate to using the newer field name.",
56+
stacklevel=2,
57+
category=MPRestWarning,
58+
)
4659

47-
query_params.update({"material_ids": ",".join(validate_ids(material_ids))})
60+
if task_ids:
61+
query_params["material_ids"] = ",".join(
62+
validate_ids([task_ids] if isinstance(task_ids, str) else task_ids)
63+
)
4864

4965
if volumes:
5066
query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]})

mp_api/client/routes/materials/tasks.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@ def get_trajectory(
3939
"""
4040
as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1]
4141
predicate = (
42-
f"WHERE run_type='{str(run_type)}' AND identifier='{as_alpha}'"
43-
if run_type
44-
else f"WHERE identifier='{as_alpha}'"
45-
)
42+
f"WHERE run_type='{str(run_type)}' AND " if run_type else ""
43+
) + f"WHERE identifier='{as_alpha}'"
4644

4745
traj_lbl, traj_tbl = self._get_delta_table(
4846
"materialsproject-parsed",
@@ -53,7 +51,6 @@ def get_trajectory(
5351
query = f"""
5452
SELECT *
5553
FROM {traj_lbl}
56-
WHERE identifier='{as_alpha}'
5754
{predicate};
5855
"""
5956

tests/client/materials/test_electronic_structure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_bs_client():
104104
with pytest.raises(MPRestError, match="No electronic structure data found."):
105105
_ = bs_rester.get_bandstructure_from_material_id("mp-0")
106106

107-
with pytest.raises(MPRestError, match="No object found"):
107+
with pytest.raises(MPRestError, match="No bandstructure data found"):
108108
_ = bs_rester.get_bandstructure_from_task_id("mp-0")
109109

110110

tests/client/materials/test_eos.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from mp_api._test_utils import client_search_testing, requires_api_key
66

7+
from mp_api.client.core.exceptions import MPRestError, MPRestWarning
78
from mp_api.client.routes.materials.eos import EOSRester
89

910

@@ -26,9 +27,9 @@ def rester():
2627

2728
sub_doc_fields: list = []
2829

29-
alt_name_dict: dict = {"material_ids": "material_id"}
30+
alt_name_dict: dict = {"task_ids": "task_id"}
3031

31-
custom_field_tests: dict = {"material_ids": ["mp-149"]}
32+
custom_field_tests: dict = {"task_ids": ["mp-149"]}
3233

3334

3435
@requires_api_key
@@ -42,3 +43,15 @@ def test_client(rester):
4243
custom_field_tests=custom_field_tests,
4344
sub_doc_fields=sub_doc_fields,
4445
)
46+
47+
48+
@requires_api_key
49+
def test_warnings_errors(rester):
50+
51+
with pytest.warns(
52+
MPRestWarning, match="`material_id` has been replaced by `task_id`"
53+
):
54+
rester.search(material_ids=["mp-149"], num_chunks=1, chunk_size=1)
55+
56+
with pytest.raises(MPRestError, match="You have specified both"):
57+
rester.search(material_ids=["mp-149"], task_ids=["mp-1"])

0 commit comments

Comments
 (0)