Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions atomgpt/inverse_models/inverse_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ase.constraints import ExpCellFilter
import time
from jarvis.core.atoms import ase_to_atoms
import re # CHANGE: used to detect formula-like filenames

parser = argparse.ArgumentParser(
description="Atomistic Generative Pre-trained Transformer"
Expand Down Expand Up @@ -111,6 +112,9 @@ def predict(
load_in_4bit=None, # temp_config["load_in_4bit"]
verbose=True, # temp_config["load_in_4bit"]
):
# Track whether the user provided a formula via CLI
user_provided_formula = formula is not None

print("config_path", config_path)
if output_dir is not None:
config_name = os.path.join(output_dir, "config.json")
Expand Down Expand Up @@ -144,6 +148,8 @@ def predict(
print("Model used:", model_name)
print("config used:", config_path)
print("formula:", formula)
print("dat_path:", dat_path)

model = None
tokenizer = None
try:
Expand All @@ -163,20 +169,28 @@ def predict(
model_name, gguf_file=filename
)
pass

atoms_arr = []
lines = []

# If formula isn't given but dat_path is, run single-file inference (no pred_csv needed)
if formula is None:
# if dat_path is None:
f = open(pred_csv, "r")
lines = f.read().splitlines()
f.close()
if dat_path is not None:
lines = [dat_path]
else:
f = open(pred_csv, "r")
lines = f.read().splitlines()
f.close()
else:
if dat_path is not None:
lines = [dat_path]
lines = [formula]

mem = []

# CHANGE: simple heuristic for "formula-like" strings (e.g., LaB6, SiO2)
formula_like_re = re.compile(r"^([A-Z][a-z]?\d*)+$")

for i in lines:
prompt = i
if ".dat" in i or dat_path is not None:
Expand All @@ -185,6 +199,11 @@ def predict(
fname_csv = os.path.join(parent, i)
else:
fname_csv = dat_path

# Determine whether the DAT filename itself looks like a chemical formula
fname_stem = Path(fname_csv).stem
filename_looks_like_formula = bool(formula_like_re.fullmatch(fname_stem))

_formula, x, y = load_exp_file(
filename=fname_csv,
intvl=intvl,
Expand All @@ -198,17 +217,40 @@ def predict(
formula = str(_formula.split("/")[-1].split(".dat")[0])
except Exception:
pass
prompt = (
"The chemical formula is "
+ formula
+ " The "
+ temp_config["prop"]
+ " is "
+ y_new_str
+ ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types."
)

# CHANGE: inject formula sentence only if:
# - user provided --formula, OR
# - DAT filename itself is a formula (e.g., LaB6.dat)
if user_provided_formula and formula is not None:
prompt = (
"The chemical formula is "
+ formula
+ " The "
+ temp_config["prop"]
+ " is "
+ y_new_str
+ ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types."
)
elif (not user_provided_formula) and filename_looks_like_formula:
prompt = (
"The chemical formula is "
+ fname_stem
+ " The "
+ temp_config["prop"]
+ " is "
+ y_new_str
+ ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types."
)
else:
prompt = (
"The "
+ temp_config["prop"]
+ " is "
+ y_new_str
+ ". Generate atomic structure description with lattice lengths, angles, coordinates and atom types."
)
else:
if formula is not None:
if user_provided_formula and formula is not None:
prompt = (
"The chemical formula is "
+ formula
Expand Down Expand Up @@ -239,7 +281,6 @@ def predict(
info["prompt"] = prompt
info["error"] = "Invalid structure returned by AtomGPT (None)."
mem.append(info)
# skip the rest of the loop for this entry
continue

if verbose:
Expand All @@ -255,6 +296,7 @@ def predict(
info["prompt"] = prompt
info["atoms"] = gen_mat.to_dict()
mem.append(info)

dumpjson(data=mem, filename=fname)
return model, tokenizer, temp_config

Expand All @@ -273,3 +315,4 @@ def predict(
prop_val=args.prop_val,
background_subs=args.background_subs,
)

Loading