Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/github/gh-models/pkg/util"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3"
)

// ModelParameters represents the parameters that can be set for a model run.
Expand Down Expand Up @@ -188,6 +189,22 @@ func isPipe(r io.Reader) bool {
return false
}

// promptFile mirrors the format of .prompt.yml
type promptFile struct {
Name string `yaml:"name"`
Comment thread
sgoedecke marked this conversation as resolved.
Description string `yaml:"description"`
Model string `yaml:"model"`
ModelParameters struct {
MaxTokens *int `yaml:"maxTokens"`
Temperature *float64 `yaml:"temperature"`
TopP *float64 `yaml:"topP"`
} `yaml:"modelParameters"`
Messages []struct {
Role string `yaml:"role"`
Content string `yaml:"content"`
} `yaml:"messages"`
}
Comment thread
sgoedecke marked this conversation as resolved.

// NewRunCommand returns a new gh command for running a model.
func NewRunCommand(cfg *command.Config) *cobra.Command {
cmd := &cobra.Command{
Expand All @@ -208,6 +225,24 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
Example: "gh models run openai/gpt-4o-mini \"how many types of hyena are there?\"",
Args: cobra.ArbitraryArgs,
RunE: func(cmd *cobra.Command, args []string) error {
filePath, _ := cmd.Flags().GetString("file")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That _, is that an error? Should you maybe check it and return early if it's non-nil?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's probably wise to do that, but I think it's safe to follow up in a later change. The possible errors here are things like "you defined the file flag as an integer but are trying to GetString it", not really things that users could trigger.

var pf *promptFile
if filePath != "" {
b, err := os.ReadFile(filePath)
if err != nil {
return err
}
p := promptFile{}
if err := yaml.Unmarshal(b, &p); err != nil {
return err
}
pf = &p
// Inject model name as the first positional arg if user didn't supply one
if pf.Model != "" && len(args) == 0 {
args = append([]string{pf.Model}, args...)
}
}

cmdHandler := newRunCommandHandler(cmd, cfg, args)
if cmdHandler == nil {
return nil
Expand Down Expand Up @@ -248,12 +283,36 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
systemPrompt: systemPrompt,
}

// preload conversation & parameters from YAML
if pf != nil {
for _, m := range pf.Messages {
switch strings.ToLower(m.Role) {
case "system":
if conversation.systemPrompt == "" {
conversation.systemPrompt = m.Content
} else {
conversation.AddMessage(azuremodels.ChatMessageRoleSystem, m.Content)
}
case "user":
conversation.AddMessage(azuremodels.ChatMessageRoleUser, m.Content)
case "assistant":
conversation.AddMessage(azuremodels.ChatMessageRoleAssistant, m.Content)
}
}
}

mp := ModelParameters{}
err = mp.PopulateFromFlags(cmd.Flags())
if err != nil {
return err
}

if pf != nil {
mp.maxTokens = pf.ModelParameters.MaxTokens
mp.temperature = pf.ModelParameters.Temperature
mp.topP = pf.ModelParameters.TopP
}

Comment thread
sgoedecke marked this conversation as resolved.
for {
prompt := ""
if initialPrompt != "" {
Expand Down Expand Up @@ -369,6 +428,7 @@ func NewRunCommand(cfg *command.Config) *cobra.Command {
},
}

cmd.Flags().String("file", "", "Path to a .prompt.yml file.")
cmd.Flags().String("max-tokens", "", "Limit the maximum tokens for the model response.")
cmd.Flags().String("temperature", "", "Controls randomness in the response, use lower to be more deterministic.")
cmd.Flags().String("top-p", "", "Controls text diversity by selecting the most probable words until a set probability is reached.")
Expand Down
70 changes: 70 additions & 0 deletions cmd/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package run
import (
"bytes"
"context"
"os"
"regexp"
"testing"

Expand Down Expand Up @@ -80,4 +81,73 @@ func TestRun(t *testing.T) {
require.Regexp(t, regexp.MustCompile(`--top-p string\s+Controls text diversity by selecting the most probable words until a set probability is reached\.`), output)
require.Empty(t, errBuf.String())
})

t.Run("--file pre-loads YAML from file", func(t *testing.T) {
const yamlBody = `
name: Text Summarizer
description: Summarizes input text concisely
model: openai/test-model
modelParameters:
temperature: 0.5
messages:
- role: system
content: You are a text summarizer.
- role: user
content: Hello there!
`
tmp, err := os.CreateTemp(t.TempDir(), "*.prompt.yml")
require.NoError(t, err)
_, err = tmp.WriteString(yamlBody)
require.NoError(t, err)
require.NoError(t, tmp.Close())

client := azuremodels.NewMockClient()
modelSummary := &azuremodels.ModelSummary{
Name: "test-model",
Publisher: "openai",
Task: "chat-completion",
}
client.MockListModels = func(ctx context.Context) ([]*azuremodels.ModelSummary, error) {
return []*azuremodels.ModelSummary{modelSummary}, nil
}

var capturedReq azuremodels.ChatCompletionOptions
reply := "Summary - foo"
chatCompletion := azuremodels.ChatCompletion{
Choices: []azuremodels.ChatChoice{{
Message: &azuremodels.ChatChoiceMessage{
Content: util.Ptr(reply),
Role: util.Ptr(string(azuremodels.ChatMessageRoleAssistant)),
},
}},
}
client.MockGetChatCompletionStream = func(ctx context.Context, opt azuremodels.ChatCompletionOptions) (*azuremodels.ChatCompletionResponse, error) {
capturedReq = opt
return &azuremodels.ChatCompletionResponse{
Reader: sse.NewMockEventReader([]azuremodels.ChatCompletion{chatCompletion}),
}, nil
}

out := new(bytes.Buffer)
cfg := command.NewConfig(out, out, client, true, 100)
runCmd := NewRunCommand(cfg)
runCmd.SetArgs([]string{
"--file", tmp.Name(),
azuremodels.FormatIdentifier("openai", "test-model"),
"foo?",
})

_, err = runCmd.ExecuteC()
require.NoError(t, err)

require.Equal(t, 3, len(capturedReq.Messages))
require.Equal(t, "You are a text summarizer.", *capturedReq.Messages[0].Content)
require.Equal(t, "Hello there!", *capturedReq.Messages[1].Content)
require.Equal(t, "foo?", *capturedReq.Messages[2].Content)

require.NotNil(t, capturedReq.Temperature)
require.Equal(t, 0.5, *capturedReq.Temperature)

require.Contains(t, out.String(), reply) // response streamed to output
})
}
14 changes: 14 additions & 0 deletions s.prompt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Text Summarizer
description: Summarizes input text concisely
model: openai/gpt-4o-mini
modelParameters:
temperature: 0.5
messages:
- role: system
content: You are a text summarizer. Your only job is to summarize text given to you.
- role: user
content: |
Summarize the given text, beginning with "Summary -":
<text>
{{input}}
</text>
Loading