From 09ea5c88c97652d2e128e31ed27918c3ed7c4cc3 Mon Sep 17 00:00:00 2001 From: thiago-figueredo Date: Sun, 28 Dec 2025 23:40:52 -0300 Subject: [PATCH] [aisdk] add tokens count feature --- .gitignore | 1 + aisdk/ai/api/llm_token_count.go | 22 ++++++ aisdk/ai/count_tokens.go | 74 ++++++++++++++++++ aisdk/ai/count_tokens_test.go | 61 +++++++++++++++ aisdk/ai/examples/count_tokens/.env.example | 1 + aisdk/ai/examples/count_tokens/go.mod | 20 +++++ aisdk/ai/examples/count_tokens/go.sum | 37 +++++++++ aisdk/ai/examples/count_tokens/main.go | 79 +++++++++++++++++++ aisdk/ai/provider/anthropic/llm.go | 84 +++++++++++++++++++++ 9 files changed, 379 insertions(+) create mode 100644 aisdk/ai/api/llm_token_count.go create mode 100644 aisdk/ai/count_tokens.go create mode 100644 aisdk/ai/count_tokens_test.go create mode 100644 aisdk/ai/examples/count_tokens/.env.example create mode 100644 aisdk/ai/examples/count_tokens/go.mod create mode 100644 aisdk/ai/examples/count_tokens/go.sum create mode 100644 aisdk/ai/examples/count_tokens/main.go diff --git a/.gitignore b/.gitignore index 1801db50..49958d76 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist .vscode +aisdk/ai/examples/count_tokens/.env diff --git a/aisdk/ai/api/llm_token_count.go b/aisdk/ai/api/llm_token_count.go new file mode 100644 index 00000000..99a34641 --- /dev/null +++ b/aisdk/ai/api/llm_token_count.go @@ -0,0 +1,22 @@ +package api + +import "context" + +// TokenCount represents the result of counting tokens in a prompt. +type TokenCount struct { + // InputTokens is the total number of tokens in the input prompt. + InputTokens int `json:"input_tokens"` +} + +// TokenCounter is an optional interface that LanguageModels can implement +// to support token counting functionality. +// +// Token counting allows estimating the cost of a request and checking +// whether a prompt fits within the model's context window before making +// an API call. +type TokenCounter interface { + // CountTokens counts the number of tokens in the given messages. + // Returns the token count or an error if counting fails. + CountTokens(ctx context.Context, prompt []Message, opts CallOptions) (*TokenCount, error) +} + diff --git a/aisdk/ai/count_tokens.go b/aisdk/ai/count_tokens.go new file mode 100644 index 00000000..0de25e53 --- /dev/null +++ b/aisdk/ai/count_tokens.go @@ -0,0 +1,74 @@ +package ai + +import ( + "context" + "fmt" + + "go.jetify.com/ai/api" +) + +// CountTokens counts the number of tokens in the given messages using the specified model. +// +// This function is useful for estimating costs and checking whether a prompt +// fits within the model's context window before making a generation request. +// +// Example usage: +// +// count, err := ai.CountTokens(ctx, messages, ai.WithModel(model)) +// if err != nil { +// // Handle error - model may not support token counting +// } +// fmt.Printf("Token count: %d\n", count.InputTokens) +// +// The function accepts [GenerateOption] arguments: +// +// ai.CountTokens(ctx, messages, ai.WithModel(model), ai.WithTools(tools...)) +func CountTokens(ctx context.Context, prompt []api.Message, opts ...GenerateOption) (*api.TokenCount, error) { + config := buildGenerateConfig(opts) + return countTokens(ctx, prompt, config) +} + +// CountTokensStr is a convenience function for counting tokens in a simple string. +// +// Example usage: +// +// count, err := ai.CountTokensStr(ctx, "Hello, world!", ai.WithModel(model)) +// if err != nil { +// // Handle error +// } +// fmt.Printf("Token count: %d\n", count.InputTokens) +// +// The string is automatically converted to a [api.UserMessage] before counting. +func CountTokensStr(ctx context.Context, text string, opts ...GenerateOption) (*api.TokenCount, error) { + msg := &api.UserMessage{ + Content: []api.ContentBlock{&api.TextBlock{Text: text}}, + } + return CountTokens(ctx, []api.Message{msg}, opts...) +} + +func countTokens(ctx context.Context, prompt []api.Message, opts GenerateOptions) (*api.TokenCount, error) { + // Check if the model implements TokenCounter + counter, ok := opts.Model.(api.TokenCounter) + if !ok { + return nil, api.NewUnsupportedFunctionalityError( + "token counting", + fmt.Sprintf("model %q does not support token counting", opts.Model.ModelID()), + ) + } + return counter.CountTokens(ctx, prompt, opts.CallOptions) +} + +func getEncodingForModel(modelID string) string { + if len(modelID) == 0 { + return "cl100k_base" + } + + if modelID == "gpt-4o" || modelID == "gpt-4o-mini" || modelID == "gpt-5" || + modelID == "o1-preview" || modelID == "o3-mini" || + (len(modelID) > 0 && modelID[0] == 'o') { + return "o200k_base" + } + + return "cl100k_base" +} + diff --git a/aisdk/ai/count_tokens_test.go b/aisdk/ai/count_tokens_test.go new file mode 100644 index 00000000..951091a2 --- /dev/null +++ b/aisdk/ai/count_tokens_test.go @@ -0,0 +1,61 @@ +package ai + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.jetify.com/ai/api" +) + +func TestCountTokens_UnsupportedModel(t *testing.T) { + ctx := context.Background() + messages := []api.Message{ + &api.UserMessage{ + Content: []api.ContentBlock{&api.TextBlock{Text: "Hello, world!"}}, + }, + } + + model := &mockLanguageModel{name: "unsupported-model"} + count, err := CountTokens(ctx, messages, WithModel(model)) + + assert.Nil(t, count) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token counting") + assert.Contains(t, err.Error(), "unsupported-model") +} + +func TestCountTokensStr_UnsupportedModel(t *testing.T) { + ctx := context.Background() + model := &mockLanguageModel{name: "unsupported-model"} + count, err := CountTokensStr(ctx, "Hello, world!", WithModel(model)) + + assert.Nil(t, count) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token counting") +} + +func TestGetEncodingForModel(t *testing.T) { + tests := []struct { + modelID string + expected string + }{ + {"gpt-4o", "o200k_base"}, + {"gpt-4o-mini", "o200k_base"}, + {"gpt-5", "o200k_base"}, + {"o1-preview", "o200k_base"}, + {"o3-mini", "o200k_base"}, + {"gpt-4", "cl100k_base"}, + {"gpt-4-turbo", "cl100k_base"}, + {"gpt-3.5-turbo", "cl100k_base"}, + {"unknown-model", "cl100k_base"}, + } + + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + result := getEncodingForModel(tt.modelID) + assert.Equal(t, tt.expected, result) + }) + } +} + diff --git a/aisdk/ai/examples/count_tokens/.env.example b/aisdk/ai/examples/count_tokens/.env.example new file mode 100644 index 00000000..daaa3741 --- /dev/null +++ b/aisdk/ai/examples/count_tokens/.env.example @@ -0,0 +1 @@ +ANTHROPIC_API_KEY="your-api-key-here" \ No newline at end of file diff --git a/aisdk/ai/examples/count_tokens/go.mod b/aisdk/ai/examples/count_tokens/go.mod new file mode 100644 index 00000000..84a79a93 --- /dev/null +++ b/aisdk/ai/examples/count_tokens/go.mod @@ -0,0 +1,20 @@ +module go.jetify.com/ai/examples/count_tokens + +go 1.24.0 + +require ( + github.com/anthropics/anthropic-sdk-go v1.19.0 + github.com/joho/godotenv v1.5.1 + go.jetify.com/ai v0.0.0 +) + +require ( + github.com/google/jsonschema-go v0.3.0 // indirect + github.com/openai/openai-go/v2 v2.7.1 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect +) + +replace go.jetify.com/ai => ../.. diff --git a/aisdk/ai/examples/count_tokens/go.sum b/aisdk/ai/examples/count_tokens/go.sum new file mode 100644 index 00000000..218429c2 --- /dev/null +++ b/aisdk/ai/examples/count_tokens/go.sum @@ -0,0 +1,37 @@ +github.com/anthropics/anthropic-sdk-go v1.19.0 h1:mO6E+ffSzLRvR/YUH9KJC0uGw0uV8GjISIuzem//3KE= +github.com/anthropics/anthropic-sdk-go v1.19.0/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q= +github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/openai/openai-go/v2 v2.7.1 h1:/tfvTJhfv7hTSL8mWwc5VL4WLLSDL5yn9VqVykdu9r8= +github.com/openai/openai-go/v2 v2.7.1/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +go.jetify.com/pkg v0.0.0-20251201231142-abe4fc632859 h1:opdRo9847AH1/OmuXvWQUSO3gfnrfl7QaeS8dC3UYwg= +go.jetify.com/pkg v0.0.0-20251201231142-abe4fc632859/go.mod h1:qR6Mz3JVuEXEINbNIoDCMpKgkNG69mtCbDKbu4iB1GM= +go.jetify.com/sse v0.1.0 h1:zLIT5XFlUVuTl68bHalpFDYbfSfXJPkmAbtmBqIHl2Q= +go.jetify.com/sse v0.1.0/go.mod h1:zFADPn3Z0aZJe3+PbArGMGwe3oTwHxPZIwNILoRCmU8= +go.yaml.in/yaml/v4 v4.0.0-rc.3 h1:3h1fjsh1CTAPjW7q/EMe+C8shx5d8ctzZTrLcs/j8Go= +go.yaml.in/yaml/v4 v4.0.0-rc.3/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0= +gopkg.in/dnaeon/go-vcr.v4 v4.0.6 h1:PiJkrakkmzc5s7EfBnZOnyiLwi7o7A9fwPzN0X2uwe0= +gopkg.in/dnaeon/go-vcr.v4 v4.0.6/go.mod h1:sbq5oMEcM4PXngbcNbHhzfCP9OdZodLhrbRYoyg09HY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/aisdk/ai/examples/count_tokens/main.go b/aisdk/ai/examples/count_tokens/main.go new file mode 100644 index 00000000..81173996 --- /dev/null +++ b/aisdk/ai/examples/count_tokens/main.go @@ -0,0 +1,79 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + anthropicsdk "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/joho/godotenv" + "go.jetify.com/ai" + "go.jetify.com/ai/api" + "go.jetify.com/ai/provider/anthropic" +) + +func main() { + if err := godotenv.Load(); err != nil { + log.Println("No .env file found, using environment variables") + } + + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + log.Fatal("ANTHROPIC_API_KEY environment variable is required") + } + + ctx := context.Background() + + client := anthropicsdk.NewClient(option.WithAPIKey(apiKey)) + model := anthropic.NewLanguageModel( + "claude-sonnet-4-20250514", + anthropic.WithClient(client), + ) + + // Test with a simple string + fmt.Println("=== CountTokensStr ===") + text := "Hello, world! How are you today?" + count, err := ai.CountTokensStr(ctx, text, ai.WithModel(model)) + if err != nil { + log.Fatalf("CountTokensStr error: %v", err) + } + fmt.Printf("Model: %s\n", model.ModelID()) + fmt.Printf("Message: %q\n", text) + fmt.Printf("Input tokens: %d\n\n", count.InputTokens) + + // Test with multiple messages + fmt.Println("=== CountTokens with multiple messages ===") + messages := []api.Message{ + &api.SystemMessage{Content: "You are a helpful assistant."}, + &api.UserMessage{ + Content: []api.ContentBlock{ + &api.TextBlock{Text: "What is the capital of France?"}, + }, + }, + &api.AssistantMessage{ + Content: []api.ContentBlock{ + &api.TextBlock{Text: "The capital of France is Paris."}, + }, + }, + &api.UserMessage{ + Content: []api.ContentBlock{ + &api.TextBlock{Text: "What about Germany?"}, + }, + }, + } + + count, err = ai.CountTokens(ctx, messages, ai.WithModel(model)) + if err != nil { + log.Fatalf("CountTokens error: %v", err) + } + fmt.Printf("Model: %s\n", model.ModelID()) + fmt.Println("Messages:") + fmt.Println(" [system] You are a helpful assistant.") + fmt.Println(" [user] What is the capital of France?") + fmt.Println(" [assistant] The capital of France is Paris.") + fmt.Println(" [user] What about Germany?") + fmt.Printf("Input tokens: %d\n", count.InputTokens) +} + diff --git a/aisdk/ai/provider/anthropic/llm.go b/aisdk/ai/provider/anthropic/llm.go index e7623c0a..0827ccbb 100644 --- a/aisdk/ai/provider/anthropic/llm.go +++ b/aisdk/ai/provider/anthropic/llm.go @@ -27,6 +27,7 @@ type LanguageModel struct { } var _ api.LanguageModel = &LanguageModel{} +var _ api.TokenCounter = &LanguageModel{} // NewLanguageModel creates a new Anthropic language model. func NewLanguageModel(modelID string, opts ...ModelOption) *LanguageModel { @@ -91,3 +92,86 @@ func (m *LanguageModel) Stream( ) (*api.StreamResponse, error) { return nil, api.NewUnsupportedFunctionalityError("streaming generation", "") } + +func (m *LanguageModel) CountTokens( + ctx context.Context, prompt []api.Message, opts api.CallOptions, +) (*api.TokenCount, error) { + anthropicPrompt, err := codec.EncodePrompt(prompt) + if err != nil { + return nil, err + } + + params := anthropic.BetaMessageCountTokensParams{ + Model: anthropic.Model(m.modelID), + } + + if len(anthropicPrompt.System) > 0 { + params.System = anthropic.BetaMessageCountTokensParamsSystemUnion{ + OfBetaTextBlockArray: anthropicPrompt.System, + } + } + if len(anthropicPrompt.Messages) > 0 { + params.Messages = anthropicPrompt.Messages + } + + if len(opts.Tools) > 0 { + tools, err := codec.EncodeTools(opts.Tools, opts.ToolChoice) + if err != nil { + return nil, err + } + if len(tools.Tools) > 0 { + countTokensTools := make([]anthropic.BetaMessageCountTokensParamsToolUnion, len(tools.Tools)) + for i, tool := range tools.Tools { + countTokensTool := anthropic.BetaMessageCountTokensParamsToolUnion{} + if tool.OfTool != nil { + countTokensTool.OfTool = tool.OfTool + } else if tool.OfBashTool20241022 != nil { + countTokensTool.OfBashTool20241022 = tool.OfBashTool20241022 + } else if tool.OfBashTool20250124 != nil { + countTokensTool.OfBashTool20250124 = tool.OfBashTool20250124 + } else if tool.OfCodeExecutionTool20250522 != nil { + countTokensTool.OfCodeExecutionTool20250522 = tool.OfCodeExecutionTool20250522 + } else if tool.OfCodeExecutionTool20250825 != nil { + countTokensTool.OfCodeExecutionTool20250825 = tool.OfCodeExecutionTool20250825 + } else if tool.OfComputerUseTool20241022 != nil { + countTokensTool.OfComputerUseTool20241022 = tool.OfComputerUseTool20241022 + } else if tool.OfMemoryTool20250818 != nil { + countTokensTool.OfMemoryTool20250818 = tool.OfMemoryTool20250818 + } else if tool.OfComputerUseTool20250124 != nil { + countTokensTool.OfComputerUseTool20250124 = tool.OfComputerUseTool20250124 + } else if tool.OfTextEditor20241022 != nil { + countTokensTool.OfTextEditor20241022 = tool.OfTextEditor20241022 + } else if tool.OfComputerUseTool20251124 != nil { + countTokensTool.OfComputerUseTool20251124 = tool.OfComputerUseTool20251124 + } else if tool.OfTextEditor20250124 != nil { + countTokensTool.OfTextEditor20250124 = tool.OfTextEditor20250124 + } else if tool.OfTextEditor20250429 != nil { + countTokensTool.OfTextEditor20250429 = tool.OfTextEditor20250429 + } else if tool.OfTextEditor20250728 != nil { + countTokensTool.OfTextEditor20250728 = tool.OfTextEditor20250728 + } else if tool.OfWebSearchTool20250305 != nil { + countTokensTool.OfWebSearchTool20250305 = tool.OfWebSearchTool20250305 + } else if tool.OfWebFetchTool20250910 != nil { + countTokensTool.OfWebFetchTool20250910 = tool.OfWebFetchTool20250910 + } else if tool.OfToolSearchToolBm25_20251119 != nil { + countTokensTool.OfToolSearchToolBm25_20251119 = tool.OfToolSearchToolBm25_20251119 + } else if tool.OfToolSearchToolRegex20251119 != nil { + countTokensTool.OfToolSearchToolRegex20251119 = tool.OfToolSearchToolRegex20251119 + } else if tool.OfMCPToolset != nil { + countTokensTool.OfMCPToolset = tool.OfMCPToolset + } + countTokensTools[i] = countTokensTool + } + params.Tools = countTokensTools + } + } + + result, err := m.client.Beta.Messages.CountTokens(ctx, params) + if err != nil { + return nil, err + } + + return &api.TokenCount{ + InputTokens: int(result.InputTokens), + }, nil +}