diff --git a/src/main.py b/src/main.py index 3f3e366..79d669c 100644 --- a/src/main.py +++ b/src/main.py @@ -90,6 +90,10 @@ def get_bead_spread(i, coords, mass, radius, grid, voxel_size, n_breaks): return spread def run_prism( coords, mass, radius, ps_names, args, output_dir = None ): + + if output_dir is None: + output_dir = args.output + models = round(args.models*coords.shape[0]) if args.models != 1: selected_models = np.random.choice(coords.shape[0], models, replace=False) @@ -116,15 +120,11 @@ def run_prism( coords, mass, radius, ps_names, args, output_dir = None ): print('Bead Spread calculation done') # If not specified create a default output directory. - if not os.path.exists(args.output) and output_dir == None: - os.makedirs(args.output) - else: - args.output = output_dir - os.makedirs(args.output) + os.makedirs(output_dir, exist_ok=True) # Save the bead_spread values. if args.return_spread == 1: - with open(args.output + "/bead_spreads_cl" + str(args.classes) + ".txt", "w") as fl: + with open(output_dir + "/bead_spreads_cl" + str(args.classes) + ".txt", "w") as fl: for bs in bead_spread: fl.write('{:0.3f}'.format(bs)) fl.write("\n") @@ -138,9 +138,9 @@ def run_prism( coords, mass, radius, ps_names, args, output_dir = None ): annot_df = pd.DataFrame(np.array(annotated_patches), columns = ['Bead', 'Bead Name', 'Type', 'Class', 'Patch']) annot_df['Patch'] = pd.to_numeric(annot_df["Patch"]) annot_df.sort_values(['Patch'], ascending=[True]) - annot_df.to_csv(args.output + '/annotations_cl' + str(args.classes) + '.txt', index=None) + annot_df.to_csv(output_dir + '/annotations_cl' + str(args.classes) + '.txt', index=None) - with open(args.output + "/low_prec_cl" + str(args.classes) + ".txt", "w") as fl: + with open(output_dir + "/low_prec_cl" + str(args.classes) + ".txt", "w") as fl: lev = 1 fl.write("Level" + "\t" + "Bead Indices" + "\t" + "Bead Names") fl.write("\n") @@ -154,7 +154,7 @@ def run_prism( coords, mass, radius, ps_names, args, output_dir = None ): fl.write("\n") lev=lev+1 - with open(args.output + "/high_prec_cl" + str(args.classes) + ".txt", "w") as fl: + with open(output_dir + "/high_prec_cl" + str(args.classes) + ".txt", "w") as fl: lev = 1 fl.write("Level" + "\t" + "Bead Indices" + "\t" + "Bead Names") fl.write("\n")