diff --git a/config/config.go b/config/config.go index e95629d..8f0ea14 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "os" "strconv" "strings" + "time" ) // ServerPort returns the port the HTTP server should listen on. @@ -197,6 +198,18 @@ func StaticAPIKey() string { return key } +// TokenTTL returns the default lifetime for bearer tokens issued by the server. +// It reads BIFROST_TOKEN_TTL (any Go duration string, e.g. "1h", "30m") and +// defaults to 24h when unset or unparseable. +func TokenTTL() time.Duration { + if v := os.Getenv("BIFROST_TOKEN_TTL"); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return 24 * time.Hour +} + // TrackTokens reports whether token counting is enabled via BIFROST_TRACK_TOKENS=true. func TrackTokens() bool { switch os.Getenv("BIFROST_TRACK_TOKENS") { diff --git a/docs/swagger/docs.go b/docs/swagger/docs.go index 8608da4..a79c821 100644 --- a/docs/swagger/docs.go +++ b/docs/swagger/docs.go @@ -1354,7 +1354,10 @@ const docTemplate = `{ "BearerAuth": [] } ], - "description": "Accepts a valid bearer token and returns a new one with a fresh 24h expiry.", + "description": "Accepts a valid bearer token and returns a new one. An optional JSON body may specify a \"ttl\" field (e.g. \"1h\"); falls back to BIFROST_TOKEN_TTL or 24h.", + "consumes": [ + "application/json" + ], "produces": [ "application/json" ], @@ -1362,6 +1365,16 @@ const docTemplate = `{ "users" ], "summary": "Refresh bearer token", + "parameters": [ + { + "description": "Optional TTL override: {\\", + "name": "body", + "in": "body", + "schema": { + "type": "object" + } + } + ], "responses": { "200": { "description": "New token: {\\\"token\\\":\\\"...\\\"}", @@ -1369,6 +1382,12 @@ const docTemplate = `{ "type": "object" } }, + "400": { + "description": "invalid ttl", + "schema": { + "$ref": "#/definitions/routes.ErrorResponse" + } + }, "401": { "description": "invalid or expired token", "schema": { @@ -1612,6 +1631,10 @@ const docTemplate = `{ "role": { "type": "string", "example": "member" + }, + "ttl": { + "type": "string", + "example": "1h" } } }, diff --git a/docs/swagger/swagger.json b/docs/swagger/swagger.json index 2beb0e0..a913bce 100644 --- a/docs/swagger/swagger.json +++ b/docs/swagger/swagger.json @@ -1348,7 +1348,10 @@ "BearerAuth": [] } ], - "description": "Accepts a valid bearer token and returns a new one with a fresh 24h expiry.", + "description": "Accepts a valid bearer token and returns a new one. An optional JSON body may specify a \"ttl\" field (e.g. \"1h\"); falls back to BIFROST_TOKEN_TTL or 24h.", + "consumes": [ + "application/json" + ], "produces": [ "application/json" ], @@ -1356,6 +1359,16 @@ "users" ], "summary": "Refresh bearer token", + "parameters": [ + { + "description": "Optional TTL override: {\\", + "name": "body", + "in": "body", + "schema": { + "type": "object" + } + } + ], "responses": { "200": { "description": "New token: {\\\"token\\\":\\\"...\\\"}", @@ -1363,6 +1376,12 @@ "type": "object" } }, + "400": { + "description": "invalid ttl", + "schema": { + "$ref": "#/definitions/routes.ErrorResponse" + } + }, "401": { "description": "invalid or expired token", "schema": { @@ -1606,6 +1625,10 @@ "role": { "type": "string", "example": "member" + }, + "ttl": { + "type": "string", + "example": "1h" } } }, diff --git a/docs/swagger/swagger.yaml b/docs/swagger/swagger.yaml index ce70276..5bd5615 100644 --- a/docs/swagger/swagger.yaml +++ b/docs/swagger/swagger.yaml @@ -67,6 +67,9 @@ definitions: role: example: member type: string + ttl: + example: 1h + type: string type: object routes.CreateUserResponse: properties: @@ -1040,8 +1043,17 @@ paths: - setup /v1/token/refresh: post: - description: Accepts a valid bearer token and returns a new one with a fresh - 24h expiry. + consumes: + - application/json + description: Accepts a valid bearer token and returns a new one. An optional + JSON body may specify a "ttl" field (e.g. "1h"); falls back to BIFROST_TOKEN_TTL + or 24h. + parameters: + - description: 'Optional TTL override: {\' + in: body + name: body + schema: + type: object produces: - application/json responses: @@ -1049,6 +1061,10 @@ paths: description: 'New token: {\"token\":\"...\"}' schema: type: object + "400": + description: invalid ttl + schema: + $ref: '#/definitions/routes.ErrorResponse' "401": description: invalid or expired token schema: diff --git a/routes/setup.go b/routes/setup.go index 6efeb96..5ccf2b9 100644 --- a/routes/setup.go +++ b/routes/setup.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" + "github.com/farovictor/bifrost/config" "github.com/farovictor/bifrost/pkg/orgs" "github.com/farovictor/bifrost/pkg/users" "github.com/farovictor/bifrost/pkg/utils" @@ -86,7 +87,7 @@ func (s *Server) Setup(w http.ResponseWriter, r *http.Request) { return } - token, err := buildAuthToken(u.ID, o.ID) + token, err := buildAuthToken(u.ID, o.ID, config.TokenTTL()) if err != nil { writeError(w, "internal error", http.StatusInternalServerError) return diff --git a/routes/users.go b/routes/users.go index db14ac8..e31974f 100644 --- a/routes/users.go +++ b/routes/users.go @@ -2,10 +2,12 @@ package routes import ( "encoding/json" + "fmt" "net/http" "strings" "time" + "github.com/farovictor/bifrost/config" "github.com/farovictor/bifrost/pkg/auth" "github.com/farovictor/bifrost/pkg/logging" "github.com/farovictor/bifrost/pkg/orgs" @@ -20,6 +22,7 @@ type CreateUserRequest struct { OrgID string `json:"org_id" example:""` OrgName string `json:"org_name" example:"Acme"` Role string `json:"role" example:"member"` + TTL string `json:"ttl" example:"1h"` } // CreateUserResponse is returned on successful user creation. @@ -126,7 +129,12 @@ func (s *Server) CreateUser(w http.ResponseWriter, r *http.Request) { } } - token, err := buildAuthToken(u.ID, orgID) + ttl, err := parseTTL(req.TTL) + if err != nil { + writeError(w, "invalid ttl: use a Go duration string (e.g. \"1h\", \"30m\")", http.StatusBadRequest) + return + } + token, err := buildAuthToken(u.ID, orgID, ttl) if err != nil { writeError(w, "internal error", http.StatusInternalServerError) return @@ -143,13 +151,16 @@ func (s *Server) CreateUser(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(resp) } -// RefreshToken handles POST /token/refresh and issues a fresh 24h token. +// RefreshToken handles POST /token/refresh and issues a fresh token. // // @Summary Refresh bearer token -// @Description Accepts a valid bearer token and returns a new one with a fresh 24h expiry. +// @Description Accepts a valid bearer token and returns a new one. An optional JSON body may specify a "ttl" field (e.g. "1h"); falls back to BIFROST_TOKEN_TTL or 24h. // @Tags users +// @Accept json // @Produce json +// @Param body body object false "Optional TTL override: {\"ttl\":\"1h\"}" // @Success 200 {object} object "New token: {\"token\":\"...\"}" +// @Failure 400 {object} ErrorResponse "invalid ttl" // @Failure 401 {object} ErrorResponse "invalid or expired token" // @Failure 500 {object} ErrorResponse // @Security BearerAuth @@ -166,7 +177,20 @@ func (s *Server) RefreshToken(w http.ResponseWriter, r *http.Request) { writeError(w, "unauthorized", http.StatusUnauthorized) return } - token, err := buildAuthToken(tok.UserID, tok.OrgID) + + var body struct { + TTL string `json:"ttl"` + } + // Body is optional — ignore decode errors. + json.NewDecoder(r.Body).Decode(&body) //nolint:errcheck + + ttl, err := parseTTL(body.TTL) + if err != nil { + writeError(w, "invalid ttl: use a Go duration string (e.g. \"1h\", \"30m\")", http.StatusBadRequest) + return + } + + token, err := buildAuthToken(tok.UserID, tok.OrgID, ttl) if err != nil { writeError(w, "internal error", http.StatusInternalServerError) return @@ -177,11 +201,25 @@ func (s *Server) RefreshToken(w http.ResponseWriter, r *http.Request) { }{Token: token}) } -func buildAuthToken(userID, orgID string) (string, error) { +// parseTTL converts a duration string to a time.Duration, falling back to +// config.TokenTTL() when s is empty. Returns an error when s is non-empty +// but unparseable or <= 0. +func parseTTL(s string) (time.Duration, error) { + if s == "" { + return config.TokenTTL(), nil + } + d, err := time.ParseDuration(s) + if err != nil || d <= 0 { + return 0, fmt.Errorf("invalid duration %q", s) + } + return d, nil +} + +func buildAuthToken(userID, orgID string, ttl time.Duration) (string, error) { t := auth.AuthToken{ UserID: userID, OrgID: orgID, - ExpiresAt: time.Now().Add(24 * time.Hour), + ExpiresAt: time.Now().Add(ttl), } return auth.Sign(t) } diff --git a/tests/routes_test.go b/tests/routes_test.go index 9e4de99..ac6c667 100644 --- a/tests/routes_test.go +++ b/tests/routes_test.go @@ -29,6 +29,7 @@ func setupRouter(s *routes.Server) http.Handler { r.With(rl.OrgCtxMiddleware(s.MembershipStore)).Get("/user", s.GetUserInfo) r.With(rl.OrgCtxMiddleware(s.MembershipStore)).Post("/user/rootkeys", s.CreateRootKey) + r.Post("/token/refresh", s.RefreshToken) r.Post("/service-token", s.ServiceToken) r.With(rl.RateLimitMiddleware(s.KeyStore)).Handle("/proxy/*", http.HandlerFunc(v1h.Proxy)) diff --git a/tests/token_ttl_test.go b/tests/token_ttl_test.go new file mode 100644 index 0000000..1b17235 --- /dev/null +++ b/tests/token_ttl_test.go @@ -0,0 +1,130 @@ +package tests + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/farovictor/bifrost/pkg/auth" +) + +// createUserWithTTL posts to /v1/users with an optional ttl field and returns +// the decoded token string. +func createUserWithTTL(t *testing.T, env *TestEnv, name, email, ttl string) string { + t.Helper() + payload := map[string]string{"name": name, "email": email} + if ttl != "" { + payload["ttl"] = ttl + } + body, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/v1/users", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // CreateUser requires a bearer token (OrgCtxMiddleware) + req.Header.Set("Authorization", "Bearer "+env.Token) + rr := httptest.NewRecorder() + env.Router.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("create user: expected 201, got %d: %s", rr.Code, rr.Body.String()) + } + var resp struct { + Token string `json:"token"` + } + json.Unmarshal(rr.Body.Bytes(), &resp) + return resp.Token +} + +func TestCreateUser_DefaultTTL(t *testing.T) { + env := newTestEnv(t) + token := createUserWithTTL(t, env, "Bob", "bob@example.com", "") + + tok, err := auth.Verify(token) + if err != nil { + t.Fatalf("verify token: %v", err) + } + ttl := time.Until(tok.ExpiresAt) + // Default is 24h — allow ±5 minutes for test execution time. + if ttl < 23*time.Hour+55*time.Minute || ttl > 24*time.Hour+5*time.Minute { + t.Errorf("expected ~24h TTL, got %v", ttl) + } +} + +func TestCreateUser_CustomTTL(t *testing.T) { + env := newTestEnv(t) + token := createUserWithTTL(t, env, "Carol", "carol@example.com", "30m") + + tok, err := auth.Verify(token) + if err != nil { + t.Fatalf("verify token: %v", err) + } + ttl := time.Until(tok.ExpiresAt) + if ttl < 29*time.Minute || ttl > 31*time.Minute { + t.Errorf("expected ~30m TTL, got %v", ttl) + } +} + +func TestCreateUser_InvalidTTL(t *testing.T) { + env := newTestEnv(t) + body := `{"name":"Dave","email":"dave@example.com","ttl":"notaduration"}` + req := httptest.NewRequest(http.MethodPost, "/v1/users", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+env.Token) + rr := httptest.NewRecorder() + env.Router.ServeHTTP(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid ttl, got %d", rr.Code) + } +} + +func TestRefreshToken_DefaultTTL(t *testing.T) { + env := newTestEnv(t) + req := httptest.NewRequest(http.MethodPost, "/v1/token/refresh", nil) + req.Header.Set("Authorization", "Bearer "+env.Token) + rr := httptest.NewRecorder() + env.Router.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + var resp struct{ Token string `json:"token"` } + json.Unmarshal(rr.Body.Bytes(), &resp) + tok, _ := auth.Verify(resp.Token) + ttl := time.Until(tok.ExpiresAt) + if ttl < 23*time.Hour+55*time.Minute || ttl > 24*time.Hour+5*time.Minute { + t.Errorf("expected ~24h TTL on refresh, got %v", ttl) + } +} + +func TestRefreshToken_CustomTTL(t *testing.T) { + env := newTestEnv(t) + body := `{"ttl":"15m"}` + req := httptest.NewRequest(http.MethodPost, "/v1/token/refresh", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+env.Token) + rr := httptest.NewRecorder() + env.Router.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rr.Code, rr.Body.String()) + } + var resp struct{ Token string `json:"token"` } + json.Unmarshal(rr.Body.Bytes(), &resp) + tok, _ := auth.Verify(resp.Token) + ttl := time.Until(tok.ExpiresAt) + if ttl < 14*time.Minute || ttl > 16*time.Minute { + t.Errorf("expected ~15m TTL on refresh, got %v", ttl) + } +} + +func TestRefreshToken_InvalidTTL(t *testing.T) { + env := newTestEnv(t) + body := `{"ttl":"bad"}` + req := httptest.NewRequest(http.MethodPost, "/v1/token/refresh", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+env.Token) + rr := httptest.NewRecorder() + env.Router.ServeHTTP(rr, req) + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid ttl on refresh, got %d", rr.Code) + } +}