Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions datawrapper/charts/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ class ArrowChart(BaseChart):
description="The column to label by",
)

#: The column to group arrows by
groups_column: str | None = Field(
default=None,
description="The column to group arrows by",
)

#
# Features
#
Expand All @@ -180,13 +186,6 @@ class ArrowChart(BaseChart):
description="Label on the first arrow that shows column names",
)

#: Enables the group-by-column feature, works with "Group" field
group_by_column: bool = Field(
default=False,
alias="group-by-column",
description="Enables the group-by-column feature, works with 'Group' field",
)

@field_validator("replace_flags")
@classmethod
def validate_replace_flags(
Expand Down Expand Up @@ -224,7 +223,7 @@ def serialize_model(self) -> dict:
"range-extent": self.range_extent,
"value-label-format": self.value_label_format,
"color-by-column": bool(self.color_category),
"group-by-column": self.group_by_column,
"group-by-column": self.groups_column is not None,
"replace-flags": ReplaceFlags.serialize(self.replace_flags),
"show-arrow-key": self.arrow_key,
}
Expand All @@ -240,6 +239,8 @@ def serialize_model(self) -> dict:
axes_dict["colors"] = self.color_column
if self.label_column is not None:
axes_dict["labels"] = self.label_column
if self.groups_column is not None:
axes_dict["groups"] = self.groups_column

# Only add axes section if there are fields to include
if axes_dict:
Expand Down Expand Up @@ -319,10 +320,10 @@ def deserialize_model(cls, api_response: dict[str, Any]) -> dict[str, Any]:
init_data["color_column"] = axes["colors"]
if "labels" in axes:
init_data["label_column"] = axes["labels"]
if "groups" in axes:
init_data["groups_column"] = axes["groups"]

# Features
if "group-by-column" in visualize:
init_data["group_by_column"] = visualize["group-by-column"]
if "show-arrow-key" in visualize:
init_data["arrow_key"] = visualize["show-arrow-key"]

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_arrow_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_serialize_with_features(self):
data=pd.DataFrame({"x": [1, 2], "y": [10, 20], "z": [15, 25]}),
start_column="y",
end_column="z",
group_by_column=True,
groups_column="x", # Setting groups_column should enable group-by-column
arrow_key=True,
)

Expand Down Expand Up @@ -289,7 +289,7 @@ def mock_get(url):
# Verify features
assert chart.thick_arrows is True
assert chart.arrow_key is True
assert chart.group_by_column is True
assert chart.groups_column == "Gender"

# Verify sorting
assert chart.sort_ranges is False
Expand Down