88from emmet .core .mpid import MPID , AlphaID
99from emmet .core .tasks import CoreTaskDoc
1010from emmet .core .trajectory import RelaxTrajectory
11+ from emmet .core .vasp .calc_types import RunType
1112
1213from mp_api .client .core import BaseRester , MPRestError
1314from mp_api .client .core .utils import validate_ids
@@ -24,18 +25,28 @@ class TaskRester(BaseRester):
2425 primary_key : str = "task_id"
2526 delta_backed = True
2627
27- def get_trajectory (self , task_id : MPID | AlphaID | str ) -> dict [str , Any ]:
28+ def get_trajectory (
29+ self , task_id : MPID | AlphaID | str , run_type : str | RunType | None = None
30+ ) -> dict [str , Any ]:
2831 """Returns a Trajectory object containing the geometry of the
2932 material throughout a calculation. This is most useful for
3033 observing how a material relaxes during a geometry optimization.
3134
3235 Args:
3336 task_id (str, MPID, AlphaID): Task ID
37+ run_type (str, RunType): Task run type
3438
3539 Returns:
3640 dict representing emmet.core.trajectory.RelaxTrajectory
3741 """
3842 as_alpha = str (AlphaID (task_id , padlen = 8 )).split ("-" )[- 1 ]
43+
44+ predicate = (
45+ f"WHERE run_type='{ str (run_type )} ' AND identifier='{ as_alpha } '"
46+ if run_type
47+ else f"WHERE identifier='{ as_alpha } '"
48+ )
49+
3950 traj_tbl = DeltaTable (
4051 "s3a://materialsproject-parsed/core/trajectories/" ,
4152 storage_options = {"AWS_SKIP_SIGNATURE" : "true" , "AWS_REGION" : "us-east-1" },
@@ -48,7 +59,7 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]:
4859 f"""
4960 SELECT *
5061 FROM traj
51- WHERE identifier=' { as_alpha } '
62+ { predicate } ;
5263 """
5364 )
5465 .read_all ()
0 commit comments