Skip to content

Commit ecfbbcd

Browse files
refactor get stability test to not rely on golden test data
1 parent d2d43f4 commit ecfbbcd

3 files changed

Lines changed: 107 additions & 84 deletions

File tree

mp_api/client/mprester.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,32 +1601,34 @@ def get_stability(
16011601
}
16021602
chemsys_str = "-".join(sorted(str(ele) for ele in chemsys))
16031603

1604-
thermo_type = (
1605-
ThermoType(thermo_type) if isinstance(thermo_type, str) else thermo_type
1604+
thermo_type_valid_str: str = (
1605+
ThermoType(thermo_type).value
1606+
if (isinstance(thermo_type, str) and thermo_type != "r2SCAN")
1607+
else str(thermo_type)
16061608
)
16071609

16081610
corrector: Compatibility | None = None
1609-
if thermo_type == ThermoType.GGA_GGA_U:
1611+
if thermo_type_valid_str == ThermoType.GGA_GGA_U.value:
16101612
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
16111613

16121614
corrector = MaterialsProject2020Compatibility()
16131615

1614-
elif thermo_type == ThermoType.GGA_GGA_U_R2SCAN:
1616+
elif thermo_type_valid_str == ThermoType.GGA_GGA_U_R2SCAN.value:
16151617
from pymatgen.entries.mixing_scheme import MaterialsProjectDFTMixingScheme
16161618

16171619
corrector = MaterialsProjectDFTMixingScheme(run_type_2="r2SCAN")
16181620

16191621
try:
16201622
pd = self.materials.thermo.get_phase_diagram_from_chemsys(
1621-
chemsys_str, thermo_type=thermo_type
1623+
chemsys_str, thermo_type=thermo_type_valid_str
16221624
)
16231625
except MPRestError:
16241626
pd = None
16251627

16261628
if not pd:
16271629
warnings.warn(
16281630
f"No phase diagram data available for chemical system {chemsys_str} "
1629-
f"and thermo type {thermo_type}.",
1631+
f"and thermo type {thermo_type_valid_str}.",
16301632
category=MPRestWarning,
16311633
stacklevel=2,
16321634
)

mp_api/client/routes/materials/thermo.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from typing import TYPE_CHECKING
45

56
import numpy as np
67
from emmet.core.thermo import ThermoDoc, validate_thermo_id
@@ -14,12 +15,45 @@
1415
from mp_api.client.core.exceptions import MPRestError
1516
from mp_api.client.core.utils import validate_ids
1617

18+
if TYPE_CHECKING:
19+
from collections.abc import Sequence
20+
21+
from enums import Enum
22+
1723

1824
class ThermoRester(BaseRester):
1925
suffix = "materials/thermo"
2026
document_model = ThermoDoc # type: ignore
2127
primary_key = "material_id"
2228

29+
@staticmethod
30+
def _check_thermo_types(thermo_types: Sequence[str | Enum]) -> set[str]:
31+
"""Check if a user has input any invalid thermo types.
32+
33+
Args:
34+
thermo_types (Sequence of str or Enum) : list of thermo types
35+
the user has queried for
36+
37+
phase-diagram tbl has "r2SCAN", not "R2SCAN"
38+
mixing of ThermoType/RunType in emmet -_-
39+
TODO: coerce upstream? allow case-insensitivity in emmet?
40+
41+
Returns:
42+
set of str: validated thermo types
43+
44+
Raises:
45+
ValueError if any invalid thermo types are input
46+
"""
47+
t_types: set[str] = {t if isinstance(t, str) else t.value for t in thermo_types}
48+
t_types = {"r2SCAN" if t == "R2SCAN" else t for t in t_types}
49+
valid_types = {"r2SCAN", *map(str, ThermoType.__members__.values())}
50+
51+
if invalid_types := t_types - valid_types:
52+
raise ValueError(
53+
f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}"
54+
)
55+
return t_types
56+
2357
def search(
2458
self,
2559
material_ids: str | list[str] | None = None,
@@ -57,7 +91,7 @@ def search(
5791
material_ids (List[str]): List of Materials Project IDs to return data for.
5892
thermo_ids (List[str]): List of thermo IDs to return data for. This is a combination of the Materials
5993
Project ID and thermo type (e.g. mp-149_GGA_GGA+U).
60-
thermo_types (List[ThermoType]): List of thermo types to return data for (e.g. ThermoType.GGA_GGA_U).
94+
thermo_types (List[ThermoType or str]): List of thermo/run types to return data for (e.g. ThermoType.GGA_GGA_U).
6195
num_elements (Tuple[int,int]): Minimum and maximum number of elements in the material to consider.
6296
total_energy (Tuple[float,float]): Minimum and maximum corrected total energy in eV/atom to consider.
6397
uncorrected_energy (Tuple[float,float]): Minimum and maximum uncorrected total
@@ -105,13 +139,9 @@ def search(
105139
)
106140

107141
if thermo_types:
108-
t_types = {t if isinstance(t, str) else t.value for t in thermo_types}
109-
valid_types = {*map(str, ThermoType.__members__.values())}
110-
if invalid_types := t_types - valid_types:
111-
raise ValueError(
112-
f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}"
113-
)
114-
query_params.update({"thermo_types": ",".join(t_types)})
142+
query_params.update(
143+
{"thermo_types": ",".join(self._check_thermo_types(thermo_types))}
144+
)
115145

116146
if num_elements:
117147
if isinstance(num_elements, int):
@@ -168,12 +198,7 @@ def get_phase_diagram_from_chemsys(
168198
Returns:
169199
(PhaseDiagram): Pymatgen phase diagram object.
170200
"""
171-
t_type = thermo_type if isinstance(thermo_type, str) else thermo_type.value
172-
valid_types = {*map(str, ThermoType.__members__.values())}
173-
if invalid_types := {t_type} - valid_types:
174-
raise ValueError(
175-
f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}"
176-
)
201+
validated_thermo_type = self._check_thermo_types([thermo_type]).pop()
177202

178203
sorted_chemsys = "-".join(sorted(chemsys.split("-")))
179204
version = self.db_version.replace(".", "-")
@@ -182,18 +207,12 @@ def get_phase_diagram_from_chemsys(
182207
"materialsproject-build", "objects/phase-diagrams", label="phase_diagrams"
183208
)
184209

185-
# phase-diagram tbl has r2SCAN, not R2SCAN
186-
# mixing of ThermoType/RunType in emmet -_-
187-
# TODO: coerce upstream? allow case-insensitivity in emmet?
188-
if thermo_type == ThermoType.R2SCAN:
189-
thermo_type = "r2SCAN"
190-
191210
query = f"""
192211
SELECT phase_diagram
193212
FROM {pd_lbl}
194213
WHERE chemsys='{sorted_chemsys}'
195214
AND version='{version}'
196-
AND thermo_type='{thermo_type}'
215+
AND thermo_type='{validated_thermo_type}'
197216
"""
198217
table = self._query_delta_single(query)
199218
as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict")

tests/client/test_mprester.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -564,81 +564,83 @@ def test_get_cohesive_energy(self):
564564
mpr.get_cohesive_energy("mp-1")
565565

566566
@pytest.mark.parametrize(
567-
"chemsys, thermo_type",
568-
[
569-
[("Fe", "P"), "GGA_GGA+U"],
570-
[("Li", "S"), ThermoType.GGA_GGA_U_R2SCAN],
571-
[("Ni", "Se"), ThermoType.R2SCAN],
572-
[("Ni", "Kr"), "R2SCAN"],
573-
],
567+
"thermo_type", ["GGA_GGA+U", ThermoType.GGA_GGA_U_R2SCAN, "r2SCAN"]
574568
)
575-
def test_get_stability(self, chemsys, thermo_type):
569+
def test_get_stability(self, thermo_type):
576570
"""
577571
This test is adapted from the pymatgen one - the scope is broadened
578572
to include more diverse chemical environments and thermo types which
579573
reflect the scope of the current MP database.
580574
"""
581575
with MPRester() as mpr:
582-
entries = mpr.get_entries_in_chemsys(
583-
chemsys, additional_criteria={"thermo_types": [thermo_type]}
584-
)
585576

586-
no_compound_entries = all(
587-
len(entry.composition.elements) == 1 for entry in entries
588-
)
577+
# No golden test data. Always test on fetched thermo data
578+
chemsys_to_test: set[str] = {
579+
doc.chemsys
580+
for doc in mpr.materials.thermo.search(
581+
thermo_types=[thermo_type],
582+
num_elements=2,
583+
num_chunks=1,
584+
chunk_size=4,
585+
fields=["chemsys"],
586+
)
587+
}
588+
589+
for chemsys in chemsys_to_test:
589590

590-
modified_entries = [
591-
ComputedEntry(
592-
entry.composition,
593-
entry.uncorrected_energy + 0.01,
594-
parameters=entry.parameters,
595-
entry_id=f"mod_{entry.entry_id}",
591+
entries = mpr.get_entries_in_chemsys(
592+
chemsys, additional_criteria={"thermo_types": [thermo_type]}
596593
)
597-
for entry in entries
598-
if entry.composition.reduced_formula in ["Fe2P", "".join(chemsys)]
599-
]
600594

601-
if len(modified_entries) == 0:
602-
# create fake entry to get PD retrieval to fail
603595
modified_entries = [
604596
ComputedEntry(
605-
"".join(chemsys),
606-
np.average([entry.energy for entry in entries]),
607-
entry_id=f"hypothetical",
597+
entry.composition,
598+
entry.uncorrected_energy + 0.01,
599+
parameters=entry.parameters,
600+
entry_id=f"mod_{entry.entry_id}",
608601
)
602+
for entry in entries
603+
if entry.entry_id == entries[0].entry_id
609604
]
610605

611-
if no_compound_entries:
612-
with pytest.warns(
613-
MPRestWarning, match="No phase diagram data available"
606+
if (
607+
all(len(entry.composition.elements) == 1 for entry in entries)
608+
and chemsys.count("-") > 0
614609
):
615-
mpr.get_stability(modified_entries, thermo_type=thermo_type)
616-
return
617-
618-
else:
619-
rester_ehulls = mpr.get_stability(
620-
modified_entries, thermo_type=thermo_type
621-
)
610+
# For a multi-element chemsys with no multinaries, only elementals,
611+
# there should be no phase diagram data available.
612+
with pytest.warns(
613+
MPRestWarning, match="No phase diagram data available"
614+
):
615+
mpr.get_stability(modified_entries, thermo_type=thermo_type)
616+
return
617+
618+
else:
619+
rester_ehulls = mpr.get_stability(
620+
modified_entries, thermo_type=thermo_type
621+
)
622622

623-
all_entries = entries + modified_entries
624-
625-
compat = None
626-
if thermo_type == "GGA_GGA+U":
627-
compat = MaterialsProject2020Compatibility()
628-
elif thermo_type == "GGA_GGA+U_R2SCAN":
629-
compat = MaterialsProjectDFTMixingScheme(run_type_2="r2SCAN")
630-
631-
if compat:
632-
all_entries = compat.process_entries(all_entries)
633-
634-
pd = PhaseDiagram(all_entries)
635-
for entry in all_entries:
636-
if str(entry.entry_id).startswith("mod"):
637-
for dct in rester_ehulls:
638-
if dct["entry_id"] == entry.entry_id:
639-
data = dct
640-
break
641-
assert pd.get_e_above_hull(entry) == pytest.approx(data["e_above_hull"])
623+
all_entries = entries + modified_entries
624+
625+
compat = None
626+
if thermo_type == "GGA_GGA+U":
627+
compat = MaterialsProject2020Compatibility()
628+
elif thermo_type == "GGA_GGA+U_R2SCAN":
629+
compat = MaterialsProjectDFTMixingScheme(run_type_2="r2SCAN")
630+
631+
if compat:
632+
all_entries = compat.process_entries(all_entries)
633+
634+
pd = PhaseDiagram(all_entries)
635+
for entry in all_entries:
636+
if str(entry.entry_id).startswith("mod"):
637+
for dct in rester_ehulls:
638+
if dct["entry_id"] == entry.entry_id:
639+
data = dct
640+
break
641+
assert pd.get_e_above_hull(entry) == pytest.approx(
642+
data["e_above_hull"]
643+
)
642644

643645
@pytest.mark.parametrize(
644646
"mpid, working_ion, thermo_type",

0 commit comments

Comments
 (0)