@@ -2206,6 +2206,7 @@ def __init__(
22062206 image_name : str ,
22072207 images : BaseManageableRepository ,
22082208 destination : str ,
2209+ annotation_classes : BaseManageableRepository ,
22092210 ):
22102211 super ().__init__ ()
22112212 self ._service = service
@@ -2214,6 +2215,7 @@ def __init__(
22142215 self ._image_name = image_name
22152216 self ._images = images
22162217 self ._destination = destination
2218+ self ._annotation_classes = annotation_classes
22172219
22182220 @property
22192221 def image_use_case (self ):
@@ -2231,6 +2233,85 @@ def validate_project_type(self):
22312233 constances .LIMITED_FUNCTIONS [self ._project .project_type ]
22322234 )
22332235
2236+ @property
2237+ def annotation_classes_name_map (self ) -> dict :
2238+ classes_data = defaultdict (dict )
2239+ annotation_classes = self ._annotation_classes .get_all ()
2240+ for annotation_class in annotation_classes :
2241+ class_info = {"id" : annotation_class .uuid }
2242+ if annotation_class .attribute_groups :
2243+ for attribute_group in annotation_class .attribute_groups :
2244+ attribute_group_data = defaultdict (dict )
2245+ for attribute in attribute_group ["attributes" ]:
2246+ attribute_group_data [attribute ["name" ]] = attribute ["id" ]
2247+ class_info ["attribute_groups" ] = {
2248+ attribute_group ["name" ]: {
2249+ "id" : attribute_group ["id" ],
2250+ "attributes" : attribute_group_data ,
2251+ }
2252+ }
2253+ classes_data [annotation_class .name ] = class_info
2254+ return classes_data
2255+
2256+ def get_templates_mapping (self ):
2257+ templates = self ._service .get_templates (team_id = self ._project .team_id ).get (
2258+ "data" , []
2259+ )
2260+ templates_map = {}
2261+ for template in templates :
2262+ templates_map [template ["name" ]] = template ["id" ]
2263+ return templates_map
2264+
2265+ def fill_classes_data (self , annotations : dict ):
2266+ annotation_classes = self .annotation_classes_name_map
2267+ if "instances" not in annotations :
2268+ return
2269+ unknown_classes = {}
2270+ for annotation in [i for i in annotations ["instances" ] if "className" in i ]:
2271+ if "className" not in annotation :
2272+ return
2273+ annotation_class_name = annotation ["className" ]
2274+ if annotation_class_name not in annotation_classes :
2275+ if annotation_class_name not in unknown_classes :
2276+ unknown_classes [annotation_class_name ] = {
2277+ "id" : - (len (unknown_classes ) + 1 ),
2278+ "attribute_groups" : {},
2279+ }
2280+ annotation_classes .update (unknown_classes )
2281+ templates = self .get_templates_mapping ()
2282+ for annotation in (
2283+ i for i in annotations ["instances" ] if i .get ("type" , None ) == "template"
2284+ ):
2285+ annotation ["templateId" ] = templates .get (
2286+ annotation .get ("templateName" , "" ), - 1
2287+ )
2288+
2289+ for annotation in [i for i in annotations ["instances" ] if "className" in i ]:
2290+ annotation_class_name = annotation ["className" ]
2291+ if annotation_class_name not in annotation_classes :
2292+ continue
2293+ annotation ["classId" ] = annotation_classes [annotation_class_name ]["id" ]
2294+ for attribute in annotation ["attributes" ]:
2295+ if (
2296+ attribute ["groupName" ]
2297+ not in annotation_classes [annotation_class_name ]["attribute_groups" ]
2298+ ):
2299+ continue
2300+ attribute ["groupId" ] = annotation_classes [annotation_class_name ][
2301+ "attribute_groups"
2302+ ][attribute ["groupName" ]]["id" ]
2303+ if (
2304+ attribute ["name" ]
2305+ not in annotation_classes [annotation_class_name ][
2306+ "attribute_groups"
2307+ ][attribute ["groupName" ]]["attributes" ]
2308+ ):
2309+ del attribute ["groupId" ]
2310+ continue
2311+ attribute ["id" ] = annotation_classes [annotation_class_name ][
2312+ "attribute_groups"
2313+ ][attribute ["groupName" ]]["attributes" ]
2314+
22342315 def execute (self ):
22352316 if self .is_valid ():
22362317 data = {
@@ -2283,6 +2364,7 @@ def execute(self):
22832364 logger .info ("There is no blue-map for the image." )
22842365
22852366 json_path = Path (self ._destination ) / data ["annotation_json_filename" ]
2367+ self .fill_classes_data (data ["annotation_json" ])
22862368 with open (json_path , "w" ) as f :
22872369 json .dump (data ["annotation_json" ], f , indent = 4 )
22882370
@@ -3039,6 +3121,7 @@ def __init__(
30393121 images : BaseManageableRepository ,
30403122 classes : BaseManageableRepository ,
30413123 backend_service_provider : SuerannotateServiceProvider ,
3124+ annotation_classes : BaseReadOnlyRepository ,
30423125 download_path : str ,
30433126 image_variant : str = "original" ,
30443127 include_annotations : bool = False ,
@@ -3065,6 +3148,7 @@ def __init__(
30653148 image_name = self ._image .name ,
30663149 images = images ,
30673150 destination = download_path ,
3151+ annotation_classes = annotation_classes ,
30683152 )
30693153 self .get_annotation_classes_ues_case = GetAnnotationClassesUseCase (
30703154 classes = classes ,
@@ -3098,31 +3182,35 @@ def validate_include_annotations(self):
30983182
30993183 def execute (self ):
31003184 if self .is_valid ():
3185+ fuse_image = None
3186+ annotations = None
3187+
31013188 image_bytes = self .get_image_use_case .execute ().data
31023189 download_path = f"{ self ._download_path } /{ self ._image .name } "
31033190 if self ._image_variant == "lores" :
31043191 download_path = download_path + "___lores.jpg"
31053192 with open (download_path , "wb" ) as image_file :
31063193 image_file .write (image_bytes .getbuffer ())
31073194
3108- annotations = None
31093195 if self ._include_annotations :
31103196 annotations = self .download_annotation_use_case .execute ().data
31113197
3112- fuse_image = None
31133198 if self ._include_annotations and self ._include_fuse :
31143199 classes = self .get_annotation_classes_ues_case .execute ().data
3115- fuse_image_use_case = CreateFuseImageUseCase (
3116- project_type = constances .ProjectType .get_name (
3117- self ._project .project_type
3118- ),
3119- image_path = download_path ,
3120- classes = [
3121- annotation_class .to_dict () for annotation_class in classes
3122- ],
3123- generate_overlay = self ._include_overlay ,
3200+ fuse_image = (
3201+ CreateFuseImageUseCase (
3202+ project_type = constances .ProjectType .get_name (
3203+ self ._project .project_type
3204+ ),
3205+ image_path = download_path ,
3206+ classes = [
3207+ annotation_class .to_dict () for annotation_class in classes
3208+ ],
3209+ generate_overlay = self ._include_overlay ,
3210+ )
3211+ .execute ()
3212+ .data
31243213 )
3125- fuse_image = fuse_image_use_case .execute ().data
31263214
31273215 self ._response .data = (
31283216 download_path ,
@@ -3912,16 +4000,18 @@ def execute(self):
39124000 os .path .basename (self ._model .config_path ), metrics_name
39134001 )
39144002
3915- download_token = self ._backend_service .get_ml_model_download_tokens (
4003+ auth_response = self ._backend_service .get_ml_model_download_tokens (
39164004 self ._team_id , self ._model .uuid
39174005 )
4006+ if not auth_response .ok :
4007+ raise AppException (auth_response .error )
39184008 s3_session = boto3 .Session (
3919- aws_access_key_id = download_token [ "tokens" ][ "accessKeyId" ] ,
3920- aws_secret_access_key = download_token [ "tokens" ][ "secretAccessKey" ] ,
3921- aws_session_token = download_token [ "tokens" ][ "sessionToken" ] ,
3922- region_name = download_token [ "tokens" ][ " region" ] ,
4009+ aws_access_key_id = auth_response . data . access_key ,
4010+ aws_secret_access_key = auth_response . data . secret_key ,
4011+ aws_session_token = auth_response . data . session_token ,
4012+ region_name = auth_response . data . region ,
39234013 )
3924- bucket = s3_session .resource ("s3" ).Bucket (download_token [ "tokens" ][ " bucket" ] )
4014+ bucket = s3_session .resource ("s3" ).Bucket (auth_response . data . bucket )
39254015
39264016 bucket .download_file (
39274017 self ._model .config_path , os .path .join (self ._download_path , "config.yaml" )
0 commit comments