Skip to content

Commit 272d2fa

Browse files
committed
add optional run_type kwarg to get_trajectory
1 parent 6ce233b commit 272d2fa

1 file changed

Lines changed: 13 additions & 2 deletions

File tree

mp_api/client/routes/materials/tasks.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from emmet.core.mpid import MPID, AlphaID
99
from emmet.core.tasks import CoreTaskDoc
1010
from emmet.core.trajectory import RelaxTrajectory
11+
from emmet.core.vasp.calc_types import RunType
1112

1213
from mp_api.client.core import BaseRester, MPRestError
1314
from mp_api.client.core.utils import validate_ids
@@ -24,18 +25,28 @@ class TaskRester(BaseRester):
2425
primary_key: str = "task_id"
2526
delta_backed = True
2627

27-
def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]:
28+
def get_trajectory(
29+
self, task_id: MPID | AlphaID | str, run_type: str | RunType | None = None
30+
) -> dict[str, Any]:
2831
"""Returns a Trajectory object containing the geometry of the
2932
material throughout a calculation. This is most useful for
3033
observing how a material relaxes during a geometry optimization.
3134
3235
Args:
3336
task_id (str, MPID, AlphaID): Task ID
37+
run_type (str, RunType): Task run type
3438
3539
Returns:
3640
dict representing emmet.core.trajectory.RelaxTrajectory
3741
"""
3842
as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1]
43+
44+
predicate = (
45+
f"WHERE run_type='{str(run_type)}' AND identifier='{as_alpha}'"
46+
if run_type
47+
else f"WHERE identifier='{as_alpha}'"
48+
)
49+
3950
traj_tbl = DeltaTable(
4051
"s3a://materialsproject-parsed/core/trajectories/",
4152
storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"},
@@ -48,7 +59,7 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]:
4859
f"""
4960
SELECT *
5061
FROM traj
51-
WHERE identifier='{as_alpha}'
62+
{predicate};
5263
"""
5364
)
5465
.read_all()

0 commit comments

Comments
 (0)