diff --git a/datawrapper/__init__.py b/datawrapper/__init__.py
index a384edae..fa546bca 100644
--- a/datawrapper/__init__.py
+++ b/datawrapper/__init__.py
@@ -63,6 +63,12 @@
ValueLabelMode,
ValueLabelPlacement,
)
+from datawrapper.charts.mixins import (
+ CustomRangeMixin,
+ CustomTicksMixin,
+ GridDisplayMixin,
+ GridFormatMixin,
+)
from datawrapper.exceptions import (
FailedRequestError,
InvalidRequestError,
@@ -123,6 +129,10 @@
"ValueLabelMode",
"ValueLabelPlacement",
"get_country_flag",
+ "CustomRangeMixin",
+ "CustomTicksMixin",
+ "GridFormatMixin",
+ "GridDisplayMixin",
"FailedRequestError",
"InvalidRequestError",
"RateLimitError",
diff --git a/datawrapper/__main__.py b/datawrapper/__main__.py
index 1e67728b..901af9f5 100644
--- a/datawrapper/__main__.py
+++ b/datawrapper/__main__.py
@@ -1156,6 +1156,10 @@ def export_chart(
) -> Path | Image:
"""Exports a chart, table, or map.
+ .. deprecated::
+ Use the object-oriented chart classes instead (e.g., BarChart, LineChart).
+ This method will be removed in a future version.
+
Parameters
----------
chart_id : str
@@ -1210,6 +1214,14 @@ def export_chart(
Path | Image
The file path to the exported image or an Image object displaying the image.
"""
+ warnings.warn(
+ "export_chart() is deprecated and will be removed in a future version. "
+ "Use the object-oriented chart classes instead. "
+ "Example: chart = BarChart.get(chart_id='abc123'); png_data = chart.export_png(); Path('chart.png').write_bytes(png_data)",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
_query = {
"unit": unit,
"mode": mode,
diff --git a/datawrapper/charts/__init__.py b/datawrapper/charts/__init__.py
index 449fd282..c33d30eb 100644
--- a/datawrapper/charts/__init__.py
+++ b/datawrapper/charts/__init__.py
@@ -33,6 +33,12 @@
ValueLabelPlacement,
)
from .line import AreaFill, Line, LineChart, LineSymbol, LineValueLabel
+from .mixins import (
+ CustomRangeMixin,
+ CustomTicksMixin,
+ GridDisplayMixin,
+ GridFormatMixin,
+)
from .models import (
Annotate,
ColumnFormat,
@@ -56,6 +62,10 @@
"Annotate",
"ColumnFormat",
"ColumnFormatList",
+ "CustomRangeMixin",
+ "CustomTicksMixin",
+ "GridFormatMixin",
+ "GridDisplayMixin",
"ArrowHead",
"ConnectorLineType",
"DateFormat",
diff --git a/datawrapper/charts/area.py b/datawrapper/charts/area.py
index f70920f0..732f46b0 100644
--- a/datawrapper/charts/area.py
+++ b/datawrapper/charts/area.py
@@ -7,23 +7,28 @@
from .base import BaseChart
from .enums import (
DateFormat,
- GridDisplay,
GridLabelAlign,
GridLabelPosition,
LineInterpolation,
NumberFormat,
PlotHeightMode,
)
+from .mixins import (
+ CustomRangeMixin,
+ CustomTicksMixin,
+ GridDisplayMixin,
+ GridFormatMixin,
+)
from .serializers import (
ColorCategory,
- CustomRange,
- CustomTicks,
ModelListSerializer,
PlotHeight,
)
-class AreaChart(BaseChart):
+class AreaChart(
+ GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart
+):
"""A base class for the Datawrapper API's area chart."""
model_config = ConfigDict(
@@ -59,70 +64,6 @@ class AreaChart(BaseChart):
description="The type of datawrapper chart to create",
)
- #
- # Horizontal axis (X-axis)
- #
-
- #: The custom range for the x axis
- custom_range_x: list[Any] | tuple[Any, Any] = Field(
- default_factory=lambda: ["", ""],
- alias="custom-range-x",
- description="The custom range for the x axis",
- )
-
- #: The custom ticks for the x axis
- custom_ticks_x: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-x",
- description="The custom ticks for the x axis",
- )
-
- #: The formatting for the x grid labels (use DateFormat or NumberFormat enum or custom format strings)
- x_grid_format: DateFormat | NumberFormat | str = Field(
- default="auto",
- alias="x-grid-format",
- description="The formatting for the x grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the x grid
- x_grid: GridDisplay | str = Field(
- default="off",
- alias="x-grid",
- description="Whether to show the x grid. The 'on' setting shows lines.",
- )
-
- #
- # Vertical axis (Y-axis)
- #
-
- #: The custom range for the y axis
- custom_range_y: list[Any] | tuple[Any, Any] = Field(
- default_factory=lambda: ["", ""],
- alias="custom-range-y",
- description="The custom range for the y axis",
- )
-
- #: The custom ticks for the y axis
- custom_ticks_y: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-y",
- description="The custom ticks for the y axis",
- )
-
- #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings)
- y_grid_format: DateFormat | NumberFormat | str = Field(
- default="",
- alias="y-grid-format",
- description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the y grid
- y_grid: GridDisplay | str = Field(
- default="on",
- alias="y-grid",
- description="Whether to show the y grid. The 'on' setting shows lines.",
- )
-
#: The labeling of the y grid labels
y_grid_labels: GridLabelPosition | str = Field(
default="auto",
@@ -313,52 +254,48 @@ def serialize_model(self) -> dict:
model = super().serialize_model()
# Add chart specific properties
- model["metadata"]["visualize"].update(
- {
- # Horizontal axis
- "custom-range-x": CustomRange.serialize(self.custom_range_x),
- "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x),
- "x-grid-format": self.x_grid_format,
- "x-grid": self.x_grid,
- # Vertical axis
- "custom-range-y": CustomRange.serialize(self.custom_range_y),
- "custom-ticks-y": CustomTicks.serialize(self.custom_ticks_y),
- "y-grid-format": self.y_grid_format,
- "y-grid": self.y_grid,
- "y-grid-labels": self.y_grid_labels,
- "y-grid-label-align": self.y_grid_label_align,
- # Customize areas
- "area-opacity": self.area_opacity,
- "base-color": self.base_color,
- "interpolation": self.interpolation,
- "sort-areas": self.sort_areas,
- "stack-areas": self.stack_areas,
- "stack-to-100": self.stack_to_100,
- "area-separator-lines": self.area_separator_lines,
- "area-separator-color": self.area_separator_color,
- # Customize specific layers
- "color-category": ColorCategory.serialize(self.color_category),
- # Labels
- "show-color-key": self.show_color_key,
- # Tooltips
- "show-tooltips": self.show_tooltips,
- "tooltip-x-format": self.tooltip_x_format,
- "tooltip-number-format": self.tooltip_number_format,
- # Appearance
- **PlotHeight.serialize(
- self.plot_height_mode,
- self.plot_height_fixed,
- self.plot_height_ratio,
- ),
- # Annotations
- "text-annotations": ModelListSerializer.serialize(
- self.text_annotations, TextAnnotation
- ),
- "range-annotations": ModelListSerializer.serialize(
- self.range_annotations, RangeAnnotation
- ),
- }
- )
+ visualize_data = {
+ # Horizontal and vertical axis (from mixins)
+ **self._serialize_grid_config(),
+ **self._serialize_grid_format(),
+ **self._serialize_custom_range(),
+ **self._serialize_custom_ticks(),
+ # Vertical axis (chart-specific)
+ "y-grid-labels": self.y_grid_labels,
+ "y-grid-label-align": self.y_grid_label_align,
+ # Customize areas
+ "area-opacity": self.area_opacity,
+ "base-color": self.base_color,
+ "interpolation": self.interpolation,
+ "sort-areas": self.sort_areas,
+ "stack-areas": self.stack_areas,
+ "stack-to-100": self.stack_to_100,
+ "area-separator-lines": self.area_separator_lines,
+ "area-separator-color": self.area_separator_color,
+ # Customize specific layers
+ "color-category": ColorCategory.serialize(self.color_category),
+ # Labels
+ "show-color-key": self.show_color_key,
+ # Tooltips
+ "show-tooltips": self.show_tooltips,
+ "tooltip-x-format": self.tooltip_x_format,
+ "tooltip-number-format": self.tooltip_number_format,
+ # Appearance
+ **PlotHeight.serialize(
+ self.plot_height_mode,
+ self.plot_height_fixed,
+ self.plot_height_ratio,
+ ),
+ # Annotations
+ "text-annotations": ModelListSerializer.serialize(
+ self.text_annotations, TextAnnotation
+ ),
+ "range-annotations": ModelListSerializer.serialize(
+ self.range_annotations, RangeAnnotation
+ ),
+ }
+
+ model["metadata"]["visualize"].update(visualize_data)
# Return the serialized data
return model
@@ -380,30 +317,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
metadata = api_response.get("metadata", {})
visualize = metadata.get("visualize", {})
- # Horizontal axis (X-axis)
- init_data["custom_range_x"] = CustomRange.deserialize(
- visualize.get("custom-range-x")
- )
- init_data["custom_ticks_x"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-x", "")
- )
- if "x-grid-format" in visualize:
- init_data["x_grid_format"] = visualize["x-grid-format"]
- if "x-grid" in visualize:
- init_data["x_grid"] = visualize["x-grid"]
-
- # Vertical axis (Y-axis)
- init_data["custom_range_y"] = CustomRange.deserialize(
- visualize.get("custom-range-y")
- )
- init_data["custom_ticks_y"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-y", "")
- )
+ # Horizontal and vertical axis (from mixins)
+ init_data.update(cls._deserialize_grid_config(visualize))
+ init_data.update(cls._deserialize_grid_format(visualize))
+ init_data.update(cls._deserialize_custom_range(visualize))
+ init_data.update(cls._deserialize_custom_ticks(visualize))
- if "y-grid-format" in visualize:
- init_data["y_grid_format"] = visualize["y-grid-format"]
- if "y-grid" in visualize:
- init_data["y_grid"] = visualize["y-grid"]
+ # Vertical axis (chart-specific)
if "y-grid-labels" in visualize:
init_data["y_grid_labels"] = visualize["y-grid-labels"]
if "y-grid-label-align" in visualize:
diff --git a/datawrapper/charts/base.py b/datawrapper/charts/base.py
index 84741233..184fe4ca 100644
--- a/datawrapper/charts/base.py
+++ b/datawrapper/charts/base.py
@@ -1,11 +1,10 @@
import os
import warnings
from io import StringIO
-from pathlib import Path
from typing import Any, Literal
import pandas as pd
-from IPython.display import IFrame, Image
+from IPython.display import IFrame
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
from datawrapper.__main__ import Datawrapper
@@ -733,99 +732,219 @@ def publish(
# Return self for chaining
return self
- def export(
+ def export_png(
self,
- unit: str = "px",
- mode: str = "rgb",
- width: int = 400,
- height: int | str | None = None,
+ *,
+ width: int | None = None,
+ height: int | None = None,
plain: bool = False,
zoom: int = 2,
+ transparent: bool = False,
+ border_width: int = 0,
+ border_color: str | None = None,
+ access_token: str | None = None,
+ timeout: int = 30,
+ ) -> bytes:
+ """Export chart as PNG and return the raw bytes.
+
+ Args:
+ width: Width of visualization in pixels. If not specified, uses chart width.
+ height: Height of visualization in pixels. If not specified, uses chart height.
+ plain: If True, exports only the visualization without header/footer.
+ zoom: Scale multiplier for PNG resolution (e.g., 2 = 2x resolution).
+ transparent: If True, exports with transparent background.
+ border_width: Margin around visualization in pixels.
+ border_color: Color of the border (e.g., "#FFFFFF"). If not specified, uses chart background color.
+ access_token: Optional Datawrapper API access token.
+ timeout: Timeout for the API request in seconds.
+
+ Returns:
+ Raw PNG image data as bytes.
+
+ Raises:
+ ValueError: If no chart_id is set.
+ Exception: If the API request fails.
+
+ Example:
+ >>> chart = LineChart.get(chart_id="abc123")
+ >>> png_data = chart.export_png(zoom=3, transparent=True)
+ >>> Path("chart.png").write_bytes(png_data)
+ """
+ if not self.chart_id:
+ raise ValueError(
+ "No chart_id set. Use create() first or set chart_id manually."
+ )
+
+ client = self._get_client(access_token)
+
+ # Build query parameters with PNG-specific defaults
+ params = {
+ "unit": "px",
+ "mode": "rgb",
+ "plain": str(plain).lower(),
+ "zoom": str(zoom),
+ "transparent": str(transparent).lower(),
+ "borderWidth": str(border_width),
+ }
+
+ if width is not None:
+ params["width"] = str(width)
+ if height is not None:
+ params["height"] = str(height)
+ if border_color is not None:
+ params["borderColor"] = border_color
+
+ # Make the API request
+ response = client.get(
+ f"{client._CHARTS_URL}/{self.chart_id}/export/png",
+ params=params,
+ timeout=timeout,
+ )
+
+ # Return raw bytes
+ if isinstance(response, bytes):
+ return response
+ raise ValueError(f"Unexpected response type from API: {type(response)}")
+
+ def export_pdf(
+ self,
+ *,
+ width: int | None = None,
+ height: int | None = None,
+ plain: bool = False,
+ unit: Literal["px", "mm", "inch"] = "px",
+ mode: Literal["rgb", "cmyk"] = "rgb",
scale: int = 1,
- border_width: int = 20,
+ border_width: int = 0,
border_color: str | None = None,
- transparent: bool = False,
- download: bool = False,
- full_vector: bool = False,
- ligatures: bool = True,
- logo: str = "auto",
- logo_id: str | None = None,
- dark: bool = False,
- output: str = "png",
- filepath: str = "./image.png",
- display: bool = False,
access_token: str | None = None,
- ) -> Path | Image:
- """Export the chart to an image file.
+ timeout: int = 30,
+ ) -> bytes:
+ """Export chart as PDF and return the raw bytes.
Args:
- unit: One of px, mm, inch. Defines the unit in which the borderwidth, height,
- and width will be measured in, by default "px"
- mode: One of rgb or cmyk. Which color mode the output should be in,
- by default "rgb"
- width: Width of visualization. If not specified, it takes the chart width,
- by default 400
- height: Height of visualization. Can be a number or "auto", by default None
- plain: Defines if only the visualization should be exported (True), or if it should
- include header and footer as well (False), by default False
- zoom: Defines the multiplier for the png size, by default 2
- scale: Defines the multiplier for the pdf size, by default 1
- border_width: Margin around the visualization, by default 20
- border_color: Color of the border around the visualization, by default None
- transparent: Set to True to export your visualization with a transparent background,
- by default False
- download: Whether to trigger a download, by default False
- full_vector: Export as full vector graphic (for supported formats), by default False
- ligatures: Enable typographic ligatures, by default True
- logo: Logo display setting. One of "auto", "on", or "off", by default "auto"
- logo_id: Custom logo ID to use, by default None
- dark: Export in dark mode, by default False
- output: One of png, pdf, or svg, by default "png"
- filepath: Name/filepath to save output in, by default "./image.png"
- display: Whether to display the exported image as output in the notebook cell,
- by default False
+ width: Width of visualization. If not specified, uses chart width.
+ height: Height of visualization. If not specified, uses chart height.
+ plain: If True, exports only the visualization without header/footer.
+ unit: Unit for measurements: "px", "mm", or "inch".
+ mode: Color mode: "rgb" or "cmyk".
+ scale: Scale multiplier for PDF resolution.
+ border_width: Margin around visualization.
+ border_color: Color of the border (e.g., "#FFFFFF"). If not specified, uses chart background color.
access_token: Optional Datawrapper API access token.
- If not provided, will use DATAWRAPPER_ACCESS_TOKEN environment variable.
+ timeout: Timeout for the API request in seconds.
Returns:
- The file path to the exported image or an Image object displaying the image.
+ Raw PDF document data as bytes.
Raises:
- ValueError: If no chart_id is set or no access token is available.
+ ValueError: If no chart_id is set or invalid parameters provided.
Exception: If the API request fails.
+
+ Example:
+ >>> chart = BarChart.get(chart_id="abc123")
+ >>> pdf_data = chart.export_pdf(unit="mm", mode="cmyk")
+ >>> Path("chart.pdf").write_bytes(pdf_data)
"""
if not self.chart_id:
raise ValueError(
"No chart_id set. Use create() first or set chart_id manually."
)
- # Get the client
+ # Validate parameters
+ if unit not in ("px", "mm", "inch"):
+ raise ValueError(f"Invalid unit: {unit}. Must be 'px', 'mm', or 'inch'.")
+ if mode not in ("rgb", "cmyk"):
+ raise ValueError(f"Invalid mode: {mode}. Must be 'rgb' or 'cmyk'.")
+
client = self._get_client(access_token)
- # Call the export_chart method from the client
- return client.export_chart(
- chart_id=self.chart_id,
- unit=unit,
- mode=mode,
- width=width,
- height=height,
- plain=plain,
- zoom=zoom,
- scale=scale,
- border_width=border_width,
- border_color=border_color,
- transparent=transparent,
- download=download,
- full_vector=full_vector,
- ligatures=ligatures,
- logo=logo,
- logo_id=logo_id,
- dark=dark,
- output=output,
- filepath=filepath,
- display=display,
+ # Build query parameters
+ params = {
+ "unit": unit,
+ "mode": mode,
+ "plain": str(plain).lower(),
+ "scale": str(scale),
+ "borderWidth": str(border_width),
+ }
+
+ if width is not None:
+ params["width"] = str(width)
+ if height is not None:
+ params["height"] = str(height)
+ if border_color is not None:
+ params["borderColor"] = border_color
+
+ # Make the API request
+ response = client.get(
+ f"{client._CHARTS_URL}/{self.chart_id}/export/pdf",
+ params=params,
+ )
+
+ # Return raw bytes
+ if isinstance(response, bytes):
+ return response
+ raise ValueError(f"Unexpected response type from API: {type(response)}")
+
+ def export_svg(
+ self,
+ *,
+ width: int | None = None,
+ height: int | None = None,
+ plain: bool = False,
+ access_token: str | None = None,
+ timeout: int = 30,
+ ) -> bytes:
+ """Export chart as SVG and return the raw bytes.
+
+ Args:
+ width: Width of visualization. If not specified, uses chart width.
+ height: Height of visualization. If not specified, uses chart height.
+ plain: If True, exports only the visualization without header/footer.
+ access_token: Optional Datawrapper API access token.
+ timeout: Timeout for the API request in seconds.
+
+ Returns:
+ Raw SVG document data as bytes.
+
+ Raises:
+ ValueError: If no chart_id is set.
+ Exception: If the API request fails.
+
+ Example:
+ >>> chart = ColumnChart.get(chart_id="abc123")
+ >>> svg_data = chart.export_svg(plain=True)
+ >>> Path("chart.svg").write_bytes(svg_data)
+ """
+ if not self.chart_id:
+ raise ValueError(
+ "No chart_id set. Use create() first or set chart_id manually."
+ )
+
+ client = self._get_client(access_token)
+
+ # Build query parameters
+ params = {
+ "plain": str(plain).lower(),
+ }
+
+ if width is not None:
+ params["width"] = str(width)
+ if height is not None:
+ params["height"] = str(height)
+
+ # Make the API request
+ response = client.get(
+ f"{client._CHARTS_URL}/{self.chart_id}/export/svg",
+ params=params,
+ timeout=timeout,
)
+ # Return raw bytes
+ if isinstance(response, bytes):
+ return response
+ raise ValueError(f"Unexpected response type from API: {type(response)}")
+
def delete(self, access_token: str | None = None) -> bool:
"""Delete the chart via the Datawrapper API.
diff --git a/datawrapper/charts/column.py b/datawrapper/charts/column.py
index 3c97d3ac..81e7b0e5 100644
--- a/datawrapper/charts/column.py
+++ b/datawrapper/charts/column.py
@@ -7,7 +7,6 @@
from .base import BaseChart
from .enums import (
DateFormat,
- GridDisplay,
GridLabelAlign,
GridLabelPosition,
NumberFormat,
@@ -15,10 +14,14 @@
ValueLabelDisplay,
ValueLabelPlacement,
)
+from .mixins import (
+ CustomRangeMixin,
+ CustomTicksMixin,
+ GridDisplayMixin,
+ GridFormatMixin,
+)
from .serializers import (
ColorCategory,
- CustomRange,
- CustomTicks,
ModelListSerializer,
NegativeColor,
PlotHeight,
@@ -26,7 +29,9 @@
)
-class ColumnChart(BaseChart):
+class ColumnChart(
+ GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart
+):
"""A base class for the Datawrapper API's column chart."""
model_config = ConfigDict(
@@ -62,69 +67,9 @@ class ColumnChart(BaseChart):
)
#
- # Horizontal axis (X-axis)
- #
-
- #: The custom range for the x axis
- custom_range_x: list[Any] | tuple[Any, Any] = Field(
- default_factory=lambda: ["", ""],
- alias="custom-range-x",
- description="The custom range for the x axis",
- )
-
- #: The custom ticks for the x axis
- custom_ticks_x: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-x",
- description="The custom ticks for the x axis",
- )
-
- #: The formatting for the x grid labels (use DateFormat or NumberFormat enum or custom format strings)
- x_grid_format: DateFormat | NumberFormat | str = Field(
- default="auto",
- alias="x-grid-format",
- description="The formatting for the x grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the x grid
- x_grid: GridDisplay | str = Field(
- default="off",
- alias="x-grid",
- description="Whether to show the x grid",
- )
-
- #
- # Vertical axis (Y-axis)
+ # Vertical axis (Y-axis) - chart-specific fields
#
- #: The custom range for the y axis
- custom_range_y: list[Any] | tuple[Any, Any] = Field(
- default_factory=lambda: ["", ""],
- alias="custom-range-y",
- description="The custom range for the y axis",
- )
-
- #: The custom ticks for the y axis
- custom_ticks_y: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-y",
- description="The custom ticks for the y axis",
- )
-
- #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings)
- y_grid_format: DateFormat | NumberFormat | str = Field(
- default="",
- alias="y-grid-format",
- description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the y grid lines
- y_grid: bool = Field(
- default=True,
- alias="y-grid",
- description="Whether to show the y grid lines",
- )
-
#: The labeling of the y grid labels
y_grid_labels: GridLabelPosition | str = Field(
default="outside",
@@ -266,6 +211,114 @@ def validate_plot_height_mode(cls, v: PlotHeightMode | str) -> PlotHeightMode |
raise ValueError(f"Invalid value: {v}. Must be one of {valid_values}")
return v
+ @classmethod
+ def _deserialize_grid_config(cls, visualize: dict) -> dict:
+ """Override to handle ColumnChart-specific grid fields.
+
+ ColumnChart uses different API fields than other charts:
+ - x_grid: Parsed from 'grid-lines-x' dict (not 'x-grid' string)
+ - y_grid: Parsed from 'grid-lines' boolean (not 'y-grid' string)
+ """
+ result = {}
+
+ # Parse grid-lines-x (dict with type/enabled)
+ if "grid-lines-x" in visualize:
+ grid_lines_x = visualize["grid-lines-x"]
+ if isinstance(grid_lines_x, dict):
+ enabled = grid_lines_x.get("enabled", False)
+ grid_type = grid_lines_x.get("type", "")
+ result["x_grid"] = grid_type if enabled else "off"
+
+ # Parse grid-lines (boolean)
+ if "grid-lines" in visualize:
+ result["y_grid"] = visualize["grid-lines"]
+
+ return result
+
+ def _serialize_grid_config(self) -> dict:
+ """Override to add ColumnChart-specific grid-lines field.
+
+ ColumnChart uses both the standard y-grid field (from mixin) and an
+ additional grid-lines boolean field that mirrors the y-grid on/off state.
+ """
+ # Get the standard grid config from the mixin
+ result = super()._serialize_grid_config()
+
+ # Add the ColumnChart-specific grid-lines boolean field
+ # This mirrors the y_grid on/off state
+ if self.y_grid is not None:
+ from .enums import GridDisplay
+
+ # Convert to boolean: "on" or True -> True, "off" or False -> False
+ if isinstance(self.y_grid, GridDisplay):
+ result["grid-lines"] = self.y_grid == GridDisplay.ON
+ elif isinstance(self.y_grid, bool):
+ result["grid-lines"] = self.y_grid
+ elif isinstance(self.y_grid, str):
+ result["grid-lines"] = self.y_grid.lower() == "on"
+
+ return result
+
+ def _serialize_custom_range(self) -> dict:
+ """Override to handle ColumnChart-specific field naming.
+
+ ColumnChart uses 'custom-range' (not 'custom-range-y') for Y-axis custom range.
+ """
+ # Get the standard custom range config from the mixin
+ result = super()._serialize_custom_range()
+
+ # Rename custom-range-y to custom-range for ColumnChart
+ if "custom-range-y" in result:
+ result["custom-range"] = result.pop("custom-range-y")
+
+ return result
+
+ @classmethod
+ def _deserialize_custom_range(cls, visualize: dict) -> dict:
+ """Override to handle ColumnChart-specific field naming.
+
+ ColumnChart uses 'custom-range' (not 'custom-range-y') for Y-axis custom range.
+ """
+ # Create a modified visualize dict with renamed field
+ modified_visualize = visualize.copy()
+ if "custom-range" in modified_visualize:
+ modified_visualize["custom-range-y"] = modified_visualize.pop(
+ "custom-range"
+ )
+
+ # Call the parent deserializer with the modified dict
+ return super()._deserialize_custom_range(modified_visualize)
+
+ def _serialize_custom_ticks(self) -> dict:
+ """Override to handle ColumnChart-specific field naming.
+
+ ColumnChart uses 'custom-ticks' (not 'custom-ticks-y') for Y-axis custom ticks.
+ """
+ # Get the standard custom ticks config from the mixin
+ result = super()._serialize_custom_ticks()
+
+ # Rename custom-ticks-y to custom-ticks for ColumnChart
+ if "custom-ticks-y" in result:
+ result["custom-ticks"] = result.pop("custom-ticks-y")
+
+ return result
+
+ @classmethod
+ def _deserialize_custom_ticks(cls, visualize: dict) -> dict:
+ """Override to handle ColumnChart-specific field naming.
+
+ ColumnChart uses 'custom-ticks' (not 'custom-ticks-y') for Y-axis custom ticks.
+ """
+ # Create a modified visualize dict with renamed field
+ modified_visualize = visualize.copy()
+ if "custom-ticks" in modified_visualize:
+ modified_visualize["custom-ticks-y"] = modified_visualize.pop(
+ "custom-ticks"
+ )
+
+ # Call the parent deserializer with the modified dict
+ return super()._deserialize_custom_ticks(modified_visualize)
+
@model_serializer
def serialize_model(self) -> dict:
"""Serialize the model to a dictionary."""
@@ -273,60 +326,51 @@ def serialize_model(self) -> dict:
model = super().serialize_model()
# Add chart specific properties
- model["metadata"]["visualize"].update(
- {
- # Horizontal axis
- "custom-range-x": CustomRange.serialize(self.custom_range_x),
- "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x),
- "x-grid-format": self.x_grid_format,
- "grid-lines-x": {
- "type": "" if self.x_grid == "off" else self.x_grid,
- "enabled": self.x_grid != "off",
- },
- # Vertical axis
- "custom-range": CustomRange.serialize(self.custom_range_y),
- "custom-ticks": CustomTicks.serialize(self.custom_ticks_y),
- "y-grid-format": self.y_grid_format,
- "grid-lines": self.y_grid,
- "yAxisLabels": {
- "enabled": self.y_grid_labels != "off",
- "alignment": self.y_grid_label_align,
- "placement": ""
- if self.y_grid_labels == "off"
- else self.y_grid_labels,
- },
- # Appearance
- "base-color": self.base_color,
- "negativeColor": NegativeColor.serialize(self.negative_color),
- "bar-padding": self.bar_padding,
- "color-category": ColorCategory.serialize(
- self.color_category,
- self.category_labels,
- self.category_order,
- ),
- "color-by-column": bool(self.color_category),
- **PlotHeight.serialize(
- self.plot_height_mode,
- self.plot_height_fixed,
- self.plot_height_ratio,
- ),
- # Labels
- "show-color-key": self.show_color_key,
- **ValueLabels.serialize(
- show=self.show_value_labels,
- format_str=self.value_labels_format,
- placement=self.value_labels_placement,
- chart_type="column",
- ),
- # Annotations
- "text-annotations": ModelListSerializer.serialize(
- self.text_annotations, TextAnnotation
- ),
- "range-annotations": ModelListSerializer.serialize(
- self.range_annotations, RangeAnnotation
- ),
- }
- )
+ visualize_data = {
+ # Horizontal and vertical axis (from mixins)
+ **self._serialize_grid_config(),
+ **self._serialize_grid_format(),
+ **self._serialize_custom_range(),
+ **self._serialize_custom_ticks(),
+ # Vertical axis (chart-specific)
+ "yAxisLabels": {
+ "enabled": self.y_grid_labels != "off",
+ "alignment": self.y_grid_label_align,
+ "placement": "" if self.y_grid_labels == "off" else self.y_grid_labels,
+ },
+ # Appearance
+ "base-color": self.base_color,
+ "negativeColor": NegativeColor.serialize(self.negative_color),
+ "bar-padding": self.bar_padding,
+ "color-category": ColorCategory.serialize(
+ self.color_category,
+ self.category_labels,
+ self.category_order,
+ ),
+ "color-by-column": bool(self.color_category),
+ **PlotHeight.serialize(
+ self.plot_height_mode,
+ self.plot_height_fixed,
+ self.plot_height_ratio,
+ ),
+ # Labels
+ "show-color-key": self.show_color_key,
+ **ValueLabels.serialize(
+ show=self.show_value_labels,
+ format_str=self.value_labels_format,
+ placement=self.value_labels_placement,
+ chart_type="column",
+ ),
+ # Annotations
+ "text-annotations": ModelListSerializer.serialize(
+ self.text_annotations, TextAnnotation
+ ),
+ "range-annotations": ModelListSerializer.serialize(
+ self.range_annotations, RangeAnnotation
+ ),
+ }
+
+ model["metadata"]["visualize"].update(visualize_data)
# Return the serialized data
return model
@@ -349,37 +393,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
metadata = api_response.get("metadata", {})
visualize = metadata.get("visualize", {})
- # Horizontal axis (X-axis)
- init_data["custom_range_x"] = CustomRange.deserialize(
- visualize.get("custom-range-x")
- )
- init_data["custom_ticks_x"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-x", "")
- )
- if "x-grid-format" in visualize:
- init_data["x_grid_format"] = visualize["x-grid-format"]
-
- # Parse grid-lines-x
- if "grid-lines-x" in visualize:
- grid_lines_x = visualize["grid-lines-x"]
- if isinstance(grid_lines_x, dict):
- enabled = grid_lines_x.get("enabled", False)
- grid_type = grid_lines_x.get("type", "")
- init_data["x_grid"] = grid_type if enabled else "off"
-
- # Vertical axis (Y-axis)
- init_data["custom_range_y"] = CustomRange.deserialize(
- visualize.get("custom-range")
- )
- init_data["custom_ticks_y"] = CustomTicks.deserialize(
- visualize.get("custom-ticks", "")
- )
- if "y-grid-format" in visualize:
- init_data["y_grid_format"] = visualize["y-grid-format"]
- if "grid-lines" in visualize:
- init_data["y_grid"] = visualize["grid-lines"]
+ # Horizontal and vertical axis (from mixins)
+ init_data.update(cls._deserialize_grid_config(visualize))
+ init_data.update(cls._deserialize_grid_format(visualize))
+ init_data.update(cls._deserialize_custom_range(visualize))
+ init_data.update(cls._deserialize_custom_ticks(visualize))
- # Parse yAxisLabels
+ # Vertical axis (chart-specific) - Parse yAxisLabels
if "yAxisLabels" in visualize:
y_axis_labels = visualize["yAxisLabels"]
if isinstance(y_axis_labels, dict):
diff --git a/datawrapper/charts/line.py b/datawrapper/charts/line.py
index 0fc467c1..ee6103d0 100644
--- a/datawrapper/charts/line.py
+++ b/datawrapper/charts/line.py
@@ -13,7 +13,6 @@
from .base import BaseChart
from .enums import (
DateFormat,
- GridDisplay,
GridLabelAlign,
GridLabelPosition,
LineDash,
@@ -25,10 +24,14 @@
SymbolShape,
SymbolStyle,
)
+from .mixins import (
+ CustomRangeMixin,
+ CustomTicksMixin,
+ GridDisplayMixin,
+ GridFormatMixin,
+)
from .serializers import (
ColorCategory,
- CustomRange,
- CustomTicks,
ModelListSerializer,
PlotHeight,
)
@@ -391,7 +394,9 @@ def deserialize_model(cls, line_name: str, line_config: dict) -> dict[str, Any]:
return init_dict
-class LineChart(BaseChart):
+class LineChart(
+ GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart
+):
"""A base class for the Datawrapper API's line chart."""
model_config = ConfigDict(
@@ -429,66 +434,12 @@ class LineChart(BaseChart):
#
# Horizontal axis (X-axis)
#
-
- #: The custom range for the x axis
- custom_range_x: list[Any] | tuple[Any, Any] = Field(
- default_factory=lambda: ["", ""],
- alias="custom-range-x",
- description="The custom range for the x axis",
- )
-
- #: The custom ticks for the x axis
- custom_ticks_x: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-x",
- description="The custom ticks for the x axis",
- )
-
- #: The formatting for the x grid labels (use DateFormat or NumberFormat enum or custom format strings)
- x_grid_format: DateFormat | NumberFormat | str = Field(
- default="auto",
- alias="x-grid-format",
- description="The formatting for the x grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the x grid
- x_grid: GridDisplay | str = Field(
- default="off",
- alias="x-grid",
- description="Whether to show the x grid. The 'on' setting shows lines.",
- )
+ # Note: x_grid, x_grid_format, custom_range_x, custom_ticks_x inherited from mixins
#
# Vertical axis (Y-axis)
#
-
- #: The custom range for the y axis
- custom_range_y: list[Any] | tuple[Any, Any] = Field(
- default_factory=lambda: ["", ""],
- alias="custom-range-y",
- description="The custom range for the y axis",
- )
-
- #: The custom ticks for the y axis
- custom_ticks_y: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-y",
- description="The custom ticks for the y axis",
- )
-
- #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings)
- y_grid_format: DateFormat | NumberFormat | str = Field(
- default="",
- alias="y-grid-format",
- description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the y grid
- y_grid: GridDisplay | str = Field(
- default="on",
- alias="y-grid",
- description="Whether to show the y grid. The 'on' setting shows lines.",
- )
+ # Note: y_grid, y_grid_format, custom_range_y, custom_ticks_y inherited from mixins
#: The labeling of the y grid labels
y_grid_labels: GridLabelPosition | str = Field(
@@ -700,56 +651,52 @@ def serialize_model(self) -> dict:
model = super().serialize_model()
# Add chart specific properties
- model["metadata"]["visualize"].update(
- {
- # Horizontal axis
- "custom-range-x": CustomRange.serialize(self.custom_range_x),
- "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x),
- "x-grid-format": self.x_grid_format,
- "x-grid": self.x_grid,
- # Vertical axis
- "custom-range-y": CustomRange.serialize(self.custom_range_y),
- "custom-ticks-y": CustomTicks.serialize(self.custom_ticks_y),
- "y-grid-format": self.y_grid_format,
- "y-grid": self.y_grid,
- "y-grid-labels": self.y_grid_labels,
- "y-grid-label-align": self.y_grid_label_align,
- "scale-y": self.scale_y,
- "y-grid-subdivide": self.y_grid_subdivide,
- # Customize lines
- "base-color": self.base_color,
- "interpolation": self.interpolation,
- "connector-lines": self.connector_lines,
- "color-category": ColorCategory.serialize(self.color_category),
- # Labels
- "stack-color-legend": self.stack_color_legend,
- "label-colors": self.label_colors,
- "label-margin": self.label_margin,
- "value-labels-format": self.value_labels_format,
- "value-label-colors": self.value_label_colors,
- # Tooltips
- "show-tooltips": self.show_tooltips,
- "tooltip-x-format": self.tooltip_x_format,
- "tooltip-number-format": self.tooltip_number_format,
- # Appearance
- **PlotHeight.serialize(
- self.plot_height_mode,
- self.plot_height_fixed,
- self.plot_height_ratio,
- ),
- # Initialize empty structures
- "lines": {},
- "text-annotations": ModelListSerializer.serialize(
- self.text_annotations, TextAnnotation
- ),
- "range-annotations": ModelListSerializer.serialize(
- self.range_annotations, RangeAnnotation
- ),
- "custom-area-fills": ModelListSerializer.serialize(
- self.area_fills, AreaFill
- ),
- }
- )
+ visualize_data = {
+ # Horizontal axis (from mixins)
+ **self._serialize_grid_config(),
+ **self._serialize_grid_format(),
+ **self._serialize_custom_range(),
+ **self._serialize_custom_ticks(),
+ # Vertical axis (chart-specific)
+ "y-grid-labels": self.y_grid_labels,
+ "y-grid-label-align": self.y_grid_label_align,
+ "scale-y": self.scale_y,
+ "y-grid-subdivide": self.y_grid_subdivide,
+ # Customize lines
+ "base-color": self.base_color,
+ "interpolation": self.interpolation,
+ "connector-lines": self.connector_lines,
+ "color-category": ColorCategory.serialize(self.color_category),
+ # Labels
+ "stack-color-legend": self.stack_color_legend,
+ "label-colors": self.label_colors,
+ "label-margin": self.label_margin,
+ "value-labels-format": self.value_labels_format,
+ "value-label-colors": self.value_label_colors,
+ # Tooltips
+ "show-tooltips": self.show_tooltips,
+ "tooltip-x-format": self.tooltip_x_format,
+ "tooltip-number-format": self.tooltip_number_format,
+ # Appearance
+ **PlotHeight.serialize(
+ self.plot_height_mode,
+ self.plot_height_fixed,
+ self.plot_height_ratio,
+ ),
+ # Initialize empty structures
+ "lines": {},
+ "text-annotations": ModelListSerializer.serialize(
+ self.text_annotations, TextAnnotation
+ ),
+ "range-annotations": ModelListSerializer.serialize(
+ self.range_annotations, RangeAnnotation
+ ),
+ "custom-area-fills": ModelListSerializer.serialize(
+ self.area_fills, AreaFill
+ ),
+ }
+
+ model["metadata"]["visualize"].update(visualize_data)
# Add line configurations
for line_obj in self.lines:
@@ -784,30 +731,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
metadata = api_response.get("metadata", {})
visualize = metadata.get("visualize", {})
- # Horizontal axis (X-axis)
- init_data["custom_range_x"] = CustomRange.deserialize(
- visualize.get("custom-range-x")
- )
- init_data["custom_ticks_x"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-x", "")
- )
- if "x-grid-format" in visualize:
- init_data["x_grid_format"] = visualize["x-grid-format"]
- if "x-grid" in visualize:
- init_data["x_grid"] = visualize["x-grid"]
-
- # Vertical axis (Y-axis)
- init_data["custom_range_y"] = CustomRange.deserialize(
- visualize.get("custom-range-y")
- )
- init_data["custom_ticks_y"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-y", "")
- )
+ # Horizontal and vertical axis (from mixins)
+ init_data.update(cls._deserialize_grid_config(visualize))
+ init_data.update(cls._deserialize_grid_format(visualize))
+ init_data.update(cls._deserialize_custom_range(visualize))
+ init_data.update(cls._deserialize_custom_ticks(visualize))
- if "y-grid-format" in visualize:
- init_data["y_grid_format"] = visualize["y-grid-format"]
- if "y-grid" in visualize:
- init_data["y_grid"] = visualize["y-grid"]
+ # Vertical axis (chart-specific)
if "y-grid-labels" in visualize:
init_data["y_grid_labels"] = visualize["y-grid-labels"]
if "y-grid-label-align" in visualize:
diff --git a/datawrapper/charts/mixins.py b/datawrapper/charts/mixins.py
new file mode 100644
index 00000000..94300fe7
--- /dev/null
+++ b/datawrapper/charts/mixins.py
@@ -0,0 +1,256 @@
+"""Mixin classes for shared chart visualization patterns."""
+
+from typing import Any
+
+from pydantic import Field
+
+from .enums import DateFormat, GridDisplay, NumberFormat
+from .serializers import CustomRange, CustomTicks
+
+
+class GridDisplayMixin:
+ """Mixin for charts that support grid display configuration.
+
+ Provides x_grid and y_grid fields for controlling grid line visibility,
+ along with serialization/deserialization methods.
+
+ Default values:
+ - x_grid: "off" (no vertical grid lines by default)
+ - y_grid: "on" (horizontal grid lines shown by default)
+
+ Supports backwards compatibility with boolean values:
+ - True → "on" during serialization
+ - False → "off" during serialization
+ - API "on" → True during deserialization
+ - API "off" → False during deserialization
+ """
+
+ x_grid: GridDisplay | str | bool | None = Field(
+ default="off",
+ description="X-axis grid display setting. Controls vertical grid lines.",
+ )
+ y_grid: GridDisplay | str | bool | None = Field(
+ default="on",
+ description="Y-axis grid display setting. Controls horizontal grid lines.",
+ )
+
+ def _serialize_grid_config(self) -> dict:
+ """Serialize grid configuration to API format.
+
+ Handles conversion of boolean values to "on"/"off" strings for backwards compatibility.
+
+ Returns:
+ dict: Grid configuration in API format with keys:
+ - x-grid: X-axis grid display setting
+ - y-grid: Y-axis grid display setting
+ """
+ result = {}
+ if self.x_grid is not None:
+ # Handle boolean values for backwards compatibility
+ if isinstance(self.x_grid, bool):
+ result["x-grid"] = "on" if self.x_grid else "off"
+ elif isinstance(self.x_grid, GridDisplay):
+ result["x-grid"] = self.x_grid.value
+ else:
+ result["x-grid"] = self.x_grid
+ if self.y_grid is not None:
+ # Handle boolean values for backwards compatibility
+ if isinstance(self.y_grid, bool):
+ result["y-grid"] = "on" if self.y_grid else "off"
+ elif isinstance(self.y_grid, GridDisplay):
+ result["y-grid"] = self.y_grid.value
+ else:
+ result["y-grid"] = self.y_grid
+ return result
+
+ @classmethod
+ def _deserialize_grid_config(cls, visualize: dict) -> dict:
+ """Deserialize grid configuration from API format.
+
+ Preserves original API values without conversion.
+
+ Args:
+ visualize: The visualize section from API response
+
+ Returns:
+ dict: Grid configuration in Python format with keys:
+ - x_grid: X-axis grid display setting (preserves API type)
+ - y_grid: Y-axis grid display setting (preserves API type)
+ """
+ result = {}
+ if "x-grid" in visualize:
+ result["x_grid"] = visualize["x-grid"]
+ if "y-grid" in visualize:
+ result["y_grid"] = visualize["y-grid"]
+ return result
+
+
+class GridFormatMixin:
+ """Mixin for charts that support grid label formatting.
+
+ Provides x_grid_format and y_grid_format fields for controlling how grid labels
+ are displayed, along with serialization/deserialization methods.
+
+ Used by: LineChart, AreaChart, ColumnChart, MultipleColumnChart, ScatterPlot
+ """
+
+ x_grid_format: DateFormat | NumberFormat | str | None = Field(
+ default=None,
+ description="Format string for X-axis grid labels. Supports date and number formats.",
+ )
+ y_grid_format: NumberFormat | str | None = Field(
+ default=None,
+ description="Format string for Y-axis grid labels. Supports number formats.",
+ )
+
+ def _serialize_grid_format(self) -> dict:
+ """Serialize grid format configuration to API format.
+
+ Returns:
+ dict: Grid format configuration in API format with keys:
+ - x-grid-format: X-axis grid label format
+ - y-grid-format: Y-axis grid label format
+ """
+ result = {}
+ if self.x_grid_format is not None:
+ result["x-grid-format"] = (
+ self.x_grid_format.value
+ if isinstance(self.x_grid_format, (DateFormat, NumberFormat))
+ else self.x_grid_format
+ )
+ if self.y_grid_format is not None:
+ result["y-grid-format"] = (
+ self.y_grid_format.value
+ if isinstance(self.y_grid_format, NumberFormat)
+ else self.y_grid_format
+ )
+ return result
+
+ @classmethod
+ def _deserialize_grid_format(cls, visualize: dict) -> dict:
+ """Deserialize grid format configuration from API format.
+
+ Args:
+ visualize: The visualize section from API response
+
+ Returns:
+ dict: Grid format configuration in Python format with keys:
+ - x_grid_format: X-axis grid label format
+ - y_grid_format: Y-axis grid label format
+ """
+ result = {}
+ if "x-grid-format" in visualize:
+ result["x_grid_format"] = visualize["x-grid-format"]
+ if "y-grid-format" in visualize:
+ result["y_grid_format"] = visualize["y-grid-format"]
+ return result
+
+
+class CustomRangeMixin:
+ """Mixin for charts that support custom axis ranges.
+
+ Provides custom_range_x and custom_range_y fields for setting explicit min/max
+ values for axes, along with serialization/deserialization methods.
+ """
+
+ custom_range_x: list[Any] | tuple[Any, Any] | None = Field(
+ default=None,
+ description="Custom range for X-axis as [min, max]. Overrides automatic range calculation.",
+ )
+ custom_range_y: list[Any] | tuple[Any, Any] | None = Field(
+ default=None,
+ description="Custom range for Y-axis as [min, max]. Overrides automatic range calculation.",
+ )
+
+ def _serialize_custom_range(self) -> dict:
+ """Serialize custom range configuration to API format.
+
+ Returns:
+ dict: Custom range configuration in API format with keys:
+ - custom-range-x: X-axis custom range [min, max]
+ - custom-range-y: Y-axis custom range [min, max]
+ """
+ result = {}
+ if self.custom_range_x is not None:
+ result["custom-range-x"] = CustomRange.serialize(self.custom_range_x)
+ if self.custom_range_y is not None:
+ result["custom-range-y"] = CustomRange.serialize(self.custom_range_y)
+ return result
+
+ @classmethod
+ def _deserialize_custom_range(cls, visualize: dict) -> dict:
+ """Deserialize custom range configuration from API format.
+
+ Args:
+ visualize: The visualize section from API response
+
+ Returns:
+ dict: Custom range configuration in Python format with keys:
+ - custom_range_x: X-axis custom range [min, max]
+ - custom_range_y: Y-axis custom range [min, max]
+ """
+ result = {}
+ if "custom-range-x" in visualize:
+ result["custom_range_x"] = CustomRange.deserialize(
+ visualize["custom-range-x"]
+ )
+ if "custom-range-y" in visualize:
+ result["custom_range_y"] = CustomRange.deserialize(
+ visualize["custom-range-y"]
+ )
+ return result
+
+
+class CustomTicksMixin:
+ """Mixin for charts that support custom tick marks.
+
+ Provides custom_ticks_x and custom_ticks_y fields for setting explicit tick mark
+ positions on axes, along with serialization/deserialization methods.
+ """
+
+ custom_ticks_x: list[Any] | None = Field(
+ default=None,
+ description="Custom tick mark positions for X-axis. List of values where ticks should appear.",
+ )
+ custom_ticks_y: list[Any] | None = Field(
+ default=None,
+ description="Custom tick mark positions for Y-axis. List of values where ticks should appear.",
+ )
+
+ def _serialize_custom_ticks(self) -> dict:
+ """Serialize custom ticks configuration to API format.
+
+ Returns:
+ dict: Custom ticks configuration in API format with keys:
+ - custom-ticks-x: X-axis custom tick positions
+ - custom-ticks-y: Y-axis custom tick positions
+ """
+ result = {}
+ if self.custom_ticks_x is not None:
+ result["custom-ticks-x"] = CustomTicks.serialize(self.custom_ticks_x)
+ if self.custom_ticks_y is not None:
+ result["custom-ticks-y"] = CustomTicks.serialize(self.custom_ticks_y)
+ return result
+
+ @classmethod
+ def _deserialize_custom_ticks(cls, visualize: dict) -> dict:
+ """Deserialize custom ticks configuration from API format.
+
+ Args:
+ visualize: The visualize section from API response
+
+ Returns:
+ dict: Custom ticks configuration in Python format with keys:
+ - custom_ticks_x: X-axis custom tick positions
+ - custom_ticks_y: Y-axis custom tick positions
+ """
+ result = {}
+ if "custom-ticks-x" in visualize:
+ result["custom_ticks_x"] = CustomTicks.deserialize(
+ visualize["custom-ticks-x"]
+ )
+ if "custom-ticks-y" in visualize:
+ result["custom_ticks_y"] = CustomTicks.deserialize(
+ visualize["custom-ticks-y"]
+ )
+ return result
diff --git a/datawrapper/charts/multiple_column.py b/datawrapper/charts/multiple_column.py
index aaf8d737..3a7fb765 100644
--- a/datawrapper/charts/multiple_column.py
+++ b/datawrapper/charts/multiple_column.py
@@ -15,10 +15,14 @@
ValueLabelDisplay,
ValueLabelPlacement,
)
+from .mixins import (
+ CustomRangeMixin,
+ CustomTicksMixin,
+ GridDisplayMixin,
+ GridFormatMixin,
+)
from .serializers import (
ColorCategory,
- CustomRange,
- CustomTicks,
ModelListSerializer,
NegativeColor,
PlotHeight,
@@ -26,7 +30,9 @@
)
-class MultipleColumnChart(BaseChart):
+class MultipleColumnChart(
+ GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart
+):
"""A base class for the Datawrapper API's multiple column chart."""
model_config = ConfigDict(
@@ -131,34 +137,6 @@ class MultipleColumnChart(BaseChart):
# Horizontal axis
#
- #: The custom range for the x axis
- custom_range_x: tuple[Any, Any] | list[Any] = Field(
- default=("", ""),
- alias="custom-range-x",
- description="The custom range for the x axis",
- )
-
- #: The custom ticks for the x axis
- custom_ticks_x: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-x",
- description="The custom ticks for the x axis",
- )
-
- #: The formatting for the x grid labels
- x_grid_format: str = Field(
- default="auto",
- alias="x-grid-format",
- description="The formatting for the x grid labels",
- )
-
- #: Whether to show the x grid
- x_grid: GridDisplay | str = Field(
- default="off",
- alias="x-grid",
- description="Whether to show the x grid",
- )
-
#: The labeling of the x axis
x_grid_labels: Literal["on", "off"] = Field(
default="on",
@@ -177,34 +155,6 @@ class MultipleColumnChart(BaseChart):
# Vertical axis
#
- #: The custom range for the y axis
- custom_range_y: tuple[Any, Any] | list[Any] = Field(
- default=("", ""),
- alias="custom-range-y",
- description="The custom range for the y axis",
- )
-
- #: The custom ticks for the y axis
- custom_ticks_y: list[Any] = Field(
- default_factory=list,
- alias="custom-ticks-y",
- description="The custom ticks for the y axis",
- )
-
- #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings)
- y_grid_format: DateFormat | NumberFormat | str = Field(
- default="",
- alias="y-grid-format",
- description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.",
- )
-
- #: Whether to show the y grid lines
- y_grid: bool = Field(
- default=True,
- alias="y-grid",
- description="Whether to show the y grid lines",
- )
-
#: The labeling of the y grid labels
y_grid_labels: GridLabelPosition | str = Field(
default="outside",
@@ -385,77 +335,74 @@ def serialize_model(self) -> dict:
model = super().serialize_model()
# Add chart specific properties to visualize section
- model["metadata"]["visualize"].update(
- {
- # Layout
- "gridLayout": self.grid_layout,
- "gridColumnCount": self.grid_column,
- "gridColumnCountMobile": self.grid_column_mobile,
- "gridColumnMinWidth": self.grid_column_width,
- "gridRowHeightFixed": self.grid_row_height,
- "sort": {
- "enabled": self.sort,
- "reverse": self.sort_reverse,
- "by": self.sort_by,
- },
- # Horizontal axis
- "custom-range-x": CustomRange.serialize(self.custom_range_x),
- "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x),
- "x-grid-format": self.x_grid_format,
- "x-grid-labels": self.x_grid_labels,
- "x-grid": self.x_grid_all,
- "grid-lines-x": {
- "type": "" if self.x_grid == "off" else self.x_grid,
- "enabled": self.x_grid != "off",
- },
- # Vertical axis
- "custom-range-y": CustomRange.serialize(self.custom_range_y),
- "custom-ticks-y": CustomTicks.serialize(self.custom_ticks_y),
- "y-grid-format": self.y_grid_format,
- "grid-lines": self.y_grid,
- "yAxisLabels": {
- "enabled": self.y_grid_labels != "off",
- "alignment": self.y_grid_label_align,
- "placement": ""
- if self.y_grid_labels == "off"
- else self.y_grid_labels,
- },
- # Appearance
- "base-color": self.base_color,
- "negativeColor": NegativeColor.serialize(self.negative_color),
- "bar-padding": self.bar_padding,
- "color-category": ColorCategory.serialize(self.color_category),
- "color-by-column": bool(self.color_category),
- **PlotHeight.serialize(
- self.plot_height_mode,
- self.plot_height_fixed,
- self.plot_height_ratio,
- ),
- "panels": {panel["column"]: panel for panel in self.panels},
- # Tooltips
- "show-tooltips": self.show_tooltips,
- "syncMultipleTooltips": self.sync_multiple_tooltips,
- "tooltip-number-format": self.tooltip_number_format,
- # Labels
- "show-color-key": self.show_color_key,
- "label-colors": self.label_colors,
- "label-margin": self.label_margin,
- **ValueLabels.serialize(
- self.show_value_labels,
- self.value_labels_format,
- placement=self.value_labels_placement,
- chart_type="multiple-column",
- ),
- "xGridLabelAllColumns": self.x_grid_label_all,
- # Annotations
- "text-annotations": ModelListSerializer.serialize(
- self.text_annotations, TextAnnotation
- ),
- "range-annotations": ModelListSerializer.serialize(
- self.range_annotations, RangeAnnotation
- ),
- }
- )
+ visualize_data = {
+ # Layout
+ "gridLayout": self.grid_layout,
+ "gridColumnCount": self.grid_column,
+ "gridColumnCountMobile": self.grid_column_mobile,
+ "gridColumnMinWidth": self.grid_column_width,
+ "gridRowHeightFixed": self.grid_row_height,
+ "sort": {
+ "enabled": self.sort,
+ "reverse": self.sort_reverse,
+ "by": self.sort_by,
+ },
+ # Horizontal and vertical axis (from mixins)
+ **self._serialize_grid_config(),
+ **self._serialize_grid_format(),
+ **self._serialize_custom_range(),
+ **self._serialize_custom_ticks(),
+ # Horizontal axis (chart-specific)
+ "x-grid-labels": self.x_grid_labels,
+ "x-grid": self.x_grid_all,
+ "grid-lines-x": {
+ "type": "" if self.x_grid == "off" else self.x_grid,
+ "enabled": self.x_grid != "off",
+ },
+ # Vertical axis (chart-specific)
+ "grid-lines": self.y_grid,
+ "yAxisLabels": {
+ "enabled": self.y_grid_labels != "off",
+ "alignment": self.y_grid_label_align,
+ "placement": "" if self.y_grid_labels == "off" else self.y_grid_labels,
+ },
+ # Appearance
+ "base-color": self.base_color,
+ "negativeColor": NegativeColor.serialize(self.negative_color),
+ "bar-padding": self.bar_padding,
+ "color-category": ColorCategory.serialize(self.color_category),
+ "color-by-column": bool(self.color_category),
+ **PlotHeight.serialize(
+ self.plot_height_mode,
+ self.plot_height_fixed,
+ self.plot_height_ratio,
+ ),
+ "panels": {panel["column"]: panel for panel in self.panels},
+ # Tooltips
+ "show-tooltips": self.show_tooltips,
+ "syncMultipleTooltips": self.sync_multiple_tooltips,
+ "tooltip-number-format": self.tooltip_number_format,
+ # Labels
+ "show-color-key": self.show_color_key,
+ "label-colors": self.label_colors,
+ "label-margin": self.label_margin,
+ **ValueLabels.serialize(
+ self.show_value_labels,
+ self.value_labels_format,
+ placement=self.value_labels_placement,
+ chart_type="multiple-column",
+ ),
+ "xGridLabelAllColumns": self.x_grid_label_all,
+ # Annotations
+ "text-annotations": ModelListSerializer.serialize(
+ self.text_annotations, TextAnnotation
+ ),
+ "range-annotations": ModelListSerializer.serialize(
+ self.range_annotations, RangeAnnotation
+ ),
+ }
+
+ model["metadata"]["visualize"].update(visualize_data)
# Return the serialized data
return model
@@ -501,15 +448,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
init_data["sort_reverse"] = False
init_data["sort_by"] = "end"
- # Horizontal axis
- init_data["custom_range_x"] = CustomRange.deserialize(
- visualize.get("custom-range-x")
- )
- init_data["custom_ticks_x"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-x", "")
- )
- if "x-grid-format" in visualize:
- init_data["x_grid_format"] = visualize["x-grid-format"]
+ # Horizontal and vertical axis (from mixins)
+ init_data.update(cls._deserialize_grid_config(visualize))
+ init_data.update(cls._deserialize_grid_format(visualize))
+ init_data.update(cls._deserialize_custom_range(visualize))
+ init_data.update(cls._deserialize_custom_ticks(visualize))
+
+ # Horizontal axis (chart-specific)
if "x-grid-labels" in visualize:
init_data["x_grid_labels"] = visualize["x-grid-labels"]
if "x-grid" in visualize:
@@ -525,16 +470,7 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
else:
init_data["x_grid"] = "off"
- # Vertical axis
- init_data["custom_range_y"] = CustomRange.deserialize(
- visualize.get("custom-range-y")
- )
- init_data["custom_ticks_y"] = CustomTicks.deserialize(
- visualize.get("custom-ticks-y", "")
- )
- if "y-grid-format" in visualize:
- init_data["y_grid_format"] = visualize["y-grid-format"]
-
+ # Vertical axis (chart-specific)
# Parse grid-lines (can be bool or string "show")
if "grid-lines" in visualize:
grid_lines_val = visualize["grid-lines"]
diff --git a/datawrapper/charts/serializers/base.py b/datawrapper/charts/serializers/base.py
new file mode 100644
index 00000000..e7fc0bb8
--- /dev/null
+++ b/datawrapper/charts/serializers/base.py
@@ -0,0 +1,69 @@
+"""Base serializer class for all serialization utilities."""
+
+from abc import ABC, abstractmethod
+from typing import Any
+
+
+class BaseSerializer(ABC):
+ """Abstract base class for serialization utilities.
+
+ This class defines the standard interface that all serializer utilities
+ should implement. It provides a consistent pattern for converting between
+ Python objects and Datawrapper API JSON formats.
+
+ All serializer classes should inherit from this base class and implement
+ the serialize() and deserialize() methods.
+
+ Example:
+ >>> class CustomRange(BaseSerializer):
+ ... @staticmethod
+ ... def serialize(range_values: list[Any] | tuple[Any, Any]) -> list[Any]:
+ ... # Implementation here
+ ... pass
+ ...
+ ... @staticmethod
+ ... def deserialize(range_list: list[Any] | None) -> list[Any] | None:
+ ... # Implementation here
+ ... pass
+ """
+
+ @staticmethod
+ @abstractmethod
+ def serialize(*args: Any, **kwargs: Any) -> Any:
+ """Convert Python objects to Datawrapper API format.
+
+ This method should be implemented by subclasses to handle the
+ conversion from Python objects to the format expected by the
+ Datawrapper API.
+
+ Args:
+ *args: Positional arguments specific to the serializer
+ **kwargs: Keyword arguments specific to the serializer
+
+ Returns:
+ Any: The serialized data in API format
+
+ Raises:
+ NotImplementedError: If not implemented by subclass
+ """
+ raise NotImplementedError("Subclasses must implement serialize()")
+
+ @staticmethod
+ @abstractmethod
+ def deserialize(*args: Any, **kwargs: Any) -> Any:
+ """Convert Datawrapper API format to Python objects.
+
+ This method should be implemented by subclasses to handle the
+ conversion from the Datawrapper API format to Python objects.
+
+ Args:
+ *args: Positional arguments specific to the serializer
+ **kwargs: Keyword arguments specific to the serializer
+
+ Returns:
+ Any: The deserialized data as Python objects
+
+ Raises:
+ NotImplementedError: If not implemented by subclass
+ """
+ raise NotImplementedError("Subclasses must implement deserialize()")
diff --git a/datawrapper/charts/serializers/color_category.py b/datawrapper/charts/serializers/color_category.py
index b67deb72..4211a8f6 100644
--- a/datawrapper/charts/serializers/color_category.py
+++ b/datawrapper/charts/serializers/color_category.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class ColorCategory:
+
+class ColorCategory(BaseSerializer):
"""Utility class for serializing and deserializing color category structures."""
@staticmethod
diff --git a/datawrapper/charts/serializers/custom_range.py b/datawrapper/charts/serializers/custom_range.py
index 6edf41ba..355c98e3 100644
--- a/datawrapper/charts/serializers/custom_range.py
+++ b/datawrapper/charts/serializers/custom_range.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class CustomRange:
+
+class CustomRange(BaseSerializer):
"""Utility class for serializing and deserializing custom axis ranges."""
@staticmethod
diff --git a/datawrapper/charts/serializers/custom_ticks.py b/datawrapper/charts/serializers/custom_ticks.py
index 159ec3d6..aec871ca 100644
--- a/datawrapper/charts/serializers/custom_ticks.py
+++ b/datawrapper/charts/serializers/custom_ticks.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class CustomTicks:
+
+class CustomTicks(BaseSerializer):
"""Utility class for serializing and deserializing custom tick marks."""
@staticmethod
diff --git a/datawrapper/charts/serializers/negative_color.py b/datawrapper/charts/serializers/negative_color.py
index 84538e4d..88004f04 100644
--- a/datawrapper/charts/serializers/negative_color.py
+++ b/datawrapper/charts/serializers/negative_color.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class NegativeColor:
+
+class NegativeColor(BaseSerializer):
"""Utility class for serializing and deserializing negative color configuration.
The Datawrapper API uses a nested object format for the negativeColor field:
diff --git a/datawrapper/charts/serializers/plot_height.py b/datawrapper/charts/serializers/plot_height.py
index c3e89485..f86cf71c 100644
--- a/datawrapper/charts/serializers/plot_height.py
+++ b/datawrapper/charts/serializers/plot_height.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class PlotHeight:
+
+class PlotHeight(BaseSerializer):
"""Utility class for serializing and deserializing plot height configuration.
The Datawrapper API uses three separate fields for plot height:
diff --git a/datawrapper/charts/serializers/replace_flags.py b/datawrapper/charts/serializers/replace_flags.py
index d76019f6..adf6877c 100644
--- a/datawrapper/charts/serializers/replace_flags.py
+++ b/datawrapper/charts/serializers/replace_flags.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class ReplaceFlags:
+
+class ReplaceFlags(BaseSerializer):
"""Utility class for serializing and deserializing replace-flags configuration.
The Datawrapper API uses a nested object format for the replace-flags field:
diff --git a/datawrapper/charts/serializers/value_labels.py b/datawrapper/charts/serializers/value_labels.py
index ddeb8b1a..2ad86229 100644
--- a/datawrapper/charts/serializers/value_labels.py
+++ b/datawrapper/charts/serializers/value_labels.py
@@ -1,7 +1,9 @@
from typing import Any
+from .base import BaseSerializer
-class ValueLabels:
+
+class ValueLabels(BaseSerializer):
"""Utility class for serializing and deserializing value label configuration.
Different chart types use different API formats for value labels:
diff --git a/docs/index.md b/docs/index.md
index a6685f83..4fc9066b 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -46,6 +46,7 @@ caption: API Reference
user-guide/api/main-client.rst
user-guide/api/chart-classes.rst
user-guide/api/models.rst
+user-guide/api/mixins.rst
user-guide/api/enums.rst
user-guide/api/exceptions.rst
```
diff --git a/docs/user-guide/advanced/exporting.md b/docs/user-guide/advanced/exporting.md
deleted file mode 100644
index 74a04419..00000000
--- a/docs/user-guide/advanced/exporting.md
+++ /dev/null
@@ -1,54 +0,0 @@
-# Exporting Charts
-
-Export Datawrapper charts in various formats including PNG, PDF, and SVG. In many cases, exporting can be done directly through the chart object methods and these methods should be considered deprecated. However, for advanced use cases or when working directly with the Datawrapper API, the following examples demonstrate how to export charts using the client. Where possible, prefer using the chart object's export methods.
-
-## Export as PNG
-
-Export a chart as a PNG image:
-
-```python
-client.export_chart(
- chart_id="abc123",
- output="png",
- filepath="chart.png",
- display=True # Opens the image after saving
-)
-```
-
-## Export as PDF
-
-Export a chart as a PDF:
-
-```python
-client.export_chart(
- chart_id="abc123",
- output="pdf",
- filepath="chart.pdf"
-)
-```
-
-## Export as SVG
-
-Export a chart as an SVG:
-
-```python
-client.export_chart(
- chart_id="abc123",
- output="svg",
- filepath="chart.svg"
-)
-```
-
-## Export with Custom Dimensions
-
-Specify custom dimensions for the export:
-
-```python
-client.export_chart(
- chart_id="abc123",
- output="png",
- filepath="chart.png",
- width=1200,
- height=800
-)
-```
diff --git a/docs/user-guide/api/mixins.rst b/docs/user-guide/api/mixins.rst
new file mode 100644
index 00000000..e1238fb8
--- /dev/null
+++ b/docs/user-guide/api/mixins.rst
@@ -0,0 +1,95 @@
+Mixins
+======
+
+The mixins module provides reusable functionality that can be shared across multiple chart types. These mixins handle common chart configuration patterns like grid display, formatting, and axis customization.
+
+.. currentmodule:: datawrapper.charts.mixins
+
+Grid Configuration
+------------------
+
+GridDisplayMixin
+~~~~~~~~~~~~~~~
+
+Controls the visibility of grid lines on chart axes.
+
+.. autoclass:: GridDisplayMixin
+ :members:
+ :show-inheritance:
+
+**Example:**
+
+.. code-block:: python
+
+ import datawrapper as dw
+
+ chart = dw.LineChart(
+ title="Temperature Trends",
+ x_grid=dw.GridDisplay.OFF,
+ y_grid=dw.GridDisplay.ON
+ )
+
+GridFormatMixin
+~~~~~~~~~~~~~~~
+
+Controls the formatting of grid labels on chart axes.
+
+.. autoclass:: GridFormatMixin
+ :members:
+ :show-inheritance:
+
+**Example:**
+
+.. code-block:: python
+
+ import datawrapper as dw
+
+ chart = dw.LineChart(
+ title="Sales Over Time",
+ x_grid_format=dw.DateFormat.MONTH_ABBREVIATED_WITH_YEAR,
+ y_grid_format=dw.NumberFormat.THOUSANDS_SEPARATOR
+ )
+
+Axis Customization
+------------------
+
+CustomRangeMixin
+~~~~~~~~~~~~~~~~
+
+Sets custom minimum and maximum values for chart axes.
+
+.. autoclass:: CustomRangeMixin
+ :members:
+ :show-inheritance:
+
+**Example:**
+
+.. code-block:: python
+
+ import datawrapper as dw
+
+ chart = dw.ColumnChart(
+ title="Revenue by Quarter",
+ custom_range_y=[0, 1000000] # Set Y-axis from 0 to 1M
+ )
+
+CustomTicksMixin
+~~~~~~~~~~~~~~~~
+
+Sets custom tick mark positions on chart axes.
+
+.. autoclass:: CustomTicksMixin
+ :members:
+ :show-inheritance:
+
+**Example:**
+
+.. code-block:: python
+
+ import datawrapper as dw
+
+ chart = dw.LineChart(
+ title="Monthly Data",
+ custom_ticks_x=["Jan", "Apr", "Jul", "Oct"],
+ custom_ticks_y=[0, 25, 50, 75, 100]
+ )
diff --git a/docs/user-guide/chart-operations.md b/docs/user-guide/chart-operations.md
index b00d928d..9ec770c7 100644
--- a/docs/user-guide/chart-operations.md
+++ b/docs/user-guide/chart-operations.md
@@ -138,3 +138,19 @@ png_url = chart.get_png_url()
html = f''
```
+
+## Exporting a chart in multiple formats
+
+You can export charts in various formats such as PNG, PDF, and SVG using the chart object's export methods:
+
+```python
+# Get the data in bytes
+png_data = chart.export_png(width=800, height=600)
+pdf_data = chart.export_pdf(mode="cmyk")
+svg_data = chart.export_svg(plain=True)
+
+# Save to disk
+Path("chart.png").write_bytes(png_data)
+Path("chart.pdf").write_bytes(pdf_data)
+Path("chart.svg").write_bytes(svg_data)
+```
diff --git a/tests/integration/test_base_export.py b/tests/integration/test_base_export.py
index 327b159c..a938a520 100644
--- a/tests/integration/test_base_export.py
+++ b/tests/integration/test_base_export.py
@@ -1,59 +1,489 @@
-"""Test the export method on BaseChart."""
+"""Integration tests for BaseChart export methods.
+
+These tests use mocked API calls to verify the export_png, export_pdf, and export_svg
+methods work correctly without requiring actual API access.
+"""
from unittest.mock import MagicMock, patch
import pytest
-from datawrapper.charts.base import BaseChart
+from datawrapper.charts import BarChart
+
+
+class TestExportPNG:
+ """Test PNG export functionality."""
+
+ def test_export_png_success(self):
+ """Test successful PNG export with default parameters."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PNG_DATA"
+ created_clients.append(mock_client)
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_png()
+
+ # Verify
+ assert result == b"PNG_DATA"
+ mock_get_client.assert_called_once()
+ # Get the mock client that was created
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ assert (
+ call_args[0][0]
+ == "https://api.datawrapper.de/v3/charts/abc123/export/png"
+ )
+ assert call_args[1]["params"]["unit"] == "px"
+ assert call_args[1]["params"]["mode"] == "rgb"
+ assert call_args[1]["params"]["plain"] == "false"
+
+ def test_export_png_with_all_parameters(self):
+ """Test PNG export with all parameters specified."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PNG_DATA"
+ created_clients.append(mock_client)
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with all parameters
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_png(
+ width=800,
+ height=600,
+ plain=True,
+ zoom=3,
+ transparent=True,
+ border_width=10,
+ border_color="#FF0000",
+ )
+
+ # Verify
+ assert result == b"PNG_DATA"
+ mock_get_client.assert_called_once()
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ params = call_args[1]["params"]
+ assert params["width"] == "800"
+ assert params["height"] == "600"
+ assert params["plain"] == "true"
+ assert params["zoom"] == "3"
+ assert params["transparent"] == "true"
+ assert params["borderWidth"] == "10"
+ assert params["borderColor"] == "#FF0000"
+
+ def test_export_png_no_chart_id(self):
+ """Test that export_png raises ValueError when no chart_id is set."""
+ chart = BarChart(title="Test Chart")
+ with pytest.raises(ValueError, match="No chart_id set"):
+ chart.export_png()
+
+ def test_export_png_custom_access_token(self):
+ """Test PNG export with custom access token."""
+
+ # Create mock client factory
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PNG_DATA"
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with custom token
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_png(access_token="custom_token")
+
+ # Verify
+ assert result == b"PNG_DATA"
+ mock_get_client.assert_called_once()
+
+
+class TestExportPDF:
+ """Test PDF export functionality."""
+
+ def test_export_pdf_success(self):
+ """Test successful PDF export with default parameters."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PDF_DATA"
+ created_clients.append(mock_client)
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_pdf()
+
+ # Verify
+ assert result == b"PDF_DATA"
+ mock_get_client.assert_called_once()
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ assert (
+ call_args[0][0]
+ == "https://api.datawrapper.de/v3/charts/abc123/export/pdf"
+ )
+ assert call_args[1]["params"]["unit"] == "px"
+ assert call_args[1]["params"]["mode"] == "rgb"
+ assert call_args[1]["params"]["plain"] == "false"
+
+ def test_export_pdf_with_all_parameters(self):
+ """Test PDF export with all parameters specified."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PDF_DATA"
+ created_clients.append(mock_client)
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with all parameters
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_pdf(
+ width=800,
+ height=600,
+ plain=True,
+ unit="mm",
+ mode="cmyk",
+ scale=2,
+ border_width=10,
+ border_color="#FF0000",
+ )
+
+ # Verify
+ assert result == b"PDF_DATA"
+ mock_get_client.assert_called_once()
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ params = call_args[1]["params"]
+ assert params["width"] == "800"
+ assert params["height"] == "600"
+ assert params["plain"] == "true"
+ assert params["unit"] == "mm"
+ assert params["mode"] == "cmyk"
+ assert params["scale"] == "2"
+ assert params["borderWidth"] == "10"
+ assert params["borderColor"] == "#FF0000"
+
+ def test_export_pdf_no_chart_id(self):
+ """Test that export_pdf raises ValueError when no chart_id is set."""
+ chart = BarChart(title="Test Chart")
+ with pytest.raises(ValueError, match="No chart_id set"):
+ chart.export_pdf()
+
+ def test_export_pdf_custom_access_token(self):
+ """Test PDF export with custom access token."""
+
+ # Create mock client factory
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PDF_DATA"
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with custom token
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_pdf(access_token="custom_token")
+
+ # Verify
+ assert result == b"PDF_DATA"
+ mock_get_client.assert_called_once()
+
+
+class TestExportSVG:
+ """Tests for the export_svg method."""
+
+ def test_export_svg_success(self):
+ """Test successful SVG export with default parameters."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b""
+ created_clients.append(mock_client)
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_svg()
+
+ # Verify
+ assert result == b""
+ assert isinstance(result, bytes)
+ mock_get_client.assert_called_once()
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ assert (
+ call_args[0][0]
+ == "https://api.datawrapper.de/v3/charts/abc123/export/svg"
+ )
+
+ def test_export_svg_with_all_parameters(self):
+ """Test SVG export with all optional parameters."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b""
+ created_clients.append(mock_client)
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with all parameters
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_svg(width=800, height=600, plain=True)
+
+ # Verify
+ assert result == b""
+ mock_get_client.assert_called_once()
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ url = call_args[0][0]
+ params = call_args[1]["params"]
+ assert url == "https://api.datawrapper.de/v3/charts/abc123/export/svg"
+ assert params["width"] == "800"
+ assert params["height"] == "600"
+ assert params["plain"] == "true"
+
+ def test_export_svg_no_chart_id(self):
+ """Test SVG export raises error when no chart_id is set."""
+ chart = BarChart(title="Test Chart")
+ with pytest.raises(ValueError, match="No chart_id set"):
+ chart.export_svg()
+
+ def test_export_svg_custom_access_token(self):
+ """Test SVG export with custom access token."""
+
+ # Create mock client factory
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b""
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with custom token
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_svg(access_token="custom_token")
+
+ # Verify
+ assert result == b""
+ mock_get_client.assert_called_once()
+
+
+class TestExportMethodComparison:
+ """Tests comparing the new export methods with the legacy export method."""
+
+ def test_export_png_vs_legacy_export(self):
+ """Test that export_png produces same result as legacy export method."""
+
+ # Create mock client factory
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PNG_IMAGE_DATA"
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+
+ # Export using new method
+ result_new = chart.export_png(width=800, height=600)
+
+ # Verify both produce bytes
+ assert isinstance(result_new, bytes)
+ assert result_new == b"PNG_IMAGE_DATA"
+ mock_get_client.assert_called_once()
+
+ def test_all_export_methods_return_bytes(self):
+ """Test that all export methods return bytes."""
+ # Create mock client factory that returns different data for each call
+ call_count = [0]
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ # Return different data based on call count
+ if call_count[0] == 0:
+ mock_client.get.return_value = b"PNG_DATA"
+ elif call_count[0] == 1:
+ mock_client.get.return_value = b"PDF_DATA"
+ else:
+ mock_client.get.return_value = b"SVG_DATA"
+ call_count[0] += 1
+ return mock_client
+
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+
+ # Test all export methods
+ png_result = chart.export_png()
+ pdf_result = chart.export_pdf()
+ svg_result = chart.export_svg()
+
+ # Verify all return bytes
+ assert isinstance(png_result, bytes)
+ assert isinstance(pdf_result, bytes)
+ assert isinstance(svg_result, bytes)
+ assert png_result == b"PNG_DATA"
+ assert pdf_result == b"PDF_DATA"
+ assert svg_result == b"SVG_DATA"
+ assert mock_get_client.call_count == 3
+
+
+class TestExportParameterValidation:
+ """Tests for parameter validation and formatting in export methods."""
+
+ def test_export_png_boolean_parameters(self):
+ """Test that boolean parameters are correctly formatted as lowercase strings."""
+ # Create mock client factory with closure to capture instance
+ created_clients = []
+
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PNG_DATA"
+ created_clients.append(mock_client)
+ return mock_client
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart and export with boolean parameters
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ result = chart.export_png(plain=True, transparent=False)
-def test_base_chart_export_method_exists():
- """Test that the export method exists on BaseChart."""
- chart = BaseChart(chart_type="d3-lines", title="Test Chart")
- assert hasattr(chart, "export")
- assert callable(chart.export)
+ # Verify
+ assert result == b"PNG_DATA"
+ mock_get_client.assert_called_once()
+ mock_client = created_clients[0]
+ call_args = mock_client.get.call_args
+ params = call_args[1]["params"]
+ # Verify boolean parameters are lowercase strings
+ assert params["plain"] == "true"
+ assert params["transparent"] == "false"
+ def test_export_pdf_unit_parameter(self):
+ """Test that unit parameter accepts valid values (px, mm, in)."""
-def test_base_chart_export_requires_chart_id():
- """Test that export raises ValueError when chart_id is not set."""
- chart = BaseChart(chart_type="d3-lines", title="Test Chart")
+ # Create mock client factory
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PDF_DATA"
+ return mock_client
- with pytest.raises(ValueError, match="No chart_id set"):
- chart.export()
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
+ # Test with different unit values
+ for unit in ["px", "mm", "inch"]:
+ result = chart.export_pdf(unit=unit)
+ assert result == b"PDF_DATA"
-def test_base_chart_export_with_chart_id(tmp_path):
- """Test that export works when chart_id is set."""
- # Create a chart with a chart_id
- chart = BaseChart(
- chart_type="d3-lines",
- title="Test Export Chart",
- data=[{"x": 1, "y": 2}, {"x": 2, "y": 4}],
- )
- chart.chart_id = "test123"
+ # Verify all three calls were made
+ assert mock_get_client.call_count == 3
- # Mock the client and its export_chart method
- mock_client = MagicMock()
- output_file = tmp_path / "test_export.png"
- mock_client.export_chart.return_value = output_file
+ def test_export_pdf_mode_parameter(self):
+ """Test that mode parameter accepts valid values (rgb, cmyk)."""
- with patch.object(chart, "_get_client", return_value=mock_client):
- # Export to a temporary file
- result = chart.export(
- output="png",
- filepath=str(output_file),
- display=False,
- )
+ # Create mock client factory
+ def mock_client_factory(access_token=None):
+ mock_client = MagicMock()
+ mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts"
+ mock_client.get.return_value = b"PDF_DATA"
+ return mock_client
- # Verify the client method was called
- mock_client.export_chart.assert_called_once()
+ with patch(
+ "datawrapper.charts.base.BaseChart._get_client",
+ side_effect=mock_client_factory,
+ ) as mock_get_client:
+ # Create chart
+ chart = BarChart(title="Test Chart")
+ chart.chart_id = "abc123"
- # Verify the chart_id was passed
- call_kwargs = mock_client.export_chart.call_args.kwargs
- assert call_kwargs["chart_id"] == "test123"
- assert call_kwargs["output"] == "png"
- assert call_kwargs["filepath"] == str(output_file)
- assert call_kwargs["display"] is False
+ # Test with different mode values
+ for mode in ["rgb", "cmyk"]:
+ result = chart.export_pdf(mode=mode)
+ assert result == b"PDF_DATA"
- # Verify the result
- assert result == output_file
+ # Verify both calls were made
+ assert mock_get_client.call_count == 2