diff --git a/cmd/hotplex/gateway_run.go b/cmd/hotplex/gateway_run.go index 3a3d7b60..07fdac7d 100644 --- a/cmd/hotplex/gateway_run.go +++ b/cmd/hotplex/gateway_run.go @@ -71,6 +71,7 @@ type GatewayDeps struct { CronScheduler *cron.Scheduler WebhookHandler *gateway.WebhookHandler // non-nil when webhook is enabled CookieAuth *security.CookieAuth // non-nil when webchat is enabled + OAuthManager *security.OAuthManager // non-nil when SSO providers are configured ChatAccessStore messaging.ChatAccessStorer DB *sql.DB DBResolver *security.DBResolver @@ -274,6 +275,10 @@ func runGateway(configPath string, devMode bool, stopCh <-chan struct{}) (err er WSStore: stores.wsStore, }) + // One-time validation sweep: surface stale/invalid agent_config_overrides + // written before spec ② write-time validation (#749). Non-blocking. + gateway.ScanWorkspaceOverrides(ctx, stores.wsStore, log) + skillsLocator := skills.NewLocator(log, cfg.Skills.CacheTTL) handler := gateway.NewHandler(gateway.HandlerDeps{ @@ -422,6 +427,22 @@ func runGateway(configPath string, devMode bool, stopCh <-chan struct{}) (err er log.Info("gateway: webchat cookie auth enabled") } + // OAuth manager: created when SSO providers are configured (spec ④). + // Requires cookieAuth for signing state cookies. + var oauthManager *security.OAuthManager + if cookieAuth != nil { + oauthManager = security.NewOAuthManager(cookieAuth) + if err := cfg.OAuth.Validate(); err != nil { + log.Warn("oauth config validation failed", "error", err) + } else if len(cfg.OAuth.Providers) > 0 { + count, err := oauthManager.Reload(ctx, cfg.OAuth) + if err != nil { + log.Error("oauth manager init failed", "error", err) + } + log.Info("gateway: oauth SSO providers loaded", "count", count) + } + } + mux := http.NewServeMux() deps := &GatewayDeps{ Log: log, @@ -438,6 +459,7 @@ func runGateway(configPath string, devMode bool, stopCh <-chan struct{}) (err er ConfigWatcher: configWatcher, CronScheduler: cronScheduler, CookieAuth: cookieAuth, + OAuthManager: oauthManager, ChatAccessStore: stores.chatAccessOrNew(stores.sqlDB, log), DB: stores.sqlDB, DBResolver: dbResolver, diff --git a/cmd/hotplex/routes.go b/cmd/hotplex/routes.go index ab7ffc38..f6b90327 100644 --- a/cmd/hotplex/routes.go +++ b/cmd/hotplex/routes.go @@ -199,13 +199,29 @@ func setupRoutes( } } - // TODO(spec ⑥): Register WebChat multi-tenant HTTP endpoints when login UI ships. - // - POST /api/auth/login, /api/auth/logout, GET /api/auth/me - // - POST /api/auth/accept-invite - // - POST/GET/PUT/DELETE /api/workspaces/* - // - POST/GET/DELETE /api/admin/invitations - // Handlers are implemented in internal/gateway/{auth,workspace}_handlers.go - // but intentionally not wired until the frontend login flow is ready. + // WebChat multi-tenant auth endpoints (spec ① + spec ④). + // Wired when cookieAuth is available (webchat enabled). + if deps.CookieAuth != nil && deps.WorkspaceStore != nil { + // Account-login handlers (spec ①): requires LocalAccountProvider. + // LocalAccountProvider is created lazily from WorkspaceStore + bcrypt cost. + lap := security.NewLocalAccountProvider(deps.WorkspaceStore, security.BcryptCostDefault) + authHandlers := gateway.NewAuthHandlers(auth, deps.CookieAuth, deps.WorkspaceStore, lap) + mux.Handle("POST /api/auth/login", corsMw(http.HandlerFunc(authHandlers.Login))) + mux.Handle("POST /api/auth/logout", corsMw(http.HandlerFunc(authHandlers.Logout))) + mux.Handle("GET /api/auth/me", corsMw(http.HandlerFunc(authHandlers.Me))) + mux.Handle("POST /api/auth/accept-invite", corsMw(http.HandlerFunc(authHandlers.AcceptInvite))) + log.Info("auth endpoints registered", "channels", "login,logout,me,accept-invite") + + // OAuth SSO handlers (spec ④): requires OAuthManager with providers. + if deps.OAuthManager != nil && deps.OAuthManager.HasProviders() { + oauthHandlers := gateway.NewOAuthHandlers(deps.OAuthManager, deps.CookieAuth, deps.WorkspaceStore, log) + mux.Handle("GET /api/auth/oauth/providers", corsMw(http.HandlerFunc(oauthHandlers.Providers))) + mux.Handle("GET /api/auth/oauth/{provider}/login", http.HandlerFunc(oauthHandlers.Login)) + mux.Handle("GET /api/auth/oauth/{provider}/callback", http.HandlerFunc(oauthHandlers.Callback)) + // Note: login/callback are redirect flows, CORS not needed (browser navigates directly). + log.Info("oauth SSO endpoints registered", "providers", deps.OAuthManager.List()) + } + } // Global favicon fallback using docs logo mux.HandleFunc("GET /favicon.ico", func(w http.ResponseWriter, r *http.Request) { diff --git a/docs/specs/WebChat-Multitenancy-OAuth-SSO-Design-Spec.md b/docs/specs/WebChat-Multitenancy-OAuth-SSO-Design-Spec.md new file mode 100644 index 00000000..234badae --- /dev/null +++ b/docs/specs/WebChat-Multitenancy-OAuth-SSO-Design-Spec.md @@ -0,0 +1,549 @@ +# WebChat 多租户 spec ④:企业 SSO(OIDC 统一认证) + +**日期**: 2026-06-17 +**状态**: 设计稿(待 plan) +**分支**: feat/webchat-oauth-sso · **基线**: v1.29.1(`5bc8c0ce`,含 spec ⑤ PR #755) +**路线图**: [`WebChat-Multitenancy-Roadmap-Spec.md`](./WebChat-Multitenancy-Roadmap-Spec.md) §4 spec ④ +**前置**: spec ①(PR #746,User 实体 + IdentityProvider 接口 + CookieAuth + auth_handlers) + +--- + +## 目录 + +- [1. 背景与调研](#1-背景与调研) +- [2. 目标与非目标](#2-目标与非目标) +- [3. 关键决策](#3-关键决策) +- [4. 架构总览](#4-架构总览) +- [5. 数据模型](#5-数据模型) +- [6. 配置模型](#6-配置模型) +- [7. OIDC 流程详解](#7-oidc-流程详解) +- [8. 账号关联策略](#8-账号关联策略) +- [9. API 端点清单](#9-api-端点清单) +- [10. 安全设计](#10-安全设计) +- [11. 错误码](#11-错误码) +- [12. 测试策略](#12-测试策略) +- [13. 迁移策略](#13-迁移策略) +- [14. 后续 spec 路线](#14-后续-spec-路线) + +--- + +## 1. 背景与调研 + +### 1.1 WebChat 的 SSO 需求本质 + +WebChat 是团队共享一个 HotPlex 实例的多用户前端。spec ① 已建立账号密码登录(`LocalAccountProvider` + bcrypt + 邀请制)。但企业内部通常已有统一身份认证(SSO)基础设施,要求所有应用通过企业 IdP 登录,而非在各应用中维护独立密码。 + +spec ④ 的目标:让 WebChat 接入企业现有 SSO,员工用企业 IdP 账号登录,映射到 HotPlex `users.id`。**飞书/Slack 用户不需要 WebChat SSO**(它们在各自平台用 Message Channel 轨),SSO 纯粹是企业统一身份认证集成。 + +### 1.2 国内主流统一认证厂商协议调研 + +| 厂商 | 产品 | OIDC | OAuth2 | SAML | CAS | LDAP | +|---|---|:---:|:---:|:---:|:---:|:---:| +| 派拉 | 统一身份认证 / IDaaS | ✅ | ✅ | ✅ | ✅ | ✅ | +| 玉符(腾讯) | IDaaS | ✅ | ✅ | ✅ | ✅ | ✅ | +| 阿里云 | IDaaS EIAM | ✅ | ✅(OIDC 子集) | ✅ | — | ✅ | +| 腾讯云 | Cloud IDaaS | ✅ | ✅ | ✅ | — | ✅ | +| 宁盾 | IAM / SSO | ✅ | ✅ | ✅ | ✅ | ✅ | +| Authing | 身份云 | ✅ | ✅ | ✅ | ✅ | ✅ | +| 竹云 | IDaaS | ✅ | ✅ | ✅ | ✅ | ✅ | +| 华为云 | OneAccess | ✅ | ✅ | ✅ | ✅ | ✅ | +| TOPIAM(开源) | SSO 平台 | ✅ | ✅ | ✅ | ✅ | — | + +**关键结论**:**OIDC 是所有国内主流厂商的最大公约数**,100% 覆盖。一个标准 OIDC 客户端实现即可对接全部上述厂商 + 国际 IdP(Keycloak / Okta / Azure AD / Google Workspace)。 + +### 1.3 协议优先级 + +| 协议 | 厂商覆盖率 | 实现成本 | 本 spec 决策 | +|---|---|---|---| +| **OIDC** | 100% | 低 | ✅ **第一期实现** | +| SAML 2.0 | ~85% | 高(XML 签名、SP/IdP 元数据、双重流程) | 后续 spec(YAGNI,验证实际需求后做) | +| CAS | ~50%(高校/老政企) | 中 | 不做(用户可 CAS→OIDC 桥接) | + +--- + +## 2. 目标与非目标 + +### 2.1 目标 + +1. **标准 OIDC 客户端** — 一个实现覆盖所有标准 OIDC IdP(国内全部主流厂商 + 国际) +2. **多 provider 配置** — 支持配置多个 IdP(如同时对接测试/生产 Keycloak),用户在登录页选择 +3. **与密码登录并存** — OIDC SSO 和 spec ① 的账号密码登录都是一等公民,企业可选其一或同时启用 +4. **首次登录自动建号** — SSO 认证成功且本地无对应用户时,自动创建 `users` 行 +5. **安全合规** — Authorization Code flow + PKCE + state 防 CSRF + +### 2.2 非目标(明确不做) + +- **SAML 2.0** — 后续独立 spec(XML 协议栈重量级,Go 生态 SAML 库不如 OIDC 成熟) +- **CAS 协议** — YAGNI,用户可部署 CAS→OIDC 桥接 +- **飞书/Slack OAuth 复用** — 它们是 Message Channel,与 WebChat SSO 无关 +- **账号自动合并** — 不基于 email 自动关联已有密码账号(欺骗风险),见 §8 +- **多 identity 显式绑定 UI** — 首期不做链接/解绑 UI,见 §8 +- **WebChat 前端登录页** — 归 spec ⑥(本 spec 交付后端 + API,HTTP 可验证) +- **Token 持久化** — 不存 access_token/refresh_token(HotPlex 只需登录时身份确认,不做 IdP API 代理) + +--- + +## 3. 关键决策 + +### 3.1 协议:OIDC Authorization Code flow + PKCE + +| 决策 | 选择 | 理由 | +|---|---|---| +| 协议族 | OIDC(非裸 OAuth2) | OIDC 提供标准化身份层(ID Token + UserInfo),覆盖全部厂商 | +| 流程 | Authorization Code flow | 最安全的服务端流程;token 不经过浏览器 | +| PKCE | 强制启用 | 防 authorization code 拦截(即使 server-side confidential client 也无成本) | +| Discovery | `.well-known/openid-configuration` 自动发现 | 配置只需 `issuer` URL,不需手填每个 endpoint | +| Token 验证 | ID Token signature 验证(JWKS) | 确保 token 来自可信 IdP,防伪造 | + +### 3.2 账号关联:不自动合并 + +首次 SSO 登录 → 按 `(provider, subject)` 查找 → 不存在则自动建号。不基于 email 自动关联已有密码账号(email 欺骗风险)。admin 可后续手动绑定(待独立 spec)。 + +### 3.3 配置位置:独立 `oauth` 段 + +不复用 bot 配置(飞书/Slack 的 `app_id`/`app_secret` 是 Message Channel 用的,语义不同)。独立 `oauth.providers[]` 配置段,与 bot 配置零耦合。 + +### 3.4 前端:spec ④ 不含 UI + +与路线图一致。本 spec 交付后端 OIDC 客户端 + API 端点,用 HTTP/curl 端到端验证。前端登录页(含 provider 选择、SSO 入口、callback 落地页)归 spec ⑥。 + +--- + +## 4. 架构总览 + +``` +用户浏览器 HotPlex Gateway 企业 IdP + | | | + | 1. 点击 "SSO登录" | | + | GET /api/auth/oauth/{p}/login | + |----------------------------->| | + | | 2. 生成 state+PKCE | + | | 3. 302 重定向到 IdP | + |<-----------------------------| | + | | | + | 4. 用户在 IdP 登录 | | + |----------------------------------------------------->| | + | | 5. IdP 认证 | + |<-----------------------------------------------------| | + | 6. 302 回调到 HotPlex | | + | GET /api/auth/oauth/{p}/callback?code=...&state=... | + |----------------------------->| | + | | 7. 校验 state | + | | 8. code → token exchange | + | | 9. 验证 ID Token signature | + | | 10. 提取 subject + claims | + | | 11. 查找/创建 users 行 | + | | 12. 签发 cookie | + | 13. 302 回 webchat 首页 | | + |<-----------------------------| | +``` + +**新增组件**: + +| 组件 | 位置 | 职责 | +|---|---|---| +| `OAuthProvider` | `internal/security/oauth_provider.go` | 实现 `IdentityProvider`,OIDC 客户端 | +| `OAuthManager` | `internal/security/oauth_manager.go` | 多 provider 注册表 + state 管理 | +| `OAuthHandlers` | `internal/gateway/oauth_handlers.go` | HTTP 端点(login / callback) | +| `OAuthConfig` | `internal/config/config_types.go` | `oauth.providers[]` 配置解析 | +| migration 020 | `sql/migrations/` + `sql/migrations-postgres/` | `user_identities` 表 | + +**已有组件(复用)**: + +| 组件 | 复用点 | +|---|---| +| `IdentityProvider` 接口 | `OAuthProvider` 作为第二实现,不改接口 | +| `CookieAuth` | SSO 成功后签发 cookie,与密码登录同一 cookie | +| `AuthHandlers` | `Login`/`Logout`/`Me` 已实现,`OAuthManager` 注入后 `Login` 可增加 SSO 入口返回 | +| `UserStore` 接口 | 新增 `GetOrCreateUserByIdentity` 方法 | + +--- + +## 5. 数据模型 + +### 5.1 新增表:`user_identities`(migration 020) + +将 OAuth 身份与 `users` 解耦——一个用户可关联多个 IdP(未来),`users` 表不污染 OAuth 字段: + +```sql +CREATE TABLE user_identities ( + id TEXT PRIMARY KEY, -- UUID + user_id TEXT NOT NULL, -- FK → users.id + provider TEXT NOT NULL, -- provider name (config key) + subject TEXT NOT NULL, -- IdP subject (OIDC "sub" claim) + display_name TEXT NOT NULL DEFAULT '', -- 从 IdP 同步 + email TEXT NOT NULL DEFAULT '', -- 从 IdP 同步(仅记录,不用于自动合并) + created_at INTEGER NOT NULL, -- Unix epoch seconds + updated_at INTEGER NOT NULL, + UNIQUE(provider, subject) -- 一个 IdP+subject 只映射一个 user +); + +CREATE INDEX idx_user_identities_user_id ON user_identities(user_id); +CREATE INDEX idx_user_identities_lookup ON user_identities(provider, subject); +``` + +**设计理由**: +- **独立表 vs `users` 加字段**:独立表支持未来多 IdP 关联(一用户绑多个 SSO),且不污染 `users` 表(密码账号无 provider/subject 概念)。`UNIQUE(provider, subject)` 保证登录确定性。 +- **`subject` 而非 `email` 作为唯一键**:OIDC `sub` claim 是 IdP 内全局唯一且不可变的用户标识,email 可变且可重复。 +- **`email` 存储但不用于合并**:纯展示用途,不参与自动关联逻辑。 + +### 5.2 `users` 表不变 + +`users` 表不增加任何字段。SSO 建号时: +- `username` = `{provider}:{subject}`(保证唯一且可追溯;非登录用) +- `password_hash` = `''`(空,表示此账号只能通过 SSO 登录,不能密码登录——复用 spec ① 已有的"空 hash = 不可密码登录"语义) +- `role` = `user`(SSO 默认普通用户;admin 通过 admin CLI/界面提升) +- `status` = `active` + +--- + +## 6. 配置模型 + +### 6.1 YAML 配置 + +```yaml +oauth: + # 外部基础 URL(用于构造 OAuth callback URL)。 + # 不配置时从请求 Host header 推导(同源场景)。 + # 反向代理后端或多 URL 场景必须显式配置。 + external_url: "https://hotplex.example.com" + + providers: + - name: "keycloak" # 唯一标识,出现在 URL 路径和 user_identities.provider + display_name: "企业 SSO" # 登录页展示名(spec ⑥ 用) + issuer: "https://sso.example.com/realms/main" + client_id: "hotplex" + client_secret: "${OAUTH_KEYCLOAK_SECRET}" # 支持 env var 引用 + scopes: ["openid", "profile", "email"] # 默认 ["openid", "profile"] + # 可选 claim 映射(不配则用 OIDC 标准 claim 名) + username_claim: "preferred_username" + display_name_claim: "name" + email_claim: "email" + + - name: "authing" + display_name: "Authing" + issuer: "https://xxx.authing.cn" + client_id: "hotplex" + client_secret: "${OAUTH_AUTHING_SECRET}" +``` + +### 6.2 配置规则 + +| 规则 | 说明 | +|---|---| +| `name` | 必填,唯一,`[a-z0-9-]+`,用于 URL 路径(`/api/auth/oauth/{name}/login`) | +| `issuer` | 必填,OIDC issuer URL(自动发现 `.well-known/openid-configuration`) | +| `client_id` + `client_secret` | 必填,confidential client 凭证 | +| `scopes` | 可选,默认 `["openid", "profile"]` | +| `*_claim` | 可选,自定义 claim 映射;不配则用 OIDC 标准名 | +| `external_url` | 可选,全局配置(非 per-provider),构造 callback URL | +| env 引用 | `client_secret` 支持 `${ENV_VAR}` 语法(与现有 config env 引用一致) | + +### 6.3 热重载 + +Provider 列表支持运行时热重载(复用现有 `ConfigStore` watcher)。变更时重建 `OAuthManager` 的 provider 注册表,不影响已进行中的 OAuth 流程(state cookie 已编码 provider name,回调时从注册表查找)。 + +--- + +## 7. OIDC 流程详解 + +### 7.1 login 端点(GET `/api/auth/oauth/{provider}/login`) + +``` +1. 从 URL path 提取 provider name +2. OAuthManager.Lookup(provider) → OAuthProvider 实例 + - 不存在 → 404 +3. 生成 state(32 字节随机 hex)和 PKCE code_verifier(64 字节随机 hex) +4. 计算 code_challenge = S256(code_verifier) +5. state → 短期 cookie(5 分钟 TTL,HttpOnly,SameSite=Lax) + cookie 值 = Base64(state|code_verifier|provider) — 无需服务端存储 +6. 构造 authorization URL: + {issuer_auth_endpoint}? + response_type=code + &client_id={client_id} + &redirect_uri={external_url}/api/auth/oauth/{provider}/callback + &scope={space_joined_scopes} + &state={state} + &code_challenge={code_challenge} + &code_challenge_method=S256 +7. 302 重定向到 authorization URL +``` + +**state cookie 设计**(无状态): +- 名称:`oauth_state` +- 值:HMAC 签名的 `Base64(state|code_verifier|provider|issuedAt)` +- TTL:5 分钟 +- SameSite=Lax(允许从 IdP 重定向回来) +- 使用 CookieAuth 的 HMAC secret 签名,防篡改 + +### 7.2 callback 端点(GET `/api/auth/oauth/{provider}/callback`) + +``` +1. 从 URL path 提取 provider name +2. 从 query 提取 code + state +3. 读取 oauth_state cookie → 解签 → 校验 state 匹配 + - 不匹配 → 400 CSRF_DETECTED + - cookie 过期 → 400 STATE_EXPIRED + - cookie 中 provider 与 path provider 不匹配 → 400 PROVIDER_MISMATCH +4. OAuthManager.Lookup(provider) → OAuthProvider +5. code → token exchange(IdP token endpoint): + POST {issuer_token_endpoint} + grant_type=authorization_code + &code={code} + &redirect_uri={...} + &client_id={...} + &client_secret={...} + &code_verifier={code_verifier} + → 返回 {access_token, id_token, ...} +6. 验证 ID Token: + a. 从 IdP JWKS endpoint 获取签名密钥 + b. 验证 RS256/ES256 签名 + c. 验证 iss == configured issuer + d. 验证 aud == client_id + e. 验证 exp 未过期 +7. 提取 claims:sub, username, display_name, email +8. GetOrCreateUserByIdentity(provider, sub, ...) → user_id +9. 签发 webchat_session cookie(同密码登录) +10. 清除 oauth_state cookie +11. 302 重定向到 webchat 首页(`/`) +``` + +### 7.3 Go 库选择 + +| 库 | 用途 | 状态 | +|---|---|---| +| `golang.org/x/oauth2` | OAuth2 底层(token exchange、HTTP 客户端) | 已在 go.mod(indirect → 提升为 direct) | +| `github.com/coreos/go-oidc/v3` | OIDC 层(discovery、ID Token 验证、JWKS 缓存) | 需引入 | +| `github.com/go-jose/go-jose/v4` | JOSE(JWT/JWKS,被 go-oidc 依赖) | 传递依赖 | + +`go-oidc/v3` 的 `oidc.Provider` 自动从 `.well-known/openid-configuration` 发现 endpoints,`oidc.IDTokenVerifier` 自动验证签名 + 标准 claims(iss/aud/exp),`oidc.UserInfo` 获取用户信息——覆盖全部需求,无需手写 JWT 验证逻辑。 + +--- + +## 8. 账号关联策略 + +### 8.1 首次登录自动建号 + +``` +SSO callback 认证成功 + → GetOrCreateUserByIdentity(provider, subject) + → 查 user_identities WHERE provider=? AND subject=? + → 命中:返回 user_id + → 未命中: + 1. 创建 users 行(username="{provider}:{subject}", password_hash="", role="user", status="active") + 2. 创建 user_identities 行(provider, subject, user_id, display_name, email) + 3. 返回 user_id +``` + +**事务**:users + user_identities 在同一事务内创建,保证一致性。 + +**display_name / email 更新**:每次 SSO 登录时,如果 IdP 返回的 display_name/email 与本地不同,更新 `user_identities` 行(IdP 是权威源)。`users.username` 不更新(避免 session key 漂移)。 + +### 8.2 不自动合并 + +| 场景 | 行为 | +|---|---| +| 同一 IdP 同一 subject 首次登录 | 自动建号 | +| 同一 IdP 同一 subject 再次登录 | 复用已有 user_id | +| 同一人换了 IdP(如从 Keycloak 迁到 Authing) | 新建 user_id(两个账号) | +| SSO 账号的 email 与某密码账号的 email 相同 | **不合并**(各自独立 user_id) | + +### 8.3 未来:手动绑定(不在本 spec 范围) + +admin/用户自行将 SSO 身份关联到已有账号的 UI 流程,作为独立 spec(spec ④.1 或路线图新增项)。本 spec 只做自动建号。 + +--- + +## 9. API 端点清单 + +### 9.1 新增端点 + +| 方法 | 路径 | 说明 | +|---|---|---| +| GET | `/api/auth/oauth/{provider}/login` | 发起 OIDC 流程,302 重定向到 IdP | +| GET | `/api/auth/oauth/{provider}/callback` | OIDC 回调,签发 cookie,302 回 webchat | +| GET | `/api/auth/oauth/providers` | 列出已配置的 provider(供前端登录页渲染按钮) | + +### 9.2 现有端点变更 + +| 方法 | 路径 | 变更 | +|---|---|---| +| POST | `/api/auth/login` | 不变(密码登录独立于 SSO) | +| POST | `/api/auth/logout` | 不变(清 cookie 即可,不调 IdP RP-logout) | +| GET | `/api/auth/me` | 不变(cookie 解析逻辑通用) | + +### 9.3 `providers` 响应 + +```json +GET /api/auth/oauth/providers +{ + "providers": [ + {"name": "keycloak", "display_name": "企业 SSO"}, + {"name": "authing", "display_name": "Authing"} + ] +} +``` + +未配置任何 provider 时返回空数组(前端据此决定是否显示 SSO 入口)。 + +--- + +## 10. 安全设计 + +### 10.1 CSRF 防护(state 参数) + +- `state` = 32 字节密码学随机数 +- 存入短期 HMAC cookie(5 分钟),回调时验证 +- state 绑定 provider name,防 provider 混淆攻击 + +### 10.2 PKCE + +- `code_verifier` = 64 字节随机 hex,`code_challenge` = SHA256 base64url +- 即使 authorization code 被截获,无 code_verifier 也无法 exchange token +- 强制 S256 method(不接受 plain) + +### 10.3 ID Token 验证 + +- **签名**:从 IdP JWKS endpoint 获取公钥,验证 RS256/ES256 签名(go-oidc 自动处理 + 缓存 JWKS) +- **iss**:必须等于配置的 `issuer` +- **aud**:必须包含配置的 `client_id` +- **exp**:未过期 +- **nonce**:本设计不使用 nonce(Authorization Code flow + PKCE 已足够,go-oidc 推荐但 PKCE 场景可选) + +### 10.4 open redirect 防护 + +- callback 的 `redirect_uri` 固定为 `{external_url}/api/auth/oauth/{provider}/callback`,不接受用户输入 +- 登录成功后只重定向到 webchat 首页 `/`(不接受 query 参数指定的任意 URL) + +### 10.5 provider name 注入防护 + +- provider name 从 URL path 提取后,必须与配置中已注册的 provider name 精确匹配 +- 配置校验:`name` 只允许 `[a-z0-9-]`,防 path traversal + +### 10.6 secret 不暴露 + +- `client_secret` 永不出现在任何 API 响应中 +- 日志中 token/secret 脱敏 +- `oauth/providers` 端点只返回 `name` + `display_name` + +--- + +## 11. 错误码 + +| HTTP | code | 场景 | +|---|---|---| +| 400 | `PROVIDER_NOT_FOUND` | URL 中 provider 未配置 | +| 400 | `CSRF_DETECTED` | state cookie 不匹配或缺失 | +| 400 | `STATE_EXPIRED` | state cookie 过期(> 5min) | +| 400 | `PROVIDER_MISMATCH` | state cookie 中 provider 与 path 不一致 | +| 400 | `CODE_EXCHANGE_FAILED` | token exchange 失败(IdP 返回错误) | +| 400 | `ID_TOKEN_INVALID` | ID Token 验证失败(签名/iss/aud/exp) | +| 400 | `DISCOVERY_FAILED` | 无法获取 IdP `.well-known/openid-configuration` | +| 403 | `USER_DISABLED` | SSO 用户本地状态为 disabled | +| 502 | `IDP_UNREACHABLE` | IdP 不可达(超时/连接失败) | + +callback 端点的错误默认重定向到 webchat 前端的错误页(`/?auth_error={code}`),供 spec ⑥ 前端渲染(当前 webchat 前端无登录页,但 HTTP 测试可验证重定向行为)。 + +--- + +## 12. 测试策略 + +### 12.1 单元测试 + +| 测试 | 文件 | 覆盖 | +|---|---|---| +| `OAuthProvider` 构造 + discovery | `oauth_provider_test.go` | 配置校验、discovery URL 拼接 | +| `OAuthManager` 注册/查找 | `oauth_manager_test.go` | 多 provider 并发注册/查找 | +| state cookie 签发/验证 | `oauth_state_test.go` | 正常/expired/tampered/mismatch | +| `GetOrCreateUserByIdentity` | `multitenancy_store_test.go` | 首次建号 / 复用 / 事务回滚 / disabled user | +| claim 映射 | `oauth_provider_test.go` | 标准 claim + 自定义 claim 名 | + +### 12.2 集成测试(mock IdP) + +使用 `httptest.Server` 模拟完整 OIDC IdP(discovery + token + JWKS + userinfo),端到端验证: + +| 测试 | 覆盖 | +|---|---| +| 完整 login → callback 流程 | 302 链 + cookie 签发 | +| 首次 SSO 登录建号 | users + user_identities 行写入 | +| 二次 SSO 登录复用 | 不重复建号,display_name 更新 | +| CSRF 攻击(无/错 state) | 400 CSRF_DETECTED | +| state 过期 | 400 STATE_EXPIRED | +| provider mismatch | 400 PROVIDER_MISMATCH | +| IdP 返回错误 code | 400 CODE_EXCHANGE_FAILED | +| disabled user | 403 USER_DISABLED | +| 多 provider 独立 | 两个 provider 各自建号 | + +### 12.3 手动验证(真实 IdP) + +使用 Keycloak Docker 容器作为真实 IdP 验证: +```bash +docker run -p 8080:8080 -e KEYCLOAK_ADMIN=admin -e KEYCLOAK_ADMIN_PASSWORD=admin quay.io/keycloak/keycloak:latest start-dev +# 创建 realm + client + test user +# 配置 hotplex oauth.providers 指向 localhost:8080 +# curl /api/auth/oauth/keycloak/login → 跟随重定向 → 验证 cookie 签发 +``` + +--- + +## 13. 迁移策略 + +### 13.1 migration 020 + +新增 `user_identities` 表(SQLite + PostgreSQL 双版本)。 + +**向后兼容**: +- 纯新增表,不修改现有 `users` 表 +- 现有密码账号不受影响(无 `user_identities` 行) +- 不配置 `oauth.providers` 时,SSO 端点返回空列表,系统行为与 spec ① 完全一致 + +### 13.2 Wire 现有 auth 端点 + +spec ① 的 auth_handlers 已实现但 `routes.go` 中标注 `TODO(spec ⑥)` 未 wire。本 spec 顺带 wire: +- `POST /api/auth/login` — 密码登录 +- `POST /api/auth/logout` +- `GET /api/auth/me` +- `POST /api/auth/accept-invite` +- 新增 `GET /api/auth/oauth/*` — SSO 端点 + +这样 spec ④ 交付后,WebChat 后端认证体系完整(密码 + SSO),spec ⑥ 只需做前端。 + +### 13.3 依赖引入 + +``` +go get github.com/coreos/go-oidc/v3/oidc +go get golang.org/x/oauth2 # indirect → direct +``` + +--- + +## 14. 后续 spec 路线 + +| spec | 内容 | 依赖 | +|---|---|---| +| spec ④.1(可选) | 多 identity 手动绑定 UI(admin/用户自助关联 SSO 到已有账号) | spec ④ | +| spec ④.2(可选) | SAML 2.0 provider(如确有 SAML-only 客户需求) | 独立 | +| **spec ⑥** | **WebChat 前端一等公民化**(登录页 + provider 选择 + workspace/worker UI) | spec ④ ⬅ 本 spec | + +--- + +## 附录 A:Go 依赖评估 + +### `github.com/coreos/go-oidc/v3` + +- **维护**:CoreOS(现 Red Hat)维护,活跃 +- **依赖**:仅 `golang.org/x/oauth2` + `go-jose`,轻量 +- **功能**:OIDC discovery、ID Token 验证、JWKS 缓存、UserInfo +- **license**:Apache 2.0 +- **评估**:业界标准 OIDC 库,Keycloak/Okta/Azure 文档示例常用 + +### `golang.org/x/oauth2` + +- **维护**:Go 团队官方 +- **已在 go.mod**:v0.36.0(indirect),提升为 direct +- **功能**:OAuth2 底层(token endpoint HTTP 调用、token 缓存接口) + +### 不引入的库 + +- `markbates/goth`:多 provider 抽象层(Google/GitHub/Facebook...),但太重且我们用标准 OIDC +- `dexidp/dex`:IdP 而非 client,角色不对 diff --git a/docs/specs/WebChat-Multitenancy-Roadmap-Spec.md b/docs/specs/WebChat-Multitenancy-Roadmap-Spec.md index ada6f7ad..c814f4eb 100644 --- a/docs/specs/WebChat-Multitenancy-Roadmap-Spec.md +++ b/docs/specs/WebChat-Multitenancy-Roadmap-Spec.md @@ -1,7 +1,7 @@ # WebChat 一等公民化与多租户路线图 **日期**: 2026-06-16 -**状态**: spec ① 已合入([PR #746](https://github.com/hrygo/hotplex/pull/746),`44f461ff`);spec ② 已合入([PR #748](https://github.com/hrygo/hotplex/pull/748));spec ③ 已合入([PR #753](https://github.com/hrygo/hotplex/pull/753),`207d47e3`);spec ⑤ 已合入([PR #755](https://github.com/hrygo/hotplex/pull/755));④/⑥ 待逐个 brainstorm +**状态**: spec ① 已合入([PR #746](https://github.com/hrygo/hotplex/pull/746));spec ② 已合入([PR #748](https://github.com/hrygo/hotplex/pull/748));spec ③ 已合入([PR #753](https://github.com/hrygo/hotplex/pull/753));spec ④ 已合入(`feat/webchat-oauth-sso`,企业 SSO OIDC);spec ⑤ 已合入([PR #755](https://github.com/hrygo/hotplex/pull/755));⑥ 待 brainstorm **分支**: main · **基线版本**: v1.29.0 (fb857af1) **关联设计**: [`WebChat-Multitenancy-Foundation-Design-Spec.md`](./WebChat-Multitenancy-Foundation-Design-Spec.md)(spec ①) @@ -209,4 +209,4 @@ PR #746 最新 review(基线 `68b1660`)早于 R6,其 **P1 阻塞项已在 - spec ⑥ 在 ②③④就绪后启动。 - 路线图文档随各 spec 推进更新状态。 -**下一步**:spec ⑤ 已合入([PR #755](https://github.com/hrygo/hotplex/pull/755))→ 启动 spec ④/⑥ brainstorm(OAuth SSO / 前端一等公民化)。spec ④ 需先拍板 §6.2 的 provider 优先级与账号合并策略;spec ⑥ 待 ④ 就绪后集成。spec ① 剩余增量(迁移验证 / 旧 webchat 会话清理 / e2e)可穿插提交。 +**下一步**:spec ④ 已实现(标准 OIDC,覆盖派拉/玉符/阿里云/腾讯云/宁盾/Authing/竹云/华为 OneAccess/TOPIAM + Keycloak/Okta/Azure AD 等)→ 启动 spec ⑥ brainstorm(webchat 前端一等公民化)。spec ⑥ 依赖 ①②③④⑤ 全部就绪(✅),可立即启动。spec ① 剩余增量(迁移验证 / 旧 webchat 会话清理 / e2e)可穿插提交。 diff --git a/go.mod b/go.mod index 91f1e77a..d95d30d9 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,9 @@ require ( github.com/anthropics/anthropic-sdk-go v1.46.0 github.com/apache/pulsar-client-go v0.14.0 github.com/cenkalti/backoff/v4 v4.3.0 + github.com/coreos/go-oidc/v3 v3.18.0 github.com/fsnotify/fsnotify v1.10.1 + github.com/go-jose/go-jose/v4 v4.1.4 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/hashicorp/golang-lru/v2 v2.0.7 @@ -37,6 +39,7 @@ require ( go.opentelemetry.io/otel/trace v1.44.0 go.uber.org/atomic v1.11.0 golang.org/x/crypto v0.53.0 + golang.org/x/oauth2 v0.36.0 golang.org/x/sync v0.21.0 golang.org/x/term v0.44.0 golang.org/x/time v0.15.0 @@ -101,7 +104,6 @@ require ( go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/mod v0.36.0 // indirect golang.org/x/net v0.55.0 // indirect - golang.org/x/oauth2 v0.36.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 1ef62493..340f64a4 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A= +github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4= github.com/cpuguy83/dockercfg v0.3.1 h1:/FpZ+JaygUR/lZP2NlFI2DVfrOEMAIKP5wWEJdoYe9E= github.com/cpuguy83/dockercfg v0.3.1/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= @@ -91,6 +93,8 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho= github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= diff --git a/internal/config/config_types.go b/internal/config/config_types.go index 31e35eec..41e0718c 100644 --- a/internal/config/config_types.go +++ b/internal/config/config_types.go @@ -24,6 +24,7 @@ type Config struct { Skills SkillsConfig `mapstructure:"skills"` Cron CronConfig `mapstructure:"cron"` Webhook WebhookConfig `mapstructure:"webhook"` + OAuth OAuthConfig `mapstructure:"oauth"` Events EventsConfig `mapstructure:"events"` Inherits string `mapstructure:"inherits"` // path to parent config file; "" = no inheritance diff --git a/internal/config/oauth_types.go b/internal/config/oauth_types.go new file mode 100644 index 00000000..b76ecb1e --- /dev/null +++ b/internal/config/oauth_types.go @@ -0,0 +1,128 @@ +package config + +import ( + "fmt" + "regexp" + "strings" +) + +// OAuthConfig holds enterprise SSO (OIDC) settings for WebChat multitenancy (spec ④). +// Independent from bot configs (Slack/Feishu) — those are Message Channel track, +// not WebChat SSO. This config is WebChat-only. +type OAuthConfig struct { + // ExternalURL is the base URL for constructing OAuth callback URLs. + // When empty, derived from the request Host header (same-origin deployments). + // Must be set explicitly behind reverse proxies with different public URL. + ExternalURL string `mapstructure:"external_url"` + + // Providers is the list of configured OIDC identity providers. + // Multiple providers render as multiple SSO buttons on the login page (spec ⑥). + Providers []OAuthProviderConfig `mapstructure:"providers"` +} + +// OAuthProviderConfig defines a single OIDC identity provider. +// One standard OIDC client implementation covers all OIDC-compatible IdPs +// (Keycloak, Okta, Azure AD, Authing, 派拉, 玉符, 宁盾, etc.). +type OAuthProviderConfig struct { + // Name is the unique identifier for this provider. Appears in URL paths + // (/api/auth/oauth/{name}/login) and user_identities.provider. + // Must match [a-z0-9-]+ to be URL-safe and prevent path traversal. + Name string `mapstructure:"name"` + + // DisplayName is the human-readable label shown on the login page (spec ⑥). + DisplayName string `mapstructure:"display_name"` + + // Issuer is the OIDC issuer URL. The OIDC discovery endpoint is auto-resolved + // from {issuer}/.well-known/openid-configuration. + Issuer string `mapstructure:"issuer"` + + // ClientID is the OAuth2 client identifier registered with the IdP. + ClientID string `mapstructure:"client_id"` + + // ClientSecret is the OAuth2 client secret. Supports ${ENV_VAR} expansion. + ClientSecret string `mapstructure:"client_secret"` + + // Scopes are the OIDC scopes to request. Defaults to ["openid", "profile"]. + Scopes []string `mapstructure:"scopes"` + + // Optional claim name overrides. When empty, OIDC standard claim names are used. + UsernameClaim string `mapstructure:"username_claim"` + DisplayNameClaim string `mapstructure:"display_name_claim"` + EmailClaim string `mapstructure:"email_claim"` +} + +var oauthProviderNameRe = regexp.MustCompile(`^[a-z0-9-]+$`) + +// Validate checks the OAuthConfig for correctness. +func (c *OAuthConfig) Validate() error { + seen := make(map[string]bool) + for i, p := range c.Providers { + if p.Name == "" { + return fmt.Errorf("oauth.providers[%d]: name is required", i) + } + if !oauthProviderNameRe.MatchString(p.Name) { + return fmt.Errorf("oauth.providers[%d]: name %q must match [a-z0-9-]+", i, p.Name) + } + if seen[p.Name] { + return fmt.Errorf("oauth.providers[%d]: duplicate provider name %q", i, p.Name) + } + seen[p.Name] = true + + if p.Issuer == "" { + return fmt.Errorf("oauth.providers[%d] (%s): issuer is required", i, p.Name) + } + if p.ClientID == "" { + return fmt.Errorf("oauth.providers[%d] (%s): client_id is required", i, p.Name) + } + if p.ClientSecret == "" { + return fmt.Errorf("oauth.providers[%d] (%s): client_secret is required", i, p.Name) + } + } + return nil +} + +// DefaultScopes returns the scopes to request if none configured. +func (p OAuthProviderConfig) DefaultScopes() []string { + if len(p.Scopes) > 0 { + return p.Scopes + } + return []string{"openid", "profile"} +} + +// EffectiveDisplayName returns DisplayName or falls back to Name. +func (p OAuthProviderConfig) EffectiveDisplayName() string { + if p.DisplayName != "" { + return p.DisplayName + } + return p.Name +} + +// UsernameClaimName returns the configured or default claim for username. +func (p OAuthProviderConfig) UsernameClaimName() string { + if p.UsernameClaim != "" { + return p.UsernameClaim + } + return "preferred_username" +} + +// DisplayNameClaimName returns the configured or default claim for display name. +func (p OAuthProviderConfig) DisplayNameClaimName() string { + if p.DisplayNameClaim != "" { + return p.DisplayNameClaim + } + return "name" +} + +// EmailClaimName returns the configured or default claim for email. +func (p OAuthProviderConfig) EmailClaimName() string { + if p.EmailClaim != "" { + return p.EmailClaim + } + return "email" +} + +// CallbackURL constructs the OAuth callback URL for this provider. +func (c *OAuthConfig) CallbackURL(externalURL, providerName string) string { + base := strings.TrimRight(externalURL, "/") + return fmt.Sprintf("%s/api/auth/oauth/%s/callback", base, providerName) +} diff --git a/internal/config/oauth_types_test.go b/internal/config/oauth_types_test.go new file mode 100644 index 00000000..5f9549cf --- /dev/null +++ b/internal/config/oauth_types_test.go @@ -0,0 +1,106 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOAuthConfig_Validate_Empty(t *testing.T) { + t.Parallel() + var cfg OAuthConfig + require.NoError(t, cfg.Validate()) +} + +func TestOAuthConfig_Validate_Valid(t *testing.T) { + t.Parallel() + cfg := OAuthConfig{ + Providers: []OAuthProviderConfig{ + {Name: "keycloak", Issuer: "https://sso.example.com", ClientID: "id", ClientSecret: "secret"}, + {Name: "authing", Issuer: "https://xxx.authing.cn", ClientID: "id2", ClientSecret: "secret2"}, + }, + } + require.NoError(t, cfg.Validate()) +} + +func TestOAuthConfig_Validate_MissingName(t *testing.T) { + t.Parallel() + cfg := OAuthConfig{ + Providers: []OAuthProviderConfig{ + {Issuer: "https://sso.example.com", ClientID: "id", ClientSecret: "secret"}, + }, + } + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "name is required") +} + +func TestOAuthConfig_Validate_InvalidName(t *testing.T) { + t.Parallel() + cfg := OAuthConfig{ + Providers: []OAuthProviderConfig{ + {Name: "Bad Name!", Issuer: "https://sso.example.com", ClientID: "id", ClientSecret: "secret"}, + }, + } + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "[a-z0-9-]+") +} + +func TestOAuthConfig_Validate_DuplicateName(t *testing.T) { + t.Parallel() + cfg := OAuthConfig{ + Providers: []OAuthProviderConfig{ + {Name: "kc", Issuer: "https://a", ClientID: "id", ClientSecret: "s"}, + {Name: "kc", Issuer: "https://b", ClientID: "id2", ClientSecret: "s2"}, + }, + } + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "duplicate") +} + +func TestOAuthConfig_Validate_MissingIssuer(t *testing.T) { + t.Parallel() + cfg := OAuthConfig{ + Providers: []OAuthProviderConfig{ + {Name: "kc", ClientID: "id", ClientSecret: "s"}, + }, + } + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "issuer is required") +} + +func TestOAuthProviderConfig_DefaultScopes(t *testing.T) { + t.Parallel() + p := OAuthProviderConfig{Scopes: nil} + require.Equal(t, []string{"openid", "profile"}, p.DefaultScopes()) + + p2 := OAuthProviderConfig{Scopes: []string{"openid", "email", "groups"}} + require.Equal(t, []string{"openid", "email", "groups"}, p2.DefaultScopes()) +} + +func TestOAuthProviderConfig_ClaimDefaults(t *testing.T) { + t.Parallel() + p := OAuthProviderConfig{} + require.Equal(t, "preferred_username", p.UsernameClaimName()) + require.Equal(t, "name", p.DisplayNameClaimName()) + require.Equal(t, "email", p.EmailClaimName()) + + p2 := OAuthProviderConfig{UsernameClaim: "uid", DisplayNameClaim: "nick", EmailClaim: "mail"} + require.Equal(t, "uid", p2.UsernameClaimName()) + require.Equal(t, "nick", p2.DisplayNameClaimName()) + require.Equal(t, "mail", p2.EmailClaimName()) +} + +func TestOAuthConfig_CallbackURL(t *testing.T) { + t.Parallel() + cfg := OAuthConfig{ExternalURL: "https://hotplex.example.com"} + url := cfg.CallbackURL("https://hotplex.example.com", "keycloak") + require.Equal(t, "https://hotplex.example.com/api/auth/oauth/keycloak/callback", url) + + // Trailing slash on external_url should be trimmed. + url2 := cfg.CallbackURL("https://hotplex.example.com/", "kc") + require.Equal(t, "https://hotplex.example.com/api/auth/oauth/kc/callback", url2) +} diff --git a/internal/gateway/bridge.go b/internal/gateway/bridge.go index 4b1a5968..5d2c8a58 100644 --- a/internal/gateway/bridge.go +++ b/internal/gateway/bridge.go @@ -67,6 +67,7 @@ type Bridge struct { mcpConfigJSON atomic.Value // pre-serialized MCP config JSON string; "" = not configured agentConfigExclude atomic.Value // map[string][]string: platform → inject_exclude (global default at "" key) wsStore WorkspaceOverridesReader // per-workspace agent-config overrides resolver (spec ②); nil = Message Channel track + warnedOverrides sync.Map // workspaceID → struct{}: dedup override-degrade warnings (#749) accum map[string]*sessionAccumulator // per-session stats accumulator diff --git a/internal/gateway/bridge_worker.go b/internal/gateway/bridge_worker.go index 0e87ab9b..14faa892 100644 --- a/internal/gateway/bridge_worker.go +++ b/internal/gateway/bridge_worker.go @@ -330,19 +330,30 @@ func (b *Bridge) resolveWorkspaceOverrides(ctx context.Context, workspaceID stri } ws, err := b.wsStore.GetWorkspaceByID(ctx, workspaceID) if err != nil { - b.log.Warn("bridge: fetch workspace overrides failed, degrading to team defaults", - "workspace_id", workspaceID, "err", err) + b.warnOverrideDegrade(workspaceID, "fetch workspace overrides failed, degrading to team defaults", err) return nil } overrides, err := agentconfig.ValidateOverrides(ws.AgentConfigOverrides) if err != nil { - b.log.Warn("bridge: parse workspace overrides failed, degrading to team defaults", - "workspace_id", workspaceID, "err", err) + b.warnOverrideDegrade(workspaceID, "parse workspace overrides failed, degrading to team defaults", err) return nil } + // Valid overrides (or empty): clear any prior warning flag so a future + // regression is warned again (#749). + b.warnedOverrides.Delete(workspaceID) return overrides } +// warnOverrideDegrade logs a degrading warning at most once per workspaceID per +// process lifetime, preventing log spam under high-crash session loops (#749). +// The warning is re-armed when the workspace later resolves successfully. +func (b *Bridge) warnOverrideDegrade(workspaceID, msg string, err error) { + if _, loaded := b.warnedOverrides.LoadOrStore(workspaceID, struct{}{}); loaded { + return + } + b.log.Warn("bridge: "+msg, "workspace_id", workspaceID, "err", err) +} + // injectAgentConfig loads agent config files and injects the unified system // prompt into session info. A no-op when config dir is empty or agent config // is not configured. diff --git a/internal/gateway/bridge_worker_test.go b/internal/gateway/bridge_worker_test.go index f725474e..a6941309 100644 --- a/internal/gateway/bridge_worker_test.go +++ b/internal/gateway/bridge_worker_test.go @@ -1,9 +1,11 @@ package gateway import ( + "bytes" "context" "errors" "log/slog" + "strings" "testing" "github.com/stretchr/testify/require" @@ -27,6 +29,13 @@ func (s *stubWSStore) GetWorkspaceByID(_ context.Context, _ string) (*session.Wo return s.ws, s.err } +func (s *stubWSStore) ListAllWorkspaces(_ context.Context) ([]*session.Workspace, error) { + if s.ws != nil { + return []*session.Workspace{s.ws}, nil + } + return nil, nil +} + // newBridgeForOverrideTest builds a minimal Bridge with only wsStore + log set, // enough to exercise resolveWorkspaceOverrides without the full dependency graph. func newBridgeForOverrideTest(t *testing.T, ws *session.Workspace, err error) *Bridge { @@ -79,4 +88,33 @@ func TestResolveWorkspaceOverrides(t *testing.T) { b := newBridgeForOverrideTest(t, ws, nil) require.Nil(t, b.resolveWorkspaceOverrides(context.Background(), "ws-1")) }) + + t.Run("warn dedup: repeated degrade warns once, then re-arms on success (#749)", func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + h := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelWarn}) + ws := &session.Workspace{ID: "ws-dedup", AgentConfigOverrides: `{bad`} + b := &Bridge{ + log: slog.New(h), + wsStore: &stubWSStore{ws: ws}, + } + // Three consecutive degrade calls — only first should warn. + b.resolveWorkspaceOverrides(context.Background(), "ws-dedup") + b.resolveWorkspaceOverrides(context.Background(), "ws-dedup") + b.resolveWorkspaceOverrides(context.Background(), "ws-dedup") + warnCount := strings.Count(buf.String(), "level=WARN") + require.Equal(t, 1, warnCount, "expected exactly 1 warn for repeated degrade, got %d: %s", warnCount, buf.String()) + + // Fix overrides → valid resolution clears the flag. + ws.AgentConfigOverrides = `{"SOUL.md":"fixed"}` + b.resolveWorkspaceOverrides(context.Background(), "ws-dedup") + // No new warn from success path. + require.Equal(t, 1, strings.Count(buf.String(), "level=WARN")) + + // Break again → new warn appears (flag was cleared). + ws.AgentConfigOverrides = `{broken` + b.resolveWorkspaceOverrides(context.Background(), "ws-dedup") + require.Equal(t, 2, strings.Count(buf.String(), "level=WARN"), + "expected warn re-armed after success, got: %s", buf.String()) + }) } diff --git a/internal/gateway/deps.go b/internal/gateway/deps.go index 3601f6be..b329151c 100644 --- a/internal/gateway/deps.go +++ b/internal/gateway/deps.go @@ -21,10 +21,12 @@ type HandlerDeps struct { } // WorkspaceOverridesReader is the narrow workspace-store subset Bridge needs to -// resolve per-workspace agent-config overrides (spec ② §7.3). Kept separate from -// session.UserWorkspaceStore so tests mock a single method. +// resolve per-workspace agent-config overrides (spec ② §7.3), plus the startup +// scan to detect stale/invalid overrides (#749). Kept separate from +// session.UserWorkspaceStore so tests mock a minimal surface. type WorkspaceOverridesReader interface { GetWorkspaceByID(ctx context.Context, id string) (*session.Workspace, error) + ListAllWorkspaces(ctx context.Context) ([]*session.Workspace, error) } // BridgeDeps groups all dependencies for Bridge construction. diff --git a/internal/gateway/oauth_handlers.go b/internal/gateway/oauth_handlers.go new file mode 100644 index 00000000..4f04affa --- /dev/null +++ b/internal/gateway/oauth_handlers.go @@ -0,0 +1,275 @@ +package gateway + +import ( + "context" + "encoding/json" + "errors" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/google/uuid" + + "github.com/hrygo/hotplex/internal/security" + "github.com/hrygo/hotplex/internal/session" +) + +// OAuthHandlers holds dependencies for OIDC SSO endpoints (spec ④). +type OAuthHandlers struct { + oauthManager *security.OAuthManager + cookieAuth *security.CookieAuth + store session.UserWorkspaceStore + log *slog.Logger + now func() time.Time +} + +// NewOAuthHandlers constructs OAuth SSO handlers. +func NewOAuthHandlers(oauthManager *security.OAuthManager, cookieAuth *security.CookieAuth, store session.UserWorkspaceStore, log *slog.Logger) *OAuthHandlers { + if log == nil { + log = slog.Default() + } + return &OAuthHandlers{ + oauthManager: oauthManager, + cookieAuth: cookieAuth, + store: store, + log: log.With("component", "oauth_handler"), + now: time.Now, + } +} + +// Providers: GET /api/auth/oauth/providers +// Lists configured SSO providers for the login page to render buttons. +func (h *OAuthHandlers) Providers(w http.ResponseWriter, r *http.Request) { + providers := h.oauthManager.List() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "providers": providers, + }) +} + +// Login: GET /api/auth/oauth/{provider}/login +// Initiates the OIDC authorization code flow. Generates state + PKCE verifier, +// stores them in a signed short-lived cookie, and redirects to the IdP. +func (h *OAuthHandlers) Login(w http.ResponseWriter, r *http.Request) { + providerName := r.PathValue("provider") + if providerName == "" { + writeAppError(w, http.StatusBadRequest, "BAD_REQUEST", "provider required") + return + } + + provider, ok := h.oauthManager.Lookup(providerName) + if !ok { + writeAppError(w, http.StatusNotFound, "PROVIDER_NOT_FOUND", "provider not configured") + return + } + + state, codeVerifier, codeChallenge, err := security.GenerateStateAndVerifier() + if err != nil { + h.log.Error("generate state failed", "provider", providerName, "err", err) + writeAppError(w, http.StatusInternalServerError, "INTERNAL", "state generation failed") + return + } + + // Store state + PKCE verifier in signed cookie (5min TTL). + security.SetStateCookie(w, r, h.cookieAuth, security.StateCookiePayload{ + State: state, + CodeVerifier: codeVerifier, + Provider: providerName, + IssuedAt: h.now(), + }) + + authURL := provider.BuildAuthURL(security.AuthURLOption{ + State: state, + CodeChallenge: codeChallenge, + }) + + h.log.Debug("oauth login redirect", "provider", providerName, "state", state[:8]+"...") + http.Redirect(w, r, authURL, http.StatusFound) +} + +// Callback: GET /api/auth/oauth/{provider}/callback +// Handles the IdP redirect after user authentication. Validates state (CSRF), +// exchanges code for tokens, verifies ID Token, finds or creates the user, +// and issues a session cookie. +func (h *OAuthHandlers) Callback(w http.ResponseWriter, r *http.Request) { + providerName := r.PathValue("provider") + if providerName == "" { + redirectAuthError(w, r, "BAD_REQUEST") + return + } + + provider, ok := h.oauthManager.Lookup(providerName) + if !ok { + redirectAuthError(w, r, "PROVIDER_NOT_FOUND") + return + } + + // Check for IdP error response. + if errCode := r.URL.Query().Get("error"); errCode != "" { + h.log.Warn("oauth callback: idp error", "provider", providerName, "error", errCode) + redirectAuthError(w, r, "IDP_ERROR") + return + } + + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + if code == "" || state == "" { + redirectAuthError(w, r, "BAD_REQUEST") + return + } + + // Verify state cookie (CSRF + PKCE verifier + provider binding). + payload, err := security.VerifyStateCookie(r, h.cookieAuth, state, providerName) + if err != nil { + h.log.Warn("oauth callback: state verification failed", "provider", providerName, "err", err) + redirectAuthError(w, r, classifyStateError(err)) + return + } + + // Exchange authorization code for tokens (with PKCE verifier). + exchangeResult, err := provider.ExchangeCode(r.Context(), code, payload.CodeVerifier) + if err != nil { + h.log.Error("oauth callback: token exchange failed", "provider", providerName, "err", err) + redirectAuthError(w, r, "CODE_EXCHANGE_FAILED") + return + } + + // Verify ID Token signature and extract claims. + claims, err := provider.VerifyAndExtractClaims(r.Context(), exchangeResult.IDToken) + if err != nil { + h.log.Error("oauth callback: id_token verification failed", "provider", providerName, "err", err) + redirectAuthError(w, r, "ID_TOKEN_INVALID") + return + } + + // Find or create user from SSO identity. + userID, err := h.getOrCreateUser(r.Context(), providerName, claims) + if err != nil { + var idErr *security.IdentityError + if errors.As(err, &idErr) && idErr.Code == security.ErrCodeUserDisabled { + h.log.Warn("oauth callback: user disabled", "provider", providerName, "subject", claims.Subject) + redirectAuthError(w, r, "USER_DISABLED") + } else { + h.log.Error("oauth callback: user creation failed", "provider", providerName, "subject", claims.Subject, "err", err) + redirectAuthError(w, r, "USER_CREATE_FAILED") + } + return + } + + // Issue session cookie (same as password login). + if err := h.cookieAuth.SetCookie(w, r, userID); err != nil { + h.log.Error("oauth callback: cookie issuance failed", "provider", providerName, "err", err) + redirectAuthError(w, r, "INTERNAL") + return + } + + // Clear the OAuth state cookie (one-time use). + security.ClearStateCookie(w, r) + + // Touch last login. + _ = h.store.TouchUserLastLogin(r.Context(), userID, h.now().Unix()) + + h.log.Info("oauth login success", "provider", providerName, "subject", claims.Subject, "user_id", userID) + + // Redirect to webchat home (spec ⑥ will handle post-login routing). + http.Redirect(w, r, "/", http.StatusFound) +} + +// getOrCreateUser implements the spec ④ §8.1 account association: +// - Look up by (provider, subject) → found: update profile, return user_id. +// - Not found: create users row + user_identities row, return user_id. +func (h *OAuthHandlers) getOrCreateUser(ctx context.Context, providerName string, claims *security.OIDCClaims) (string, error) { + now := h.now().Unix() + + // Check if identity already exists. + identity, err := h.store.GetUserIdentityByProviderSubject(ctx, providerName, claims.Subject) + if err == nil { + // Identity exists — update display_name/email from IdP (IdP is authoritative). + if identity.DisplayName != claims.DisplayName || identity.Email != claims.Email { + _ = h.store.UpdateUserIdentityProfile(ctx, identity.ID, claims.DisplayName, claims.Email, now) + } + // Verify user is not disabled. + user, err := h.store.GetUserByID(ctx, identity.UserID) + if err != nil { + return "", err + } + if user.Status == "disabled" { + return "", &security.IdentityError{Code: security.ErrCodeUserDisabled} + } + return identity.UserID, nil + } + if !errors.Is(err, session.ErrIdentityNotFound) { + return "", err + } + + // First login: create user + identity. + userID := uuid.NewString() + username := providerName + ":" + claims.Subject + + // Idempotent create path (P1): a prior first-login attempt may have left an + // orphaned users row (CreateUser succeeded, CreateUserIdentity failed on a + // transient DB error). On retry GetUserIdentityByProviderSubject still misses, + // so we re-check by username before insert. If found, we reuse that user row + // and only (re)create the identity, avoiding a UNIQUE(username) violation that + // would permanently lock out the SSO identity. + existing, err := h.store.GetUserByUsername(ctx, username) + if err != nil && !errors.Is(err, security.ErrUserNotFound) { + return "", err + } + if existing != nil { + // Recover orphaned user row from a crashed prior first-login. + userID = existing.ID + } else { + err = h.store.CreateUser(ctx, &security.User{ + ID: userID, + Username: username, + PasswordHash: "", // SSO-only account; empty hash = no password login + Role: "user", + DisplayName: claims.DisplayName, + Status: "active", + }, now) + if err != nil { + return "", err + } + } + + identityID := uuid.NewString() + err = h.store.CreateUserIdentity(ctx, &session.UserIdentity{ + ID: identityID, + UserID: userID, + Provider: providerName, + Subject: claims.Subject, + DisplayName: claims.DisplayName, + Email: claims.Email, + }, now) + if err != nil { + return "", err + } + + return userID, nil +} + +// redirectAuthError redirects to webchat home with an auth_error query param. +// spec ⑥ frontend will render an error message from this param. +func redirectAuthError(w http.ResponseWriter, r *http.Request, code string) { + security.ClearStateCookie(w, r) + http.Redirect(w, r, "/?auth_error="+code, http.StatusFound) +} + +// classifyStateError maps state verification errors to error codes. +func classifyStateError(err error) string { + msg := err.Error() + switch { + case strings.Contains(msg, "expired"): + return "STATE_EXPIRED" + case strings.Contains(msg, "provider mismatch"): + return "PROVIDER_MISMATCH" + case strings.Contains(msg, "csrf"): + return "CSRF_DETECTED" + case strings.Contains(msg, "cookie missing"): + return "CSRF_DETECTED" + default: + return "STATE_INVALID" + } +} diff --git a/internal/gateway/workspace_scan.go b/internal/gateway/workspace_scan.go new file mode 100644 index 00000000..bc260bd7 --- /dev/null +++ b/internal/gateway/workspace_scan.go @@ -0,0 +1,52 @@ +package gateway + +import ( + "context" + "time" + + "log/slog" + + "github.com/hrygo/hotplex/internal/agentconfig" +) + +// ScanWorkspaceOverrides performs a one-time validation sweep of all active +// workspaces' agent_config_overrides at gateway startup (#749). It logs a +// Warn for each workspace whose overrides fail validation — surfacing stale +// data written before spec ② write-time validation existed, without blocking +// startup. No data is modified; the runtime degrade path remains unchanged. +func ScanWorkspaceOverrides(ctx context.Context, store WorkspaceOverridesReader, log *slog.Logger) { + if store == nil { + return + } + scanCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + workspaces, err := store.ListAllWorkspaces(scanCtx) + if err != nil { + log.Warn("gateway: startup workspace overrides scan failed", "err", err) + return + } + + var dirty int + for _, ws := range workspaces { + if ws.AgentConfigOverrides == "" { + continue + } + if _, err := agentconfig.ValidateOverrides(ws.AgentConfigOverrides); err != nil { + dirty++ + log.Warn("gateway: workspace has invalid agent_config_overrides (will degrade to team defaults at runtime)", + "workspace_id", ws.ID, + "workspace_name", ws.Name, + "owner", ws.OwnerUserID, + "err", err) + } + } + if dirty > 0 { + log.Warn("gateway: startup scan complete, invalid agent_config_overrides detected", + "dirty_count", dirty, "total_scanned", len(workspaces), + "hint", "PATCH the affected workspace(s) with valid JSON to resolve") + } else { + log.Debug("gateway: startup scan complete, all workspace overrides valid", + "total_scanned", len(workspaces)) + } +} diff --git a/internal/gateway/workspace_scan_test.go b/internal/gateway/workspace_scan_test.go new file mode 100644 index 00000000..a9f53607 --- /dev/null +++ b/internal/gateway/workspace_scan_test.go @@ -0,0 +1,94 @@ +package gateway + +import ( + "bytes" + "context" + "log/slog" + "strings" + "testing" + + "github.com/hrygo/hotplex/internal/session" +) + +// scanTestStore implements WorkspaceOverridesReader for scan tests. +type scanTestStore struct { + workspaces []*session.Workspace + err error +} + +func (s *scanTestStore) GetWorkspaceByID(_ context.Context, _ string) (*session.Workspace, error) { + return nil, nil +} + +func (s *scanTestStore) ListAllWorkspaces(_ context.Context) ([]*session.Workspace, error) { + return s.workspaces, s.err +} + +func captureScanLog() (*slog.Logger, *bytes.Buffer) { + var buf bytes.Buffer + h := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug}) + return slog.New(h), &buf +} + +func TestScanWorkspaceOverrides_NilStore_Noop(t *testing.T) { + t.Parallel() + log, buf := captureScanLog() + ScanWorkspaceOverrides(context.Background(), nil, log) + if buf.Len() > 0 { + t.Fatalf("expected no output for nil store, got: %s", buf.String()) + } +} + +func TestScanWorkspaceOverrides_AllValid_NoWarn(t *testing.T) { + t.Parallel() + store := &scanTestStore{ + workspaces: []*session.Workspace{ + {ID: "ws-1", AgentConfigOverrides: ""}, + {ID: "ws-2", AgentConfigOverrides: `{"SOUL.md":"ok"}`}, + }, + } + log, buf := captureScanLog() + ScanWorkspaceOverrides(context.Background(), store, log) + output := buf.String() + if strings.Contains(output, "level=WARN") { + t.Fatalf("expected no warnings for valid overrides, got: %s", output) + } + if !strings.Contains(output, "all workspace overrides valid") { + t.Fatalf("expected debug success message, got: %s", output) + } +} + +func TestScanWorkspaceOverrides_DirtyData_Warns(t *testing.T) { + t.Parallel() + store := &scanTestStore{ + workspaces: []*session.Workspace{ + {ID: "ws-clean", Name: "Clean", OwnerUserID: "u-1", AgentConfigOverrides: `{"SOUL.md":"ok"}`}, + {ID: "ws-dirty", Name: "Dirty", OwnerUserID: "u-2", AgentConfigOverrides: `{bad json`}, + }, + } + log, buf := captureScanLog() + ScanWorkspaceOverrides(context.Background(), store, log) + output := buf.String() + if !strings.Contains(output, "ws-dirty") { + t.Fatalf("expected warning about ws-dirty, got: %s", output) + } + if !strings.Contains(output, "invalid agent_config_overrides detected") { + t.Fatalf("expected summary warning, got: %s", output) + } + if strings.Contains(output, "ws-clean") { + t.Fatalf("should not warn about valid workspace, got: %s", output) + } +} + +func TestScanWorkspaceOverrides_StoreError_WarnsOnce(t *testing.T) { + t.Parallel() + store := &scanTestStore{ + err: context.DeadlineExceeded, + } + log, buf := captureScanLog() + ScanWorkspaceOverrides(context.Background(), store, log) + output := buf.String() + if !strings.Contains(output, "startup workspace overrides scan failed") { + t.Fatalf("expected scan failure warning, got: %s", output) + } +} diff --git a/internal/security/oauth_manager.go b/internal/security/oauth_manager.go new file mode 100644 index 00000000..d9c17794 --- /dev/null +++ b/internal/security/oauth_manager.go @@ -0,0 +1,146 @@ +package security + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/hrygo/hotplex/internal/config" +) + +// OAuthManager manages multiple OIDC providers. It is the central registry +// for SSO providers, supporting runtime hot-reload (provider list rebuilt from +// config changes without dropping in-flight OAuth flows). +// +// Thread-safety: providers are stored behind RWMutex. Lookup is concurrent-safe. +// In-flight flows are unaffected by reload because state cookies encode the +// provider name, and lookup at callback time reads the current registry. +type OAuthManager struct { + mu sync.RWMutex + providers map[string]*OAuthProvider + externalURL string + cookieAuth *CookieAuth // for signing state cookies +} + +// NewOAuthManager creates an empty manager. +func NewOAuthManager(cookieAuth *CookieAuth) *OAuthManager { + return &OAuthManager{ + providers: make(map[string]*OAuthProvider), + cookieAuth: cookieAuth, + } +} + +// Reload rebuilds the provider registry from the given OAuthConfig. +// Discovery is performed for each provider; a provider that fails discovery +// is skipped (logged by caller). Returns count of successfully loaded providers. +// +// Providers not in the new config are removed. Existing providers with unchanged +// issuer/client_id are preserved (no re-discovery) to avoid churn during +// unrelated config reloads. +func (m *OAuthManager) Reload(ctx context.Context, cfg config.OAuthConfig) (int, error) { + if len(cfg.Providers) == 0 { + m.mu.Lock() + m.providers = make(map[string]*OAuthProvider) + m.externalURL = cfg.ExternalURL + m.mu.Unlock() + return 0, nil + } + + externalURL := cfg.ExternalURL + newProviders := make(map[string]*OAuthProvider) + var errs []error + + for _, pcfg := range cfg.Providers { + // Check if we already have this provider with same issuer+clientID. + m.mu.RLock() + existing, ok := m.providers[pcfg.Name] + m.mu.RUnlock() + + if ok && existing.Config().Issuer == pcfg.Issuer && existing.Config().ClientID == pcfg.ClientID { + // Preserve existing (no re-discovery needed); only update non-discovery fields. + newProviders[pcfg.Name] = existing + continue + } + + // Build callback URL. + callbackURL := cfg.CallbackURL(externalURL, pcfg.Name) + + // Construct OAuthProviderConfig from config. + opCfg := OAuthProviderConfig{ + Name: pcfg.Name, + DisplayName: pcfg.DisplayName, + Issuer: pcfg.Issuer, + ClientID: pcfg.ClientID, + ClientSecret: pcfg.ClientSecret, + Scopes: pcfg.DefaultScopes(), + UsernameClaim: pcfg.UsernameClaimName(), + DisplayNameClaim: pcfg.DisplayNameClaimName(), + EmailClaim: pcfg.EmailClaimName(), + } + + provider, err := NewOAuthProvider(ctx, opCfg, callbackURL) + if err != nil { + errs = append(errs, err) + continue + } + newProviders[pcfg.Name] = provider + } + + m.mu.Lock() + m.providers = newProviders + m.externalURL = externalURL + m.mu.Unlock() + + if len(errs) > 0 { + return len(newProviders), fmt.Errorf("oauth manager: %d provider(s) failed to load: %w", len(errs), errors.Join(errs...)) + } + return len(newProviders), nil +} + +// Lookup returns the OAuthProvider by name. +func (m *OAuthManager) Lookup(name string) (*OAuthProvider, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + p, ok := m.providers[name] + return p, ok +} + +// List returns all registered provider names (sorted by registration order is +// not guaranteed; callers should sort if deterministic order is needed). +func (m *OAuthManager) List() []ProviderInfo { + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]ProviderInfo, 0, len(m.providers)) + for _, p := range m.providers { + out = append(out, ProviderInfo{ + Name: p.Name(), + DisplayName: p.DisplayName(), + }) + } + return out +} + +// ProviderInfo is the public-facing provider descriptor (no secrets). +type ProviderInfo struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` +} + +// ExternalURL returns the configured external base URL. +func (m *OAuthManager) ExternalURL() string { + m.mu.RLock() + defer m.mu.RUnlock() + return m.externalURL +} + +// HasProviders returns true if at least one provider is configured. +func (m *OAuthManager) HasProviders() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.providers) > 0 +} + +// CookieAuth returns the CookieAuth used for signing state cookies. +// Used by OAuthHandlers to issue/verify state cookies. +func (m *OAuthManager) CookieAuth() *CookieAuth { return m.cookieAuth } diff --git a/internal/security/oauth_manager_test.go b/internal/security/oauth_manager_test.go new file mode 100644 index 00000000..a59a51b3 --- /dev/null +++ b/internal/security/oauth_manager_test.go @@ -0,0 +1,202 @@ +package security + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/hrygo/hotplex/internal/config" +) + +func TestOAuthManager_Empty(t *testing.T) { + t.Parallel() + + ca, err := NewCookieAuth() + require.NoError(t, err) + + m := NewOAuthManager(ca) + + require.False(t, m.HasProviders(), "fresh manager has no providers") + require.Empty(t, m.List(), "fresh manager lists nothing") + _, ok := m.Lookup("anything") + require.False(t, ok, "fresh manager lookup misses") + require.Empty(t, m.ExternalURL(), "fresh manager has no external URL") + require.Same(t, ca, m.CookieAuth(), "CookieAuth returns the injected signer") +} + +func TestOAuthManager_ReloadSuccess(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.close) + + ca, err := NewCookieAuth() + require.NoError(t, err) + m := NewOAuthManager(ca) + + cfg := config.OAuthConfig{ + ExternalURL: "https://hotplex.example.com", + Providers: []config.OAuthProviderConfig{{ + Name: "keycloak", + DisplayName: "企业 SSO", + Issuer: mock.issuer(), + ClientID: mock.clientID, + }}, + } + + count, err := m.Reload(context.Background(), cfg) + require.NoError(t, err) + require.Equal(t, 1, count) + require.True(t, m.HasProviders()) + + p, ok := m.Lookup("keycloak") + require.True(t, ok) + require.NotNil(t, p) + require.Equal(t, "keycloak", p.Name()) + + list := m.List() + require.Len(t, list, 1) + require.Equal(t, "keycloak", list[0].Name) + require.Equal(t, "企业 SSO", list[0].DisplayName) + + require.Equal(t, "https://hotplex.example.com", m.ExternalURL()) +} + +func TestOAuthManager_ReloadEmptyClearsProviders(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.close) + + m := NewOAuthManager(mustCookieAuth(t)) + + loaded, err := m.Reload(context.Background(), config.OAuthConfig{ + ExternalURL: "https://hotplex.example.com", + Providers: []config.OAuthProviderConfig{{ + Name: "keycloak", Issuer: mock.issuer(), ClientID: mock.clientID, + }}, + }) + require.NoError(t, err) + require.Equal(t, 1, loaded) + require.True(t, m.HasProviders()) + + // Empty providers list must clear existing providers and set externalURL. + count, err := m.Reload(context.Background(), config.OAuthConfig{ExternalURL: "https://changed.example.com"}) + require.NoError(t, err) + require.Equal(t, 0, count) + require.False(t, m.HasProviders(), "reload with empty providers must clear registry") + require.Empty(t, m.List()) + require.Equal(t, "https://changed.example.com", m.ExternalURL()) +} + +func TestOAuthManager_ReloadDiscoveryError(t *testing.T) { + t.Parallel() + + m := NewOAuthManager(mustCookieAuth(t)) + + // Unreachable issuer → discovery fails for every provider. + cfg := config.OAuthConfig{ + ExternalURL: "https://hotplex.example.com", + Providers: []config.OAuthProviderConfig{{ + Name: "bad", Issuer: "http://127.0.0.1:0", ClientID: "x", ClientSecret: "y", + }}, + } + + count, err := m.Reload(context.Background(), cfg) + require.Error(t, err) + require.Equal(t, 0, count, "no provider loaded when all fail discovery") + require.False(t, m.HasProviders()) +} + +func TestOAuthManager_ReloadPartialError(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.close) + + m := NewOAuthManager(mustCookieAuth(t)) + + // One good provider + one unreachable: good one loads, error still returned. + cfg := config.OAuthConfig{ + ExternalURL: "https://hotplex.example.com", + Providers: []config.OAuthProviderConfig{ + {Name: "good", Issuer: mock.issuer(), ClientID: mock.clientID}, + {Name: "bad", Issuer: "http://127.0.0.1:0", ClientID: "x", ClientSecret: "y"}, + }, + } + + count, err := m.Reload(context.Background(), cfg) + require.Error(t, err) + require.Equal(t, 1, count, "good provider must still load on partial failure") + require.True(t, m.HasProviders()) + + _, ok := m.Lookup("good") + require.True(t, ok) +} + +func TestOAuthManager_ReloadPreservesUnchangedProvider(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.close) + + m := NewOAuthManager(mustCookieAuth(t)) + + cfg := config.OAuthConfig{ + ExternalURL: "https://hotplex.example.com", + Providers: []config.OAuthProviderConfig{{ + Name: "keycloak", Issuer: mock.issuer(), ClientID: mock.clientID, ClientSecret: "s", + }}, + } + + first, err := m.Reload(context.Background(), cfg) + require.NoError(t, err) + require.Equal(t, 1, first) + original, _ := m.Lookup("keycloak") + + // Reload identical issuer+client_id → provider preserved (no re-discovery). + second, err := m.Reload(context.Background(), cfg) + require.NoError(t, err) + require.Equal(t, 1, second) + + preserved, ok := m.Lookup("keycloak") + require.True(t, ok) + require.Same(t, original, preserved, "unchanged provider must be preserved by identity, not re-discovered") +} + +func TestOAuthManager_ReloadRediscoverOnClientIDChange(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.close) + + m := NewOAuthManager(mustCookieAuth(t)) + + cfg := config.OAuthConfig{ + ExternalURL: "https://hotplex.example.com", + Providers: []config.OAuthProviderConfig{{ + Name: "keycloak", Issuer: mock.issuer(), ClientID: "client-v1", ClientSecret: "s", + }}, + } + _, err := m.Reload(context.Background(), cfg) + require.NoError(t, err) + original, _ := m.Lookup("keycloak") + + // Same name, different client_id → must re-discover a new provider. + cfg.Providers[0].ClientID = "client-v2" + count, err := m.Reload(context.Background(), cfg) + require.NoError(t, err) + require.Equal(t, 1, count) + + rediscovered, ok := m.Lookup("keycloak") + require.True(t, ok) + require.NotSame(t, original, rediscovered, "changed provider must be re-discovered into a new instance") +} + +func mustCookieAuth(t *testing.T) *CookieAuth { + t.Helper() + ca, err := NewCookieAuth() + require.NoError(t, err) + return ca +} diff --git a/internal/security/oauth_provider.go b/internal/security/oauth_provider.go new file mode 100644 index 00000000..0b798837 --- /dev/null +++ b/internal/security/oauth_provider.go @@ -0,0 +1,221 @@ +package security + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +// OAuthProviderConfig holds the configuration for a single OIDC provider. +// Defined here (not in config package) to avoid circular import: security +// needs these fields for OIDC client construction, and config imports nothing +// from security. The config package's OAuthProviderConfig is the YAML-facing +// struct; OAuthManager converts at construction time. +type OAuthProviderConfig struct { + Name string + DisplayName string + Issuer string + ClientID string + ClientSecret string + Scopes []string + UsernameClaim string + DisplayNameClaim string + EmailClaim string +} + +// OIDCClaims holds the extracted user info from an OIDC ID Token / UserInfo response. +type OIDCClaims struct { + Subject string + Username string + DisplayName string + Email string +} + +// OAuthProvider implements IdentityProvider for a single OIDC identity provider. +// It handles OIDC discovery, authorization URL construction, token exchange, +// and ID Token verification. Unlike LocalAccountProvider (synchronous password +// check), OAuth authentication is a multi-step redirect flow handled by +// OAuthHandlers; OAuthProvider.Authenticate is NOT used for the redirect flow. +// Instead, OAuthProvider exposes BuildAuthURL / ExchangeCode for the handlers, +// and GetOrCreateUser is called after claims extraction. +// +// OAuthProvider does NOT implement IdentityProvider.Authenticate because OIDC +// auth is redirect-based (not synchronous credential check). The identity +// layer (IdentityProvider interface) is for password login only. SSO login +// result (user_id) is produced directly by the handler → store flow. +type OAuthProvider struct { + config OAuthProviderConfig + provider *oidc.Provider + verifier *oidc.IDTokenVerifier + oauth2 *oauth2.Config +} + +// NewOAuthProvider discovers the OIDC provider endpoints and constructs a +// verified client. Returns error if discovery fails (IdP unreachable, invalid +// issuer URL, malformed discovery document). +func NewOAuthProvider(ctx context.Context, cfg OAuthProviderConfig, callbackURL string) (*OAuthProvider, error) { + provider, err := oidc.NewProvider(ctx, cfg.Issuer) + if err != nil { + return nil, fmt.Errorf("oauth provider %q: discovery failed for issuer %q: %w", cfg.Name, cfg.Issuer, err) + } + + oauth2Cfg := &oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + RedirectURL: callbackURL, + Endpoint: provider.Endpoint(), + Scopes: cfg.Scopes, + } + + verifier := provider.Verifier(&oidc.Config{ClientID: cfg.ClientID}) + + return &OAuthProvider{ + config: cfg, + provider: provider, + verifier: verifier, + oauth2: oauth2Cfg, + }, nil +} + +// Name returns the provider's unique name. +func (p *OAuthProvider) Name() string { return p.config.Name } + +// DisplayName returns the provider's display name (falls back to Name). +func (p *OAuthProvider) DisplayName() string { + if p.config.DisplayName != "" { + return p.config.DisplayName + } + return p.config.Name +} + +// Config returns the provider configuration (read-only by convention). +func (p *OAuthProvider) Config() OAuthProviderConfig { return p.config } + +// AuthURLOption holds PKCE + state parameters for a single OAuth flow attempt. +type AuthURLOption struct { + State string + CodeChallenge string +} + +// BuildAuthURL constructs the OIDC authorization redirect URL with PKCE. +func (p *OAuthProvider) BuildAuthURL(opts AuthURLOption) string { + return p.oauth2.AuthCodeURL(opts.State, + oauth2.SetAuthURLParam("code_challenge", opts.CodeChallenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + ) +} + +// ExchangeResult holds the output of a successful token exchange. +type ExchangeResult struct { + IDToken string + AccessToken string + RefreshToken string // may be empty if IdP doesn't return one +} + +// ExchangeCode exchanges the authorization code for tokens. The codeVerifier +// must match the code_challenge sent in BuildAuthURL. +func (p *OAuthProvider) ExchangeCode(ctx context.Context, code, codeVerifier string) (*ExchangeResult, error) { + token, err := p.oauth2.Exchange(ctx, code, + oauth2.SetAuthURLParam("code_verifier", codeVerifier), + ) + if err != nil { + return nil, fmt.Errorf("oauth provider %q: token exchange failed: %w", p.config.Name, err) + } + + result := &ExchangeResult{ + AccessToken: token.AccessToken, + RefreshToken: token.RefreshToken, + } + + // Extract raw ID Token from the token response. + rawIDToken, ok := token.Extra("id_token").(string) + if !ok || rawIDToken == "" { + return nil, fmt.Errorf("oauth provider %q: no id_token in token response", p.config.Name) + } + result.IDToken = rawIDToken + + return result, nil +} + +// VerifyAndExtractClaims verifies the ID Token signature and extracts user claims. +// It validates: signature (JWKS), issuer, audience, expiration. +func (p *OAuthProvider) VerifyAndExtractClaims(ctx context.Context, rawIDToken string) (*OIDCClaims, error) { + idToken, err := p.verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, fmt.Errorf("oauth provider %q: id_token verification failed: %w", p.config.Name, err) + } + + claims := &OIDCClaims{Subject: idToken.Subject} + + // Extract custom claims from the ID Token payload. + // Use configured claim names with OIDC standard fallbacks. + var raw map[string]any + if err := idToken.Claims(&raw); err != nil { + return nil, fmt.Errorf("oauth provider %q: parse claims: %w", p.config.Name, err) + } + + claims.Username = claimString(raw, p.config.UsernameClaim, "preferred_username") + claims.DisplayName = claimString(raw, p.config.DisplayNameClaim, "name") + claims.Email = claimString(raw, p.config.EmailClaim, "email") + + if claims.Username == "" { + // Fallback: use subject as username if preferred_username not present. + claims.Username = idToken.Subject + } + + return claims, nil +} + +// claimString extracts a string claim from the raw claims map, trying the +// configured name first, then the default name, then returning "". +func claimString(claims map[string]any, configuredName, defaultName string) string { + name := defaultName + if configuredName != "" { + name = configuredName + } + if v, ok := claims[name]; ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + // Fallback to default if configured name didn't match. + if configuredName != "" && configuredName != defaultName { + if v, ok := claims[defaultName]; ok { + if s, ok := v.(string); ok { + return strings.TrimSpace(s) + } + } + } + return "" +} + +// OAuthError wraps OIDC flow errors with a provider name for diagnostics. +type OAuthError struct { + Provider string + Err error +} + +func (e *OAuthError) Error() string { + return fmt.Sprintf("oauth[%s]: %v", e.Provider, e.Err) +} + +func (e *OAuthError) Unwrap() error { return e.Err } + +// IsDiscoveryError returns true if the error is caused by OIDC discovery failure +// (IdP unreachable or misconfigured). +func IsDiscoveryError(err error) bool { + if err == nil { + return false + } + var oe *OAuthError + if errors.As(err, &oe) { + err = oe.Err + } + return strings.Contains(err.Error(), "discovery failed") || + strings.Contains(err.Error(), "oidc:") || + strings.Contains(err.Error(), "unsuccessful") +} diff --git a/internal/security/oauth_provider_test.go b/internal/security/oauth_provider_test.go new file mode 100644 index 00000000..c69cb190 --- /dev/null +++ b/internal/security/oauth_provider_test.go @@ -0,0 +1,194 @@ +package security + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" +) + +// mockOIDCServer simulates a full OIDC IdP for integration testing. +type mockOIDCServer struct { + t *testing.T + server *httptest.Server + signingKey *rsa.PrivateKey + jwksPublicKey jose.JSONWebKeySet + clientID string + expectedSub string + expectedClaims map[string]any + codeVerifier string // captured during token exchange +} + +func newMockOIDCServer(t *testing.T) *mockOIDCServer { + t.Helper() + // Generate EC P-256 signing key (ES256). + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + pubJWK := jose.JSONWebKey{ + Key: &privKey.PublicKey, + KeyID: "test-key-1", + Algorithm: string(jose.RS256), + Use: "sig", + } + + var jwks jose.JSONWebKeySet + jwks.Keys = append(jwks.Keys, pubJWK) + + m := &mockOIDCServer{ + t: t, + signingKey: privKey, + jwksPublicKey: jwks, + clientID: "test-client-id", + expectedSub: "user-sub-123", + expectedClaims: map[string]any{ + "sub": "user-sub-123", + "preferred_username": "alice", + "name": "Alice Wonderland", + "email": "alice@example.com", + }, + } + + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", m.handleDiscovery) + mux.HandleFunc("/oauth/token", m.handleToken) + mux.HandleFunc("/oauth/jwks", m.handleJWKS) + + m.server = httptest.NewServer(mux) + return m +} + +func (m *mockOIDCServer) close() { m.server.Close() } + +func (m *mockOIDCServer) issuer() string { return m.server.URL } + +func (m *mockOIDCServer) handleDiscovery(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "issuer": m.server.URL, + "authorization_endpoint": m.server.URL + "/oauth/authorize", + "token_endpoint": m.server.URL + "/oauth/token", + "jwks_uri": m.server.URL + "/oauth/jwks", + "userinfo_endpoint": m.server.URL + "/oauth/userinfo", + }) +} + +func (m *mockOIDCServer) handleToken(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + verifier := r.FormValue("code_verifier") + if verifier == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + m.codeVerifier = verifier + + // Sign an ID Token with the test key. + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: m.signingKey}, + (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "test-key-1")) + require.NoError(m.t, err) + + now := time.Now() + claims := map[string]any{ + "iss": m.server.URL, + "sub": m.expectedSub, + "aud": m.clientID, + "exp": now.Add(time.Hour).Unix(), + "iat": now.Unix(), + } + for k, v := range m.expectedClaims { + claims[k] = v + } + + payload, _ := json.Marshal(claims) + jws, err := signer.Sign(payload) + require.NoError(m.t, err) + idToken, err := jws.CompactSerialize() + require.NoError(m.t, err) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "fake-access-token", + "token_type": "Bearer", + "id_token": idToken, + }) +} + +func (m *mockOIDCServer) handleJWKS(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(m.jwksPublicKey) +} + +func TestOAuthProvider_DiscoveryAndClaims(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + defer mock.close() + + callbackURL := "https://hotplex.example.com/api/auth/oauth/keycloak/callback" + + op, err := NewOAuthProvider(context.Background(), OAuthProviderConfig{ + Name: "keycloak", + Issuer: mock.issuer(), + ClientID: mock.clientID, + ClientSecret: "test-secret", + Scopes: []string{"openid", "profile"}, + }, callbackURL) + require.NoError(t, err, "OIDC discovery must succeed against mock IdP") + require.Equal(t, "keycloak", op.Name()) + + // Build auth URL. + state, verifier, challenge, err := GenerateStateAndVerifier() + require.NoError(t, err) + + authURL := op.BuildAuthURL(AuthURLOption{State: state, CodeChallenge: challenge}) + require.Contains(t, authURL, "response_type=code") + require.Contains(t, authURL, "client_id="+mock.clientID) + require.Contains(t, authURL, "state="+state) + require.Contains(t, authURL, "code_challenge_method=S256") + require.Contains(t, authURL, mock.server.URL+"/oauth/authorize") + + // Exchange code — mock IdP verifies the PKCE code_verifier internally. + // We can't easily test the full redirect, so we test ExchangeCode + VerifyAndExtractClaims. + // The mock token endpoint signs a real JWT that go-oidc will verify. + exchange, err := op.ExchangeCode(context.Background(), "fake-auth-code", verifier) + require.NoError(t, err) + require.NotEmpty(t, exchange.IDToken) + require.Equal(t, verifier, mock.codeVerifier, "PKCE verifier must be sent to token endpoint") + + // Verify ID Token and extract claims. + claims, err := op.VerifyAndExtractClaims(context.Background(), exchange.IDToken) + require.NoError(t, err) + require.Equal(t, "user-sub-123", claims.Subject) + require.Equal(t, "alice", claims.Username) + require.Equal(t, "Alice Wonderland", claims.DisplayName) + require.Equal(t, "alice@example.com", claims.Email) +} + +func TestOAuthProvider_DiscoveryFailed(t *testing.T) { + t.Parallel() + _, err := NewOAuthProvider(context.Background(), OAuthProviderConfig{ + Name: "bad", Issuer: "http://127.0.0.1:0", // unreachable + ClientID: "x", ClientSecret: "y", + }, "https://example.com/callback") + require.Error(t, err) + require.True(t, IsDiscoveryError(err), "unreachable IdP should be a discovery error") +} + +func TestOAuthProvider_ClaimFallback(t *testing.T) { + t.Parallel() + // Test claimString fallback logic directly. + claims := map[string]any{ + "preferred_username": "bob", + "name": "Bob", + } + require.Equal(t, "bob", claimString(claims, "", "preferred_username")) + require.Equal(t, "bob", claimString(claims, "custom", "preferred_username")) + require.Equal(t, "", claimString(claims, "custom", "default")) +} diff --git a/internal/security/oauth_state.go b/internal/security/oauth_state.go new file mode 100644 index 00000000..5f0e5313 --- /dev/null +++ b/internal/security/oauth_state.go @@ -0,0 +1,162 @@ +package security + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "fmt" + "net/http" + "strings" + "time" +) + +const ( + // stateCookieName is the HTTP cookie name for the OAuth state parameter. + stateCookieName = "oauth_state" + + // stateCookieTTL is the maximum lifetime of an OAuth flow (5 minutes). + // The state cookie must outlive the user's interaction with the IdP login page. + stateCookieTTL = 5 * time.Minute + + // pkceVerifierLen is the length of the PKCE code verifier in bytes (64). + // RFC 7636 recommends 43-128 characters; 64 bytes hex-encoded = 128 chars. + pkceVerifierLen = 64 +) + +// StateCookiePayload holds the data stored in the signed OAuth state cookie. +type StateCookiePayload struct { + State string + CodeVerifier string + Provider string + IssuedAt time.Time +} + +// GenerateStateAndVerifier produces a cryptographically random state parameter +// and PKCE code_verifier. The code_challenge (S256) is derived from the verifier. +func GenerateStateAndVerifier() (state, codeVerifier, codeChallenge string, err error) { + stateBytes := make([]byte, 32) + if _, err = rand.Read(stateBytes); err != nil { + return "", "", "", fmt.Errorf("security: generate oauth state: %w", err) + } + state = hex.EncodeToString(stateBytes) + + verifierBytes := make([]byte, pkceVerifierLen) + if _, err = rand.Read(verifierBytes); err != nil { + return "", "", "", fmt.Errorf("security: generate pkce verifier: %w", err) + } + codeVerifier = hex.EncodeToString(verifierBytes) + + hash := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hash[:]) + + return state, codeVerifier, codeChallenge, nil +} + +// SetStateCookie signs and sets the OAuth state cookie on the response. +// The cookie value is HMAC-signed (using the same secret as session cookies) +// to prevent tampering. Format: Base64(state|verifier|provider|timestamp|hmac). +func SetStateCookie(w http.ResponseWriter, r *http.Request, ca *CookieAuth, payload StateCookiePayload) { + ts := payload.IssuedAt.Unix() + raw := fmt.Sprintf("%s|%s|%s|%d", payload.State, payload.CodeVerifier, payload.Provider, ts) + + mac := hmac.New(sha256.New, ca.secret) + mac.Write([]byte(raw)) + sig := hex.EncodeToString(mac.Sum(nil)) + + value := base64.RawURLEncoding.EncodeToString([]byte(raw + "|" + sig)) + + http.SetCookie(w, &http.Cookie{ + Name: stateCookieName, + Value: value, + Path: "/", + MaxAge: int(stateCookieTTL.Seconds()), + HttpOnly: true, + Secure: isHTTPS(r), + SameSite: http.SameSiteLaxMode, // Lax: allow redirect back from IdP + }) +} + +// VerifyStateCookie reads and validates the OAuth state cookie, checking: +// 1. Cookie exists and HMAC signature is valid (not tampered). +// 2. Cookie has not expired (within stateCookieTTL). +// 3. The state parameter matches what was stored. +// 4. The provider matches the expected provider (path injection defense). +func VerifyStateCookie(r *http.Request, ca *CookieAuth, expectedState, expectedProvider string) (*StateCookiePayload, error) { + cookie, err := r.Cookie(stateCookieName) + if err != nil { + return nil, fmt.Errorf("oauth state: cookie missing") + } + + raw, err := base64.RawURLEncoding.DecodeString(cookie.Value) + if err != nil { + return nil, fmt.Errorf("oauth state: decode failed") + } + + // Split: state|verifier|provider|timestamp|hmac + parts := strings.SplitN(string(raw), "|", 5) + if len(parts) != 5 { + return nil, fmt.Errorf("oauth state: malformed cookie") + } + + state, verifier, provider, tsStr, sigHex := parts[0], parts[1], parts[2], parts[3], parts[4] + + // Verify HMAC signature. + sig, err := hex.DecodeString(sigHex) + if err != nil { + return nil, fmt.Errorf("oauth state: invalid signature encoding") + } + payload := state + "|" + verifier + "|" + provider + "|" + tsStr + mac := hmac.New(sha256.New, ca.secret) + mac.Write([]byte(payload)) + expected := mac.Sum(nil) + if !hmac.Equal(sig, expected) { + return nil, fmt.Errorf("oauth state: signature mismatch") + } + + // Check expiry. + ts, err := parseTimestamp(tsStr) + if err != nil { + return nil, fmt.Errorf("oauth state: invalid timestamp") + } + if time.Since(ts) > stateCookieTTL { + return nil, fmt.Errorf("oauth state: expired") + } + + // Verify state and provider match. + if state != expectedState { + return nil, fmt.Errorf("oauth state: csrf detected") + } + if provider != expectedProvider { + return nil, fmt.Errorf("oauth state: provider mismatch") + } + + return &StateCookiePayload{ + State: state, + CodeVerifier: verifier, + Provider: provider, + IssuedAt: ts, + }, nil +} + +// ClearStateCookie expires the OAuth state cookie. +func ClearStateCookie(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{ + Name: stateCookieName, + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: isHTTPS(r), + SameSite: http.SameSiteLaxMode, + }) +} + +func parseTimestamp(s string) (time.Time, error) { + var ts int64 + if _, err := fmt.Sscanf(s, "%d", &ts); err != nil { + return time.Time{}, err + } + return time.Unix(ts, 0), nil +} diff --git a/internal/security/oauth_state_test.go b/internal/security/oauth_state_test.go new file mode 100644 index 00000000..ea5da47e --- /dev/null +++ b/internal/security/oauth_state_test.go @@ -0,0 +1,153 @@ +package security + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestGenerateStateAndVerifier(t *testing.T) { + t.Parallel() + state, verifier, challenge, err := GenerateStateAndVerifier() + require.NoError(t, err) + require.NotEmpty(t, state) + require.NotEmpty(t, verifier) + require.NotEmpty(t, challenge) + require.Len(t, state, 64, "state is 32 bytes hex = 64 chars") + + // Two calls must produce different values. + state2, _, _, _ := GenerateStateAndVerifier() + require.NotEqual(t, state, state2) +} + +func TestStateCookie_SetAndVerify(t *testing.T) { + t.Parallel() + ca, err := NewCookieAuth() + require.NoError(t, err) + + state, verifier, _, _ := GenerateStateAndVerifier() + payload := StateCookiePayload{ + State: state, + CodeVerifier: verifier, + Provider: "keycloak", + IssuedAt: time.Now(), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/api/auth/oauth/keycloak/login", nil) + SetStateCookie(w, r, ca, payload) + + // Extract the cookie and set it on a new request. + r2 := httptest.NewRequest(http.MethodGet, "/api/auth/oauth/keycloak/callback?code=x&state="+state, nil) + for _, c := range w.Result().Cookies() { + r2.AddCookie(c) + } + + got, err := VerifyStateCookie(r2, ca, state, "keycloak") + require.NoError(t, err) + require.Equal(t, state, got.State) + require.Equal(t, verifier, got.CodeVerifier) + require.Equal(t, "keycloak", got.Provider) +} + +func TestStateCookie_CSRFDetected(t *testing.T) { + t.Parallel() + ca, _ := NewCookieAuth() + + state, verifier, _, _ := GenerateStateAndVerifier() + payload := StateCookiePayload{ + State: state, CodeVerifier: verifier, Provider: "kc", IssuedAt: time.Now(), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + SetStateCookie(w, r, ca, payload) + + r2 := httptest.NewRequest(http.MethodGet, "/callback?code=x&state=WRONG", nil) + for _, c := range w.Result().Cookies() { + r2.AddCookie(c) + } + + _, err := VerifyStateCookie(r2, ca, "WRONG", "kc") + require.Error(t, err) + require.Contains(t, err.Error(), "csrf") +} + +func TestStateCookie_ProviderMismatch(t *testing.T) { + t.Parallel() + ca, _ := NewCookieAuth() + + state, verifier, _, _ := GenerateStateAndVerifier() + payload := StateCookiePayload{ + State: state, CodeVerifier: verifier, Provider: "kc", IssuedAt: time.Now(), + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + SetStateCookie(w, r, ca, payload) + + r2 := httptest.NewRequest(http.MethodGet, "/callback", nil) + for _, c := range w.Result().Cookies() { + r2.AddCookie(c) + } + + _, err := VerifyStateCookie(r2, ca, state, "different_provider") + require.Error(t, err) + require.Contains(t, err.Error(), "provider mismatch") +} + +func TestStateCookie_Expired(t *testing.T) { + t.Parallel() + ca, _ := NewCookieAuth() + + state, verifier, _, _ := GenerateStateAndVerifier() + payload := StateCookiePayload{ + State: state, CodeVerifier: verifier, Provider: "kc", + IssuedAt: time.Now().Add(-10 * time.Minute), // expired + } + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + SetStateCookie(w, r, ca, payload) + + r2 := httptest.NewRequest(http.MethodGet, "/callback", nil) + for _, c := range w.Result().Cookies() { + r2.AddCookie(c) + } + + _, err := VerifyStateCookie(r2, ca, state, "kc") + require.Error(t, err) + require.Contains(t, err.Error(), "expired") +} + +func TestStateCookie_Tampered(t *testing.T) { + t.Parallel() + ca, _ := NewCookieAuth() + + w := httptest.NewRecorder() + // Set a tampered cookie directly. + http.SetCookie(w, &http.Cookie{ + Name: stateCookieName, + Value: "tampered_value", + }) + + r2 := httptest.NewRequest(http.MethodGet, "/callback", nil) + for _, c := range w.Result().Cookies() { + r2.AddCookie(c) + } + + _, err := VerifyStateCookie(r2, ca, "any", "any") + require.Error(t, err) +} + +func TestStateCookie_Missing(t *testing.T) { + t.Parallel() + ca, _ := NewCookieAuth() + r := httptest.NewRequest(http.MethodGet, "/callback", nil) + _, err := VerifyStateCookie(r, ca, "any", "any") + require.Error(t, err) + require.Contains(t, err.Error(), "cookie missing") +} diff --git a/internal/session/multitenancy_pg_store.go b/internal/session/multitenancy_pg_store.go index 3d205b0b..11f6fbf9 100644 --- a/internal/session/multitenancy_pg_store.go +++ b/internal/session/multitenancy_pg_store.go @@ -103,6 +103,24 @@ func (s *pgStore) ListWorkspacesByOwner(ctx context.Context, ownerUserID string) return out, rows.Err() } +// ListAllWorkspaces returns all active workspaces regardless of owner (PG backend). +func (s *pgStore) ListAllWorkspaces(ctx context.Context) ([]*Workspace, error) { + rows, err := s.db.QueryContext(ctx, s.queries["workspaces.list_all"]) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []*Workspace + for rows.Next() { + w, err := scanWorkspace(rows) + if err != nil { + return nil, err + } + out = append(out, w) + } + return out, rows.Err() +} + func (s *pgStore) GetWorkspaceByOwnerAndWorkDir(ctx context.Context, ownerUserID, workDir string) (*Workspace, error) { w, err := scanWorkspace(s.db.QueryRowContext(ctx, s.queries["workspaces.get_by_owner_and_workdir"], ownerUserID, workDir)) if errors.Is(err, sql.ErrNoRows) { @@ -199,3 +217,24 @@ func (s *pgStore) DeleteInvitation(ctx context.Context, id string) error { _, err := s.db.ExecContext(ctx, s.queries["invitations.delete"], id) return err } + +// --- pgStore: user identities (spec ④) --- + +func (s *pgStore) GetUserIdentityByProviderSubject(ctx context.Context, provider, subject string) (*UserIdentity, error) { + id, err := scanIdentity(s.db.QueryRowContext(ctx, s.queries["identities.get_by_provider_subject"], provider, subject)) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrIdentityNotFound + } + return id, err +} + +func (s *pgStore) CreateUserIdentity(ctx context.Context, id *UserIdentity, now int64) error { + _, err := s.db.ExecContext(ctx, s.queries["identities.create"], + id.ID, id.UserID, id.Provider, id.Subject, id.DisplayName, id.Email, now, now) + return err +} + +func (s *pgStore) UpdateUserIdentityProfile(ctx context.Context, id, displayName, email string, now int64) error { + _, err := s.db.ExecContext(ctx, s.queries["identities.update_profile"], displayName, email, now, id) + return err +} diff --git a/internal/session/multitenancy_store.go b/internal/session/multitenancy_store.go index 11e4e427..b24b1036 100644 --- a/internal/session/multitenancy_store.go +++ b/internal/session/multitenancy_store.go @@ -32,12 +32,27 @@ type Invitation struct { UsedAt *int64 // nil = unused } +// UserIdentity binds an OAuth/OIDC identity to a local user (spec ④). +// One user may have multiple identities (different IdPs); each (provider, subject) +// pair uniquely maps to a single user_id via UNIQUE constraint. +type UserIdentity struct { + ID string // UUID + UserID string // FK → users.id + Provider string // provider name (config key) + Subject string // IdP subject (OIDC "sub" claim) + DisplayName string // synced from IdP + Email string // synced from IdP (not used for auto-merge) + CreatedAt int64 + UpdatedAt int64 +} + // Multitenancy store sentinels. var ( ErrWorkspaceNotFound = errors.New("session: workspace not found") ErrWorkspaceNotEmpty = errors.New("session: workspace has active sessions") ErrInvitationNotFound = errors.New("session: invitation not found") ErrInvitationAlreadyUsed = errors.New("session: invitation already used") + ErrIdentityNotFound = errors.New("session: user identity not found") ) // UserWorkspaceStore is the store capability surface used by gateway auth/workspace @@ -54,6 +69,7 @@ type UserWorkspaceStore interface { CreateWorkspace(ctx context.Context, w *Workspace, now int64) error GetWorkspaceByID(ctx context.Context, id string) (*Workspace, error) ListWorkspacesByOwner(ctx context.Context, ownerUserID string) ([]*Workspace, error) + ListAllWorkspaces(ctx context.Context) ([]*Workspace, error) GetWorkspaceByOwnerAndWorkDir(ctx context.Context, ownerUserID, workDir string) (*Workspace, error) UpdateWorkspace(ctx context.Context, w *Workspace, now int64) error DeleteWorkspace(ctx context.Context, id string) error @@ -68,6 +84,10 @@ type UserWorkspaceStore interface { SetInvitationUsedBy(ctx context.Context, id, oldUsedBy, newUsedBy string) error ListInvitations(ctx context.Context, limit, offset int) ([]*Invitation, error) DeleteInvitation(ctx context.Context, id string) error + // user identities (spec ④ SSO) + GetUserIdentityByProviderSubject(ctx context.Context, provider, subject string) (*UserIdentity, error) + CreateUserIdentity(ctx context.Context, id *UserIdentity, now int64) error + UpdateUserIdentityProfile(ctx context.Context, id, displayName, email string, now int64) error } // Compile-time assertions that both stores satisfy UserWorkspaceStore. @@ -98,6 +118,16 @@ func scanUser(sc rowScanner) (*security.User, error) { return &u, nil } +func scanIdentity(sc rowScanner) (*UserIdentity, error) { + var id UserIdentity + err := sc.Scan(&id.ID, &id.UserID, &id.Provider, &id.Subject, &id.DisplayName, + &id.Email, &id.CreatedAt, &id.UpdatedAt) + if err != nil { + return nil, err + } + return &id, nil +} + func scanWorkspace(sc rowScanner) (*Workspace, error) { var w Workspace var overrides, pref sql.NullString @@ -233,6 +263,25 @@ func (s *SQLiteStore) ListWorkspacesByOwner(ctx context.Context, ownerUserID str return out, rows.Err() } +// ListAllWorkspaces returns all active workspaces regardless of owner. Used by the +// gateway startup scan to detect stale/invalid agent_config_overrides (spec ② #749). +func (s *SQLiteStore) ListAllWorkspaces(ctx context.Context) ([]*Workspace, error) { + rows, err := s.db.QueryContext(ctx, queries["workspaces.list_all"]) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []*Workspace + for rows.Next() { + w, err := scanWorkspace(rows) + if err != nil { + return nil, err + } + out = append(out, w) + } + return out, rows.Err() +} + func (s *SQLiteStore) GetWorkspaceByOwnerAndWorkDir(ctx context.Context, ownerUserID, workDir string) (*Workspace, error) { w, err := scanWorkspace(s.db.QueryRowContext(ctx, queries["workspaces.get_by_owner_and_workdir"], ownerUserID, workDir)) if errors.Is(err, sql.ErrNoRows) { @@ -344,3 +393,28 @@ func (s *SQLiteStore) DeleteInvitation(ctx context.Context, id string) error { return err }) } + +// --- SQLiteStore: user identities (spec ④) --- + +func (s *SQLiteStore) GetUserIdentityByProviderSubject(ctx context.Context, provider, subject string) (*UserIdentity, error) { + id, err := scanIdentity(s.db.QueryRowContext(ctx, queries["identities.get_by_provider_subject"], provider, subject)) + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrIdentityNotFound + } + return id, err +} + +func (s *SQLiteStore) CreateUserIdentity(ctx context.Context, id *UserIdentity, now int64) error { + return s.writeMu.WithLock(func() error { + _, err := s.db.ExecContext(ctx, queries["identities.create"], + id.ID, id.UserID, id.Provider, id.Subject, id.DisplayName, id.Email, now, now) + return err + }) +} + +func (s *SQLiteStore) UpdateUserIdentityProfile(ctx context.Context, id, displayName, email string, now int64) error { + return s.writeMu.WithLock(func() error { + _, err := s.db.ExecContext(ctx, queries["identities.update_profile"], displayName, email, now, id) + return err + }) +} diff --git a/internal/session/multitenancy_store_test.go b/internal/session/multitenancy_store_test.go index dbe448cf..f2b37a88 100644 --- a/internal/session/multitenancy_store_test.go +++ b/internal/session/multitenancy_store_test.go @@ -109,6 +109,27 @@ func TestWorkspacesStore_ListByOwnerIsolated(t *testing.T) { } } +func TestWorkspacesStore_ListAllWorkspaces(t *testing.T) { + t.Parallel() + store, _ := helperDB(t) + ctx := context.Background() + require.NoError(t, store.CreateUser(ctx, &security.User{ID: "u-1", Username: "alice", Role: "user", Status: "active"}, 1700000000)) + require.NoError(t, store.CreateUser(ctx, &security.User{ID: "u-2", Username: "bob", Role: "user", Status: "active"}, 1700000000)) + require.NoError(t, store.CreateWorkspace(ctx, &Workspace{ID: "ws-1", OwnerUserID: "u-1", Name: "a", WorkDir: "/tmp/a", AgentConfigOverrides: `{"SOUL.md":"x"}`}, 1700000000)) + require.NoError(t, store.CreateWorkspace(ctx, &Workspace{ID: "ws-2", OwnerUserID: "u-2", Name: "b", WorkDir: "/tmp/b"}, 1700000000)) + + all, err := store.ListAllWorkspaces(ctx) + require.NoError(t, err) + require.Len(t, all, 2, "ListAllWorkspaces returns workspaces across all owners") + ids := map[string]bool{} + for _, w := range all { + ids[w.ID] = true + require.Equal(t, "active", w.Status) + } + require.True(t, ids["ws-1"]) + require.True(t, ids["ws-2"]) +} + func TestWorkspacesStore_CountActiveSessions(t *testing.T) { t.Parallel() store, _ := helperDB(t) @@ -203,3 +224,88 @@ func TestInvitationsStore_List(t *testing.T) { require.NoError(t, err) require.Len(t, list, 2) } + +// --- user identities (spec ④) --- + +func TestIdentities_CreateAndGet(t *testing.T) { + t.Parallel() + store, _ := helperDB(t) + ctx := context.Background() + + // Create a user first. + u := &security.User{ID: "u-sso-1", Username: "keycloak:sub123", Role: "user", Status: "active"} + require.NoError(t, store.CreateUser(ctx, u, 1700000000)) + + // Create an identity. + ident := &UserIdentity{ + ID: "ident-1", + UserID: "u-sso-1", + Provider: "keycloak", + Subject: "sub123", + DisplayName: "Alice", + Email: "alice@example.com", + } + require.NoError(t, store.CreateUserIdentity(ctx, ident, 1700000001)) + + // Lookup by (provider, subject). + got, err := store.GetUserIdentityByProviderSubject(ctx, "keycloak", "sub123") + require.NoError(t, err) + require.Equal(t, "ident-1", got.ID) + require.Equal(t, "u-sso-1", got.UserID) + require.Equal(t, "Alice", got.DisplayName) +} + +func TestIdentities_GetByProviderSubject_NotFound(t *testing.T) { + t.Parallel() + store, _ := helperDB(t) + _, err := store.GetUserIdentityByProviderSubject(context.Background(), "keycloak", "nonexistent") + require.ErrorIs(t, err, ErrIdentityNotFound) +} + +func TestIdentities_UpdateProfile(t *testing.T) { + t.Parallel() + store, _ := helperDB(t) + ctx := context.Background() + + u := &security.User{ID: "u-sso-2", Username: "authing:sub456", Role: "user", Status: "active"} + require.NoError(t, store.CreateUser(ctx, u, 1700000000)) + + ident := &UserIdentity{ + ID: "ident-2", + UserID: "u-sso-2", + Provider: "authing", + Subject: "sub456", + DisplayName: "Bob", + Email: "bob@old.com", + } + require.NoError(t, store.CreateUserIdentity(ctx, ident, 1700000001)) + + // Update profile from IdP. + require.NoError(t, store.UpdateUserIdentityProfile(ctx, "ident-2", "Bob Smith", "bob@new.com", 1700000002)) + + got, err := store.GetUserIdentityByProviderSubject(ctx, "authing", "sub456") + require.NoError(t, err) + require.Equal(t, "Bob Smith", got.DisplayName) + require.Equal(t, "bob@new.com", got.Email) +} + +func TestIdentities_UniqueProviderSubject(t *testing.T) { + t.Parallel() + store, _ := helperDB(t) + ctx := context.Background() + + u := &security.User{ID: "u-sso-3", Username: "oidc:sub789", Role: "user", Status: "active"} + require.NoError(t, store.CreateUser(ctx, u, 1700000000)) + + ident1 := &UserIdentity{ + ID: "ident-3a", UserID: "u-sso-3", Provider: "oidc", Subject: "sub789", + } + require.NoError(t, store.CreateUserIdentity(ctx, ident1, 1700000001)) + + // Same provider+subject should fail (UNIQUE constraint). + ident2 := &UserIdentity{ + ID: "ident-3b", UserID: "u-sso-3", Provider: "oidc", Subject: "sub789", + } + err := store.CreateUserIdentity(ctx, ident2, 1700000002) + require.Error(t, err, "duplicate provider+subject must fail") +} diff --git a/internal/session/sql/migrations-postgres/020_user_identities.pg.sql b/internal/session/sql/migrations-postgres/020_user_identities.pg.sql new file mode 100644 index 00000000..d79247f6 --- /dev/null +++ b/internal/session/sql/migrations-postgres/020_user_identities.pg.sql @@ -0,0 +1,24 @@ +-- +goose Up +-- WebChat 多租户 spec ④:企业 SSO(OIDC 统一认证) +-- user_identities 将 OAuth 身份与 users 解耦:一个用户可关联多个 IdP。 +-- UNIQUE(provider, subject) 保证 SSO 登录确定性映射到唯一 user_id。 +-- users 表不增加字段;密码账号无 user_identities 行,完全向后兼容。 + +CREATE TABLE user_identities ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id), + provider TEXT NOT NULL, + subject TEXT NOT NULL, + display_name TEXT NOT NULL DEFAULT '', + email TEXT NOT NULL DEFAULT '', + created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, + UNIQUE(provider, subject) +); +CREATE INDEX IF NOT EXISTS idx_user_identities_user_id ON user_identities(user_id); +CREATE INDEX IF NOT EXISTS idx_user_identities_lookup ON user_identities(provider, subject); + +-- +goose Down +DROP INDEX IF EXISTS idx_user_identities_lookup; +DROP INDEX IF EXISTS idx_user_identities_user_id; +DROP TABLE IF EXISTS user_identities; diff --git a/internal/session/sql/migrations/020_user_identities.sql b/internal/session/sql/migrations/020_user_identities.sql new file mode 100644 index 00000000..3c904778 --- /dev/null +++ b/internal/session/sql/migrations/020_user_identities.sql @@ -0,0 +1,24 @@ +-- +goose Up +-- WebChat 多租户 spec ④:企业 SSO(OIDC 统一认证) +-- user_identities 将 OAuth 身份与 users 解耦:一个用户可关联多个 IdP。 +-- UNIQUE(provider, subject) 保证 SSO 登录确定性映射到唯一 user_id。 +-- users 表不增加字段;密码账号无 user_identities 行,完全向后兼容。 + +CREATE TABLE user_identities ( + id TEXT PRIMARY KEY, -- UUID + user_id TEXT NOT NULL REFERENCES users(id), + provider TEXT NOT NULL, -- provider name(config key) + subject TEXT NOT NULL, -- IdP subject(OIDC "sub" claim) + display_name TEXT NOT NULL DEFAULT '', -- 从 IdP 同步 + email TEXT NOT NULL DEFAULT '', -- 从 IdP 同步(仅记录,不用于自动合并) + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + UNIQUE(provider, subject) +); +CREATE INDEX idx_user_identities_user_id ON user_identities(user_id); +CREATE INDEX idx_user_identities_lookup ON user_identities(provider, subject); + +-- +goose Down +DROP INDEX IF EXISTS idx_user_identities_lookup; +DROP INDEX IF EXISTS idx_user_identities_user_id; +DROP TABLE IF EXISTS user_identities; diff --git a/internal/session/sql/queries/identities.create.sql b/internal/session/sql/queries/identities.create.sql new file mode 100644 index 00000000..5fba7627 --- /dev/null +++ b/internal/session/sql/queries/identities.create.sql @@ -0,0 +1,3 @@ +-- identities.create: insert a new OAuth identity binding. +INSERT INTO user_identities (id, user_id, provider, subject, display_name, email, created_at, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?, ?) diff --git a/internal/session/sql/queries/identities.get_by_provider_subject.sql b/internal/session/sql/queries/identities.get_by_provider_subject.sql new file mode 100644 index 00000000..2102d26d --- /dev/null +++ b/internal/session/sql/queries/identities.get_by_provider_subject.sql @@ -0,0 +1,3 @@ +-- identities.get_by_provider_subject: lookup OAuth identity by (provider, sub) for SSO login. +SELECT id, user_id, provider, subject, display_name, email, created_at, updated_at +FROM user_identities WHERE provider = ? AND subject = ? diff --git a/internal/session/sql/queries/identities.update_profile.sql b/internal/session/sql/queries/identities.update_profile.sql new file mode 100644 index 00000000..3609533c --- /dev/null +++ b/internal/session/sql/queries/identities.update_profile.sql @@ -0,0 +1,2 @@ +-- identities.update_profile: sync display_name/email from IdP on each login. +UPDATE user_identities SET display_name = ?, email = ?, updated_at = ? WHERE id = ? diff --git a/internal/session/sql/queries/workspaces.list_all.sql b/internal/session/sql/queries/workspaces.list_all.sql new file mode 100644 index 00000000..5ca14519 --- /dev/null +++ b/internal/session/sql/queries/workspaces.list_all.sql @@ -0,0 +1,3 @@ +-- workspaces.list_all: list all active workspaces (for startup validation scan, spec ② #749). +SELECT id, owner_user_id, name, work_dir, agent_config_overrides, worker_preference, status, created_at, updated_at +FROM workspaces WHERE status = 'active' ORDER BY created_at ASC