diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 6b35fabdc..7df557efc 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -304,12 +304,17 @@ def find_structure( use_document_model=False, ).get("data") - if len(results) > 1: # type: ignore + if not results: + return [] + + material_ids = validate_ids([doc["material_id"] for doc in results]) + + if len(material_ids) > 1: # type: ignore if not allow_multiple_results: raise ValueError( "Multiple matches found for this combination of tolerances, but " "`allow_multiple_results` set to False." ) - return results # type: ignore + return material_ids # type: ignore - return results[0]["material_id"] if (results and results[0]) else [] + return material_ids[0] diff --git a/tests/test_mprester.py b/tests/test_mprester.py index dbb5e3950..10c98e008 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -96,16 +96,14 @@ def test_get_structures(self, mpr): structs = mpr.get_structures("Mn-O", final=False) assert len(structs) > 0 - @pytest.mark.skip(reason="Endpoint issues") def test_find_structure(self, mpr): path = os.path.join(MAPIClientSettings().TEST_FILES, "Si_mp_149.cif") - with open(path) as file: - data = mpr.find_structure(path) - assert len(data) > 0 + data = mpr.find_structure(path) + assert isinstance(data, str) and data == "mp-149" - s = CifParser(file).get_structures()[0] - data = mpr.find_structure(s) - assert len(data) > 0 + s = CifParser(path).get_structures()[0] + data = mpr.find_structure(s) + assert isinstance(data, str) and data == "mp-149" def test_get_bandstructure_by_material_id(self, mpr): bs = mpr.get_bandstructure_by_material_id("mp-149")