Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sdks/python/apache_beam/pvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def __init__(
self._tags = tags
self._main_tag = main_tag
self._transform = transform
self._tagged_output_types = (
transform.get_type_hints().tagged_output_types() if transform else {})
self._allow_unknown_tags = (
not tags if allow_unknown_tags is None else allow_unknown_tags)
# The ApplyPTransform instance for the application of the multi FlatMap
Expand Down Expand Up @@ -303,7 +305,7 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection:
pcoll = PCollection(
self._pipeline,
tag=tag,
element_type=typehints.Any,
element_type=self._tagged_output_types.get(tag, typehints.Any),
is_bounded=is_bounded)
# Transfer the producer from the DoOutputsTuple to the resulting
# PCollection.
Expand All @@ -323,15 +325,19 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection:
return pcoll


class TaggedOutput(object):
TagType = TypeVar('TagType', bound=str)
ValueType = TypeVar('ValueType')


class TaggedOutput(Generic[TagType, ValueType]):
"""An object representing a tagged value.

ParDo, Map, and FlatMap transforms can emit values on multiple outputs which
are distinguished by string tags. The DoFn will return plain values
if it wants to emit on the main output and TaggedOutput objects
if it wants to emit a value on a specific tagged output.
"""
def __init__(self, tag: str, value: Any) -> None:
def __init__(self, tag: TagType, value: ValueType) -> None:
if not isinstance(tag, str):
raise TypeError(
'Attempting to create a TaggedOutput with non-string tag %s' %
Expand Down
24 changes: 21 additions & 3 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,17 @@ def with_outputs(self, *tags, main=None, allow_unknown_tags=None):
raise ValueError(
'Main output tag %r must be different from side output tags %r.' %
(main, tags))
type_hints = self.get_type_hints()
declared_tags = set(type_hints.tagged_output_types().keys())
requested_tags = set(tags)

unknown = requested_tags - declared_tags
if unknown and declared_tags: # Only warn if type hints exist
logging.warning(
"Tags %s requested in with_outputs() but not declared "
"in type hints. Declared tags: %s",
unknown,
declared_tags)
return _MultiParDo(self, tags, main, allow_unknown_tags)

def _do_fn_info(self):
Expand Down Expand Up @@ -2120,8 +2131,10 @@ def Map(fn, *args, **kwargs): # pylint: disable=invalid-name
wrapper)
output_hint = type_hints.simple_output_type(label)
if output_hint:
tagged_output_types = type_hints.tagged_output_types()
wrapper = with_output_types(
typehints.Iterable[_strip_output_annotations(output_hint)])(
typehints.Iterable[_strip_output_annotations(output_hint)],
**tagged_output_types)(
wrapper)
# pylint: disable=protected-access
wrapper._argspec_fn = fn
Expand Down Expand Up @@ -2189,8 +2202,10 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name
pass
output_hint = type_hints.simple_output_type(label)
if output_hint:
tagged_output_types = type_hints.tagged_output_types()
wrapper = with_output_types(
typehints.Iterable[_strip_output_annotations(output_hint)])(
typehints.Iterable[_strip_output_annotations(output_hint)],
**tagged_output_types)(
wrapper)

# Replace the first (args) component.
Expand Down Expand Up @@ -2261,7 +2276,10 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name
pass
output_hint = type_hints.simple_output_type(label)
if output_hint:
wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper)
tagged_output_types = type_hints.tagged_output_types()
wrapper = with_output_types(
_strip_output_annotations(output_hint), **tagged_output_types)(
wrapper)

# Replace the first (args) component.
modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
Expand Down
28 changes: 22 additions & 6 deletions sdks/python/apache_beam/transforms/ptransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,12 +414,15 @@ def with_input_types(self, input_type_hint):
input_type_hint, 'Type hints for a PTransform')
return super().with_input_types(input_type_hint)

def with_output_types(self, type_hint):
def with_output_types(self, type_hint, **tagged_type_hints):
"""Annotates the output type of a :class:`PTransform` with a type-hint.

Args:
type_hint (type): An instance of an allowed built-in type, a custom class,
or a :class:`~apache_beam.typehints.typehints.TypeConstraint`.
or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. This is
the type hint for the main output.
**tagged_type_hints: Type hints for tagged outputs. Each keyword argument
specifies the type for a tagged output e.g., ``errors=str``.

Raises:
TypeError: If **type_hint** is not a valid type-hint. See
Expand All @@ -430,10 +433,22 @@ def with_output_types(self, type_hint):
PTransform: A reference to the instance of this particular
:class:`PTransform` object. This allows chaining type-hinting related
methods.

Example::
result = pcoll | beam.ParDo(MyDoFn()).with_output_types(
int, # main output type
errors=str, # 'errors' tagged output type
warnings=str # 'warnings' tagged output type
).with_outputs('errors', 'warnings', main='main')
"""
type_hint = native_type_compatibility.convert_to_beam_type(type_hint)
validate_composite_type_param(type_hint, 'Type hints for a PTransform')
return super().with_output_types(type_hint)
for tag, hint in tagged_type_hints.items():
tagged_type_hints[tag] = native_type_compatibility.convert_to_beam_type(
hint)
validate_composite_type_param(
tagged_type_hints[tag], f'Tagged output type hint for {tag!r}')
return super().with_output_types(type_hint, **tagged_type_hints)

def with_resource_hints(self, **kwargs): # type: (...) -> PTransform
"""Adds resource hints to the :class:`PTransform`.
Expand Down Expand Up @@ -479,10 +494,11 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output):
if hints is None or not any(hints):
return
arg_hints, kwarg_hints = hints
if arg_hints and kwarg_hints:
# Output types can have kwargs for tagged output types.
if arg_hints and kwarg_hints and input_or_output != 'output':
raise TypeCheckError(
'PTransform cannot have both positional and keyword type hints '
'without overriding %s._type_check_%s()' %
'PTransform cannot have both positional and keyword input type hints'
' without overriding %s._type_check_%s()' %
(self.__class__, input_or_output))
root_hint = (
arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
Expand Down
Loading
Loading