diff --git a/datawrapper/__init__.py b/datawrapper/__init__.py index fa546bc..a549d60 100644 --- a/datawrapper/__init__.py +++ b/datawrapper/__init__.py @@ -29,6 +29,7 @@ LineSymbol, LineValueLabel, MultipleColumnChart, + MultipleColumnRangeAnnotation, RangeAnnotation, ScatterPlot, StackedBarChart, @@ -97,6 +98,7 @@ "AreaChart", "ArrowChart", "MultipleColumnChart", + "MultipleColumnRangeAnnotation", "ScatterPlot", "StackedBarChart", "TextAnnotation", diff --git a/datawrapper/charts/__init__.py b/datawrapper/charts/__init__.py index c33d30e..4fa1293 100644 --- a/datawrapper/charts/__init__.py +++ b/datawrapper/charts/__init__.py @@ -51,7 +51,7 @@ Transform, Visualize, ) -from .multiple_column import MultipleColumnChart +from .multiple_column import MultipleColumnChart, MultipleColumnRangeAnnotation from .scatter import ScatterPlot from .stacked_bar import StackedBarChart @@ -111,6 +111,7 @@ "AreaChart", "ArrowChart", "MultipleColumnChart", + "MultipleColumnRangeAnnotation", "ScatterPlot", "StackedBarChart", ) diff --git a/datawrapper/charts/annos.py b/datawrapper/charts/annos.py index b01dd97..18c4dc0 100644 --- a/datawrapper/charts/annos.py +++ b/datawrapper/charts/annos.py @@ -255,12 +255,11 @@ def serialize_model(self) -> dict: return model @classmethod - def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict]: + def deserialize_model(cls, api_data: dict[str, dict] | None) -> list[dict]: """Deserialize annotations from API response format. Args: - api_data: Dictionary mapping UUID keys to annotation data, - or a list, or None + api_data: Dictionary mapping UUID keys to annotation data, or None Returns: List of annotation dicts with 'id' field preserved @@ -268,27 +267,22 @@ def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict if not api_data: return [] - # Handle dict format (UUID keys -> annotation data) - if isinstance(api_data, dict): - result = [] - for anno_id, anno_data in api_data.items(): - # Create a copy to avoid modifying the original - anno_dict = {**anno_data, "id": anno_id} - - # Handle connector line deserialization (enabled by presence pattern) - if "connectorLine" in anno_dict: - connector = anno_dict["connectorLine"] - if isinstance(connector, dict): - # If enabled is False or missing, set to None (disabled) - if not connector.get("enabled", False): - anno_dict["connectorLine"] = None - # Otherwise keep the connector line object (enabled) - - result.append(anno_dict) - return result - - # Handle list format (already deserialized or legacy) - return list(api_data) + result = [] + for anno_id, anno_data in api_data.items(): + # Create a copy to avoid modifying the original + anno_dict = {**anno_data, "id": anno_id} + + # Handle connector line deserialization (enabled by presence pattern) + if "connectorLine" in anno_dict: + connector = anno_dict["connectorLine"] + if isinstance(connector, dict): + # If enabled is False or missing, set to None (disabled) + if not connector.get("enabled", False): + anno_dict["connectorLine"] = None + # Otherwise keep the connector line object (enabled) + + result.append(anno_dict) + return result class AreaFill(BaseModel): @@ -396,12 +390,11 @@ def serialize_model(self) -> dict: return result @classmethod - def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict]: + def deserialize_model(cls, api_data: dict[str, dict] | None) -> list[dict]: """Deserialize area fills from API response format. Args: - api_data: Dictionary mapping UUID keys to area fill data, - or a list, or None + api_data: Dictionary mapping UUID keys to area fill data, or None Returns: List of area fill dicts with 'id' field preserved @@ -409,14 +402,7 @@ def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict if not api_data: return [] - # Handle dict format (UUID keys -> area fill data) - if isinstance(api_data, dict): - return [ - {**fill_data, "id": fill_id} for fill_id, fill_data in api_data.items() - ] - - # Handle list format (already deserialized or legacy) - return list(api_data) + return [{**fill_data, "id": fill_id} for fill_id, fill_data in api_data.items()] class RangeAnnotation(BaseModel): @@ -538,12 +524,11 @@ def serialize_model(self) -> dict: } @classmethod - def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict]: + def deserialize_model(cls, api_data: dict[str, dict] | None) -> list[dict]: """Deserialize annotations from API response format. Args: - api_data: Dictionary mapping UUID keys to annotation data, - or a list, or None + api_data: Dictionary mapping UUID keys to annotation data, or None Returns: List of annotation dicts with 'id' field preserved @@ -551,11 +536,4 @@ def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict if not api_data: return [] - # Handle dict format (UUID keys -> annotation data) - if isinstance(api_data, dict): - return [ - {**anno_data, "id": anno_id} for anno_id, anno_data in api_data.items() - ] - - # Handle list format (already deserialized or legacy) - return list(api_data) + return [{**anno_data, "id": anno_id} for anno_id, anno_data in api_data.items()] diff --git a/datawrapper/charts/multiple_column.py b/datawrapper/charts/multiple_column.py index cd6b407..e603daa 100644 --- a/datawrapper/charts/multiple_column.py +++ b/datawrapper/charts/multiple_column.py @@ -30,6 +30,89 @@ ) +class MultipleColumnRangeAnnotation(RangeAnnotation): + """Range annotation with additional fields specific to MultipleColumnChart. + + This subclass extends RangeAnnotation to support multi-panel charts where + annotations can be associated with specific plots/panels. + + Attributes: + plot: Which plot/panel this annotation applies to (e.g., "Paris", "London") + showInAllPlots: Whether to show this annotation in all plots (defaults to True) + """ + + #: Which plot/panel this annotation applies to + plot: str | None = Field( + default=None, + description="Which plot/panel this annotation applies to", + ) + + #: Whether to show this annotation in all plots + show_in_all_plots: bool = Field( + default=True, + alias="showInAllPlots", + description="Whether to show this annotation in all plots", + ) + + def serialize_model(self) -> dict: + """Serialize the annotation to API format. + + Extends the base RangeAnnotation serialization to include: + - plot field inside the position object + - showInAllPlots field at the top level + + Returns: + Dictionary in Datawrapper API format + """ + result = super().serialize_model() + + # Add plot to position object if specified + if self.plot is not None: + result["position"]["plot"] = self.plot + + # Always include showInAllPlots at top level + result["showInAllPlots"] = self.show_in_all_plots + + return result + + @classmethod + def deserialize_model(cls, api_data: dict[str, dict] | None) -> list[dict]: + """Parse API response to extract MultipleColumnRangeAnnotation data. + + Handles the API format where: + - plot is inside the position object + - showInAllPlots is at the top level + + Args: + api_data: API response data (dict with UUID keys) + + Returns: + List of dictionaries that can initialize MultipleColumnRangeAnnotation instances + """ + if not api_data: + return [] + + result = [] + for anno_id, anno_data in api_data.items(): + # Extract position data + position = anno_data.get("position", {}) + plot = position.get("plot") if isinstance(position, dict) else None + + # Extract showInAllPlots (defaults to True) + show_in_all = anno_data.get("showInAllPlots", True) + + # Build annotation dict with id + anno_dict = {**anno_data, "id": anno_id} + + # Add MultipleColumnChart-specific fields + if plot is not None: + anno_dict["plot"] = plot + anno_dict["show_in_all_plots"] = show_in_all + + result.append(anno_dict) + return result + + class MultipleColumnChart( GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart ): @@ -398,7 +481,7 @@ def serialize_model(self) -> dict: self.text_annotations, TextAnnotation ), "range-annotations": ModelListSerializer.serialize( - self.range_annotations, RangeAnnotation + self.range_annotations, MultipleColumnRangeAnnotation ), } @@ -551,8 +634,10 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: init_data["text_annotations"] = TextAnnotation.deserialize_model( visualize.get("text-annotations") ) - init_data["range_annotations"] = RangeAnnotation.deserialize_model( - visualize.get("range-annotations") + init_data["range_annotations"] = ( + MultipleColumnRangeAnnotation.deserialize_model( + visualize.get("range-annotations") + ) ) return init_data diff --git a/docs/user-guide/charts/multiple-column-charts.md b/docs/user-guide/charts/multiple-column-charts.md index 1877f34..bc46da9 100644 --- a/docs/user-guide/charts/multiple-column-charts.md +++ b/docs/user-guide/charts/multiple-column-charts.md @@ -84,3 +84,8 @@ chart.create() ```{eval-rst} .. parameter-table:: datawrapper.charts.MultipleColumnChart +``` + +```{eval-rst} +.. parameter-table:: datawrapper.charts.MultipleColumnRangeAnnotation +```