-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
When an sklearn-compatible estimator exposes a method with extra keyword arguments (e.g. predict(X, my_param=100)), the Snowflake Model Registry's SKLModelHandler wraps it in a closure that only accepts (self, X) and silently drops all kwargs. This means ParamSpec parameters registered in the model signature are never forwarded to the underlying estimator.
Minimal reproducible example:
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from snowflake.ml.model import model_signature
from snowflake.ml.model.model_signature import DataType, ParamSpec
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session
session = Session.builder.configs({<redacted>}).create()
registry = Registry(session=session)
class EstimatorWithKwarg(BaseEstimator):
"""Trivial estimator whose predict() accepts an extra `my_param` kwarg."""
def fit(self, X, y=None):
return self
def predict(self, X, *, my_param=365):
# Return `my_param` to verify it was forwarded
return np.full(len(X), fill_value=my_param, dtype=float)
estimator = EstimatorWithKwarg().fit(pd.DataFrame({"a": [1, 2, 3]}))
# Verify the estimator itself works with my_param kwarg
X = pd.DataFrame({"a": [10, 20]})
y_hat = estimator.predict(X, my_param=180)
y_hatarray([180., 180.])
predict_sig = model_signature.infer_signature(
X,
y_hat,
output_feature_names=["pred"],
params = [ParamSpec(name="my_param", dtype=DataType.INT32, default_value=100)]
)
registry.log_model(
model=estimator,
model_name="deleteme",
sample_input_data=X,
signatures = {
"predict": predict_sig
}
)
model_ref = registry.get_model("deleteme")
mv = model_ref.version("default")
mv.run(X, function_name="predict", params={"my_param": 100})SnowparkSQLException: (1304): 01c2d5a2-081b-3567-0026-55033fabff02: 100357 (P0000): Python Interpreter Error:
Traceback (most recent call last):
File "/home/udf/10789520488747366/predict.py", line 91, in infer
predictions_df = runner(input_df, **method_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: SKLModelHandler.convert_as_custom_model.<locals>._create_custom_model.<locals>.fn_factory.<locals>.fn() got an unexpected keyword argument 'my_param' in function PREDICT with handler predict.infer
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels