diff --git a/CHANGELOG.md b/CHANGELOG.md index d4642f23f0..97b0395ae9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### New Features - Added support for the `array_union_agg` function in the `snowflake.snowpark.functions` module. +- Added the `udf_init_once` decorator in `snowflake.snowpark.functions` for marking functions to be executed once during pre-fork initialization on Snowflake workers, matching the server-side `_snowflake.udf_init_once` API. #### Bug Fixes diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 0737b24828..6a140b2980 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -158,6 +158,7 @@ """ import functools +import inspect import sys import typing from functools import reduce @@ -10315,6 +10316,61 @@ def _inner(*args, **kwargs): return _decorator +def udf_init_once(func: Callable) -> Callable: + """Mark a function to be executed once during pre-fork initialization. + + This decorator has the same signature as the server-side ``_snowflake.udf_init_once``. + It is a no-op for local invocation. Functions decorated with ``@udf_init_once`` run in + the head worker process before individual workers are forked. The initialized state + (e.g., loaded models, computed lookup tables) is shared across all workers via + Copy-On-Write. + + The decorated function must: + - Accept zero arguments + - Be a callable (function, lambda, or callable object) + + Multiple ``@udf_init_once`` functions are executed in the order they are defined. + + Example:: + + from snowflake.snowpark.functions import udf_init_once + + model = None + + @udf_init_once + def load_model(): + global model + model = 42 + + Use this decorator in handler files registered via + :meth:`~snowflake.snowpark.udf.UDFRegistration.register_from_file`. + On the Snowflake server the ``_snowflake.udf_init_once`` implementation + is used instead; this client-side definition provides the same API so + that handler files can be tested locally. + + Args: + func: The init function to decorate. Must be callable and accept zero arguments. + + Returns: + The decorated function, unchanged. + + See Also: + - :func:`udf` + - :meth:`~snowflake.snowpark.udf.UDFRegistration.register_from_file` + """ + if not callable(func): + raise TypeError( + f"@udf_init_once target must be callable, got {type(func).__name__}" + ) + sig = inspect.signature(func) + if len(sig.parameters) != 0: + raise TypeError( + f"@udf_init_once function must take 0 arguments, got {len(sig.parameters)}" + ) + func._sf_init_once = True + return func + + @publicapi def pandas_udtf( handler: Optional[Callable] = None, diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index ce8b1ae5f9..5986cbf576 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -2684,3 +2684,20 @@ def add_series(s1, s2): ) res = df.select(add_series("a", "b").alias("result")).collect() assert [row.RESULT for row in res] == [3, 30, 300, 15] + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="UDF init_once is not supported in Local Testing", +) +def test_udf_init_once_register_from_file(session): + """Test @udf_init_once in a handler file used with register_from_file.""" + multiply_udf = session.udf.register_from_file( + file_path="tests/resources/test_udf_dir/test_udf_init_once_file.py", + func_name="multiply", + return_type=IntegerType(), + input_types=[IntegerType()], + ) + df = session.create_dataframe([[1], [2], [3]], schema=["a"]) + res = df.select(multiply_udf("a").alias("result")).collect() + assert sorted(row.RESULT for row in res) == [10, 20, 30] diff --git a/tests/resources/test_udf_dir/test_udf_init_once_file.py b/tests/resources/test_udf_dir/test_udf_init_once_file.py new file mode 100644 index 0000000000..2c89a42a7a --- /dev/null +++ b/tests/resources/test_udf_dir/test_udf_init_once_file.py @@ -0,0 +1,17 @@ +try: + from _snowflake import udf_init_once +except ModuleNotFoundError: + from snowflake.snowpark.functions import udf_init_once + + +_multiplier = 1 + + +@udf_init_once +def setup(): + global _multiplier + _multiplier = 10 + + +def multiply(x): + return x * _multiplier diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index 6b18c6aa9f..b3a9acc943 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -19,6 +19,7 @@ get_stage_parts, get_temp_type_for_object, get_udf_upload_prefix, + is_cloud_path, is_snowflake_quoted_id_case_insensitive, is_snowflake_unquoted_suffix_case_insensitive, is_sql_select_statement, @@ -30,7 +31,6 @@ validate_object_name, warning, zip_file_or_directory_to_stream, - is_cloud_path, ) from tests.utils import IS_WINDOWS, TestFiles @@ -114,23 +114,23 @@ def test_calculate_checksum(): else: assert ( calculate_checksum(test_files.test_udf_directory) - == "3a2607ef293801f59e7840f5be423d4a55edfe2ac732775dcfda01205df377f0" + == "d472060e6d717517a3f9c7048ac43fc0c646467a6b3772f63355071a65ad8ecf" ) assert ( calculate_checksum(test_files.test_udf_directory, algorithm="md5") - == "b72b61c8d5639fff8aa9a80278dba60f" + == "322ad29c4018a1375f61b670196cf902" ) # Validate that hashes are different when reading whole dir. # Using a sufficiently small chunk size so that the hashes differ. assert ( calculate_checksum(test_files.test_udf_directory, chunk_size=128) - == "c071de824a67c083edad45c2b18729e17c50f1b13be980140437063842ea2469" + == "40d4811752a978779518082f0312f8823cb46dd84d12c68a7bb456c388f38df4" ) assert ( calculate_checksum( test_files.test_udf_directory, chunk_size=128, whole_file_hash=True ) - == "3a2607ef293801f59e7840f5be423d4a55edfe2ac732775dcfda01205df377f0" + == "d472060e6d717517a3f9c7048ac43fc0c646467a6b3772f63355071a65ad8ecf" ) @@ -256,6 +256,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): [ "test_udf_dir/", "test_udf_dir/test_another_udf_file.py", + "test_udf_dir/test_udf_init_once_file.py", "test_udf_dir/test_pandas_udf_file.py", "test_udf_dir/test_udf_file.py", ], @@ -270,6 +271,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): [ "test_udf_dir/", "test_udf_dir/test_another_udf_file.py", + "test_udf_dir/test_udf_init_once_file.py", "test_udf_dir/test_pandas_udf_file.py", "test_udf_dir/test_udf_file.py", ], @@ -285,6 +287,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/", "resources/test_udf_dir/", "resources/test_udf_dir/test_another_udf_file.py", + "resources/test_udf_dir/test_udf_init_once_file.py", "resources/test_udf_dir/test_pandas_udf_file.py", "resources/test_udf_dir/test_udf_file.py", ], @@ -381,6 +384,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/test_sp_dir/test_table_sp_file.py", "resources/test_udf_dir/", "resources/test_udf_dir/test_another_udf_file.py", + "resources/test_udf_dir/test_udf_init_once_file.py", "resources/test_udf_dir/test_pandas_udf_file.py", "resources/test_udf_dir/test_udf_file.py", "resources/test_udtf_dir/", diff --git a/tests/unit/test_udf_utils.py b/tests/unit/test_udf_utils.py index 32a1626eb9..337af3c101 100644 --- a/tests/unit/test_udf_utils.py +++ b/tests/unit/test_udf_utils.py @@ -79,6 +79,37 @@ def test_generate_python_code_exception(): assert "Source code comment could not be generated" in generated_code +def test_udf_init_once_decorator_simple(): + """Test @udf_init_once as simple decorator (same API as _snowflake.udf_init_once).""" + from snowflake.snowpark.functions import udf_init_once + + @udf_init_once + def my_init(): + pass + + # Sets _sf_init_once = True on the init function itself + assert my_init._sf_init_once is True + # The function is returned unchanged (not wrapped) + assert my_init.__name__ == "my_init" + + +def test_udf_init_once_rejects_non_zero_arg_function(): + from snowflake.snowpark.functions import udf_init_once + + with pytest.raises(TypeError, match="must take 0 arguments"): + + @udf_init_once + def bad_init(x): + return x + + +def test_udf_init_once_rejects_non_callable(): + from snowflake.snowpark.functions import udf_init_once + + with pytest.raises(TypeError, match="must be callable"): + udf_init_once("not_a_function") + + @pytest.mark.parametrize( "packages", [