diff --git a/datawrapper/charts/annos.py b/datawrapper/charts/annos.py index 340e9c1..b01dd97 100644 --- a/datawrapper/charts/annos.py +++ b/datawrapper/charts/annos.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from .enums import ArrowHead, ConnectorLineType, LineInterpolation, StrokeWidth @@ -321,8 +321,10 @@ class AreaFill(BaseModel): #: The line to fill upwards to to_column: str = Field(alias="to", description="The line to fill upwards to") - #: The color of the fill - color: str = Field(default="#15607a", description="The color of the fill") + #: The color of the fill (hex string or palette index) + color: str | int = Field( + default=0, description="The color of the fill (hex string or palette index)" + ) #: The opacity of the fill opacity: float = Field(default=0.3, description="The opacity of the fill") @@ -334,11 +336,11 @@ class AreaFill(BaseModel): description="Whether to use different colors when there are negative values", ) - #: The color of the fill when it is negative - color_negative: str = Field( - default="#cc0000", + #: The color of the fill when it is negative (hex string or palette index, None = disabled) + color_negative: str | int | None = Field( + default=None, alias="colorNegative", - description="The color of the fill when it is negative", + description="The color of the fill when it is negative (hex string or palette index, None = disabled)", ) #: The interpolation method to use when drawing lines @@ -360,21 +362,39 @@ def validate_interpolation( ) return v + @model_validator(mode="after") + def auto_enable_mixed_colors(self) -> "AreaFill": + """Auto-enable use_mixed_colors when color_negative is provided. + + If a user provides a color_negative value (not None), + automatically enable use_mixed_colors to make the feature work as expected. + """ + # Only auto-enable if color_negative is provided and use_mixed_colors is False + if self.color_negative is not None and not self.use_mixed_colors: + self.use_mixed_colors = True + return self + def serialize_model(self) -> dict: """Serialize the model to a dictionary for the Datawrapper API. Note: The 'id' field is not included in the output as it's used as the dict key. + Only includes colorNegative if it's not None. """ - return { + result = { "from": self.from_column, "to": self.to_column, "color": self.color, "opacity": self.opacity, "useMixedColors": self.use_mixed_colors, - "colorNegative": self.color_negative, "interpolation": self.interpolation, } + # Only include colorNegative if it's provided (not None) + if self.color_negative is not None: + result["colorNegative"] = self.color_negative + + return result + @classmethod def deserialize_model(cls, api_data: dict[str, dict] | list | None) -> list[dict]: """Deserialize area fills from API response format.