From f31b8dd19fb285bccd4b85d7938186adafdf4788 Mon Sep 17 00:00:00 2001 From: ccamp104 Date: Fri, 2 Jan 2026 17:22:15 -0500 Subject: [PATCH] if dat_file is inputted and formula is not, have formula not be injected into prompt if filename is not a formula --- atomgpt/inverse_models/inverse_predict.py | 73 ++++++++++++++++++----- 1 file changed, 58 insertions(+), 15 deletions(-) diff --git a/atomgpt/inverse_models/inverse_predict.py b/atomgpt/inverse_models/inverse_predict.py index 63054cf..da46085 100644 --- a/atomgpt/inverse_models/inverse_predict.py +++ b/atomgpt/inverse_models/inverse_predict.py @@ -12,6 +12,7 @@ from ase.constraints import ExpCellFilter import time from jarvis.core.atoms import ase_to_atoms +import re # CHANGE: used to detect formula-like filenames parser = argparse.ArgumentParser( description="Atomistic Generative Pre-trained Transformer" @@ -111,6 +112,9 @@ def predict( load_in_4bit=None, # temp_config["load_in_4bit"] verbose=True, # temp_config["load_in_4bit"] ): + # Track whether the user provided a formula via CLI + user_provided_formula = formula is not None + print("config_path", config_path) if output_dir is not None: config_name = os.path.join(output_dir, "config.json") @@ -144,6 +148,8 @@ def predict( print("Model used:", model_name) print("config used:", config_path) print("formula:", formula) + print("dat_path:", dat_path) + model = None tokenizer = None try: @@ -163,13 +169,18 @@ def predict( model_name, gguf_file=filename ) pass + atoms_arr = [] lines = [] + + # If formula isn't given but dat_path is, run single-file inference (no pred_csv needed) if formula is None: - # if dat_path is None: - f = open(pred_csv, "r") - lines = f.read().splitlines() - f.close() + if dat_path is not None: + lines = [dat_path] + else: + f = open(pred_csv, "r") + lines = f.read().splitlines() + f.close() else: if dat_path is not None: lines = [dat_path] @@ -177,6 +188,9 @@ def predict( mem = [] + # CHANGE: simple heuristic for "formula-like" strings (e.g., LaB6, SiO2) + formula_like_re = re.compile(r"^([A-Z][a-z]?\d*)+$") + for i in lines: prompt = i if ".dat" in i or dat_path is not None: @@ -185,6 +199,11 @@ def predict( fname_csv = os.path.join(parent, i) else: fname_csv = dat_path + + # Determine whether the DAT filename itself looks like a chemical formula + fname_stem = Path(fname_csv).stem + filename_looks_like_formula = bool(formula_like_re.fullmatch(fname_stem)) + _formula, x, y = load_exp_file( filename=fname_csv, intvl=intvl, @@ -198,17 +217,40 @@ def predict( formula = str(_formula.split("/")[-1].split(".dat")[0]) except Exception: pass - prompt = ( - "The chemical formula is " - + formula - + " The " - + temp_config["prop"] - + " is " - + y_new_str - + ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types." - ) + + # CHANGE: inject formula sentence only if: + # - user provided --formula, OR + # - DAT filename itself is a formula (e.g., LaB6.dat) + if user_provided_formula and formula is not None: + prompt = ( + "The chemical formula is " + + formula + + " The " + + temp_config["prop"] + + " is " + + y_new_str + + ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types." + ) + elif (not user_provided_formula) and filename_looks_like_formula: + prompt = ( + "The chemical formula is " + + fname_stem + + " The " + + temp_config["prop"] + + " is " + + y_new_str + + ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types." + ) + else: + prompt = ( + "The " + + temp_config["prop"] + + " is " + + y_new_str + + ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types." + ) else: - if formula is not None: + if user_provided_formula and formula is not None: prompt = ( "The chemical formula is " + formula @@ -239,7 +281,6 @@ def predict( info["prompt"] = prompt info["error"] = "Invalid structure returned by AtomGPT (None)." mem.append(info) - # skip the rest of the loop for this entry continue if verbose: @@ -255,6 +296,7 @@ def predict( info["prompt"] = prompt info["atoms"] = gen_mat.to_dict() mem.append(info) + dumpjson(data=mem, filename=fname) return model, tokenizer, temp_config @@ -273,3 +315,4 @@ def predict( prop_val=args.prop_val, background_subs=args.background_subs, ) +