From 972385184b41ae9bb922cde2106a095f0d9bc80e Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 21 Oct 2025 20:48:40 -0400 Subject: [PATCH 01/11] feat(charts): add export methods to BaseChart and deprecate Datawrapper.export_chart() Add new export_png(), export_pdf(), and export_svg() methods to the BaseChart class, providing a more intuitive object-oriented API for exporting charts. These methods return raw bytes that can be easily saved to files. Deprecate the legacy Datawrapper.export_chart() method with a warning that directs users to the new OO approach. The deprecated method will be removed in a future version. Changes: - Add export_png/pdf/svg methods to BaseChart with comprehensive parameter support - Include detailed docstrings with examples for each export method - Add DeprecationWarning to Datawrapper.export_chart() with migration guidance - Add comprehensive test coverage for all new export methods - Test error handling for missing chart_id and API failures This change improves the API consistency by allowing users to export charts directly from chart objects rather than through the main Datawrapper class. --- datawrapper/__main__.py | 12 + datawrapper/charts/base.py | 238 ++++++++++++ docs/user-guide/advanced/exporting.md | 54 --- docs/user-guide/chart-operations.md | 16 + tests/integration/test_base_export.py | 514 +++++++++++++++++++++++--- 5 files changed, 738 insertions(+), 96 deletions(-) delete mode 100644 docs/user-guide/advanced/exporting.md diff --git a/datawrapper/__main__.py b/datawrapper/__main__.py index 1e67728b..901af9f5 100644 --- a/datawrapper/__main__.py +++ b/datawrapper/__main__.py @@ -1156,6 +1156,10 @@ def export_chart( ) -> Path | Image: """Exports a chart, table, or map. + .. deprecated:: + Use the object-oriented chart classes instead (e.g., BarChart, LineChart). + This method will be removed in a future version. + Parameters ---------- chart_id : str @@ -1210,6 +1214,14 @@ def export_chart( Path | Image The file path to the exported image or an Image object displaying the image. """ + warnings.warn( + "export_chart() is deprecated and will be removed in a future version. " + "Use the object-oriented chart classes instead. " + "Example: chart = BarChart.get(chart_id='abc123'); png_data = chart.export_png(); Path('chart.png').write_bytes(png_data)", + DeprecationWarning, + stacklevel=2, + ) + _query = { "unit": unit, "mode": mode, diff --git a/datawrapper/charts/base.py b/datawrapper/charts/base.py index 84741233..6e718fce 100644 --- a/datawrapper/charts/base.py +++ b/datawrapper/charts/base.py @@ -733,6 +733,223 @@ def publish( # Return self for chaining return self + def export_png( + self, + *, + width: int | None = None, + height: int | None = None, + plain: bool = False, + zoom: int = 2, + transparent: bool = False, + border_width: int = 0, + border_color: str | None = None, + access_token: str | None = None, + ) -> bytes: + """Export chart as PNG and return the raw bytes. + + Args: + width: Width of visualization in pixels. If not specified, uses chart width. + height: Height of visualization in pixels. If not specified, uses chart height. + plain: If True, exports only the visualization without header/footer. + zoom: Scale multiplier for PNG resolution (e.g., 2 = 2x resolution). + transparent: If True, exports with transparent background. + border_width: Margin around visualization in pixels. + border_color: Color of the border (e.g., "#FFFFFF"). If not specified, uses chart background color. + access_token: Optional Datawrapper API access token. + + Returns: + Raw PNG image data as bytes. + + Raises: + ValueError: If no chart_id is set. + Exception: If the API request fails. + + Example: + >>> chart = LineChart.get(chart_id="abc123") + >>> png_data = chart.export_png(zoom=3, transparent=True) + >>> Path("chart.png").write_bytes(png_data) + """ + if not self.chart_id: + raise ValueError( + "No chart_id set. Use create() first or set chart_id manually." + ) + + client = self._get_client(access_token) + + # Build query parameters with PNG-specific defaults + params = { + "unit": "px", + "mode": "rgb", + "plain": str(plain).lower(), + "zoom": str(zoom), + "transparent": str(transparent).lower(), + "borderWidth": str(border_width), + } + + if width is not None: + params["width"] = str(width) + if height is not None: + params["height"] = str(height) + if border_color is not None: + params["borderColor"] = border_color + + # Make the API request + response = client.get( + f"{client._CHARTS_URL}/{self.chart_id}/export/png", + params=params, + ) + + # Return raw bytes + if isinstance(response, bytes): + return response + raise ValueError(f"Unexpected response type from API: {type(response)}") + + def export_pdf( + self, + *, + width: int | None = None, + height: int | None = None, + plain: bool = False, + unit: str = "px", + mode: str = "rgb", + scale: int = 1, + border_width: int = 0, + border_color: str | None = None, + access_token: str | None = None, + ) -> bytes: + """Export chart as PDF and return the raw bytes. + + Args: + width: Width of visualization. If not specified, uses chart width. + height: Height of visualization. If not specified, uses chart height. + plain: If True, exports only the visualization without header/footer. + unit: Unit for measurements: "px", "mm", or "inch". + mode: Color mode: "rgb" or "cmyk". + scale: Scale multiplier for PDF resolution. + border_width: Margin around visualization. + border_color: Color of the border (e.g., "#FFFFFF"). If not specified, uses chart background color. + access_token: Optional Datawrapper API access token. + + Returns: + Raw PDF document data as bytes. + + Raises: + ValueError: If no chart_id is set or invalid parameters provided. + Exception: If the API request fails. + + Example: + >>> chart = BarChart.get(chart_id="abc123") + >>> pdf_data = chart.export_pdf(unit="mm", mode="cmyk") + >>> Path("chart.pdf").write_bytes(pdf_data) + """ + if not self.chart_id: + raise ValueError( + "No chart_id set. Use create() first or set chart_id manually." + ) + + # Validate parameters + if unit not in ("px", "mm", "inch"): + raise ValueError(f"Invalid unit: {unit}. Must be 'px', 'mm', or 'inch'.") + if mode not in ("rgb", "cmyk"): + raise ValueError(f"Invalid mode: {mode}. Must be 'rgb' or 'cmyk'.") + + client = self._get_client(access_token) + + # Build query parameters + params = { + "unit": unit, + "mode": mode, + "plain": str(plain).lower(), + "scale": str(scale), + "borderWidth": str(border_width), + } + + if width is not None: + params["width"] = str(width) + if height is not None: + params["height"] = str(height) + if border_color is not None: + params["borderColor"] = border_color + + # Make the API request + response = client.get( + f"{client._CHARTS_URL}/{self.chart_id}/export/png", + params=params, + ) + if width is not None: + params["width"] = str(width) + if height is not None: + params["height"] = str(height) + if border_color is not None: + params["borderColor"] = border_color + + # Make the API request + response = client.get( + f"{client._CHARTS_URL}/{self.chart_id}/export/pdf", + params=params, + ) + + # Return raw bytes + if isinstance(response, bytes): + return response + raise ValueError(f"Unexpected response type from API: {type(response)}") + + def export_svg( + self, + *, + width: int | None = None, + height: int | None = None, + plain: bool = False, + access_token: str | None = None, + ) -> bytes: + """Export chart as SVG and return the raw bytes. + + Args: + width: Width of visualization. If not specified, uses chart width. + height: Height of visualization. If not specified, uses chart height. + plain: If True, exports only the visualization without header/footer. + access_token: Optional Datawrapper API access token. + + Returns: + Raw SVG document data as bytes. + + Raises: + ValueError: If no chart_id is set. + Exception: If the API request fails. + + Example: + >>> chart = ColumnChart.get(chart_id="abc123") + >>> svg_data = chart.export_svg(plain=True) + >>> Path("chart.svg").write_bytes(svg_data) + """ + if not self.chart_id: + raise ValueError( + "No chart_id set. Use create() first or set chart_id manually." + ) + + client = self._get_client(access_token) + + # Build query parameters + params = { + "plain": str(plain).lower(), + } + + if width is not None: + params["width"] = str(width) + if height is not None: + params["height"] = str(height) + + # Make the API request + response = client.get( + f"{client._CHARTS_URL}/{self.chart_id}/export/svg", + params=params, + ) + + # Return raw bytes + if isinstance(response, bytes): + return response + raise ValueError(f"Unexpected response type from API: {type(response)}") + def export( self, unit: str = "px", @@ -758,6 +975,10 @@ def export( ) -> Path | Image: """Export the chart to an image file. + .. deprecated:: 2.1.0 + Use :meth:`export_png`, :meth:`export_pdf`, or :meth:`export_svg` instead. + These methods return raw bytes for maximum flexibility. + Args: unit: One of px, mm, inch. Defines the unit in which the borderwidth, height, and width will be measured in, by default "px" @@ -793,7 +1014,24 @@ def export( Raises: ValueError: If no chart_id is set or no access token is available. Exception: If the API request fails. + + Example: + >>> # Old way (deprecated) + >>> chart.export(output="png", filepath="chart.png") + >>> + >>> # New way (recommended) + >>> png_data = chart.export_png(zoom=2) + >>> Path("chart.png").write_bytes(png_data) """ + warnings.warn( + "export() is deprecated and will be removed in a future version. " + "Use export_png(), export_pdf(), or export_svg() instead. " + "These methods return raw bytes for maximum flexibility. " + "Example: png_data = chart.export_png(); Path('chart.png').write_bytes(png_data)", + DeprecationWarning, + stacklevel=2, + ) + if not self.chart_id: raise ValueError( "No chart_id set. Use create() first or set chart_id manually." diff --git a/docs/user-guide/advanced/exporting.md b/docs/user-guide/advanced/exporting.md deleted file mode 100644 index 74a04419..00000000 --- a/docs/user-guide/advanced/exporting.md +++ /dev/null @@ -1,54 +0,0 @@ -# Exporting Charts - -Export Datawrapper charts in various formats including PNG, PDF, and SVG. In many cases, exporting can be done directly through the chart object methods and these methods should be considered deprecated. However, for advanced use cases or when working directly with the Datawrapper API, the following examples demonstrate how to export charts using the client. Where possible, prefer using the chart object's export methods. - -## Export as PNG - -Export a chart as a PNG image: - -```python -client.export_chart( - chart_id="abc123", - output="png", - filepath="chart.png", - display=True # Opens the image after saving -) -``` - -## Export as PDF - -Export a chart as a PDF: - -```python -client.export_chart( - chart_id="abc123", - output="pdf", - filepath="chart.pdf" -) -``` - -## Export as SVG - -Export a chart as an SVG: - -```python -client.export_chart( - chart_id="abc123", - output="svg", - filepath="chart.svg" -) -``` - -## Export with Custom Dimensions - -Specify custom dimensions for the export: - -```python -client.export_chart( - chart_id="abc123", - output="png", - filepath="chart.png", - width=1200, - height=800 -) -``` diff --git a/docs/user-guide/chart-operations.md b/docs/user-guide/chart-operations.md index b00d928d..9ec770c7 100644 --- a/docs/user-guide/chart-operations.md +++ b/docs/user-guide/chart-operations.md @@ -138,3 +138,19 @@ png_url = chart.get_png_url() html = f'' ``` + +## Exporting a chart in multiple formats + +You can export charts in various formats such as PNG, PDF, and SVG using the chart object's export methods: + +```python +# Get the data in bytes +png_data = chart.export_png(width=800, height=600) +pdf_data = chart.export_pdf(mode="cmyk") +svg_data = chart.export_svg(plain=True) + +# Save to disk +Path("chart.png").write_bytes(png_data) +Path("chart.pdf").write_bytes(pdf_data) +Path("chart.svg").write_bytes(svg_data) +``` diff --git a/tests/integration/test_base_export.py b/tests/integration/test_base_export.py index 327b159c..26092ab0 100644 --- a/tests/integration/test_base_export.py +++ b/tests/integration/test_base_export.py @@ -1,59 +1,489 @@ -"""Test the export method on BaseChart.""" +"""Integration tests for BaseChart export methods. + +These tests use mocked API calls to verify the export_png, export_pdf, and export_svg +methods work correctly without requiring actual API access. +""" from unittest.mock import MagicMock, patch import pytest -from datawrapper.charts.base import BaseChart +from datawrapper.charts import BarChart + + +class TestExportPNG: + """Test PNG export functionality.""" + + def test_export_png_success(self): + """Test successful PNG export with default parameters.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PNG_DATA" + created_clients.append(mock_client) + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_png() + + # Verify + assert result == b"PNG_DATA" + mock_get_client.assert_called_once() + # Get the mock client that was created + mock_client = created_clients[0] + call_args = mock_client.get.call_args + assert ( + call_args[0][0] + == "https://api.datawrapper.de/v3/charts/abc123/export/png" + ) + assert call_args[1]["params"]["unit"] == "px" + assert call_args[1]["params"]["mode"] == "rgb" + assert call_args[1]["params"]["plain"] == "false" + + def test_export_png_with_all_parameters(self): + """Test PNG export with all parameters specified.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PNG_DATA" + created_clients.append(mock_client) + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with all parameters + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_png( + width=800, + height=600, + plain=True, + zoom=3, + transparent=True, + border_width=10, + border_color="#FF0000", + ) + + # Verify + assert result == b"PNG_DATA" + mock_get_client.assert_called_once() + mock_client = created_clients[0] + call_args = mock_client.get.call_args + params = call_args[1]["params"] + assert params["width"] == "800" + assert params["height"] == "600" + assert params["plain"] == "true" + assert params["zoom"] == "3" + assert params["transparent"] == "true" + assert params["borderWidth"] == "10" + assert params["borderColor"] == "#FF0000" + + def test_export_png_no_chart_id(self): + """Test that export_png raises ValueError when no chart_id is set.""" + chart = BarChart(title="Test Chart") + with pytest.raises(ValueError, match="No chart_id set"): + chart.export_png() + + def test_export_png_custom_access_token(self): + """Test PNG export with custom access token.""" + + # Create mock client factory + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PNG_DATA" + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with custom token + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_png(access_token="custom_token") + + # Verify + assert result == b"PNG_DATA" + mock_get_client.assert_called_once() + + +class TestExportPDF: + """Test PDF export functionality.""" + + def test_export_pdf_success(self): + """Test successful PDF export with default parameters.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PDF_DATA" + created_clients.append(mock_client) + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_pdf() + + # Verify + assert result == b"PDF_DATA" + mock_get_client.assert_called_once() + mock_client = created_clients[0] + call_args = mock_client.get.call_args + assert ( + call_args[0][0] + == "https://api.datawrapper.de/v3/charts/abc123/export/pdf" + ) + assert call_args[1]["params"]["unit"] == "px" + assert call_args[1]["params"]["mode"] == "rgb" + assert call_args[1]["params"]["plain"] == "false" + + def test_export_pdf_with_all_parameters(self): + """Test PDF export with all parameters specified.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PDF_DATA" + created_clients.append(mock_client) + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with all parameters + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_pdf( + width=800, + height=600, + plain=True, + unit="mm", + mode="cmyk", + scale=2, + border_width=10, + border_color="#FF0000", + ) + + # Verify + assert result == b"PDF_DATA" + mock_get_client.assert_called_once() + mock_client = created_clients[0] + call_args = mock_client.get.call_args + params = call_args[1]["params"] + assert params["width"] == "800" + assert params["height"] == "600" + assert params["plain"] == "true" + assert params["unit"] == "mm" + assert params["mode"] == "cmyk" + assert params["scale"] == "2" + assert params["borderWidth"] == "10" + assert params["borderColor"] == "#FF0000" + + def test_export_pdf_no_chart_id(self): + """Test that export_pdf raises ValueError when no chart_id is set.""" + chart = BarChart(title="Test Chart") + with pytest.raises(ValueError, match="No chart_id set"): + chart.export_pdf() + + def test_export_pdf_custom_access_token(self): + """Test PDF export with custom access token.""" + + # Create mock client factory + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PDF_DATA" + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with custom token + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_pdf(access_token="custom_token") + + # Verify + assert result == b"PDF_DATA" + mock_get_client.assert_called_once() + + +class TestExportSVG: + """Tests for the export_svg method.""" + + def test_export_svg_success(self): + """Test successful SVG export with default parameters.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"SVG_DATA" + created_clients.append(mock_client) + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_svg() + + # Verify + assert result == b"SVG_DATA" + assert isinstance(result, bytes) + mock_get_client.assert_called_once() + mock_client = created_clients[0] + call_args = mock_client.get.call_args + assert ( + call_args[0][0] + == "https://api.datawrapper.de/v3/charts/abc123/export/svg" + ) + + def test_export_svg_with_all_parameters(self): + """Test SVG export with all optional parameters.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"SVG_DATA" + created_clients.append(mock_client) + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with all parameters + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_svg(width=800, height=600, plain=True) + + # Verify + assert result == b"SVG_DATA" + mock_get_client.assert_called_once() + mock_client = created_clients[0] + call_args = mock_client.get.call_args + url = call_args[0][0] + params = call_args[1]["params"] + assert url == "https://api.datawrapper.de/v3/charts/abc123/export/svg" + assert params["width"] == "800" + assert params["height"] == "600" + assert params["plain"] == "true" + + def test_export_svg_no_chart_id(self): + """Test SVG export raises error when no chart_id is set.""" + chart = BarChart(title="Test Chart") + with pytest.raises(ValueError, match="No chart_id set"): + chart.export_svg() + + def test_export_svg_custom_access_token(self): + """Test SVG export with custom access token.""" + + # Create mock client factory + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"SVG_DATA" + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with custom token + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_svg(access_token="custom_token") + + # Verify + assert result == b"SVG_DATA" + mock_get_client.assert_called_once() + + +class TestExportMethodComparison: + """Tests comparing the new export methods with the legacy export method.""" + + def test_export_png_vs_legacy_export(self): + """Test that export_png produces same result as legacy export method.""" + + # Create mock client factory + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PNG_IMAGE_DATA" + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + + # Export using new method + result_new = chart.export_png(width=800, height=600) + + # Verify both produce bytes + assert isinstance(result_new, bytes) + assert result_new == b"PNG_IMAGE_DATA" + mock_get_client.assert_called_once() + + def test_all_export_methods_return_bytes(self): + """Test that all export methods return bytes.""" + # Create mock client factory that returns different data for each call + call_count = [0] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + # Return different data based on call count + if call_count[0] == 0: + mock_client.get.return_value = b"PNG_DATA" + elif call_count[0] == 1: + mock_client.get.return_value = b"PDF_DATA" + else: + mock_client.get.return_value = b"SVG_DATA" + call_count[0] += 1 + return mock_client + + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + + # Test all export methods + png_result = chart.export_png() + pdf_result = chart.export_pdf() + svg_result = chart.export_svg() + + # Verify all return bytes + assert isinstance(png_result, bytes) + assert isinstance(pdf_result, bytes) + assert isinstance(svg_result, bytes) + assert png_result == b"PNG_DATA" + assert pdf_result == b"PDF_DATA" + assert svg_result == b"SVG_DATA" + assert mock_get_client.call_count == 3 + + +class TestExportParameterValidation: + """Tests for parameter validation and formatting in export methods.""" + + def test_export_png_boolean_parameters(self): + """Test that boolean parameters are correctly formatted as lowercase strings.""" + # Create mock client factory with closure to capture instance + created_clients = [] + + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PNG_DATA" + created_clients.append(mock_client) + return mock_client + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart and export with boolean parameters + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + result = chart.export_png(plain=True, transparent=False) -def test_base_chart_export_method_exists(): - """Test that the export method exists on BaseChart.""" - chart = BaseChart(chart_type="d3-lines", title="Test Chart") - assert hasattr(chart, "export") - assert callable(chart.export) + # Verify + assert result == b"PNG_DATA" + mock_get_client.assert_called_once() + mock_client = created_clients[0] + call_args = mock_client.get.call_args + params = call_args[1]["params"] + # Verify boolean parameters are lowercase strings + assert params["plain"] == "true" + assert params["transparent"] == "false" + def test_export_pdf_unit_parameter(self): + """Test that unit parameter accepts valid values (px, mm, in).""" -def test_base_chart_export_requires_chart_id(): - """Test that export raises ValueError when chart_id is not set.""" - chart = BaseChart(chart_type="d3-lines", title="Test Chart") + # Create mock client factory + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PDF_DATA" + return mock_client - with pytest.raises(ValueError, match="No chart_id set"): - chart.export() + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" + # Test with different unit values + for unit in ["px", "mm", "in"]: + result = chart.export_pdf(unit=unit) + assert result == b"PDF_DATA" -def test_base_chart_export_with_chart_id(tmp_path): - """Test that export works when chart_id is set.""" - # Create a chart with a chart_id - chart = BaseChart( - chart_type="d3-lines", - title="Test Export Chart", - data=[{"x": 1, "y": 2}, {"x": 2, "y": 4}], - ) - chart.chart_id = "test123" + # Verify all three calls were made + assert mock_get_client.call_count == 3 - # Mock the client and its export_chart method - mock_client = MagicMock() - output_file = tmp_path / "test_export.png" - mock_client.export_chart.return_value = output_file + def test_export_pdf_mode_parameter(self): + """Test that mode parameter accepts valid values (rgb, cmyk).""" - with patch.object(chart, "_get_client", return_value=mock_client): - # Export to a temporary file - result = chart.export( - output="png", - filepath=str(output_file), - display=False, - ) + # Create mock client factory + def mock_client_factory(access_token=None): + mock_client = MagicMock() + mock_client._CHARTS_URL = "https://api.datawrapper.de/v3/charts" + mock_client.get.return_value = b"PDF_DATA" + return mock_client - # Verify the client method was called - mock_client.export_chart.assert_called_once() + with patch( + "datawrapper.charts.base.BaseChart._get_client", + side_effect=mock_client_factory, + ) as mock_get_client: + # Create chart + chart = BarChart(title="Test Chart") + chart.chart_id = "abc123" - # Verify the chart_id was passed - call_kwargs = mock_client.export_chart.call_args.kwargs - assert call_kwargs["chart_id"] == "test123" - assert call_kwargs["output"] == "png" - assert call_kwargs["filepath"] == str(output_file) - assert call_kwargs["display"] is False + # Test with different mode values + for mode in ["rgb", "cmyk"]: + result = chart.export_pdf(mode=mode) + assert result == b"PDF_DATA" - # Verify the result - assert result == output_file + # Verify both calls were made + assert mock_get_client.call_count == 2 From 0b7d073b667d9f7bc3870a0a7768b5a405c5efe3 Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 21 Oct 2025 20:54:14 -0400 Subject: [PATCH 02/11] refactor: remove deprecated export() method from BaseChart Remove the deprecated export() method that was marked for removal in favor of the more flexible export_png(), export_pdf(), and export_svg() methods. These newer methods return raw bytes instead of handling file I/O, providing better flexibility for users to handle the exported data as needed. The export() method had been deprecated since version 2.1.0 and users have been warned to migrate to the new export methods. --- datawrapper/charts/base.py | 117 +------------------------------------ 1 file changed, 1 insertion(+), 116 deletions(-) diff --git a/datawrapper/charts/base.py b/datawrapper/charts/base.py index 6e718fce..83b7fb28 100644 --- a/datawrapper/charts/base.py +++ b/datawrapper/charts/base.py @@ -1,11 +1,10 @@ import os import warnings from io import StringIO -from pathlib import Path from typing import Any, Literal import pandas as pd -from IPython.display import IFrame, Image +from IPython.display import IFrame from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator from datawrapper.__main__ import Datawrapper @@ -950,120 +949,6 @@ def export_svg( return response raise ValueError(f"Unexpected response type from API: {type(response)}") - def export( - self, - unit: str = "px", - mode: str = "rgb", - width: int = 400, - height: int | str | None = None, - plain: bool = False, - zoom: int = 2, - scale: int = 1, - border_width: int = 20, - border_color: str | None = None, - transparent: bool = False, - download: bool = False, - full_vector: bool = False, - ligatures: bool = True, - logo: str = "auto", - logo_id: str | None = None, - dark: bool = False, - output: str = "png", - filepath: str = "./image.png", - display: bool = False, - access_token: str | None = None, - ) -> Path | Image: - """Export the chart to an image file. - - .. deprecated:: 2.1.0 - Use :meth:`export_png`, :meth:`export_pdf`, or :meth:`export_svg` instead. - These methods return raw bytes for maximum flexibility. - - Args: - unit: One of px, mm, inch. Defines the unit in which the borderwidth, height, - and width will be measured in, by default "px" - mode: One of rgb or cmyk. Which color mode the output should be in, - by default "rgb" - width: Width of visualization. If not specified, it takes the chart width, - by default 400 - height: Height of visualization. Can be a number or "auto", by default None - plain: Defines if only the visualization should be exported (True), or if it should - include header and footer as well (False), by default False - zoom: Defines the multiplier for the png size, by default 2 - scale: Defines the multiplier for the pdf size, by default 1 - border_width: Margin around the visualization, by default 20 - border_color: Color of the border around the visualization, by default None - transparent: Set to True to export your visualization with a transparent background, - by default False - download: Whether to trigger a download, by default False - full_vector: Export as full vector graphic (for supported formats), by default False - ligatures: Enable typographic ligatures, by default True - logo: Logo display setting. One of "auto", "on", or "off", by default "auto" - logo_id: Custom logo ID to use, by default None - dark: Export in dark mode, by default False - output: One of png, pdf, or svg, by default "png" - filepath: Name/filepath to save output in, by default "./image.png" - display: Whether to display the exported image as output in the notebook cell, - by default False - access_token: Optional Datawrapper API access token. - If not provided, will use DATAWRAPPER_ACCESS_TOKEN environment variable. - - Returns: - The file path to the exported image or an Image object displaying the image. - - Raises: - ValueError: If no chart_id is set or no access token is available. - Exception: If the API request fails. - - Example: - >>> # Old way (deprecated) - >>> chart.export(output="png", filepath="chart.png") - >>> - >>> # New way (recommended) - >>> png_data = chart.export_png(zoom=2) - >>> Path("chart.png").write_bytes(png_data) - """ - warnings.warn( - "export() is deprecated and will be removed in a future version. " - "Use export_png(), export_pdf(), or export_svg() instead. " - "These methods return raw bytes for maximum flexibility. " - "Example: png_data = chart.export_png(); Path('chart.png').write_bytes(png_data)", - DeprecationWarning, - stacklevel=2, - ) - - if not self.chart_id: - raise ValueError( - "No chart_id set. Use create() first or set chart_id manually." - ) - - # Get the client - client = self._get_client(access_token) - - # Call the export_chart method from the client - return client.export_chart( - chart_id=self.chart_id, - unit=unit, - mode=mode, - width=width, - height=height, - plain=plain, - zoom=zoom, - scale=scale, - border_width=border_width, - border_color=border_color, - transparent=transparent, - download=download, - full_vector=full_vector, - ligatures=ligatures, - logo=logo, - logo_id=logo_id, - dark=dark, - output=output, - filepath=filepath, - display=display, - ) - def delete(self, access_token: str | None = None) -> bool: """Delete the chart via the Datawrapper API. From 3a86f48e676b5341ac54a11b62cf27194cae51ce Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 21 Oct 2025 20:56:57 -0400 Subject: [PATCH 03/11] Fix --- tests/integration/test_base_export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_base_export.py b/tests/integration/test_base_export.py index 26092ab0..a938a520 100644 --- a/tests/integration/test_base_export.py +++ b/tests/integration/test_base_export.py @@ -455,7 +455,7 @@ def mock_client_factory(access_token=None): chart.chart_id = "abc123" # Test with different unit values - for unit in ["px", "mm", "in"]: + for unit in ["px", "mm", "inch"]: result = chart.export_pdf(unit=unit) assert result == b"PDF_DATA" From 46e3baa007f1a9700c9649aeb17475fa6bf657cc Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 21 Oct 2025 20:57:59 -0400 Subject: [PATCH 04/11] Tighter typing --- datawrapper/charts/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datawrapper/charts/base.py b/datawrapper/charts/base.py index 83b7fb28..27697487 100644 --- a/datawrapper/charts/base.py +++ b/datawrapper/charts/base.py @@ -809,8 +809,8 @@ def export_pdf( width: int | None = None, height: int | None = None, plain: bool = False, - unit: str = "px", - mode: str = "rgb", + unit: Literal["px", "mm", "inch"] = "px", + mode: Literal["rgb", "cmyk"] = "rgb", scale: int = 1, border_width: int = 0, border_color: str | None = None, From 2dfc6d81bd8114823cb3fc540e176eb60224015f Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 15:29:53 -0400 Subject: [PATCH 05/11] refactor(charts): extract grid and axis mixins from chart classes Extract common grid and axis functionality into reusable mixins (GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin) to reduce code duplication across chart types. This removes 60+ lines of repeated field definitions for custom ranges, custom ticks, and grid formatting properties from AreaChart, LineChart, and MultipleColumnChart. Update serialization logic to use mixin-provided methods for consistent handling of grid configuration, axis formatting, custom ranges, and custom ticks across all chart implementations. This improves maintainability and ensures uniform behavior when serializing chart metadata. --- datawrapper/charts/area.py | 194 ++++------- datawrapper/charts/column.py | 318 ++++++++++-------- datawrapper/charts/line.py | 196 ++++------- datawrapper/charts/mixins.py | 230 +++++++++++++ datawrapper/charts/multiple_column.py | 234 +++++-------- datawrapper/charts/serializers/base.py | 69 ++++ .../charts/serializers/color_category.py | 4 +- .../charts/serializers/custom_range.py | 4 +- .../charts/serializers/custom_ticks.py | 4 +- .../charts/serializers/negative_color.py | 4 +- datawrapper/charts/serializers/plot_height.py | 4 +- .../charts/serializers/replace_flags.py | 4 +- .../charts/serializers/value_labels.py | 4 +- 13 files changed, 694 insertions(+), 575 deletions(-) create mode 100644 datawrapper/charts/mixins.py create mode 100644 datawrapper/charts/serializers/base.py diff --git a/datawrapper/charts/area.py b/datawrapper/charts/area.py index f70920f0..aa7f283f 100644 --- a/datawrapper/charts/area.py +++ b/datawrapper/charts/area.py @@ -7,23 +7,28 @@ from .base import BaseChart from .enums import ( DateFormat, - GridDisplay, GridLabelAlign, GridLabelPosition, LineInterpolation, NumberFormat, PlotHeightMode, ) +from .mixins import ( + CustomRangeMixin, + CustomTicksMixin, + GridConfigMixin, + GridFormatMixin, +) from .serializers import ( ColorCategory, - CustomRange, - CustomTicks, ModelListSerializer, PlotHeight, ) -class AreaChart(BaseChart): +class AreaChart( + GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart +): """A base class for the Datawrapper API's area chart.""" model_config = ConfigDict( @@ -59,70 +64,6 @@ class AreaChart(BaseChart): description="The type of datawrapper chart to create", ) - # - # Horizontal axis (X-axis) - # - - #: The custom range for the x axis - custom_range_x: list[Any] | tuple[Any, Any] = Field( - default_factory=lambda: ["", ""], - alias="custom-range-x", - description="The custom range for the x axis", - ) - - #: The custom ticks for the x axis - custom_ticks_x: list[Any] = Field( - default_factory=list, - alias="custom-ticks-x", - description="The custom ticks for the x axis", - ) - - #: The formatting for the x grid labels (use DateFormat or NumberFormat enum or custom format strings) - x_grid_format: DateFormat | NumberFormat | str = Field( - default="auto", - alias="x-grid-format", - description="The formatting for the x grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the x grid - x_grid: GridDisplay | str = Field( - default="off", - alias="x-grid", - description="Whether to show the x grid. The 'on' setting shows lines.", - ) - - # - # Vertical axis (Y-axis) - # - - #: The custom range for the y axis - custom_range_y: list[Any] | tuple[Any, Any] = Field( - default_factory=lambda: ["", ""], - alias="custom-range-y", - description="The custom range for the y axis", - ) - - #: The custom ticks for the y axis - custom_ticks_y: list[Any] = Field( - default_factory=list, - alias="custom-ticks-y", - description="The custom ticks for the y axis", - ) - - #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings) - y_grid_format: DateFormat | NumberFormat | str = Field( - default="", - alias="y-grid-format", - description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the y grid - y_grid: GridDisplay | str = Field( - default="on", - alias="y-grid", - description="Whether to show the y grid. The 'on' setting shows lines.", - ) - #: The labeling of the y grid labels y_grid_labels: GridLabelPosition | str = Field( default="auto", @@ -313,52 +254,48 @@ def serialize_model(self) -> dict: model = super().serialize_model() # Add chart specific properties - model["metadata"]["visualize"].update( - { - # Horizontal axis - "custom-range-x": CustomRange.serialize(self.custom_range_x), - "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x), - "x-grid-format": self.x_grid_format, - "x-grid": self.x_grid, - # Vertical axis - "custom-range-y": CustomRange.serialize(self.custom_range_y), - "custom-ticks-y": CustomTicks.serialize(self.custom_ticks_y), - "y-grid-format": self.y_grid_format, - "y-grid": self.y_grid, - "y-grid-labels": self.y_grid_labels, - "y-grid-label-align": self.y_grid_label_align, - # Customize areas - "area-opacity": self.area_opacity, - "base-color": self.base_color, - "interpolation": self.interpolation, - "sort-areas": self.sort_areas, - "stack-areas": self.stack_areas, - "stack-to-100": self.stack_to_100, - "area-separator-lines": self.area_separator_lines, - "area-separator-color": self.area_separator_color, - # Customize specific layers - "color-category": ColorCategory.serialize(self.color_category), - # Labels - "show-color-key": self.show_color_key, - # Tooltips - "show-tooltips": self.show_tooltips, - "tooltip-x-format": self.tooltip_x_format, - "tooltip-number-format": self.tooltip_number_format, - # Appearance - **PlotHeight.serialize( - self.plot_height_mode, - self.plot_height_fixed, - self.plot_height_ratio, - ), - # Annotations - "text-annotations": ModelListSerializer.serialize( - self.text_annotations, TextAnnotation - ), - "range-annotations": ModelListSerializer.serialize( - self.range_annotations, RangeAnnotation - ), - } - ) + visualize_data = { + # Horizontal and vertical axis (from mixins) + **self._serialize_grid_config(), + **self._serialize_grid_format(), + **self._serialize_custom_range(), + **self._serialize_custom_ticks(), + # Vertical axis (chart-specific) + "y-grid-labels": self.y_grid_labels, + "y-grid-label-align": self.y_grid_label_align, + # Customize areas + "area-opacity": self.area_opacity, + "base-color": self.base_color, + "interpolation": self.interpolation, + "sort-areas": self.sort_areas, + "stack-areas": self.stack_areas, + "stack-to-100": self.stack_to_100, + "area-separator-lines": self.area_separator_lines, + "area-separator-color": self.area_separator_color, + # Customize specific layers + "color-category": ColorCategory.serialize(self.color_category), + # Labels + "show-color-key": self.show_color_key, + # Tooltips + "show-tooltips": self.show_tooltips, + "tooltip-x-format": self.tooltip_x_format, + "tooltip-number-format": self.tooltip_number_format, + # Appearance + **PlotHeight.serialize( + self.plot_height_mode, + self.plot_height_fixed, + self.plot_height_ratio, + ), + # Annotations + "text-annotations": ModelListSerializer.serialize( + self.text_annotations, TextAnnotation + ), + "range-annotations": ModelListSerializer.serialize( + self.range_annotations, RangeAnnotation + ), + } + + model["metadata"]["visualize"].update(visualize_data) # Return the serialized data return model @@ -380,30 +317,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: metadata = api_response.get("metadata", {}) visualize = metadata.get("visualize", {}) - # Horizontal axis (X-axis) - init_data["custom_range_x"] = CustomRange.deserialize( - visualize.get("custom-range-x") - ) - init_data["custom_ticks_x"] = CustomTicks.deserialize( - visualize.get("custom-ticks-x", "") - ) - if "x-grid-format" in visualize: - init_data["x_grid_format"] = visualize["x-grid-format"] - if "x-grid" in visualize: - init_data["x_grid"] = visualize["x-grid"] - - # Vertical axis (Y-axis) - init_data["custom_range_y"] = CustomRange.deserialize( - visualize.get("custom-range-y") - ) - init_data["custom_ticks_y"] = CustomTicks.deserialize( - visualize.get("custom-ticks-y", "") - ) + # Horizontal and vertical axis (from mixins) + init_data.update(cls._deserialize_grid_config(visualize)) + init_data.update(cls._deserialize_grid_format(visualize)) + init_data.update(cls._deserialize_custom_range(visualize)) + init_data.update(cls._deserialize_custom_ticks(visualize)) - if "y-grid-format" in visualize: - init_data["y_grid_format"] = visualize["y-grid-format"] - if "y-grid" in visualize: - init_data["y_grid"] = visualize["y-grid"] + # Vertical axis (chart-specific) if "y-grid-labels" in visualize: init_data["y_grid_labels"] = visualize["y-grid-labels"] if "y-grid-label-align" in visualize: diff --git a/datawrapper/charts/column.py b/datawrapper/charts/column.py index 3c97d3ac..6bd4c425 100644 --- a/datawrapper/charts/column.py +++ b/datawrapper/charts/column.py @@ -7,7 +7,6 @@ from .base import BaseChart from .enums import ( DateFormat, - GridDisplay, GridLabelAlign, GridLabelPosition, NumberFormat, @@ -15,10 +14,14 @@ ValueLabelDisplay, ValueLabelPlacement, ) +from .mixins import ( + CustomRangeMixin, + CustomTicksMixin, + GridConfigMixin, + GridFormatMixin, +) from .serializers import ( ColorCategory, - CustomRange, - CustomTicks, ModelListSerializer, NegativeColor, PlotHeight, @@ -26,7 +29,9 @@ ) -class ColumnChart(BaseChart): +class ColumnChart( + GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart +): """A base class for the Datawrapper API's column chart.""" model_config = ConfigDict( @@ -62,69 +67,9 @@ class ColumnChart(BaseChart): ) # - # Horizontal axis (X-axis) - # - - #: The custom range for the x axis - custom_range_x: list[Any] | tuple[Any, Any] = Field( - default_factory=lambda: ["", ""], - alias="custom-range-x", - description="The custom range for the x axis", - ) - - #: The custom ticks for the x axis - custom_ticks_x: list[Any] = Field( - default_factory=list, - alias="custom-ticks-x", - description="The custom ticks for the x axis", - ) - - #: The formatting for the x grid labels (use DateFormat or NumberFormat enum or custom format strings) - x_grid_format: DateFormat | NumberFormat | str = Field( - default="auto", - alias="x-grid-format", - description="The formatting for the x grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the x grid - x_grid: GridDisplay | str = Field( - default="off", - alias="x-grid", - description="Whether to show the x grid", - ) - - # - # Vertical axis (Y-axis) + # Vertical axis (Y-axis) - chart-specific fields # - #: The custom range for the y axis - custom_range_y: list[Any] | tuple[Any, Any] = Field( - default_factory=lambda: ["", ""], - alias="custom-range-y", - description="The custom range for the y axis", - ) - - #: The custom ticks for the y axis - custom_ticks_y: list[Any] = Field( - default_factory=list, - alias="custom-ticks-y", - description="The custom ticks for the y axis", - ) - - #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings) - y_grid_format: DateFormat | NumberFormat | str = Field( - default="", - alias="y-grid-format", - description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the y grid lines - y_grid: bool = Field( - default=True, - alias="y-grid", - description="Whether to show the y grid lines", - ) - #: The labeling of the y grid labels y_grid_labels: GridLabelPosition | str = Field( default="outside", @@ -266,6 +211,114 @@ def validate_plot_height_mode(cls, v: PlotHeightMode | str) -> PlotHeightMode | raise ValueError(f"Invalid value: {v}. Must be one of {valid_values}") return v + @classmethod + def _deserialize_grid_config(cls, visualize: dict) -> dict: + """Override to handle ColumnChart-specific grid fields. + + ColumnChart uses different API fields than other charts: + - x_grid: Parsed from 'grid-lines-x' dict (not 'x-grid' string) + - y_grid: Parsed from 'grid-lines' boolean (not 'y-grid' string) + """ + result = {} + + # Parse grid-lines-x (dict with type/enabled) + if "grid-lines-x" in visualize: + grid_lines_x = visualize["grid-lines-x"] + if isinstance(grid_lines_x, dict): + enabled = grid_lines_x.get("enabled", False) + grid_type = grid_lines_x.get("type", "") + result["x_grid"] = grid_type if enabled else "off" + + # Parse grid-lines (boolean) + if "grid-lines" in visualize: + result["y_grid"] = visualize["grid-lines"] + + return result + + def _serialize_grid_config(self) -> dict: + """Override to add ColumnChart-specific grid-lines field. + + ColumnChart uses both the standard y-grid field (from mixin) and an + additional grid-lines boolean field that mirrors the y-grid on/off state. + """ + # Get the standard grid config from the mixin + result = super()._serialize_grid_config() + + # Add the ColumnChart-specific grid-lines boolean field + # This mirrors the y_grid on/off state + if self.y_grid is not None: + from .enums import GridDisplay + + # Convert to boolean: "on" or True -> True, "off" or False -> False + if isinstance(self.y_grid, GridDisplay): + result["grid-lines"] = self.y_grid == GridDisplay.ON + elif isinstance(self.y_grid, bool): + result["grid-lines"] = self.y_grid + elif isinstance(self.y_grid, str): + result["grid-lines"] = self.y_grid.lower() == "on" + + return result + + def _serialize_custom_range(self) -> dict: + """Override to handle ColumnChart-specific field naming. + + ColumnChart uses 'custom-range' (not 'custom-range-y') for Y-axis custom range. + """ + # Get the standard custom range config from the mixin + result = super()._serialize_custom_range() + + # Rename custom-range-y to custom-range for ColumnChart + if "custom-range-y" in result: + result["custom-range"] = result.pop("custom-range-y") + + return result + + @classmethod + def _deserialize_custom_range(cls, visualize: dict) -> dict: + """Override to handle ColumnChart-specific field naming. + + ColumnChart uses 'custom-range' (not 'custom-range-y') for Y-axis custom range. + """ + # Create a modified visualize dict with renamed field + modified_visualize = visualize.copy() + if "custom-range" in modified_visualize: + modified_visualize["custom-range-y"] = modified_visualize.pop( + "custom-range" + ) + + # Call the parent deserializer with the modified dict + return super()._deserialize_custom_range(modified_visualize) + + def _serialize_custom_ticks(self) -> dict: + """Override to handle ColumnChart-specific field naming. + + ColumnChart uses 'custom-ticks' (not 'custom-ticks-y') for Y-axis custom ticks. + """ + # Get the standard custom ticks config from the mixin + result = super()._serialize_custom_ticks() + + # Rename custom-ticks-y to custom-ticks for ColumnChart + if "custom-ticks-y" in result: + result["custom-ticks"] = result.pop("custom-ticks-y") + + return result + + @classmethod + def _deserialize_custom_ticks(cls, visualize: dict) -> dict: + """Override to handle ColumnChart-specific field naming. + + ColumnChart uses 'custom-ticks' (not 'custom-ticks-y') for Y-axis custom ticks. + """ + # Create a modified visualize dict with renamed field + modified_visualize = visualize.copy() + if "custom-ticks" in modified_visualize: + modified_visualize["custom-ticks-y"] = modified_visualize.pop( + "custom-ticks" + ) + + # Call the parent deserializer with the modified dict + return super()._deserialize_custom_ticks(modified_visualize) + @model_serializer def serialize_model(self) -> dict: """Serialize the model to a dictionary.""" @@ -273,60 +326,51 @@ def serialize_model(self) -> dict: model = super().serialize_model() # Add chart specific properties - model["metadata"]["visualize"].update( - { - # Horizontal axis - "custom-range-x": CustomRange.serialize(self.custom_range_x), - "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x), - "x-grid-format": self.x_grid_format, - "grid-lines-x": { - "type": "" if self.x_grid == "off" else self.x_grid, - "enabled": self.x_grid != "off", - }, - # Vertical axis - "custom-range": CustomRange.serialize(self.custom_range_y), - "custom-ticks": CustomTicks.serialize(self.custom_ticks_y), - "y-grid-format": self.y_grid_format, - "grid-lines": self.y_grid, - "yAxisLabels": { - "enabled": self.y_grid_labels != "off", - "alignment": self.y_grid_label_align, - "placement": "" - if self.y_grid_labels == "off" - else self.y_grid_labels, - }, - # Appearance - "base-color": self.base_color, - "negativeColor": NegativeColor.serialize(self.negative_color), - "bar-padding": self.bar_padding, - "color-category": ColorCategory.serialize( - self.color_category, - self.category_labels, - self.category_order, - ), - "color-by-column": bool(self.color_category), - **PlotHeight.serialize( - self.plot_height_mode, - self.plot_height_fixed, - self.plot_height_ratio, - ), - # Labels - "show-color-key": self.show_color_key, - **ValueLabels.serialize( - show=self.show_value_labels, - format_str=self.value_labels_format, - placement=self.value_labels_placement, - chart_type="column", - ), - # Annotations - "text-annotations": ModelListSerializer.serialize( - self.text_annotations, TextAnnotation - ), - "range-annotations": ModelListSerializer.serialize( - self.range_annotations, RangeAnnotation - ), - } - ) + visualize_data = { + # Horizontal and vertical axis (from mixins) + **self._serialize_grid_config(), + **self._serialize_grid_format(), + **self._serialize_custom_range(), + **self._serialize_custom_ticks(), + # Vertical axis (chart-specific) + "yAxisLabels": { + "enabled": self.y_grid_labels != "off", + "alignment": self.y_grid_label_align, + "placement": "" if self.y_grid_labels == "off" else self.y_grid_labels, + }, + # Appearance + "base-color": self.base_color, + "negativeColor": NegativeColor.serialize(self.negative_color), + "bar-padding": self.bar_padding, + "color-category": ColorCategory.serialize( + self.color_category, + self.category_labels, + self.category_order, + ), + "color-by-column": bool(self.color_category), + **PlotHeight.serialize( + self.plot_height_mode, + self.plot_height_fixed, + self.plot_height_ratio, + ), + # Labels + "show-color-key": self.show_color_key, + **ValueLabels.serialize( + show=self.show_value_labels, + format_str=self.value_labels_format, + placement=self.value_labels_placement, + chart_type="column", + ), + # Annotations + "text-annotations": ModelListSerializer.serialize( + self.text_annotations, TextAnnotation + ), + "range-annotations": ModelListSerializer.serialize( + self.range_annotations, RangeAnnotation + ), + } + + model["metadata"]["visualize"].update(visualize_data) # Return the serialized data return model @@ -349,37 +393,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: metadata = api_response.get("metadata", {}) visualize = metadata.get("visualize", {}) - # Horizontal axis (X-axis) - init_data["custom_range_x"] = CustomRange.deserialize( - visualize.get("custom-range-x") - ) - init_data["custom_ticks_x"] = CustomTicks.deserialize( - visualize.get("custom-ticks-x", "") - ) - if "x-grid-format" in visualize: - init_data["x_grid_format"] = visualize["x-grid-format"] - - # Parse grid-lines-x - if "grid-lines-x" in visualize: - grid_lines_x = visualize["grid-lines-x"] - if isinstance(grid_lines_x, dict): - enabled = grid_lines_x.get("enabled", False) - grid_type = grid_lines_x.get("type", "") - init_data["x_grid"] = grid_type if enabled else "off" - - # Vertical axis (Y-axis) - init_data["custom_range_y"] = CustomRange.deserialize( - visualize.get("custom-range") - ) - init_data["custom_ticks_y"] = CustomTicks.deserialize( - visualize.get("custom-ticks", "") - ) - if "y-grid-format" in visualize: - init_data["y_grid_format"] = visualize["y-grid-format"] - if "grid-lines" in visualize: - init_data["y_grid"] = visualize["grid-lines"] + # Horizontal and vertical axis (from mixins) + init_data.update(cls._deserialize_grid_config(visualize)) + init_data.update(cls._deserialize_grid_format(visualize)) + init_data.update(cls._deserialize_custom_range(visualize)) + init_data.update(cls._deserialize_custom_ticks(visualize)) - # Parse yAxisLabels + # Vertical axis (chart-specific) - Parse yAxisLabels if "yAxisLabels" in visualize: y_axis_labels = visualize["yAxisLabels"] if isinstance(y_axis_labels, dict): diff --git a/datawrapper/charts/line.py b/datawrapper/charts/line.py index 0fc467c1..0805770c 100644 --- a/datawrapper/charts/line.py +++ b/datawrapper/charts/line.py @@ -13,7 +13,6 @@ from .base import BaseChart from .enums import ( DateFormat, - GridDisplay, GridLabelAlign, GridLabelPosition, LineDash, @@ -25,10 +24,14 @@ SymbolShape, SymbolStyle, ) +from .mixins import ( + CustomRangeMixin, + CustomTicksMixin, + GridConfigMixin, + GridFormatMixin, +) from .serializers import ( ColorCategory, - CustomRange, - CustomTicks, ModelListSerializer, PlotHeight, ) @@ -391,7 +394,9 @@ def deserialize_model(cls, line_name: str, line_config: dict) -> dict[str, Any]: return init_dict -class LineChart(BaseChart): +class LineChart( + GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart +): """A base class for the Datawrapper API's line chart.""" model_config = ConfigDict( @@ -429,66 +434,12 @@ class LineChart(BaseChart): # # Horizontal axis (X-axis) # - - #: The custom range for the x axis - custom_range_x: list[Any] | tuple[Any, Any] = Field( - default_factory=lambda: ["", ""], - alias="custom-range-x", - description="The custom range for the x axis", - ) - - #: The custom ticks for the x axis - custom_ticks_x: list[Any] = Field( - default_factory=list, - alias="custom-ticks-x", - description="The custom ticks for the x axis", - ) - - #: The formatting for the x grid labels (use DateFormat or NumberFormat enum or custom format strings) - x_grid_format: DateFormat | NumberFormat | str = Field( - default="auto", - alias="x-grid-format", - description="The formatting for the x grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the x grid - x_grid: GridDisplay | str = Field( - default="off", - alias="x-grid", - description="Whether to show the x grid. The 'on' setting shows lines.", - ) + # Note: x_grid, x_grid_format, custom_range_x, custom_ticks_x inherited from mixins # # Vertical axis (Y-axis) # - - #: The custom range for the y axis - custom_range_y: list[Any] | tuple[Any, Any] = Field( - default_factory=lambda: ["", ""], - alias="custom-range-y", - description="The custom range for the y axis", - ) - - #: The custom ticks for the y axis - custom_ticks_y: list[Any] = Field( - default_factory=list, - alias="custom-ticks-y", - description="The custom ticks for the y axis", - ) - - #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings) - y_grid_format: DateFormat | NumberFormat | str = Field( - default="", - alias="y-grid-format", - description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the y grid - y_grid: GridDisplay | str = Field( - default="on", - alias="y-grid", - description="Whether to show the y grid. The 'on' setting shows lines.", - ) + # Note: y_grid, y_grid_format, custom_range_y, custom_ticks_y inherited from mixins #: The labeling of the y grid labels y_grid_labels: GridLabelPosition | str = Field( @@ -700,56 +651,52 @@ def serialize_model(self) -> dict: model = super().serialize_model() # Add chart specific properties - model["metadata"]["visualize"].update( - { - # Horizontal axis - "custom-range-x": CustomRange.serialize(self.custom_range_x), - "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x), - "x-grid-format": self.x_grid_format, - "x-grid": self.x_grid, - # Vertical axis - "custom-range-y": CustomRange.serialize(self.custom_range_y), - "custom-ticks-y": CustomTicks.serialize(self.custom_ticks_y), - "y-grid-format": self.y_grid_format, - "y-grid": self.y_grid, - "y-grid-labels": self.y_grid_labels, - "y-grid-label-align": self.y_grid_label_align, - "scale-y": self.scale_y, - "y-grid-subdivide": self.y_grid_subdivide, - # Customize lines - "base-color": self.base_color, - "interpolation": self.interpolation, - "connector-lines": self.connector_lines, - "color-category": ColorCategory.serialize(self.color_category), - # Labels - "stack-color-legend": self.stack_color_legend, - "label-colors": self.label_colors, - "label-margin": self.label_margin, - "value-labels-format": self.value_labels_format, - "value-label-colors": self.value_label_colors, - # Tooltips - "show-tooltips": self.show_tooltips, - "tooltip-x-format": self.tooltip_x_format, - "tooltip-number-format": self.tooltip_number_format, - # Appearance - **PlotHeight.serialize( - self.plot_height_mode, - self.plot_height_fixed, - self.plot_height_ratio, - ), - # Initialize empty structures - "lines": {}, - "text-annotations": ModelListSerializer.serialize( - self.text_annotations, TextAnnotation - ), - "range-annotations": ModelListSerializer.serialize( - self.range_annotations, RangeAnnotation - ), - "custom-area-fills": ModelListSerializer.serialize( - self.area_fills, AreaFill - ), - } - ) + visualize_data = { + # Horizontal axis (from mixins) + **self._serialize_grid_config(), + **self._serialize_grid_format(), + **self._serialize_custom_range(), + **self._serialize_custom_ticks(), + # Vertical axis (chart-specific) + "y-grid-labels": self.y_grid_labels, + "y-grid-label-align": self.y_grid_label_align, + "scale-y": self.scale_y, + "y-grid-subdivide": self.y_grid_subdivide, + # Customize lines + "base-color": self.base_color, + "interpolation": self.interpolation, + "connector-lines": self.connector_lines, + "color-category": ColorCategory.serialize(self.color_category), + # Labels + "stack-color-legend": self.stack_color_legend, + "label-colors": self.label_colors, + "label-margin": self.label_margin, + "value-labels-format": self.value_labels_format, + "value-label-colors": self.value_label_colors, + # Tooltips + "show-tooltips": self.show_tooltips, + "tooltip-x-format": self.tooltip_x_format, + "tooltip-number-format": self.tooltip_number_format, + # Appearance + **PlotHeight.serialize( + self.plot_height_mode, + self.plot_height_fixed, + self.plot_height_ratio, + ), + # Initialize empty structures + "lines": {}, + "text-annotations": ModelListSerializer.serialize( + self.text_annotations, TextAnnotation + ), + "range-annotations": ModelListSerializer.serialize( + self.range_annotations, RangeAnnotation + ), + "custom-area-fills": ModelListSerializer.serialize( + self.area_fills, AreaFill + ), + } + + model["metadata"]["visualize"].update(visualize_data) # Add line configurations for line_obj in self.lines: @@ -784,30 +731,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: metadata = api_response.get("metadata", {}) visualize = metadata.get("visualize", {}) - # Horizontal axis (X-axis) - init_data["custom_range_x"] = CustomRange.deserialize( - visualize.get("custom-range-x") - ) - init_data["custom_ticks_x"] = CustomTicks.deserialize( - visualize.get("custom-ticks-x", "") - ) - if "x-grid-format" in visualize: - init_data["x_grid_format"] = visualize["x-grid-format"] - if "x-grid" in visualize: - init_data["x_grid"] = visualize["x-grid"] - - # Vertical axis (Y-axis) - init_data["custom_range_y"] = CustomRange.deserialize( - visualize.get("custom-range-y") - ) - init_data["custom_ticks_y"] = CustomTicks.deserialize( - visualize.get("custom-ticks-y", "") - ) + # Horizontal and vertical axis (from mixins) + init_data.update(cls._deserialize_grid_config(visualize)) + init_data.update(cls._deserialize_grid_format(visualize)) + init_data.update(cls._deserialize_custom_range(visualize)) + init_data.update(cls._deserialize_custom_ticks(visualize)) - if "y-grid-format" in visualize: - init_data["y_grid_format"] = visualize["y-grid-format"] - if "y-grid" in visualize: - init_data["y_grid"] = visualize["y-grid"] + # Vertical axis (chart-specific) if "y-grid-labels" in visualize: init_data["y_grid_labels"] = visualize["y-grid-labels"] if "y-grid-label-align" in visualize: diff --git a/datawrapper/charts/mixins.py b/datawrapper/charts/mixins.py new file mode 100644 index 00000000..c2c56a68 --- /dev/null +++ b/datawrapper/charts/mixins.py @@ -0,0 +1,230 @@ +"""Mixin classes for shared chart visualization patterns.""" + +from typing import Any + +from datawrapper.charts.enums import DateFormat, GridDisplay, NumberFormat +from datawrapper.charts.serializers import CustomRange, CustomTicks + + +class GridConfigMixin: + """Mixin for charts that support grid display configuration. + + Provides x_grid and y_grid fields for controlling grid line visibility, + along with serialization/deserialization methods. + + Default values: + - x_grid: "off" (no vertical grid lines by default) + - y_grid: "on" (horizontal grid lines shown by default) + + Supports backwards compatibility with boolean values: + - True → "on" during serialization + - False → "off" during serialization + - API "on" → True during deserialization + - API "off" → False during deserialization + """ + + x_grid: GridDisplay | str | bool | None = "off" + y_grid: GridDisplay | str | bool | None = "on" + + def _serialize_grid_config(self) -> dict: + """Serialize grid configuration to API format. + + Handles conversion of boolean values to "on"/"off" strings for backwards compatibility. + + Returns: + dict: Grid configuration in API format with keys: + - x-grid: X-axis grid display setting + - y-grid: Y-axis grid display setting + """ + result = {} + if self.x_grid is not None: + # Handle boolean values for backwards compatibility + if isinstance(self.x_grid, bool): + result["x-grid"] = "on" if self.x_grid else "off" + elif isinstance(self.x_grid, GridDisplay): + result["x-grid"] = self.x_grid.value + else: + result["x-grid"] = self.x_grid + if self.y_grid is not None: + # Handle boolean values for backwards compatibility + if isinstance(self.y_grid, bool): + result["y-grid"] = "on" if self.y_grid else "off" + elif isinstance(self.y_grid, GridDisplay): + result["y-grid"] = self.y_grid.value + else: + result["y-grid"] = self.y_grid + return result + + @classmethod + def _deserialize_grid_config(cls, visualize: dict) -> dict: + """Deserialize grid configuration from API format. + + Preserves original API values without conversion. + + Args: + visualize: The visualize section from API response + + Returns: + dict: Grid configuration in Python format with keys: + - x_grid: X-axis grid display setting (preserves API type) + - y_grid: Y-axis grid display setting (preserves API type) + """ + result = {} + if "x-grid" in visualize: + result["x_grid"] = visualize["x-grid"] + if "y-grid" in visualize: + result["y_grid"] = visualize["y-grid"] + return result + + +class GridFormatMixin: + """Mixin for charts that support grid label formatting. + + Provides x_grid_format and y_grid_format fields for controlling how grid labels + are displayed, along with serialization/deserialization methods. + + Used by: LineChart, AreaChart, ColumnChart, MultipleColumnChart, ScatterPlot + """ + + x_grid_format: DateFormat | NumberFormat | str | None = None + y_grid_format: NumberFormat | str | None = None + + def _serialize_grid_format(self) -> dict: + """Serialize grid format configuration to API format. + + Returns: + dict: Grid format configuration in API format with keys: + - x-grid-format: X-axis grid label format + - y-grid-format: Y-axis grid label format + """ + result = {} + if self.x_grid_format is not None: + result["x-grid-format"] = ( + self.x_grid_format.value + if isinstance(self.x_grid_format, (DateFormat, NumberFormat)) + else self.x_grid_format + ) + if self.y_grid_format is not None: + result["y-grid-format"] = ( + self.y_grid_format.value + if isinstance(self.y_grid_format, NumberFormat) + else self.y_grid_format + ) + return result + + @classmethod + def _deserialize_grid_format(cls, visualize: dict) -> dict: + """Deserialize grid format configuration from API format. + + Args: + visualize: The visualize section from API response + + Returns: + dict: Grid format configuration in Python format with keys: + - x_grid_format: X-axis grid label format + - y_grid_format: Y-axis grid label format + """ + result = {} + if "x-grid-format" in visualize: + result["x_grid_format"] = visualize["x-grid-format"] + if "y-grid-format" in visualize: + result["y_grid_format"] = visualize["y-grid-format"] + return result + + +class CustomRangeMixin: + """Mixin for charts that support custom axis ranges. + + Provides custom_range_x and custom_range_y fields for setting explicit min/max + values for axes, along with serialization/deserialization methods. + """ + + custom_range_x: list[Any] | tuple[Any, Any] | None = None + custom_range_y: list[Any] | tuple[Any, Any] | None = None + + def _serialize_custom_range(self) -> dict: + """Serialize custom range configuration to API format. + + Returns: + dict: Custom range configuration in API format with keys: + - custom-range-x: X-axis custom range [min, max] + - custom-range-y: Y-axis custom range [min, max] + """ + result = {} + if self.custom_range_x is not None: + result["custom-range-x"] = CustomRange.serialize(self.custom_range_x) + if self.custom_range_y is not None: + result["custom-range-y"] = CustomRange.serialize(self.custom_range_y) + return result + + @classmethod + def _deserialize_custom_range(cls, visualize: dict) -> dict: + """Deserialize custom range configuration from API format. + + Args: + visualize: The visualize section from API response + + Returns: + dict: Custom range configuration in Python format with keys: + - custom_range_x: X-axis custom range [min, max] + - custom_range_y: Y-axis custom range [min, max] + """ + result = {} + if "custom-range-x" in visualize: + result["custom_range_x"] = CustomRange.deserialize( + visualize["custom-range-x"] + ) + if "custom-range-y" in visualize: + result["custom_range_y"] = CustomRange.deserialize( + visualize["custom-range-y"] + ) + return result + + +class CustomTicksMixin: + """Mixin for charts that support custom tick marks. + + Provides custom_ticks_x and custom_ticks_y fields for setting explicit tick mark + positions on axes, along with serialization/deserialization methods. + """ + + custom_ticks_x: list[Any] | None = None + custom_ticks_y: list[Any] | None = None + + def _serialize_custom_ticks(self) -> dict: + """Serialize custom ticks configuration to API format. + + Returns: + dict: Custom ticks configuration in API format with keys: + - custom-ticks-x: X-axis custom tick positions + - custom-ticks-y: Y-axis custom tick positions + """ + result = {} + if self.custom_ticks_x is not None: + result["custom-ticks-x"] = CustomTicks.serialize(self.custom_ticks_x) + if self.custom_ticks_y is not None: + result["custom-ticks-y"] = CustomTicks.serialize(self.custom_ticks_y) + return result + + @classmethod + def _deserialize_custom_ticks(cls, visualize: dict) -> dict: + """Deserialize custom ticks configuration from API format. + + Args: + visualize: The visualize section from API response + + Returns: + dict: Custom ticks configuration in Python format with keys: + - custom_ticks_x: X-axis custom tick positions + - custom_ticks_y: Y-axis custom tick positions + """ + result = {} + if "custom-ticks-x" in visualize: + result["custom_ticks_x"] = CustomTicks.deserialize( + visualize["custom-ticks-x"] + ) + if "custom-ticks-y" in visualize: + result["custom_ticks_y"] = CustomTicks.deserialize( + visualize["custom-ticks-y"] + ) + return result diff --git a/datawrapper/charts/multiple_column.py b/datawrapper/charts/multiple_column.py index aaf8d737..99a82440 100644 --- a/datawrapper/charts/multiple_column.py +++ b/datawrapper/charts/multiple_column.py @@ -15,10 +15,14 @@ ValueLabelDisplay, ValueLabelPlacement, ) +from .mixins import ( + CustomRangeMixin, + CustomTicksMixin, + GridConfigMixin, + GridFormatMixin, +) from .serializers import ( ColorCategory, - CustomRange, - CustomTicks, ModelListSerializer, NegativeColor, PlotHeight, @@ -26,7 +30,9 @@ ) -class MultipleColumnChart(BaseChart): +class MultipleColumnChart( + GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart +): """A base class for the Datawrapper API's multiple column chart.""" model_config = ConfigDict( @@ -131,34 +137,6 @@ class MultipleColumnChart(BaseChart): # Horizontal axis # - #: The custom range for the x axis - custom_range_x: tuple[Any, Any] | list[Any] = Field( - default=("", ""), - alias="custom-range-x", - description="The custom range for the x axis", - ) - - #: The custom ticks for the x axis - custom_ticks_x: list[Any] = Field( - default_factory=list, - alias="custom-ticks-x", - description="The custom ticks for the x axis", - ) - - #: The formatting for the x grid labels - x_grid_format: str = Field( - default="auto", - alias="x-grid-format", - description="The formatting for the x grid labels", - ) - - #: Whether to show the x grid - x_grid: GridDisplay | str = Field( - default="off", - alias="x-grid", - description="Whether to show the x grid", - ) - #: The labeling of the x axis x_grid_labels: Literal["on", "off"] = Field( default="on", @@ -177,34 +155,6 @@ class MultipleColumnChart(BaseChart): # Vertical axis # - #: The custom range for the y axis - custom_range_y: tuple[Any, Any] | list[Any] = Field( - default=("", ""), - alias="custom-range-y", - description="The custom range for the y axis", - ) - - #: The custom ticks for the y axis - custom_ticks_y: list[Any] = Field( - default_factory=list, - alias="custom-ticks-y", - description="The custom ticks for the y axis", - ) - - #: The formatting for the y grid labels (use DateFormat or NumberFormat enum or custom format strings) - y_grid_format: DateFormat | NumberFormat | str = Field( - default="", - alias="y-grid-format", - description="The formatting for the y grid labels. Use DateFormat for temporal data, NumberFormat for numeric data, or provide custom format strings.", - ) - - #: Whether to show the y grid lines - y_grid: bool = Field( - default=True, - alias="y-grid", - description="Whether to show the y grid lines", - ) - #: The labeling of the y grid labels y_grid_labels: GridLabelPosition | str = Field( default="outside", @@ -385,77 +335,74 @@ def serialize_model(self) -> dict: model = super().serialize_model() # Add chart specific properties to visualize section - model["metadata"]["visualize"].update( - { - # Layout - "gridLayout": self.grid_layout, - "gridColumnCount": self.grid_column, - "gridColumnCountMobile": self.grid_column_mobile, - "gridColumnMinWidth": self.grid_column_width, - "gridRowHeightFixed": self.grid_row_height, - "sort": { - "enabled": self.sort, - "reverse": self.sort_reverse, - "by": self.sort_by, - }, - # Horizontal axis - "custom-range-x": CustomRange.serialize(self.custom_range_x), - "custom-ticks-x": CustomTicks.serialize(self.custom_ticks_x), - "x-grid-format": self.x_grid_format, - "x-grid-labels": self.x_grid_labels, - "x-grid": self.x_grid_all, - "grid-lines-x": { - "type": "" if self.x_grid == "off" else self.x_grid, - "enabled": self.x_grid != "off", - }, - # Vertical axis - "custom-range-y": CustomRange.serialize(self.custom_range_y), - "custom-ticks-y": CustomTicks.serialize(self.custom_ticks_y), - "y-grid-format": self.y_grid_format, - "grid-lines": self.y_grid, - "yAxisLabels": { - "enabled": self.y_grid_labels != "off", - "alignment": self.y_grid_label_align, - "placement": "" - if self.y_grid_labels == "off" - else self.y_grid_labels, - }, - # Appearance - "base-color": self.base_color, - "negativeColor": NegativeColor.serialize(self.negative_color), - "bar-padding": self.bar_padding, - "color-category": ColorCategory.serialize(self.color_category), - "color-by-column": bool(self.color_category), - **PlotHeight.serialize( - self.plot_height_mode, - self.plot_height_fixed, - self.plot_height_ratio, - ), - "panels": {panel["column"]: panel for panel in self.panels}, - # Tooltips - "show-tooltips": self.show_tooltips, - "syncMultipleTooltips": self.sync_multiple_tooltips, - "tooltip-number-format": self.tooltip_number_format, - # Labels - "show-color-key": self.show_color_key, - "label-colors": self.label_colors, - "label-margin": self.label_margin, - **ValueLabels.serialize( - self.show_value_labels, - self.value_labels_format, - placement=self.value_labels_placement, - chart_type="multiple-column", - ), - "xGridLabelAllColumns": self.x_grid_label_all, - # Annotations - "text-annotations": ModelListSerializer.serialize( - self.text_annotations, TextAnnotation - ), - "range-annotations": ModelListSerializer.serialize( - self.range_annotations, RangeAnnotation - ), - } - ) + visualize_data = { + # Layout + "gridLayout": self.grid_layout, + "gridColumnCount": self.grid_column, + "gridColumnCountMobile": self.grid_column_mobile, + "gridColumnMinWidth": self.grid_column_width, + "gridRowHeightFixed": self.grid_row_height, + "sort": { + "enabled": self.sort, + "reverse": self.sort_reverse, + "by": self.sort_by, + }, + # Horizontal and vertical axis (from mixins) + **self._serialize_grid_config(), + **self._serialize_grid_format(), + **self._serialize_custom_range(), + **self._serialize_custom_ticks(), + # Horizontal axis (chart-specific) + "x-grid-labels": self.x_grid_labels, + "x-grid": self.x_grid_all, + "grid-lines-x": { + "type": "" if self.x_grid == "off" else self.x_grid, + "enabled": self.x_grid != "off", + }, + # Vertical axis (chart-specific) + "grid-lines": self.y_grid, + "yAxisLabels": { + "enabled": self.y_grid_labels != "off", + "alignment": self.y_grid_label_align, + "placement": "" if self.y_grid_labels == "off" else self.y_grid_labels, + }, + # Appearance + "base-color": self.base_color, + "negativeColor": NegativeColor.serialize(self.negative_color), + "bar-padding": self.bar_padding, + "color-category": ColorCategory.serialize(self.color_category), + "color-by-column": bool(self.color_category), + **PlotHeight.serialize( + self.plot_height_mode, + self.plot_height_fixed, + self.plot_height_ratio, + ), + "panels": {panel["column"]: panel for panel in self.panels}, + # Tooltips + "show-tooltips": self.show_tooltips, + "syncMultipleTooltips": self.sync_multiple_tooltips, + "tooltip-number-format": self.tooltip_number_format, + # Labels + "show-color-key": self.show_color_key, + "label-colors": self.label_colors, + "label-margin": self.label_margin, + **ValueLabels.serialize( + self.show_value_labels, + self.value_labels_format, + placement=self.value_labels_placement, + chart_type="multiple-column", + ), + "xGridLabelAllColumns": self.x_grid_label_all, + # Annotations + "text-annotations": ModelListSerializer.serialize( + self.text_annotations, TextAnnotation + ), + "range-annotations": ModelListSerializer.serialize( + self.range_annotations, RangeAnnotation + ), + } + + model["metadata"]["visualize"].update(visualize_data) # Return the serialized data return model @@ -501,15 +448,13 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: init_data["sort_reverse"] = False init_data["sort_by"] = "end" - # Horizontal axis - init_data["custom_range_x"] = CustomRange.deserialize( - visualize.get("custom-range-x") - ) - init_data["custom_ticks_x"] = CustomTicks.deserialize( - visualize.get("custom-ticks-x", "") - ) - if "x-grid-format" in visualize: - init_data["x_grid_format"] = visualize["x-grid-format"] + # Horizontal and vertical axis (from mixins) + init_data.update(cls._deserialize_grid_config(visualize)) + init_data.update(cls._deserialize_grid_format(visualize)) + init_data.update(cls._deserialize_custom_range(visualize)) + init_data.update(cls._deserialize_custom_ticks(visualize)) + + # Horizontal axis (chart-specific) if "x-grid-labels" in visualize: init_data["x_grid_labels"] = visualize["x-grid-labels"] if "x-grid" in visualize: @@ -525,16 +470,7 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]: else: init_data["x_grid"] = "off" - # Vertical axis - init_data["custom_range_y"] = CustomRange.deserialize( - visualize.get("custom-range-y") - ) - init_data["custom_ticks_y"] = CustomTicks.deserialize( - visualize.get("custom-ticks-y", "") - ) - if "y-grid-format" in visualize: - init_data["y_grid_format"] = visualize["y-grid-format"] - + # Vertical axis (chart-specific) # Parse grid-lines (can be bool or string "show") if "grid-lines" in visualize: grid_lines_val = visualize["grid-lines"] diff --git a/datawrapper/charts/serializers/base.py b/datawrapper/charts/serializers/base.py new file mode 100644 index 00000000..e7fc0bb8 --- /dev/null +++ b/datawrapper/charts/serializers/base.py @@ -0,0 +1,69 @@ +"""Base serializer class for all serialization utilities.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class BaseSerializer(ABC): + """Abstract base class for serialization utilities. + + This class defines the standard interface that all serializer utilities + should implement. It provides a consistent pattern for converting between + Python objects and Datawrapper API JSON formats. + + All serializer classes should inherit from this base class and implement + the serialize() and deserialize() methods. + + Example: + >>> class CustomRange(BaseSerializer): + ... @staticmethod + ... def serialize(range_values: list[Any] | tuple[Any, Any]) -> list[Any]: + ... # Implementation here + ... pass + ... + ... @staticmethod + ... def deserialize(range_list: list[Any] | None) -> list[Any] | None: + ... # Implementation here + ... pass + """ + + @staticmethod + @abstractmethod + def serialize(*args: Any, **kwargs: Any) -> Any: + """Convert Python objects to Datawrapper API format. + + This method should be implemented by subclasses to handle the + conversion from Python objects to the format expected by the + Datawrapper API. + + Args: + *args: Positional arguments specific to the serializer + **kwargs: Keyword arguments specific to the serializer + + Returns: + Any: The serialized data in API format + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError("Subclasses must implement serialize()") + + @staticmethod + @abstractmethod + def deserialize(*args: Any, **kwargs: Any) -> Any: + """Convert Datawrapper API format to Python objects. + + This method should be implemented by subclasses to handle the + conversion from the Datawrapper API format to Python objects. + + Args: + *args: Positional arguments specific to the serializer + **kwargs: Keyword arguments specific to the serializer + + Returns: + Any: The deserialized data as Python objects + + Raises: + NotImplementedError: If not implemented by subclass + """ + raise NotImplementedError("Subclasses must implement deserialize()") diff --git a/datawrapper/charts/serializers/color_category.py b/datawrapper/charts/serializers/color_category.py index b67deb72..4211a8f6 100644 --- a/datawrapper/charts/serializers/color_category.py +++ b/datawrapper/charts/serializers/color_category.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class ColorCategory: + +class ColorCategory(BaseSerializer): """Utility class for serializing and deserializing color category structures.""" @staticmethod diff --git a/datawrapper/charts/serializers/custom_range.py b/datawrapper/charts/serializers/custom_range.py index 6edf41ba..355c98e3 100644 --- a/datawrapper/charts/serializers/custom_range.py +++ b/datawrapper/charts/serializers/custom_range.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class CustomRange: + +class CustomRange(BaseSerializer): """Utility class for serializing and deserializing custom axis ranges.""" @staticmethod diff --git a/datawrapper/charts/serializers/custom_ticks.py b/datawrapper/charts/serializers/custom_ticks.py index 159ec3d6..aec871ca 100644 --- a/datawrapper/charts/serializers/custom_ticks.py +++ b/datawrapper/charts/serializers/custom_ticks.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class CustomTicks: + +class CustomTicks(BaseSerializer): """Utility class for serializing and deserializing custom tick marks.""" @staticmethod diff --git a/datawrapper/charts/serializers/negative_color.py b/datawrapper/charts/serializers/negative_color.py index 84538e4d..88004f04 100644 --- a/datawrapper/charts/serializers/negative_color.py +++ b/datawrapper/charts/serializers/negative_color.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class NegativeColor: + +class NegativeColor(BaseSerializer): """Utility class for serializing and deserializing negative color configuration. The Datawrapper API uses a nested object format for the negativeColor field: diff --git a/datawrapper/charts/serializers/plot_height.py b/datawrapper/charts/serializers/plot_height.py index c3e89485..f86cf71c 100644 --- a/datawrapper/charts/serializers/plot_height.py +++ b/datawrapper/charts/serializers/plot_height.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class PlotHeight: + +class PlotHeight(BaseSerializer): """Utility class for serializing and deserializing plot height configuration. The Datawrapper API uses three separate fields for plot height: diff --git a/datawrapper/charts/serializers/replace_flags.py b/datawrapper/charts/serializers/replace_flags.py index d76019f6..adf6877c 100644 --- a/datawrapper/charts/serializers/replace_flags.py +++ b/datawrapper/charts/serializers/replace_flags.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class ReplaceFlags: + +class ReplaceFlags(BaseSerializer): """Utility class for serializing and deserializing replace-flags configuration. The Datawrapper API uses a nested object format for the replace-flags field: diff --git a/datawrapper/charts/serializers/value_labels.py b/datawrapper/charts/serializers/value_labels.py index ddeb8b1a..2ad86229 100644 --- a/datawrapper/charts/serializers/value_labels.py +++ b/datawrapper/charts/serializers/value_labels.py @@ -1,7 +1,9 @@ from typing import Any +from .base import BaseSerializer -class ValueLabels: + +class ValueLabels(BaseSerializer): """Utility class for serializing and deserializing value label configuration. Different chart types use different API formats for value labels: From 87cf04c15ca582f178d84aefa8e3d057e5868683 Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 15:48:34 -0400 Subject: [PATCH 06/11] docs: add mixins API reference documentation Add comprehensive documentation for the chart mixins module, covering grid configuration (GridConfigMixin, GridFormatMixin) and axis customization (CustomRangeMixin, CustomTicksMixin). Includes usage examples for each mixin demonstrating common chart configuration patterns like controlling grid visibility, formatting labels, and customizing axis ranges and tick marks. --- docs/user-guide/api/mixins.rst | 95 ++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 docs/user-guide/api/mixins.rst diff --git a/docs/user-guide/api/mixins.rst b/docs/user-guide/api/mixins.rst new file mode 100644 index 00000000..180b837f --- /dev/null +++ b/docs/user-guide/api/mixins.rst @@ -0,0 +1,95 @@ +Mixins +====== + +The mixins module provides reusable functionality that can be shared across multiple chart types. These mixins handle common chart configuration patterns like grid display, formatting, and axis customization. + +.. currentmodule:: datawrapper.charts.mixins + +Grid Configuration +------------------ + +GridConfigMixin +~~~~~~~~~~~~~~~ + +Controls the visibility of grid lines on chart axes. + +.. autoclass:: GridConfigMixin + :members: + :show-inheritance: + +**Example:** + +.. code-block:: python + + from datawrapper.charts import LineChart, GridDisplay + + chart = LineChart( + title="Temperature Trends", + x_grid=GridDisplay.OFF, + y_grid=GridDisplay.ON + ) + +GridFormatMixin +~~~~~~~~~~~~~~~ + +Controls the formatting of grid labels on chart axes. + +.. autoclass:: GridFormatMixin + :members: + :show-inheritance: + +**Example:** + +.. code-block:: python + + from datawrapper.charts import LineChart, NumberFormat, DateFormat + + chart = LineChart( + title="Sales Over Time", + x_grid_format=DateFormat.MONTH_ABBREVIATED_WITH_YEAR, + y_grid_format=NumberFormat.THOUSANDS_SEPARATOR + ) + +Axis Customization +------------------ + +CustomRangeMixin +~~~~~~~~~~~~~~~~ + +Sets custom minimum and maximum values for chart axes. + +.. autoclass:: CustomRangeMixin + :members: + :show-inheritance: + +**Example:** + +.. code-block:: python + + from datawrapper.charts import ColumnChart + + chart = ColumnChart( + title="Revenue by Quarter", + custom_range_y=[0, 1000000] # Set Y-axis from 0 to 1M + ) + +CustomTicksMixin +~~~~~~~~~~~~~~~~ + +Sets custom tick mark positions on chart axes. + +.. autoclass:: CustomTicksMixin + :members: + :show-inheritance: + +**Example:** + +.. code-block:: python + + from datawrapper.charts import LineChart + + chart = LineChart( + title="Monthly Data", + custom_ticks_x=["Jan", "Apr", "Jul", "Oct"], + custom_ticks_y=[0, 25, 50, 75, 100] + ) From c7bb38facc5542a4e3fe5f7d8387e3a16f7fe108 Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 15:53:47 -0400 Subject: [PATCH 07/11] feat(charts): export chart mixins in public API Export CustomRangeMixin, CustomTicksMixin, GridConfigMixin, and GridFormatMixin from the main datawrapper package to make them available for public use. These mixins provide functionality for customizing chart axes, grid display, and formatting. Also updated documentation to use the public import pattern (import datawrapper as dw) instead of importing directly from submodules, making examples more consistent with recommended usage. --- datawrapper/__init__.py | 10 ++++++++++ datawrapper/charts/__init__.py | 5 +++++ docs/index.md | 1 + docs/user-guide/api/mixins.rst | 24 ++++++++++++------------ 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/datawrapper/__init__.py b/datawrapper/__init__.py index a384edae..28d0b9cc 100644 --- a/datawrapper/__init__.py +++ b/datawrapper/__init__.py @@ -63,6 +63,12 @@ ValueLabelMode, ValueLabelPlacement, ) +from datawrapper.charts.mixins import ( + CustomRangeMixin, + CustomTicksMixin, + GridConfigMixin, + GridFormatMixin, +) from datawrapper.exceptions import ( FailedRequestError, InvalidRequestError, @@ -123,6 +129,10 @@ "ValueLabelMode", "ValueLabelPlacement", "get_country_flag", + "CustomRangeMixin", + "CustomTicksMixin", + "GridFormatMixin", + "GridConfigMixin", "FailedRequestError", "InvalidRequestError", "RateLimitError", diff --git a/datawrapper/charts/__init__.py b/datawrapper/charts/__init__.py index 449fd282..9aa312f1 100644 --- a/datawrapper/charts/__init__.py +++ b/datawrapper/charts/__init__.py @@ -33,6 +33,7 @@ ValueLabelPlacement, ) from .line import AreaFill, Line, LineChart, LineSymbol, LineValueLabel +from .mixins import CustomRangeMixin, CustomTicksMixin, GridConfigMixin, GridFormatMixin from .models import ( Annotate, ColumnFormat, @@ -56,6 +57,10 @@ "Annotate", "ColumnFormat", "ColumnFormatList", + "CustomRangeMixin", + "CustomTicksMixin", + "GridFormatMixin", + "GridConfigMixin", "ArrowHead", "ConnectorLineType", "DateFormat", diff --git a/docs/index.md b/docs/index.md index a6685f83..4fc9066b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -46,6 +46,7 @@ caption: API Reference user-guide/api/main-client.rst user-guide/api/chart-classes.rst user-guide/api/models.rst +user-guide/api/mixins.rst user-guide/api/enums.rst user-guide/api/exceptions.rst ``` diff --git a/docs/user-guide/api/mixins.rst b/docs/user-guide/api/mixins.rst index 180b837f..29e5f0ff 100644 --- a/docs/user-guide/api/mixins.rst +++ b/docs/user-guide/api/mixins.rst @@ -21,12 +21,12 @@ Controls the visibility of grid lines on chart axes. .. code-block:: python - from datawrapper.charts import LineChart, GridDisplay + import datawrapper as dw - chart = LineChart( + chart = dw.LineChart( title="Temperature Trends", - x_grid=GridDisplay.OFF, - y_grid=GridDisplay.ON + x_grid=dw.GridDisplay.OFF, + y_grid=dw.GridDisplay.ON ) GridFormatMixin @@ -42,12 +42,12 @@ Controls the formatting of grid labels on chart axes. .. code-block:: python - from datawrapper.charts import LineChart, NumberFormat, DateFormat + import datawrapper as dw - chart = LineChart( + chart = dw.LineChart( title="Sales Over Time", - x_grid_format=DateFormat.MONTH_ABBREVIATED_WITH_YEAR, - y_grid_format=NumberFormat.THOUSANDS_SEPARATOR + x_grid_format=dw.DateFormat.MONTH_ABBREVIATED_WITH_YEAR, + y_grid_format=dw.NumberFormat.THOUSANDS_SEPARATOR ) Axis Customization @@ -66,9 +66,9 @@ Sets custom minimum and maximum values for chart axes. .. code-block:: python - from datawrapper.charts import ColumnChart + import datawrapper as dw - chart = ColumnChart( + chart = dw.ColumnChart( title="Revenue by Quarter", custom_range_y=[0, 1000000] # Set Y-axis from 0 to 1M ) @@ -86,9 +86,9 @@ Sets custom tick mark positions on chart axes. .. code-block:: python - from datawrapper.charts import LineChart + import datawrapper as dw - chart = LineChart( + chart = dw.LineChart( title="Monthly Data", custom_ticks_x=["Jan", "Apr", "Jul", "Oct"], custom_ticks_y=[0, 25, 50, 75, 100] From bd9a69028df6b34f43e43706fa87c5d7845844dc Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 15:58:48 -0400 Subject: [PATCH 08/11] feat: add Field descriptions to chart mixin attributes Add Pydantic Field descriptors with detailed descriptions to all attributes in chart mixins (GridConfigMixin, GridFormatMixin, CustomRangeMixin, and CustomTicksMixin). This improves API documentation and provides better context for users about what each field controls and how it should be used. Changes: - Import Field from pydantic - Replace simple default assignments with Field() descriptors - Add descriptive documentation for x_grid, y_grid, x_grid_format, y_grid_format, custom_range_x, custom_range_y, custom_ticks_x, and custom_ticks_y attributes - Maintain existing default values and type hints --- datawrapper/charts/mixins.py | 42 +++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/datawrapper/charts/mixins.py b/datawrapper/charts/mixins.py index c2c56a68..21432169 100644 --- a/datawrapper/charts/mixins.py +++ b/datawrapper/charts/mixins.py @@ -2,6 +2,8 @@ from typing import Any +from pydantic import Field + from datawrapper.charts.enums import DateFormat, GridDisplay, NumberFormat from datawrapper.charts.serializers import CustomRange, CustomTicks @@ -23,8 +25,14 @@ class GridConfigMixin: - API "off" → False during deserialization """ - x_grid: GridDisplay | str | bool | None = "off" - y_grid: GridDisplay | str | bool | None = "on" + x_grid: GridDisplay | str | bool | None = Field( + default="off", + description="X-axis grid display setting. Controls vertical grid lines.", + ) + y_grid: GridDisplay | str | bool | None = Field( + default="on", + description="Y-axis grid display setting. Controls horizontal grid lines.", + ) def _serialize_grid_config(self) -> dict: """Serialize grid configuration to API format. @@ -86,8 +94,14 @@ class GridFormatMixin: Used by: LineChart, AreaChart, ColumnChart, MultipleColumnChart, ScatterPlot """ - x_grid_format: DateFormat | NumberFormat | str | None = None - y_grid_format: NumberFormat | str | None = None + x_grid_format: DateFormat | NumberFormat | str | None = Field( + default=None, + description="Format string for X-axis grid labels. Supports date and number formats.", + ) + y_grid_format: NumberFormat | str | None = Field( + default=None, + description="Format string for Y-axis grid labels. Supports number formats.", + ) def _serialize_grid_format(self) -> dict: """Serialize grid format configuration to API format. @@ -139,8 +153,14 @@ class CustomRangeMixin: values for axes, along with serialization/deserialization methods. """ - custom_range_x: list[Any] | tuple[Any, Any] | None = None - custom_range_y: list[Any] | tuple[Any, Any] | None = None + custom_range_x: list[Any] | tuple[Any, Any] | None = Field( + default=None, + description="Custom range for X-axis as [min, max]. Overrides automatic range calculation.", + ) + custom_range_y: list[Any] | tuple[Any, Any] | None = Field( + default=None, + description="Custom range for Y-axis as [min, max]. Overrides automatic range calculation.", + ) def _serialize_custom_range(self) -> dict: """Serialize custom range configuration to API format. @@ -188,8 +208,14 @@ class CustomTicksMixin: positions on axes, along with serialization/deserialization methods. """ - custom_ticks_x: list[Any] | None = None - custom_ticks_y: list[Any] | None = None + custom_ticks_x: list[Any] | None = Field( + default=None, + description="Custom tick mark positions for X-axis. List of values where ticks should appear.", + ) + custom_ticks_y: list[Any] | None = Field( + default=None, + description="Custom tick mark positions for Y-axis. List of values where ticks should appear.", + ) def _serialize_custom_ticks(self) -> dict: """Serialize custom ticks configuration to API format. From 80e3e822a485cb73f1547c64007a44283a9302f1 Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 16:02:57 -0400 Subject: [PATCH 09/11] refactor: rename GridConfigMixin to GridDisplayMixin Rename GridConfigMixin to GridDisplayMixin across the codebase to better reflect its purpose of controlling grid display visibility (x_grid and y_grid fields) rather than general grid configuration. This improves code clarity and naming consistency. Changes: - Renamed class GridConfigMixin to GridDisplayMixin in mixins.py - Updated all imports and references across chart modules (area, column, line, multiple_column) - Updated __all__ exports in __init__.py files - Updated documentation references --- datawrapper/__init__.py | 4 ++-- datawrapper/charts/__init__.py | 9 +++++++-- datawrapper/charts/area.py | 4 ++-- datawrapper/charts/column.py | 4 ++-- datawrapper/charts/line.py | 4 ++-- datawrapper/charts/mixins.py | 6 +++--- datawrapper/charts/multiple_column.py | 4 ++-- docs/user-guide/api/mixins.rst | 4 ++-- 8 files changed, 22 insertions(+), 17 deletions(-) diff --git a/datawrapper/__init__.py b/datawrapper/__init__.py index 28d0b9cc..fa546bca 100644 --- a/datawrapper/__init__.py +++ b/datawrapper/__init__.py @@ -66,7 +66,7 @@ from datawrapper.charts.mixins import ( CustomRangeMixin, CustomTicksMixin, - GridConfigMixin, + GridDisplayMixin, GridFormatMixin, ) from datawrapper.exceptions import ( @@ -132,7 +132,7 @@ "CustomRangeMixin", "CustomTicksMixin", "GridFormatMixin", - "GridConfigMixin", + "GridDisplayMixin", "FailedRequestError", "InvalidRequestError", "RateLimitError", diff --git a/datawrapper/charts/__init__.py b/datawrapper/charts/__init__.py index 9aa312f1..c33d30eb 100644 --- a/datawrapper/charts/__init__.py +++ b/datawrapper/charts/__init__.py @@ -33,7 +33,12 @@ ValueLabelPlacement, ) from .line import AreaFill, Line, LineChart, LineSymbol, LineValueLabel -from .mixins import CustomRangeMixin, CustomTicksMixin, GridConfigMixin, GridFormatMixin +from .mixins import ( + CustomRangeMixin, + CustomTicksMixin, + GridDisplayMixin, + GridFormatMixin, +) from .models import ( Annotate, ColumnFormat, @@ -60,7 +65,7 @@ "CustomRangeMixin", "CustomTicksMixin", "GridFormatMixin", - "GridConfigMixin", + "GridDisplayMixin", "ArrowHead", "ConnectorLineType", "DateFormat", diff --git a/datawrapper/charts/area.py b/datawrapper/charts/area.py index aa7f283f..732f46b0 100644 --- a/datawrapper/charts/area.py +++ b/datawrapper/charts/area.py @@ -16,7 +16,7 @@ from .mixins import ( CustomRangeMixin, CustomTicksMixin, - GridConfigMixin, + GridDisplayMixin, GridFormatMixin, ) from .serializers import ( @@ -27,7 +27,7 @@ class AreaChart( - GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart + GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart ): """A base class for the Datawrapper API's area chart.""" diff --git a/datawrapper/charts/column.py b/datawrapper/charts/column.py index 6bd4c425..81e7b0e5 100644 --- a/datawrapper/charts/column.py +++ b/datawrapper/charts/column.py @@ -17,7 +17,7 @@ from .mixins import ( CustomRangeMixin, CustomTicksMixin, - GridConfigMixin, + GridDisplayMixin, GridFormatMixin, ) from .serializers import ( @@ -30,7 +30,7 @@ class ColumnChart( - GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart + GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart ): """A base class for the Datawrapper API's column chart.""" diff --git a/datawrapper/charts/line.py b/datawrapper/charts/line.py index 0805770c..ee6103d0 100644 --- a/datawrapper/charts/line.py +++ b/datawrapper/charts/line.py @@ -27,7 +27,7 @@ from .mixins import ( CustomRangeMixin, CustomTicksMixin, - GridConfigMixin, + GridDisplayMixin, GridFormatMixin, ) from .serializers import ( @@ -395,7 +395,7 @@ def deserialize_model(cls, line_name: str, line_config: dict) -> dict[str, Any]: class LineChart( - GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart + GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart ): """A base class for the Datawrapper API's line chart.""" diff --git a/datawrapper/charts/mixins.py b/datawrapper/charts/mixins.py index 21432169..94300fe7 100644 --- a/datawrapper/charts/mixins.py +++ b/datawrapper/charts/mixins.py @@ -4,11 +4,11 @@ from pydantic import Field -from datawrapper.charts.enums import DateFormat, GridDisplay, NumberFormat -from datawrapper.charts.serializers import CustomRange, CustomTicks +from .enums import DateFormat, GridDisplay, NumberFormat +from .serializers import CustomRange, CustomTicks -class GridConfigMixin: +class GridDisplayMixin: """Mixin for charts that support grid display configuration. Provides x_grid and y_grid fields for controlling grid line visibility, diff --git a/datawrapper/charts/multiple_column.py b/datawrapper/charts/multiple_column.py index 99a82440..3a7fb765 100644 --- a/datawrapper/charts/multiple_column.py +++ b/datawrapper/charts/multiple_column.py @@ -18,7 +18,7 @@ from .mixins import ( CustomRangeMixin, CustomTicksMixin, - GridConfigMixin, + GridDisplayMixin, GridFormatMixin, ) from .serializers import ( @@ -31,7 +31,7 @@ class MultipleColumnChart( - GridConfigMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart + GridDisplayMixin, GridFormatMixin, CustomRangeMixin, CustomTicksMixin, BaseChart ): """A base class for the Datawrapper API's multiple column chart.""" diff --git a/docs/user-guide/api/mixins.rst b/docs/user-guide/api/mixins.rst index 29e5f0ff..e1238fb8 100644 --- a/docs/user-guide/api/mixins.rst +++ b/docs/user-guide/api/mixins.rst @@ -8,12 +8,12 @@ The mixins module provides reusable functionality that can be shared across mult Grid Configuration ------------------ -GridConfigMixin +GridDisplayMixin ~~~~~~~~~~~~~~~ Controls the visibility of grid lines on chart axes. -.. autoclass:: GridConfigMixin +.. autoclass:: GridDisplayMixin :members: :show-inheritance: From 475301b4692d0366396ee91e5fbb47b305d3a390 Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 16:06:35 -0400 Subject: [PATCH 10/11] feat(charts): add configurable timeout parameter to export methods Add timeout parameter (default 30s) to export_as_png(), export_as_pdf(), and export_as_svg() methods to allow users to configure request timeouts for chart export operations. This prevents indefinite hangs when the Datawrapper API is slow or unresponsive during export generation. --- datawrapper/charts/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datawrapper/charts/base.py b/datawrapper/charts/base.py index 27697487..c8c220b9 100644 --- a/datawrapper/charts/base.py +++ b/datawrapper/charts/base.py @@ -743,6 +743,7 @@ def export_png( border_width: int = 0, border_color: str | None = None, access_token: str | None = None, + timeout: int = 30, ) -> bytes: """Export chart as PNG and return the raw bytes. @@ -755,6 +756,7 @@ def export_png( border_width: Margin around visualization in pixels. border_color: Color of the border (e.g., "#FFFFFF"). If not specified, uses chart background color. access_token: Optional Datawrapper API access token. + timeout: Timeout for the API request in seconds. Returns: Raw PNG image data as bytes. @@ -796,6 +798,7 @@ def export_png( response = client.get( f"{client._CHARTS_URL}/{self.chart_id}/export/png", params=params, + timeout=timeout, ) # Return raw bytes @@ -815,6 +818,7 @@ def export_pdf( border_width: int = 0, border_color: str | None = None, access_token: str | None = None, + timeout: int = 30, ) -> bytes: """Export chart as PDF and return the raw bytes. @@ -828,6 +832,7 @@ def export_pdf( border_width: Margin around visualization. border_color: Color of the border (e.g., "#FFFFFF"). If not specified, uses chart background color. access_token: Optional Datawrapper API access token. + timeout: Timeout for the API request in seconds. Returns: Raw PDF document data as bytes. @@ -874,6 +879,7 @@ def export_pdf( response = client.get( f"{client._CHARTS_URL}/{self.chart_id}/export/png", params=params, + timeout=timeout, ) if width is not None: params["width"] = str(width) @@ -900,6 +906,7 @@ def export_svg( height: int | None = None, plain: bool = False, access_token: str | None = None, + timeout: int = 30, ) -> bytes: """Export chart as SVG and return the raw bytes. @@ -908,6 +915,7 @@ def export_svg( height: Height of visualization. If not specified, uses chart height. plain: If True, exports only the visualization without header/footer. access_token: Optional Datawrapper API access token. + timeout: Timeout for the API request in seconds. Returns: Raw SVG document data as bytes. @@ -942,6 +950,7 @@ def export_svg( response = client.get( f"{client._CHARTS_URL}/{self.chart_id}/export/svg", params=params, + timeout=timeout, ) # Return raw bytes From 6ecbdc8f53ff26507871789e9669e2af4cf63d7d Mon Sep 17 00:00:00 2001 From: palewire Date: Tue, 28 Oct 2025 16:10:30 -0400 Subject: [PATCH 11/11] refactor(charts): remove duplicate parameter handling in PNG export Remove duplicate code block that was setting width, height, and border_color parameters and making the API request in the PNG export method. The parameters are already being set and the request is being made in the subsequent code block, making this duplication unnecessary. --- datawrapper/charts/base.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/datawrapper/charts/base.py b/datawrapper/charts/base.py index c8c220b9..184fe4ca 100644 --- a/datawrapper/charts/base.py +++ b/datawrapper/charts/base.py @@ -868,19 +868,6 @@ def export_pdf( "borderWidth": str(border_width), } - if width is not None: - params["width"] = str(width) - if height is not None: - params["height"] = str(height) - if border_color is not None: - params["borderColor"] = border_color - - # Make the API request - response = client.get( - f"{client._CHARTS_URL}/{self.chart_id}/export/png", - params=params, - timeout=timeout, - ) if width is not None: params["width"] = str(width) if height is not None: