diff --git a/datawrapper/__init__.py b/datawrapper/__init__.py index 45cbfc5..d459953 100644 --- a/datawrapper/__init__.py +++ b/datawrapper/__init__.py @@ -64,6 +64,7 @@ SymbolDisplay, SymbolShape, SymbolStyle, + TextAlign, ValueLabelAlignment, ValueLabelDisplay, ValueLabelMode, @@ -136,6 +137,7 @@ "SymbolDisplay", "SymbolShape", "SymbolStyle", + "TextAlign", "ValueLabelAlignment", "ValueLabelDisplay", "ValueLabelMode", diff --git a/datawrapper/charts/__init__.py b/datawrapper/charts/__init__.py index d9775ef..19d26d2 100644 --- a/datawrapper/charts/__init__.py +++ b/datawrapper/charts/__init__.py @@ -36,6 +36,7 @@ SymbolDisplay, SymbolShape, SymbolStyle, + TextAlign, ValueLabelAlignment, ValueLabelDisplay, ValueLabelMode, @@ -102,6 +103,7 @@ "SymbolDisplay", "SymbolShape", "SymbolStyle", + "TextAlign", "Transform", "Describe", "Logo", diff --git a/datawrapper/charts/annos.py b/datawrapper/charts/annos.py index 026a9e2..0e61569 100644 --- a/datawrapper/charts/annos.py +++ b/datawrapper/charts/annos.py @@ -8,6 +8,7 @@ LineInterpolation, StrokeType, StrokeWidth, + TextAlign, ) @@ -88,12 +89,33 @@ def validate_stroke(cls, v: StrokeWidth | int) -> StrokeWidth | int: ) #: The style of the circle at the end of the connector line - circle_style: Literal["solid", "dashed"] = Field( + circle_style: StrokeType | str = Field( default="solid", alias="circleStyle", description="The style of the circle at the end of the connector line", ) + @field_validator("circle_style") + @classmethod + def validate_circle_style(cls, v: StrokeType | str) -> StrokeType | str: + """Validate that circle_style is either solid or dashed (not dotted). + + Handles both string and enum inputs. DOTTED is not allowed. + """ + # Handle enum inputs + if isinstance(v, StrokeType): + if v not in [StrokeType.SOLID, StrokeType.DASHED]: + raise ValueError( + f"Invalid circle style: {v.value}. Must be either 'solid' or 'dashed'" + ) + # Handle string inputs + elif isinstance(v, str): + if v not in ["solid", "dashed"]: + raise ValueError( + f"Invalid circle style: {v}. Must be either 'solid' or 'dashed'" + ) + return v + #: The radius of the circle at the end of the connector line circle_radius: int = Field( default=15, @@ -164,10 +186,22 @@ class TextAnnotation(BaseModel): text: str = Field(min_length=1, description="The text to display") #: The alignment of the text - align: Literal["tl", "tc", "tr", "ml", "mc", "mr", "bl", "bc", "br"] = Field( + align: TextAlign | str = Field( default="tl", description="The alignment of the text" ) + @field_validator("align") + @classmethod + def validate_align(cls, v: TextAlign | str) -> TextAlign | str: + """Validate that align is a valid TextAlign value.""" + if isinstance(v, str): + valid_values = [e.value for e in TextAlign] + if v not in valid_values: + raise ValueError( + f"Invalid text alignment: {v}. Must be one of {valid_values}" + ) + return v + #: The color of the text color: str | bool = Field( default=False, # If you don't set a color, it will default to the Datawrapper standard @@ -180,6 +214,16 @@ class TextAnnotation(BaseModel): description="The width of the text as a percentage of the chart width", ) + @field_validator("width") + @classmethod + def validate_width(cls, v: float) -> float: + """Validate that width is between 0.0 and 100.0.""" + if not 0.0 <= v <= 100.0: + raise ValueError( + f"Invalid width: {v}. Must be between 0.0 and 100.0 (inclusive)" + ) + return v + #: Whether or not to italicize the text italic: bool = Field( default=False, description="Whether or not to italicize the text" @@ -329,6 +373,16 @@ class AreaFill(BaseModel): #: The opacity of the fill opacity: float = Field(default=0.3, description="The opacity of the fill") + @field_validator("opacity") + @classmethod + def validate_opacity(cls, v: float) -> float: + """Validate that opacity is between 0.0 and 1.0.""" + if not 0.0 <= v <= 1.0: + raise ValueError( + f"Invalid opacity: {v}. Must be between 0.0 and 1.0 (inclusive)" + ) + return v + #: Whether to use different colors when there are negative values use_mixed_colors: bool = Field( default=False, diff --git a/datawrapper/charts/enums/__init__.py b/datawrapper/charts/enums/__init__.py index c4a69be..7b611b9 100644 --- a/datawrapper/charts/enums/__init__.py +++ b/datawrapper/charts/enums/__init__.py @@ -19,6 +19,7 @@ ScatterSize, ) from .symbol_shape import SymbolDisplay, SymbolShape, SymbolStyle +from .text_align import TextAlign from .value_label import ( ValueLabelAlignment, ValueLabelDisplay, @@ -50,6 +51,7 @@ "SymbolDisplay", "SymbolShape", "SymbolStyle", + "TextAlign", "ValueLabelAlignment", "ValueLabelDisplay", "ValueLabelMode", diff --git a/datawrapper/charts/enums/text_align.py b/datawrapper/charts/enums/text_align.py new file mode 100644 index 0000000..2da6ac8 --- /dev/null +++ b/datawrapper/charts/enums/text_align.py @@ -0,0 +1,50 @@ +"""Text alignment enums for annotations.""" + +from enum import Enum + + +class TextAlign(str, Enum): + """Text alignment positions for annotations. + + Represents a 3x3 grid of alignment positions combining vertical and horizontal alignment. + + Examples: + >>> from datawrapper.charts import TextAnnotation, TextAlign + >>> anno = TextAnnotation( + ... text="Top left corner", x=10, y=20, align=TextAlign.TOP_LEFT + ... ) + >>> anno.align + + + >>> # Using raw string (backwards compatible) + >>> anno = TextAnnotation(text="Center", x=50, y=50, align="mc") + >>> anno.align + 'mc' + """ + + #: Top-left alignment + TOP_LEFT = "tl" + + #: Top-center alignment + TOP_CENTER = "tc" + + #: Top-right alignment + TOP_RIGHT = "tr" + + #: Middle-left alignment + MIDDLE_LEFT = "ml" + + #: Middle-center alignment + MIDDLE_CENTER = "mc" + + #: Middle-right alignment + MIDDLE_RIGHT = "mr" + + #: Bottom-left alignment + BOTTOM_LEFT = "bl" + + #: Bottom-center alignment + BOTTOM_CENTER = "bc" + + #: Bottom-right alignment + BOTTOM_RIGHT = "br" diff --git a/docs/user-guide/api/enums.rst b/docs/user-guide/api/enums.rst index 34e6dda..cef9c41 100644 --- a/docs/user-guide/api/enums.rst +++ b/docs/user-guide/api/enums.rst @@ -227,6 +227,16 @@ SymbolStyle .. enum-table:: datawrapper.charts.enums.SymbolStyle +TextAlign +--------- + +.. code-block:: python + + import datawrapper as dw + chart = dw.TextAnnotation(align=dw.TextAlign.TOP_LEFT) + +.. enum-table:: datawrapper.charts.enums.TextAlign + ValueLabelAlignment ------------------- diff --git a/tests/unit/test_area_fill_opacity_validator.py b/tests/unit/test_area_fill_opacity_validator.py new file mode 100644 index 0000000..bb0b6cd --- /dev/null +++ b/tests/unit/test_area_fill_opacity_validator.py @@ -0,0 +1,94 @@ +"""Unit tests for AreaFill opacity field validator.""" + +import pytest +from pydantic import ValidationError + +from datawrapper.charts.annos import AreaFill + + +def test_opacity_valid_default(): + """Test that the default opacity value (0.3) is valid.""" + fill = AreaFill(**{"from": "baseline", "to": "new"}) + assert fill.opacity == 0.3 + + +def test_opacity_valid_zero(): + """Test that opacity=0.0 is valid (minimum boundary).""" + fill = AreaFill(**{"from": "baseline", "to": "new", "opacity": 0.0}) + assert fill.opacity == 0.0 + + +def test_opacity_valid_one(): + """Test that opacity=1.0 is valid (maximum boundary).""" + fill = AreaFill(**{"from": "baseline", "to": "new", "opacity": 1.0}) + assert fill.opacity == 1.0 + + +def test_opacity_valid_middle(): + """Test that a middle value like 0.5 is valid.""" + fill = AreaFill(**{"from": "baseline", "to": "new", "opacity": 0.5}) + assert fill.opacity == 0.5 + + +def test_opacity_valid_decimal(): + """Test that decimal values are valid.""" + fill = AreaFill(**{"from": "baseline", "to": "new", "opacity": 0.75}) + assert fill.opacity == 0.75 + + +def test_opacity_invalid_negative(): + """Test that negative opacity values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + AreaFill(**{"from": "baseline", "to": "new", "opacity": -0.1}) + + error = exc_info.value.errors()[0] + assert "Invalid opacity: -0.1" in str(error.get("ctx", {}).get("error", "")) + assert "Must be between 0.0 and 1.0" in str(error.get("ctx", {}).get("error", "")) + + +def test_opacity_invalid_over_one(): + """Test that opacity values over 1.0 raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + AreaFill(**{"from": "baseline", "to": "new", "opacity": 1.1}) + + error = exc_info.value.errors()[0] + assert "Invalid opacity: 1.1" in str(error.get("ctx", {}).get("error", "")) + assert "Must be between 0.0 and 1.0" in str(error.get("ctx", {}).get("error", "")) + + +def test_opacity_invalid_large_value(): + """Test that very large opacity values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + AreaFill(**{"from": "baseline", "to": "new", "opacity": 2.0}) + + error = exc_info.value.errors()[0] + assert "Invalid opacity: 2.0" in str(error.get("ctx", {}).get("error", "")) + assert "Must be between 0.0 and 1.0" in str(error.get("ctx", {}).get("error", "")) + + +def test_opacity_invalid_very_negative(): + """Test that very negative opacity values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + AreaFill(**{"from": "baseline", "to": "new", "opacity": -1.0}) + + error = exc_info.value.errors()[0] + assert "Invalid opacity: -1.0" in str(error.get("ctx", {}).get("error", "")) + assert "Must be between 0.0 and 1.0" in str(error.get("ctx", {}).get("error", "")) + + +def test_opacity_boundary_just_below_zero(): + """Test that opacity just below 0.0 raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + AreaFill(**{"from": "baseline", "to": "new", "opacity": -0.01}) + + error = exc_info.value.errors()[0] + assert "Invalid opacity: -0.01" in str(error.get("ctx", {}).get("error", "")) + + +def test_opacity_boundary_just_above_one(): + """Test that opacity just above 1.0 raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + AreaFill(**{"from": "baseline", "to": "new", "opacity": 1.01}) + + error = exc_info.value.errors()[0] + assert "Invalid opacity: 1.01" in str(error.get("ctx", {}).get("error", "")) diff --git a/tests/unit/test_connector_line_circle_style_validator.py b/tests/unit/test_connector_line_circle_style_validator.py new file mode 100644 index 0000000..eee3e50 --- /dev/null +++ b/tests/unit/test_connector_line_circle_style_validator.py @@ -0,0 +1,120 @@ +"""Unit tests for ConnectorLine circle_style field validator.""" + +import pytest +from pydantic import ValidationError + +from datawrapper.charts.annos import ConnectorLine +from datawrapper.charts.enums.annos import StrokeType + + +def test_circle_style_valid_default(): + """Test that the default circle_style value ('solid') is valid.""" + connector = ConnectorLine() + assert connector.circle_style == "solid" + + +def test_circle_style_valid_enum_solid(): + """Test that StrokeType.SOLID enum value is valid.""" + connector = ConnectorLine(circleStyle=StrokeType.SOLID) + assert connector.circle_style == StrokeType.SOLID + + +def test_circle_style_valid_enum_dashed(): + """Test that StrokeType.DASHED enum value is valid.""" + connector = ConnectorLine(circleStyle=StrokeType.DASHED) + assert connector.circle_style == StrokeType.DASHED + + +def test_circle_style_valid_string_solid(): + """Test that 'solid' string value is valid.""" + connector = ConnectorLine(circleStyle="solid") + assert connector.circle_style == "solid" + + +def test_circle_style_valid_string_dashed(): + """Test that 'dashed' string value is valid.""" + connector = ConnectorLine(circleStyle="dashed") + assert connector.circle_style == "dashed" + + +def test_circle_style_invalid_enum_dotted(): + """Test that StrokeType.DOTTED enum value raises ValidationError. + + Note: Pydantic converts enum values to strings before validation, + so the error message will show 'dotted' not 'StrokeType.DOTTED'. + """ + with pytest.raises(ValidationError) as exc_info: + ConnectorLine(circleStyle=StrokeType.DOTTED) + + error = exc_info.value.errors()[0] + # Pydantic converts the enum to its string value before validation + assert "Invalid circle style: dotted" in str(error.get("ctx", {}).get("error", "")) + assert "Must be either 'solid' or 'dashed'" in str( + error.get("ctx", {}).get("error", "") + ) + + +def test_circle_style_invalid_string_dotted(): + """Test that 'dotted' string value raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + ConnectorLine(circleStyle="dotted") + + error = exc_info.value.errors()[0] + assert "Invalid circle style: dotted" in str(error.get("ctx", {}).get("error", "")) + assert "Must be either 'solid' or 'dashed'" in str( + error.get("ctx", {}).get("error", "") + ) + + +def test_circle_style_invalid_string_random(): + """Test that invalid string values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + ConnectorLine(circleStyle="invalid") + + error = exc_info.value.errors()[0] + assert "Invalid circle style: invalid" in str(error.get("ctx", {}).get("error", "")) + assert "Must be either 'solid' or 'dashed'" in str( + error.get("ctx", {}).get("error", "") + ) + + +def test_circle_style_invalid_string_empty(): + """Test that empty string raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + ConnectorLine(circleStyle="") + + error = exc_info.value.errors()[0] + assert "Invalid circle style:" in str(error.get("ctx", {}).get("error", "")) + assert "Must be either 'solid' or 'dashed'" in str( + error.get("ctx", {}).get("error", "") + ) + + +def test_circle_style_case_sensitive(): + """Test that circle_style validation is case-sensitive.""" + with pytest.raises(ValidationError) as exc_info: + ConnectorLine(circleStyle="SOLID") + + error = exc_info.value.errors()[0] + assert "Invalid circle style: SOLID" in str(error.get("ctx", {}).get("error", "")) + + +def test_circle_style_with_other_fields(): + """Test that circle_style validation works when other fields are set.""" + connector = ConnectorLine(type="straight", stroke=2, circleStyle=StrokeType.DASHED) + assert connector.circle_style == StrokeType.DASHED + assert connector.type == "straight" + assert connector.stroke == 2 + + +def test_circle_style_serialization_enum(): + """Test that enum values serialize correctly.""" + connector = ConnectorLine(circleStyle=StrokeType.SOLID) + # Pydantic should serialize the enum to its string value + assert connector.model_dump()["circle_style"] == "solid" + + +def test_circle_style_serialization_string(): + """Test that string values serialize correctly.""" + connector = ConnectorLine(circleStyle="dashed") + assert connector.model_dump()["circle_style"] == "dashed" diff --git a/tests/unit/test_text_align_enum.py b/tests/unit/test_text_align_enum.py new file mode 100644 index 0000000..d270ef0 --- /dev/null +++ b/tests/unit/test_text_align_enum.py @@ -0,0 +1,80 @@ +"""Unit tests for TextAlign enum.""" + +import pytest +from pydantic import ValidationError + +from datawrapper import TextAlign +from datawrapper.charts import TextAnnotation + + +def test_text_align_enum_values(): + """Test that TextAlign enum has all expected values.""" + assert TextAlign.TOP_LEFT == "tl" + assert TextAlign.TOP_CENTER == "tc" + assert TextAlign.TOP_RIGHT == "tr" + assert TextAlign.MIDDLE_LEFT == "ml" + assert TextAlign.MIDDLE_CENTER == "mc" + assert TextAlign.MIDDLE_RIGHT == "mr" + assert TextAlign.BOTTOM_LEFT == "bl" + assert TextAlign.BOTTOM_CENTER == "bc" + assert TextAlign.BOTTOM_RIGHT == "br" + + +def test_text_align_enum_count(): + """Test that TextAlign enum has exactly 9 values.""" + assert len(TextAlign) == 9 + + +def test_text_annotation_with_enum(): + """Test TextAnnotation accepts TextAlign enum values.""" + anno = TextAnnotation(text="Test annotation", x=10, y=20, align=TextAlign.TOP_LEFT) + assert anno.align == TextAlign.TOP_LEFT + + +def test_text_annotation_with_string(): + """Test TextAnnotation accepts valid string values for backwards compatibility.""" + anno = TextAnnotation(text="Test annotation", x=10, y=20, align="tc") + assert anno.align == "tc" + + +def test_text_annotation_invalid_string(): + """Test TextAnnotation rejects invalid string values.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test annotation", x=10, y=20, align="invalid") + assert "Invalid text alignment" in str(exc_info.value) + + +def test_text_annotation_all_enum_values(): + """Test TextAnnotation accepts all TextAlign enum values.""" + for align_value in TextAlign: + anno = TextAnnotation(text="Test", x=0, y=0, align=align_value) + assert anno.align == align_value + + +def test_text_annotation_all_string_values(): + """Test TextAnnotation accepts all valid string values.""" + valid_strings = ["tl", "tc", "tr", "ml", "mc", "mr", "bl", "bc", "br"] + for align_str in valid_strings: + anno = TextAnnotation(text="Test", x=0, y=0, align=align_str) + assert anno.align == align_str + + +def test_text_align_import_from_top_level(): + """Test that TextAlign can be imported from top-level datawrapper package.""" + from datawrapper import TextAlign as TopLevelTextAlign + + assert TopLevelTextAlign.TOP_LEFT == "tl" + + +def test_text_align_import_from_charts(): + """Test that TextAlign can be imported from datawrapper.charts.""" + from datawrapper.charts import TextAlign as ChartsTextAlign + + assert ChartsTextAlign.TOP_LEFT == "tl" + + +def test_text_align_import_from_enums(): + """Test that TextAlign can be imported from datawrapper.charts.enums.""" + from datawrapper.charts.enums import TextAlign as EnumsTextAlign + + assert EnumsTextAlign.TOP_LEFT == "tl" diff --git a/tests/unit/test_text_annotation_width_validator.py b/tests/unit/test_text_annotation_width_validator.py new file mode 100644 index 0000000..cd5a9c7 --- /dev/null +++ b/tests/unit/test_text_annotation_width_validator.py @@ -0,0 +1,94 @@ +"""Unit tests for TextAnnotation width field validator.""" + +import pytest +from pydantic import ValidationError + +from datawrapper.charts.annos import TextAnnotation + + +def test_width_valid_default(): + """Test that the default width value (33.3) is valid.""" + anno = TextAnnotation(text="Test", x=0, y=0) + assert anno.width == 33.3 + + +def test_width_valid_zero(): + """Test that width=0.0 is valid (minimum boundary).""" + anno = TextAnnotation(text="Test", x=0, y=0, width=0.0) + assert anno.width == 0.0 + + +def test_width_valid_hundred(): + """Test that width=100.0 is valid (maximum boundary).""" + anno = TextAnnotation(text="Test", x=0, y=0, width=100.0) + assert anno.width == 100.0 + + +def test_width_valid_middle(): + """Test that a middle value like 50.0 is valid.""" + anno = TextAnnotation(text="Test", x=0, y=0, width=50.0) + assert anno.width == 50.0 + + +def test_width_valid_decimal(): + """Test that decimal values are valid.""" + anno = TextAnnotation(text="Test", x=0, y=0, width=25.5) + assert anno.width == 25.5 + + +def test_width_invalid_negative(): + """Test that negative width values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test", x=0, y=0, width=-1.0) + + error = exc_info.value.errors()[0] + assert "Invalid width: -1.0" in str(error["ctx"]["error"]) + assert "Must be between 0.0 and 100.0" in str(error["ctx"]["error"]) + + +def test_width_invalid_over_hundred(): + """Test that width values over 100.0 raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test", x=0, y=0, width=100.1) + + error = exc_info.value.errors()[0] + assert "Invalid width: 100.1" in str(error["ctx"]["error"]) + assert "Must be between 0.0 and 100.0" in str(error["ctx"]["error"]) + + +def test_width_invalid_large_value(): + """Test that very large width values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test", x=0, y=0, width=500.0) + + error = exc_info.value.errors()[0] + assert "Invalid width: 500.0" in str(error["ctx"]["error"]) + assert "Must be between 0.0 and 100.0" in str(error["ctx"]["error"]) + + +def test_width_invalid_very_negative(): + """Test that very negative width values raise ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test", x=0, y=0, width=-50.0) + + error = exc_info.value.errors()[0] + assert "Invalid width: -50.0" in str(error["ctx"]["error"]) + assert "Must be between 0.0 and 100.0" in str(error["ctx"]["error"]) + + +def test_width_boundary_just_below_zero(): + """Test that width just below 0.0 raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test", x=0, y=0, width=-0.1) + + error = exc_info.value.errors()[0] + assert "Invalid width: -0.1" in str(error["ctx"]["error"]) + + +def test_width_boundary_just_above_hundred(): + """Test that width just above 100.0 raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + TextAnnotation(text="Test", x=0, y=0, width=100.01) + + error = exc_info.value.errors()[0] + assert "Invalid width: 100.01" in str(error["ctx"]["error"])