Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions internal/application/identra/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,18 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) {
Endpoint: github.Endpoint,
}

emailStore, storeErr := cache.NewRedisEmailCodeStore(10*time.Minute, redis.NewRDB(*cfg.RedisClient))
rdb := redis.NewRDB(*cfg.RedisClient)

emailStore, storeErr := cache.NewRedisEmailCodeStore(10*time.Minute, rdb)
if storeErr != nil {
return nil, fmt.Errorf("failed to initialize email code store: %w", storeErr)
}

oauthStore, storeErr := cache.NewRedisOAuthStateStore(stateTTL, rdb)
if storeErr != nil {
return nil, fmt.Errorf("failed to initialize oauth state store: %w", storeErr)
}

loginMaxAttempts := cfg.LoginMaxAttempts
if loginMaxAttempts <= 0 {
loginMaxAttempts = DefaultLoginMaxAttempts
Expand All @@ -127,7 +134,7 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) {
}

loginLimiter, loginLimiterErr := cache.NewRedisRateLimiter(
redis.NewRDB(*cfg.RedisClient),
rdb,
"identra:rl:login:",
loginMaxAttempts,
loginLockoutDuration,
Expand All @@ -146,7 +153,7 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) {
}

sendCodeLimiter, sendCodeLimiterErr := cache.NewRedisRateLimiter(
redis.NewRDB(*cfg.RedisClient),
rdb,
"identra:rl:send_code:",
sendCodeMaxAttempts,
sendCodeWindow,
Expand All @@ -159,7 +166,7 @@ func NewService(ctx context.Context, cfg Config) (*Service, error) {
userStore: userStore,
keyManager: km,
tokenCfg: tokenCfg,
oauthStateStore: oauth.NewInMemoryStateStore(stateTTL),
oauthStateStore: oauthStore,
emailCodeStore: emailStore,
githubOAuthConfig: githubCfg,
oauthFetchEmailIfMissing: cfg.OAuthFetchEmailIfMissing,
Expand Down Expand Up @@ -278,7 +285,10 @@ func (s *Service) GetOAuthAuthorizationURL(
slog.ErrorContext(ctx, "failed to generate oauth state", "error", err)
return nil, status.Error(codes.Internal, "failed to generate oauth state")
}
s.oauthStateStore.Add(state, provider, redirectURL)
if err := s.oauthStateStore.Add(ctx, state, provider, redirectURL); err != nil {
slog.ErrorContext(ctx, "failed to store oauth state", "error", err)
return nil, status.Error(codes.Internal, "failed to store oauth state")
}

authURL := oauthCfg.AuthCodeURL(state, oauth2.AccessTypeOffline)
return &identra_v1_pb.GetOAuthAuthorizationURLResponse{Url: authURL, State: state}, nil
Expand All @@ -295,7 +305,11 @@ func (s *Service) LoginByOAuth(
return nil, status.Error(codes.InvalidArgument, "state is required")
}

stateData, ok := s.oauthStateStore.Consume(req.GetState())
stateData, ok, err := s.oauthStateStore.Consume(ctx, req.GetState())
if err != nil {
slog.ErrorContext(ctx, "failed to consume oauth state", "error", err)
return nil, status.Error(codes.Internal, "failed to validate oauth state")
}
if !ok {
return nil, status.Error(codes.InvalidArgument, "invalid or expired state")
}
Expand Down Expand Up @@ -369,7 +383,11 @@ func (s *Service) BindUserByOAuth(
return nil, status.Error(codes.Internal, "failed to fetch user")
}

stateData, ok := s.oauthStateStore.Consume(req.GetState())
stateData, ok, err := s.oauthStateStore.Consume(ctx, req.GetState())
if err != nil {
slog.ErrorContext(ctx, "failed to consume oauth state", "error", err)
return nil, status.Error(codes.Internal, "failed to validate oauth state")
}
if !ok {
return nil, status.Error(codes.InvalidArgument, "invalid or expired state")
}
Expand Down
79 changes: 79 additions & 0 deletions internal/infrastructure/cache/redis_oauth_state_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package cache

import (
"context"
"encoding/json"
"errors"
"time"

"github.com/poly-workshop/identra/internal/infrastructure/oauth"
goredis "github.com/redis/go-redis/v9"
)
Comment on lines +3 to +11
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new file isn’t gofmt-formatted (imports and indentation). Please run gofmt so it matches existing Redis cache implementations (e.g., redis_email_code_store.go) and keeps diffs readable.

Copilot uses AI. Check for mistakes.

// NewRedisOAuthStateStore creates a Redis-backed OAuth state store.
// Expiry is enforced by Redis TTL; the ExpiresAt field of returned State is not populated.
func NewRedisOAuthStateStore(ttl time.Duration, rdb goredis.UniversalClient) (oauth.StateStore, error) {
if ttl <= 0 {
ttl = time.Minute
}
if rdb == nil {
return nil, errors.New("redis client is required for oauth state store")
}
return &redisOAuthStateStore{rdb: rdb, ttl: ttl, prefix: "identra:oauth_state:"}, nil
}

type redisOAuthStateStore struct {
rdb goredis.UniversalClient
ttl time.Duration
prefix string
}

type oauthStateValue struct {
Provider string `json:"provider"`
RedirectURL string `json:"redirect_url"`
}

func (s *redisOAuthStateStore) key(state string) string {
return s.prefix + state
}

// Add stores the state with its provider and redirect URL in Redis with a TTL.
func (s *redisOAuthStateStore) Add(ctx context.Context, state, provider, redirectURL string) error {
val, err := json.Marshal(oauthStateValue{Provider: provider, RedirectURL: redirectURL})
if err != nil {
return err
}
return s.rdb.Set(ctx, s.key(state), val, s.ttl).Err()
}

// consumeStateScript atomically retrieves and deletes the state key.
// Returns the value if found, or nil if not present.
var consumeStateScript = goredis.NewScript(`
local v = redis.call("GET", KEYS[1])
if not v then return nil end
redis.call("DEL", KEYS[1])
return v
`)

// Consume retrieves and atomically removes the state from Redis.
// Returns false (with no error) when the state is not found or has expired.
// ExpiresAt is not populated in the returned State because Redis enforces expiry via TTL.
func (s *redisOAuthStateStore) Consume(ctx context.Context, state string) (oauth.State, bool, error) {
res, err := consumeStateScript.Run(ctx, s.rdb, []string{s.key(state)}).Text()
if err != nil {
if errors.Is(err, goredis.Nil) {
return oauth.State{}, false, nil
}
return oauth.State{}, false, err
}

var val oauthStateValue
if err := json.Unmarshal([]byte(res), &val); err != nil {
return oauth.State{}, false, err
}

return oauth.State{
Provider: val.Provider,
RedirectURL: val.RedirectURL,
}, true, nil
}
19 changes: 12 additions & 7 deletions internal/infrastructure/oauth/state_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oauth

import (
"context"
"sync"
"time"
)
Comment on lines 3 to 7
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file isn’t gofmt-formatted (imports and block indentation lost). Please run gofmt on this file so it matches the rest of the Go codebase and avoids noisy diffs/lint failures.

Copilot uses AI. Check for mistakes.
Expand All @@ -14,8 +15,11 @@ type State struct {

// StateStore defines the interface for OAuth state storage.
type StateStore interface {
Add(state, provider, redirectURL string)
Consume(state string) (State, bool)
// Add stores a new state with its provider and redirect URL.
Add(ctx context.Context, state, provider, redirectURL string) error
// Consume returns the state details when valid and removes it from the store.
// Returns false when the state is not found or has expired.
Consume(ctx context.Context, state string) (State, bool, error)
}

type inMemoryStateStore struct {
Expand All @@ -36,7 +40,7 @@ func NewInMemoryStateStore(ttl time.Duration) StateStore {
}

// Add stores a new state with its provider and redirect URL.
func (s *inMemoryStateStore) Add(state, provider, redirectURL string) {
func (s *inMemoryStateStore) Add(_ context.Context, state, provider, redirectURL string) error {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -46,25 +50,26 @@ func (s *inMemoryStateStore) Add(state, provider, redirectURL string) {
RedirectURL: redirectURL,
ExpiresAt: time.Now().Add(s.ttl),
}
return nil
}

// Consume returns the state details when valid and removes it from the store.
func (s *inMemoryStateStore) Consume(state string) (State, bool) {
func (s *inMemoryStateStore) Consume(_ context.Context, state string) (State, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()

s.cleanupLocked()
value, ok := s.values[state]
if !ok {
return State{}, false
return State{}, false, nil
}
delete(s.values, state)

if time.Now().After(value.ExpiresAt) {
return State{}, false
return State{}, false, nil
}

return value, true
return value, true, nil
}

func (s *inMemoryStateStore) cleanupLocked() {
Expand Down
Loading