Skip to content

Commit 88385ac

Browse files
committed
Fix tests
1 parent 6bab737 commit 88385ac

File tree

4 files changed

+33
-18
lines changed

4 files changed

+33
-18
lines changed

superannotate/db/annotation_classes.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,6 @@ def download_annotation_classes_json(project, folder):
296296

297297

298298
def fill_class_and_attribute_names(annotations_json, annotation_classes_dict):
299-
if "instances" not in annotations_json:
300-
return
301299
for r in annotations_json["instances"]:
302300
if "classId" in r and r["classId"] in annotation_classes_dict:
303301
r["className"] = annotation_classes_dict[r["classId"]]["name"]

superannotate/db/images.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,17 @@ def get_image_annotations(project, image_name, project_type=None):
753753
raise SABaseException(response.status_code, response.text)
754754
res_json = response.json()
755755
fill_class_and_attribute_names(res_json, annotation_classes_dict)
756-
url = res["pixelSave"]["url"]
757-
annotation_mask_filename = url.rsplit('/', 1)[-1]
758-
headers = res["pixelSave"]["headers"]
759-
response = requests.get(url=url, headers=headers)
760-
if not response.ok:
761-
raise SABaseException(response.status_code, response.text)
762-
mask = io.BytesIO(response.content)
756+
if len(res_json["instances"]) != 0:
757+
url = res["pixelSave"]["url"]
758+
annotation_mask_filename = url.rsplit('/', 1)[-1]
759+
headers = res["pixelSave"]["headers"]
760+
response = requests.get(url=url, headers=headers)
761+
if not response.ok:
762+
raise SABaseException(response.status_code, response.text)
763+
mask = io.BytesIO(response.content)
764+
else:
765+
mask = None
766+
annotation_mask_filename = None
763767
return {
764768
"annotation_json": res_json,
765769
"annotation_json_filename": annotation_json_filename,
@@ -800,11 +804,14 @@ def download_image_annotations(project, image_name, local_dir_path):
800804
else:
801805
with open(json_path, "w") as f:
802806
json.dump(annotation["annotation_json"], f, indent=4)
803-
mask_path = Path(local_dir_path
804-
) / annotation["annotation_mask_filename"]
807+
if annotation["annotation_mask_filename"] is not None:
808+
mask_path = Path(local_dir_path
809+
) / annotation["annotation_mask_filename"]
810+
with open(mask_path, "wb") as f:
811+
f.write(annotation["annotation_mask"].getbuffer())
812+
else:
813+
mask_path = None
805814
return_filepaths.append(str(mask_path))
806-
with open(mask_path, "wb") as f:
807-
f.write(annotation["annotation_mask"].getbuffer())
808815

809816
return tuple(return_filepaths)
810817

superannotate/db/projects.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,15 @@ def upload_images_from_folder_to_project(
539539

540540

541541
def create_empty_annotation(size):
542-
return {"metadata": {'height': size[1], 'width': size[0]}}
542+
return {
543+
"metadata": {
544+
'height': size[1],
545+
'width': size[0]
546+
},
547+
"instances": [],
548+
"comments": [],
549+
"tags": []
550+
}
543551

544552

545553
def upload_image_array_to_s3(

tests/test_basic_images.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@ def test_basic_images(project_type, name, description, from_folder, tmpdir):
4242
sa.download_image(project, image_name, tmpdir, True)
4343
assert sa.get_image_preannotations(project, image_name
4444
)["preannotation_json_filename"] is None
45-
assert sa.get_image_annotations(project, image_name
46-
)["annotation_json_filename"] is None
45+
assert len(
46+
sa.get_image_annotations(project,
47+
image_name)["annotation_json"]["instances"]
48+
) == 0
4749
sa.download_image_annotations(project, image_name, tmpdir)
48-
assert len(list(Path(tmpdir).glob("*"))) == 1
50+
assert len(list(Path(tmpdir).glob("*"))) == 2
4951
sa.download_image_preannotations(project, image_name, tmpdir)
50-
assert len(list(Path(tmpdir).glob("*"))) == 1
52+
assert len(list(Path(tmpdir).glob("*"))) == 2
5153

5254
assert (Path(tmpdir) / image_name).is_file()
5355

0 commit comments

Comments
 (0)