From 8de506466f8de9f004e5ee4ad396056f7644072c Mon Sep 17 00:00:00 2001 From: FelixAbrahamsson Date: Mon, 19 May 2025 15:30:21 +0200 Subject: [PATCH] feature: support tag annotations --- next_cvat/__init__.py | 13 +- next_cvat/annotations.py | 30 ++++- next_cvat/types/__init__.py | 2 + next_cvat/types/image_annotation.py | 6 + next_cvat/types/tag.py | 13 ++ tests/test_tag.py | 193 ++++++++++++++++++++++++++++ 6 files changed, 255 insertions(+), 2 deletions(-) create mode 100644 next_cvat/types/tag.py create mode 100644 tests/test_tag.py diff --git a/next_cvat/__init__.py b/next_cvat/__init__.py index bfb6388..6682968 100644 --- a/next_cvat/__init__.py +++ b/next_cvat/__init__.py @@ -4,4 +4,15 @@ from .annotations import Annotations from .client import Client -from .types import Attribute, Box, Label, LabelAttribute, Mask, Polygon, Project, Task +from .types import ( + Attribute, + Box, + ImageAnnotation, + Label, + LabelAttribute, + Mask, + Polygon, + Project, + Tag, + Task, +) diff --git a/next_cvat/annotations.py b/next_cvat/annotations.py index 73ab7af..3e463b1 100644 --- a/next_cvat/annotations.py +++ b/next_cvat/annotations.py @@ -18,6 +18,7 @@ Polygon, Polyline, Project, + Tag, Task, ) @@ -182,6 +183,21 @@ def from_path( Ellipse(**ellipse.attrib, attributes=ellipse_attributes) ) + # Parse tags + tags = [] + for tag in image.findall("tag"): + tag_attributes = [ + Attribute(name=attr.get("name"), value=attr.text) + for attr in tag.findall("attribute") + ] + tags.append( + Tag( + label=tag.get("label"), + source=tag.get("source", "manual"), + attributes=tag_attributes, + ) + ) + # Get job_id from task_job_mapping if available task_id = image.get("task_id") job_id = task_job_mapping.get(task_id) if task_id else None @@ -200,6 +216,7 @@ def from_path( masks=masks, polylines=polylines, ellipses=ellipses, + tags=tags, ) ) @@ -365,6 +382,18 @@ def save_xml_(self, path: Union[str, Path]) -> Annotations: attr_elem.set("name", attr.name) attr_elem.text = attr.value + # Add tags + for tag in image.tags: + tag_elem = ElementTree.SubElement(image_elem, "tag") + tag_elem.set("label", tag.label) + tag_elem.set("source", tag.source) + + if tag.attributes: + for attr in tag.attributes: + attr_elem = ElementTree.SubElement(tag_elem, "attribute") + attr_elem.set("name", attr.name) + attr_elem.text = attr.value + root.append(image_elem) # Create XML tree and save to file @@ -465,7 +494,6 @@ def get_images_from_completed_tasks(self) -> List[ImageAnnotation]: completed_task_ids = self.get_completed_task_ids() return [image for image in self.images if image.task_id in completed_task_ids] - def create_cvat_link(self, image_name: str) -> str: """Create a CVAT link for the given image name. diff --git a/next_cvat/types/__init__.py b/next_cvat/types/__init__.py index 3449b7b..f585d4b 100644 --- a/next_cvat/types/__init__.py +++ b/next_cvat/types/__init__.py @@ -9,6 +9,7 @@ from .polygon import Polygon from .polyline import Polyline from .project import Project +from .tag import Tag from .task import Task __all__ = [ @@ -23,5 +24,6 @@ "Polygon", "Polyline", "Project", + "Tag", "Task", ] diff --git a/next_cvat/types/image_annotation.py b/next_cvat/types/image_annotation.py index 765a90f..7d16afe 100644 --- a/next_cvat/types/image_annotation.py +++ b/next_cvat/types/image_annotation.py @@ -9,6 +9,7 @@ from .mask import Mask from .polygon import Polygon from .polyline import Polyline +from .tag import Tag class ImageAnnotation(BaseModel): @@ -29,6 +30,7 @@ class ImageAnnotation(BaseModel): masks: List of mask annotations polylines: List of polyline annotations ellipses: List of ellipse annotations + tags: List of tag annotations Example: ```python @@ -48,6 +50,9 @@ class ImageAnnotation(BaseModel): ], ellipses=[ Ellipse(label="defect", cx=500, cy=600, rx=50, ry=30) + ], + tags=[ + Tag(label="interesting", source="manual", attributes=[]) ] ) ``` @@ -65,3 +70,4 @@ class ImageAnnotation(BaseModel): masks: List[Mask] = [] polylines: List[Polyline] = [] ellipses: List[Ellipse] = [] + tags: List[Tag] = [] diff --git a/next_cvat/types/tag.py b/next_cvat/types/tag.py new file mode 100644 index 0000000..e696b88 --- /dev/null +++ b/next_cvat/types/tag.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import List + +from pydantic import BaseModel + +from .attribute import Attribute + + +class Tag(BaseModel): + label: str + source: str + attributes: List[Attribute] diff --git a/tests/test_tag.py b/tests/test_tag.py new file mode 100644 index 0000000..3c3e7fc --- /dev/null +++ b/tests/test_tag.py @@ -0,0 +1,193 @@ +import tempfile +from pathlib import Path +from xml.etree import ElementTree + +import pytest + +import next_cvat +from next_cvat import Annotations, Attribute, ImageAnnotation, Project, Tag, Task + + +def test_tag_creation(): + """Test that a tag can be created.""" + tag = Tag(label="no-crack", source="manual", attributes=[]) + assert tag.label == "no-crack" + assert tag.source == "manual" + assert tag.attributes == [] + + +def test_tag_with_attributes(): + """Test that a tag can be created with attributes.""" + tag = Tag( + label="no-crack", + source="manual", + attributes=[Attribute(name="confidence", value="0.95")], + ) + assert tag.label == "no-crack" + assert tag.source == "manual" + assert len(tag.attributes) == 1 + assert tag.attributes[0].name == "confidence" + assert tag.attributes[0].value == "0.95" + + +def test_load_and_save_tags(): + """Test that tags can be loaded from and saved to XML.""" + # Create test XML with a tag + xml_content = """ + + 1.1 + + + 123 + Test Project + 2021-01-01T00:00:00Z + 2021-01-01T00:00:00Z + + + + + + + + 0.95 + + + + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + xml_path = Path(tmp_dir) / "annotations.xml" + with open(xml_path, "w") as f: + f.write(xml_content) + + # Load annotations + annotations = Annotations.from_path(xml_path) + + # Check that tag was loaded + assert len(annotations.images) == 1 + image = annotations.images[0] + assert len(image.tags) == 1 + tag = image.tags[0] + assert tag.label == "no-crack" + assert tag.source == "manual" + assert len(tag.attributes) == 1 + assert tag.attributes[0].name == "confidence" + assert tag.attributes[0].value == "0.95" + + # Save annotations + output_path = Path(tmp_dir) / "output.xml" + annotations.save_xml_(output_path) + + # Parse saved XML and check tag + tree = ElementTree.parse(output_path) + root = tree.getroot() + + image_elem = root.find("image") + assert image_elem is not None + + tag_elem = image_elem.find("tag") + assert tag_elem is not None + assert tag_elem.get("label") == "no-crack" + assert tag_elem.get("source") == "manual" + + attr_elem = tag_elem.find("attribute") + assert attr_elem is not None + assert attr_elem.get("name") == "confidence" + assert attr_elem.text == "0.95" + + +def test_add_tag_to_annotations(): + """Test adding a tag to existing annotations.""" + # Create a simple annotations object + project = Project( + id="123", + name="Test Project", + created="2021-01-01T00:00:00Z", + updated="2021-01-01T00:00:00Z", + labels=[], + ) + + image = ImageAnnotation( + id="1", + name="image1.jpg", + width=800, + height=600, + task_id="1", + ) + + annotations = Annotations( + version="1.1", + project=project, + tasks=[Task(task_id="1", name="Task 1")], + images=[image], + ) + + # Add a tag + tag = Tag(label="no-crack", source="manual", attributes=[]) + annotations.images[0].tags.append(tag) + + # Check that tag was added + assert len(annotations.images[0].tags) == 1 + assert annotations.images[0].tags[0].label == "no-crack" + + # Save and load annotations + with tempfile.TemporaryDirectory() as tmp_dir: + xml_path = Path(tmp_dir) / "annotations.xml" + annotations.save_xml_(xml_path) + + loaded_annotations = Annotations.from_path(xml_path) + + # Check that tag is still there + assert len(loaded_annotations.images[0].tags) == 1 + assert loaded_annotations.images[0].tags[0].label == "no-crack" + + +def test_with_real_example(): + """Test using the example from the annotations.xml file.""" + # Create test XML with the provided example + xml_content = """ + + 1.1 + + + 123 + Test Project + 2021-01-01T00:00:00Z + 2021-01-01T00:00:00Z + + + + + + + + + + + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + xml_path = Path(tmp_dir) / "annotations.xml" + with open(xml_path, "w") as f: + f.write(xml_content) + + # Load annotations + annotations = Annotations.from_path(xml_path) + + # Check that tag was loaded + assert len(annotations.images) == 1 + image = annotations.images[0] + assert image.name == "some_image.png" + assert len(image.tags) == 1 + tag = image.tags[0] + assert tag.label == "no-crack" + assert tag.source == "manual" + assert len(tag.attributes) == 0