Skip to content

Commit eba2e99

Browse files
Vaghinak BasentsyanVaghinak Basentsyan
authored andcommitted
Merge remote-tracking branch 'origin/develop' into sdk_limitations
# Conflicts: # src/superannotate/lib/app/interface/sdk_interface.py # src/superannotate/lib/core/service_types.py
2 parents 1f59527 + 7bcd4be commit eba2e99

File tree

9 files changed

+186
-27
lines changed

9 files changed

+186
-27
lines changed

src/superannotate/lib/app/interface/sdk_interface.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from lib.core.exceptions import AppValidationException
4444
from lib.core.types import AttributeGroup
4545
from lib.core.types import ClassesJson
46+
from lib.core.types import MLModel
4647
from lib.core.types import Project
4748
from lib.infrastructure.controller import Controller
4849
from plotly.subplots import make_subplots
@@ -2269,7 +2270,6 @@ def download_image(
22692270
)
22702271
if response.errors:
22712272
raise AppException(response.errors)
2272-
logger.info(f"Downloaded image {image_name} to {local_dir_path} ")
22732273
return response.data
22742274

22752275

@@ -2734,7 +2734,7 @@ def stop_model_training(model: dict):
27342734

27352735
@Trackable
27362736
@validate_arguments
2737-
def download_model(model: dict, output_dir: Union[str, Path]):
2737+
def download_model(model: MLModel, output_dir: Union[str, Path]):
27382738
"""Downloads the neural network and related files
27392739
which are the <model_name>.pth/pkl. <model_name>.json, <model_name>.yaml, classes_mapper.json
27402740
@@ -2745,8 +2745,9 @@ def download_model(model: dict, output_dir: Union[str, Path]):
27452745
:return: the metadata of the model
27462746
:rtype: dict
27472747
"""
2748-
2749-
res = controller.download_ml_model(model_data=model, download_path=output_dir)
2748+
res = controller.download_ml_model(
2749+
model_data=model.dict(), download_path=output_dir
2750+
)
27502751
if res.errors:
27512752
logger.error("\n".join([str(error) for error in res.errors]))
27522753
else:

src/superannotate/lib/core/service_types.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Any
12
from typing import Dict
23
from typing import Optional
34
from typing import Union
5+
from typing import List
46

57
from pydantic import BaseModel
68
from pydantic import Extra
@@ -45,11 +47,35 @@ def __init__(self, **data):
4547
super().__init__(**data)
4648

4749

50+
class DownloadMLModelAuthData(BaseModel):
51+
access_key: str
52+
secret_key: str
53+
session_token: str
54+
region: str
55+
bucket: str
56+
paths: List[str]
57+
58+
class Config:
59+
extra = Extra.allow
60+
fields = {
61+
"access_key": "accessKeyId",
62+
"secret_key": "secretAccessKey",
63+
"session_token": "sessionToken",
64+
"region": "region",
65+
}
66+
67+
def __init__(self, **data):
68+
credentials = data["tokens"]
69+
data.update(credentials)
70+
del data["tokens"]
71+
super().__init__(**data)
72+
73+
4874
class ServiceResponse(BaseModel):
4975
status: int
5076
reason: str
5177
content: Union[bytes, str]
52-
data: Optional[Union[UserLimits, UploadAnnotationAuthData, ErrorMessage]]
78+
data: Any
5379

5480
def __init__(self, response, content_type):
5581
data = {

src/superannotate/lib/core/serviceproviders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,9 @@ def delete_model(self, team_id: int, model_id: int):
286286
def stop_model_training(self, team_id: int, model_id: int):
287287
raise NotImplementedError
288288

289-
def get_ml_model_download_tokens(self, team_id: int, model_id: int):
289+
def get_ml_model_download_tokens(
290+
self, team_id: int, model_id: int
291+
) -> ServiceResponse:
290292
raise NotImplementedError
291293

292294
def run_segmentation(

src/superannotate/lib/core/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,14 @@ class Project(BaseModel):
127127

128128
class Config:
129129
extra = Extra.allow
130+
131+
132+
class MLModel(BaseModel):
133+
name: NotEmptyStr
134+
id: int
135+
path: NotEmptyStr
136+
config_path: NotEmptyStr
137+
team_id: Optional[int]
138+
139+
class Config:
140+
extra = Extra.allow

src/superannotate/lib/core/usecases.py

Lines changed: 108 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,7 @@ def __init__(
22062206
image_name: str,
22072207
images: BaseManageableRepository,
22082208
destination: str,
2209+
annotation_classes: BaseManageableRepository,
22092210
):
22102211
super().__init__()
22112212
self._service = service
@@ -2214,6 +2215,7 @@ def __init__(
22142215
self._image_name = image_name
22152216
self._images = images
22162217
self._destination = destination
2218+
self._annotation_classes = annotation_classes
22172219

22182220
@property
22192221
def image_use_case(self):
@@ -2231,6 +2233,85 @@ def validate_project_type(self):
22312233
constances.LIMITED_FUNCTIONS[self._project.project_type]
22322234
)
22332235

2236+
@property
2237+
def annotation_classes_name_map(self) -> dict:
2238+
classes_data = defaultdict(dict)
2239+
annotation_classes = self._annotation_classes.get_all()
2240+
for annotation_class in annotation_classes:
2241+
class_info = {"id": annotation_class.uuid}
2242+
if annotation_class.attribute_groups:
2243+
for attribute_group in annotation_class.attribute_groups:
2244+
attribute_group_data = defaultdict(dict)
2245+
for attribute in attribute_group["attributes"]:
2246+
attribute_group_data[attribute["name"]] = attribute["id"]
2247+
class_info["attribute_groups"] = {
2248+
attribute_group["name"]: {
2249+
"id": attribute_group["id"],
2250+
"attributes": attribute_group_data,
2251+
}
2252+
}
2253+
classes_data[annotation_class.name] = class_info
2254+
return classes_data
2255+
2256+
def get_templates_mapping(self):
2257+
templates = self._service.get_templates(team_id=self._project.team_id).get(
2258+
"data", []
2259+
)
2260+
templates_map = {}
2261+
for template in templates:
2262+
templates_map[template["name"]] = template["id"]
2263+
return templates_map
2264+
2265+
def fill_classes_data(self, annotations: dict):
2266+
annotation_classes = self.annotation_classes_name_map
2267+
if "instances" not in annotations:
2268+
return
2269+
unknown_classes = {}
2270+
for annotation in [i for i in annotations["instances"] if "className" in i]:
2271+
if "className" not in annotation:
2272+
return
2273+
annotation_class_name = annotation["className"]
2274+
if annotation_class_name not in annotation_classes:
2275+
if annotation_class_name not in unknown_classes:
2276+
unknown_classes[annotation_class_name] = {
2277+
"id": -(len(unknown_classes) + 1),
2278+
"attribute_groups": {},
2279+
}
2280+
annotation_classes.update(unknown_classes)
2281+
templates = self.get_templates_mapping()
2282+
for annotation in (
2283+
i for i in annotations["instances"] if i.get("type", None) == "template"
2284+
):
2285+
annotation["templateId"] = templates.get(
2286+
annotation.get("templateName", ""), -1
2287+
)
2288+
2289+
for annotation in [i for i in annotations["instances"] if "className" in i]:
2290+
annotation_class_name = annotation["className"]
2291+
if annotation_class_name not in annotation_classes:
2292+
continue
2293+
annotation["classId"] = annotation_classes[annotation_class_name]["id"]
2294+
for attribute in annotation["attributes"]:
2295+
if (
2296+
attribute["groupName"]
2297+
not in annotation_classes[annotation_class_name]["attribute_groups"]
2298+
):
2299+
continue
2300+
attribute["groupId"] = annotation_classes[annotation_class_name][
2301+
"attribute_groups"
2302+
][attribute["groupName"]]["id"]
2303+
if (
2304+
attribute["name"]
2305+
not in annotation_classes[annotation_class_name][
2306+
"attribute_groups"
2307+
][attribute["groupName"]]["attributes"]
2308+
):
2309+
del attribute["groupId"]
2310+
continue
2311+
attribute["id"] = annotation_classes[annotation_class_name][
2312+
"attribute_groups"
2313+
][attribute["groupName"]]["attributes"]
2314+
22342315
def execute(self):
22352316
if self.is_valid():
22362317
data = {
@@ -2283,6 +2364,7 @@ def execute(self):
22832364
logger.info("There is no blue-map for the image.")
22842365

22852366
json_path = Path(self._destination) / data["annotation_json_filename"]
2367+
self.fill_classes_data(data["annotation_json"])
22862368
with open(json_path, "w") as f:
22872369
json.dump(data["annotation_json"], f, indent=4)
22882370

@@ -3039,6 +3121,7 @@ def __init__(
30393121
images: BaseManageableRepository,
30403122
classes: BaseManageableRepository,
30413123
backend_service_provider: SuerannotateServiceProvider,
3124+
annotation_classes: BaseReadOnlyRepository,
30423125
download_path: str,
30433126
image_variant: str = "original",
30443127
include_annotations: bool = False,
@@ -3065,6 +3148,7 @@ def __init__(
30653148
image_name=self._image.name,
30663149
images=images,
30673150
destination=download_path,
3151+
annotation_classes=annotation_classes,
30683152
)
30693153
self.get_annotation_classes_ues_case = GetAnnotationClassesUseCase(
30703154
classes=classes,
@@ -3098,31 +3182,35 @@ def validate_include_annotations(self):
30983182

30993183
def execute(self):
31003184
if self.is_valid():
3185+
fuse_image = None
3186+
annotations = None
3187+
31013188
image_bytes = self.get_image_use_case.execute().data
31023189
download_path = f"{self._download_path}/{self._image.name}"
31033190
if self._image_variant == "lores":
31043191
download_path = download_path + "___lores.jpg"
31053192
with open(download_path, "wb") as image_file:
31063193
image_file.write(image_bytes.getbuffer())
31073194

3108-
annotations = None
31093195
if self._include_annotations:
31103196
annotations = self.download_annotation_use_case.execute().data
31113197

3112-
fuse_image = None
31133198
if self._include_annotations and self._include_fuse:
31143199
classes = self.get_annotation_classes_ues_case.execute().data
3115-
fuse_image_use_case = CreateFuseImageUseCase(
3116-
project_type=constances.ProjectType.get_name(
3117-
self._project.project_type
3118-
),
3119-
image_path=download_path,
3120-
classes=[
3121-
annotation_class.to_dict() for annotation_class in classes
3122-
],
3123-
generate_overlay=self._include_overlay,
3200+
fuse_image = (
3201+
CreateFuseImageUseCase(
3202+
project_type=constances.ProjectType.get_name(
3203+
self._project.project_type
3204+
),
3205+
image_path=download_path,
3206+
classes=[
3207+
annotation_class.to_dict() for annotation_class in classes
3208+
],
3209+
generate_overlay=self._include_overlay,
3210+
)
3211+
.execute()
3212+
.data
31243213
)
3125-
fuse_image = fuse_image_use_case.execute().data
31263214

31273215
self._response.data = (
31283216
download_path,
@@ -3912,16 +4000,18 @@ def execute(self):
39124000
os.path.basename(self._model.config_path), metrics_name
39134001
)
39144002

3915-
download_token = self._backend_service.get_ml_model_download_tokens(
4003+
auth_response = self._backend_service.get_ml_model_download_tokens(
39164004
self._team_id, self._model.uuid
39174005
)
4006+
if not auth_response.ok:
4007+
raise AppException(auth_response.error)
39184008
s3_session = boto3.Session(
3919-
aws_access_key_id=download_token["tokens"]["accessKeyId"],
3920-
aws_secret_access_key=download_token["tokens"]["secretAccessKey"],
3921-
aws_session_token=download_token["tokens"]["sessionToken"],
3922-
region_name=download_token["tokens"]["region"],
4009+
aws_access_key_id=auth_response.data.access_key,
4010+
aws_secret_access_key=auth_response.data.secret_key,
4011+
aws_session_token=auth_response.data.session_token,
4012+
region_name=auth_response.data.region,
39234013
)
3924-
bucket = s3_session.resource("s3").Bucket(download_token["tokens"]["bucket"])
4014+
bucket = s3_session.resource("s3").Bucket(auth_response.data.bucket)
39254015

39264016
bucket.download_file(
39274017
self._model.config_path, os.path.join(self._download_path, "config.yaml")

src/superannotate/lib/infrastructure/controller.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,9 @@ def download_image_annotations(
994994
image_name=image_name,
995995
images=ImageRepository(service=self._backend_client),
996996
destination=destination,
997+
annotation_classes=AnnotationClassRepository(
998+
service=self._backend_client, project=project
999+
),
9971000
)
9981001
return use_case.execute()
9991002

@@ -1230,6 +1233,9 @@ def download_image(
12301233
include_annotations=include_annotations,
12311234
include_fuse=include_fuse,
12321235
include_overlay=include_overlay,
1236+
annotation_classes=AnnotationClassRepository(
1237+
service=self._backend_client, project=project
1238+
),
12331239
)
12341240
return use_case.execute()
12351241

src/superannotate/lib/infrastructure/services.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import lib.core as constance
1111
import requests.packages.urllib3
1212
from lib.core.exceptions import AppException
13+
from lib.core.service_types import DownloadMLModelAuthData
1314
from lib.core.service_types import ServiceResponse
1415
from lib.core.service_types import UploadAnnotationAuthData
1516
from lib.core.service_types import UserLimits
@@ -934,8 +935,12 @@ def get_ml_model_download_tokens(self, team_id: int, model_id: int):
934935
get_token_url = urljoin(
935936
self.api_url, self.URL_GET_ML_MODEL_DOWNLOAD_TOKEN.format(model_id)
936937
)
937-
res = self._request(get_token_url, "get", params={"team_id": team_id})
938-
return res.json()
938+
return self._request(
939+
get_token_url,
940+
"get",
941+
params={"team_id": team_id},
942+
content_type=DownloadMLModelAuthData,
943+
)
939944

940945
def run_segmentation(
941946
self, team_id: int, project_id: int, model_name: str, image_ids: list

src/superannotate/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "5.0.0b25"
1+
__version__ = "5.0.0b27"

tests/integration/test_interface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,21 @@ def test_search_project(self):
9797
sa.set_image_annotation_status(self.PROJECT_NAME, self.EXAMPLE_IMAGE_1, "Completed")
9898
data = sa.search_projects(self.PROJECT_NAME, return_metadata=True, include_complete_image_count=True)
9999
self.assertIsNotNone(data[0]['completed_images_count'])
100+
101+
def test_overlay_fuse(self):
102+
sa.upload_image_to_project(self.PROJECT_NAME, f"{self.folder_path}/{self.EXAMPLE_IMAGE_1}")
103+
sa.create_annotation_classes_from_classes_json(self.PROJECT_NAME, f"{self.folder_path}/classes/classes.json")
104+
sa.upload_image_annotations(
105+
self.PROJECT_NAME, self.EXAMPLE_IMAGE_1, f"{self.folder_path}/{self.EXAMPLE_IMAGE_1}___objects.json"
106+
)
107+
with tempfile.TemporaryDirectory() as temp_dir:
108+
paths = sa.download_image(
109+
self.PROJECT_NAME,
110+
self.EXAMPLE_IMAGE_1,
111+
temp_dir,
112+
include_annotations=True,
113+
include_fuse=True,
114+
include_overlay=True,
115+
)
116+
self.assertIsNotNone(paths)
117+

0 commit comments

Comments
 (0)