Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions datawrapper/charts/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class ArrowChart(BaseChart):
# Customize arrows
#

#: The base color for the arrows
base_color: str | int = Field(
default=0,
alias="base-color",
description="The base color for the arrows",
)

#: A mapping of layer names to colors
color_category: dict[str, str] = Field(
default_factory=dict,
Expand Down Expand Up @@ -152,6 +159,20 @@ class ArrowChart(BaseChart):
description="The column that arrows should end at",
)

#: The axes to color by
axis_colors: str = Field(
default="",
alias="axis-colors",
description="The axes to color by",
)

#: The axes to label by
axis_labels: str = Field(
default="",
alias="axis-labels",
description="The axes to label by",
)

#
# Features
#
Expand Down Expand Up @@ -203,6 +224,7 @@ def serialize_model(self) -> dict:
"y-grid": self.y_grid,
"reverse-order": self.reverse_order,
"thick-arrows": self.thick_arrows,
"base-color": self.base_color,
"color-category": ColorCategory.serialize(self.color_category),
"range-value-labels": self.range_value_labels,
"sort-range": {
Expand All @@ -224,6 +246,10 @@ def serialize_model(self) -> dict:
"start": self.axis_start,
"end": self.axis_end,
}
if self.axis_colors:
model["metadata"]["axes"]["colors"] = self.axis_colors
if self.axis_labels:
model["metadata"]["axes"]["labels"] = self.axis_labels

# Return the serialized data
return model
Expand Down Expand Up @@ -254,6 +280,10 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
if "thick-arrows" in visualize:
init_data["thick_arrows"] = visualize["thick-arrows"]

# Base color
if "base-color" in visualize:
init_data["base_color"] = visualize["base-color"]

# Parse color-category using utility
color_data = ColorCategory.deserialize(visualize.get("color-category"))
init_data["color_category"] = color_data["color_category"]
Expand Down Expand Up @@ -291,6 +321,10 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
init_data["axis_start"] = axes["start"]
if "end" in axes:
init_data["axis_end"] = axes["end"]
if "colors" in axes:
init_data["axis_colors"] = axes["colors"]
if "labels" in axes:
init_data["axis_labels"] = axes["labels"]

# Features
if "color-by-column" in visualize:
Expand Down
41 changes: 22 additions & 19 deletions docs/user-guide/charts/arrow-charts.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,60 @@
# ArrowChart

## Example
This example, drawn from <a href="https://www.datawrapper.de/charts/arrow-plot">the Datawrapper documentation</a>, demonstrates how to create an arrow chart with customized sorting and highlighted elements.

This example demonstrates how to create an arrow chart showing income inequality (Gini index) before and after taxes across different countries. The chart highlights how tax policies affect income distribution.
<iframe title="Many European countries bring income inequality down with taxes. The US and Mexico: Not so much." aria-label="Arrow Plot" id="datawrapper-chart-W0zuU" src="https://datawrapper.dwcdn.net/W0zuU/1/" scrolling="no" frameborder="0" style="width: 0; min-width: 100% !important; border: none; margin-bottom: 20px;" height="416" data-external="1"></iframe><script type="text/javascript">window.addEventListener("message",function(a){if(void 0!==a.data["datawrapper-height"]){var e=document.querySelectorAll("iframe");for(var t in a.data["datawrapper-height"])for(var r,i=0;r=e[i];i++)if(r.contentWindow===a.source){var d=a.data["datawrapper-height"][t]+"px";r.style.height=d}}});</script>

```python
import pandas as pd
import datawrapper as dw

# Load data from GitHub
url = "https://raw.githubusercontent.com/palewire/datawrapper-api-classes/main/tests/samples/arrow/inequality.csv"
df = pd.read_csv(url)
url = "https://raw.githubusercontent.com/chekos/datawrapper/main/tests/samples/arrow/inequality.csv"
df = pd.read_csv(url, sep="\t")

# Create arrow chart
chart = dw.ArrowChart(
# Chart title
title="Many European countries bring income inequality down with taxes. The US and Mexico: Not so much.",
# The description line with a bit of HTML
intro="Income inequality (gini index) in selected OECD countries in 2014, before and after taxes. A gini index of 0 means that every household earns exactly the same income, while an index of 1 means that one household in the country makes all the income. <b>The lower the Gini index, the more equal the income is distributed in a country.</b>",
# Data source attribution
source_name="OECD",
# The byline
byline="Lisa Charlotte Rost, Datawrapper",
# Pass the DataFrame
data=df,
# Start column (Gini before taxes)
axis_start="Gini before taxes",
# End column (Gini after taxes)
axis_end="Gini after taxes",
# Custom Y-axis range
custom_range_y=[0.15, 0.6],
# Y-axis grid format (thousands separator with optional decimals)
y_grid_format=dw.NumberFormat.THOUSANDS_WITH_OPTIONAL_DECIMALS,
# Value label format (one decimal place)
value_label_format=dw.NumberFormat.ONE_DECIMAL,
# Custom X-axis range
range_extent="custom",
custom_range=[0.15, 0.6],
# Value label format (three decimal places)
value_label_format=dw.NumberFormat.THREE_DECIMALS,
# Sort by the start column
sort_by="Gini before taxes",
sort_by="end",
# Enable sorting
sort_ranges=True,
# Label position (right side)
labeling="right",
# Show arrow key/legend
show_arrow_key=True,
arrow_key=True,
# Set the default arrow color
base_color="rgb(196, 148, 67)",
# Highlight specific countries in red
axis_colors="Country",
axis_labels="Country",
color_by_column=True,
color_category={
"Mexico": "#c71e1d",
"United States": "#c71e1d"
"<b>Mexico</b>": "#c71e1d",
"<b>United States</b>": "#c71e1d"
}
)

# Create the chart in Datawrapper
chart.create()
```

<iframe title="Many European countries bring income inequality down with taxes. The US and Mexico: Not so much." aria-label="Arrow Plot" id="datawrapper-chart-cjX4C" src="https://datawrapper.dwcdn.net/cjX4C/1/" scrolling="no" frameborder="0" style="width: 0; min-width: 100% !important; border: none;" height="656" data-external="1"></iframe><script type="text/javascript">!function(){"use strict";window.addEventListener("message",(function(a){if(void 0!==a.data["datawrapper-height"]){var e=document.querySelectorAll("iframe");for(var t in a.data["datawrapper-height"])for(var r=0;r<e.length;r++)if(e[r].contentWindow===a.source){var i=a.data["datawrapper-height"][t]+"px";e[r].style.height=i}}}))}();
</script>

## Reference

```{eval-rst}
Expand Down
16 changes: 16 additions & 0 deletions tests/integration/test_arrow_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def load_sample_csv(filename: str) -> str:
class TestArrowChartCreation:
"""Tests for ArrowChart creation and serialization."""

def test_serialize_with_axis_colors_and_labels(self):
"""Test serializing with axis_colors and axis_labels."""
chart = ArrowChart(
title="Test",
data=pd.DataFrame({"Country": ["A", "B"], "Start": [1, 2], "End": [3, 4]}),
axis_start="Start",
axis_end="End",
axis_colors="Country",
axis_labels="Country",
)

serialized = chart.serialize_model()

assert serialized["metadata"]["axes"]["colors"] == "Country"
assert serialized["metadata"]["axes"]["labels"] == "Country"

def test_create_basic_arrow_chart(self):
"""Test creating a basic arrow chart."""
chart = ArrowChart(
Expand Down