diff --git a/batch/job_template.sh b/batch/job_template.sh index 1205bc8..e98eaa3 100644 --- a/batch/job_template.sh +++ b/batch/job_template.sh @@ -4,15 +4,15 @@ # # To test the script, run the following command from risk-slim directory: # -# `bash batch/job_template.sh` +# bash batch/job_template.sh # # To see a detailed list of all arguments that can be passed into risk_slim, use: # -# `python "batch/train_risk_slim.py --help` +# python batch/train_risk_slim.py --help # # or # -# `python2 "batch/train_risk_slim.py --help` +# python3 batch/train_risk_slim.py --help # # Recommended Directory Structure for Batch Computing: # @@ -30,6 +30,7 @@ # The values can be changed directly using a text editor, or programmatically using a tool such as # `jq` https://stedolan.github.io/jq/ + #directories repo_dir=$(pwd) data_dir="${repo_dir}/examples/data" #change to /batch/data/ for your own data @@ -88,4 +89,3 @@ python3 "${batch_dir}/train_risk_slim.py" \ --log "${log_file}" exit -W \ No newline at end of file diff --git a/riskslim/setup_functions.py b/riskslim/setup_functions.py index 731f1c5..0315897 100644 --- a/riskslim/setup_functions.py +++ b/riskslim/setup_functions.py @@ -95,7 +95,7 @@ def setup_loss_functions(data, coef_set, L0_max = None, loss_computation = None, L0_max = L0_max) - Z = np.require(Z, requirements=['F'], dtype = float) + Z = np.require(Z, requirements=['F'], dtype = np.float_) print_log("%d rows in lookup table" % (s_max - s_min + 1)) loss_value_tbl, prob_value_tbl, tbl_offset = get_loss_value_and_prob_tables(s_min, s_max) diff --git a/riskslim/solution_pool.py b/riskslim/solution_pool.py index e7ac26b..14ad42d 100644 --- a/riskslim/solution_pool.py +++ b/riskslim/solution_pool.py @@ -184,11 +184,11 @@ def sort(self): def map(self, mapfun, target = 'all'): assert callable(mapfun), 'map function must be callable' - if target is 'solutions': + if target == 'solutions': return list(map(mapfun, self.solutions)) - elif target is 'objvals': + elif target == 'objvals': return list(map(mapfun, self.objvals)) - elif target is 'all': + elif target == 'all': return list(map(mapfun, self.objvals, self.solutions)) else: raise ValueError('target must be either solutions, objvals, or all') @@ -314,4 +314,4 @@ def __repr__(self): def __str__(self): - return self.table() \ No newline at end of file + return self.table()