diff --git a/datawrapper/charts/arrow.py b/datawrapper/charts/arrow.py
index 7d05648..871b625 100644
--- a/datawrapper/charts/arrow.py
+++ b/datawrapper/charts/arrow.py
@@ -49,6 +49,13 @@ class ArrowChart(BaseChart):
# Customize arrows
#
+ #: The base color for the arrows
+ base_color: str | int = Field(
+ default=0,
+ alias="base-color",
+ description="The base color for the arrows",
+ )
+
#: A mapping of layer names to colors
color_category: dict[str, str] = Field(
default_factory=dict,
@@ -152,6 +159,20 @@ class ArrowChart(BaseChart):
description="The column that arrows should end at",
)
+ #: The axes to color by
+ axis_colors: str = Field(
+ default="",
+ alias="axis-colors",
+ description="The axes to color by",
+ )
+
+ #: The axes to label by
+ axis_labels: str = Field(
+ default="",
+ alias="axis-labels",
+ description="The axes to label by",
+ )
+
#
# Features
#
@@ -203,6 +224,7 @@ def serialize_model(self) -> dict:
"y-grid": self.y_grid,
"reverse-order": self.reverse_order,
"thick-arrows": self.thick_arrows,
+ "base-color": self.base_color,
"color-category": ColorCategory.serialize(self.color_category),
"range-value-labels": self.range_value_labels,
"sort-range": {
@@ -224,6 +246,10 @@ def serialize_model(self) -> dict:
"start": self.axis_start,
"end": self.axis_end,
}
+ if self.axis_colors:
+ model["metadata"]["axes"]["colors"] = self.axis_colors
+ if self.axis_labels:
+ model["metadata"]["axes"]["labels"] = self.axis_labels
# Return the serialized data
return model
@@ -254,6 +280,10 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
if "thick-arrows" in visualize:
init_data["thick_arrows"] = visualize["thick-arrows"]
+ # Base color
+ if "base-color" in visualize:
+ init_data["base_color"] = visualize["base-color"]
+
# Parse color-category using utility
color_data = ColorCategory.deserialize(visualize.get("color-category"))
init_data["color_category"] = color_data["color_category"]
@@ -291,6 +321,10 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
init_data["axis_start"] = axes["start"]
if "end" in axes:
init_data["axis_end"] = axes["end"]
+ if "colors" in axes:
+ init_data["axis_colors"] = axes["colors"]
+ if "labels" in axes:
+ init_data["axis_labels"] = axes["labels"]
# Features
if "color-by-column" in visualize:
diff --git a/docs/user-guide/charts/arrow-charts.md b/docs/user-guide/charts/arrow-charts.md
index aa406e2..ee87461 100644
--- a/docs/user-guide/charts/arrow-charts.md
+++ b/docs/user-guide/charts/arrow-charts.md
@@ -1,47 +1,53 @@
# ArrowChart
-## Example
+This example, drawn from the Datawrapper documentation, demonstrates how to create an arrow chart with customized sorting and highlighted elements.
-This example demonstrates how to create an arrow chart showing income inequality (Gini index) before and after taxes across different countries. The chart highlights how tax policies affect income distribution.
+
```python
import pandas as pd
import datawrapper as dw
# Load data from GitHub
-url = "https://raw.githubusercontent.com/palewire/datawrapper-api-classes/main/tests/samples/arrow/inequality.csv"
-df = pd.read_csv(url)
+url = "https://raw.githubusercontent.com/chekos/datawrapper/main/tests/samples/arrow/inequality.csv"
+df = pd.read_csv(url, sep="\t")
# Create arrow chart
chart = dw.ArrowChart(
# Chart title
title="Many European countries bring income inequality down with taxes. The US and Mexico: Not so much.",
+ # The description line with a bit of HTML
+ intro="Income inequality (gini index) in selected OECD countries in 2014, before and after taxes. A gini index of 0 means that every household earns exactly the same income, while an index of 1 means that one household in the country makes all the income. The lower the Gini index, the more equal the income is distributed in a country.",
# Data source attribution
source_name="OECD",
+ # The byline
+ byline="Lisa Charlotte Rost, Datawrapper",
# Pass the DataFrame
data=df,
# Start column (Gini before taxes)
axis_start="Gini before taxes",
# End column (Gini after taxes)
axis_end="Gini after taxes",
- # Custom Y-axis range
- custom_range_y=[0.15, 0.6],
- # Y-axis grid format (thousands separator with optional decimals)
- y_grid_format=dw.NumberFormat.THOUSANDS_WITH_OPTIONAL_DECIMALS,
- # Value label format (one decimal place)
- value_label_format=dw.NumberFormat.ONE_DECIMAL,
+ # Custom X-axis range
+ range_extent="custom",
+ custom_range=[0.15, 0.6],
+ # Value label format (three decimal places)
+ value_label_format=dw.NumberFormat.THREE_DECIMALS,
# Sort by the start column
- sort_by="Gini before taxes",
+ sort_by="end",
# Enable sorting
sort_ranges=True,
- # Label position (right side)
- labeling="right",
# Show arrow key/legend
- show_arrow_key=True,
+ arrow_key=True,
+ # Set the default arrow color
+ base_color="rgb(196, 148, 67)",
# Highlight specific countries in red
+ axis_colors="Country",
+ axis_labels="Country",
+ color_by_column=True,
color_category={
- "Mexico": "#c71e1d",
- "United States": "#c71e1d"
+ "Mexico": "#c71e1d",
+ "United States": "#c71e1d"
}
)
@@ -49,9 +55,6 @@ chart = dw.ArrowChart(
chart.create()
```
-
-
## Reference
```{eval-rst}
diff --git a/tests/integration/test_arrow_chart.py b/tests/integration/test_arrow_chart.py
index 73be098..81430a5 100644
--- a/tests/integration/test_arrow_chart.py
+++ b/tests/integration/test_arrow_chart.py
@@ -27,6 +27,22 @@ def load_sample_csv(filename: str) -> str:
class TestArrowChartCreation:
"""Tests for ArrowChart creation and serialization."""
+ def test_serialize_with_axis_colors_and_labels(self):
+ """Test serializing with axis_colors and axis_labels."""
+ chart = ArrowChart(
+ title="Test",
+ data=pd.DataFrame({"Country": ["A", "B"], "Start": [1, 2], "End": [3, 4]}),
+ axis_start="Start",
+ axis_end="End",
+ axis_colors="Country",
+ axis_labels="Country",
+ )
+
+ serialized = chart.serialize_model()
+
+ assert serialized["metadata"]["axes"]["colors"] == "Country"
+ assert serialized["metadata"]["axes"]["labels"] == "Country"
+
def test_create_basic_arrow_chart(self):
"""Test creating a basic arrow chart."""
chart = ArrowChart(