From 0cb15367e355b594704c2f351b17dbb2c336bf62 Mon Sep 17 00:00:00 2001 From: Emma Garrett Date: Mon, 9 Dec 2024 22:54:46 -0500 Subject: [PATCH 1/2] script to call mlx with google flan t5 --- ml_mdm/language_models/call_mlx_lm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 ml_mdm/language_models/call_mlx_lm.py diff --git a/ml_mdm/language_models/call_mlx_lm.py b/ml_mdm/language_models/call_mlx_lm.py new file mode 100644 index 0000000..367c6a3 --- /dev/null +++ b/ml_mdm/language_models/call_mlx_lm.py @@ -0,0 +1,12 @@ +from typing import Optional +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +def call_mlx_lm(input_text: str, model_name: Optional[str] = "google/flan-t5-base") -> str: + tokenizer = T5Tokenizer.from_pretrained(model_name) + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + input_ids = tokenizer.encode(input_text, return_tensors="pt") + output_ids = model.generate(input_ids) + + output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) + return output_text \ No newline at end of file From b4984a0a57babb0ccc309f151a64ca58f2998a29 Mon Sep 17 00:00:00 2001 From: Emma Garrett Date: Mon, 9 Dec 2024 23:01:25 -0500 Subject: [PATCH 2/2] added command line arguments --- ml_mdm/language_models/call_mlx_lm.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/ml_mdm/language_models/call_mlx_lm.py b/ml_mdm/language_models/call_mlx_lm.py index 367c6a3..2bedf83 100644 --- a/ml_mdm/language_models/call_mlx_lm.py +++ b/ml_mdm/language_models/call_mlx_lm.py @@ -1,12 +1,23 @@ from typing import Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +import argparse def call_mlx_lm(input_text: str, model_name: Optional[str] = "google/flan-t5-base") -> str: - tokenizer = T5Tokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) input_ids = tokenizer.encode(input_text, return_tensors="pt") output_ids = model.generate(input_ids) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) - return output_text \ No newline at end of file + return output_text + +# command line arguments +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--input_text", type=str, required=True) + parser.add_argument("--model", type=str, default="google/flan-t5-base") + args = parser.parse_args() + + output_text = call_mlx_lm(args.input_text, args.model) + print(output_text) \ No newline at end of file