From ed63f70c87fb71430eecfdbde4874b963aaa3260 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 07:41:04 +0000 Subject: [PATCH 1/4] migrate OAuth state store from in-memory to Redis Agent-Logs-Url: https://github.com/poly-workshop/identra/sessions/49272a53-7a4a-43de-b5f6-1ad1af5036c1 Co-authored-by: slhmy <31381093+slhmy@users.noreply.github.com> --- internal/application/identra/service.go | 32 +++++-- .../cache/redis_oauth_state_store.go | 77 +++++++++++++++ internal/infrastructure/oauth/state_store.go | 95 ++++++++++--------- 3 files changed, 152 insertions(+), 52 deletions(-) create mode 100644 internal/infrastructure/cache/redis_oauth_state_store.go diff --git a/internal/application/identra/service.go b/internal/application/identra/service.go index 6d6ad20..b159299 100644 --- a/internal/application/identra/service.go +++ b/internal/application/identra/service.go @@ -112,11 +112,18 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) { Endpoint: github.Endpoint, } - emailStore, storeErr := cache.NewRedisEmailCodeStore(10*time.Minute, redis.NewRDB(*cfg.RedisClient)) + rdb := redis.NewRDB(*cfg.RedisClient) + + emailStore, storeErr := cache.NewRedisEmailCodeStore(10*time.Minute, rdb) if storeErr != nil { return nil, fmt.Errorf("failed to initialize email code store: %w", storeErr) } + oauthStore, storeErr := cache.NewRedisOAuthStateStore(stateTTL, rdb) + if storeErr != nil { + return nil, fmt.Errorf("failed to initialize oauth state store: %w", storeErr) + } + loginMaxAttempts := cfg.LoginMaxAttempts if loginMaxAttempts <= 0 { loginMaxAttempts = DefaultLoginMaxAttempts @@ -127,7 +134,7 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) { } loginLimiter, loginLimiterErr := cache.NewRedisRateLimiter( - redis.NewRDB(*cfg.RedisClient), + rdb, "identra:rl:login:", loginMaxAttempts, loginLockoutDuration, @@ -146,7 +153,7 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) { } sendCodeLimiter, sendCodeLimiterErr := cache.NewRedisRateLimiter( - redis.NewRDB(*cfg.RedisClient), + rdb, "identra:rl:send_code:", sendCodeMaxAttempts, sendCodeWindow, @@ -159,7 +166,7 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) { userStore: userStore, keyManager: km, tokenCfg: tokenCfg, - oauthStateStore: oauth.NewInMemoryStateStore(stateTTL), + oauthStateStore: oauthStore, emailCodeStore: emailStore, githubOAuthConfig: githubCfg, oauthFetchEmailIfMissing: cfg.OAuthFetchEmailIfMissing, @@ -278,7 +285,10 @@ func (s *Service) GetOAuthAuthorizationURL( slog.ErrorContext(ctx, "failed to generate oauth state", "error", err) return nil, status.Error(codes.Internal, "failed to generate oauth state") } - s.oauthStateStore.Add(state, provider, redirectURL) + if err := s.oauthStateStore.Add(ctx, state, provider, redirectURL); err != nil { + slog.ErrorContext(ctx, "failed to store oauth state", "error", err) + return nil, status.Error(codes.Internal, "failed to store oauth state") + } authURL := oauthCfg.AuthCodeURL(state, oauth2.AccessTypeOffline) return &identra_v1_pb.GetOAuthAuthorizationURLResponse{Url: authURL, State: state}, nil @@ -295,7 +305,11 @@ func (s *Service) LoginByOAuth( return nil, status.Error(codes.InvalidArgument, "state is required") } - stateData, ok := s.oauthStateStore.Consume(req.GetState()) + stateData, ok, err := s.oauthStateStore.Consume(ctx, req.GetState()) + if err != nil { + slog.ErrorContext(ctx, "failed to consume oauth state", "error", err) + return nil, status.Error(codes.Internal, "failed to validate oauth state") + } if !ok { return nil, status.Error(codes.InvalidArgument, "invalid or expired state") } @@ -369,7 +383,11 @@ func (s *Service) BindUserByOAuth( return nil, status.Error(codes.Internal, "failed to fetch user") } - stateData, ok := s.oauthStateStore.Consume(req.GetState()) + stateData, ok, err := s.oauthStateStore.Consume(ctx, req.GetState()) + if err != nil { + slog.ErrorContext(ctx, "failed to consume oauth state", "error", err) + return nil, status.Error(codes.Internal, "failed to validate oauth state") + } if !ok { return nil, status.Error(codes.InvalidArgument, "invalid or expired state") } diff --git a/internal/infrastructure/cache/redis_oauth_state_store.go b/internal/infrastructure/cache/redis_oauth_state_store.go new file mode 100644 index 0000000..c7b593e --- /dev/null +++ b/internal/infrastructure/cache/redis_oauth_state_store.go @@ -0,0 +1,77 @@ +package cache + +import ( +"context" +"encoding/json" +"errors" +"time" + +"github.com/poly-workshop/identra/internal/infrastructure/oauth" +goredis "github.com/redis/go-redis/v9" +) + +// NewRedisOAuthStateStore creates a Redis-backed OAuth state store. +func NewRedisOAuthStateStore(ttl time.Duration, rdb goredis.UniversalClient) (oauth.StateStore, error) { +if ttl <= 0 { +ttl = 10 * time.Minute +} +if rdb == nil { +return nil, errors.New("redis client is required for oauth state store") +} +return &redisOAuthStateStore{rdb: rdb, ttl: ttl, prefix: "identra:oauth_state:"}, nil +} + +type redisOAuthStateStore struct { +rdb goredis.UniversalClient +ttl time.Duration +prefix string +} + +type oauthStateValue struct { +Provider string `json:"provider"` +RedirectURL string `json:"redirect_url"` +} + +func (s *redisOAuthStateStore) key(state string) string { +return s.prefix + state +} + +// Add stores the state with its provider and redirect URL in Redis with a TTL. +func (s *redisOAuthStateStore) Add(ctx context.Context, state, provider, redirectURL string) error { +val, err := json.Marshal(oauthStateValue{Provider: provider, RedirectURL: redirectURL}) +if err != nil { +return err +} +return s.rdb.Set(ctx, s.key(state), val, s.ttl).Err() +} + +// consumeStateScript atomically retrieves and deletes the state key. +// Returns the value if found, or nil if not present. +var consumeStateScript = goredis.NewScript(` +local v = redis.call("GET", KEYS[1]) +if not v then return false end +redis.call("DEL", KEYS[1]) +return v +`) + +// Consume retrieves and atomically removes the state from Redis. +// Returns false (with no error) when the state is not found or has expired. +func (s *redisOAuthStateStore) Consume(ctx context.Context, state string) (oauth.State, bool, error) { +res, err := consumeStateScript.Run(ctx, s.rdb, []string{s.key(state)}).Text() +if err != nil { +if errors.Is(err, goredis.Nil) { +return oauth.State{}, false, nil +} +return oauth.State{}, false, err +} + +var val oauthStateValue +if err := json.Unmarshal([]byte(res), &val); err != nil { +return oauth.State{}, false, err +} + +return oauth.State{ +Provider: val.Provider, +RedirectURL: val.RedirectURL, +}, true, nil +} diff --git a/internal/infrastructure/oauth/state_store.go b/internal/infrastructure/oauth/state_store.go index 71c765f..a1da1a3 100644 --- a/internal/infrastructure/oauth/state_store.go +++ b/internal/infrastructure/oauth/state_store.go @@ -1,77 +1,82 @@ package oauth import ( - "sync" - "time" +"context" +"sync" +"time" ) // State represents an OAuth state entry. type State struct { - Provider string - RedirectURL string - ExpiresAt time.Time +Provider string +RedirectURL string +ExpiresAt time.Time } // StateStore defines the interface for OAuth state storage. type StateStore interface { - Add(state, provider, redirectURL string) - Consume(state string) (State, bool) +// Add stores a new state with its provider and redirect URL. +Add(ctx context.Context, state, provider, redirectURL string) error +// Consume returns the state details when valid and removes it from the store. +// Returns false when the state is not found or has expired. +Consume(ctx context.Context, state string) (State, bool, error) } type inMemoryStateStore struct { - mu sync.Mutex - ttl time.Duration - values map[string]State +mu sync.Mutex +ttl time.Duration +values map[string]State } // NewInMemoryStateStore creates an in-memory OAuth state store. func NewInMemoryStateStore(ttl time.Duration) StateStore { - if ttl <= 0 { - ttl = time.Minute - } - return &inMemoryStateStore{ - ttl: ttl, - values: make(map[string]State), - } +if ttl <= 0 { +ttl = time.Minute +} +return &inMemoryStateStore{ +ttl: ttl, +values: make(map[string]State), +} } // Add stores a new state with its provider and redirect URL. -func (s *inMemoryStateStore) Add(state, provider, redirectURL string) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *inMemoryStateStore) Add(_ context.Context, state, provider, redirectURL string) error { +s.mu.Lock() +defer s.mu.Unlock() - s.cleanupLocked() - s.values[state] = State{ - Provider: provider, - RedirectURL: redirectURL, - ExpiresAt: time.Now().Add(s.ttl), - } +s.cleanupLocked() +s.values[state] = State{ +Provider: provider, +RedirectURL: redirectURL, +ExpiresAt: time.Now().Add(s.ttl), +} +return nil } // Consume returns the state details when valid and removes it from the store. -func (s *inMemoryStateStore) Consume(state string) (State, bool) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *inMemoryStateStore) Consume(_ context.Context, state string) (State, bool, error) { +s.mu.Lock() +defer s.mu.Unlock() - s.cleanupLocked() - value, ok := s.values[state] - if !ok { - return State{}, false - } - delete(s.values, state) +s.cleanupLocked() +value, ok := s.values[state] +if !ok { +return State{}, false, nil +} +delete(s.values, state) - if time.Now().After(value.ExpiresAt) { - return State{}, false - } +if time.Now().After(value.ExpiresAt) { +return State{}, false, nil +} - return value, true +return value, true, nil } func (s *inMemoryStateStore) cleanupLocked() { - now := time.Now() - for key, value := range s.values { - if now.After(value.ExpiresAt) { - delete(s.values, key) - } - } +now := time.Now() +for key, value := range s.values { +if now.After(value.ExpiresAt) { +delete(s.values, key) +} +} } From e91ab34f23cc6438333698ff1fa24968a1ff7f13 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 07:43:53 +0000 Subject: [PATCH 2/4] fix: align default TTL and document ExpiresAt in Redis state store Agent-Logs-Url: https://github.com/poly-workshop/identra/sessions/49272a53-7a4a-43de-b5f6-1ad1af5036c1 Co-authored-by: slhmy <31381093+slhmy@users.noreply.github.com> --- internal/infrastructure/cache/redis_oauth_state_store.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/infrastructure/cache/redis_oauth_state_store.go b/internal/infrastructure/cache/redis_oauth_state_store.go index c7b593e..b17c939 100644 --- a/internal/infrastructure/cache/redis_oauth_state_store.go +++ b/internal/infrastructure/cache/redis_oauth_state_store.go @@ -11,9 +11,10 @@ goredis "github.com/redis/go-redis/v9" ) // NewRedisOAuthStateStore creates a Redis-backed OAuth state store. +// Expiry is enforced by Redis TTL; the ExpiresAt field of returned State is not populated. func NewRedisOAuthStateStore(ttl time.Duration, rdb goredis.UniversalClient) (oauth.StateStore, error) { if ttl <= 0 { -ttl = 10 * time.Minute +ttl = time.Minute } if rdb == nil { return nil, errors.New("redis client is required for oauth state store") @@ -56,6 +57,7 @@ return v // Consume retrieves and atomically removes the state from Redis. // Returns false (with no error) when the state is not found or has expired. +// ExpiresAt is not populated in the returned State because Redis enforces expiry via TTL. func (s *redisOAuthStateStore) Consume(ctx context.Context, state string) (oauth.State, bool, error) { res, err := consumeStateScript.Run(ctx, s.rdb, []string{s.key(state)}).Text() if err != nil { From 3ae271cf61d1dad7e8e9448b264d462eed2f3e15 Mon Sep 17 00:00:00 2001 From: slhmy Date: Wed, 15 Apr 2026 16:05:39 +0800 Subject: [PATCH 3/4] Update internal/infrastructure/cache/redis_oauth_state_store.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/infrastructure/cache/redis_oauth_state_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/infrastructure/cache/redis_oauth_state_store.go b/internal/infrastructure/cache/redis_oauth_state_store.go index b17c939..5198eb6 100644 --- a/internal/infrastructure/cache/redis_oauth_state_store.go +++ b/internal/infrastructure/cache/redis_oauth_state_store.go @@ -50,7 +50,7 @@ return s.rdb.Set(ctx, s.key(state), val, s.ttl).Err() // Returns the value if found, or nil if not present. var consumeStateScript = goredis.NewScript(` local v = redis.call("GET", KEYS[1]) -if not v then return false end +if not v then return nil end redis.call("DEL", KEYS[1]) return v `) From 52c479d6289d93cdcb6376b0dff955a02bdd995c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 08:09:26 +0000 Subject: [PATCH 4/4] style: run gofmt on state_store.go and redis_oauth_state_store.go Agent-Logs-Url: https://github.com/poly-workshop/identra/sessions/1181e47e-b283-40fb-a256-5038f22c2498 Co-authored-by: slhmy <31381093+slhmy@users.noreply.github.com> --- .../cache/redis_oauth_state_store.go | 78 +++++++-------- internal/infrastructure/oauth/state_store.go | 96 +++++++++---------- 2 files changed, 87 insertions(+), 87 deletions(-) diff --git a/internal/infrastructure/cache/redis_oauth_state_store.go b/internal/infrastructure/cache/redis_oauth_state_store.go index 5198eb6..a743c41 100644 --- a/internal/infrastructure/cache/redis_oauth_state_store.go +++ b/internal/infrastructure/cache/redis_oauth_state_store.go @@ -1,49 +1,49 @@ package cache import ( -"context" -"encoding/json" -"errors" -"time" + "context" + "encoding/json" + "errors" + "time" -"github.com/poly-workshop/identra/internal/infrastructure/oauth" -goredis "github.com/redis/go-redis/v9" + "github.com/poly-workshop/identra/internal/infrastructure/oauth" + goredis "github.com/redis/go-redis/v9" ) // NewRedisOAuthStateStore creates a Redis-backed OAuth state store. // Expiry is enforced by Redis TTL; the ExpiresAt field of returned State is not populated. func NewRedisOAuthStateStore(ttl time.Duration, rdb goredis.UniversalClient) (oauth.StateStore, error) { -if ttl <= 0 { -ttl = time.Minute -} -if rdb == nil { -return nil, errors.New("redis client is required for oauth state store") -} -return &redisOAuthStateStore{rdb: rdb, ttl: ttl, prefix: "identra:oauth_state:"}, nil + if ttl <= 0 { + ttl = time.Minute + } + if rdb == nil { + return nil, errors.New("redis client is required for oauth state store") + } + return &redisOAuthStateStore{rdb: rdb, ttl: ttl, prefix: "identra:oauth_state:"}, nil } type redisOAuthStateStore struct { -rdb goredis.UniversalClient -ttl time.Duration -prefix string + rdb goredis.UniversalClient + ttl time.Duration + prefix string } type oauthStateValue struct { -Provider string `json:"provider"` -RedirectURL string `json:"redirect_url"` + Provider string `json:"provider"` + RedirectURL string `json:"redirect_url"` } func (s *redisOAuthStateStore) key(state string) string { -return s.prefix + state + return s.prefix + state } // Add stores the state with its provider and redirect URL in Redis with a TTL. func (s *redisOAuthStateStore) Add(ctx context.Context, state, provider, redirectURL string) error { -val, err := json.Marshal(oauthStateValue{Provider: provider, RedirectURL: redirectURL}) -if err != nil { -return err -} -return s.rdb.Set(ctx, s.key(state), val, s.ttl).Err() + val, err := json.Marshal(oauthStateValue{Provider: provider, RedirectURL: redirectURL}) + if err != nil { + return err + } + return s.rdb.Set(ctx, s.key(state), val, s.ttl).Err() } // consumeStateScript atomically retrieves and deletes the state key. @@ -59,21 +59,21 @@ return v // Returns false (with no error) when the state is not found or has expired. // ExpiresAt is not populated in the returned State because Redis enforces expiry via TTL. func (s *redisOAuthStateStore) Consume(ctx context.Context, state string) (oauth.State, bool, error) { -res, err := consumeStateScript.Run(ctx, s.rdb, []string{s.key(state)}).Text() -if err != nil { -if errors.Is(err, goredis.Nil) { -return oauth.State{}, false, nil -} -return oauth.State{}, false, err -} + res, err := consumeStateScript.Run(ctx, s.rdb, []string{s.key(state)}).Text() + if err != nil { + if errors.Is(err, goredis.Nil) { + return oauth.State{}, false, nil + } + return oauth.State{}, false, err + } -var val oauthStateValue -if err := json.Unmarshal([]byte(res), &val); err != nil { -return oauth.State{}, false, err -} + var val oauthStateValue + if err := json.Unmarshal([]byte(res), &val); err != nil { + return oauth.State{}, false, err + } -return oauth.State{ -Provider: val.Provider, -RedirectURL: val.RedirectURL, -}, true, nil + return oauth.State{ + Provider: val.Provider, + RedirectURL: val.RedirectURL, + }, true, nil } diff --git a/internal/infrastructure/oauth/state_store.go b/internal/infrastructure/oauth/state_store.go index a1da1a3..256da93 100644 --- a/internal/infrastructure/oauth/state_store.go +++ b/internal/infrastructure/oauth/state_store.go @@ -1,82 +1,82 @@ package oauth import ( -"context" -"sync" -"time" + "context" + "sync" + "time" ) // State represents an OAuth state entry. type State struct { -Provider string -RedirectURL string -ExpiresAt time.Time + Provider string + RedirectURL string + ExpiresAt time.Time } // StateStore defines the interface for OAuth state storage. type StateStore interface { -// Add stores a new state with its provider and redirect URL. -Add(ctx context.Context, state, provider, redirectURL string) error -// Consume returns the state details when valid and removes it from the store. -// Returns false when the state is not found or has expired. -Consume(ctx context.Context, state string) (State, bool, error) + // Add stores a new state with its provider and redirect URL. + Add(ctx context.Context, state, provider, redirectURL string) error + // Consume returns the state details when valid and removes it from the store. + // Returns false when the state is not found or has expired. + Consume(ctx context.Context, state string) (State, bool, error) } type inMemoryStateStore struct { -mu sync.Mutex -ttl time.Duration -values map[string]State + mu sync.Mutex + ttl time.Duration + values map[string]State } // NewInMemoryStateStore creates an in-memory OAuth state store. func NewInMemoryStateStore(ttl time.Duration) StateStore { -if ttl <= 0 { -ttl = time.Minute -} -return &inMemoryStateStore{ -ttl: ttl, -values: make(map[string]State), -} + if ttl <= 0 { + ttl = time.Minute + } + return &inMemoryStateStore{ + ttl: ttl, + values: make(map[string]State), + } } // Add stores a new state with its provider and redirect URL. func (s *inMemoryStateStore) Add(_ context.Context, state, provider, redirectURL string) error { -s.mu.Lock() -defer s.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() -s.cleanupLocked() -s.values[state] = State{ -Provider: provider, -RedirectURL: redirectURL, -ExpiresAt: time.Now().Add(s.ttl), -} -return nil + s.cleanupLocked() + s.values[state] = State{ + Provider: provider, + RedirectURL: redirectURL, + ExpiresAt: time.Now().Add(s.ttl), + } + return nil } // Consume returns the state details when valid and removes it from the store. func (s *inMemoryStateStore) Consume(_ context.Context, state string) (State, bool, error) { -s.mu.Lock() -defer s.mu.Unlock() + s.mu.Lock() + defer s.mu.Unlock() -s.cleanupLocked() -value, ok := s.values[state] -if !ok { -return State{}, false, nil -} -delete(s.values, state) + s.cleanupLocked() + value, ok := s.values[state] + if !ok { + return State{}, false, nil + } + delete(s.values, state) -if time.Now().After(value.ExpiresAt) { -return State{}, false, nil -} + if time.Now().After(value.ExpiresAt) { + return State{}, false, nil + } -return value, true, nil + return value, true, nil } func (s *inMemoryStateStore) cleanupLocked() { -now := time.Now() -for key, value := range s.values { -if now.After(value.ExpiresAt) { -delete(s.values, key) -} -} + now := time.Now() + for key, value := range s.values { + if now.After(value.ExpiresAt) { + delete(s.values, key) + } + } }