From b3ab362a7504871471e80360a72a1966ebb53157 Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 28 Oct 2025 19:37:39 -0400 Subject: [PATCH] feat(charts): add specialized annotation subclasses for X/Y lines and ranges Add four new annotation classes that extend RangeAnnotation with more specific interfaces: - XRangeAnnotation: horizontal range between two x positions - YRangeAnnotation: vertical range between two y positions - XLineAnnotation: vertical line at a specific x position - YLineAnnotation: horizontal line at a specific y position These subclasses automatically set appropriate type and display values, and include validation to ensure required position parameters are provided. Also adds field validators to RangeAnnotation for type, display, and opacity values. This improves the API ergonomics by providing clearer, more intuitive classes for common annotation patterns while maintaining backward compatibility with the base RangeAnnotation class. --- datawrapper/__init__.py | 8 + datawrapper/charts/__init__.py | 14 +- datawrapper/charts/annos.py | 108 ++++++ docs/user-guide/api/models.rst | 20 + .../unit/test_range_annotation_subclasses.py | 352 ++++++++++++++++++ 5 files changed, 501 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_range_annotation_subclasses.py diff --git a/datawrapper/__init__.py b/datawrapper/__init__.py index bf23b14..45cbfc5 100644 --- a/datawrapper/__init__.py +++ b/datawrapper/__init__.py @@ -35,6 +35,10 @@ StackedBarChart, TextAnnotation, Transform, + XLineAnnotation, + XRangeAnnotation, + YLineAnnotation, + YRangeAnnotation, ) from datawrapper.charts.enums import ( ArrowHead, @@ -104,6 +108,10 @@ "StackedBarChart", "TextAnnotation", "RangeAnnotation", + "XRangeAnnotation", + "YRangeAnnotation", + "XLineAnnotation", + "YLineAnnotation", "ConnectorLine", "ArrowHead", "ConnectorLineType", diff --git a/datawrapper/charts/__init__.py b/datawrapper/charts/__init__.py index 045cd69..d9775ef 100644 --- a/datawrapper/charts/__init__.py +++ b/datawrapper/charts/__init__.py @@ -1,4 +1,12 @@ -from .annos import ConnectorLine, RangeAnnotation, TextAnnotation +from .annos import ( + ConnectorLine, + RangeAnnotation, + TextAnnotation, + XLineAnnotation, + XRangeAnnotation, + YLineAnnotation, + YRangeAnnotation, +) from .area import AreaChart from .arrow import ArrowChart from .bar import BarChart, BarOverlay @@ -60,6 +68,10 @@ "ConnectorLine", "RangeAnnotation", "TextAnnotation", + "XLineAnnotation", + "XRangeAnnotation", + "YLineAnnotation", + "YRangeAnnotation", "Annotate", "ColumnFormat", "ColumnFormatList", diff --git a/datawrapper/charts/annos.py b/datawrapper/charts/annos.py index 40878a3..026a9e2 100644 --- a/datawrapper/charts/annos.py +++ b/datawrapper/charts/annos.py @@ -441,6 +441,14 @@ class RangeAnnotation(BaseModel): default="x", description="The axis of the annotation" ) + @field_validator("type") + @classmethod + def validate_type(cls, v: str) -> str: + """Validate that type is either 'x' or 'y'.""" + if v not in ["x", "y"]: + raise ValueError(f"Invalid type: {v}. Must be either 'x' or 'y'") + return v + #: The color of the annotation color: str = Field(default="#989898", description="The color of the annotation") @@ -449,9 +457,25 @@ class RangeAnnotation(BaseModel): default="range", description="The display style of the annotation" ) + @field_validator("display") + @classmethod + def validate_display(cls, v: str) -> str: + """Validate that display is either 'line' or 'range'.""" + if v not in ["line", "range"]: + raise ValueError(f"Invalid display: {v}. Must be either 'line' or 'range'") + return v + #: The opacity of the annotation opacity: int = Field(default=50, description="The opacity of the annotation") + @field_validator("opacity") + @classmethod + def validate_opacity(cls, v: int) -> int: + """Validate that opacity is between 0 and 100.""" + if not 0 <= v <= 100: + raise ValueError(f"Invalid opacity: {v}. Must be between 0 and 100") + return v + #: The first x position (required for type="x" annotations) x0: Any | None = Field( default=None, @@ -555,3 +579,87 @@ def deserialize_model(cls, api_data: dict[str, dict] | None) -> list[dict]: return [] return [{**anno_data, "id": anno_id} for anno_id, anno_data in api_data.items()] + + +class XRangeAnnotation(RangeAnnotation): + """A horizontal range annotation (shaded area between two x positions). + + Automatically sets type="x" and display="range". + Requires both x0 and x1 to be provided. + """ + + def __init__(self, **data): + """Initialize with type="x" and display="range" automatically set.""" + data.setdefault("type", "x") + data.setdefault("display", "range") + super().__init__(**data) + + @model_validator(mode="after") + def validate_x_positions_required(self) -> "XRangeAnnotation": + """Validate that both x0 and x1 are provided.""" + if self.x0 is None or self.x1 is None: + raise ValueError("XRangeAnnotation requires both x0 and x1 to be set") + return self + + +class YRangeAnnotation(RangeAnnotation): + """A vertical range annotation (shaded area between two y positions). + + Automatically sets type="y" and display="range". + Requires both y0 and y1 to be provided. + """ + + def __init__(self, **data): + """Initialize with type="y" and display="range" automatically set.""" + data.setdefault("type", "y") + data.setdefault("display", "range") + super().__init__(**data) + + @model_validator(mode="after") + def validate_y_positions_required(self) -> "YRangeAnnotation": + """Validate that both y0 and y1 are provided.""" + if self.y0 is None or self.y1 is None: + raise ValueError("YRangeAnnotation requires both y0 and y1 to be set") + return self + + +class XLineAnnotation(RangeAnnotation): + """A vertical line annotation at a specific x position. + + Automatically sets type="x" and display="line". + Requires x0 to be provided. + """ + + def __init__(self, **data): + """Initialize with type="x" and display="line" automatically set.""" + data.setdefault("type", "x") + data.setdefault("display", "line") + super().__init__(**data) + + @model_validator(mode="after") + def validate_x0_required(self) -> "XLineAnnotation": + """Validate that x0 is provided.""" + if self.x0 is None: + raise ValueError("XLineAnnotation requires x0 to be set") + return self + + +class YLineAnnotation(RangeAnnotation): + """A horizontal line annotation at a specific y position. + + Automatically sets type="y" and display="line". + Requires y0 to be provided. + """ + + def __init__(self, **data): + """Initialize with type="y" and display="line" automatically set.""" + data.setdefault("type", "y") + data.setdefault("display", "line") + super().__init__(**data) + + @model_validator(mode="after") + def validate_y0_required(self) -> "YLineAnnotation": + """Validate that y0 is provided.""" + if self.y0 is None: + raise ValueError("YLineAnnotation requires y0 to be set") + return self diff --git a/docs/user-guide/api/models.rst b/docs/user-guide/api/models.rst index 127a6bc..993f3b9 100644 --- a/docs/user-guide/api/models.rst +++ b/docs/user-guide/api/models.rst @@ -17,6 +17,26 @@ Annotations :members: :show-inheritance: +.. autoclass:: XRangeAnnotation + :members: + :show-inheritance: + +.. autoclass:: YRangeAnnotation + :members: + :show-inheritance: + +.. autoclass:: XLineAnnotation + :members: + :show-inheritance: + +.. autoclass:: YLineAnnotation + :members: + :show-inheritance: + +.. autoclass:: AreaFill + :members: + :show-inheritance: + .. autoclass:: ConnectorLine :members: :show-inheritance: diff --git a/tests/unit/test_range_annotation_subclasses.py b/tests/unit/test_range_annotation_subclasses.py new file mode 100644 index 0000000..6836fb8 --- /dev/null +++ b/tests/unit/test_range_annotation_subclasses.py @@ -0,0 +1,352 @@ +"""Test RangeAnnotation subclasses and validators.""" + +import pytest +from pydantic import ValidationError + +from datawrapper.charts.annos import ( + RangeAnnotation, + XLineAnnotation, + XRangeAnnotation, + YLineAnnotation, + YRangeAnnotation, +) +from datawrapper.charts.enums import StrokeType, StrokeWidth + + +class TestRangeAnnotationValidators: + """Test validators on the base RangeAnnotation class.""" + + def test_type_validator_accepts_valid_values(self): + """Test that type validator accepts 'x' and 'y'.""" + anno_x = RangeAnnotation(type="x", x0=0, x1=10) + assert anno_x.type == "x" + + anno_y = RangeAnnotation(type="y", y0=0, y1=10) + assert anno_y.type == "y" + + def test_type_validator_rejects_invalid_values(self): + """Test that type validator rejects invalid values.""" + with pytest.raises(ValidationError) as exc_info: + RangeAnnotation(type="z", x0=0, x1=10) + # Pydantic's Literal validation happens before custom validator + assert "Input should be 'x' or 'y'" in str(exc_info.value) + + def test_display_validator_accepts_valid_values(self): + """Test that display validator accepts 'line' and 'range'.""" + anno_line = RangeAnnotation(display="line", x0=0) + assert anno_line.display == "line" + + anno_range = RangeAnnotation(display="range", x0=0, x1=10) + assert anno_range.display == "range" + + def test_display_validator_rejects_invalid_values(self): + """Test that display validator rejects invalid values.""" + with pytest.raises(ValidationError) as exc_info: + RangeAnnotation(display="invalid", x0=0) + # Pydantic's Literal validation happens before custom validator + assert "Input should be 'line' or 'range'" in str(exc_info.value) + + def test_opacity_validator_accepts_valid_range(self): + """Test that opacity validator accepts values 0-100.""" + # Test boundary values + anno_0 = RangeAnnotation(opacity=0, x0=0) + assert anno_0.opacity == 0 + + anno_100 = RangeAnnotation(opacity=100, x0=0) + assert anno_100.opacity == 100 + + # Test middle value + anno_50 = RangeAnnotation(opacity=50, x0=0) + assert anno_50.opacity == 50 + + def test_opacity_validator_rejects_out_of_range(self): + """Test that opacity validator rejects values outside 0-100.""" + with pytest.raises(ValidationError) as exc_info: + RangeAnnotation(opacity=-1, x0=0) + assert "Invalid opacity: -1" in str(exc_info.value) + assert "Must be between 0 and 100" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + RangeAnnotation(opacity=101, x0=0) + assert "Invalid opacity: 101" in str(exc_info.value) + assert "Must be between 0 and 100" in str(exc_info.value) + + def test_stroke_type_validator_accepts_enum_values(self): + """Test that stroke_type validator accepts StrokeType enum values.""" + anno = RangeAnnotation(stroke_type=StrokeType.DASHED, x0=0) + assert anno.stroke_type == "dashed" + + def test_stroke_type_validator_accepts_valid_strings(self): + """Test that stroke_type validator accepts valid string values.""" + anno = RangeAnnotation(stroke_type="dotted", x0=0) + assert anno.stroke_type == "dotted" + + def test_stroke_type_validator_rejects_invalid_strings(self): + """Test that stroke_type validator rejects invalid string values.""" + with pytest.raises(ValidationError) as exc_info: + RangeAnnotation(stroke_type="invalid", x0=0) + assert "Invalid stroke type: invalid" in str(exc_info.value) + + def test_stroke_width_validator_accepts_enum_values(self): + """Test that stroke_width validator accepts StrokeWidth enum values.""" + anno = RangeAnnotation(stroke_width=StrokeWidth.THICK, x0=0) + assert anno.stroke_width == 3 + + def test_stroke_width_validator_accepts_valid_ints(self): + """Test that stroke_width validator accepts valid int values.""" + anno = RangeAnnotation(stroke_width=2, x0=0) + assert anno.stroke_width == 2 + + def test_stroke_width_validator_rejects_invalid_ints(self): + """Test that stroke_width validator rejects invalid int values.""" + with pytest.raises(ValidationError) as exc_info: + RangeAnnotation(stroke_width=99, x0=0) + assert "Invalid stroke width: 99" in str(exc_info.value) + + +class TestXRangeAnnotation: + """Test XRangeAnnotation subclass.""" + + def test_automatic_field_setting(self): + """Test that type and display are automatically set.""" + anno = XRangeAnnotation(x0=0, x1=10) + assert anno.type == "x" + assert anno.display == "range" + + def test_requires_both_x_positions(self): + """Test that both x0 and x1 are required.""" + # Should work with both positions + anno = XRangeAnnotation(x0=0, x1=10) + assert anno.x0 == 0 + assert anno.x1 == 10 + + # Should fail without x0 + with pytest.raises(ValidationError) as exc_info: + XRangeAnnotation(x1=10) + assert "requires both x0 and x1" in str(exc_info.value) + + # Should fail without x1 + with pytest.raises(ValidationError) as exc_info: + XRangeAnnotation(x0=0) + assert "requires both x0 and x1" in str(exc_info.value) + + # Should fail without both + with pytest.raises(ValidationError) as exc_info: + XRangeAnnotation() + assert "requires both x0 and x1" in str(exc_info.value) + + def test_accepts_custom_color_and_opacity(self): + """Test that custom color and opacity can be set.""" + anno = XRangeAnnotation(x0=0, x1=10, color="#ff0000", opacity=75) + assert anno.color == "#ff0000" + assert anno.opacity == 75 + + def test_serialization(self): + """Test that serialization works correctly.""" + anno = XRangeAnnotation(x0=0, x1=10, color="#ff0000") + serialized = anno.serialize_model() + assert serialized["type"] == "x" + assert serialized["display"] == "range" + assert serialized["position"]["x0"] == 0 + assert serialized["position"]["x1"] == 10 + assert serialized["color"] == "#ff0000" + + def test_cannot_override_type_or_display(self): + """Test that type and display cannot be overridden to invalid values.""" + # Type is automatically set to "x", so trying to set it to "y" should still result in "x" + anno = XRangeAnnotation(x0=0, x1=10, type="x") + assert anno.type == "x" + + # Display is automatically set to "range" + anno = XRangeAnnotation(x0=0, x1=10, display="range") + assert anno.display == "range" + + +class TestYRangeAnnotation: + """Test YRangeAnnotation subclass.""" + + def test_automatic_field_setting(self): + """Test that type and display are automatically set.""" + anno = YRangeAnnotation(y0=0, y1=10) + assert anno.type == "y" + assert anno.display == "range" + + def test_requires_both_y_positions(self): + """Test that both y0 and y1 are required.""" + # Should work with both positions + anno = YRangeAnnotation(y0=0, y1=10) + assert anno.y0 == 0 + assert anno.y1 == 10 + + # Should fail without y0 + with pytest.raises(ValidationError) as exc_info: + YRangeAnnotation(y1=10) + assert "requires both y0 and y1" in str(exc_info.value) + + # Should fail without y1 + with pytest.raises(ValidationError) as exc_info: + YRangeAnnotation(y0=0) + assert "requires both y0 and y1" in str(exc_info.value) + + # Should fail without both + with pytest.raises(ValidationError) as exc_info: + YRangeAnnotation() + assert "requires both y0 and y1" in str(exc_info.value) + + def test_accepts_custom_color_and_opacity(self): + """Test that custom color and opacity can be set.""" + anno = YRangeAnnotation(y0=0, y1=10, color="#00ff00", opacity=25) + assert anno.color == "#00ff00" + assert anno.opacity == 25 + + def test_serialization(self): + """Test that serialization works correctly.""" + anno = YRangeAnnotation(y0=0, y1=10, color="#00ff00") + serialized = anno.serialize_model() + assert serialized["type"] == "y" + assert serialized["display"] == "range" + assert serialized["position"]["y0"] == 0 + assert serialized["position"]["y1"] == 10 + assert serialized["color"] == "#00ff00" + + +class TestXLineAnnotation: + """Test XLineAnnotation subclass.""" + + def test_automatic_field_setting(self): + """Test that type and display are automatically set.""" + anno = XLineAnnotation(x0=5) + assert anno.type == "x" + assert anno.display == "line" + + def test_requires_x0_position(self): + """Test that x0 is required.""" + # Should work with x0 + anno = XLineAnnotation(x0=5) + assert anno.x0 == 5 + + # Should fail without x0 + with pytest.raises(ValidationError) as exc_info: + XLineAnnotation() + assert "requires x0 to be set" in str(exc_info.value) + + def test_x1_is_optional(self): + """Test that x1 is optional for line annotations.""" + anno = XLineAnnotation(x0=5) + assert anno.x1 is None + + # Can also provide x1 if desired (though not typical for line) + anno_with_x1 = XLineAnnotation(x0=5, x1=10) + assert anno_with_x1.x0 == 5 + assert anno_with_x1.x1 == 10 + + def test_accepts_stroke_customization(self): + """Test that stroke type and width can be customized.""" + anno = XLineAnnotation( + x0=5, stroke_type=StrokeType.DASHED, stroke_width=StrokeWidth.THICK + ) + assert anno.stroke_type == "dashed" + assert anno.stroke_width == 3 + + def test_serialization(self): + """Test that serialization works correctly.""" + anno = XLineAnnotation(x0=5, color="#0000ff", stroke_type="dotted") + serialized = anno.serialize_model() + assert serialized["type"] == "x" + assert serialized["display"] == "line" + assert serialized["position"]["x0"] == 5 + assert "x1" not in serialized["position"] # x1 should not be included if None + assert serialized["color"] == "#0000ff" + assert serialized["strokeType"] == "dotted" + + +class TestYLineAnnotation: + """Test YLineAnnotation subclass.""" + + def test_automatic_field_setting(self): + """Test that type and display are automatically set.""" + anno = YLineAnnotation(y0=5) + assert anno.type == "y" + assert anno.display == "line" + + def test_requires_y0_position(self): + """Test that y0 is required.""" + # Should work with y0 + anno = YLineAnnotation(y0=5) + assert anno.y0 == 5 + + # Should fail without y0 + with pytest.raises(ValidationError) as exc_info: + YLineAnnotation() + assert "requires y0 to be set" in str(exc_info.value) + + def test_y1_is_optional(self): + """Test that y1 is optional for line annotations.""" + anno = YLineAnnotation(y0=5) + assert anno.y1 is None + + # Can also provide y1 if desired (though not typical for line) + anno_with_y1 = YLineAnnotation(y0=5, y1=10) + assert anno_with_y1.y0 == 5 + assert anno_with_y1.y1 == 10 + + def test_accepts_stroke_customization(self): + """Test that stroke type and width can be customized.""" + anno = YLineAnnotation( + y0=5, stroke_type=StrokeType.DOTTED, stroke_width=StrokeWidth.MEDIUM + ) + assert anno.stroke_type == "dotted" + assert anno.stroke_width == 2 + + def test_serialization(self): + """Test that serialization works correctly.""" + anno = YLineAnnotation(y0=5, color="#ff00ff", stroke_width=3) + serialized = anno.serialize_model() + assert serialized["type"] == "y" + assert serialized["display"] == "line" + assert serialized["position"]["y0"] == 5 + assert "y1" not in serialized["position"] # y1 should not be included if None + assert serialized["color"] == "#ff00ff" + assert serialized["strokeWidth"] == 3 + + +class TestSubclassInheritance: + """Test that subclasses properly inherit validators from parent.""" + + def test_xrange_inherits_opacity_validator(self): + """Test that XRangeAnnotation inherits opacity validator.""" + with pytest.raises(ValidationError) as exc_info: + XRangeAnnotation(x0=0, x1=10, opacity=150) + assert "Invalid opacity: 150" in str(exc_info.value) + + def test_yrange_inherits_type_validator(self): + """Test that YRangeAnnotation inherits type validator.""" + # Type is auto-set to "y", but validator should still work + anno = YRangeAnnotation(y0=0, y1=10) + assert anno.type == "y" + + def test_xline_inherits_stroke_validators(self): + """Test that XLineAnnotation inherits stroke validators.""" + with pytest.raises(ValidationError) as exc_info: + XLineAnnotation(x0=5, stroke_type="invalid") + assert "Invalid stroke type: invalid" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + XLineAnnotation(x0=5, stroke_width=99) + assert "Invalid stroke width: 99" in str(exc_info.value) + + def test_yline_inherits_all_validators(self): + """Test that YLineAnnotation inherits all validators from parent.""" + # Test opacity validator + with pytest.raises(ValidationError) as exc_info: + YLineAnnotation(y0=5, opacity=200) + assert "Invalid opacity: 200" in str(exc_info.value) + + # Test stroke validators + with pytest.raises(ValidationError) as exc_info: + YLineAnnotation(y0=5, stroke_type="bad") + assert "Invalid stroke type: bad" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + YLineAnnotation(y0=5, stroke_width=100) + assert "Invalid stroke width: 100" in str(exc_info.value)