Skip to content

EXPLAIN table function rejects Snowflake native types for registered models #210

@robert-norberg

Description

@robert-norberg

When a model is registered via Registry.log_model() with enable_explainability=True, the auto-generated EXPLAIN table function's signature uses Snowpark ML types (INT8, INT64, DOUBLE) that do not exist as distinct types in Snowflake SQL. The EXPLAIN function performs strict type matching, so it rejects the Snowflake-native types that are the only ones available. This means that TABLE(MODEL(...)!EXPLAIN(...)) always fails.

import contextlib
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from snowflake.ml.model import model_signature
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session

session = Session.builder.configs(...).create()

registry = Registry(session=session)
# Example model with mixed dtypes (int8, float64, object/string, bool)
rng = np.random.default_rng(42)
n = 100

X = pd.DataFrame(
    {
        "age": rng.integers(20, 80, size=n).astype(np.int64),
        "score": rng.standard_normal(n),
        "flag": rng.choice([True, False], size=n),
        "category": rng.choice(["a", "b", "c"], size=n),
    }
)
y = rng.integers(0, 2, size=n)

preprocessor = ColumnTransformer(
    transformers=[
        ("categorical_features", OneHotEncoder(categories=[["a", "b", "c"]], sparse_output=False), ["category"]),
        ("continuous_features", StandardScaler(), ["score"]),
    ]
)

pipeline = Pipeline(
    steps=[
        ("preprocessor", preprocessor),
        ("model", RandomForestClassifier(n_estimators=10, random_state=42)),
    ]
)

pipeline.fit(X, y)
# Infer signature from the training data (this is what produces INT8, DOUBLE, etc.)
sample_output = pd.DataFrame({"prediction": pipeline.predict(X[:5])})
sig = model_signature.infer_signature(X[:5], sample_output)

print("Inferred signature input types:")
for feat in sig.inputs:
    print(f"  {feat.name}: {feat.as_snowpark_type()}")
Inferred signature input types:
  age: LongType()
  score: DoubleType()
  flag: BooleanType()
  category: StringType()
MODEL_NAME = "repro_explain_type_bug"
MODEL_VERSION = "v1"

# delete if exists
with contextlib.suppress(Exception):
    registry.delete_model(MODEL_NAME)

# register the model with explainability enabled
mv = registry.log_model(
    model=pipeline,
    model_name=MODEL_NAME,
    version_name=MODEL_VERSION,
    sample_input_data=X,
    signatures={"predict": sig},
    options={
        "enable_explainability": True,
    },
)

predict works in python:

test_df = X[:5]
mv.run(test_df, function_name="predict")
prediction
0
1
1
1
0

It also works in SQL:

pred_in_sql = session.sql("""
WITH test_df AS (
    SELECT *
    FROM VALUES
    (25, 0.678914, TRUE, 'c'),
    (66, 0.067579, FALSE, 'c'),
    (59, 0.289119, TRUE, 'a'),
    (46, 0.631288, FALSE, 'c'),
    (45, -1.457156, FALSE, 'b')
    AS t(age, score, flag, category)
)
SELECT MODEL(repro_explain_type_bug)!PREDICT(age, score, flag, category):prediction::NUMBER AS model_pred
FROM test_df;
""")
pred_in_sql.to_pandas()
MODEL_PRED
0
1
1
1
0

explain works in python:

mv.run(test_df, function_name="explain")
<produces a pandas DataFrame with 4 columns and 5 rows>

But explain fails in python:

explain_in_sql = session.sql("""
WITH test_df AS (
    SELECT *
    FROM VALUES
    (25, 0.678914, TRUE, 'c'),
    (66, 0.067579, FALSE, 'c'),
    (59, 0.289119, TRUE, 'a'),
    (46, 0.631288, FALSE, 'c'),
    (45, -1.457156, FALSE, 'b')
    AS t(age, score, flag, category)
)
SELECT *
FROM test_df,
TABLE(MODEL(repro_explain_type_bug)!EXPLAIN(age, score, flag, category));
""")
explain_in_sql.to_pandas()
SnowparkSQLException: (1304): 01c2d9dc-0a1b-3f25-0026-55033fc11336: 001044 (42P13): SQL compilation error: error line 13 at position 6
Invalid argument types for function 'REPRO_EXPLAIN_TYPE_BUG!"DEFAULT.EXPLAIN"': (NUMBER(2,0), NUMBER(7,6), BOOLEAN, VARCHAR(1))

The predict and explain function signatures are the same.

mv.show_functions()
[{'name': 'EXPLAIN', 'target_method': 'explain', 'target_method_function_type': 'TABLE_FUNCTION', 'signature': ModelSignature(
    inputs=[
        FeatureSpec(dtype=DataType.INT64, name='age', nullable=True),
        FeatureSpec(dtype=DataType.DOUBLE, name='score', nullable=True),
        FeatureSpec(dtype=DataType.BOOL, name='flag', nullable=True),
        FeatureSpec(dtype=DataType.STRING, name='category', nullable=True)
    ],
    outputs=[
        FeatureSpec(dtype=DataType.DOUBLE, name='categorical_features__category_a_explanation', nullable=True),
        FeatureSpec(dtype=DataType.DOUBLE, name='categorical_features__category_b_explanation', nullable=True),
        FeatureSpec(dtype=DataType.DOUBLE, name='categorical_features__category_c_explanation', nullable=True),
        FeatureSpec(dtype=DataType.DOUBLE, name='continuous_features__score_explanation', nullable=True)
    ],
    params=[]
), 'is_partitioned': True}, {'name': 'PREDICT', 'target_method': 'predict', 'target_method_function_type': 'FUNCTION', 'signature': ModelSignature(
    inputs=[
        FeatureSpec(dtype=DataType.INT64, name='age', nullable=True),
        FeatureSpec(dtype=DataType.DOUBLE, name='score', nullable=True),
        FeatureSpec(dtype=DataType.BOOL, name='flag', nullable=True),
        FeatureSpec(dtype=DataType.STRING, name='category', nullable=True)
    ],
    outputs=[
        FeatureSpec(dtype=DataType.INT64, name='prediction', nullable=True)
    ],
    params=[]
), 'is_partitioned': False}]

I think it has something to do with table functions having more strict type checking than regular scalar functions in Snowflake? I'm not sure.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions