Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions mp_api/client/routes/materials/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def search( # noqa: D417
poisson_ratio: tuple[float, float] | None = None,
possible_species: list[str] | None = None,
shape_factor: tuple[float, float] | None = None,
spacegroup_number: int | None = None,
spacegroup_symbol: str | None = None,
spacegroup_number: int | list[int] | None = None,
spacegroup_symbol: str | list[str] | None = None,
surface_energy_anisotropy: tuple[float, float] | None = None,
theoretical: bool | None = None,
total_energy: tuple[float, float] | None = None,
Expand All @@ -81,7 +81,7 @@ def search( # noqa: D417
band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider.
chemsys (str, List[str]): A chemical system or list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
crystal_system (CrystalSystem): Crystal system of material.
crystal_system (CrystalSystem or list[CrystalSystem]): Crystal system(s) of the materials.
density (Tuple[float,float]): Minimum and maximum density to consider.
deprecated (bool): Whether the material is tagged as deprecated.
e_electronic (Tuple[float,float]): Minimum and maximum electronic dielectric constant to consider.
Expand Down Expand Up @@ -128,8 +128,8 @@ def search( # noqa: D417
poisson_ratio (Tuple[float,float]): Minimum and maximum value to consider for Poisson's ratio.
possible_species (List(str)): List of element symbols appended with oxidation states. (e.g. Cr2+,O2-)
shape_factor (Tuple[float,float]): Minimum and maximum shape factor values to consider.
spacegroup_number (int): Space group number of material.
spacegroup_symbol (str): Space group symbol of the material in international short symbol notation.
spacegroup_number (int or list[int]): Space group number(s) of materials.
spacegroup_symbol (str or list[str]): Space group symbol(s) of the materials in international short symbol notation.
surface_energy_anisotropy (Tuple[float,float]): Minimum and maximum surface energy anisotropy values
to consider.
theoretical: (bool): Whether the material is theoretical.
Expand Down Expand Up @@ -319,13 +319,25 @@ def _csrc(x):
if possible_species is not None:
query_params.update({"possible_species": ",".join(possible_species)})

query_params.update(
{
"crystal_system": crystal_system,
"spacegroup_number": spacegroup_number,
"spacegroup_symbol": spacegroup_symbol,
}
)
symm_cardinality = {
"crystal_system": 7,
"spacegroup_number": 230,
"spacegroup_symbol": 230,
}
for k, cardinality in symm_cardinality.items():
if hasattr(symm_vals := locals().get(k), "__len__") and not isinstance(
symm_vals, str
):
if len(symm_vals) < cardinality // 2:
query_params.update({k: ",".join(str(v) for v in symm_vals)})
else:
raise ValueError(
f"Querying `{k}` by a list of values is only "
f"supported for up to {cardinality//2 - 1} values. "
f"For your query, retrieve all data first and then filter on `{k}`."
)
else:
query_params.update({k: symm_vals})

if is_stable is not None:
query_params.update({"is_stable": is_stable})
Expand Down
40 changes: 40 additions & 0 deletions tests/materials/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,43 @@ def test_client():
custom_field_tests=custom_field_tests,
sub_doc_fields=[],
)


@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_list_like_input():
search_method = SummaryRester().search

# These are specifically chosen for the low representation in MP
# Specifically, these are the four least-represented space groups
# with at least one member
sparse_sgn = (93, 101, 172, 179, 211)
docs_by_number = search_method(
spacegroup_number=sparse_sgn, fields=["material_id", "symmetry"]
)
assert {doc.symmetry.number for doc in docs_by_number} == set(sparse_sgn)

sparse_symbols = {doc.symmetry.symbol for doc in docs_by_number}
docs_by_symbol = search_method(
spacegroup_symbol=sparse_symbols, fields=["material_id", "symmetry"]
)
assert {doc.symmetry.symbol for doc in docs_by_symbol} == sparse_symbols
assert {doc.material_id for doc in docs_by_symbol} == {
doc.material_id for doc in docs_by_number
}

# also chosen for very low document count
crys_sys = ["Hexagonal", "Cubic"]
assert {
doc.symmetry.crystal_system
for doc in search_method(elements=["Ar"], crystal_system=crys_sys)
} == set(crys_sys)

# should fail - we don't support querying by so many list values
with pytest.raises(ValueError, match="retrieve all data first and then filter"):
_ = search_method(spacegroup_number=list(range(1, 231)))

with pytest.raises(ValueError, match="retrieve all data first and then filter"):
_ = search_method(spacegroup_number=["null" for _ in range(230)])

with pytest.raises(ValueError, match="retrieve all data first and then filter"):
_ = search_method(crystal_system=list(CrystalSystem))