diff --git a/docs/swagger/docs.go b/docs/swagger/docs.go index 8ce0c05..a468259 100644 --- a/docs/swagger/docs.go +++ b/docs/swagger/docs.go @@ -1245,6 +1245,9 @@ const docTemplate = `{ "id": { "type": "string" }, + "one_shot": { + "type": "boolean" + }, "rate_limit": { "type": "integer" }, @@ -1256,6 +1259,9 @@ const docTemplate = `{ }, "target": { "type": "string" + }, + "used": { + "type": "boolean" } } }, diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index 08a4855..13b9f09 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -1239,6 +1239,9 @@ "id": { "type": "string" }, + "one_shot": { + "type": "boolean" + }, "rate_limit": { "type": "integer" }, @@ -1250,6 +1253,9 @@ }, "target": { "type": "string" + }, + "used": { + "type": "boolean" } } }, diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index caffa55..a72749b 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -6,6 +6,8 @@ definitions: type: string id: type: string + one_shot: + type: boolean rate_limit: type: integer scope: @@ -14,6 +16,8 @@ definitions: type: string target: type: string + used: + type: boolean type: object orgs.Membership: properties: diff --git a/migrations/008_add_virtual_key_one_shot.sql b/migrations/008_add_virtual_key_one_shot.sql new file mode 100644 index 0000000..7a88381 --- /dev/null +++ b/migrations/008_add_virtual_key_one_shot.sql @@ -0,0 +1,2 @@ +ALTER TABLE virtual_keys ADD COLUMN IF NOT EXISTS one_shot BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE virtual_keys ADD COLUMN IF NOT EXISTS used BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/pkg/keys/virtualkey.go b/pkg/keys/virtualkey.go index 8c8b3fb..9a54d6e 100644 --- a/pkg/keys/virtualkey.go +++ b/pkg/keys/virtualkey.go @@ -11,6 +11,8 @@ type VirtualKey struct { Target string `json:"target" gorm:"not null"` RateLimit int `json:"rate_limit" gorm:"not null"` Source string `json:"source,omitempty" gorm:"size:16;default:''"` + OneShot bool `json:"one_shot,omitempty" gorm:"default:false"` + Used bool `json:"used,omitempty" gorm:"default:false"` } // SourceMCP is the source label for keys issued via the MCP tool. diff --git a/routes/mcp.go b/routes/mcp.go index 2b45cbd..aed1300 100644 --- a/routes/mcp.go +++ b/routes/mcp.go @@ -77,6 +77,7 @@ var availableTools = []mcpTool{ "service_name": {Type: "string", Description: "ID of the target service"}, "ttl_seconds": {Type: "integer", Description: "Key lifetime in seconds (default 3600)"}, "rate_limit": {Type: "integer", Description: "Max requests per minute (default 60)"}, + "one_shot": {Type: "boolean", Description: "If true, key is invalidated after first use"}, }, Required: []string{"service_name"}, }, @@ -163,6 +164,7 @@ func (s *Server) mcpRequestKey(w http.ResponseWriter, id any, raw json.RawMessag ServiceName string `json:"service_name"` TTLSeconds int `json:"ttl_seconds"` RateLimit int `json:"rate_limit"` + OneShot bool `json:"one_shot"` } if err := json.Unmarshal(raw, &args); err != nil { writeMCPError(w, id, mcpErrInvalid, "invalid arguments") @@ -196,6 +198,7 @@ func (s *Server) mcpRequestKey(w http.ResponseWriter, id any, raw json.RawMessag ExpiresAt: expiresAt, RateLimit: args.RateLimit, Source: keys.SourceMCP, + OneShot: args.OneShot, } if err := s.KeyStore.Create(k); err != nil { writeMCPError(w, id, mcpErrInternal, "failed to issue key") diff --git a/routes/v1/proxy.go b/routes/v1/proxy.go index 67debb3..b81fcd9 100644 --- a/routes/v1/proxy.go +++ b/routes/v1/proxy.go @@ -47,6 +47,11 @@ func (h *Handler) Proxy(w http.ResponseWriter, r *http.Request) { return } + if k.OneShot && k.Used { + writeError(w, "key already used", http.StatusUnauthorized) + return + } + if config.MetricsEnabled() { metrics.KeyUsageTotal.WithLabelValues(k.ID).Inc() } @@ -102,6 +107,13 @@ func (h *Handler) Proxy(w http.ResponseWriter, r *http.Request) { r.Header.Set("X-Bifrost-Agent-ID", k.ID) } + // Mark one-shot keys as used before forwarding — prevents replay even if + // the upstream returns an error. + if k.OneShot { + k.Used = true + h.KeyStore.Update(k.ID, k) + } + proxy := httputil.NewSingleHostReverseProxy(target) r.Host = target.Host proxy.ServeHTTP(w, r) diff --git a/tests/mcp_test.go b/tests/mcp_test.go index 0bcdbe9..86661c9 100644 --- a/tests/mcp_test.go +++ b/tests/mcp_test.go @@ -388,3 +388,28 @@ func TestProxyNoAgentIDForRegularKey(t *testing.T) { t.Errorf("X-Bifrost-Agent-ID should not be set for non-MCP keys, got %q", capturedAgentID) } } + +func TestMCPRequestKeyOneShot(t *testing.T) { + env := newTestEnv(t) + seedService(t, env, "svc-oneshot") + + resp := mcpCall(t, env, map[string]any{ + "jsonrpc": "2.0", + "id": 20, + "method": "tools/call", + "params": map[string]any{ + "name": "request_key", + "arguments": map[string]any{"service_name": "svc-oneshot", "one_shot": true}, + }, + }) + + result := resp["result"].(map[string]any) + vk := result["virtual_key"].(string) + k, err := env.Server.KeyStore.Get(vk) + if err != nil { + t.Fatalf("key not found: %v", err) + } + if !k.OneShot { + t.Error("expected one_shot=true on MCP-issued key") + } +} diff --git a/tests/proxy_test.go b/tests/proxy_test.go index b9a48e1..869fe92 100644 --- a/tests/proxy_test.go +++ b/tests/proxy_test.go @@ -1,6 +1,7 @@ package tests import ( + "encoding/json" "io" "net/http" "net/http/httptest" @@ -10,6 +11,7 @@ import ( "github.com/farovictor/bifrost/pkg/keys" "github.com/farovictor/bifrost/pkg/rootkeys" "github.com/farovictor/bifrost/pkg/services" + routes "github.com/farovictor/bifrost/routes" ) func TestProxy(t *testing.T) { @@ -170,3 +172,94 @@ func TestProxyScopeEnforcement(t *testing.T) { }) } } + +func seedProxyBackend(t *testing.T, s *routes.Server) (svcID string, backend *httptest.Server) { + t.Helper() + backend = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "ok") + })) + t.Cleanup(backend.Close) + rk := rootkeys.RootKey{ID: "rk-os-" + svcID, APIKey: "real"} + s.RootKeyStore.Create(rk) + svc := services.Service{ID: "svc-os", Endpoint: backend.URL, RootKeyID: rk.ID} + s.ServiceStore.Create(svc) + return "svc-os", backend +} + +func TestOneShotKeyUsedOnce(t *testing.T) { + s := newTestServer(t) + svcID, _ := seedProxyBackend(t, s) + + k := keys.VirtualKey{ID: "vk-oneshot", Target: svcID, Scope: keys.ScopeRead, ExpiresAt: time.Now().Add(time.Hour), RateLimit: 100, OneShot: true} + s.KeyStore.Create(k) + + router := setupRouter(s) + + // First request — should succeed. + req := httptest.NewRequest(http.MethodGet, "/v1/proxy/path", nil) + req.Header.Set("X-Virtual-Key", k.ID) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("first request: expected 200, got %d", rr.Code) + } + + // Second request — key is now used, must be rejected. + req2 := httptest.NewRequest(http.MethodGet, "/v1/proxy/path", nil) + req2.Header.Set("X-Virtual-Key", k.ID) + rr2 := httptest.NewRecorder() + router.ServeHTTP(rr2, req2) + if rr2.Code != http.StatusUnauthorized { + t.Fatalf("second request: expected 401, got %d", rr2.Code) + } + if msg := errorBody(t, rr2); msg != "key already used" { + t.Fatalf("unexpected error: %s", msg) + } +} + +func TestOneShotKeyAppearsInList(t *testing.T) { + env := newTestEnv(t) + svcID, _ := seedProxyBackend(t, env.Server) + + k := keys.VirtualKey{ID: "vk-os-list", Target: svcID, Scope: keys.ScopeRead, ExpiresAt: time.Now().Add(time.Hour), RateLimit: 10, OneShot: true} + env.Server.KeyStore.Create(k) + + req := httptest.NewRequest(http.MethodGet, "/v1/keys", nil) + env.Authorize(req) + rr := httptest.NewRecorder() + env.Router.ServeHTTP(rr, req) + + var listed []keys.VirtualKey + if err := json.Unmarshal(rr.Body.Bytes(), &listed); err != nil { + t.Fatalf("decode: %v", err) + } + for _, lk := range listed { + if lk.ID == k.ID { + if !lk.OneShot { + t.Error("expected one_shot=true in list response") + } + return + } + } + t.Fatal("one-shot key not found in list") +} + +func TestRegularKeyNotAffectedByOneShotLogic(t *testing.T) { + s := newTestServer(t) + svcID, _ := seedProxyBackend(t, s) + + k := keys.VirtualKey{ID: "vk-regular-os", Target: svcID, Scope: keys.ScopeRead, ExpiresAt: time.Now().Add(time.Hour), RateLimit: 100} + s.KeyStore.Create(k) + + router := setupRouter(s) + + for i := 0; i < 3; i++ { + req := httptest.NewRequest(http.MethodGet, "/v1/proxy/path", nil) + req.Header.Set("X-Virtual-Key", k.ID) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code) + } + } +}