From f74c60aadeb713ebe4ee3ba570d619e7d4ae15a0 Mon Sep 17 00:00:00 2001 From: Sean Goedecke Date: Thu, 8 May 2025 01:41:22 +0000 Subject: [PATCH] Support loading prompt from yml file --- cmd/run/run.go | 60 ++++++++++++++++++++++++++++++++++++++ cmd/run/run_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++++ s.prompt.yml | 14 +++++++++ 3 files changed, 144 insertions(+) create mode 100644 s.prompt.yml diff --git a/cmd/run/run.go b/cmd/run/run.go index 4c60f034..3c3c6964 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -21,6 +21,7 @@ import ( "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" "github.com/spf13/pflag" + "gopkg.in/yaml.v3" ) // ModelParameters represents the parameters that can be set for a model run. @@ -188,6 +189,22 @@ func isPipe(r io.Reader) bool { return false } +// promptFile mirrors the format of .prompt.yml +type promptFile struct { + Name string `yaml:"name"` + Description string `yaml:"description"` + Model string `yaml:"model"` + ModelParameters struct { + MaxTokens *int `yaml:"maxTokens"` + Temperature *float64 `yaml:"temperature"` + TopP *float64 `yaml:"topP"` + } `yaml:"modelParameters"` + Messages []struct { + Role string `yaml:"role"` + Content string `yaml:"content"` + } `yaml:"messages"` +} + // NewRunCommand returns a new gh command for running a model. func NewRunCommand(cfg *command.Config) *cobra.Command { cmd := &cobra.Command{ @@ -208,6 +225,24 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"", Args: cobra.ArbitraryArgs, RunE: func(cmd *cobra.Command, args []string) error { + filePath, _ := cmd.Flags().GetString("file") + var pf *promptFile + if filePath != "" { + b, err := os.ReadFile(filePath) + if err != nil { + return err + } + p := promptFile{} + if err := yaml.Unmarshal(b, &p); err != nil { + return err + } + pf = &p + // Inject model name as the first positional arg if user didn't supply one + if pf.Model != "" && len(args) == 0 { + args = append([]string{pf.Model}, args...) + } + } + cmdHandler := newRunCommandHandler(cmd, cfg, args) if cmdHandler == nil { return nil @@ -248,12 +283,36 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { systemPrompt: systemPrompt, } + // preload conversation & parameters from YAML + if pf != nil { + for _, m := range pf.Messages { + switch strings.ToLower(m.Role) { + case "system": + if conversation.systemPrompt == "" { + conversation.systemPrompt = m.Content + } else { + conversation.AddMessage(azuremodels.ChatMessageRoleSystem, m.Content) + } + case "user": + conversation.AddMessage(azuremodels.ChatMessageRoleUser, m.Content) + case "assistant": + conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, m.Content) + } + } + } + mp := ModelParameters{} err = mp.PopulateFromFlags(cmd.Flags()) if err != nil { return err } + if pf != nil { + mp.maxTokens = pf.ModelParameters.MaxTokens + mp.temperature = pf.ModelParameters.Temperature + mp.topP = pf.ModelParameters.TopP + } + for { prompt := "" if initialPrompt != "" { @@ -369,6 +428,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command { }, } + cmd.Flags().String("file", "", "Path to a .prompt.yml file.") cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.") cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.") cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.") diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 5b88cfa6..ca20ca10 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -3,6 +3,7 @@ package run import ( "bytes" "context" + "os" "regexp" "testing" @@ -80,4 +81,73 @@ func TestRun(t *testing.T) { require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output) require.Empty(t, errBuf.String()) }) + + t.Run("--file pre-loads YAML from file", func(t *testing.T) { + const yamlBody = ` +name: Text Summarizer +description: Summarizes input text concisely +model: openai/test-model +modelParameters: + temperature: 0.5 +messages: + - role: system + content: You are a text summarizer. + - role: user + content: Hello there! +` + tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml") + require.NoError(t, err) + _, err = tmp.WriteString(yamlBody) + require.NoError(t, err) + require.NoError(t, tmp.Close()) + + client := azuremodels.NewMockClient() + modelSummary := &azuremodels.ModelSummary{ + Name: "test-model", + Publisher: "openai", + Task: "chat-completion", + } + client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) { + return []*azuremodels.ModelSummary{modelSummary}, nil + } + + var capturedReq azuremodels.ChatCompletionOptions + reply := "Summary - foo" + chatCompletion := azuremodels.ChatCompletion{ + Choices: []azuremodels.ChatChoice{{ + Message: &azuremodels.ChatChoiceMessage{ + Content: util.Ptr(reply), + Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)), + }, + }}, + } + client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) { + capturedReq = opt + return &azuremodels.ChatCompletionResponse{ + Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}), + }, nil + } + + out := new(bytes.Buffer) + cfg := command.NewConfig(out, out, client, true, 100) + runCmd := NewRunCommand(cfg) + runCmd.SetArgs([]string{ + "--file", tmp.Name(), + azuremodels.FormatIdentifier("openai", "test-model"), + "foo?", + }) + + _, err = runCmd.ExecuteC() + require.NoError(t, err) + + require.Equal(t, 3, len(capturedReq.Messages)) + require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content) + require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content) + require.Equal(t, "foo?", *capturedReq.Messages[2].Content) + + require.NotNil(t, capturedReq.Temperature) + require.Equal(t, 0.5, *capturedReq.Temperature) + + require.Contains(t, out.String(), reply) // response streamed to output + }) } diff --git a/s.prompt.yml b/s.prompt.yml new file mode 100644 index 00000000..b8b577f2 --- /dev/null +++ b/s.prompt.yml @@ -0,0 +1,14 @@ +name: Text Summarizer +description: Summarizes input text concisely +model: openai/gpt-4o-mini +modelParameters: + temperature: 0.5 +messages: + - role: system + content: You are a text summarizer. Your only job is to summarize text given to you. + - role: user + content: | + Summarize the given text, beginning with "Summary -": + + {{input}} + \ No newline at end of file