Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### New Features

- Added support for the `array_union_agg` function in the `snowflake.snowpark.functions` module.
- Added the `udf_init_once` decorator in `snowflake.snowpark.functions` for marking functions to be executed once during pre-fork initialization on Snowflake workers, matching the server-side `_snowflake.udf_init_once` API.

#### Bug Fixes

Expand Down
56 changes: 56 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
<BLANKLINE>
"""
import functools
import inspect
import sys
import typing
from functools import reduce
Expand Down Expand Up @@ -10315,6 +10316,61 @@ def _inner(*args, **kwargs):
return _decorator


def udf_init_once(func: Callable) -> Callable:
"""Mark a function to be executed once during pre-fork initialization.

This decorator has the same signature as the server-side ``_snowflake.udf_init_once``.
It is a no-op for local invocation. Functions decorated with ``@udf_init_once`` run in
the head worker process before individual workers are forked. The initialized state
(e.g., loaded models, computed lookup tables) is shared across all workers via
Copy-On-Write.

The decorated function must:
- Accept zero arguments
- Be a callable (function, lambda, or callable object)

Multiple ``@udf_init_once`` functions are executed in the order they are defined.

Example::

from snowflake.snowpark.functions import udf_init_once

model = None

@udf_init_once
def load_model():
global model
model = 42

Use this decorator in handler files registered via
:meth:`~snowflake.snowpark.udf.UDFRegistration.register_from_file`.
On the Snowflake server the ``_snowflake.udf_init_once`` implementation
is used instead; this client-side definition provides the same API so
that handler files can be tested locally.

Args:
func: The init function to decorate. Must be callable and accept zero arguments.

Returns:
The decorated function, unchanged.

See Also:
- :func:`udf`
- :meth:`~snowflake.snowpark.udf.UDFRegistration.register_from_file`
"""
if not callable(func):
raise TypeError(
f"@udf_init_once target must be callable, got {type(func).__name__}"
)
sig = inspect.signature(func)
if len(sig.parameters) != 0:
raise TypeError(
f"@udf_init_once function must take 0 arguments, got {len(sig.parameters)}"
)
func._sf_init_once = True
return func


@publicapi
def pandas_udtf(
handler: Optional[Callable] = None,
Expand Down
17 changes: 17 additions & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2684,3 +2684,20 @@ def add_series(s1, s2):
)
res = df.select(add_series("a", "b").alias("result")).collect()
assert [row.RESULT for row in res] == [3, 30, 300, 15]


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="UDF init_once is not supported in Local Testing",
)
def test_udf_init_once_register_from_file(session):
"""Test @udf_init_once in a handler file used with register_from_file."""
multiply_udf = session.udf.register_from_file(
file_path="tests/resources/test_udf_dir/test_udf_init_once_file.py",
func_name="multiply",
return_type=IntegerType(),
input_types=[IntegerType()],
)
df = session.create_dataframe([[1], [2], [3]], schema=["a"])
res = df.select(multiply_udf("a").alias("result")).collect()
assert sorted(row.RESULT for row in res) == [10, 20, 30]
17 changes: 17 additions & 0 deletions tests/resources/test_udf_dir/test_udf_init_once_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
try:
from _snowflake import udf_init_once
except ModuleNotFoundError:
from snowflake.snowpark.functions import udf_init_once


_multiplier = 1


@udf_init_once
def setup():
global _multiplier
_multiplier = 10


def multiply(x):
return x * _multiplier
14 changes: 9 additions & 5 deletions tests/unit/scala/test_utils_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_stage_parts,
get_temp_type_for_object,
get_udf_upload_prefix,
is_cloud_path,
is_snowflake_quoted_id_case_insensitive,
is_snowflake_unquoted_suffix_case_insensitive,
is_sql_select_statement,
Expand All @@ -30,7 +31,6 @@
validate_object_name,
warning,
zip_file_or_directory_to_stream,
is_cloud_path,
)
from tests.utils import IS_WINDOWS, TestFiles

Expand Down Expand Up @@ -114,23 +114,23 @@ def test_calculate_checksum():
else:
assert (
calculate_checksum(test_files.test_udf_directory)
== "3a2607ef293801f59e7840f5be423d4a55edfe2ac732775dcfda01205df377f0"
== "d472060e6d717517a3f9c7048ac43fc0c646467a6b3772f63355071a65ad8ecf"
)
assert (
calculate_checksum(test_files.test_udf_directory, algorithm="md5")
== "b72b61c8d5639fff8aa9a80278dba60f"
== "322ad29c4018a1375f61b670196cf902"
)
# Validate that hashes are different when reading whole dir.
# Using a sufficiently small chunk size so that the hashes differ.
assert (
calculate_checksum(test_files.test_udf_directory, chunk_size=128)
== "c071de824a67c083edad45c2b18729e17c50f1b13be980140437063842ea2469"
== "40d4811752a978779518082f0312f8823cb46dd84d12c68a7bb456c388f38df4"
)
assert (
calculate_checksum(
test_files.test_udf_directory, chunk_size=128, whole_file_hash=True
)
== "3a2607ef293801f59e7840f5be423d4a55edfe2ac732775dcfda01205df377f0"
== "d472060e6d717517a3f9c7048ac43fc0c646467a6b3772f63355071a65ad8ecf"
)


Expand Down Expand Up @@ -256,6 +256,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
[
"test_udf_dir/",
"test_udf_dir/test_another_udf_file.py",
"test_udf_dir/test_udf_init_once_file.py",
"test_udf_dir/test_pandas_udf_file.py",
"test_udf_dir/test_udf_file.py",
],
Expand All @@ -270,6 +271,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
[
"test_udf_dir/",
"test_udf_dir/test_another_udf_file.py",
"test_udf_dir/test_udf_init_once_file.py",
"test_udf_dir/test_pandas_udf_file.py",
"test_udf_dir/test_udf_file.py",
],
Expand All @@ -285,6 +287,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
"resources/",
"resources/test_udf_dir/",
"resources/test_udf_dir/test_another_udf_file.py",
"resources/test_udf_dir/test_udf_init_once_file.py",
"resources/test_udf_dir/test_pandas_udf_file.py",
"resources/test_udf_dir/test_udf_file.py",
],
Expand Down Expand Up @@ -381,6 +384,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files):
"resources/test_sp_dir/test_table_sp_file.py",
"resources/test_udf_dir/",
"resources/test_udf_dir/test_another_udf_file.py",
"resources/test_udf_dir/test_udf_init_once_file.py",
"resources/test_udf_dir/test_pandas_udf_file.py",
"resources/test_udf_dir/test_udf_file.py",
"resources/test_udtf_dir/",
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ def test_generate_python_code_exception():
assert "Source code comment could not be generated" in generated_code


def test_udf_init_once_decorator_simple():
"""Test @udf_init_once as simple decorator (same API as _snowflake.udf_init_once)."""
from snowflake.snowpark.functions import udf_init_once

@udf_init_once
def my_init():
pass

# Sets _sf_init_once = True on the init function itself
assert my_init._sf_init_once is True
# The function is returned unchanged (not wrapped)
assert my_init.__name__ == "my_init"


def test_udf_init_once_rejects_non_zero_arg_function():
from snowflake.snowpark.functions import udf_init_once

with pytest.raises(TypeError, match="must take 0 arguments"):

@udf_init_once
def bad_init(x):
return x


def test_udf_init_once_rejects_non_callable():
from snowflake.snowpark.functions import udf_init_once

with pytest.raises(TypeError, match="must be callable"):
udf_init_once("not_a_function")


@pytest.mark.parametrize(
"packages",
[
Expand Down
Loading