Skip to content

Commit 1d1d7c6

Browse files
author
iscai-msft
committed
update unions serializer to get around pyright issue
1 parent cdcedf5 commit 1d1d7c6

2 files changed

Lines changed: 13 additions & 7 deletions

File tree

packages/http-client-python/generator/pygen/codegen/models/model_type.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ def __init__(
7777
self.cross_language_definition_id: Optional[str] = self.yaml_data.get("crossLanguageDefinitionId")
7878
self.usage: int = self.yaml_data.get("usage", UsageFlags.Input.value | UsageFlags.Output.value)
7979
self.client_namespace: str = self.yaml_data.get("clientNamespace", code_model.namespace)
80-
self.is_typed_dict_only: bool = (
81-
self.yaml_data.get("typedDictOnly", False)
82-
or self.name in code_model.options.get("typed-dict-only-models", [])
83-
)
80+
self.is_typed_dict_only: bool = self.yaml_data.get(
81+
"typedDictOnly", False
82+
) or self.name in code_model.options.get("typed-dict-only-models", [])
8483

8584
@property
8685
def is_usage_output(self) -> bool:
@@ -396,8 +395,7 @@ def imports(self, **kwargs: Any) -> FileImport:
396395
ImportType.LOCAL,
397396
)
398397
elif serialize_namespace_type in [NamespaceType.TYPES_FILE, NamespaceType.UNIONS_FILE] or (
399-
serialize_namespace_type == NamespaceType.MODEL
400-
and kwargs.get("called_by_property", False)
398+
serialize_namespace_type == NamespaceType.MODEL and kwargs.get("called_by_property", False)
401399
):
402400
file_import.add_submodule_import(
403401
relative_path,

packages/http-client-python/generator/pygen/codegen/serializers/unions_serializer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ def imports(self) -> FileImport:
4343
serialize_namespace_type=NamespaceType.UNIONS_FILE,
4444
)
4545
)
46+
for model in self.discriminated_base_models:
47+
for subtype in model.discriminated_subtypes.values():
48+
file_import.merge(
49+
subtype.imports(
50+
serialize_namespace=self.serialize_namespace,
51+
serialize_namespace_type=NamespaceType.UNIONS_FILE,
52+
)
53+
)
4654
return file_import
4755

4856
def discriminated_subtypes_union(self, model: ModelType) -> str:
4957
subtypes = list(model.discriminated_subtypes.values())
50-
subtype_names = [s.name for s in subtypes]
58+
subtype_names = [s.type_annotation(skip_quote=True) for s in subtypes]
5159
return f"{model.name} = Union[{', '.join(subtype_names)}]"
5260

5361
def serialize(self) -> str:

0 commit comments

Comments
 (0)