-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathutils.py
More file actions
169 lines (138 loc) · 5.16 KB
/
utils.py
File metadata and controls
169 lines (138 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright (c) 2020-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import re
import sys
import math
import time
import pickle
import random
import getpass
import argparse
import subprocess
import errno
import signal
from functools import wraps, partial
from logger import create_logger
FALSY_STRINGS = {'off', 'false', '0'}
TRUTHY_STRINGS = {'on', 'true', '1'}
DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser()
CUDA = True
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("Invalid value for a boolean flag!")
def initialize_exp(params):
"""
Initialize the experience:
- dump parameters
- create a logger
"""
# dump parameters
get_dump_path(params)
pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb'))
# get running command
command = ["python", sys.argv[0]]
for x in sys.argv[1:]:
if x.startswith('--'):
assert '"' not in x and "'" not in x
command.append(x)
else:
assert "'" not in x
if re.match('^[a-zA-Z0-9_]+$', x):
command.append("%s" % x)
else:
command.append("'%s'" % x)
command = ' '.join(command)
params.command = command + ' --exp_id "%s"' % params.exp_id
# check experiment name
assert len(params.exp_name.strip()) > 0
# create a logger
logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0))
logger.info("============ Initialized logger ============")
logger.info("\n".join("%s: %s" % (k, str(v))
for k, v in sorted(dict(vars(params)).items())))
logger.info("The experiment will be stored in %s\n" % params.dump_path)
logger.info("Running command: %s" % command)
logger.info("")
return logger
def get_dump_path(params):
"""
Create a directory to store the experiment.
"""
params.dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path
assert len(params.exp_name) > 0
# create the sweep path if it does not exist
sweep_path = os.path.join(params.dump_path, params.exp_name)
if not os.path.exists(sweep_path):
subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()
# create an ID for the job if it is not given in the parameters.
# if we run on the cluster, the job ID is the one of Chronos.
# otherwise, it is randomly generated
if params.exp_id == '':
chronos_job_id = os.environ.get('CHRONOS_JOB_ID')
slurm_job_id = os.environ.get('SLURM_JOB_ID')
assert chronos_job_id is None or slurm_job_id is None
exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id
if exp_id is None:
chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
while True:
exp_id = ''.join(random.choice(chars) for _ in range(10))
if not os.path.isdir(os.path.join(sweep_path, exp_id)):
break
else:
assert exp_id.isdigit()
params.exp_id = exp_id
# create the dump folder / update parameters
params.dump_path = os.path.join(sweep_path, params.exp_id)
if not os.path.isdir(params.dump_path):
subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait()
def to_cuda(*args):
"""
Move tensors to CUDA.
"""
if not CUDA:
return args
return [None if x is None else x.cuda() for x in args]
class TimeoutError(BaseException):
pass
def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
def decorator(func):
def _handle_timeout(repeat_id, signum, frame):
# logger.warning(f"Catched the signal ({repeat_id}) Setting signal handler {repeat_id + 1}")
signal.signal(signal.SIGALRM, partial(_handle_timeout, repeat_id + 1))
signal.alarm(seconds)
raise TimeoutError(error_message)
def wrapper(*args, **kwargs):
old_signal = signal.signal(signal.SIGALRM, partial(_handle_timeout, 0))
old_time_left = signal.alarm(seconds)
assert type(old_time_left) is int and old_time_left >= 0
if 0 < old_time_left < seconds: # do not exceed previous timer
signal.alarm(old_time_left)
start_time = time.time()
try:
result = func(*args, **kwargs)
finally:
if old_time_left == 0:
signal.alarm(0)
else:
sub = time.time() - start_time
signal.signal(signal.SIGALRM, old_signal)
signal.alarm(max(0, math.ceil(old_time_left - sub)))
return result
return wraps(func)(wrapper)
return decorator