diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index ca9a662d399e..1cd220cc2566 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -246,6 +246,8 @@ def __init__( self._tags = tags self._main_tag = main_tag self._transform = transform + self._tagged_output_types = ( + transform.get_type_hints().tagged_output_types() if transform else {}) self._allow_unknown_tags = ( not tags if allow_unknown_tags is None else allow_unknown_tags) # The ApplyPTransform instance for the application of the multi FlatMap @@ -303,7 +305,7 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection: pcoll = PCollection( self._pipeline, tag=tag, - element_type=typehints.Any, + element_type=self._tagged_output_types.get(tag, typehints.Any), is_bounded=is_bounded) # Transfer the producer from the DoOutputsTuple to the resulting # PCollection. @@ -323,7 +325,11 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection: return pcoll -class TaggedOutput(object): +TagType = TypeVar('TagType', bound=str) +ValueType = TypeVar('ValueType') + + +class TaggedOutput(Generic[TagType, ValueType]): """An object representing a tagged value. ParDo, Map, and FlatMap transforms can emit values on multiple outputs which @@ -331,7 +337,7 @@ class TaggedOutput(object): if it wants to emit on the main output and TaggedOutput objects if it wants to emit a value on a specific tagged output. """ - def __init__(self, tag: str, value: Any) -> None: + def __init__(self, tag: TagType, value: ValueType) -> None: if not isinstance(tag, str): raise TypeError( 'Attempting to create a TaggedOutput with non-string tag %s' % diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index ea11bca9474d..19c96fa51f2b 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1834,6 +1834,17 @@ def with_outputs(self, *tags, main=None, allow_unknown_tags=None): raise ValueError( 'Main output tag %r must be different from side output tags %r.' % (main, tags)) + type_hints = self.get_type_hints() + declared_tags = set(type_hints.tagged_output_types().keys()) + requested_tags = set(tags) + + unknown = requested_tags - declared_tags + if unknown and declared_tags: # Only warn if type hints exist + logging.warning( + "Tags %s requested in with_outputs() but not declared " + "in type hints. Declared tags: %s", + unknown, + declared_tags) return _MultiParDo(self, tags, main, allow_unknown_tags) def _do_fn_info(self): @@ -2120,8 +2131,10 @@ def Map(fn, *args, **kwargs): # pylint: disable=invalid-name wrapper) output_hint = type_hints.simple_output_type(label) if output_hint: + tagged_output_types = type_hints.tagged_output_types() wrapper = with_output_types( - typehints.Iterable[_strip_output_annotations(output_hint)])( + typehints.Iterable[_strip_output_annotations(output_hint)], + **tagged_output_types)( wrapper) # pylint: disable=protected-access wrapper._argspec_fn = fn @@ -2189,8 +2202,10 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name pass output_hint = type_hints.simple_output_type(label) if output_hint: + tagged_output_types = type_hints.tagged_output_types() wrapper = with_output_types( - typehints.Iterable[_strip_output_annotations(output_hint)])( + typehints.Iterable[_strip_output_annotations(output_hint)], + **tagged_output_types)( wrapper) # Replace the first (args) component. @@ -2261,7 +2276,10 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name pass output_hint = type_hints.simple_output_type(label) if output_hint: - wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper) + tagged_output_types = type_hints.tagged_output_types() + wrapper = with_output_types( + _strip_output_annotations(output_hint), **tagged_output_types)( + wrapper) # Replace the first (args) component. modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:] diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 94e9a0644d04..d5985b6212df 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -414,12 +414,15 @@ def with_input_types(self, input_type_hint): input_type_hint, 'Type hints for a PTransform') return super().with_input_types(input_type_hint) - def with_output_types(self, type_hint): + def with_output_types(self, type_hint, **tagged_type_hints): """Annotates the output type of a :class:`PTransform` with a type-hint. Args: type_hint (type): An instance of an allowed built-in type, a custom class, - or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. + or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. This is + the type hint for the main output. + **tagged_type_hints: Type hints for tagged outputs. Each keyword argument + specifies the type for a tagged output e.g., ``errors=str``. Raises: TypeError: If **type_hint** is not a valid type-hint. See @@ -430,10 +433,22 @@ def with_output_types(self, type_hint): PTransform: A reference to the instance of this particular :class:`PTransform` object. This allows chaining type-hinting related methods. + + Example:: + result = pcoll | beam.ParDo(MyDoFn()).with_output_types( + int, # main output type + errors=str, # 'errors' tagged output type + warnings=str # 'warnings' tagged output type + ).with_outputs('errors', 'warnings', main='main') """ type_hint = native_type_compatibility.convert_to_beam_type(type_hint) validate_composite_type_param(type_hint, 'Type hints for a PTransform') - return super().with_output_types(type_hint) + for tag, hint in tagged_type_hints.items(): + tagged_type_hints[tag] = native_type_compatibility.convert_to_beam_type( + hint) + validate_composite_type_param( + tagged_type_hints[tag], f'Tagged output type hint for {tag!r}') + return super().with_output_types(type_hint, **tagged_type_hints) def with_resource_hints(self, **kwargs): # type: (...) -> PTransform """Adds resource hints to the :class:`PTransform`. @@ -479,10 +494,11 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output): if hints is None or not any(hints): return arg_hints, kwarg_hints = hints - if arg_hints and kwarg_hints: + # Output types can have kwargs for tagged output types. + if arg_hints and kwarg_hints and input_or_output != 'output': raise TypeCheckError( - 'PTransform cannot have both positional and keyword type hints ' - 'without overriding %s._type_check_%s()' % + 'PTransform cannot have both positional and keyword input type hints' + ' without overriding %s._type_check_%s()' % (self.__class__, input_or_output)) root_hint = ( arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 2d2f7981dd29..106765d76ba4 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -79,6 +79,7 @@ def foo((a, b)): # pytype: skip-file +import collections.abc import inspect import itertools import logging @@ -89,12 +90,16 @@ def foo((a, b)): from typing import Dict from typing import Iterable from typing import List +from typing import Literal from typing import NamedTuple from typing import Optional from typing import Tuple from typing import TypeVar from typing import Union +from typing import get_args +from typing import get_origin +from apache_beam.pvalue import TaggedOutput from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type @@ -180,6 +185,104 @@ def disable_type_annotations(): TRACEBACK_LIMIT = 5 +_NO_MAIN_TYPE = object() + + +def _is_union_type(origin): + """Check if a type origin is a Union (typing.Union or types.UnionType).""" + return origin is Union or origin is types.UnionType + + +def _tag_and_type(t): + """Extract tag name and value type from TaggedOutput[Literal['tag'], Type]. + + Returns raw Python types - conversion to beam types happens in + _extract_output_types. + """ + args = get_args(t) + if len(args) != 2: + raise TypeError( + f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}") + + literal_type, value_type = args + + if get_origin(literal_type) is not Literal: + raise TypeError( + f"First type parameter of TaggedOutput must be Literal['tag_name'], " + f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]") + + tag_string = get_args(literal_type)[0] + return tag_string, value_type + + +def _extract_main_and_tagged(t): + """Extract main type and tagged types from a type annotation. + + Returns: + (main_type, tagged_dict) where main_type is the type without TaggedOutput + annotations (or _NO_MAIN_TYPE if no main type), and tagged_dict maps tag + names to their types. + """ + if get_origin(t) is TaggedOutput: + tag, typ = _tag_and_type(t) + return _NO_MAIN_TYPE, {tag: typ} + + if t is TaggedOutput: + logging.warning( + "TaggedOutput in return type must include type parameters: " + "TaggedOutput[Literal['tag_name'], ValueType]. " + "Bare TaggedOutput falling back to Any.") + return _NO_MAIN_TYPE, {} + + if not _is_union_type(get_origin(t)): + return t, {} + + main_types = [] + tagged_types = {} + for arg in get_args(t): + if get_origin(arg) is TaggedOutput: + tag, typ = _tag_and_type(arg) + tagged_types[tag] = typ + elif arg is TaggedOutput: + logging.warning( + "TaggedOutput in return type must include type parameters: " + "TaggedOutput[Literal['tag_name'], ValueType]. " + "Bare TaggedOutput falling back to Any.") + else: + main_types.append(arg) + + if len(main_types) == 0: + main_type = _NO_MAIN_TYPE + elif len(main_types) == 1: + main_type = main_types[0] + else: + main_type = Union[tuple(main_types)] + + return main_type, tagged_types + + +def _extract_output_types(return_annotation): + """Parse return annotation into (main_types, tagged_types). + + For tagged outputs to be extracted from generator/iterator functions, + users must explicitly use Iterable[T | TaggedOutput[...]] as return type. + + Returns raw Python types. Conversion to beam types happens in from_callable. + """ + if return_annotation == inspect.Signature.empty: + return [Any], {} + + # Iterable[T | TaggedOutput[...]] + if get_origin(return_annotation) is collections.abc.Iterable: + yield_type = get_args(return_annotation)[0] + clean_yield, tagged_types = _extract_main_and_tagged(yield_type) + clean_main = Any if clean_yield is _NO_MAIN_TYPE else clean_yield + return [Iterable[clean_main]], tagged_types + + # T | TaggedOutput (or plain type with no tags) + main_type, tagged_types = _extract_main_and_tagged(return_annotation) + main = Any if main_type is _NO_MAIN_TYPE else main_type + return [main], tagged_types class IOTypeHints(NamedTuple): @@ -273,11 +376,14 @@ def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: param.VAR_POSITIONAL], \ 'Unsupported Parameter kind: %s' % param.kind input_args.append(convert_to_beam_type(param.annotation)) - output_args = [] - if signature.return_annotation != signature.empty: - output_args.append(convert_to_beam_type(signature.return_annotation)) - else: - output_args.append(typehints.Any) + + output_args, output_kwargs = _extract_output_types( + signature.return_annotation) + output_args = [convert_to_beam_type(t) for t in output_args] + output_kwargs = { + k: convert_to_beam_type(v) + for k, v in output_kwargs.items() + } name = getattr(fn, '__name__', '') msg = ['from_callable(%s)' % name, ' signature: %s' % signature] @@ -287,7 +393,7 @@ def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: (fn.__code__.co_filename, fn.__code__.co_firstlineno)) return IOTypeHints( input_types=(tuple(input_args), input_kwargs), - output_types=(tuple(output_args), {}), + output_types=(tuple(output_args), output_kwargs), origin=cls._make_origin([], tb=False, msg=msg)) def with_input_types(self, *args, **kwargs) -> 'IOTypeHints': @@ -308,18 +414,24 @@ def with_output_types_from(self, other: 'IOTypeHints') -> 'IOTypeHints': def simple_output_type(self, context): if self._has_output_types(): - args, kwargs = self.output_types - if len(args) != 1 or kwargs: + args, _ = self.output_types + # Note: kwargs may contain tagged output types, which are ignored here. + # Use tagged_output_types() to access those. + if len(args) != 1: raise TypeError( 'Expected single output type hint for %s but got: %s' % (context, self.output_types)) return args[0] + def tagged_output_types(self): + if not self._has_output_types(): + return {} + _, tagged_output_types = self.output_types + return tagged_output_types + def has_simple_output_type(self): """Whether there's a single positional output type.""" - return ( - self.output_types and len(self.output_types[0]) == 1 and - not self.output_types[1]) + return (self.output_types and len(self.output_types[0]) == 1) def strip_pcoll(self): from apache_beam.pipeline import Pipeline @@ -413,6 +525,7 @@ def strip_iterable(self) -> 'IOTypeHints': if self.output_types is None or not self.has_simple_output_type(): return self output_type = self.output_types[0][0] + tagged_output_types = self.output_types[1] if output_type is None or isinstance(output_type, type(None)): return self # If output_type == Optional[T]: output_type = T. @@ -427,12 +540,12 @@ def strip_iterable(self) -> 'IOTypeHints': if isinstance(output_type, typehints.TypeVariable): # We don't know what T yields, so we just assume Any. return self._replace( - output_types=((typehints.Any, ), {}), + output_types=((typehints.Any, ), tagged_output_types), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) yielded_type = typehints.get_yielded_type(output_type) return self._replace( - output_types=((yielded_type, ), {}), + output_types=((yielded_type, ), tagged_output_types), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints': @@ -782,7 +895,7 @@ def annotate_input_types(f): def with_output_types(*return_type_hint: Any, - **kwargs: Any) -> Callable[[T], T]: + **tagged_type_hints: Any) -> Callable[[T], T]: """A decorator that type-checks defined type-hints for return values(s). This decorator will type-check the return value(s) of the decorated function. @@ -822,18 +935,34 @@ def parse_ints(ints): def negate(p): return not p if p else p + For DoFns with tagged outputs, you can specify type hints for each tag: + + .. testcode:: + from apache_beam.typehints import with_input_types, with_output_types + @with_output_types(int, errors=str, warnings=str) + class MyDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', 'Negative value') + elif element == 0: + yield beam.pvalue.TaggedOutput('warnings', 'Zero value') + else: + yield element + Args: *return_type_hint: A type-hint specifying the proper return type of the function. This argument should either be a built-in Python type or an instance of a :class:`~apache_beam.typehints.typehints.TypeConstraint` created by 'indexing' a :class:`~apache_beam.typehints.typehints.CompositeTypeHint`. - **kwargs: Not used. + **tagged_type_hints: Type hints for tagged outputs. Each keyword argument + specifies the type for a tagged output, e.g., ``errors=str``. + Raises: - :class:`ValueError`: If any kwarg parameters are passed in, - or the length of **return_type_hint** is greater than ``1``. Or if the - inner wrapper function isn't passed a function object. + :class:`ValueError`: If the length of **return_type_hint** is greater + than ``1``. Or if the inner wrapper function isn't passed a function + object. :class:`TypeCheckError`: If the **return_type_hint** object is in invalid type-hint. @@ -841,11 +970,6 @@ def negate(p): The original function decorated such that it enforces type-hint constraints for all return values. """ - if kwargs: - raise ValueError( - "All arguments for the 'returns' decorator must be " - "positional arguments.") - if len(return_type_hint) != 1: raise ValueError( "'returns' accepts only a single positional argument. In " @@ -854,13 +978,20 @@ def negate(p): return_type_hint = native_type_compatibility.convert_to_beam_type( return_type_hint[0]) - validate_composite_type_param( return_type_hint, error_msg_prefix='All type hint arguments') + converted_tag_hints = {} + for tag, hint in tagged_type_hints.items(): + converted_hint = native_type_compatibility.convert_to_beam_type(hint) + validate_composite_type_param( + converted_hint, 'Tagged output type hint for %r' % tag) + converted_tag_hints[tag] = converted_hint + def annotate_output_types(f): th = getattr(f, '_type_hints', IOTypeHints.empty()) - f._type_hints = th.with_output_types(return_type_hint) # pylint: disable=protected-access + f._type_hints = th.with_output_types( # pylint: disable=protected-access + return_type_hint, **converted_tag_hints) return f return annotate_output_types diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index a2909b4e545f..626c2ffbd497 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -24,6 +24,7 @@ import unittest from apache_beam import Map +from apache_beam.pvalue import TaggedOutput from apache_beam.typehints import Any from apache_beam.typehints import Dict from apache_beam.typehints import List @@ -262,6 +263,75 @@ def fn(a: int) -> int: th = decorators.IOTypeHints.from_callable(fn) self.assertRegex(th.debug_str(), r'unknown') + def test_from_callable_no_tagged_output(self): + def fn(x: int) -> str: + return str(x) + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual(th.output_types, ((str, ), {})) + + def fn2(x: int) -> typing.Iterable[str]: + yield str(x) + + th = decorators.IOTypeHints.from_callable(fn2) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual(th.output_types, ((typehints.Iterable[str], ), {})) + + def test_from_callable_tagged_output_union(self): + def fn( + x: int + ) -> int | str | TaggedOutput[typing.Literal['errors'], float + | str] | TaggedOutput[ + typing.Literal['warnings'], str]: + return x + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual( + th.output_types, + ((typehints.Union[int, str], ), { + 'errors': typehints.Union[float, str], 'warnings': str + })) + + def test_from_callable_tagged_output_iterable(self): + def fn( + x: int + ) -> typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]]: + yield x + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual( + th.output_types, ((typehints.Iterable[int], ), { + 'errors': str + })) + + def test_from_callable_tagged_output_multiple_tags(self): + def fn( + x: int + ) -> ( + int | TaggedOutput[typing.Literal['errors'], str] | + TaggedOutput[typing.Literal['warnings'], str]): + return x + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual( + th.output_types, ((int, ), { + 'errors': str, 'warnings': str + })) + + def test_from_callable_tagged_output_only(self): + def fn(x: int) -> TaggedOutput[typing.Literal['errors'], str]: + pass + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual(th.output_types, ((Any, ), { + 'errors': str + })) + def test_getcallargs_forhints(self): def fn( a: int, @@ -426,5 +496,90 @@ def fn2(a: int) -> int: _ = ['a', 'b', 'c'] | Map(fn2) # Doesn't raise - no input type hints. +class TaggedOutputExtractionTest(unittest.TestCase): + """Tests for TaggedOutput extraction helper functions.""" + def test_extract_main_and_tagged_simple_type(self): + main, tagged = decorators._extract_main_and_tagged(int) + self.assertEqual(main, int) + self.assertEqual(tagged, {}) + + def test_extract_main_and_tagged_tagged_output_only(self): + t = TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_main_and_tagged(t) + self.assertIs(main, decorators._NO_MAIN_TYPE) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_main_and_tagged_union(self): + t = int | TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_main_and_tagged(t) + self.assertEqual(main, int) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_main_and_tagged_union_multiple_tagged(self): + t = ( + int | TaggedOutput[typing.Literal['errors'], str] + | TaggedOutput[typing.Literal['warnings'], str]) + main, tagged = decorators._extract_main_and_tagged(t) + self.assertEqual(main, int) + self.assertEqual(tagged, {'errors': str, 'warnings': str}) + + def test_extract_main_and_tagged_union_multiple_main_types(self): + t = (int | str | TaggedOutput[typing.Literal['errors'], bytes]) + main, tagged = decorators._extract_main_and_tagged(t) + # Main type should be Union[int, str] + self.assertEqual(typing.get_origin(main), typing.Union) + self.assertIn(int, typing.get_args(main)) + self.assertIn(str, typing.get_args(main)) + self.assertEqual(tagged, {'errors': bytes}) + + def test_extract_output_types_empty_signature(self): + import inspect + main, tagged = decorators._extract_output_types(inspect.Signature.empty) + self.assertEqual(main, [typing.Any]) + self.assertEqual(tagged, {}) + + def test_extract_output_types_simple_type(self): + main, tagged = decorators._extract_output_types(int) + self.assertEqual(main, [int]) + self.assertEqual(tagged, {}) + + def test_extract_output_types_union_with_tagged(self): + t = int | TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [int]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_iterable_with_tagged(self): + t = typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [typing.Iterable[int]]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_list_with_tagged_not_extracted(self): + t = typing.List[int | TaggedOutput[typing.Literal['errors'], str]] + _, tagged = decorators._extract_output_types(t) + # The whole type is converted as-is. Users should use Iterable instead. + self.assertEqual(tagged, {}) + + def test_extract_output_types_tagged_only(self): + t = TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [typing.Any]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_iterable_tagged_only(self): + t = typing.Iterable[TaggedOutput[typing.Literal['errors'], str]] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [typing.Iterable[typing.Any]]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_bare_tagged_excluded(self): + with self.assertLogs(level='WARNING') as cm: + main, tagged = decorators._extract_output_types(str | TaggedOutput) + self.assertIn('Bare TaggedOutput falling back to Any', cm.output[0]) + self.assertEqual(main, [str]) + self.assertEqual(tagged, {}) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py new file mode 100644 index 000000000000..5dfae1b7e3dd --- /dev/null +++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py @@ -0,0 +1,356 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for tagged output type hints. + +This tests the implementation of type hints for tagged outputs via three styles: + +1. Decorator style: + @with_output_types(int, errors=str, warnings=str) + class MyDoFn(beam.DoFn): + ... + +2. Method chain style: + beam.ParDo(MyDoFn()).with_output_types(int, errors=str) + +3. Function annotation style: + def fn(element) -> int | TaggedOutput[Literal['errors'], str]: + ... +""" + +# pytype: skip-file + +import unittest +from typing import Iterable +from typing import Literal +from typing import Union + +import apache_beam as beam +from apache_beam.pvalue import TaggedOutput +from apache_beam.typehints import with_output_types +from apache_beam.typehints.decorators import IOTypeHints + + +class IOTypeHintsTaggedOutputTest(unittest.TestCase): + """Tests for IOTypeHints.tagged_output_types() accessor.""" + def test_empty_hints_returns_empty_dict(self): + empty = IOTypeHints.empty() + self.assertEqual(empty.tagged_output_types(), {}) + + def test_with_tagged_types(self): + hints = IOTypeHints.empty().with_output_types(int, errors=str, warnings=str) + self.assertEqual( + hints.tagged_output_types(), { + 'errors': str, 'warnings': str + }) + + def test_simple_output_type_with_tagged_types(self): + """simple_output_type() should still return main type when tags present.""" + hints = IOTypeHints.empty().with_output_types(int, errors=str, warnings=str) + self.assertEqual(hints.simple_output_type('test'), int) + + hints = IOTypeHints.empty().with_output_types( + Union[int, str], errors=str, warnings=str) + self.assertEqual(hints.simple_output_type('test'), Union[int, str]) + + def test_without_tagged_types(self): + """Without tagged types, tagged_output_types() returns empty dict.""" + hints = IOTypeHints.empty().with_output_types(int) + self.assertEqual(hints.tagged_output_types(), {}) + self.assertEqual(hints.simple_output_type('test'), int) + + +class DecoratorStyleTaggedOutputTest(unittest.TestCase): + """Tests for @with_output_types decorator style across all transforms.""" + def test_pardo_decorator_pipeline(self): + """Test that tagged types propagate through ParDo pipeline.""" + @with_output_types(int, errors=str) + class MyDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(MyDoFn()).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_map_decorator_pipeline(self): + """Test that tagged types propagate through Map.""" + @with_output_types(int, errors=str) + def mapfn(element): + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmap_decorator_pipeline(self): + """Test that tagged types propagate through FlatMap.""" + @with_output_types(Iterable[int], errors=str) + def flatmapfn(element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.FlatMap(flatmapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_maptuple_decorator_pipeline(self): + """Test that tagged types propagate through MapTuple.""" + @with_output_types(int, errors=str) + def maptuplefn(key, value): + if value < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + return value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.MapTuple(maptuplefn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmaptuple_decorator_pipeline(self): + """Test that tagged types propagate through FlatMapTuple.""" + @with_output_types(Iterable[int], errors=str) + def flatmaptuplefn(key, value): + if value < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + yield value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.FlatMapTuple(flatmaptuplefn).with_outputs( + 'errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + +class ChainStyleTaggedOutputTest(unittest.TestCase): + """Tests for .with_output_types() method chain style across all transforms.""" + def test_pardo_chain_pipeline(self): + """Test ParDo with chained type hints.""" + class SimpleDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(SimpleDoFn()).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_map_chain_pipeline(self): + """Test Map with chained type hints.""" + def mapfn(element): + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_output_types(int, errors=str).with_outputs( + 'errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmap_chain_pipeline(self): + """Test FlatMap with chained type hints. + + Note: For FlatMap.with_output_types(), specify the element type directly + (int), not wrapped in Iterable. The transform handles iteration internally. + """ + def flatmapfn(element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.FlatMap(flatmapfn).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_maptuple_chain_pipeline(self): + """Test MapTuple with chained type hints.""" + def maptuplefn(key, value): + if value < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + return value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.MapTuple(maptuplefn).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmaptuple_chain_pipeline(self): + """Test FlatMapTuple with chained type hints. + + Note: For FlatMapTuple.with_output_types(), specify the element type + directly (int), not wrapped in Iterable. + """ + def flatmaptuplefn(key, value): + if value < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + yield value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.FlatMapTuple(flatmaptuplefn).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + +class AnnotationStyleTaggedOutputTest(unittest.TestCase): + """Tests for function annotation style across all transforms.""" + def test_map_annotation_union(self): + """Test Map with Union[int, TaggedOutput[...]] annotation.""" + def mapfn(element: int) -> int | TaggedOutput[Literal['errors'], str]: + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_map_annotation_multiple_tags(self): + """Test Map with multiple TaggedOutput types in annotation.""" + def mapfn( + element: int + ) -> int | TaggedOutput[Literal['errors'], + str] | TaggedOutput[Literal['warnings'], str]: + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + elif element == 0: + return beam.pvalue.TaggedOutput('warnings', 'Zero value') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_outputs('errors', 'warnings', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + self.assertEqual(results.warnings.element_type, str) + + def test_flatmap_annotation_iterable(self): + """Test FlatMap with Iterable[int | TaggedOutput[...]] annotation.""" + def flatmapfn( + element: int) -> Iterable[int | TaggedOutput[Literal['errors'], str]]: + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.FlatMap(flatmapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_pardo_annotation_process_method(self): + """Test DoFn with process method annotation.""" + class AnnotatedDoFn(beam.DoFn): + def process( + self, + element: int) -> Iterable[int | TaggedOutput[Literal['errors'], str]]: + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(AnnotatedDoFn()).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 0bbc21f6739c..cec830380087 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1421,7 +1421,7 @@ def unused_foo(): return 5, 'bar' def test_no_kwargs_accepted(self): - with self.assertRaisesRegex(ValueError, r'must be positional'): + with self.assertRaisesRegex(ValueError, r'single positional argument'): @with_output_types(m=int) def unused_foo():