diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index c1b22bc3b4f..3c9342cbfbf 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #include #include @@ -102,6 +103,10 @@ DEFINE_string(model_path, "", "Speculator model.pte path."); DEFINE_string(data_path, "", "Tensor data (.ptd) path for the CUDA backend."); DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path."); DEFINE_string(prompt, "Explain why the sky is blue.", "Prompt text."); +DEFINE_string( + prompt_file, + "", + "Read the prompt text from this file instead of --prompt."); DEFINE_bool(raw_prompt, false, "Skip the Gemma 4 IT chat template."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id (-1 to skip; Gemma convention: 2)."); @@ -315,6 +320,16 @@ int main(int argc, char** argv) { } std::string prompt_text = FLAGS_prompt; + if (!FLAGS_prompt_file.empty()) { + std::ifstream f(FLAGS_prompt_file); + if (!f) { + ET_LOG( + Error, "Failed to open prompt file: %s", FLAGS_prompt_file.c_str()); + return 1; + } + prompt_text = std::string( + std::istreambuf_iterator(f), std::istreambuf_iterator()); + } if (!FLAGS_raw_prompt) { prompt_text = FLAGS_chat_prefix + prompt_text + FLAGS_chat_suffix; }