2222from .defaults import DEFAULT_HYPERPARAMETERS , NON_PLOTABLE_KEYS
2323from .utils import log_process , make_plotly_specs , reformat_metrics_json
2424from ..db .utils import _get_boto_session_by_credentials
25+ from ..db .project_api import get_project_metadata_bare , get_folder_metadata
2526from ..mixp .decorators import Trackable
27+ from ..db .images import get_project_root_folder_id
28+ from .ml_models import search_models
2629
2730logger = logging .getLogger ("superannotate-python-sdk" )
2831_api = API .get_instance ()
@@ -182,30 +185,60 @@ def run_segmentation(project, images_list, model):
182185 return (succeded_imgs , failed_imgs )
183186
184187
188+ def _path_to_folder_id_project (path ):
189+ parts = path .split ('/' )
190+ folder_id = None
191+ project = None
192+ if len (parts ) == 1 :
193+ project_name = parts [0 ]
194+ project = get_project_metadata_bare (project_name )
195+ folder_id = get_project_root_folder_id (project )
196+ elif len (parts ) == 2 :
197+ project_name , folder_name = parts
198+ project = get_project_metadata_bare (project_name )
199+ folder = get_folder_metadata (project = project_name , folder_name = folder_name )
200+ folder_id = folder ['id' ]
201+ return folder_id , project
202+
203+
204+ def _get_completed_images_counts (project_ids ):
205+ params = {
206+ "team_id" : _api .team_id ,
207+ "completedImagesCount" : True
208+ }
209+ response = _api .send_request (
210+ req_type = "PUT" ,
211+ path = "/foldersByTeam" ,
212+ json_req = {"project_ids" : project_ids },
213+ params = params
214+ )
215+ return response .json ()
216+
217+
185218@Trackable
186- @project_metadata
187- @model_metadata
188219def run_training (
189- project ,
190- base_model ,
191- model_name ,
192- model_description ,
193- task ,
194- hyperparameters ,
195- log = False
220+ model_name ,
221+ model_description ,
222+ task ,
223+ base_model ,
224+ train_data ,
225+ test_data ,
226+ hyperparameters = None ,
227+ log = False
196228):
197229 """Runs neural network training
198-
199- :param project: project or list of projects that contain the training images
200- :type project: str, dict or list of dict
201- :param base_model: base model on which the new network will be trained
202- :type base_model: str or dict
203230 :param model_name: name of the new model
204231 :type model_name: str
205232 :param model_description: description of the new model
206233 :type model_description: str
207234 :param task: The model training task
208235 :type task: str
236+ :param base_model: base model on which the new network will be trained
237+ :type base_model: str or dict
238+ :param train_data: train data folder id
239+ :type train_data: list of int
240+ :param test_data: test data folder id
241+ :type test_data: list of int
209242 :param hyperparameters: hyperparameters that should be used in training
210243 :type hyperparameters: dict
211244 :param log: If true will log training metrics in the stdout
@@ -214,62 +247,81 @@ def run_training(
214247 :rtype: dict
215248 """
216249
217- project_ids = None
218- project_type = None
250+ train_folder_ids = []
251+ test_folder_ids = []
252+ projects = []
219253
220- if isinstance (project , dict ):
221- project_ids = [project ["id" ]]
222- project_type = project ["type" ]
223- project = [project ]
224- else :
225- project_ids = [x ["id" ] for x in project ]
226- types = (x ["type" ] for x in project )
227- types = set (types )
228- if len (types ) != 1 :
229- logger .error (
230- "All projects have to be of the same type. Either vector or pixel"
231- )
232- raise SABaseException (0 , "Invalid project types" )
233- project_type = types .pop ()
254+ for path in train_data :
255+ folder_id , project = _path_to_folder_id_project (path )
256+ train_folder_ids .append (folder_id )
257+ projects .append (project )
258+
259+ for path in test_data :
260+ folder_id , project = _path_to_folder_id_project (path )
261+ test_folder_ids .append (folder_id )
262+ projects .append (project )
234263
235- for single_project in project :
236- upload_state = upload_state_int_to_str (
237- single_project .get ("upload_state" )
264+ if set (train_folder_ids ) & set (test_folder_ids ):
265+ raise SABaseException (
266+ 0 ,
267+ "Avoid overlapping between training and test data."
238268 )
239- if upload_state == "External" :
240- raise SABaseException (
241- 0 ,
242- "The function does not support projects containing images attached with URLs"
243- )
244269
245- base_model = base_model .get (project_type , None )
246- if not base_model :
270+ types = [i ["type" ] for i in projects ]
271+ if len (set (types )) != 1 :
272+ logger .error (
273+ "All projects have to be of the same type. Either vector or pixel"
274+ )
275+ raise SABaseException (0 , "Invalid project types" )
276+
277+ upload_states = set (i ["upload_state" ] for i in projects )
278+ if any ([True for state in upload_states if "External" == upload_state_int_to_str (state )]):
279+ raise SABaseException (
280+ 0 ,
281+ "The function does not support projects containing images attached with URLs"
282+ )
283+ if isinstance (base_model , dict ):
284+ base_model = base_model ['name' ]
285+ models = search_models (
286+ include_global = True , name = base_model
287+ )
288+ if not models :
289+ raise SABaseException (
290+ 0 ,
291+ "The specifed model does not exist."
292+ )
293+ base_model = models [0 ]
294+ base_model_id = base_model ['id' ]
295+ project_type = types [0 ]
296+ if not base_model ['type' ] == project_type :
247297 logger .error (
248298 "The base model has to be of the same type (vector or pixel) as the projects"
249299 )
250300 raise SABaseException (
251301 0 ,
252302 f"The type of provided projects is { project_type } , and does not correspond to the type of provided model"
253303 )
254-
255304 for item in DEFAULT_HYPERPARAMETERS :
256305 if item not in hyperparameters :
257306 hyperparameters [item ] = DEFAULT_HYPERPARAMETERS [item ]
258- complete_image_count = 0
259- for proj in project :
260- complete_image_count += proj ['rootFolderCompletedImagesCount' ]
261307
308+ project_ids = [i ['id' ] for i in projects ]
309+ completed_images_data = _get_completed_images_counts (project_ids )
310+ complete_image_count = 0
311+ for folder in completed_images_data ['data' ]:
312+ if folder ['id' ] in train_folder_ids :
313+ complete_image_count += folder ['completedCount' ]
262314 hyperparameters ["name" ] = model_name
263315 hyperparameters ["description" ] = model_description
264316 hyperparameters ["task" ] = _MODEL_TRAINING_TASKS [task ]
265- hyperparameters ["base_model_id" ] = base_model ["id" ]
266- hyperparameters ["project_ids" ] = project_ids
317+ hyperparameters ["base_model_id" ] = base_model_id
267318 hyperparameters ["image_count" ] = complete_image_count
268319 hyperparameters ["project_type" ] = project_type_str_to_int (project_type )
320+ hyperparameters ["test_folder_ids" ] = test_folder_ids
321+ hyperparameters ["train_folder_ids" ] = train_folder_ids
269322 params = {
270323 "team_id" : _api .team_id ,
271324 }
272-
273325 response = _api .send_request (
274326 req_type = "POST" ,
275327 path = "/ml_models" ,
@@ -332,7 +384,7 @@ def run_training(
332384 if answer in ['Y' , 'y' ]:
333385 params = {'team_id' : _api .team_id }
334386 json_req = {'training_status' : 6 }
335- response = _api .send_request (
387+ _api .send_request (
336388 req_type = 'PUT' ,
337389 path = f'ml_model/{ new_model_id } ' ,
338390 params = params ,
@@ -341,7 +393,7 @@ def run_training(
341393 logger .info ("Model was successfully saved" )
342394 pass
343395 else :
344- delete_model (name )
396+ delete_model (model_name )
345397 logger .info ('The model was not saved' )
346398 is_training_finished = True
347399
@@ -380,6 +432,7 @@ def plot_model_metrics(metric_json_list):
380432 :param metric_json_list: list of <model_name>.json files
381433 :type metric_json_list: list of str
382434 """
435+
383436 def plot_df (df , plottable_cols , figure , start_index = 1 ):
384437 for row , metric in enumerate (plottable_cols , start_index ):
385438 for model_df in df :
@@ -401,7 +454,7 @@ def get_plottable_cols(df):
401454 plottable_cols += [
402455 col_name
403456 for col_name in col_names if col_name not in plottable_cols and
404- col_name not in NON_PLOTABLE_KEYS
457+ col_name not in NON_PLOTABLE_KEYS
405458 ]
406459 return plottable_cols
407460
0 commit comments