Skip to content

SKLModelHandler drops extra keyword arguments in predict method #209

@robert-norberg

Description

@robert-norberg

When an sklearn-compatible estimator exposes a method with extra keyword arguments (e.g. predict(X, my_param=100)), the Snowflake Model Registry's SKLModelHandler wraps it in a closure that only accepts (self, X) and silently drops all kwargs. This means ParamSpec parameters registered in the model signature are never forwarded to the underlying estimator.

Minimal reproducible example:

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from snowflake.ml.model import model_signature
from snowflake.ml.model.model_signature import DataType, ParamSpec
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session

session = Session.builder.configs({<redacted>}).create()
registry = Registry(session=session)

class EstimatorWithKwarg(BaseEstimator):
    """Trivial estimator whose predict() accepts an extra `my_param` kwarg."""

    def fit(self, X, y=None):
        return self

    def predict(self, X, *, my_param=365):
        # Return `my_param` to verify it was forwarded
        return np.full(len(X), fill_value=my_param, dtype=float)

estimator = EstimatorWithKwarg().fit(pd.DataFrame({"a": [1, 2, 3]}))

# Verify the estimator itself works with my_param kwarg
X = pd.DataFrame({"a": [10, 20]})
y_hat = estimator.predict(X, my_param=180)
y_hat
array([180., 180.])
predict_sig = model_signature.infer_signature(
    X,
    y_hat,
    output_feature_names=["pred"],
    params = [ParamSpec(name="my_param", dtype=DataType.INT32, default_value=100)]
)

registry.log_model(
    model=estimator,
    model_name="deleteme",
    sample_input_data=X,
    signatures = {
        "predict": predict_sig
    }
)
model_ref = registry.get_model("deleteme")
mv = model_ref.version("default")
mv.run(X, function_name="predict", params={"my_param": 100})
SnowparkSQLException: (1304): 01c2d5a2-081b-3567-0026-55033fabff02: 100357 (P0000): Python Interpreter Error:

Traceback (most recent call last):

  File "/home/udf/10789520488747366/predict.py", line 91, in infer

    predictions_df = runner(input_df, **method_params)

                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

TypeError: SKLModelHandler.convert_as_custom_model.<locals>._create_custom_model.<locals>.fn_factory.<locals>.fn() got an unexpected keyword argument 'my_param' in function PREDICT with handler predict.infer

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions