-
Notifications
You must be signed in to change notification settings - Fork 14
Closed
Description
When a model is registered via Registry.log_model() with enable_explainability=True, the auto-generated EXPLAIN table function's signature uses Snowpark ML types (INT8, INT64, DOUBLE) that do not exist as distinct types in Snowflake SQL. The EXPLAIN function performs strict type matching, so it rejects the Snowflake-native types that are the only ones available. This means that TABLE(MODEL(...)!EXPLAIN(...)) always fails.
import contextlib
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from snowflake.ml.model import model_signature
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session
session = Session.builder.configs(...).create()
registry = Registry(session=session)# Example model with mixed dtypes (int8, float64, object/string, bool)
rng = np.random.default_rng(42)
n = 100
X = pd.DataFrame(
{
"age": rng.integers(20, 80, size=n).astype(np.int64),
"score": rng.standard_normal(n),
"flag": rng.choice([True, False], size=n),
"category": rng.choice(["a", "b", "c"], size=n),
}
)
y = rng.integers(0, 2, size=n)
preprocessor = ColumnTransformer(
transformers=[
("categorical_features", OneHotEncoder(categories=[["a", "b", "c"]], sparse_output=False), ["category"]),
("continuous_features", StandardScaler(), ["score"]),
]
)
pipeline = Pipeline(
steps=[
("preprocessor", preprocessor),
("model", RandomForestClassifier(n_estimators=10, random_state=42)),
]
)
pipeline.fit(X, y)# Infer signature from the training data (this is what produces INT8, DOUBLE, etc.)
sample_output = pd.DataFrame({"prediction": pipeline.predict(X[:5])})
sig = model_signature.infer_signature(X[:5], sample_output)
print("Inferred signature input types:")
for feat in sig.inputs:
print(f" {feat.name}: {feat.as_snowpark_type()}")Inferred signature input types:
age: LongType()
score: DoubleType()
flag: BooleanType()
category: StringType()
MODEL_NAME = "repro_explain_type_bug"
MODEL_VERSION = "v1"
# delete if exists
with contextlib.suppress(Exception):
registry.delete_model(MODEL_NAME)
# register the model with explainability enabled
mv = registry.log_model(
model=pipeline,
model_name=MODEL_NAME,
version_name=MODEL_VERSION,
sample_input_data=X,
signatures={"predict": sig},
options={
"enable_explainability": True,
},
)predict works in python:
test_df = X[:5]
mv.run(test_df, function_name="predict")prediction
0
1
1
1
0
It also works in SQL:
pred_in_sql = session.sql("""
WITH test_df AS (
SELECT *
FROM VALUES
(25, 0.678914, TRUE, 'c'),
(66, 0.067579, FALSE, 'c'),
(59, 0.289119, TRUE, 'a'),
(46, 0.631288, FALSE, 'c'),
(45, -1.457156, FALSE, 'b')
AS t(age, score, flag, category)
)
SELECT MODEL(repro_explain_type_bug)!PREDICT(age, score, flag, category):prediction::NUMBER AS model_pred
FROM test_df;
""")
pred_in_sql.to_pandas()MODEL_PRED
0
1
1
1
0
explain works in python:
mv.run(test_df, function_name="explain")<produces a pandas DataFrame with 4 columns and 5 rows>
But explain fails in python:
explain_in_sql = session.sql("""
WITH test_df AS (
SELECT *
FROM VALUES
(25, 0.678914, TRUE, 'c'),
(66, 0.067579, FALSE, 'c'),
(59, 0.289119, TRUE, 'a'),
(46, 0.631288, FALSE, 'c'),
(45, -1.457156, FALSE, 'b')
AS t(age, score, flag, category)
)
SELECT *
FROM test_df,
TABLE(MODEL(repro_explain_type_bug)!EXPLAIN(age, score, flag, category));
""")
explain_in_sql.to_pandas()SnowparkSQLException: (1304): 01c2d9dc-0a1b-3f25-0026-55033fc11336: 001044 (42P13): SQL compilation error: error line 13 at position 6
Invalid argument types for function 'REPRO_EXPLAIN_TYPE_BUG!"DEFAULT.EXPLAIN"': (NUMBER(2,0), NUMBER(7,6), BOOLEAN, VARCHAR(1))
The predict and explain function signatures are the same.
mv.show_functions()[{'name': 'EXPLAIN', 'target_method': 'explain', 'target_method_function_type': 'TABLE_FUNCTION', 'signature': ModelSignature(
inputs=[
FeatureSpec(dtype=DataType.INT64, name='age', nullable=True),
FeatureSpec(dtype=DataType.DOUBLE, name='score', nullable=True),
FeatureSpec(dtype=DataType.BOOL, name='flag', nullable=True),
FeatureSpec(dtype=DataType.STRING, name='category', nullable=True)
],
outputs=[
FeatureSpec(dtype=DataType.DOUBLE, name='categorical_features__category_a_explanation', nullable=True),
FeatureSpec(dtype=DataType.DOUBLE, name='categorical_features__category_b_explanation', nullable=True),
FeatureSpec(dtype=DataType.DOUBLE, name='categorical_features__category_c_explanation', nullable=True),
FeatureSpec(dtype=DataType.DOUBLE, name='continuous_features__score_explanation', nullable=True)
],
params=[]
), 'is_partitioned': True}, {'name': 'PREDICT', 'target_method': 'predict', 'target_method_function_type': 'FUNCTION', 'signature': ModelSignature(
inputs=[
FeatureSpec(dtype=DataType.INT64, name='age', nullable=True),
FeatureSpec(dtype=DataType.DOUBLE, name='score', nullable=True),
FeatureSpec(dtype=DataType.BOOL, name='flag', nullable=True),
FeatureSpec(dtype=DataType.STRING, name='category', nullable=True)
],
outputs=[
FeatureSpec(dtype=DataType.INT64, name='prediction', nullable=True)
],
params=[]
), 'is_partitioned': False}]
I think it has something to do with table functions having more strict type checking than regular scalar functions in Snowflake? I'm not sure.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels