Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/swagger/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,9 @@ const docTemplate = `{
"id": {
"type": "string"
},
"one_shot": {
"type": "boolean"
},
"rate_limit": {
"type": "integer"
},
Expand All @@ -1256,6 +1259,9 @@ const docTemplate = `{
},
"target": {
"type": "string"
},
"used": {
"type": "boolean"
}
}
},
Expand Down
6 changes: 6 additions & 0 deletions docs/swagger/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,9 @@
"id": {
"type": "string"
},
"one_shot": {
"type": "boolean"
},
"rate_limit": {
"type": "integer"
},
Expand All @@ -1250,6 +1253,9 @@
},
"target": {
"type": "string"
},
"used": {
"type": "boolean"
}
}
},
Expand Down
4 changes: 4 additions & 0 deletions docs/swagger/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ definitions:
type: string
id:
type: string
one_shot:
type: boolean
rate_limit:
type: integer
scope:
Expand All @@ -14,6 +16,8 @@ definitions:
type: string
target:
type: string
used:
type: boolean
type: object
orgs.Membership:
properties:
Expand Down
2 changes: 2 additions & 0 deletions migrations/008_add_virtual_key_one_shot.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE virtual_keys ADD COLUMN IF NOT EXISTS one_shot BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE virtual_keys ADD COLUMN IF NOT EXISTS used BOOLEAN NOT NULL DEFAULT FALSE;
2 changes: 2 additions & 0 deletions pkg/keys/virtualkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ type VirtualKey struct {
Target string `json:"target" gorm:"not null"`
RateLimit int `json:"rate_limit" gorm:"not null"`
Source string `json:"source,omitempty" gorm:"size:16;default:''"`
OneShot bool `json:"one_shot,omitempty" gorm:"default:false"`
Used bool `json:"used,omitempty" gorm:"default:false"`
}

// SourceMCP is the source label for keys issued via the MCP tool.
Expand Down
3 changes: 3 additions & 0 deletions routes/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ var availableTools = []mcpTool{
"service_name": {Type: "string", Description: "ID of the target service"},
"ttl_seconds": {Type: "integer", Description: "Key lifetime in seconds (default 3600)"},
"rate_limit": {Type: "integer", Description: "Max requests per minute (default 60)"},
"one_shot": {Type: "boolean", Description: "If true, key is invalidated after first use"},
},
Required: []string{"service_name"},
},
Expand Down Expand Up @@ -163,6 +164,7 @@ func (s *Server) mcpRequestKey(w http.ResponseWriter, id any, raw json.RawMessag
ServiceName string `json:"service_name"`
TTLSeconds int `json:"ttl_seconds"`
RateLimit int `json:"rate_limit"`
OneShot bool `json:"one_shot"`
}
if err := json.Unmarshal(raw, &args); err != nil {
writeMCPError(w, id, mcpErrInvalid, "invalid arguments")
Expand Down Expand Up @@ -196,6 +198,7 @@ func (s *Server) mcpRequestKey(w http.ResponseWriter, id any, raw json.RawMessag
ExpiresAt: expiresAt,
RateLimit: args.RateLimit,
Source: keys.SourceMCP,
OneShot: args.OneShot,
}
if err := s.KeyStore.Create(k); err != nil {
writeMCPError(w, id, mcpErrInternal, "failed to issue key")
Expand Down
12 changes: 12 additions & 0 deletions routes/v1/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ func (h *Handler) Proxy(w http.ResponseWriter, r *http.Request) {
return
}

if k.OneShot && k.Used {
writeError(w, "key already used", http.StatusUnauthorized)
return
}

if config.MetricsEnabled() {
metrics.KeyUsageTotal.WithLabelValues(k.ID).Inc()
}
Expand Down Expand Up @@ -102,6 +107,13 @@ func (h *Handler) Proxy(w http.ResponseWriter, r *http.Request) {
r.Header.Set("X-Bifrost-Agent-ID", k.ID)
}

// Mark one-shot keys as used before forwarding — prevents replay even if
// the upstream returns an error.
if k.OneShot {
k.Used = true
h.KeyStore.Update(k.ID, k)
}

proxy := httputil.NewSingleHostReverseProxy(target)
r.Host = target.Host
proxy.ServeHTTP(w, r)
Expand Down
25 changes: 25 additions & 0 deletions tests/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,28 @@ func TestProxyNoAgentIDForRegularKey(t *testing.T) {
t.Errorf("X-Bifrost-Agent-ID should not be set for non-MCP keys, got %q", capturedAgentID)
}
}

func TestMCPRequestKeyOneShot(t *testing.T) {
env := newTestEnv(t)
seedService(t, env, "svc-oneshot")

resp := mcpCall(t, env, map[string]any{
"jsonrpc": "2.0",
"id": 20,
"method": "tools/call",
"params": map[string]any{
"name": "request_key",
"arguments": map[string]any{"service_name": "svc-oneshot", "one_shot": true},
},
})

result := resp["result"].(map[string]any)
vk := result["virtual_key"].(string)
k, err := env.Server.KeyStore.Get(vk)
if err != nil {
t.Fatalf("key not found: %v", err)
}
if !k.OneShot {
t.Error("expected one_shot=true on MCP-issued key")
}
}
93 changes: 93 additions & 0 deletions tests/proxy_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/farovictor/bifrost/pkg/keys"
"github.com/farovictor/bifrost/pkg/rootkeys"
"github.com/farovictor/bifrost/pkg/services"
routes "github.com/farovictor/bifrost/routes"
)

func TestProxy(t *testing.T) {
Expand Down Expand Up @@ -170,3 +172,94 @@ func TestProxyScopeEnforcement(t *testing.T) {
})
}
}

func seedProxyBackend(t *testing.T, s *routes.Server) (svcID string, backend *httptest.Server) {
t.Helper()
backend = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "ok")
}))
t.Cleanup(backend.Close)
rk := rootkeys.RootKey{ID: "rk-os-" + svcID, APIKey: "real"}
s.RootKeyStore.Create(rk)
svc := services.Service{ID: "svc-os", Endpoint: backend.URL, RootKeyID: rk.ID}
s.ServiceStore.Create(svc)
return "svc-os", backend
}

func TestOneShotKeyUsedOnce(t *testing.T) {
s := newTestServer(t)
svcID, _ := seedProxyBackend(t, s)

k := keys.VirtualKey{ID: "vk-oneshot", Target: svcID, Scope: keys.ScopeRead, ExpiresAt: time.Now().Add(time.Hour), RateLimit: 100, OneShot: true}
s.KeyStore.Create(k)

router := setupRouter(s)

// First request — should succeed.
req := httptest.NewRequest(http.MethodGet, "/v1/proxy/path", nil)
req.Header.Set("X-Virtual-Key", k.ID)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("first request: expected 200, got %d", rr.Code)
}

// Second request — key is now used, must be rejected.
req2 := httptest.NewRequest(http.MethodGet, "/v1/proxy/path", nil)
req2.Header.Set("X-Virtual-Key", k.ID)
rr2 := httptest.NewRecorder()
router.ServeHTTP(rr2, req2)
if rr2.Code != http.StatusUnauthorized {
t.Fatalf("second request: expected 401, got %d", rr2.Code)
}
if msg := errorBody(t, rr2); msg != "key already used" {
t.Fatalf("unexpected error: %s", msg)
}
}

func TestOneShotKeyAppearsInList(t *testing.T) {
env := newTestEnv(t)
svcID, _ := seedProxyBackend(t, env.Server)

k := keys.VirtualKey{ID: "vk-os-list", Target: svcID, Scope: keys.ScopeRead, ExpiresAt: time.Now().Add(time.Hour), RateLimit: 10, OneShot: true}
env.Server.KeyStore.Create(k)

req := httptest.NewRequest(http.MethodGet, "/v1/keys", nil)
env.Authorize(req)
rr := httptest.NewRecorder()
env.Router.ServeHTTP(rr, req)

var listed []keys.VirtualKey
if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil {
t.Fatalf("decode: %v", err)
}
for _, lk := range listed {
if lk.ID == k.ID {
if !lk.OneShot {
t.Error("expected one_shot=true in list response")
}
return
}
}
t.Fatal("one-shot key not found in list")
}

func TestRegularKeyNotAffectedByOneShotLogic(t *testing.T) {
s := newTestServer(t)
svcID, _ := seedProxyBackend(t, s)

k := keys.VirtualKey{ID: "vk-regular-os", Target: svcID, Scope: keys.ScopeRead, ExpiresAt: time.Now().Add(time.Hour), RateLimit: 100}
s.KeyStore.Create(k)

router := setupRouter(s)

for i := 0; i < 3; i++ {
req := httptest.NewRequest(http.MethodGet, "/v1/proxy/path", nil)
req.Header.Set("X-Virtual-Key", k.ID)
rr := httptest.NewRecorder()
router.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code)
}
}
}
Loading