diff --git a/docs/modules.md b/docs/modules.md index a5d646b3..89b547ef 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -334,14 +334,22 @@ on the data of the snapshot. #### Snapshots Correlation Hook +There are two correlation hooks available: + +- [`register_correlation_hook`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook] +- [`register_correlation_hook_with_master_record`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook_with_master_record] + +Both do the same thing, but as the naming suggests, the latter also provides a master record. The [`register_correlation_hook`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook] method expects a callable with the following signature: `Callable[[str, dict], Union[None, list[DataPointTask]]]`, where the first argument is the entity type, and the second is a dict containing the current values of the entity and its linked entities. -The method can optionally return a list of DataPointTask objects to be inserted into the system. +The [`register_correlation_hook_with_master_record`][dp3.common.callback_registrar.CallbackRegistrar.register_correlation_hook_with_master_record] method expects a callable with the following signature: +`Callable[[str, dict, dict], Union[None, list[DataPointTask]]]` - the first two arguments are identical (entity type and dict with current values), but there is also a third argument: a dictionary of values stored in the master record of the entity. +The method (applicable to both variants) can optionally return a list of `DataPointTask` objects to be inserted into the system. As correlation hooks can depend on each other, the hook inputs and outputs must be specified -using the depends_on and may_change arguments. Both arguments are lists of lists of strings, +using the `depends_on` and `may_change` arguments. Both arguments are lists of lists of strings, where each list of strings is a path from the specified entity type to individual attributes (even on linked entities). For example, if the entity type is `test_entity_type`, and the hook depends on the attribute `test_attr_type1`, the path is simply `[["test_attr_type1"]]`. If the hook depends on the attribute `test_attr_type1` @@ -351,9 +359,21 @@ of an entity linked using `test_attr_link`, the path will be `[["test_attr_link def correlation_hook(entity_type: str, values: dict): ... +def correlation_hook_with_master_record(entity_type: str, values: dict, master_record: dict): + ... + +# Without master record registrar.register_correlation_hook( correlation_hook, "test_entity_type", [["test_attr_type1"]], [["test_attr_type2"]] ) + +# Or with master record +registrar.register_correlation_hook_with_master_record( + correlation_hook_with_master_record, + "test_entity_type", + [["test_attr_type1"]], + [["test_attr_type2"]] +) ``` The order of running callbacks is determined automatically, based on the dependencies. diff --git a/dp3/common/callback_registrar.py b/dp3/common/callback_registrar.py index ba13c299..e885610b 100644 --- a/dp3/common/callback_registrar.py +++ b/dp3/common/callback_registrar.py @@ -1,5 +1,5 @@ import logging -from functools import partial +from functools import partial, wraps from logging import Logger from typing import Callable, Union @@ -365,6 +365,47 @@ def register_correlation_hook( may_change: each item should specify an attribute that `hook` may change. specification format is identical to `depends_on`. + Raises: + ValueError: On failure of specification validation. + """ + + # Ignore master record for this variant of the hook + @wraps(hook) + def wrapped_hook(e: str, s: dict, _m: dict): + return hook(e, s) + + self._snap_shooter.register_correlation_hook( + wrapped_hook, entity_type, depends_on, may_change + ) + + def register_correlation_hook_with_master_record( + self, + hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], + entity_type: str, + depends_on: list[list[str]], + may_change: list[list[str]], + ): + """ + Registers passed hook to be called during snapshot creation. + + Identical to `register_correlation_hook`, but the hook also receives the master record. + + Binds hook to specified entity_type (though same hook can be bound multiple times). + + `entity_type` and attribute specifications are validated, `ValueError` is raised on failure. + + Args: + hook: `hook` callable should have the signature + `hook(entity_type: str, current_values: dict, master_record: dict)`. + where `current_values` includes linked entities. + Can optionally return a list of DataPointTask objects to perform. + entity_type: specifies entity type + depends_on: each item should specify an attribute that is depended on + in the form of a path from the specified entity_type to individual attributes + (even on linked entities). + may_change: each item should specify an attribute that `hook` may change. + specification format is identical to `depends_on`. + Raises: ValueError: On failure of specification validation. """ diff --git a/dp3/snapshots/snapshooter.py b/dp3/snapshots/snapshooter.py index 713d4426..46d70770 100644 --- a/dp3/snapshots/snapshooter.py +++ b/dp3/snapshots/snapshooter.py @@ -185,7 +185,7 @@ def register_timeseries_hook( def register_correlation_hook( self, - hook: Callable[[str, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], @@ -193,13 +193,16 @@ def register_correlation_hook( """ Registers passed hook to be called during snapshot creation. + Common implementation for hooks with and without master record. + Binds hook to specified entity_type (though same hook can be bound multiple times). `entity_type` and attribute specifications are validated, `ValueError` is raised on failure. Args: hook: `hook` callable should expect entity type as str - and its current values, including linked entities, as dict + and its current values, including linked entities, as dict; + and its master record as dict. Can optionally return a list of DataPointTask objects to perform. entity_type: specifies entity type depends_on: each item should specify an attribute that is depended on @@ -457,9 +460,12 @@ def make_linkless_snapshot(self, entity_type: str, master_record: dict, time: da self.run_timeseries_processing(entity_type, master_record) values = self.get_values_at_time(entity_type, master_record, time) self.add_mirrored_links(entity_type, values) - entity_values = {(entity_type, master_record["_id"]): values} + entity_id = master_record["_id"] + entity_values = {(entity_type, entity_id): values} - tasks = self._correlation_hooks.run(entity_values) + tasks = self._correlation_hooks.run( + entity_values, {(entity_type, entity_id): master_record} + ) for task in tasks: self.task_queue_writer.put_task(task) @@ -499,6 +505,7 @@ def make_snapshot(self, task: Snapshot): The resulting snapshots are saved into DB. """ entity_values = {} + entity_master_records = {} for entity_type, entity_id in task.entities: record = self.db.get_master_record(entity_type, entity_id) or {"_id": entity_id} if not self.config.keep_empty and len(record) == 1: @@ -508,9 +515,10 @@ def make_snapshot(self, task: Snapshot): values = self.get_values_at_time(entity_type, record, task.time) self.add_mirrored_links(entity_type, values) entity_values[entity_type, entity_id] = values + entity_master_records[entity_type, entity_id] = record self.link_loaded_entities(entity_values) - created_tasks = self._correlation_hooks.run(entity_values) + created_tasks = self._correlation_hooks.run(entity_values, entity_master_records) for created_task in created_tasks: self.task_queue_writer.put_task(created_task) diff --git a/dp3/snapshots/snapshot_hooks.py b/dp3/snapshots/snapshot_hooks.py index d58e1c1e..dcf437f5 100644 --- a/dp3/snapshots/snapshot_hooks.py +++ b/dp3/snapshots/snapshot_hooks.py @@ -84,7 +84,7 @@ def __init__(self, log: logging.Logger, model_spec: ModelSpec, elog: EventGroupT def register( self, - hook: Callable[[str, dict], Union[None, list[DataPointTask]]], + hook: Callable[[str, dict, dict], Union[None, list[DataPointTask]]], entity_type: str, depends_on: list[list[str]], may_change: list[list[str]], @@ -97,8 +97,9 @@ def register( If entity_type and attribute specifications are validated and ValueError is raised on failure. Args: - hook: `hook` callable should expect entity type as str - and its current values, including linked entities, as dict. + hook: `hook` callable should expect entity type as str; + its current values, including linked entities, as dict; + and its master record as dict. Can optionally return a list of DataPointTask objects to perform. entity_type: specifies entity type depends_on: each item should specify an attribute that is depended on @@ -191,7 +192,7 @@ def _resolve_entities_in_path(self, base_entity: str, path: list[str]) -> list[t position = entity_attributes[position.relation_to] return resolved_path - def run(self, entities: dict) -> list[DataPointTask]: + def run(self, entities: dict, entity_master_records: dict) -> list[DataPointTask]: """Runs registered hooks.""" entity_types = {etype for etype, _ in entities} hook_subset = [ @@ -200,8 +201,11 @@ def run(self, entities: dict) -> list[DataPointTask]: topological_order = self._dependency_graph.topological_order hook_subset.sort(key=lambda x: topological_order.index(x[0])) entities_by_etype = defaultdict(dict) + entity_master_records_by_etype = defaultdict(dict) for (etype, eid), values in entities.items(): entities_by_etype[etype][eid] = values + for (etype, eid), mr in entity_master_records.items(): + entity_master_records_by_etype[etype][eid] = mr created_tasks = [] @@ -209,9 +213,10 @@ def run(self, entities: dict) -> list[DataPointTask]: for hook_id, hook, etype in hook_subset: short_id = hook_id if len(hook_id) < 160 else self._short_hook_ids[hook_id] for eid, entity_values in entities_by_etype[etype].items(): + entity_master_record = entity_master_records_by_etype[etype].get(eid, {}) self.log.debug("Running hook %s on entity %s", short_id, eid) try: - tasks = hook(etype, entity_values) + tasks = hook(etype, entity_values, entity_master_record) if tasks is not None and tasks: created_tasks.extend(tasks) except Exception as e: diff --git a/tests/modules/test_module.py b/tests/modules/test_module.py index 3640a3e4..a7079f9c 100644 --- a/tests/modules/test_module.py +++ b/tests/modules/test_module.py @@ -14,6 +14,22 @@ def modify_value(_: str, record: dict, attr: str, value): record[attr] = value +def use_master_record( + _: str, record: dict, master_record: dict, target_attr: str, source_attr: str +): + """Hook that uses master record to copy a value from master to snapshot. + + Only applies when source attribute in master record has value starting with "master_" + to avoid interfering with other test cases. + """ + if source_attr in master_record: + # Get the value from master record + master_value = master_record[source_attr].get("v", None) + if master_value is not None and str(master_value).startswith("master_"): + # Append a suffix to demonstrate master record was used + record[target_attr] = f"{master_value}_from_master" + + dummy_hook_abc = update_wrapper(partial(modify_value, attr="data2", value="abc"), modify_value) dummy_hook_def = update_wrapper(partial(modify_value, attr="data1", value="def"), modify_value) @@ -67,3 +83,15 @@ def __init__( depends_on=[], may_change=[["data1"]], ) + + # Testing register_correlation_hook_with_master_record + # This hook should copy data1 from master record to data2 with a suffix + registrar.register_correlation_hook_with_master_record( + update_wrapper( + partial(use_master_record, target_attr="data4", source_attr="data3"), + use_master_record, + ), + "A", + depends_on=[], + may_change=[["data4"]], + ) diff --git a/tests/test_api/test_snapshots.py b/tests/test_api/test_snapshots.py index 8bdd4119..ffae83ff 100644 --- a/tests/test_api/test_snapshots.py +++ b/tests/test_api/test_snapshots.py @@ -42,6 +42,9 @@ def make_dp(type, id, attr, v, time=False): make_dp("C", "c1", "ds", {"eid": "d1"}, time=True), make_dp("C", "c1", "data1", "inita"), make_dp("C", "c1", "data2", "inita"), + # For test_master_record_hook (A-423) + make_dp("A", 423, "data3", "master_test"), + make_dp("A", 423, "data4", "placeholder"), ] res = cls.push_datapoints(entity_datapoints) if res.status_code != 200: @@ -79,3 +82,17 @@ def test_hook_dependency_value_forwarding(self): for snapshot in data.snapshots: self.assertEqual(snapshot["data1"], "modifd") self.assertEqual(snapshot["data2"], "modifc") + + def test_master_record_hook(self): + """ + Test that hooks registered via register_correlation_hook_with_master_record + correctly receive the master record parameter. + """ + # Entity A-423 has data1="master_test" in its master record + # The master record hook should copy data1 from master and append "_from_master" + data = self.get_entity_data("entity/A/423", EntityEidData) + self.assertGreater(len(data.snapshots), 0) + for snapshot in data.snapshots: + self.assertEqual(snapshot["data3"], "master_test") + # The hook should have set data4 to data3 from master record + "_from_master" + self.assertEqual(snapshot["data4"], "master_test_from_master") diff --git a/tests/test_common/test_snapshots.py b/tests/test_common/test_snapshots.py index 2d67a421..9adff033 100644 --- a/tests/test_common/test_snapshots.py +++ b/tests/test_common/test_snapshots.py @@ -14,7 +14,7 @@ from dp3.snapshots.snapshot_hooks import SnapshotCorrelationHookContainer -def modify_value(_: str, record: dict, attr: str, value): +def modify_value(_: str, record: dict, _master_record: dict, attr: str, value): record[attr] = value @@ -37,7 +37,7 @@ def test_basic_function(self): hook=dummy_hook_abc, entity_type="A", depends_on=[["data1"]], may_change=[["data2"]] ) values = {} - self.container.run({("A", "a1"): values}) + self.container.run({("A", "a1"): values}, {}) self.assertEqual(values["data2"], "abc") def test_circular_dependency_error(self): diff --git a/tests/test_config/db_entities/A.yml b/tests/test_config/db_entities/A.yml index d602cd2f..7ebfb92c 100644 --- a/tests/test_config/db_entities/A.yml +++ b/tests/test_config/db_entities/A.yml @@ -17,6 +17,18 @@ attribs: type: plain data_type: string + data3: + name: data3 + description: entity data + type: plain + data_type: string + + data4: + name: data4 + description: entity data + type: plain + data_type: string + as: name: As description: Link to other A entities