Skip to content

Commit 0ed5c40

Browse files
authored
Merge pull request #98 from superannotateai/SAS-3622
Sas 3622
2 parents ace2c49 + 69fafd9 commit 0ed5c40

File tree

4 files changed

+185
-57
lines changed

4 files changed

+185
-57
lines changed

superannotate/mixp/decorators.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,26 @@ class Trackable(object):
1717
def __init__(self, function):
1818
lock = Lock()
1919
self.function = function
20+
self._success = None
21+
self._caller_name = None
22+
self._func_name_to_track = None
2023
with lock:
2124
Trackable.registered.add(function.__name__)
2225
functools.update_wrapper(self, function)
2326

24-
def __call__(self, *args, **kwargs):
27+
def should_track(self):
28+
if self._caller_name not in Trackable.registered or self._func_name_to_track in always_trackable_func_names:
29+
return True
30+
return False
31+
32+
def track(self, *args, **kwargs):
2533
try:
26-
func_name_to_track = self.function.__name__
27-
caller_name = sys._getframe(1).f_code.co_name
28-
if caller_name not in Trackable.registered or func_name_to_track in always_trackable_func_names:
29-
data = getattr(parsers, func_name_to_track)(*args, **kwargs)
34+
if self.should_track():
35+
data = getattr(parsers, self._func_name_to_track)(*args, **kwargs)
3036
user_id = _api.user_id
3137
event_name = data['event_name']
3238
properties = data['properties']
39+
properties['Success'] = self._success
3340
default = get_default(
3441
_api.team_name,
3542
_api.user_id,
@@ -39,6 +46,19 @@ def __call__(self, *args, **kwargs):
3946
properties = {**default, **properties}
4047
if "pytest" not in sys.modules:
4148
mp.track(user_id, event_name, properties)
42-
except:
49+
except Exception as e:
4350
pass
44-
return self.function(*args, **kwargs)
51+
52+
53+
def __call__(self, *args, **kwargs):
54+
try:
55+
self._caller_name = sys._getframe(1).f_code.co_name
56+
self._func_name_to_track = self.function.__name__
57+
ret = self.function(*args, **kwargs)
58+
self._success = True
59+
self.track(*args, **kwargs)
60+
except Exception as e:
61+
self._success = False
62+
self.track(*args, **kwargs)
63+
raise e
64+
return ret

superannotate/ml/ml_funcs.py

Lines changed: 103 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from .defaults import DEFAULT_HYPERPARAMETERS, NON_PLOTABLE_KEYS
2323
from .utils import log_process, make_plotly_specs, reformat_metrics_json
2424
from ..db.utils import _get_boto_session_by_credentials
25+
from ..db.project_api import get_project_metadata_bare, get_folder_metadata
2526
from ..mixp.decorators import Trackable
27+
from ..db.images import get_project_root_folder_id
28+
from .ml_models import search_models
2629

2730
logger = logging.getLogger("superannotate-python-sdk")
2831
_api = API.get_instance()
@@ -182,30 +185,60 @@ def run_segmentation(project, images_list, model):
182185
return (succeded_imgs, failed_imgs)
183186

184187

188+
def _path_to_folder_id_project(path):
189+
parts = path.split('/')
190+
folder_id = None
191+
project = None
192+
if len(parts) == 1:
193+
project_name = parts[0]
194+
project = get_project_metadata_bare(project_name)
195+
folder_id = get_project_root_folder_id(project)
196+
elif len(parts) == 2:
197+
project_name, folder_name = parts
198+
project = get_project_metadata_bare(project_name)
199+
folder = get_folder_metadata(project=project_name, folder_name=folder_name)
200+
folder_id = folder['id']
201+
return folder_id, project
202+
203+
204+
def _get_completed_images_counts(project_ids):
205+
params = {
206+
"team_id": _api.team_id,
207+
"completedImagesCount": True
208+
}
209+
response = _api.send_request(
210+
req_type="PUT",
211+
path="/foldersByTeam",
212+
json_req={"project_ids": project_ids},
213+
params=params
214+
)
215+
return response.json()
216+
217+
185218
@Trackable
186-
@project_metadata
187-
@model_metadata
188219
def run_training(
189-
project,
190-
base_model,
191-
model_name,
192-
model_description,
193-
task,
194-
hyperparameters,
195-
log=False
220+
model_name,
221+
model_description,
222+
task,
223+
base_model,
224+
train_data,
225+
test_data,
226+
hyperparameters=None,
227+
log=False
196228
):
197229
"""Runs neural network training
198-
199-
:param project: project or list of projects that contain the training images
200-
:type project: str, dict or list of dict
201-
:param base_model: base model on which the new network will be trained
202-
:type base_model: str or dict
203230
:param model_name: name of the new model
204231
:type model_name: str
205232
:param model_description: description of the new model
206233
:type model_description: str
207234
:param task: The model training task
208235
:type task: str
236+
:param base_model: base model on which the new network will be trained
237+
:type base_model: str or dict
238+
:param train_data: train data folder id
239+
:type train_data: list of int
240+
:param test_data: test data folder id
241+
:type test_data: list of int
209242
:param hyperparameters: hyperparameters that should be used in training
210243
:type hyperparameters: dict
211244
:param log: If true will log training metrics in the stdout
@@ -214,62 +247,81 @@ def run_training(
214247
:rtype: dict
215248
"""
216249

217-
project_ids = None
218-
project_type = None
250+
train_folder_ids = []
251+
test_folder_ids = []
252+
projects = []
219253

220-
if isinstance(project, dict):
221-
project_ids = [project["id"]]
222-
project_type = project["type"]
223-
project = [project]
224-
else:
225-
project_ids = [x["id"] for x in project]
226-
types = (x["type"] for x in project)
227-
types = set(types)
228-
if len(types) != 1:
229-
logger.error(
230-
"All projects have to be of the same type. Either vector or pixel"
231-
)
232-
raise SABaseException(0, "Invalid project types")
233-
project_type = types.pop()
254+
for path in train_data:
255+
folder_id, project = _path_to_folder_id_project(path)
256+
train_folder_ids.append(folder_id)
257+
projects.append(project)
258+
259+
for path in test_data:
260+
folder_id, project = _path_to_folder_id_project(path)
261+
test_folder_ids.append(folder_id)
262+
projects.append(project)
234263

235-
for single_project in project:
236-
upload_state = upload_state_int_to_str(
237-
single_project.get("upload_state")
264+
if set(train_folder_ids) & set(test_folder_ids):
265+
raise SABaseException(
266+
0,
267+
"Avoid overlapping between training and test data."
238268
)
239-
if upload_state == "External":
240-
raise SABaseException(
241-
0,
242-
"The function does not support projects containing images attached with URLs"
243-
)
244269

245-
base_model = base_model.get(project_type, None)
246-
if not base_model:
270+
types = [i["type"] for i in projects]
271+
if len(set(types)) != 1:
272+
logger.error(
273+
"All projects have to be of the same type. Either vector or pixel"
274+
)
275+
raise SABaseException(0, "Invalid project types")
276+
277+
upload_states = set(i["upload_state"] for i in projects)
278+
if any([True for state in upload_states if "External" == upload_state_int_to_str(state)]):
279+
raise SABaseException(
280+
0,
281+
"The function does not support projects containing images attached with URLs"
282+
)
283+
if isinstance(base_model, dict):
284+
base_model = base_model['name']
285+
models = search_models(
286+
include_global=True, name=base_model
287+
)
288+
if not models:
289+
raise SABaseException(
290+
0,
291+
"The specifed model does not exist."
292+
)
293+
base_model = models[0]
294+
base_model_id = base_model['id']
295+
project_type = types[0]
296+
if not base_model['type'] == project_type:
247297
logger.error(
248298
"The base model has to be of the same type (vector or pixel) as the projects"
249299
)
250300
raise SABaseException(
251301
0,
252302
f"The type of provided projects is {project_type}, and does not correspond to the type of provided model"
253303
)
254-
255304
for item in DEFAULT_HYPERPARAMETERS:
256305
if item not in hyperparameters:
257306
hyperparameters[item] = DEFAULT_HYPERPARAMETERS[item]
258-
complete_image_count = 0
259-
for proj in project:
260-
complete_image_count += proj['rootFolderCompletedImagesCount']
261307

308+
project_ids = [i['id'] for i in projects]
309+
completed_images_data = _get_completed_images_counts(project_ids)
310+
complete_image_count = 0
311+
for folder in completed_images_data['data']:
312+
if folder['id'] in train_folder_ids:
313+
complete_image_count += folder['completedCount']
262314
hyperparameters["name"] = model_name
263315
hyperparameters["description"] = model_description
264316
hyperparameters["task"] = _MODEL_TRAINING_TASKS[task]
265-
hyperparameters["base_model_id"] = base_model["id"]
266-
hyperparameters["project_ids"] = project_ids
317+
hyperparameters["base_model_id"] = base_model_id
267318
hyperparameters["image_count"] = complete_image_count
268319
hyperparameters["project_type"] = project_type_str_to_int(project_type)
320+
hyperparameters["test_folder_ids"] = test_folder_ids
321+
hyperparameters["train_folder_ids"] = train_folder_ids
269322
params = {
270323
"team_id": _api.team_id,
271324
}
272-
273325
response = _api.send_request(
274326
req_type="POST",
275327
path="/ml_models",
@@ -332,7 +384,7 @@ def run_training(
332384
if answer in ['Y', 'y']:
333385
params = {'team_id': _api.team_id}
334386
json_req = {'training_status': 6}
335-
response = _api.send_request(
387+
_api.send_request(
336388
req_type='PUT',
337389
path=f'ml_model/{new_model_id}',
338390
params=params,
@@ -341,7 +393,7 @@ def run_training(
341393
logger.info("Model was successfully saved")
342394
pass
343395
else:
344-
delete_model(name)
396+
delete_model(model_name)
345397
logger.info('The model was not saved')
346398
is_training_finished = True
347399

@@ -380,6 +432,7 @@ def plot_model_metrics(metric_json_list):
380432
:param metric_json_list: list of <model_name>.json files
381433
:type metric_json_list: list of str
382434
"""
435+
383436
def plot_df(df, plottable_cols, figure, start_index=1):
384437
for row, metric in enumerate(plottable_cols, start_index):
385438
for model_df in df:
@@ -401,7 +454,7 @@ def get_plottable_cols(df):
401454
plottable_cols += [
402455
col_name
403456
for col_name in col_names if col_name not in plottable_cols and
404-
col_name not in NON_PLOTABLE_KEYS
457+
col_name not in NON_PLOTABLE_KEYS
405458
]
406459
return plottable_cols
407460

tests/test_neural_networks.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import time
2+
from pathlib import Path
3+
4+
import superannotate as sa
5+
6+
test_root = Path().resolve() / 'tests'
7+
project_name = "training"
8+
9+
10+
def test_run_training():
11+
export_path = test_root / 'consensus_benchmark' / 'consensus_test_data'
12+
if len(sa.search_projects(project_name)) != 0:
13+
sa.delete_project(project_name)
14+
time.sleep(2)
15+
16+
sa.create_project(project_name, "test bench", "Vector")
17+
time.sleep(2)
18+
for i in range(1, 4):
19+
sa.create_folder(project_name, "consensus_" + str(i))
20+
time.sleep(2)
21+
sa.create_annotation_classes_from_classes_json(
22+
project_name, export_path / 'classes' / 'classes.json'
23+
)
24+
sa.upload_images_from_folder_to_project(
25+
project_name, export_path / "images", annotation_status="Completed"
26+
)
27+
for i in range(1, 4):
28+
sa.upload_images_from_folder_to_project(
29+
project_name + '/consensus_' + str(i),
30+
export_path / "images",
31+
annotation_status="Completed"
32+
)
33+
sa.upload_annotations_from_folder_to_project(project_name, export_path)
34+
for i in range(1, 4):
35+
sa.upload_annotations_from_folder_to_project(
36+
project_name + '/consensus_' + str(i),
37+
export_path / ('consensus_' + str(i))
38+
)
39+
time.sleep(2)
40+
new_model = sa.run_training(
41+
"some name",
42+
"some desc",
43+
"Instance Segmentation for Vector Projects",
44+
"Instance Segmentation (trained on COCO)",
45+
[f"{project_name}/consensus_1"],
46+
[f"{project_name}/consensus_2"],
47+
{
48+
"base_lr": 0.02,
49+
"images_per_batch": 8
50+
},
51+
False
52+
)
53+
54+
assert "id" in new_model

tests/test_video.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_video(tmpdir):
2020
sa.delete_project(project)
2121

2222
project = sa.create_project(PROJECT_NAME1, "test", "Vector")
23+
print(project)
2324
time.sleep(1)
2425
sa.create_annotation_class(project, "fr", "#FFAAAA")
2526
time.sleep(1)

0 commit comments

Comments
 (0)