Skip to content

ENH: Serialize tt_function #73

@MImmesberger

Description

@MImmesberger

Is your feature request related to a problem?

The output of main(main_target="tt_function", ...) cannot be serialized using either pickle or cloudpickle. This implies that users cannot pickle objects that take tt_functions as inputs. For example, the tt_function cannot be a product of the workflow manager pytask or be an input of the model instance of pylcm if the workflow is handled via pytask.

Cause: The load_module() function registers policy modules with relative paths from the policy ROOT_PATH (e.g., orc_hunting_bounty.orc_hunting_bounty instead of mettsim.middle_earth.orc_hunting_bounty.orc_hunting_bounty).

Potential solutions

It is probably out of question that this will ever work with pickle.

Using cloudpickle, I had some success using cloudpickle.register_pickle_by_value (see PR #72). The issue here is that this is heavily context dependent. When running this fix in a pytest environment or in a Jupyter notebook it fails. Here is the corresponding function:

def cloudpickle_main_output(obj: object, root = middle_earth.ROOT_PATH) -> bytes:
    for mod in sys.modules.values():
        if mod is None:
            continue
        mod_file = getattr(mod, "__file__", None)
        if mod_file and Path(mod_file).is_relative_to(root):
            cloudpickle.register_pickle_by_value(mod)

    return cloudpickle.dumps(obj)

dill works in notebooks using the above function (replacing cloudpickle.register_pickle_by_value with dill.register).

Reproducer

1. Pickle failure

import pickle                                                                                                                                 
import cloudpickle                                                                                                                            
import numpy                                                                                                                                  
from mettsim import middle_earth                                                                                                              
from ttsim import main, OrigPolicyObjects, TTTargets                                                                                          
from ttsim.main_args import InputData                                                                                                         
                                                                                                                                              
# Prepare input data                                                                                                                          
data = {                                                                                                                                      
    ("age",): numpy.array([30, 30]),                                                                                                          
    ("kin_id",): numpy.array([0, 0]),                                                                                                         
    ("p_id",): numpy.array([0, 1]),                                                                                                           
    ("p_id_parent_1",): numpy.array([-1, -1]),                                                                                                
    ("p_id_parent_2",): numpy.array([-1, -1]),                                                                                                
    ("p_id_spouse",): numpy.array([1, 0]),                                                                                                    
    ("parent_is_noble",): numpy.array([False, False]),                                                                                        
    ("payroll_tax", "child_tax_credit", "p_id_recipient"): numpy.array([-1, -1]),                                                             
    ("payroll_tax", "income", "gross_wage_y"): numpy.array([10000.0, 0.0]),                                                                   
    ("wealth",): numpy.array([0.0, 0.0]),                                                                                                     
}                                                                                                                                             
                                                                                                                                              
# Get the tt_function                                                                                                                         
tt_func = main(                                                                                                                               
    main_target="tt_function",                                                                                                                
    policy_date_str="2025-01-01",                                                                                                             
    input_data=InputData.flat(data),                                                                                                          
    orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),                                                                       
    tt_targets=TTTargets.tree({"payroll_tax": {"amount_y": None}}),                                                                           
)                                                                                                                                             
                                                                                                                                              
# Both fail:                                                                                                                                  
pickle.dumps(tt_func)      # PicklingError: Can't pickle local object 'tt_function.<locals>.wrapper'                                          
cloudpickle.dumps(tt_func) # PicklingError: Can't pickle <class 'orc_hunting_bounty...'>: No module named 'orc_hunting_bounty' 

2. Fix that runs in a python file but not via pytest or in a notebook

import numpy
import cloudpickle
from mettsim import middle_earth
from ttsim import main, OrigPolicyObjects, TTTargets, cloudpickle_main_output
from ttsim.main_args import InputData
from pathlib import Path
import sys

def cloudpickle_main_output(obj: object, root = middle_earth.ROOT_PATH) -> bytes:
    """Cloudpickle an object that references policy modules.

    Policy modules are loaded with non-importable relative paths, which causes
    cloudpickle to fail when it tries to verify classes can be imported.
    This function registers modules for "pickle by value" (embedding class
    definitions directly in the pickle) before pickling the object.

    Parameters
    ----------
    obj
        The object to pickle, typically a tt_function or other main output.
    root
        The ROOT_PATH of the policy environment.

    Returns
    -------
    The pickled bytes.

    Example
    -------
    >>> from mettsim import middle_earth
    >>> from ttsim import main, MainTarget, cloudpickle_main_output
    >>>
    >>> tt_func = main(main_target=MainTarget.tt_function, ...)
    >>> pickled = cloudpickle_main_output(tt_func, middle_earth.ROOT_PATH)
    """
    for mod in sys.modules.values():
        if mod is None:
            continue
        mod_file = getattr(mod, "__file__", None)
        if mod_file and Path(mod_file).is_relative_to(root):
            cloudpickle.register_pickle_by_value(mod)

    return cloudpickle.dumps(obj)

data = {
    ("age",): numpy.array([30, 30]),
    ("kin_id",): numpy.array([0, 0]),
    ("p_id",): numpy.array([0, 1]),
    ("p_id_parent_1",): numpy.array([-1, -1]),
    ("p_id_parent_2",): numpy.array([-1, -1]),
    ("p_id_spouse",): numpy.array([1, 0]),
    ("parent_is_noble",): numpy.array([False, False]),
    ("payroll_tax", "child_tax_credit", "p_id_recipient"):
        numpy.array([-1, -1]),
    ("payroll_tax", "income", "gross_wage_y"): numpy.array([10000.0, 0.0]),
    ("wealth",): numpy.array([0.0, 0.0]),
}
tt_func = main(
    main_target="tt_function",
    policy_date_str="2025-01-01",
    input_data=InputData.flat(data),
    orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),
    tt_targets=TTTargets.tree({"payroll_tax": {"amount_y": None}}),
    backend="numpy",
)
processed_data = main(
    main_target="processed_data",
    policy_date_str="2025-01-01",
    input_data=InputData.flat(data),
    orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),
    tt_targets=TTTargets.tree({"payroll_tax": {"amount_y": None}}),
    backend="numpy",
)
pickled = cloudpickle_main_output(tt_func, middle_earth.ROOT_PATH)
unpickled_func = cloudpickle.loads(pickled)
original_result = tt_func(processed_data)
restored_result = unpickled_func(processed_data)

3. Dill works in a notebook

import numpy
from mettsim import middle_earth
from ttsim import main, OrigPolicyObjects, TTTargets, cloudpickle_main_output
from ttsim.main_args import InputData
from pathlib import Path
import sys
import dill

def pickle_main_output(obj: object, root = middle_earth.ROOT_PATH) -> bytes:
    for mod in sys.modules.values():
        if mod is None:
            continue
        mod_file = getattr(mod, "__file__", None)
        if mod_file and Path(mod_file).is_relative_to(root):
            dill.register(mod)

    return dill.dumps(obj)

data = {
    ("age",): numpy.array([30, 30]),
    ("kin_id",): numpy.array([0, 0]),
    ("p_id",): numpy.array([0, 1]),
    ("p_id_parent_1",): numpy.array([-1, -1]),
    ("p_id_parent_2",): numpy.array([-1, -1]),
    ("p_id_spouse",): numpy.array([1, 0]),
    ("parent_is_noble",): numpy.array([False, False]),
    ("payroll_tax", "child_tax_credit", "p_id_recipient"):
        numpy.array([-1, -1]),
    ("payroll_tax", "income", "gross_wage_y"): numpy.array([10000.0, 0.0]),
    ("wealth",): numpy.array([0.0, 0.0]),
}
tt_func = main(
    main_target="tt_function",
    policy_date_str="2025-01-01",
    input_data=InputData.flat(data),
    orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),
    tt_targets=TTTargets.tree({"payroll_tax": {"amount_y": None}}),
    backend="numpy",
)
processed_data = main(
    main_target="processed_data",
    policy_date_str="2025-01-01",
    input_data=InputData.flat(data),
    orig_policy_objects=OrigPolicyObjects.root(middle_earth.ROOT_PATH),
    tt_targets=TTTargets.tree({"payroll_tax": {"amount_y": None}}),
    backend="numpy",
)
pickled = pickle_main_output(tt_func, middle_earth.ROOT_PATH)
unpickled_func = dill.loads(pickled)
original_result = tt_func(processed_data)
restored_result = unpickled_func(processed_data)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions