From d70b1ddaac7b3f6e79f03035b8bb76a61e0c7c15 Mon Sep 17 00:00:00 2001 From: Jaseem Jas Date: Sat, 21 Feb 2026 18:30:22 +0530 Subject: [PATCH 1/3] chore: add pre-commit hooks with Ruff, ESLint, and Prettier Set up automated code quality checks using the pre-commit framework: - Backend: Ruff lint + format (backend/, code-explorer/) - Frontend: ESLint (migrated to flat config) + Prettier with Tailwind plugin - General: trailing whitespace, EOF fixer, YAML/JSON checks, merge conflict detection - CI: GitHub Actions lint workflow for PRs and pushes to main - Developer setup: scripts/setup-hooks.sh for one-command onboarding --- .github/workflows/lint.yml | 39 +++++++++++++ .gitignore | 3 + .pre-commit-config.yaml | 44 +++++++++++++++ backend/README.md | Bin 3570 -> 4337 bytes backend/pyproject.toml | 26 ++++++++- code-explorer/pyproject.toml | 16 ++++++ frontend/.eslintrc.json | 3 - frontend/.prettierignore | 6 ++ frontend/.prettierrc.json | 11 ++++ frontend/eslint.config.mjs | 20 +++++++ frontend/package-lock.json | 105 +++++++++++++++++++++++++++++++++++ frontend/package.json | 6 +- scripts/setup-hooks.sh | 27 +++++++++ 13 files changed, 301 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/lint.yml create mode 100644 .pre-commit-config.yaml delete mode 100644 frontend/.eslintrc.json create mode 100644 frontend/.prettierignore create mode 100644 frontend/.prettierrc.json create mode 100644 frontend/eslint.config.mjs create mode 100755 scripts/setup-hooks.sh diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..ded2d87 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,39 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + backend-lint: + name: Backend (Ruff) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 + with: + args: check + src: backend/ code-explorer/ + - uses: astral-sh/ruff-action@v3 + with: + args: format --check + src: backend/ code-explorer/ + + frontend-lint: + name: Frontend (ESLint + Prettier) + runs-on: ubuntu-latest + defaults: + run: + working-directory: frontend + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: 22 + cache: npm + cache-dependency-path: frontend/package-lock.json + - run: npm ci + - run: npx eslint . + - run: npx prettier --check . diff --git a/.gitignore b/.gitignore index 0333f9c..9e496f7 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,9 @@ env/ # UV .python-version +# Linting +.ruff_cache/ + # Testing .pytest_cache/ .coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0c60833 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,44 @@ +repos: + # Backend: Ruff lint + format + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.9.7 + hooks: + - id: ruff + name: "backend: ruff lint" + args: [--fix, --exit-non-zero-on-fix] + files: ^(backend|code-explorer)/ + types_or: [python, pyi] + - id: ruff-format + name: "backend: ruff format" + files: ^(backend|code-explorer)/ + types_or: [python, pyi] + + # Frontend: ESLint + Prettier + - repo: local + hooks: + - id: frontend-eslint + name: "frontend: eslint" + entry: bash -c 'args=(); for f in "$@"; do args+=("${f#frontend/}"); done; cd frontend && npx eslint --fix "${args[@]}"' -- + language: system + files: ^frontend/.*\.(ts|tsx|js|jsx)$ + types: [file] + - id: frontend-prettier + name: "frontend: prettier" + entry: bash -c 'args=(); for f in "$@"; do args+=("${f#frontend/}"); done; cd frontend && npx prettier --write "${args[@]}"' -- + language: system + files: ^frontend/.*\.(ts|tsx|js|jsx|css|json|md)$ + types: [file] + + # General checks + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + args: [--markdown-linebreak-ext=md] + - id: end-of-file-fixer + - id: check-yaml + - id: check-json + exclude: ^frontend/package-lock\.json$ + - id: check-merge-conflict + - id: check-added-large-files + args: [--maxkb=500] diff --git a/backend/README.md b/backend/README.md index 84982b6b31967c13643300f67566e40a201fd3f8..1c144706b6477283110129f388d92e453bf95b3e 100644 GIT binary patch delta 810 zcmZ9KO>WyT5QWh#ksA~RrpU%-O=<643d8}5{1izyMNmKtHM9s*q(D-Rz0w7Gm;yan zbApck>NFcd;v3AHc^_Zi|M~dq_N@CFP$*y5CPC%>E&@93gF>LQmMIyx1?ft##*N=? z@^o*MwDt(z0q!vzVQi5x#`(sCm_#x)Ld4WIYt9)fisDJeiuDIX6-<+2PIA{aV(`sK zq$0+?RVIRQF_A%pR3S8PMqI&jWmE;Bb!#So^UqT4(3PC8;qkuJ`p2S5sfp`6FTIMH zj44P<26-ScUkly-zF*v2=U*P~ZecI2DG7^rE2uGSk%%~LjY?2rLNzK@_KG%^N|LHqK_KFLk03k>cqr}c(fHro%yMH3TejkwvJq8FoO*&D9alD1-fLeIlzc#^2~;-83l+B;yF;Xkd2Rz^Q+dB ONSR|F#BY7}^W9&c1Qufe delta 37 scmeyU_(^($ACH7veojt)xk5=sYH_MUdTC}#YEEWeYO&SkRQ?`D02lZUB>(^b diff --git a/backend/pyproject.toml b/backend/pyproject.toml index efa8ffc..32de2df 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -51,6 +51,31 @@ enterprise = [ "slack-bolt>=1.21.0", "slack-sdk>=3.33.0", ] +dev = ["pre-commit>=4.0.0"] + +[tool.ruff] +target-version = "py312" +line-length = 120 + +[tool.ruff.lint] +select = ["F", "E", "W", "I"] +ignore = [ + "E501", # Line too long (formatter handles this) + "E711", # Comparison to None (SQLAlchemy: column == None) + "E712", # Comparison to True/False (SQLAlchemy: column == True) + "E741", # Ambiguous variable name + "F821", # Undefined name (SQLAlchemy forward references in Mapped["Model"]) + "F841", # Local variable assigned but never used (side-effect assignments) + "E402", # Module-level import not at top (common in workers/plugins) +] + +[tool.ruff.lint.isort] +known-first-party = ["app", "workers"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" [tool.pytest.ini_options] testpaths = ["tests"] @@ -63,4 +88,3 @@ addopts = [ "--strict-markers", "--strict-config", ] - diff --git a/code-explorer/pyproject.toml b/code-explorer/pyproject.toml index eaee676..f888a53 100644 --- a/code-explorer/pyproject.toml +++ b/code-explorer/pyproject.toml @@ -10,6 +10,22 @@ dependencies = [ "pydantic-settings>=2.6.0", ] +[tool.ruff] +target-version = "py312" +line-length = 120 + +[tool.ruff.lint] +select = ["F", "E", "W", "I"] +ignore = ["E501"] + +[tool.ruff.lint.isort] +known-first-party = ["app"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "lf" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/frontend/.eslintrc.json b/frontend/.eslintrc.json deleted file mode 100644 index bffb357..0000000 --- a/frontend/.eslintrc.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "extends": "next/core-web-vitals" -} diff --git a/frontend/.prettierignore b/frontend/.prettierignore new file mode 100644 index 0000000..419c80c --- /dev/null +++ b/frontend/.prettierignore @@ -0,0 +1,6 @@ +node_modules/ +.next/ +out/ +build/ +coverage/ +package-lock.json diff --git a/frontend/.prettierrc.json b/frontend/.prettierrc.json new file mode 100644 index 0000000..34a1965 --- /dev/null +++ b/frontend/.prettierrc.json @@ -0,0 +1,11 @@ +{ + "semi": true, + "singleQuote": false, + "tabWidth": 2, + "trailingComma": "all", + "printWidth": 100, + "bracketSpacing": true, + "arrowParens": "always", + "endOfLine": "lf", + "plugins": ["prettier-plugin-tailwindcss"] +} diff --git a/frontend/eslint.config.mjs b/frontend/eslint.config.mjs new file mode 100644 index 0000000..9f0a8b1 --- /dev/null +++ b/frontend/eslint.config.mjs @@ -0,0 +1,20 @@ +import nextCoreWebVitals from "eslint-config-next/core-web-vitals"; + +export default [ + ...nextCoreWebVitals, + { + rules: { + // Downgrade pre-existing issues to warnings (fix incrementally) + "react/no-unescaped-entities": "warn", + "react/no-children-prop": "warn", + "@next/next/no-img-element": "warn", + "@next/next/no-assign-module-variable": "warn", + "react-hooks/exhaustive-deps": "warn", + "react-hooks/refs": "warn", + "react-hooks/set-state-in-effect": "warn", + "react-hooks/static-components": "warn", + "react-hooks/set-state-in-render": "warn", + "react-hooks/immutability": "warn", + }, + }, +]; diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 04c8a91..b70489f 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -67,6 +67,8 @@ "eslint": "^9.39.1", "eslint-config-next": "^16.0.3", "postcss": "^8.5.6", + "prettier": "^3.4.0", + "prettier-plugin-tailwindcss": "^0.6.0", "typescript": "^5.8.0" } }, @@ -8950,6 +8952,109 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "3.8.1", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.8.1.tgz", + "integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/prettier-plugin-tailwindcss": { + "version": "0.6.14", + "resolved": "https://registry.npmjs.org/prettier-plugin-tailwindcss/-/prettier-plugin-tailwindcss-0.6.14.tgz", + "integrity": "sha512-pi2e/+ZygeIqntN+vC573BcW5Cve8zUB0SSAGxqpB4f96boZF4M3phPVoOFCeypwkpRYdi7+jQ5YJJUwrkGUAg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=14.21.3" + }, + "peerDependencies": { + "@ianvs/prettier-plugin-sort-imports": "*", + "@prettier/plugin-hermes": "*", + "@prettier/plugin-oxc": "*", + "@prettier/plugin-pug": "*", + "@shopify/prettier-plugin-liquid": "*", + "@trivago/prettier-plugin-sort-imports": "*", + "@zackad/prettier-plugin-twig": "*", + "prettier": "^3.0", + "prettier-plugin-astro": "*", + "prettier-plugin-css-order": "*", + "prettier-plugin-import-sort": "*", + "prettier-plugin-jsdoc": "*", + "prettier-plugin-marko": "*", + "prettier-plugin-multiline-arrays": "*", + "prettier-plugin-organize-attributes": "*", + "prettier-plugin-organize-imports": "*", + "prettier-plugin-sort-imports": "*", + "prettier-plugin-style-order": "*", + "prettier-plugin-svelte": "*" + }, + "peerDependenciesMeta": { + "@ianvs/prettier-plugin-sort-imports": { + "optional": true + }, + "@prettier/plugin-hermes": { + "optional": true + }, + "@prettier/plugin-oxc": { + "optional": true + }, + "@prettier/plugin-pug": { + "optional": true + }, + "@shopify/prettier-plugin-liquid": { + "optional": true + }, + "@trivago/prettier-plugin-sort-imports": { + "optional": true + }, + "@zackad/prettier-plugin-twig": { + "optional": true + }, + "prettier-plugin-astro": { + "optional": true + }, + "prettier-plugin-css-order": { + "optional": true + }, + "prettier-plugin-import-sort": { + "optional": true + }, + "prettier-plugin-jsdoc": { + "optional": true + }, + "prettier-plugin-marko": { + "optional": true + }, + "prettier-plugin-multiline-arrays": { + "optional": true + }, + "prettier-plugin-organize-attributes": { + "optional": true + }, + "prettier-plugin-organize-imports": { + "optional": true + }, + "prettier-plugin-sort-imports": { + "optional": true + }, + "prettier-plugin-style-order": { + "optional": true + }, + "prettier-plugin-svelte": { + "optional": true + } + } + }, "node_modules/prism-react-renderer": { "version": "2.4.1", "resolved": "https://registry.npmjs.org/prism-react-renderer/-/prism-react-renderer-2.4.1.tgz", diff --git a/frontend/package.json b/frontend/package.json index d8bef3e..070555e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -6,7 +6,9 @@ "dev": "node_modules/.bin/next dev", "build": "node_modules/.bin/next build", "start": "node_modules/.bin/next start", - "lint": "node_modules/.bin/next lint" + "lint": "node_modules/.bin/next lint", + "format": "prettier --write .", + "format:check": "prettier --check ." }, "dependencies": { "@monaco-editor/react": "^4.7.0", @@ -68,6 +70,8 @@ "eslint": "^9.39.1", "eslint-config-next": "^16.0.3", "postcss": "^8.5.6", + "prettier": "^3.4.0", + "prettier-plugin-tailwindcss": "^0.6.0", "typescript": "^5.8.0" } } diff --git a/scripts/setup-hooks.sh b/scripts/setup-hooks.sh new file mode 100755 index 0000000..457be24 --- /dev/null +++ b/scripts/setup-hooks.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" + +echo "==> Installing pre-commit..." +cd "$REPO_ROOT/backend" +uv pip install pre-commit + +echo "==> Installing pre-commit hooks..." +cd "$REPO_ROOT" +pre-commit install + +echo "==> Ensuring frontend dependencies are installed..." +cd "$REPO_ROOT/frontend" +if [ ! -d node_modules ]; then + npm install +else + echo " node_modules already exists, skipping npm install" +fi + +echo "" +echo "Done! Pre-commit hooks are now active." +echo "They will run automatically on every git commit." +echo "" +echo "To run hooks manually on all files:" +echo " pre-commit run --all-files" From 24a5e90dda6e4377db6b7dc12e6fab6ce0dd8e00 Mon Sep 17 00:00:00 2001 From: Jaseem Jas Date: Sat, 21 Feb 2026 18:30:36 +0530 Subject: [PATCH 2/3] style: apply Ruff and Prettier formatting to entire codebase One-time auto-format of all existing code: - 579 Python files reformatted by Ruff (backend + code-explorer) - Frontend files reformatted by Prettier with Tailwind class sorting --- .github/PULL_REQUEST_TEMPLATE.md | 4 +- README.md | 2 +- backend/alembic/README | 2 +- ...f349edc_add_feature_completion_tracking.py | 64 +- ...dd_completion_summary_and_triggered_by_.py | 15 +- ..._drop_conversation_generation_triggers_.py | 7 +- ...625ec8ba1_add_project_memberships_table.py | 41 +- .../0d2b4de48ed2_add_org_plan_fields.py | 70 +- ...fix_spec_version_unique_constraint_per_.py | 21 +- ...cb83afe_add_implementation_phases_table.py | 17 +- ...17e2d7ed64e2_add_testing_debug_settings.py | 34 +- .../1931abc02c3f_add_user_auth_fields.py | 15 +- ...52_add_organization_id_to_organizations.py | 23 +- ...bffcc2a_phase3_autogen_discovery_schema.py | 88 +- .../1e859fbdbd74_add_spec_coverage_reports.py | 35 +- .../22633d1c969b_add_feature_key_number.py | 19 +- ...3c2e92fa_add_llm_usage_tracking_to_jobs.py | 23 +- .../264f9632081a_add_notification_tables.py | 91 +- ...e_add_integration_configs_and_bug_sync_.py | 65 +- ...c2d246f32_merge_llm_usage_and_call_logs.py | 9 +- ...remove_completion_summary_from_features.py | 11 +- ..._add_prompt_plan_coverage_reports_table.py | 38 +- ...3270c_add_organizations_and_memberships.py | 55 +- .../4691251c9f11_add_user_trial_started_at.py | 14 +- ...ix_brainstorm_module_feature_types_data.py | 7 +- ...ccf3da8_add_key_lookup_hash_to_api_keys.py | 19 +- ...8b3923aa347_add_thread_ai_response_flag.py | 12 +- .../58b5bf6d73fc_add_discovery_tables.py | 99 +- ...add_generation_flags_to_impl_and_thread.py | 22 +- ...c2c0d52_add_threads_and_comments_tables.py | 72 +- ...cb961_fix_module_feature_types_properly.py | 7 +- .../6ebda2d1112d_add_llm_preference_table.py | 31 +- ...bb4_add_module_and_feature_type_columns.py | 39 +- ...dd_thread_items_and_followup_timestamps.py | 51 +- ..._add_display_name_to_integration_config.py | 25 +- ..._add_deleted_at_and_key_constraints_to_.py | 23 +- ...664c9d7b_set_freemium_max_users_default.py | 7 +- .../8e970e133ccd_add_spec_versions_table.py | 17 +- .../8eebe3a100d3_add_mcp_oauth_tables.py | 7 +- ...73c88f9ac_add_email_verification_fields.py | 7 +- ...b_add_suggested_implementation_name_to_.py | 7 +- ...c0f1117_add_implementation_notes_tables.py | 97 +- ...d_brainstorming_phases_modules_features.py | 33 +- .../versions/a1b2c3d4e5f7_add_module_key.py | 14 +- .../a60117da6409_add_projects_table.py | 63 +- .../a6b7c8d9e0f1_user_level_api_keys.py | 2 +- .../versions/a8845f795cf1_add_jobs_table.py | 61 +- ...d4e5f6g7_add_draft_final_version_tables.py | 49 +- ...d9e0f1a2_add_summary_to_grounding_files.py | 3 +- ...0f1g2_add_coding_agent_name_to_mcp_logs.py | 3 +- ...94089ee371_add_api_key_encrypted_column.py | 2 +- ...ba27eeb7dc05_merge_cexp08_and_icshare01.py | 9 +- ...f8b5b7_add_api_keys_table_for_mcp_http_.py | 47 +- ...d4e5f6g7h8_add_thread_version_anchoring.py | 23 +- ...add_archived_at_to_brainstorming_phases.py | 3 +- .../c8d9e0f1g2h3_add_feature_updated_at.py | 17 +- .../cec01_add_code_exploration_cache.py | 2 +- .../versions/cexp01_add_code_explorer.py | 2 +- ...xp06_add_raw_output_to_code_exploration.py | 3 +- .../cexp07_add_thread_code_exploration.py | 3 +- .../cexp08_move_code_explorer_to_project.py | 3 +- .../d1cf77c4c1fa_add_llm_usage_logs_table.py | 7 +- .../d4e5f6g7h8i9_add_activity_logs.py | 21 +- ...d5b6e7f8g123_add_decision_summary_short.py | 3 +- ...2c9e1a3_add_non_goals_violations_column.py | 15 +- ..._add_show_create_implementation_button_.py | 12 +- .../dus01_daily_usage_summary_table.py | 2 +- .../dus02_realtime_aggregation_trigger.py | 2 - .../versions/dus03_fix_trigger_on_conflict.py | 1 - ...5g6h7i8j9_add_feature_priority_category.py | 33 +- .../e6f7g8h9i0j1_add_email_templates_table.py | 3 +- .../ea0f87fc2305_add_is_sample_to_projects.py | 3 +- ...2775e46_add_pending_approval_to_threads.py | 11 +- .../f5g6h7i8j9k0_add_llm_call_logs_table.py | 65 +- .../f7g8h9i0j1k2_seed_email_templates.py | 17 +- ...3d_initial_migration_create_users_table.py | 25 +- .../fb1c2d3e4f5g_add_notes_updated_fields.py | 3 +- ...13b_add_created_by_to_pre_phase_message.py | 24 +- .../freemium01_increase_max_users_to_5.py | 6 +- .../g6h7i8j9k0l1_add_mcp_call_logs_table.py | 101 +- ...9i0j1k2l3_add_triggered_by_and_duration.py | 12 +- ...gc2d3e4f5g6h_add_thread_ai_error_fields.py | 19 +- .../gen01flags_add_generation_status_flags.py | 2 +- .../ghoauth01_add_github_oauth_states.py | 2 +- ...2_add_github_oauth_to_platform_settings.py | 2 +- .../grdbr01_add_grounding_file_branches.py | 2 +- .../versions/grdbr02_add_is_merging_flag.py | 2 +- .../grdbr03_add_content_updated_at.py | 6 +- .../grdbr04_add_global_content_updated_at.py | 6 +- .../grdbr05_add_last_synced_with_global_at.py | 2 +- .../grnotes01_add_grounding_note_versions.py | 44 +- .../h7i8j9k0l1m2_add_vfs_metadata_table.py | 2 +- ...2l3m4_add_thread_item_summary_snapshots.py | 3 +- ...d3e4f5g6h7i_add_session_ai_error_fields.py | 19 +- .../i0j1k2l3m4n5_add_mcp_image_submissions.py | 68 +- .../i8j9k0l1m2n3_add_team_roles_tables.py | 2 +- ...cshare01_add_integration_config_sharing.py | 2 +- .../ie4f5g6h7i8j_convert_trial_to_freemium.py | 54 +- .../impl01_add_implementations_table.py | 6 +- .../impl02_add_implementation_created_enum.py | 2 +- ...d_completion_summary_to_implementations.py | 14 +- .../j1k2l3m4n5o6_add_form_drafts_table.py | 4 +- .../j9k0l1m2n3o4_make_team_roles_dynamic.py | 11 +- .../k0l1m2n3o4p5_add_grounding_files_table.py | 2 +- ...m2n3o4p5q6_add_feature_content_versions.py | 14 +- .../m2n3o4p5q6r7_add_feature_import_fields.py | 2 +- .../mcqans01_add_mcq_answer_item_type.py | 2 +- ...o4p5q6r7s8_add_identity_provider_tables.py | 2 +- .../o4p5q6r7s8t9_seed_identity_providers.py | 8 +- ...q6r7s8t9u0_add_platform_settings_tables.py | 12 +- .../pc01_rename_pre_phase_to_project_chat.py | 43 +- .../pc02_add_retry_status_to_project_chat.py | 4 +- .../phc01_add_phase_containers_table.py | 2 +- .../phc02_add_container_fields_to_phases.py | 2 +- .../phc03_migrate_phases_to_containers.py | 7 +- ...4_add_target_container_to_project_chats.py | 2 +- .../ppc01_add_visibility_to_project_chats.py | 5 +- .../ppd01_add_pre_phase_discussions.py | 6 +- .../ppd02_add_discussion_phase_link.py | 3 +- .../ppd03_add_pre_phase_feature_fields.py | 3 +- .../ppd04_add_org_scoped_discussions.py | 3 +- .../versions/ppd05_add_image_attachments.py | 3 +- .../alembic/versions/ppd06_add_chat_title.py | 3 +- .../versions/ppd08_add_summary_snapshot.py | 2 +- .../versions/pts01_add_project_tech_stack.py | 2 +- ...9u0v1_add_base_url_to_platform_settings.py | 3 +- .../qrs02_add_grounding_file_is_generating.py | 2 +- .../qst01phase_add_phase_question_stats.py | 42 +- ...r7s8t9u0v1w2_add_invitations_and_groups.py | 11 +- .../versions/rec01_plan_recommendations.py | 2 +- .../repo01_add_project_repositories.py | 4 +- ...8t9u0v1w2x3_add_current_org_id_to_users.py | 2 +- .../alembic/versions/sid01_add_short_ids.py | 3 +- .../versions/slack01_add_slack_tables.py | 3 +- .../versions/slack02_add_oauth_tables.py | 3 +- .../sys01_add_system_thread_item_type.py | 2 +- ...9u0v1w2x3y4_add_thread_decision_summary.py | 2 +- ...p01_add_exploration_prompt_search_query.py | 2 +- .../th01_add_retry_status_to_threads.py | 2 +- ...4z5_add_proactive_conversation_features.py | 9 +- ...at01_add_project_chat_fields_to_threads.py | 2 +- ...v1w2x3y4z5a6_add_user_question_sessions.py | 8 +- ...y4z5a6b7_remove_thread_followup_columns.py | 2 +- .../wsrch01_add_web_search_support.py | 2 +- .../x3y4z5a6b7c8_add_feature_description.py | 2 +- ...4z5a6b7c8d9_add_conversation_rerun_flag.py | 3 +- ...6b7c8d9e0_convert_project_key_to_prefix.py | 3 +- backend/app/agents/brainstorm/__init__.py | 14 +- backend/app/agents/brainstorm/generator.py | 39 +- backend/app/agents/brainstorm/orchestrator.py | 17 +- backend/app/agents/brainstorm/types.py | 36 +- .../brainstorm_conversation/__init__.py | 64 +- .../aspect_generator.py | 123 +- .../brainstorm_conversation/classifier.py | 73 +- .../code_explorer_stage.py | 119 +- .../brainstorm_conversation/critic_pruner.py | 59 +- .../brainstorm_conversation/logging_config.py | 40 +- .../brainstorm_conversation/orchestrator.py | 131 +- .../question_generator.py | 124 +- .../brainstorm_conversation/summarizer.py | 22 +- .../agents/brainstorm_conversation/types.py | 79 +- .../agents/brainstorm_conversation/utils.py | 8 +- .../agents/brainstorm_prompt_plan/__init__.py | 65 +- .../agents/brainstorm_prompt_plan/analyzer.py | 60 +- .../brainstorm_prompt_plan/logging_config.py | 40 +- .../brainstorm_prompt_plan/orchestrator.py | 52 +- .../agents/brainstorm_prompt_plan/planner.py | 105 +- .../app/agents/brainstorm_prompt_plan/qa.py | 47 +- .../agents/brainstorm_prompt_plan/types.py | 43 +- .../agents/brainstorm_prompt_plan/utils.py | 8 +- .../agents/brainstorm_prompt_plan/writer.py | 74 +- .../app/agents/brainstorm_spec/__init__.py | 67 +- .../agents/brainstorm_spec/logging_config.py | 40 +- .../agents/brainstorm_spec/orchestrator.py | 47 +- backend/app/agents/brainstorm_spec/planner.py | 86 +- .../app/agents/brainstorm_spec/qa_coverage.py | 63 +- .../app/agents/brainstorm_spec/summarizer.py | 56 +- backend/app/agents/brainstorm_spec/types.py | 40 +- backend/app/agents/brainstorm_spec/utils.py | 8 +- backend/app/agents/brainstorm_spec/writer.py | 86 +- .../collab_thread_assistant/__init__.py | 30 +- .../collab_thread_assistant/assistant.py | 41 +- .../collab_thread_assistant/context_loader.py | 224 ++-- .../exploration_parser.py | 3 +- .../instrumentation.py | 20 +- .../collab_thread_assistant/mcq_parser.py | 29 +- .../collab_thread_assistant/orchestrator.py | 122 +- .../agents/collab_thread_assistant/retry.py | 6 +- .../spec_draft_assistant.py | 3 +- .../spec_draft_handler.py | 9 +- .../collab_thread_assistant/summarizer.py | 6 +- .../agents/collab_thread_assistant/types.py | 2 +- .../web_search_parser.py | 3 +- .../orchestrator.py | 21 +- .../summarizer.py | 23 +- .../types.py | 2 +- .../app/agents/feature_content/__init__.py | 24 +- .../agents/feature_content/context_loader.py | 86 +- .../agents/feature_content/orchestrator.py | 48 +- backend/app/agents/feature_content/types.py | 19 +- backend/app/agents/grounding/__init__.py | 4 +- .../agents/grounding/merge_orchestrator.py | 33 +- backend/app/agents/grounding/orchestrator.py | 35 +- backend/app/agents/grounding/types.py | 3 + .../app/agents/image_annotator/annotator.py | 15 +- .../agents/image_annotator/orchestrator.py | 37 +- backend/app/agents/image_annotator/types.py | 2 + backend/app/agents/llm_client.py | 151 +-- backend/app/agents/module_feature/__init__.py | 46 +- .../agents/module_feature/logging_config.py | 79 +- backend/app/agents/module_feature/merger.py | 206 ++- .../app/agents/module_feature/orchestrator.py | 129 +- .../agents/module_feature/plan_structurer.py | 216 ++-- .../agents/module_feature/spec_analyzer.py | 112 +- backend/app/agents/module_feature/types.py | 59 +- backend/app/agents/module_feature/utils.py | 150 ++- .../app/agents/module_feature/validator.py | 55 +- backend/app/agents/module_feature/writer.py | 127 +- .../agents/project_chat_assistant/__init__.py | 16 +- .../project_chat_assistant/assistant.py | 60 +- .../project_chat_assistant/orchestrator.py | 258 ++-- .../agents/project_chat_assistant/types.py | 20 +- .../app/agents/project_chat_gating/agent.py | 2 +- backend/app/agents/response_parser.py | 60 +- backend/app/agents/retry.py | 18 +- backend/app/auth/__init__.py | 1 + backend/app/auth/api_key_utils.py | 1 + backend/app/auth/dependencies.py | 42 +- backend/app/auth/domain_validation.py | 1 + backend/app/auth/platform_admin.py | 1 + backend/app/auth/providers.py | 2 +- backend/app/auth/service.py | 17 +- backend/app/auth/trial.py | 4 +- backend/app/auth/utils.py | 13 +- backend/app/config.py | 40 +- backend/app/database.py | 7 +- backend/app/integrations/__init__.py | 1 + backend/app/integrations/base.py | 1 + backend/app/integrations/factory.py | 10 +- backend/app/integrations/github.py | 30 +- backend/app/integrations/gitlab.py | 1 + backend/app/integrations/jira.py | 41 +- backend/app/main.py | 58 +- backend/app/mcp/server.py | 6 +- backend/app/mcp/tools/append_feature_note.py | 3 +- .../tools/create_clarification_question.py | 7 +- backend/app/mcp/tools/get_context.py | 77 +- backend/app/mcp/tools/get_feature_notes.py | 3 +- backend/app/mcp/tools/get_section.py | 18 +- backend/app/mcp/tools/get_toc.py | 25 +- backend/app/mcp/tools/vfs_cat.py | 13 +- backend/app/mcp/tools/vfs_find.py | 88 +- backend/app/mcp/tools/vfs_grep.py | 90 +- backend/app/mcp/tools/vfs_head.py | 2 +- backend/app/mcp/tools/vfs_ls.py | 19 +- backend/app/mcp/tools/vfs_sed.py | 6 +- backend/app/mcp/tools/vfs_set_metadata.py | 6 +- backend/app/mcp/tools/vfs_tail.py | 2 +- backend/app/mcp/tools/vfs_tree.py | 171 +-- backend/app/mcp/tools/vfs_write.py | 125 +- backend/app/mcp/utils/markdown_parser.py | 16 +- backend/app/mcp/utils/project_resolver.py | 1 + backend/app/mcp/vfs/__init__.py | 12 +- backend/app/mcp/vfs/content.py | 425 +++---- backend/app/mcp/vfs/errors.py | 3 +- backend/app/mcp/vfs/metadata.py | 92 +- backend/app/mcp/vfs/path_resolver.py | 102 +- backend/app/middleware/__init__.py | 1 + backend/app/models/__init__.py | 126 +- backend/app/models/activity_log.py | 9 +- backend/app/models/api_key.py | 2 +- backend/app/models/brainstorming_phase.py | 18 +- backend/app/models/bug_sync_history.py | 9 +- backend/app/models/code_exploration_result.py | 12 +- backend/app/models/daily_usage_summary.py | 8 +- backend/app/models/email_template.py | 9 +- backend/app/models/events/__init__.py | 1 + .../models/events/phase_container_events.py | 9 +- backend/app/models/feature.py | 10 +- backend/app/models/feature_content_version.py | 8 +- backend/app/models/feature_import_comment.py | 9 +- backend/app/models/final_prompt_plan.py | 8 +- backend/app/models/final_spec.py | 8 +- backend/app/models/form_draft.py | 2 +- backend/app/models/github_oauth_state.py | 12 +- backend/app/models/grounding_file.py | 4 +- backend/app/models/grounding_file_branch.py | 9 +- backend/app/models/grounding_note_version.py | 8 +- backend/app/models/identity_provider.py | 9 +- backend/app/models/implementation.py | 16 +- backend/app/models/inbox_follow.py | 7 +- backend/app/models/inbox_mention.py | 6 +- backend/app/models/integration_config.py | 19 +- .../app/models/integration_config_share.py | 12 +- backend/app/models/job.py | 12 +- backend/app/models/llm_call_log.py | 15 +- backend/app/models/llm_preference.py | 29 +- backend/app/models/llm_usage_log.py | 7 +- backend/app/models/mcp_call_log.py | 3 +- backend/app/models/mcp_image_submission.py | 6 +- backend/app/models/mcp_oauth_client.py | 14 +- backend/app/models/mcp_oauth_code.py | 18 +- backend/app/models/mcp_oauth_token.py | 22 +- backend/app/models/module.py | 10 +- backend/app/models/notification_preference.py | 9 +- .../app/models/notification_project_mute.py | 9 +- .../app/models/notification_thread_watch.py | 9 +- backend/app/models/org_invitation.py | 32 +- backend/app/models/org_invitation_group.py | 8 +- backend/app/models/org_membership.py | 20 +- backend/app/models/organization.py | 8 +- backend/app/models/phase_container.py | 14 +- backend/app/models/plan_recommendation.py | 5 +- backend/app/models/platform_connector.py | 21 +- backend/app/models/platform_settings.py | 64 +- backend/app/models/project.py | 12 +- backend/app/models/project_chat.py | 22 +- backend/app/models/project_membership.py | 18 +- backend/app/models/project_repository.py | 7 +- backend/app/models/project_share.py | 10 +- backend/app/models/prompt_plan_coverage.py | 14 +- backend/app/models/provisioning.py | 6 +- backend/app/models/slack_channel_link.py | 24 +- backend/app/models/slack_user_mapping.py | 25 +- backend/app/models/spec_coverage.py | 14 +- backend/app/models/spec_version.py | 8 +- backend/app/models/team_role.py | 14 +- backend/app/models/thread.py | 29 +- backend/app/models/thread_item.py | 25 +- backend/app/models/user.py | 9 +- .../app/models/user_conversation_status.py | 5 +- backend/app/models/user_group.py | 8 +- backend/app/models/user_group_membership.py | 12 +- backend/app/models/user_identity.py | 10 +- backend/app/models/user_question_session.py | 15 +- backend/app/models/validators/__init__.py | 1 + .../app/models/validators/phase_validators.py | 43 +- backend/app/models/vfs_metadata.py | 2 +- backend/app/permissions/__init__.py | 6 +- backend/app/permissions/context.py | 10 +- backend/app/plugin_registry.py | 8 +- backend/app/routers/__init__.py | 1 + backend/app/routers/activity.py | 10 +- backend/app/routers/agent_api.py | 19 +- backend/app/routers/analytics.py | 8 +- backend/app/routers/api_keys.py | 12 +- backend/app/routers/auth.py | 48 +- backend/app/routers/brainstorming_phases.py | 228 ++-- backend/app/routers/conversations.py | 92 +- backend/app/routers/dashboard.py | 11 +- backend/app/routers/drafts.py | 9 +- backend/app/routers/email_templates.py | 1 + .../app/routers/feature_content_versions.py | 7 +- backend/app/routers/features.py | 150 +-- backend/app/routers/form_drafts.py | 2 +- backend/app/routers/grounding.py | 67 +- backend/app/routers/grounding_notes.py | 2 +- backend/app/routers/images.py | 53 +- backend/app/routers/implementations.py | 64 +- backend/app/routers/inbox.py | 16 +- backend/app/routers/inbox_deep_link.py | 2 +- backend/app/routers/inbox_follows.py | 43 +- backend/app/routers/integrations.py | 28 +- backend/app/routers/invitations.py | 35 +- backend/app/routers/invite_acceptance.py | 10 +- backend/app/routers/jobs.py | 7 +- backend/app/routers/llm_call_logs.py | 138 +- backend/app/routers/llm_preferences.py | 7 +- backend/app/routers/mcp_call_logs.py | 10 +- backend/app/routers/mcp_http.py | 240 ++-- backend/app/routers/mcp_images.py | 5 +- backend/app/routers/modules.py | 33 +- backend/app/routers/org_chats.py | 64 +- backend/app/routers/orgs.py | 47 +- backend/app/routers/phase_containers.py | 44 +- backend/app/routers/plan_recommendations.py | 7 +- backend/app/routers/platform_settings.py | 84 +- backend/app/routers/project_chat_images.py | 18 +- backend/app/routers/project_chats.py | 122 +- backend/app/routers/project_repositories.py | 48 +- backend/app/routers/project_shares.py | 17 +- backend/app/routers/projects.py | 48 +- backend/app/routers/team_roles.py | 18 +- backend/app/routers/testing.py | 3 +- backend/app/routers/thread_images.py | 13 +- backend/app/routers/thread_items.py | 14 +- backend/app/routers/threads.py | 84 +- backend/app/routers/user_groups.py | 20 +- backend/app/routers/user_question_sessions.py | 40 +- backend/app/routers/websocket.py | 2 +- backend/app/schemas/__init__.py | 87 +- backend/app/schemas/activity.py | 5 +- backend/app/schemas/analytics.py | 96 +- backend/app/schemas/api_key.py | 8 +- backend/app/schemas/auth.py | 2 + backend/app/schemas/brainstorming_phase.py | 13 +- backend/app/schemas/bug_sync_history.py | 5 +- backend/app/schemas/dashboard.py | 4 +- backend/app/schemas/draft_version.py | 7 +- backend/app/schemas/email_template.py | 1 + backend/app/schemas/feature.py | 16 +- .../app/schemas/feature_content_version.py | 3 +- backend/app/schemas/final_version.py | 11 +- backend/app/schemas/form_draft.py | 4 +- backend/app/schemas/grounding_note.py | 14 +- backend/app/schemas/identity.py | 1 + backend/app/schemas/implementation.py | 2 +- backend/app/schemas/inbox_badge.py | 1 + backend/app/schemas/inbox_conversation.py | 1 + backend/app/schemas/inbox_deep_link.py | 1 + backend/app/schemas/inbox_event.py | 2 + backend/app/schemas/inbox_follow.py | 5 +- backend/app/schemas/integration_config.py | 19 +- backend/app/schemas/invitation.py | 2 - backend/app/schemas/llm_call_log.py | 7 +- backend/app/schemas/llm_preference.py | 1 + backend/app/schemas/mcp_call_log.py | 6 +- backend/app/schemas/module.py | 11 +- backend/app/schemas/notification.py | 15 +- backend/app/schemas/oauth.py | 1 + backend/app/schemas/org.py | 5 +- backend/app/schemas/phase_container.py | 7 +- backend/app/schemas/plan_recommendation.py | 56 +- backend/app/schemas/platform_settings.py | 186 +-- backend/app/schemas/project.py | 7 +- backend/app/schemas/project_chat.py | 30 +- backend/app/schemas/project_repository.py | 1 + backend/app/schemas/project_share.py | 1 - backend/app/schemas/spec.py | 1 + backend/app/schemas/thread.py | 37 +- backend/app/schemas/thread_item.py | 39 +- backend/app/schemas/user_question_session.py | 15 +- backend/app/services/activity_log_service.py | 44 +- backend/app/services/agent_utils.py | 6 +- backend/app/services/analytics_cache.py | 4 +- backend/app/services/analytics_service.py | 245 ++-- backend/app/services/api_key_service.py | 41 +- .../services/brainstorming_phase_service.py | 1117 +++++++++-------- backend/app/services/bug_sync_service.py | 27 +- backend/app/services/code_explorer_client.py | 35 +- .../services/daily_usage_summary_service.py | 165 +-- backend/app/services/dashboard_service.py | 30 +- backend/app/services/draft_version_service.py | 12 +- backend/app/services/email_service.py | 13 +- .../app/services/email_template_service.py | 35 +- .../feature_content_version_service.py | 37 +- .../app/services/feature_import_service.py | 30 +- backend/app/services/feature_service.py | 81 +- backend/app/services/finalization_service.py | 20 +- .../github_integration_oauth_service.py | 16 +- .../app/services/grounding_note_service.py | 10 +- backend/app/services/grounding_service.py | 47 +- backend/app/services/image_service.py | 56 +- .../app/services/implementation_service.py | 168 +-- backend/app/services/inbox_badge_service.py | 49 +- .../app/services/inbox_broadcast_service.py | 90 +- .../services/inbox_conversation_service.py | 319 ++--- backend/app/services/inbox_follow_service.py | 205 +-- backend/app/services/inbox_mention_service.py | 86 +- backend/app/services/inbox_status_service.py | 213 ++-- .../integration_config_share_service.py | 22 +- backend/app/services/integration_service.py | 57 +- backend/app/services/invitation_service.py | 40 +- backend/app/services/job_service.py | 75 +- backend/app/services/kafka_producer.py | 13 +- backend/app/services/llm_adapters.py | 31 +- backend/app/services/llm_call_log_service.py | 12 +- backend/app/services/llm_mock.py | 5 +- backend/app/services/llm_mock_prompt_plan.py | 149 ++- backend/app/services/llm_mock_spec.py | 10 +- .../app/services/llm_preference_service.py | 5 +- backend/app/services/llm_usage_log_service.py | 82 +- backend/app/services/mcp_call_log_service.py | 18 +- backend/app/services/mcp_image_service.py | 4 +- backend/app/services/mcp_oauth_service.py | 13 +- backend/app/services/mention_utils.py | 2 +- backend/app/services/module_service.py | 84 +- .../notification_adapters/__init__.py | 1 + .../services/notification_adapters/base.py | 5 +- .../services/notification_adapters/email.py | 10 +- .../services/notification_adapters/slack.py | 10 +- .../services/notification_adapters/teams.py | 10 +- backend/app/services/notification_service.py | 207 +-- backend/app/services/org_service.py | 35 +- .../app/services/phase_container_service.py | 61 +- .../app/services/phase_progress_service.py | 14 +- .../services/plan_recommendation_service.py | 24 +- backend/app/services/plan_service.py | 28 +- .../app/services/platform_settings_service.py | 62 +- backend/app/services/prefix_service.py | 15 +- backend/app/services/project_chat_service.py | 746 +++++------ .../services/project_repository_service.py | 29 +- backend/app/services/project_service.py | 63 +- backend/app/services/project_share_service.py | 115 +- .../app/services/sample_project_service.py | 85 +- backend/app/services/spec_service.py | 3 +- backend/app/services/team_role_service.py | 60 +- backend/app/services/thread_service.py | 305 ++--- .../app/services/typing_indicator_service.py | 24 +- backend/app/services/user_group_service.py | 17 +- .../services/user_question_session_service.py | 214 ++-- backend/app/services/user_service.py | 17 +- backend/app/utils/deep_link.py | 1 + backend/app/websocket/broadcast_consumer.py | 5 +- backend/app/websocket/manager.py | 8 +- backend/export_mock_discovery_data.py | 23 +- backend/mcp_server.py | 4 +- .../scripts/fix_missing_content_versions.py | 46 +- backend/scripts/mcp_stdio_proxy.py | 71 +- .../test_input_validator.py | 77 +- .../test_mcq_parser.py | 10 +- .../test_brainstorm_code_explorer_stage.py | 142 ++- .../agents/test_brainstorm_spec_types.py | 163 +-- .../agents/test_brainstorm_spec_writer.py | 8 +- .../test_assistant.py | 26 +- .../test_context_loader.py | 78 +- .../test_exploration_parser.py | 12 +- .../test_instrumentation.py | 13 +- .../test_orchestrator.py | 115 +- .../test_quality.py | 43 +- .../test_retry.py | 29 +- .../test_stress.py | 124 +- .../test_summarizer.py | 68 +- .../test_types.py | 11 +- .../test_web_search_parser.py | 4 +- .../agents/test_grounding_orchestrator.py | 52 +- backend/tests/agents/test_llm_client.py | 24 +- .../tests/agents/test_project_chat_gating.py | 9 +- backend/tests/agents/test_response_parser.py | 121 +- .../agents/test_response_parser_realdata.py | 33 +- backend/tests/agents/test_retry.py | 2 +- backend/tests/conftest.py | 18 +- .../tests/services/test_web_search_service.py | 26 +- backend/tests/test_activity_log_model.py | 32 +- backend/tests/test_activity_log_service.py | 15 +- backend/tests/test_activity_log_wiring.py | 14 +- backend/tests/test_agent_api.py | 26 +- backend/tests/test_analytics_cache.py | 4 +- backend/tests/test_analytics_service.py | 29 +- backend/tests/test_api_key_encryption.py | 40 +- backend/tests/test_auth_service.py | 64 +- backend/tests/test_auth_utils.py | 7 +- backend/tests/test_brainstorm_agent.py | 29 +- backend/tests/test_brainstorm_generation.py | 10 +- .../test_brainstorming_phase_endpoints.py | 2 +- .../tests/test_brainstorming_phase_models.py | 88 +- .../tests/test_brainstorming_phase_service.py | 44 +- ...est_brainstorming_phase_service_summary.py | 2 +- backend/tests/test_brainstorming_preflight.py | 51 +- backend/tests/test_code_explorer_client.py | 10 +- backend/tests/test_credit_formatter.py | 2 - backend/tests/test_daily_aggregation_job.py | 76 +- .../tests/test_daily_usage_summary_service.py | 13 +- backend/tests/test_daily_usage_trigger.py | 87 +- backend/tests/test_database.py | 9 +- backend/tests/test_domain_validation.py | 2 +- backend/tests/test_draft_endpoints.py | 4 +- backend/tests/test_draft_version_service.py | 14 +- backend/tests/test_e2e_workflow.py | 34 +- backend/tests/test_email_service.py | 23 +- backend/tests/test_email_verification.py | 58 +- backend/tests/test_feature_endpoints.py | 28 +- backend/tests/test_feature_import_service.py | 10 +- backend/tests/test_feature_service.py | 16 +- backend/tests/test_final_version_models.py | 14 +- backend/tests/test_finalization_service.py | 17 +- backend/tests/test_form_draft_router.py | 33 +- backend/tests/test_form_draft_service.py | 25 +- backend/tests/test_github_adapter.py | 9 +- .../tests/test_grounding_note_endpoints.py | 6 +- backend/tests/test_grounding_note_service.py | 25 +- backend/tests/test_grounding_service.py | 8 +- backend/tests/test_health.py | 3 +- backend/tests/test_identity_models.py | 22 +- backend/tests/test_image_service.py | 23 +- backend/tests/test_image_signing.py | 4 +- backend/tests/test_images_router.py | 94 +- backend/tests/test_inbox_badge_service.py | 11 +- backend/tests/test_inbox_broadcast_service.py | 45 +- .../tests/test_inbox_conversation_service.py | 33 +- backend/tests/test_inbox_deep_link.py | 18 +- backend/tests/test_inbox_follow_router.py | 9 +- backend/tests/test_inbox_follow_service.py | 39 +- backend/tests/test_inbox_mention_service.py | 7 +- backend/tests/test_inbox_status_service.py | 39 +- backend/tests/test_invitation_router.py | 13 +- backend/tests/test_invitation_service.py | 152 +-- .../tests/test_invite_acceptance_router.py | 19 +- backend/tests/test_jobs.py | 17 +- backend/tests/test_legacy_migration.py | 159 ++- backend/tests/test_llm_mock_prompt_plan.py | 6 +- backend/tests/test_llm_mock_spec.py | 9 +- backend/tests/test_markdown_parser.py | 3 +- backend/tests/test_mcp_call_log.py | 52 +- backend/tests/test_mcp_get_context.py | 43 +- backend/tests/test_mcp_get_section.py | 38 +- backend/tests/test_mcp_get_toc.py | 38 +- backend/tests/test_mcp_image_upload_api.py | 94 +- backend/tests/test_mcp_permissions.py | 64 +- backend/tests/test_mcp_server.py | 11 +- .../test_mention_notification_handler.py | 21 +- backend/tests/test_mention_utils.py | 26 +- .../tests/test_message_sequence_numbers.py | 10 +- backend/tests/test_module_endpoints.py | 7 +- backend/tests/test_module_feature_agent.py | 128 +- backend/tests/test_module_service.py | 16 +- backend/tests/test_notification_adapters.py | 32 +- backend/tests/test_notification_models.py | 100 +- backend/tests/test_notification_service.py | 107 +- backend/tests/test_oauth_providers.py | 3 +- backend/tests/test_oauth_routes.py | 75 +- backend/tests/test_org_chats.py | 8 +- backend/tests/test_org_endpoints.py | 31 +- backend/tests/test_org_service.py | 51 +- backend/tests/test_permissions.py | 279 ++-- backend/tests/test_phase_container_service.py | 33 +- backend/tests/test_phase_containers_router.py | 78 +- backend/tests/test_phase_progress_service.py | 144 +-- backend/tests/test_phase_validators.py | 23 +- .../tests/test_plan_recommendation_service.py | 14 +- backend/tests/test_plan_service.py | 26 +- backend/tests/test_platform_admin.py | 4 +- .../tests/test_platform_settings_router.py | 27 +- .../tests/test_platform_settings_service.py | 13 +- backend/tests/test_prefix_service.py | 11 +- backend/tests/test_project_chat_list.py | 4 +- backend/tests/test_project_chat_reactions.py | 12 +- backend/tests/test_project_endpoints.py | 111 +- .../tests/test_project_membership_service.py | 8 +- backend/tests/test_project_resolver.py | 13 +- backend/tests/test_project_service.py | 85 +- backend/tests/test_project_share_service.py | 335 ++--- backend/tests/test_project_shares_router.py | 88 +- backend/tests/test_short_id.py | 9 +- backend/tests/test_slack_models.py | 3 - backend/tests/test_spec_service.py | 22 +- backend/tests/test_team_roles.py | 124 +- backend/tests/test_thread_decision_summary.py | 6 +- backend/tests/test_thread_endpoints.py | 190 ++- backend/tests/test_thread_item_deletion.py | 219 ++-- backend/tests/test_thread_item_reactions.py | 12 +- backend/tests/test_thread_service.py | 85 +- backend/tests/test_thread_service_version.py | 14 +- .../tests/test_thread_version_anchoring.py | 18 +- .../tests/test_typing_indicator_service.py | 13 +- backend/tests/test_user_group_service.py | 50 +- backend/tests/test_user_groups_router.py | 3 +- backend/tests/test_user_service.py | 28 +- backend/tests/test_vfs_content.py | 34 +- backend/tests/test_vfs_path_resolver.py | 41 +- backend/tests/test_vfs_sed.py | 27 +- backend/tests/test_vfs_write_conversations.py | 175 ++- .../tests/test_worker_graceful_shutdown.py | 6 +- backend/tests/workers/test_helpers.py | 1 + backend/workers/consumer.py | 42 +- backend/workers/core/__init__.py | 4 +- backend/workers/core/helpers.py | 40 +- backend/workers/core/worker.py | 16 +- backend/workers/handlers/__init__.py | 46 +- backend/workers/handlers/brainstorming.py | 130 +- backend/workers/handlers/code_explorer.py | 141 +-- backend/workers/handlers/collaboration.py | 174 +-- backend/workers/handlers/generation.py | 155 +-- backend/workers/handlers/grounding.py | 221 +--- backend/workers/handlers/image_annotator.py | 14 +- backend/workers/handlers/integration.py | 89 +- backend/workers/handlers/project_chat.py | 78 +- backend/workers/handlers/web_search.py | 14 +- backend/workers/scheduler.py | 38 +- code-explorer/app/routes/explore.py | 25 +- code-explorer/app/services/claude_runner.py | 39 +- code-explorer/app/services/worktree.py | 12 +- .../admin/analytics/AdminAnalyticsClient.tsx | 42 +- frontend/app/agent-log/AgentLogClient.tsx | 28 +- frontend/app/dashboard/page.tsx | 19 +- .../email-templates/EmailTemplatesClient.tsx | 18 +- frontend/app/globals.css | 32 +- frontend/app/inbox/[...path]/page.tsx | 43 +- frontend/app/inbox/page.tsx | 2 +- .../invite/accept/InviteRedirectClient.tsx | 4 +- frontend/app/invites/[token]/InviteClient.tsx | 87 +- frontend/app/layout.tsx | 6 +- frontend/app/login/LoginClient.tsx | 36 +- .../app/mcp-explorer/McpExplorerClient.tsx | 17 +- frontend/app/mcp-log/McpLogClient.tsx | 24 +- .../members-groups/MembersGroupsClient.tsx | 8 +- .../oauth/authorize/OAuthAuthorizeClient.tsx | 77 +- frontend/app/page.tsx | 2 +- .../PlatformSettingsClient.tsx | 8 +- frontend/app/projects/ProjectsClient.tsx | 55 +- .../projects/[projectId]/activity/page.tsx | 6 +- .../[phaseId]/BrainstormingLayoutClient.tsx | 43 +- .../brainstorming/[phaseId]/activity/page.tsx | 12 +- .../[phaseId]/conversations/page.tsx | 340 ++--- .../[phaseId]/description/page.tsx | 80 +- .../brainstorming/[phaseId]/features/page.tsx | 165 ++- .../brainstorming/[phaseId]/layout.tsx | 9 +- .../[phaseId]/prompt-plan/page.tsx | 6 +- .../[projectId]/brainstorming/page.tsx | 114 +- .../[featureId]/FeatureDetailClient.tsx | 308 +++-- .../[projectId]/features/[featureId]/page.tsx | 9 +- .../projects/[projectId]/features/page.tsx | 133 +- .../app/projects/[projectId]/jobs/page.tsx | 42 +- frontend/app/projects/[projectId]/layout.tsx | 13 +- .../project-chat/[discussionId]/page.tsx | 404 +++--- .../[projectId]/project-chat/page.tsx | 161 +-- .../[projectId]/project-settings/page.tsx | 21 +- frontend/app/register/RegisterClient.tsx | 20 +- frontend/app/settings/SettingsClient.tsx | 10 +- frontend/app/settings/page.tsx | 8 +- frontend/app/trial-expired/page.tsx | 61 +- .../app/verify-email/VerifyEmailClient.tsx | 93 +- frontend/components/AICommentCard.tsx | 6 +- frontend/components/AIErrorPanel.tsx | 26 +- frontend/components/AIQuickActionsButton.tsx | 11 +- frontend/components/AIThinkingCard.tsx | 33 +- frontend/components/ActivityLogCard.tsx | 25 +- frontend/components/ActivityLogView.tsx | 13 +- .../AddBugTrackerConnectorModal.tsx | 160 +-- frontend/components/AddLLMConnectorModal.tsx | 17 +- frontend/components/AddReactionButton.tsx | 11 +- frontend/components/AddShareModal.tsx | 66 +- frontend/components/ApiKeysSection.tsx | 91 +- frontend/components/AppTopNav.tsx | 65 +- frontend/components/BugTrackerConnectors.tsx | 43 +- frontend/components/CodeExplorationCard.tsx | 53 +- .../components/CodeExplorationThreadItem.tsx | 44 +- frontend/components/CommentCard.tsx | 23 +- frontend/components/CommentItemCard.tsx | 61 +- frontend/components/CopyableId.tsx | 12 +- frontend/components/CreateChoiceDialog.tsx | 43 +- frontend/components/CreateProjectModal.tsx | 72 +- frontend/components/CreatedFeaturesPanel.tsx | 23 +- frontend/components/DangerZoneSection.tsx | 20 +- .../components/DecisionSummarizerStatus.tsx | 15 +- .../components/DescriptionImageGallery.tsx | 28 +- .../components/DiscussionImageGallery.tsx | 28 +- .../EditBugTrackerConnectorModal.tsx | 127 +- frontend/components/EditLLMConnectorModal.tsx | 18 +- frontend/components/EmptyState.tsx | 14 +- .../components/ExtensionCreationModal.tsx | 80 +- frontend/components/FlowStep.tsx | 10 +- frontend/components/ImageAnnotationModal.tsx | 47 +- .../components/ImageAttachmentGallery.tsx | 28 +- frontend/components/ImageViewerModal.tsx | 32 +- .../ImplementationCreatedMarker.tsx | 20 +- frontend/components/IntegrationShareModal.tsx | 61 +- frontend/components/JobStatusBadge.tsx | 15 +- frontend/components/JobsTable.tsx | 20 +- frontend/components/LLMConnectors.tsx | 30 +- frontend/components/LLMSettingsForm.tsx | 37 +- frontend/components/MCPConnectionInfo.tsx | 135 +- frontend/components/MCQAnswerCard.tsx | 36 +- frontend/components/MCQFollowupCard.tsx | 114 +- frontend/components/MarkdownArticle.tsx | 67 +- frontend/components/MarkdownWithMentions.tsx | 106 +- frontend/components/NewChatConfirmDialog.tsx | 14 +- frontend/components/NoFollowupMessageCard.tsx | 6 +- frontend/components/OrgSwitcher.tsx | 18 +- .../components/PendingMFBTAIMCQsMessage.tsx | 10 +- frontend/components/PhaseCreatedBanner.tsx | 22 +- .../components/ProjectChatCreateBanner.tsx | 12 +- .../ProjectChatCreateFeatureBanner.tsx | 20 +- frontend/components/ProjectChatHeader.tsx | 129 +- frontend/components/ProjectChatSidebar.tsx | 79 +- frontend/components/ProjectLayoutClient.tsx | 4 +- .../components/ProjectMetadataSection.tsx | 68 +- .../components/ProjectRepositoriesSection.tsx | 102 +- frontend/components/ProjectSharesSection.tsx | 63 +- frontend/components/ProtectedRoute.tsx | 2 +- frontend/components/QuickActions.tsx | 8 +- frontend/components/ReactionBar.tsx | 91 +- frontend/components/SlackBotSettings.tsx | 162 +-- .../components/TeamAssignmentsSection.tsx | 54 +- frontend/components/TeamRoleSettings.tsx | 64 +- frontend/components/TestingDebugSettings.tsx | 78 +- frontend/components/ThreadItemList.tsx | 40 +- frontend/components/ThreadPanel.tsx | 256 ++-- frontend/components/TrialBadge.tsx | 4 +- frontend/components/UnresolvedPointsBadge.tsx | 16 +- .../components/UnresolvedPointsSection.tsx | 31 +- frontend/components/UserMenu.tsx | 21 +- frontend/components/WebSearchCard.tsx | 126 +- frontend/components/WebSearchThreadItem.tsx | 70 +- .../admin/analytics/AnalyticsSkeleton.tsx | 10 +- .../admin/analytics/EfficiencyBadge.tsx | 20 +- .../admin/analytics/RecommendationCards.tsx | 25 +- .../admin/analytics/TopProjectsTable.tsx | 29 +- .../admin/analytics/TopUsersTable.tsx | 32 +- .../components/agent-log/AgentLogFilters.tsx | 80 +- .../components/agent-log/AgentLogJobCard.tsx | 37 +- .../agent-log/LLMCallDetailModal.tsx | 228 ++-- frontend/components/agent-log/LLMCallRow.tsx | 68 +- .../components/agent-log/renderers/index.tsx | 485 +++---- .../brainstorming/AISuggestionsTab.tsx | 154 ++- .../brainstorming/CombinedStatusPanel.tsx | 54 +- .../brainstorming/CompactGenerationStatus.tsx | 6 +- .../brainstorming/ConversationInfoPanel.tsx | 46 +- .../brainstorming/CustomQuestionsModal.tsx | 17 +- .../brainstorming/DocumentBlockRenderer.tsx | 12 +- .../DocumentWithCommentsLayout.tsx | 15 +- .../brainstorming/DraftFinalTabs.tsx | 12 +- .../brainstorming/DraftVersionHeader.tsx | 24 +- .../brainstorming/DraftsTabContent.tsx | 144 ++- .../brainstorming/FeatureThreadPanel.tsx | 29 +- .../brainstorming/FinalTabContent.tsx | 65 +- .../brainstorming/FinalizeConfirmDialog.tsx | 18 +- .../brainstorming/GenerateConfirmDialog.tsx | 38 +- .../GenerateFeatureConfirmDialog.tsx | 40 +- .../brainstorming/GenerateOwnTab.tsx | 134 +- .../brainstorming/GenerationProgressCard.tsx | 10 +- .../PendingQuestionsReviewModal.tsx | 19 +- .../brainstorming/PhaseStageProgress.tsx | 76 +- .../brainstorming/RenamePhaseDialog.tsx | 13 +- .../brainstorming/SpecDraftThreadPanel.tsx | 53 +- .../brainstorming/VersionDropdown.tsx | 12 +- frontend/components/chat/ChatFilters.tsx | 17 +- .../components/chat/CreateProjectBanner.tsx | 10 +- frontend/components/chat/MCQOptions.tsx | 4 +- .../components/chat/OnboardingSuggestions.tsx | 20 +- .../chat/ProjectChatMCQAnswerCard.tsx | 22 +- .../components/chat/ProjectCreatedBanner.tsx | 25 +- .../components/chat/QuickActionButtons.tsx | 11 +- frontend/components/chat/TypingIndicator.tsx | 13 +- .../dashboard/DashboardStatsCard.tsx | 18 +- .../components/dashboard/LLMUsageCard.tsx | 18 +- .../components/dashboard/PlanStatusCard.tsx | 28 +- .../dashboard/RecentLLMCallsTable.tsx | 34 +- frontend/components/drafts/DraftPicker.tsx | 25 +- frontend/components/editor/ActionToolbar.tsx | 6 +- frontend/components/editor/ChatInput.tsx | 177 +-- .../components/editor/DescriptionEditor.tsx | 2 +- .../components/editor/EntityMentionList.tsx | 37 +- .../editor/FeatureDisambiguationModal.tsx | 16 +- .../components/editor/FormattingToolbar.tsx | 20 +- .../components/editor/ImagePreviewGrid.tsx | 2 +- .../components/editor/ImagePreviewItem.tsx | 24 +- frontend/components/editor/MentionList.tsx | 10 +- frontend/components/editor/PlusMenu.tsx | 6 +- frontend/components/editor/RichTextEditor.tsx | 218 ++-- .../email-templates/EmailTemplateEditor.tsx | 62 +- .../email-templates/EmailTemplateSidebar.tsx | 30 +- .../email-templates/InsertVariableModal.tsx | 10 +- .../features/BulkClearStatusNotesDialog.tsx | 5 +- .../features/ClearStatusNotesDialog.tsx | 19 +- .../components/features/CopyableModuleKey.tsx | 9 +- .../features/CreateFeatureDialog.tsx | 21 +- .../features/CreateImplementationButton.tsx | 18 +- .../features/CreateModuleDialog.tsx | 13 +- .../components/features/EditModuleDialog.tsx | 20 +- .../features/FeatureContentEditModal.tsx | 24 +- .../features/FeatureContentViewer.tsx | 190 +-- .../components/features/FeatureDiffModal.tsx | 64 +- .../features/FeatureFilterToolbar.tsx | 64 +- frontend/components/features/FeatureRow.tsx | 106 +- .../components/features/FeatureSidebar.tsx | 96 +- .../features/FeatureStageProgress.tsx | 23 +- frontend/components/features/FeatureTable.tsx | 83 +- .../features/FeatureThreadColumn.tsx | 315 +++-- .../features/ImplementationSelector.tsx | 56 +- .../components/features/ImportIssueModal.tsx | 176 ++- frontend/components/features/ModuleCard.tsx | 115 +- .../features/RestoreFeatureDialog.tsx | 26 +- .../grounding/BranchMetadataPanel.tsx | 57 +- .../grounding/GroundingBranchTabs.tsx | 8 +- .../grounding/GroundingCreateModal.tsx | 12 +- .../grounding/GroundingDiffModal.tsx | 37 +- .../grounding/GroundingEditModal.tsx | 6 +- .../components/grounding/GroundingLayout.tsx | 106 +- .../grounding/GroundingNotesModal.tsx | 51 +- .../grounding/GroundingUpdateEvents.tsx | 39 +- .../grounding/MergeConfirmDialog.tsx | 51 +- .../components/inbox/ConversationCard.tsx | 43 +- .../inbox/ConversationCardSkeleton.tsx | 6 +- frontend/components/inbox/InboxBadge.tsx | 8 +- .../inbox/InboxConnectionStatus.tsx | 35 +- frontend/components/inbox/InboxContainer.tsx | 6 +- frontend/components/inbox/InboxEmptyState.tsx | 6 +- frontend/components/inbox/InboxHeader.tsx | 16 +- frontend/components/inbox/InboxNavBadge.tsx | 2 +- frontend/components/inbox/InboxSidebar.tsx | 86 +- .../components/inbox/InboxSidebarHeader.tsx | 4 +- .../components/inbox/InboxSidebarItem.tsx | 15 +- .../inbox/InboxSidebarProjectGroup.tsx | 23 +- .../components/inbox/InboxSidebarTrigger.tsx | 9 +- frontend/components/inbox/TrackedMessage.tsx | 7 +- .../mcp-explorer/MCPExplorerLayout.tsx | 33 +- .../mcp-explorer/VFSContentViewer.tsx | 46 +- .../mcp-explorer/VFSContextMenu.tsx | 33 +- .../mcp-explorer/VFSMetadataPanel.tsx | 37 +- .../mcp-explorer/VFSSearchResults.tsx | 42 +- .../components/mcp-explorer/VFSToolbar.tsx | 30 +- frontend/components/mcp-explorer/VFSTree.tsx | 25 +- frontend/components/mcp-log/MCPLogCard.tsx | 94 +- .../components/mcp-log/MCPLogDetailModal.tsx | 52 +- .../components/mcp-log/renderers/index.tsx | 276 ++-- .../members-groups/CreateGroupModal.tsx | 19 +- .../components/members-groups/GroupsTab.tsx | 96 +- .../members-groups/InviteMembersModal.tsx | 33 +- .../components/members-groups/MembersTab.tsx | 84 +- .../members-groups/OrgSettingsTab.tsx | 31 +- .../platform/AddPlatformLLMConnectorModal.tsx | 17 +- .../EditPlatformLLMConnectorModal.tsx | 13 +- .../platform/PlatformCodeExplorerSettings.tsx | 33 +- .../platform/PlatformConnectorsTab.tsx | 57 +- .../platform/PlatformEmailSettings.tsx | 65 +- .../platform/PlatformFreemiumSettings.tsx | 54 +- .../platform/PlatformGitHubOAuthSettings.tsx | 42 +- .../platform/PlatformLLMConnectors.tsx | 24 +- .../platform/PlatformLLMSettings.tsx | 24 +- .../PlatformObjectStorageSettings.tsx | 28 +- .../platform/PlatformOtherSettings.tsx | 17 +- .../components/platform/PlatformUserPlans.tsx | 97 +- .../platform/PlatformWebSearchSettings.tsx | 45 +- .../components/platform/S3ConnectorModal.tsx | 54 +- .../platform/SendgridConnectorModal.tsx | 22 +- .../BrownfieldConnectorStep.tsx | 19 +- .../project-wizard/BrownfieldRepoStep.tsx | 116 +- .../project-wizard/GreenfieldDetailsStep.tsx | 29 +- .../InlineGitHubConnectorForm.tsx | 68 +- .../project-wizard/PathSelectionStep.tsx | 64 +- .../project-wizard/ProjectTypeSelector.tsx | 14 +- .../project-wizard/ProjectWizard.tsx | 124 +- .../project-wizard/WizardStepIndicator.tsx | 21 +- frontend/components/ui/alert-dialog.tsx | 96 +- frontend/components/ui/alert.tsx | 54 +- frontend/components/ui/avatar.tsx | 31 +- frontend/components/ui/badge.tsx | 25 +- frontend/components/ui/button.tsx | 44 +- frontend/components/ui/card.tsx | 131 +- frontend/components/ui/checkbox.tsx | 24 +- frontend/components/ui/collapsible.tsx | 12 +- frontend/components/ui/command.tsx | 92 +- frontend/components/ui/dialog.tsx | 91 +- frontend/components/ui/dropdown-menu.tsx | 111 +- frontend/components/ui/input.tsx | 18 +- frontend/components/ui/label.tsx | 29 +- .../components/ui/pagination-controls.tsx | 22 +- frontend/components/ui/popover.tsx | 24 +- frontend/components/ui/progress.tsx | 21 +- frontend/components/ui/radio-group.tsx | 36 +- frontend/components/ui/scroll-area.tsx | 30 +- frontend/components/ui/select.tsx | 71 +- frontend/components/ui/separator.tsx | 43 +- frontend/components/ui/sheet.tsx | 92 +- frontend/components/ui/skeleton.tsx | 16 +- frontend/components/ui/switch.tsx | 20 +- frontend/components/ui/table.tsx | 110 +- frontend/components/ui/tabs.tsx | 34 +- frontend/components/ui/textarea.tsx | 37 +- frontend/components/ui/toast.tsx | 63 +- frontend/components/ui/toaster.tsx | 16 +- frontend/components/ui/tooltip.tsx | 24 +- .../unified-chat/UnifiedChatContainer.tsx | 78 +- .../unified-chat/UnifiedChatHeader.tsx | 36 +- .../unified-chat/UnifiedMessageList.tsx | 32 +- .../unified-chat/UnifiedProposalBanner.tsx | 44 +- frontend/components/unified-chat/types.ts | 4 +- frontend/e2e/HAPPY_PATHS.md | 61 + frontend/e2e/HAPPY_PATHS_SELECTORS.md | 694 +++++----- frontend/hooks/use-toast.ts | 133 +- frontend/lib/activity-helpers.ts | 4 +- frontend/lib/api/client.ts | 769 +++--------- frontend/lib/api/mcp-client.ts | 15 +- frontend/lib/api/types.ts | 65 +- frontend/lib/auth/AuthContext.tsx | 3 +- frontend/lib/auth/PlatformAdminContext.tsx | 4 +- frontend/lib/chatFilters.ts | 16 +- frontend/lib/contexts/InboxSidebarContext.tsx | 4 +- .../lib/editor/entity-mention-suggestion.ts | 8 +- frontend/lib/editor/mention-suggestion.ts | 6 +- frontend/lib/hooks/useBrowserNotifications.ts | 76 +- frontend/lib/hooks/useChatFilters.ts | 9 +- frontend/lib/hooks/useDashboardWebSocket.ts | 5 +- frontend/lib/hooks/useFeaturePreferences.ts | 37 +- frontend/lib/hooks/useFeatureWebSocket.ts | 10 +- frontend/lib/hooks/useFormDraft.ts | 48 +- frontend/lib/hooks/useGroundingWebSocket.ts | 19 +- frontend/lib/hooks/useHighlightedItem.ts | 6 +- .../lib/hooks/useImplementationWebSocket.ts | 5 +- frontend/lib/hooks/useInboxDeepLink.ts | 8 +- frontend/lib/hooks/useInboxEventHandler.ts | 27 +- frontend/lib/hooks/useInboxWebSocket.ts | 9 +- frontend/lib/hooks/useJobWebSocket.ts | 10 +- frontend/lib/hooks/useMessageHighlight.ts | 8 +- frontend/lib/hooks/useProjectChatWebSocket.ts | 9 +- frontend/lib/hooks/useProjectPermissions.ts | 4 +- frontend/lib/hooks/useRelativeTime.ts | 10 +- frontend/lib/hooks/useSessionDisclosure.ts | 14 +- .../useThreadDecisionSummaryWebSocket.ts | 7 +- frontend/lib/hooks/useThreadItemsWebSocket.ts | 7 +- frontend/lib/hooks/useTypingIndicator.ts | 39 +- frontend/lib/hooks/useUnifiedChat.ts | 193 ++- frontend/lib/hooks/useUnifiedChatWebSocket.ts | 13 +- frontend/lib/hooks/useViewportReadTracking.ts | 10 +- frontend/lib/url.ts | 43 +- frontend/lib/utils.ts | 6 +- frontend/lib/websocket/WebSocketContext.tsx | 27 +- frontend/lib/websocket/types.ts | 5 +- .../lib/websocket/useWebSocketSubscription.ts | 5 +- frontend/tsconfig.json | 14 +- prompting/phase_1/idea.md | 10 +- prompting/phase_1/prompt_plan.md | 1 - prompting/phase_4/spec.md | 3 +- prompting/phase_7/idea.md | 12 +- prompting/phase_7/spec.md | 2 +- prompting/phase_8/spec.md | 2 +- 1007 files changed, 19524 insertions(+), 23805 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ee3563e..e0c4428 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -16,11 +16,11 @@ ## Database Migrations -- +- ## Env Config -- +- ## Relevant Docs diff --git a/README.md b/README.md index 264b328..61bab0a 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Need to give your existing projects the vibe coding boost? All you need to do is ## 🔍 Grounded and relevant -mfbt includes Code Exploration and Web Search agents to ensure that the implementation plans that are generated are not just based on your conversations, but are also based on real code exploration and web searches. +mfbt includes Code Exploration and Web Search agents to ensure that the implementation plans that are generated are not just based on your conversations, but are also based on real code exploration and web searches. diff --git a/backend/alembic/README b/backend/alembic/README index 98e4f9c..2500aa1 100644 --- a/backend/alembic/README +++ b/backend/alembic/README @@ -1 +1 @@ -Generic single-database configuration. \ No newline at end of file +Generic single-database configuration. diff --git a/backend/alembic/versions/00ebcf349edc_add_feature_completion_tracking.py b/backend/alembic/versions/00ebcf349edc_add_feature_completion_tracking.py index a87e32a..b038052 100644 --- a/backend/alembic/versions/00ebcf349edc_add_feature_completion_tracking.py +++ b/backend/alembic/versions/00ebcf349edc_add_feature_completion_tracking.py @@ -5,16 +5,17 @@ Create Date: 2025-12-05 14:21:20.764109 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = '00ebcf349edc' -down_revision: Union[str, Sequence[str], None] = '6baa75dcb961' +revision: str = "00ebcf349edc" +down_revision: Union[str, Sequence[str], None] = "6baa75dcb961" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -23,48 +24,21 @@ def upgrade() -> None: """Add feature completion tracking fields.""" # Create the feature_completion_status enum for PostgreSQL bind = op.get_bind() - if bind.dialect.name == 'postgresql': + if bind.dialect.name == "postgresql": feature_completion_status = postgresql.ENUM( - 'pending', 'in_progress', 'completed', - name='feature_completion_status', - create_type=False + "pending", "in_progress", "completed", name="feature_completion_status", create_type=False ) feature_completion_status.create(bind, checkfirst=True) # Add completion tracking columns to features table - op.add_column( - 'features', - sa.Column( - 'completion_status', - sa.String(20), - server_default='pending', - nullable=False - ) - ) - op.add_column( - 'features', - sa.Column('completion_summary', sa.Text(), nullable=True) - ) - op.add_column( - 'features', - sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True) - ) - op.add_column( - 'features', - sa.Column( - 'completed_by_id', - sa.UUID(), - nullable=True - ) - ) + op.add_column("features", sa.Column("completion_status", sa.String(20), server_default="pending", nullable=False)) + op.add_column("features", sa.Column("completion_summary", sa.Text(), nullable=True)) + op.add_column("features", sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True)) + op.add_column("features", sa.Column("completed_by_id", sa.UUID(), nullable=True)) # Add foreign key constraint (PostgreSQL only due to SQLite limitations) - if bind.dialect.name == 'postgresql': - op.create_foreign_key( - 'fk_features_completed_by_id', - 'features', 'users', - ['completed_by_id'], ['id'] - ) + if bind.dialect.name == "postgresql": + op.create_foreign_key("fk_features_completed_by_id", "features", "users", ["completed_by_id"], ["id"]) def downgrade() -> None: @@ -72,15 +46,15 @@ def downgrade() -> None: bind = op.get_bind() # Drop foreign key constraint (PostgreSQL only) - if bind.dialect.name == 'postgresql': - op.drop_constraint('fk_features_completed_by_id', 'features', type_='foreignkey') + if bind.dialect.name == "postgresql": + op.drop_constraint("fk_features_completed_by_id", "features", type_="foreignkey") # Drop columns - op.drop_column('features', 'completed_by_id') - op.drop_column('features', 'completed_at') - op.drop_column('features', 'completion_summary') - op.drop_column('features', 'completion_status') + op.drop_column("features", "completed_by_id") + op.drop_column("features", "completed_at") + op.drop_column("features", "completion_summary") + op.drop_column("features", "completion_status") # Drop the enum type (PostgreSQL only) - if bind.dialect.name == 'postgresql': + if bind.dialect.name == "postgresql": op.execute("DROP TYPE IF EXISTS feature_completion_status") diff --git a/backend/alembic/versions/021b37581165_add_completion_summary_and_triggered_by_.py b/backend/alembic/versions/021b37581165_add_completion_summary_and_triggered_by_.py index 1aee3e9..8e42599 100644 --- a/backend/alembic/versions/021b37581165_add_completion_summary_and_triggered_by_.py +++ b/backend/alembic/versions/021b37581165_add_completion_summary_and_triggered_by_.py @@ -5,26 +5,27 @@ Create Date: 2025-11-20 13:03:09.128947 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '021b37581165' -down_revision: Union[str, Sequence[str], None] = 'ed7322775e46' +revision: str = "021b37581165" +down_revision: Union[str, Sequence[str], None] = "ed7322775e46" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" - op.add_column('implementation_phases', sa.Column('completion_summary', sa.Text(), nullable=True)) - op.add_column('implementation_phases', sa.Column('triggered_by', sa.String(length=50), nullable=True)) + op.add_column("implementation_phases", sa.Column("completion_summary", sa.Text(), nullable=True)) + op.add_column("implementation_phases", sa.Column("triggered_by", sa.String(length=50), nullable=True)) def downgrade() -> None: """Downgrade schema.""" - op.drop_column('implementation_phases', 'triggered_by') - op.drop_column('implementation_phases', 'completion_summary') + op.drop_column("implementation_phases", "triggered_by") + op.drop_column("implementation_phases", "completion_summary") diff --git a/backend/alembic/versions/0c4f46a254f8_drop_conversation_generation_triggers_.py b/backend/alembic/versions/0c4f46a254f8_drop_conversation_generation_triggers_.py index a540909..2260abc 100644 --- a/backend/alembic/versions/0c4f46a254f8_drop_conversation_generation_triggers_.py +++ b/backend/alembic/versions/0c4f46a254f8_drop_conversation_generation_triggers_.py @@ -9,16 +9,17 @@ instead of using a batched scheduler with trigger records. """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = '0c4f46a254f8' -down_revision: Union[str, Sequence[str], None] = 'u0v1w2x3y4z5' +revision: str = "0c4f46a254f8" +down_revision: Union[str, Sequence[str], None] = "u0v1w2x3y4z5" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/0c5625ec8ba1_add_project_memberships_table.py b/backend/alembic/versions/0c5625ec8ba1_add_project_memberships_table.py index 82dc9cc..2cbbae4 100644 --- a/backend/alembic/versions/0c5625ec8ba1_add_project_memberships_table.py +++ b/backend/alembic/versions/0c5625ec8ba1_add_project_memberships_table.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 10:37:38.546482 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '0c5625ec8ba1' -down_revision: Union[str, Sequence[str], None] = 'a60117da6409' +revision: str = "0c5625ec8ba1" +down_revision: Union[str, Sequence[str], None] = "a60117da6409" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,28 +23,32 @@ def upgrade() -> None: """Upgrade schema.""" # Create project_memberships table op.create_table( - 'project_memberships', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('role', sa.String(length=20), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], name=op.f('fk_project_memberships_project_id_projects'), ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_project_memberships_user_id_users'), ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id', name=op.f('pk_project_memberships')), - sa.UniqueConstraint('project_id', 'user_id', name=op.f('uq_project_memberships_project_id_user_id')) + "project_memberships", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("role", sa.String(length=20), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint( + ["project_id"], ["projects.id"], name=op.f("fk_project_memberships_project_id_projects"), ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["user_id"], ["users.id"], name=op.f("fk_project_memberships_user_id_users"), ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_project_memberships")), + sa.UniqueConstraint("project_id", "user_id", name=op.f("uq_project_memberships_project_id_user_id")), ) # Create indexes for foreign keys - op.create_index(op.f('ix_project_memberships_project_id'), 'project_memberships', ['project_id'], unique=False) - op.create_index(op.f('ix_project_memberships_user_id'), 'project_memberships', ['user_id'], unique=False) + op.create_index(op.f("ix_project_memberships_project_id"), "project_memberships", ["project_id"], unique=False) + op.create_index(op.f("ix_project_memberships_user_id"), "project_memberships", ["user_id"], unique=False) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index(op.f('ix_project_memberships_user_id'), table_name='project_memberships') - op.drop_index(op.f('ix_project_memberships_project_id'), table_name='project_memberships') + op.drop_index(op.f("ix_project_memberships_user_id"), table_name="project_memberships") + op.drop_index(op.f("ix_project_memberships_project_id"), table_name="project_memberships") # Drop table - op.drop_table('project_memberships') + op.drop_table("project_memberships") diff --git a/backend/alembic/versions/0d2b4de48ed2_add_org_plan_fields.py b/backend/alembic/versions/0d2b4de48ed2_add_org_plan_fields.py index fe7680e..06b2cc2 100644 --- a/backend/alembic/versions/0d2b4de48ed2_add_org_plan_fields.py +++ b/backend/alembic/versions/0d2b4de48ed2_add_org_plan_fields.py @@ -8,67 +8,41 @@ Create Date: 2025-12-16 19:19:14.846155 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '0d2b4de48ed2' -down_revision: Union[str, Sequence[str], None] = '4691251c9f11' +revision: str = "0d2b4de48ed2" +down_revision: Union[str, Sequence[str], None] = "4691251c9f11" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Add plan/subscription fields to organizations table.""" - op.add_column( - 'organizations', - sa.Column('plan_name', sa.String(100), nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_llm_tokens_per_month', sa.Integer, nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_llm_tokens_total', sa.Integer, nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_llm_tokens_used', sa.Integer, nullable=False, server_default='0') - ) - op.add_column( - 'organizations', - sa.Column('plan_max_projects', sa.Integer, nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_max_users', sa.Integer, nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_start_date', sa.DateTime(timezone=True), nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_end_date', sa.DateTime(timezone=True), nullable=True) - ) - op.add_column( - 'organizations', - sa.Column('plan_billing_cycle_start', sa.Integer, nullable=True, server_default='1') - ) + op.add_column("organizations", sa.Column("plan_name", sa.String(100), nullable=True)) + op.add_column("organizations", sa.Column("plan_llm_tokens_per_month", sa.Integer, nullable=True)) + op.add_column("organizations", sa.Column("plan_llm_tokens_total", sa.Integer, nullable=True)) + op.add_column("organizations", sa.Column("plan_llm_tokens_used", sa.Integer, nullable=False, server_default="0")) + op.add_column("organizations", sa.Column("plan_max_projects", sa.Integer, nullable=True)) + op.add_column("organizations", sa.Column("plan_max_users", sa.Integer, nullable=True)) + op.add_column("organizations", sa.Column("plan_start_date", sa.DateTime(timezone=True), nullable=True)) + op.add_column("organizations", sa.Column("plan_end_date", sa.DateTime(timezone=True), nullable=True)) + op.add_column("organizations", sa.Column("plan_billing_cycle_start", sa.Integer, nullable=True, server_default="1")) def downgrade() -> None: """Remove plan/subscription fields from organizations table.""" - op.drop_column('organizations', 'plan_billing_cycle_start') - op.drop_column('organizations', 'plan_end_date') - op.drop_column('organizations', 'plan_start_date') - op.drop_column('organizations', 'plan_max_users') - op.drop_column('organizations', 'plan_max_projects') - op.drop_column('organizations', 'plan_llm_tokens_used') - op.drop_column('organizations', 'plan_llm_tokens_total') - op.drop_column('organizations', 'plan_llm_tokens_per_month') - op.drop_column('organizations', 'plan_name') + op.drop_column("organizations", "plan_billing_cycle_start") + op.drop_column("organizations", "plan_end_date") + op.drop_column("organizations", "plan_start_date") + op.drop_column("organizations", "plan_max_users") + op.drop_column("organizations", "plan_max_projects") + op.drop_column("organizations", "plan_llm_tokens_used") + op.drop_column("organizations", "plan_llm_tokens_total") + op.drop_column("organizations", "plan_llm_tokens_per_month") + op.drop_column("organizations", "plan_name") diff --git a/backend/alembic/versions/1227d87646fe_fix_spec_version_unique_constraint_per_.py b/backend/alembic/versions/1227d87646fe_fix_spec_version_unique_constraint_per_.py index 42812dd..c07e6dd 100644 --- a/backend/alembic/versions/1227d87646fe_fix_spec_version_unique_constraint_per_.py +++ b/backend/alembic/versions/1227d87646fe_fix_spec_version_unique_constraint_per_.py @@ -5,15 +5,14 @@ Create Date: 2025-12-19 21:56:40.605043 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '1227d87646fe' -down_revision: Union[str, Sequence[str], None] = '8bba664c9d7b' +revision: str = "1227d87646fe" +down_revision: Union[str, Sequence[str], None] = "8bba664c9d7b" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -26,24 +25,18 @@ def upgrade() -> None: its own version sequence (v1, v2, v3...). """ # Drop old constraint (per project) - op.drop_constraint('uq_project_spec_version', 'spec_versions', type_='unique') + op.drop_constraint("uq_project_spec_version", "spec_versions", type_="unique") # Create new constraint (per phase) op.create_unique_constraint( - 'uq_phase_spec_version', - 'spec_versions', - ['brainstorming_phase_id', 'spec_type', 'version'] + "uq_phase_spec_version", "spec_versions", ["brainstorming_phase_id", "spec_type", "version"] ) def downgrade() -> None: """Downgrade schema.""" # Drop new constraint - op.drop_constraint('uq_phase_spec_version', 'spec_versions', type_='unique') + op.drop_constraint("uq_phase_spec_version", "spec_versions", type_="unique") # Restore old constraint - op.create_unique_constraint( - 'uq_project_spec_version', - 'spec_versions', - ['project_id', 'spec_type', 'version'] - ) + op.create_unique_constraint("uq_project_spec_version", "spec_versions", ["project_id", "spec_type", "version"]) diff --git a/backend/alembic/versions/178e7cb83afe_add_implementation_phases_table.py b/backend/alembic/versions/178e7cb83afe_add_implementation_phases_table.py index f6f2d21..92a4957 100644 --- a/backend/alembic/versions/178e7cb83afe_add_implementation_phases_table.py +++ b/backend/alembic/versions/178e7cb83afe_add_implementation_phases_table.py @@ -5,15 +5,14 @@ Create Date: 2025-11-20 11:47:32.962280 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '178e7cb83afe' -down_revision: Union[str, Sequence[str], None] = '8e970e133ccd' +revision: str = "178e7cb83afe" +down_revision: Union[str, Sequence[str], None] = "8e970e133ccd" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -45,18 +44,18 @@ def upgrade() -> None: """) # Create indexes - op.create_index('ix_implementation_phases_project_id', 'implementation_phases', ['project_id']) - op.create_index('ix_implementation_phases_status', 'implementation_phases', ['project_id', 'status']) + op.create_index("ix_implementation_phases_project_id", "implementation_phases", ["project_id"]) + op.create_index("ix_implementation_phases_status", "implementation_phases", ["project_id", "status"]) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index('ix_implementation_phases_status', table_name='implementation_phases') - op.drop_index('ix_implementation_phases_project_id', table_name='implementation_phases') + op.drop_index("ix_implementation_phases_status", table_name="implementation_phases") + op.drop_index("ix_implementation_phases_project_id", table_name="implementation_phases") # Drop table - op.drop_table('implementation_phases') + op.drop_table("implementation_phases") # Drop ENUM op.execute("DROP TYPE phase_status") diff --git a/backend/alembic/versions/17e2d7ed64e2_add_testing_debug_settings.py b/backend/alembic/versions/17e2d7ed64e2_add_testing_debug_settings.py index a33a377..65ad74d 100644 --- a/backend/alembic/versions/17e2d7ed64e2_add_testing_debug_settings.py +++ b/backend/alembic/versions/17e2d7ed64e2_add_testing_debug_settings.py @@ -5,15 +5,16 @@ Create Date: 2025-11-22 09:37:51.442393 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '17e2d7ed64e2' -down_revision: Union[str, Sequence[str], None] = '840545b82f16' +revision: str = "17e2d7ed64e2" +down_revision: Union[str, Sequence[str], None] = "840545b82f16" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,40 +23,33 @@ def upgrade() -> None: """Add testing & debugging settings to llm_preferences table.""" # Add mock_discovery_enabled column (boolean, default false) op.add_column( - 'llm_preferences', - sa.Column('mock_discovery_enabled', sa.Boolean(), nullable=False, server_default='false') + "llm_preferences", sa.Column("mock_discovery_enabled", sa.Boolean(), nullable=False, server_default="false") ) # Add mock_discovery_question_limit column (integer, default 10, check constraint) op.add_column( - 'llm_preferences', - sa.Column('mock_discovery_question_limit', sa.Integer(), nullable=False, server_default='10') + "llm_preferences", sa.Column("mock_discovery_question_limit", sa.Integer(), nullable=False, server_default="10") ) op.create_check_constraint( - 'ck_mock_discovery_question_limit', - 'llm_preferences', - 'mock_discovery_question_limit IN (10, 20, 30)' + "ck_mock_discovery_question_limit", "llm_preferences", "mock_discovery_question_limit IN (10, 20, 30)" ) # Add mock_discovery_delay_seconds column (integer, default 5, check constraint) op.add_column( - 'llm_preferences', - sa.Column('mock_discovery_delay_seconds', sa.Integer(), nullable=False, server_default='5') + "llm_preferences", sa.Column("mock_discovery_delay_seconds", sa.Integer(), nullable=False, server_default="5") ) op.create_check_constraint( - 'ck_mock_discovery_delay_seconds', - 'llm_preferences', - 'mock_discovery_delay_seconds IN (5, 10, 20)' + "ck_mock_discovery_delay_seconds", "llm_preferences", "mock_discovery_delay_seconds IN (5, 10, 20)" ) def downgrade() -> None: """Remove testing & debugging settings from llm_preferences table.""" # Drop check constraints first - op.drop_constraint('ck_mock_discovery_delay_seconds', 'llm_preferences', type_='check') - op.drop_constraint('ck_mock_discovery_question_limit', 'llm_preferences', type_='check') + op.drop_constraint("ck_mock_discovery_delay_seconds", "llm_preferences", type_="check") + op.drop_constraint("ck_mock_discovery_question_limit", "llm_preferences", type_="check") # Drop columns - op.drop_column('llm_preferences', 'mock_discovery_delay_seconds') - op.drop_column('llm_preferences', 'mock_discovery_question_limit') - op.drop_column('llm_preferences', 'mock_discovery_enabled') + op.drop_column("llm_preferences", "mock_discovery_delay_seconds") + op.drop_column("llm_preferences", "mock_discovery_question_limit") + op.drop_column("llm_preferences", "mock_discovery_enabled") diff --git a/backend/alembic/versions/1931abc02c3f_add_user_auth_fields.py b/backend/alembic/versions/1931abc02c3f_add_user_auth_fields.py index 89a4c95..36abdca 100644 --- a/backend/alembic/versions/1931abc02c3f_add_user_auth_fields.py +++ b/backend/alembic/versions/1931abc02c3f_add_user_auth_fields.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 09:16:02.362772 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '1931abc02c3f' -down_revision: Union[str, Sequence[str], None] = 'a8845f795cf1' +revision: str = "1931abc02c3f" +down_revision: Union[str, Sequence[str], None] = "a8845f795cf1" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,16 +22,16 @@ def upgrade() -> None: """Upgrade schema.""" # Add password_hash field - op.add_column('users', sa.Column('password_hash', sa.String(length=255), nullable=False)) + op.add_column("users", sa.Column("password_hash", sa.String(length=255), nullable=False)) # Add display_name field - op.add_column('users', sa.Column('display_name', sa.String(length=100), nullable=True)) + op.add_column("users", sa.Column("display_name", sa.String(length=100), nullable=True)) def downgrade() -> None: """Downgrade schema.""" # Remove display_name field - op.drop_column('users', 'display_name') + op.drop_column("users", "display_name") # Remove password_hash field - op.drop_column('users', 'password_hash') + op.drop_column("users", "password_hash") diff --git a/backend/alembic/versions/19fa5dd5cb52_add_organization_id_to_organizations.py b/backend/alembic/versions/19fa5dd5cb52_add_organization_id_to_organizations.py index a9b43ee..379ba42 100644 --- a/backend/alembic/versions/19fa5dd5cb52_add_organization_id_to_organizations.py +++ b/backend/alembic/versions/19fa5dd5cb52_add_organization_id_to_organizations.py @@ -5,15 +5,16 @@ Create Date: 2025-12-24 00:20:21.946817 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '19fa5dd5cb52' -down_revision: Union[str, Sequence[str], None] = 'j1k2l3m4n5o6' +revision: str = "19fa5dd5cb52" +down_revision: Union[str, Sequence[str], None] = "j1k2l3m4n5o6" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,22 +22,14 @@ def upgrade() -> None: """Upgrade schema.""" # Add organization_id column to organizations table - op.add_column( - 'organizations', - sa.Column('organization_id', sa.String(length=255), nullable=True) - ) + op.add_column("organizations", sa.Column("organization_id", sa.String(length=255), nullable=True)) # Create unique index on organization_id - op.create_index( - 'ix_organizations_organization_id', - 'organizations', - ['organization_id'], - unique=True - ) + op.create_index("ix_organizations_organization_id", "organizations", ["organization_id"], unique=True) def downgrade() -> None: """Downgrade schema.""" # Drop the index first - op.drop_index('ix_organizations_organization_id', table_name='organizations') + op.drop_index("ix_organizations_organization_id", table_name="organizations") # Drop the column - op.drop_column('organizations', 'organization_id') + op.drop_column("organizations", "organization_id") diff --git a/backend/alembic/versions/1cfb3bffcc2a_phase3_autogen_discovery_schema.py b/backend/alembic/versions/1cfb3bffcc2a_phase3_autogen_discovery_schema.py index f63471d..516fd3d 100644 --- a/backend/alembic/versions/1cfb3bffcc2a_phase3_autogen_discovery_schema.py +++ b/backend/alembic/versions/1cfb3bffcc2a_phase3_autogen_discovery_schema.py @@ -5,15 +5,16 @@ Create Date: 2025-11-21 10:58:33.178811 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '1cfb3bffcc2a' -down_revision: Union[str, Sequence[str], None] = '6ebda2d1112d' +revision: str = "1cfb3bffcc2a" +down_revision: Union[str, Sequence[str], None] = "6ebda2d1112d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -71,37 +72,46 @@ def upgrade() -> None: op.execute("DROP TYPE discoverystatus_old;") # 3. Add new columns for Phase 3 - op.add_column('discovery_questions', sa.Column('type', sa.String(), nullable=False, server_default='mcq_with_custom_input')) - op.add_column('discovery_questions', sa.Column('allows_custom_input', sa.Boolean(), nullable=False, server_default='true')) - op.add_column('discovery_questions', sa.Column('followup_to', sa.dialects.postgresql.UUID(as_uuid=True), nullable=True)) - op.add_column('discovery_questions', sa.Column('depth', sa.Integer(), nullable=False, server_default='0')) - op.add_column('discovery_questions', sa.Column('explanation', sa.Text(), nullable=True)) - op.add_column('discovery_questions', sa.Column('internal_agent_notes', sa.Text(), nullable=True)) - op.add_column('discovery_questions', sa.Column('discussion_thread_id', sa.dialects.postgresql.UUID(as_uuid=True), nullable=True)) - op.add_column('discovery_questions', sa.Column('suggested_follow_ups', sa.JSON(), nullable=True)) + op.add_column( + "discovery_questions", sa.Column("type", sa.String(), nullable=False, server_default="mcq_with_custom_input") + ) + op.add_column( + "discovery_questions", sa.Column("allows_custom_input", sa.Boolean(), nullable=False, server_default="true") + ) + op.add_column( + "discovery_questions", sa.Column("followup_to", sa.dialects.postgresql.UUID(as_uuid=True), nullable=True) + ) + op.add_column("discovery_questions", sa.Column("depth", sa.Integer(), nullable=False, server_default="0")) + op.add_column("discovery_questions", sa.Column("explanation", sa.Text(), nullable=True)) + op.add_column("discovery_questions", sa.Column("internal_agent_notes", sa.Text(), nullable=True)) + op.add_column( + "discovery_questions", + sa.Column("discussion_thread_id", sa.dialects.postgresql.UUID(as_uuid=True), nullable=True), + ) + op.add_column("discovery_questions", sa.Column("suggested_follow_ups", sa.JSON(), nullable=True)) # 4. Add foreign key for followup_to op.create_foreign_key( - 'fk_discovery_questions_followup_to', - 'discovery_questions', - 'discovery_questions', - ['followup_to'], - ['id'], - ondelete='CASCADE' + "fk_discovery_questions_followup_to", + "discovery_questions", + "discovery_questions", + ["followup_to"], + ["id"], + ondelete="CASCADE", ) # 5. Add foreign key for discussion_thread_id op.create_foreign_key( - 'fk_discovery_questions_discussion_thread_id', - 'discovery_questions', - 'threads', - ['discussion_thread_id'], - ['id'], - ondelete='SET NULL' + "fk_discovery_questions_discussion_thread_id", + "discovery_questions", + "threads", + ["discussion_thread_id"], + ["id"], + ondelete="SET NULL", ) # 6. Create index on followup_to for better query performance - op.create_index('ix_discovery_questions_followup_to', 'discovery_questions', ['followup_to']) + op.create_index("ix_discovery_questions_followup_to", "discovery_questions", ["followup_to"]) # 7. Migrate existing data: map depends_on to followup_to op.execute(""" @@ -112,16 +122,16 @@ def upgrade() -> None: """) # 8. Add intent and complexity to projects table for caching - op.add_column('projects', sa.Column('discovery_intent', sa.String(), nullable=True)) - op.add_column('projects', sa.Column('discovery_complexity', sa.String(), nullable=True)) + op.add_column("projects", sa.Column("discovery_intent", sa.String(), nullable=True)) + op.add_column("projects", sa.Column("discovery_complexity", sa.String(), nullable=True)) def downgrade() -> None: """Downgrade schema to Phase 1.""" # 1. Remove columns from projects table - op.drop_column('projects', 'discovery_complexity') - op.drop_column('projects', 'discovery_intent') + op.drop_column("projects", "discovery_complexity") + op.drop_column("projects", "discovery_intent") # 2. Migrate followup_to back to depends_on op.execute(""" @@ -131,19 +141,19 @@ def downgrade() -> None: """) # 3. Drop indexes and foreign keys - op.drop_index('ix_discovery_questions_followup_to', table_name='discovery_questions') - op.drop_constraint('fk_discovery_questions_discussion_thread_id', 'discovery_questions', type_='foreignkey') - op.drop_constraint('fk_discovery_questions_followup_to', 'discovery_questions', type_='foreignkey') + op.drop_index("ix_discovery_questions_followup_to", table_name="discovery_questions") + op.drop_constraint("fk_discovery_questions_discussion_thread_id", "discovery_questions", type_="foreignkey") + op.drop_constraint("fk_discovery_questions_followup_to", "discovery_questions", type_="foreignkey") # 4. Drop new columns - op.drop_column('discovery_questions', 'suggested_follow_ups') - op.drop_column('discovery_questions', 'discussion_thread_id') - op.drop_column('discovery_questions', 'internal_agent_notes') - op.drop_column('discovery_questions', 'explanation') - op.drop_column('discovery_questions', 'depth') - op.drop_column('discovery_questions', 'followup_to') - op.drop_column('discovery_questions', 'allows_custom_input') - op.drop_column('discovery_questions', 'type') + op.drop_column("discovery_questions", "suggested_follow_ups") + op.drop_column("discovery_questions", "discussion_thread_id") + op.drop_column("discovery_questions", "internal_agent_notes") + op.drop_column("discovery_questions", "explanation") + op.drop_column("discovery_questions", "depth") + op.drop_column("discovery_questions", "followup_to") + op.drop_column("discovery_questions", "allows_custom_input") + op.drop_column("discovery_questions", "type") # 5. Revert status enum op.execute(""" diff --git a/backend/alembic/versions/1e859fbdbd74_add_spec_coverage_reports.py b/backend/alembic/versions/1e859fbdbd74_add_spec_coverage_reports.py index 8ce7dac..1b64a0d 100644 --- a/backend/alembic/versions/1e859fbdbd74_add_spec_coverage_reports.py +++ b/backend/alembic/versions/1e859fbdbd74_add_spec_coverage_reports.py @@ -5,16 +5,17 @@ Create Date: 2025-11-24 11:43:58.546628 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import UUID, JSON +from sqlalchemy.dialects.postgresql import JSON, UUID +from alembic import op # revision identifiers, used by Alembic. -revision: str = '1e859fbdbd74' -down_revision: Union[str, Sequence[str], None] = 'bc6ddbf8b5b7' +revision: str = "1e859fbdbd74" +down_revision: Union[str, Sequence[str], None] = "bc6ddbf8b5b7" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,21 +23,23 @@ def upgrade() -> None: """Upgrade schema.""" op.create_table( - 'spec_coverage_reports', - sa.Column('id', UUID(as_uuid=True), primary_key=True), - sa.Column('spec_version_id', UUID(as_uuid=True), sa.ForeignKey('spec_versions.id'), nullable=False, unique=True), - sa.Column('ok', sa.Boolean(), nullable=False), - sa.Column('uncovered_must_have_questions', JSON(), nullable=False, server_default='[]'), - sa.Column('weak_coverage_warnings', JSON(), nullable=False, server_default='[]'), - sa.Column('contradictions_found', JSON(), nullable=False, server_default='[]'), - sa.Column('suggested_rewrites', JSON(), nullable=False, server_default='[]'), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + "spec_coverage_reports", + sa.Column("id", UUID(as_uuid=True), primary_key=True), + sa.Column( + "spec_version_id", UUID(as_uuid=True), sa.ForeignKey("spec_versions.id"), nullable=False, unique=True + ), + sa.Column("ok", sa.Boolean(), nullable=False), + sa.Column("uncovered_must_have_questions", JSON(), nullable=False, server_default="[]"), + sa.Column("weak_coverage_warnings", JSON(), nullable=False, server_default="[]"), + sa.Column("contradictions_found", JSON(), nullable=False, server_default="[]"), + sa.Column("suggested_rewrites", JSON(), nullable=False, server_default="[]"), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), ) # Create index on spec_version_id for fast lookups - op.create_index('ix_spec_coverage_reports_spec_version_id', 'spec_coverage_reports', ['spec_version_id']) + op.create_index("ix_spec_coverage_reports_spec_version_id", "spec_coverage_reports", ["spec_version_id"]) def downgrade() -> None: """Downgrade schema.""" - op.drop_index('ix_spec_coverage_reports_spec_version_id', table_name='spec_coverage_reports') - op.drop_table('spec_coverage_reports') + op.drop_index("ix_spec_coverage_reports_spec_version_id", table_name="spec_coverage_reports") + op.drop_table("spec_coverage_reports") diff --git a/backend/alembic/versions/22633d1c969b_add_feature_key_number.py b/backend/alembic/versions/22633d1c969b_add_feature_key_number.py index 9418a73..3f874d6 100644 --- a/backend/alembic/versions/22633d1c969b_add_feature_key_number.py +++ b/backend/alembic/versions/22633d1c969b_add_feature_key_number.py @@ -5,22 +5,23 @@ Create Date: 2025-01-01 00:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '22633d1c969b' -down_revision: Union[str, None] = 'c8d9e0f1g2h3' +revision: str = "22633d1c969b" +down_revision: Union[str, None] = "c8d9e0f1g2h3" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add feature_key_number column (nullable initially for backfill) - op.add_column('features', sa.Column('feature_key_number', sa.Integer(), nullable=True)) + op.add_column("features", sa.Column("feature_key_number", sa.Integer(), nullable=True)) # Backfill existing rows by extracting the numeric part after the last hyphen # e.g., "1F40-090" -> 90, "USR-1234" -> 1234 @@ -33,15 +34,15 @@ def upgrade() -> None: """) # Make it non-nullable after backfill - op.alter_column('features', 'feature_key_number', nullable=False, server_default='0') + op.alter_column("features", "feature_key_number", nullable=False, server_default="0") # Remove the server default (it was just for the alter) - op.alter_column('features', 'feature_key_number', server_default=None) + op.alter_column("features", "feature_key_number", server_default=None) # Add index for efficient sorting - op.create_index('ix_features_feature_key_number', 'features', ['feature_key_number']) + op.create_index("ix_features_feature_key_number", "features", ["feature_key_number"]) def downgrade() -> None: - op.drop_index('ix_features_feature_key_number', table_name='features') - op.drop_column('features', 'feature_key_number') + op.drop_index("ix_features_feature_key_number", table_name="features") + op.drop_column("features", "feature_key_number") diff --git a/backend/alembic/versions/23893c2e92fa_add_llm_usage_tracking_to_jobs.py b/backend/alembic/versions/23893c2e92fa_add_llm_usage_tracking_to_jobs.py index bb6d2e7..eef497c 100644 --- a/backend/alembic/versions/23893c2e92fa_add_llm_usage_tracking_to_jobs.py +++ b/backend/alembic/versions/23893c2e92fa_add_llm_usage_tracking_to_jobs.py @@ -5,30 +5,31 @@ Create Date: 2025-12-05 16:14:42.646001 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '23893c2e92fa' -down_revision: Union[str, Sequence[str], None] = '00ebcf349edc' +revision: str = "23893c2e92fa" +down_revision: Union[str, Sequence[str], None] = "00ebcf349edc" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Add LLM usage tracking columns to jobs table.""" - op.add_column('jobs', sa.Column('model_used', sa.String(length=100), nullable=True)) - op.add_column('jobs', sa.Column('total_prompt_tokens', sa.Integer(), nullable=True)) - op.add_column('jobs', sa.Column('total_completion_tokens', sa.Integer(), nullable=True)) - op.add_column('jobs', sa.Column('total_cost_usd', sa.Numeric(precision=10, scale=6), nullable=True)) + op.add_column("jobs", sa.Column("model_used", sa.String(length=100), nullable=True)) + op.add_column("jobs", sa.Column("total_prompt_tokens", sa.Integer(), nullable=True)) + op.add_column("jobs", sa.Column("total_completion_tokens", sa.Integer(), nullable=True)) + op.add_column("jobs", sa.Column("total_cost_usd", sa.Numeric(precision=10, scale=6), nullable=True)) def downgrade() -> None: """Remove LLM usage tracking columns from jobs table.""" - op.drop_column('jobs', 'total_cost_usd') - op.drop_column('jobs', 'total_completion_tokens') - op.drop_column('jobs', 'total_prompt_tokens') - op.drop_column('jobs', 'model_used') + op.drop_column("jobs", "total_cost_usd") + op.drop_column("jobs", "total_completion_tokens") + op.drop_column("jobs", "total_prompt_tokens") + op.drop_column("jobs", "model_used") diff --git a/backend/alembic/versions/264f9632081a_add_notification_tables.py b/backend/alembic/versions/264f9632081a_add_notification_tables.py index 57888db..aa52f45 100644 --- a/backend/alembic/versions/264f9632081a_add_notification_tables.py +++ b/backend/alembic/versions/264f9632081a_add_notification_tables.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 00:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '264f9632081a' -down_revision: Union[str, Sequence[str], None] = '2aec09f57c3e' +revision: str = "264f9632081a" +down_revision: Union[str, Sequence[str], None] = "2aec09f57c3e" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,61 +23,61 @@ def upgrade() -> None: """Upgrade schema.""" # Create notification_preferences table op.create_table( - 'notification_preferences', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('channel', sa.Enum('EMAIL', 'SLACK', 'TEAMS', name='notificationchannel'), nullable=False), - sa.Column('enabled', sa.Boolean(), nullable=False, server_default=sa.true()), - sa.Column('channel_config', sa.String(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "notification_preferences", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("channel", sa.Enum("EMAIL", "SLACK", "TEAMS", name="notificationchannel"), nullable=False), + sa.Column("enabled", sa.Boolean(), nullable=False, server_default=sa.true()), + sa.Column("channel_config", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('ix_notification_preferences_user_id', 'notification_preferences', ['user_id']) + op.create_index("ix_notification_preferences_user_id", "notification_preferences", ["user_id"]) # Create notification_project_mutes table op.create_table( - 'notification_project_mutes', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('user_id', 'project_id', name='uq_user_project_mute') + "notification_project_mutes", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "project_id", name="uq_user_project_mute"), ) - op.create_index('ix_notification_project_mutes_user_id', 'notification_project_mutes', ['user_id']) - op.create_index('ix_notification_project_mutes_project_id', 'notification_project_mutes', ['project_id']) + op.create_index("ix_notification_project_mutes_user_id", "notification_project_mutes", ["user_id"]) + op.create_index("ix_notification_project_mutes_project_id", "notification_project_mutes", ["project_id"]) # Create notification_thread_watches table op.create_table( - 'notification_thread_watches', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('thread_id', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['thread_id'], ['threads.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('user_id', 'thread_id', name='uq_user_thread_watch') + "notification_thread_watches", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("thread_id", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["thread_id"], ["threads.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "thread_id", name="uq_user_thread_watch"), ) - op.create_index('ix_notification_thread_watches_user_id', 'notification_thread_watches', ['user_id']) - op.create_index('ix_notification_thread_watches_thread_id', 'notification_thread_watches', ['thread_id']) + op.create_index("ix_notification_thread_watches_user_id", "notification_thread_watches", ["user_id"]) + op.create_index("ix_notification_thread_watches_thread_id", "notification_thread_watches", ["thread_id"]) def downgrade() -> None: """Downgrade schema.""" - op.drop_index('ix_notification_thread_watches_thread_id', table_name='notification_thread_watches') - op.drop_index('ix_notification_thread_watches_user_id', table_name='notification_thread_watches') - op.drop_table('notification_thread_watches') + op.drop_index("ix_notification_thread_watches_thread_id", table_name="notification_thread_watches") + op.drop_index("ix_notification_thread_watches_user_id", table_name="notification_thread_watches") + op.drop_table("notification_thread_watches") - op.drop_index('ix_notification_project_mutes_project_id', table_name='notification_project_mutes') - op.drop_index('ix_notification_project_mutes_user_id', table_name='notification_project_mutes') - op.drop_table('notification_project_mutes') + op.drop_index("ix_notification_project_mutes_project_id", table_name="notification_project_mutes") + op.drop_index("ix_notification_project_mutes_user_id", table_name="notification_project_mutes") + op.drop_table("notification_project_mutes") - op.drop_index('ix_notification_preferences_user_id', table_name='notification_preferences') - op.drop_table('notification_preferences') + op.drop_index("ix_notification_preferences_user_id", table_name="notification_preferences") + op.drop_table("notification_preferences") - op.execute('DROP TYPE notificationchannel') + op.execute("DROP TYPE notificationchannel") diff --git a/backend/alembic/versions/2aec09f57c3e_add_integration_configs_and_bug_sync_.py b/backend/alembic/versions/2aec09f57c3e_add_integration_configs_and_bug_sync_.py index 777041f..e22660f 100644 --- a/backend/alembic/versions/2aec09f57c3e_add_integration_configs_and_bug_sync_.py +++ b/backend/alembic/versions/2aec09f57c3e_add_integration_configs_and_bug_sync_.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 13:18:11.302275 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '2aec09f57c3e' -down_revision: Union[str, Sequence[str], None] = '021b37581165' +revision: str = "2aec09f57c3e" +down_revision: Union[str, Sequence[str], None] = "021b37581165" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,42 +23,42 @@ def upgrade() -> None: """Upgrade schema.""" # Create integration_configs table op.create_table( - 'integration_configs', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('organization_id', sa.UUID(), nullable=False), - sa.Column('provider', sa.String(length=50), nullable=False), - sa.Column('encrypted_token', sa.Text(), nullable=False), - sa.Column('config_json', sa.JSON(), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('organization_id', 'provider', name='uq_org_provider') + "integration_configs", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("organization_id", sa.UUID(), nullable=False), + sa.Column("provider", sa.String(length=50), nullable=False), + sa.Column("encrypted_token", sa.Text(), nullable=False), + sa.Column("config_json", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("organization_id", "provider", name="uq_org_provider"), ) - op.create_index('ix_integration_configs_organization_id', 'integration_configs', ['organization_id']) + op.create_index("ix_integration_configs_organization_id", "integration_configs", ["organization_id"]) # Create bug_sync_history table op.create_table( - 'bug_sync_history', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('synced_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('status', sa.String(length=20), nullable=False), - sa.Column('imported_data_json', sa.JSON(), nullable=True), - sa.Column('error_message', sa.Text(), nullable=True), - sa.Column('triggered_by', sa.String(length=20), nullable=False), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "bug_sync_history", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("synced_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("imported_data_json", sa.JSON(), nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("triggered_by", sa.String(length=20), nullable=False), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('ix_bug_sync_history_project_id', 'bug_sync_history', ['project_id']) - op.create_index('ix_bug_sync_history_synced_at', 'bug_sync_history', ['synced_at']) + op.create_index("ix_bug_sync_history_project_id", "bug_sync_history", ["project_id"]) + op.create_index("ix_bug_sync_history_synced_at", "bug_sync_history", ["synced_at"]) def downgrade() -> None: """Downgrade schema.""" - op.drop_index('ix_bug_sync_history_synced_at', table_name='bug_sync_history') - op.drop_index('ix_bug_sync_history_project_id', table_name='bug_sync_history') - op.drop_table('bug_sync_history') + op.drop_index("ix_bug_sync_history_synced_at", table_name="bug_sync_history") + op.drop_index("ix_bug_sync_history_project_id", table_name="bug_sync_history") + op.drop_table("bug_sync_history") - op.drop_index('ix_integration_configs_organization_id', table_name='integration_configs') - op.drop_table('integration_configs') + op.drop_index("ix_integration_configs_organization_id", table_name="integration_configs") + op.drop_table("integration_configs") diff --git a/backend/alembic/versions/2f8c2d246f32_merge_llm_usage_and_call_logs.py b/backend/alembic/versions/2f8c2d246f32_merge_llm_usage_and_call_logs.py index ed703ad..9e39db5 100644 --- a/backend/alembic/versions/2f8c2d246f32_merge_llm_usage_and_call_logs.py +++ b/backend/alembic/versions/2f8c2d246f32_merge_llm_usage_and_call_logs.py @@ -5,15 +5,12 @@ Create Date: 2025-12-05 17:42:23.198841 """ -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa +from typing import Sequence, Union # revision identifiers, used by Alembic. -revision: str = '2f8c2d246f32' -down_revision: Union[str, Sequence[str], None] = ('23893c2e92fa', 'f5g6h7i8j9k0') +revision: str = "2f8c2d246f32" +down_revision: Union[str, Sequence[str], None] = ("23893c2e92fa", "f5g6h7i8j9k0") branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/307472938827_remove_completion_summary_from_features.py b/backend/alembic/versions/307472938827_remove_completion_summary_from_features.py index 6437f27..b75d87e 100644 --- a/backend/alembic/versions/307472938827_remove_completion_summary_from_features.py +++ b/backend/alembic/versions/307472938827_remove_completion_summary_from_features.py @@ -5,15 +5,16 @@ Create Date: 2026-01-09 20:25:10.479596 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '307472938827' -down_revision: Union[str, Sequence[str], None] = 'implcs01' +revision: str = "307472938827" +down_revision: Union[str, Sequence[str], None] = "implcs01" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -24,9 +25,9 @@ def upgrade() -> None: completion_summary is now stored at the implementation level and auto-generated by the grounding agent when notes are written. """ - op.drop_column('features', 'completion_summary') + op.drop_column("features", "completion_summary") def downgrade() -> None: """Re-add completion_summary column to features table.""" - op.add_column('features', sa.Column('completion_summary', sa.Text(), nullable=True)) + op.add_column("features", sa.Column("completion_summary", sa.Text(), nullable=True)) diff --git a/backend/alembic/versions/3e35a2b90829_add_prompt_plan_coverage_reports_table.py b/backend/alembic/versions/3e35a2b90829_add_prompt_plan_coverage_reports_table.py index 877bc4c..0c76484 100644 --- a/backend/alembic/versions/3e35a2b90829_add_prompt_plan_coverage_reports_table.py +++ b/backend/alembic/versions/3e35a2b90829_add_prompt_plan_coverage_reports_table.py @@ -5,15 +5,16 @@ Create Date: 2025-11-24 16:47:11.941231 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '3e35a2b90829' -down_revision: Union[str, Sequence[str], None] = '1e859fbdbd74' +revision: str = "3e35a2b90829" +down_revision: Union[str, Sequence[str], None] = "1e859fbdbd74" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,22 +22,25 @@ def upgrade() -> None: """Upgrade schema.""" op.create_table( - 'prompt_plan_coverage_reports', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('spec_version_id', sa.UUID(), nullable=False), - sa.Column('ok', sa.Boolean(), nullable=False), - sa.Column('missing_phases', sa.JSON(), nullable=False), - sa.Column('missing_mcp_methods', sa.JSON(), nullable=False), - sa.Column('hallucinated_constraints', sa.JSON(), nullable=False), - sa.Column('weak_sections', sa.JSON(), nullable=False), - sa.Column('suggested_improvements', sa.JSON(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint(['spec_version_id'], ['spec_versions.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('spec_version_id') + "prompt_plan_coverage_reports", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("spec_version_id", sa.UUID(), nullable=False), + sa.Column("ok", sa.Boolean(), nullable=False), + sa.Column("missing_phases", sa.JSON(), nullable=False), + sa.Column("missing_mcp_methods", sa.JSON(), nullable=False), + sa.Column("hallucinated_constraints", sa.JSON(), nullable=False), + sa.Column("weak_sections", sa.JSON(), nullable=False), + sa.Column("suggested_improvements", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["spec_version_id"], + ["spec_versions.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("spec_version_id"), ) def downgrade() -> None: """Downgrade schema.""" - op.drop_table('prompt_plan_coverage_reports') + op.drop_table("prompt_plan_coverage_reports") diff --git a/backend/alembic/versions/4631a4a3270c_add_organizations_and_memberships.py b/backend/alembic/versions/4631a4a3270c_add_organizations_and_memberships.py index cd1c93c..41dc338 100644 --- a/backend/alembic/versions/4631a4a3270c_add_organizations_and_memberships.py +++ b/backend/alembic/versions/4631a4a3270c_add_organizations_and_memberships.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 09:48:16.630938 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '4631a4a3270c' -down_revision: Union[str, Sequence[str], None] = '1931abc02c3f' +revision: str = "4631a4a3270c" +down_revision: Union[str, Sequence[str], None] = "1931abc02c3f" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,39 +23,43 @@ def upgrade() -> None: """Upgrade schema.""" # Create organizations table op.create_table( - 'organizations', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('name', sa.String(length=255), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.PrimaryKeyConstraint('id', name=op.f('pk_organizations')) + "organizations", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_organizations")), ) # Create org_memberships table op.create_table( - 'org_memberships', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('org_id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('role', sa.String(length=20), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['org_id'], ['organizations.id'], name=op.f('fk_org_memberships_org_id_organizations'), ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], name=op.f('fk_org_memberships_user_id_users'), ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id', name=op.f('pk_org_memberships')), - sa.UniqueConstraint('org_id', 'user_id', name=op.f('uq_org_memberships_org_id_user_id')) + "org_memberships", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("org_id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("role", sa.String(length=20), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.ForeignKeyConstraint( + ["org_id"], ["organizations.id"], name=op.f("fk_org_memberships_org_id_organizations"), ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["user_id"], ["users.id"], name=op.f("fk_org_memberships_user_id_users"), ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_org_memberships")), + sa.UniqueConstraint("org_id", "user_id", name=op.f("uq_org_memberships_org_id_user_id")), ) # Create indexes for foreign keys - op.create_index(op.f('ix_org_memberships_org_id'), 'org_memberships', ['org_id'], unique=False) - op.create_index(op.f('ix_org_memberships_user_id'), 'org_memberships', ['user_id'], unique=False) + op.create_index(op.f("ix_org_memberships_org_id"), "org_memberships", ["org_id"], unique=False) + op.create_index(op.f("ix_org_memberships_user_id"), "org_memberships", ["user_id"], unique=False) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index(op.f('ix_org_memberships_user_id'), table_name='org_memberships') - op.drop_index(op.f('ix_org_memberships_org_id'), table_name='org_memberships') + op.drop_index(op.f("ix_org_memberships_user_id"), table_name="org_memberships") + op.drop_index(op.f("ix_org_memberships_org_id"), table_name="org_memberships") # Drop tables (reverse order) - op.drop_table('org_memberships') - op.drop_table('organizations') + op.drop_table("org_memberships") + op.drop_table("organizations") diff --git a/backend/alembic/versions/4691251c9f11_add_user_trial_started_at.py b/backend/alembic/versions/4691251c9f11_add_user_trial_started_at.py index 1c3cf71..76b85d0 100644 --- a/backend/alembic/versions/4691251c9f11_add_user_trial_started_at.py +++ b/backend/alembic/versions/4691251c9f11_add_user_trial_started_at.py @@ -8,28 +8,26 @@ Create Date: 2025-12-16 17:21:42.122199 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '4691251c9f11' -down_revision: Union[str, Sequence[str], None] = 'd1cf77c4c1fa' +revision: str = "4691251c9f11" +down_revision: Union[str, Sequence[str], None] = "d1cf77c4c1fa" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Add trial_started_at column to users table.""" - op.add_column( - 'users', - sa.Column('trial_started_at', sa.DateTime(timezone=True), nullable=True) - ) + op.add_column("users", sa.Column("trial_started_at", sa.DateTime(timezone=True), nullable=True)) # Existing users stay NULL (no trial = grandfathered with perpetual access) def downgrade() -> None: """Remove trial_started_at column from users table.""" - op.drop_column('users', 'trial_started_at') + op.drop_column("users", "trial_started_at") diff --git a/backend/alembic/versions/489421eb9675_fix_brainstorm_module_feature_types_data.py b/backend/alembic/versions/489421eb9675_fix_brainstorm_module_feature_types_data.py index d6098ce..7e26df6 100644 --- a/backend/alembic/versions/489421eb9675_fix_brainstorm_module_feature_types_data.py +++ b/backend/alembic/versions/489421eb9675_fix_brainstorm_module_feature_types_data.py @@ -5,15 +5,14 @@ Create Date: 2025-12-05 07:29:11.649749 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '489421eb9675' -down_revision: Union[str, Sequence[str], None] = '77d468f92bb4' +revision: str = "489421eb9675" +down_revision: Union[str, Sequence[str], None] = "77d468f92bb4" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/5464eccf3da8_add_key_lookup_hash_to_api_keys.py b/backend/alembic/versions/5464eccf3da8_add_key_lookup_hash_to_api_keys.py index 223cf50..70d784c 100644 --- a/backend/alembic/versions/5464eccf3da8_add_key_lookup_hash_to_api_keys.py +++ b/backend/alembic/versions/5464eccf3da8_add_key_lookup_hash_to_api_keys.py @@ -8,16 +8,17 @@ Create Date: 2026-02-14 18:58:17.183207 """ + import hashlib from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '5464eccf3da8' -down_revision: Union[str, Sequence[str], None] = 'slack02' +revision: str = "5464eccf3da8" +down_revision: Union[str, Sequence[str], None] = "slack02" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -25,17 +26,15 @@ def upgrade() -> None: """Add key_lookup_hash column and backfill from encrypted keys.""" # 1. Add the column (nullable for backwards compatibility) - op.add_column('api_keys', sa.Column('key_lookup_hash', sa.String(64), nullable=True)) + op.add_column("api_keys", sa.Column("key_lookup_hash", sa.String(64), nullable=True)) # 2. Add index for fast lookups - op.create_index('idx_api_keys_key_lookup_hash', 'api_keys', ['key_lookup_hash']) + op.create_index("idx_api_keys_key_lookup_hash", "api_keys", ["key_lookup_hash"]) # 3. Backfill: decrypt key_encrypted → compute SHA-256 → store # This runs in-database using Python for the crypto operations conn = op.get_bind() - rows = conn.execute( - sa.text("SELECT id, key_encrypted FROM api_keys WHERE key_encrypted IS NOT NULL") - ).fetchall() + rows = conn.execute(sa.text("SELECT id, key_encrypted FROM api_keys WHERE key_encrypted IS NOT NULL")).fetchall() if rows: # Import decryption utils (requires ENCRYPTION_KEY to be set) @@ -61,5 +60,5 @@ def upgrade() -> None: def downgrade() -> None: """Remove key_lookup_hash column.""" - op.drop_index('idx_api_keys_key_lookup_hash', table_name='api_keys') - op.drop_column('api_keys', 'key_lookup_hash') + op.drop_index("idx_api_keys_key_lookup_hash", table_name="api_keys") + op.drop_column("api_keys", "key_lookup_hash") diff --git a/backend/alembic/versions/58b3923aa347_add_thread_ai_response_flag.py b/backend/alembic/versions/58b3923aa347_add_thread_ai_response_flag.py index 88c447b..7f1f4ad 100644 --- a/backend/alembic/versions/58b3923aa347_add_thread_ai_response_flag.py +++ b/backend/alembic/versions/58b3923aa347_add_thread_ai_response_flag.py @@ -5,15 +5,16 @@ Create Date: 2026-01-04 13:21:04.973636 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '58b3923aa347' -down_revision: Union[str, Sequence[str], None] = 'ppd06' +revision: str = "58b3923aa347" +down_revision: Union[str, Sequence[str], None] = "ppd06" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -24,11 +25,10 @@ def upgrade() -> None: # This flag is set TRUE when user triggers @MFBTAI AI mention, # and cleared to FALSE when job completes (success or failure). op.add_column( - 'threads', - sa.Column('is_generating_ai_response', sa.Boolean(), nullable=False, server_default='false') + "threads", sa.Column("is_generating_ai_response", sa.Boolean(), nullable=False, server_default="false") ) def downgrade() -> None: """Downgrade schema.""" - op.drop_column('threads', 'is_generating_ai_response') + op.drop_column("threads", "is_generating_ai_response") diff --git a/backend/alembic/versions/58b5bf6d73fc_add_discovery_tables.py b/backend/alembic/versions/58b5bf6d73fc_add_discovery_tables.py index 0498dec..19ce9dd 100644 --- a/backend/alembic/versions/58b5bf6d73fc_add_discovery_tables.py +++ b/backend/alembic/versions/58b5bf6d73fc_add_discovery_tables.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 10:57:24.440539 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '58b5bf6d73fc' -down_revision: Union[str, Sequence[str], None] = '0c5625ec8ba1' +revision: str = "58b5bf6d73fc" +down_revision: Union[str, Sequence[str], None] = "0c5625ec8ba1" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,52 +23,64 @@ def upgrade() -> None: """Upgrade schema.""" # Create discovery_questions table op.create_table( - 'discovery_questions', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('category', sa.String(), nullable=True), - sa.Column('priority', sa.Enum('LOW', 'MEDIUM', 'HIGH', name='discoverypriority'), nullable=False), - sa.Column('question_text', sa.Text(), nullable=False), - sa.Column('is_multiple_choice', sa.Boolean(), nullable=False), - sa.Column('options', sa.JSON(), nullable=False), - sa.Column('depends_on', sa.UUID(), nullable=True), - sa.Column('created_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('status', sa.Enum('OPEN', 'RESOLVED', 'NA', name='discoverystatus'), nullable=False), - sa.Column('resolved_by', sa.UUID(), nullable=True), - sa.Column('resolved_at', sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint(['created_by'], ['users.id'], ), - sa.ForeignKeyConstraint(['depends_on'], ['discovery_questions.id'], ), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['resolved_by'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + "discovery_questions", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("category", sa.String(), nullable=True), + sa.Column("priority", sa.Enum("LOW", "MEDIUM", "HIGH", name="discoverypriority"), nullable=False), + sa.Column("question_text", sa.Text(), nullable=False), + sa.Column("is_multiple_choice", sa.Boolean(), nullable=False), + sa.Column("options", sa.JSON(), nullable=False), + sa.Column("depends_on", sa.UUID(), nullable=True), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("status", sa.Enum("OPEN", "RESOLVED", "NA", name="discoverystatus"), nullable=False), + sa.Column("resolved_by", sa.UUID(), nullable=True), + sa.Column("resolved_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint( + ["created_by"], + ["users.id"], + ), + sa.ForeignKeyConstraint( + ["depends_on"], + ["discovery_questions.id"], + ), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["resolved_by"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_discovery_questions_project_id'), 'discovery_questions', ['project_id'], unique=False) - op.create_index(op.f('ix_discovery_questions_status'), 'discovery_questions', ['status'], unique=False) + op.create_index(op.f("ix_discovery_questions_project_id"), "discovery_questions", ["project_id"], unique=False) + op.create_index(op.f("ix_discovery_questions_status"), "discovery_questions", ["status"], unique=False) # Create discovery_answers table op.create_table( - 'discovery_answers', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('question_id', sa.UUID(), nullable=False), - sa.Column('selected_option_id', sa.String(), nullable=True), - sa.Column('free_text', sa.Text(), nullable=True), - sa.Column('answered_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint(['answered_by'], ['users.id'], ), - sa.ForeignKeyConstraint(['question_id'], ['discovery_questions.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "discovery_answers", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("question_id", sa.UUID(), nullable=False), + sa.Column("selected_option_id", sa.String(), nullable=True), + sa.Column("free_text", sa.Text(), nullable=True), + sa.Column("answered_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["answered_by"], + ["users.id"], + ), + sa.ForeignKeyConstraint(["question_id"], ["discovery_questions.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_discovery_answers_question_id'), 'discovery_answers', ['question_id'], unique=False) + op.create_index(op.f("ix_discovery_answers_question_id"), "discovery_answers", ["question_id"], unique=False) def downgrade() -> None: """Downgrade schema.""" - op.drop_index(op.f('ix_discovery_answers_question_id'), table_name='discovery_answers') - op.drop_table('discovery_answers') - op.drop_index(op.f('ix_discovery_questions_status'), table_name='discovery_questions') - op.drop_index(op.f('ix_discovery_questions_project_id'), table_name='discovery_questions') - op.drop_table('discovery_questions') - op.execute('DROP TYPE discoverystatus') - op.execute('DROP TYPE discoverypriority') + op.drop_index(op.f("ix_discovery_answers_question_id"), table_name="discovery_answers") + op.drop_table("discovery_answers") + op.drop_index(op.f("ix_discovery_questions_status"), table_name="discovery_questions") + op.drop_index(op.f("ix_discovery_questions_project_id"), table_name="discovery_questions") + op.drop_table("discovery_questions") + op.execute("DROP TYPE discoverystatus") + op.execute("DROP TYPE discoverypriority") diff --git a/backend/alembic/versions/616549379b06_add_generation_flags_to_impl_and_thread.py b/backend/alembic/versions/616549379b06_add_generation_flags_to_impl_and_thread.py index d811e8a..67b01c4 100644 --- a/backend/alembic/versions/616549379b06_add_generation_flags_to_impl_and_thread.py +++ b/backend/alembic/versions/616549379b06_add_generation_flags_to_impl_and_thread.py @@ -5,15 +5,16 @@ Create Date: 2026-01-03 15:28:40.447898 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '616549379b06' -down_revision: Union[str, Sequence[str], None] = 'a049ffa5b22b' +revision: str = "616549379b06" +down_revision: Union[str, Sequence[str], None] = "a049ffa5b22b" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,23 +23,20 @@ def upgrade() -> None: """Add generation status flags to implementations and threads tables.""" # Add is_generating_spec and is_generating_prompt_plan to implementations op.add_column( - 'implementations', - sa.Column('is_generating_spec', sa.Boolean(), nullable=False, server_default='false') + "implementations", sa.Column("is_generating_spec", sa.Boolean(), nullable=False, server_default="false") ) op.add_column( - 'implementations', - sa.Column('is_generating_prompt_plan', sa.Boolean(), nullable=False, server_default='false') + "implementations", sa.Column("is_generating_prompt_plan", sa.Boolean(), nullable=False, server_default="false") ) # Add is_generating_decision_summary to threads op.add_column( - 'threads', - sa.Column('is_generating_decision_summary', sa.Boolean(), nullable=False, server_default='false') + "threads", sa.Column("is_generating_decision_summary", sa.Boolean(), nullable=False, server_default="false") ) def downgrade() -> None: """Remove generation status flags from implementations and threads tables.""" - op.drop_column('threads', 'is_generating_decision_summary') - op.drop_column('implementations', 'is_generating_prompt_plan') - op.drop_column('implementations', 'is_generating_spec') + op.drop_column("threads", "is_generating_decision_summary") + op.drop_column("implementations", "is_generating_prompt_plan") + op.drop_column("implementations", "is_generating_spec") diff --git a/backend/alembic/versions/63256c2c0d52_add_threads_and_comments_tables.py b/backend/alembic/versions/63256c2c0d52_add_threads_and_comments_tables.py index 65912f5..0b403ae 100644 --- a/backend/alembic/versions/63256c2c0d52_add_threads_and_comments_tables.py +++ b/backend/alembic/versions/63256c2c0d52_add_threads_and_comments_tables.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 12:03:13.033739 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '63256c2c0d52' -down_revision: Union[str, Sequence[str], None] = '178e7cb83afe' +revision: str = "63256c2c0d52" +down_revision: Union[str, Sequence[str], None] = "178e7cb83afe" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,47 +23,42 @@ def upgrade() -> None: """Upgrade schema.""" # Create threads table op.create_table( - 'threads', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('context_type', sa.String(), nullable=False), - sa.Column('context_id', sa.String(), nullable=True), - sa.Column('title', sa.String(), nullable=True), - sa.Column('created_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint(['created_by'], ['users.id']), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') - ) - op.create_index('ix_threads_project_id', 'threads', ['project_id']) - op.create_index( - 'ix_threads_context', - 'threads', - ['project_id', 'context_type', 'context_id'], - unique=False + "threads", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("context_type", sa.String(), nullable=False), + sa.Column("context_id", sa.String(), nullable=True), + sa.Column("title", sa.String(), nullable=True), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["created_by"], ["users.id"]), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) + op.create_index("ix_threads_project_id", "threads", ["project_id"]) + op.create_index("ix_threads_context", "threads", ["project_id", "context_type", "context_id"], unique=False) # Create comments table op.create_table( - 'comments', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('thread_id', sa.UUID(), nullable=False), - sa.Column('author_id', sa.UUID(), nullable=False), - sa.Column('body_markdown', sa.Text(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint(['author_id'], ['users.id']), - sa.ForeignKeyConstraint(['thread_id'], ['threads.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "comments", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("thread_id", sa.UUID(), nullable=False), + sa.Column("author_id", sa.UUID(), nullable=False), + sa.Column("body_markdown", sa.Text(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(["author_id"], ["users.id"]), + sa.ForeignKeyConstraint(["thread_id"], ["threads.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('ix_comments_thread_id', 'comments', ['thread_id']) + op.create_index("ix_comments_thread_id", "comments", ["thread_id"]) def downgrade() -> None: """Downgrade schema.""" - op.drop_index('ix_comments_thread_id', table_name='comments') - op.drop_table('comments') - op.drop_index('ix_threads_context', table_name='threads') - op.drop_index('ix_threads_project_id', table_name='threads') - op.drop_table('threads') + op.drop_index("ix_comments_thread_id", table_name="comments") + op.drop_table("comments") + op.drop_index("ix_threads_context", table_name="threads") + op.drop_index("ix_threads_project_id", table_name="threads") + op.drop_table("threads") diff --git a/backend/alembic/versions/6baa75dcb961_fix_module_feature_types_properly.py b/backend/alembic/versions/6baa75dcb961_fix_module_feature_types_properly.py index 0124b24..7e5dbf2 100644 --- a/backend/alembic/versions/6baa75dcb961_fix_module_feature_types_properly.py +++ b/backend/alembic/versions/6baa75dcb961_fix_module_feature_types_properly.py @@ -5,15 +5,14 @@ Create Date: 2025-12-05 07:30:43.900157 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '6baa75dcb961' -down_revision: Union[str, Sequence[str], None] = '489421eb9675' +revision: str = "6baa75dcb961" +down_revision: Union[str, Sequence[str], None] = "489421eb9675" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/6ebda2d1112d_add_llm_preference_table.py b/backend/alembic/versions/6ebda2d1112d_add_llm_preference_table.py index c8ffc4b..34445a6 100644 --- a/backend/alembic/versions/6ebda2d1112d_add_llm_preference_table.py +++ b/backend/alembic/versions/6ebda2d1112d_add_llm_preference_table.py @@ -5,16 +5,17 @@ Create Date: 2025-11-20 19:06:43.742518 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = '6ebda2d1112d' -down_revision: Union[str, Sequence[str], None] = '83c377d42a07' +revision: str = "6ebda2d1112d" +down_revision: Union[str, Sequence[str], None] = "83c377d42a07" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,20 +23,20 @@ def upgrade() -> None: """Upgrade schema.""" op.create_table( - 'llm_preferences', - sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), - sa.Column('organization_id', postgresql.UUID(as_uuid=True), nullable=False, index=True), - sa.Column('main_llm_config_id', postgresql.UUID(as_uuid=True), nullable=True), - sa.Column('lightweight_llm_config_id', postgresql.UUID(as_uuid=True), nullable=True), - sa.Column('created_at', postgresql.TIMESTAMP(timezone=True), nullable=False), - sa.Column('updated_at', postgresql.TIMESTAMP(timezone=True), nullable=False), - sa.ForeignKeyConstraint(['organization_id'], ['organizations.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['main_llm_config_id'], ['integration_configs.id'], ondelete='SET NULL'), - sa.ForeignKeyConstraint(['lightweight_llm_config_id'], ['integration_configs.id'], ondelete='SET NULL'), - sa.UniqueConstraint('organization_id', name='uq_org_llm_preference'), + "llm_preferences", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column("organization_id", postgresql.UUID(as_uuid=True), nullable=False, index=True), + sa.Column("main_llm_config_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("lightweight_llm_config_id", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["organization_id"], ["organizations.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["main_llm_config_id"], ["integration_configs.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint(["lightweight_llm_config_id"], ["integration_configs.id"], ondelete="SET NULL"), + sa.UniqueConstraint("organization_id", name="uq_org_llm_preference"), ) def downgrade() -> None: """Downgrade schema.""" - op.drop_table('llm_preferences') + op.drop_table("llm_preferences") diff --git a/backend/alembic/versions/77d468f92bb4_add_module_and_feature_type_columns.py b/backend/alembic/versions/77d468f92bb4_add_module_and_feature_type_columns.py index 622d1da..7f44d80 100644 --- a/backend/alembic/versions/77d468f92bb4_add_module_and_feature_type_columns.py +++ b/backend/alembic/versions/77d468f92bb4_add_module_and_feature_type_columns.py @@ -5,15 +5,16 @@ Create Date: 2025-12-04 18:33:23.985570 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '77d468f92bb4' -down_revision: Union[str, Sequence[str], None] = 'e4f5g6h7i8j9' +revision: str = "77d468f92bb4" +down_revision: Union[str, Sequence[str], None] = "e4f5g6h7i8j9" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,21 +22,15 @@ def upgrade() -> None: """Add module_type and feature_type columns with data migration.""" # Create the enum types - moduletype_enum = sa.Enum('conversation', 'implementation', name='moduletype') - featuretype_enum = sa.Enum('conversation', 'implementation', name='featuretype') + moduletype_enum = sa.Enum("conversation", "implementation", name="moduletype") + featuretype_enum = sa.Enum("conversation", "implementation", name="featuretype") moduletype_enum.create(op.get_bind(), checkfirst=True) featuretype_enum.create(op.get_bind(), checkfirst=True) # Add columns as nullable first - op.add_column( - 'modules', - sa.Column('module_type', moduletype_enum, nullable=True) - ) - op.add_column( - 'features', - sa.Column('feature_type', featuretype_enum, nullable=True) - ) + op.add_column("modules", sa.Column("module_type", moduletype_enum, nullable=True)) + op.add_column("features", sa.Column("feature_type", featuretype_enum, nullable=True)) # Data migration: Set feature_type based on Thread association # Features with BRAINSTORM_FEATURE threads are conversation questions @@ -69,24 +64,14 @@ def upgrade() -> None: """) # Make columns NOT NULL with defaults - op.alter_column( - 'modules', - 'module_type', - nullable=False, - server_default='implementation' - ) - op.alter_column( - 'features', - 'feature_type', - nullable=False, - server_default='implementation' - ) + op.alter_column("modules", "module_type", nullable=False, server_default="implementation") + op.alter_column("features", "feature_type", nullable=False, server_default="implementation") def downgrade() -> None: """Remove module_type and feature_type columns.""" - op.drop_column('features', 'feature_type') - op.drop_column('modules', 'module_type') + op.drop_column("features", "feature_type") + op.drop_column("modules", "module_type") # Drop the enum types op.execute("DROP TYPE IF EXISTS featuretype") diff --git a/backend/alembic/versions/8058002151ba_add_thread_items_and_followup_timestamps.py b/backend/alembic/versions/8058002151ba_add_thread_items_and_followup_timestamps.py index 02ec8bd..5ad70e6 100644 --- a/backend/alembic/versions/8058002151ba_add_thread_items_and_followup_timestamps.py +++ b/backend/alembic/versions/8058002151ba_add_thread_items_and_followup_timestamps.py @@ -5,16 +5,17 @@ Create Date: 2025-11-22 17:33:59.041505 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = '8058002151ba' -down_revision: Union[str, Sequence[str], None] = '17e2d7ed64e2' +revision: str = "8058002151ba" +down_revision: Union[str, Sequence[str], None] = "17e2d7ed64e2" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,25 +23,29 @@ def upgrade() -> None: """Upgrade schema.""" # Add timestamp columns to threads table - op.add_column('threads', sa.Column('last_followup_check_at', sa.DateTime(timezone=True), nullable=True)) - op.add_column('threads', sa.Column('last_user_comment_at', sa.DateTime(timezone=True), nullable=True)) + op.add_column("threads", sa.Column("last_followup_check_at", sa.DateTime(timezone=True), nullable=True)) + op.add_column("threads", sa.Column("last_user_comment_at", sa.DateTime(timezone=True), nullable=True)) # Create thread_items table op.create_table( - 'thread_items', - sa.Column('id', postgresql.UUID(), nullable=False), - sa.Column('thread_id', postgresql.UUID(), nullable=False), - sa.Column('item_type', sa.Enum('comment', 'mcq_followup', 'no_followup_message', name='threaditemtype'), nullable=False), - sa.Column('author_id', postgresql.UUID(), nullable=False), - sa.Column('content_data', sa.JSON(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint(['author_id'], ['users.id']), - sa.ForeignKeyConstraint(['thread_id'], ['threads.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + "thread_items", + sa.Column("id", postgresql.UUID(), nullable=False), + sa.Column("thread_id", postgresql.UUID(), nullable=False), + sa.Column( + "item_type", + sa.Enum("comment", "mcq_followup", "no_followup_message", name="threaditemtype"), + nullable=False, + ), + sa.Column("author_id", postgresql.UUID(), nullable=False), + sa.Column("content_data", sa.JSON(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(["author_id"], ["users.id"]), + sa.ForeignKeyConstraint(["thread_id"], ["threads.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) - op.create_index('ix_thread_items_thread_id', 'thread_items', ['thread_id']) - op.create_index('ix_thread_items_item_type', 'thread_items', ['item_type']) + op.create_index("ix_thread_items_thread_id", "thread_items", ["thread_id"]) + op.create_index("ix_thread_items_item_type", "thread_items", ["item_type"]) # Migrate existing comments to thread_items op.execute(""" @@ -71,13 +76,13 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade schema.""" # Drop thread_items table and indexes - op.drop_index('ix_thread_items_item_type', table_name='thread_items') - op.drop_index('ix_thread_items_thread_id', table_name='thread_items') - op.drop_table('thread_items') + op.drop_index("ix_thread_items_item_type", table_name="thread_items") + op.drop_index("ix_thread_items_thread_id", table_name="thread_items") + op.drop_table("thread_items") # Drop enum type op.execute("DROP TYPE threaditemtype") # Drop timestamp columns from threads - op.drop_column('threads', 'last_user_comment_at') - op.drop_column('threads', 'last_followup_check_at') + op.drop_column("threads", "last_user_comment_at") + op.drop_column("threads", "last_followup_check_at") diff --git a/backend/alembic/versions/83c377d42a07_add_display_name_to_integration_config.py b/backend/alembic/versions/83c377d42a07_add_display_name_to_integration_config.py index d9f266a..afed7e8 100644 --- a/backend/alembic/versions/83c377d42a07_add_display_name_to_integration_config.py +++ b/backend/alembic/versions/83c377d42a07_add_display_name_to_integration_config.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 19:05:19.168793 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '83c377d42a07' -down_revision: Union[str, Sequence[str], None] = 'a04c6c0f1117' +revision: str = "83c377d42a07" +down_revision: Union[str, Sequence[str], None] = "a04c6c0f1117" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,30 +22,30 @@ def upgrade() -> None: """Upgrade schema.""" # Add display_name column (nullable initially to allow backfill) - op.add_column('integration_configs', sa.Column('display_name', sa.String(length=100), nullable=True)) + op.add_column("integration_configs", sa.Column("display_name", sa.String(length=100), nullable=True)) # Backfill display_name with provider name for existing rows op.execute("UPDATE integration_configs SET display_name = provider WHERE display_name IS NULL") # Make display_name non-nullable - op.alter_column('integration_configs', 'display_name', nullable=False) + op.alter_column("integration_configs", "display_name", nullable=False) # Drop old unique constraint - op.drop_constraint('uq_org_provider', 'integration_configs', type_='unique') + op.drop_constraint("uq_org_provider", "integration_configs", type_="unique") # Add new unique constraint with display_name - op.create_unique_constraint('uq_org_provider_name', 'integration_configs', - ['organization_id', 'provider', 'display_name']) + op.create_unique_constraint( + "uq_org_provider_name", "integration_configs", ["organization_id", "provider", "display_name"] + ) def downgrade() -> None: """Downgrade schema.""" # Drop new unique constraint - op.drop_constraint('uq_org_provider_name', 'integration_configs', type_='unique') + op.drop_constraint("uq_org_provider_name", "integration_configs", type_="unique") # Recreate old unique constraint - op.create_unique_constraint('uq_org_provider', 'integration_configs', - ['organization_id', 'provider']) + op.create_unique_constraint("uq_org_provider", "integration_configs", ["organization_id", "provider"]) # Drop display_name column - op.drop_column('integration_configs', 'display_name') + op.drop_column("integration_configs", "display_name") diff --git a/backend/alembic/versions/840545b82f16_add_deleted_at_and_key_constraints_to_.py b/backend/alembic/versions/840545b82f16_add_deleted_at_and_key_constraints_to_.py index 519590e..6005a1f 100644 --- a/backend/alembic/versions/840545b82f16_add_deleted_at_and_key_constraints_to_.py +++ b/backend/alembic/versions/840545b82f16_add_deleted_at_and_key_constraints_to_.py @@ -5,15 +5,16 @@ Create Date: 2025-11-22 07:09:28.843970 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '840545b82f16' -down_revision: Union[str, Sequence[str], None] = '1cfb3bffcc2a' +revision: str = "840545b82f16" +down_revision: Union[str, Sequence[str], None] = "1cfb3bffcc2a" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,7 +22,7 @@ def upgrade() -> None: """Upgrade schema.""" # Add deleted_at column - op.add_column('projects', sa.Column('deleted_at', sa.TIMESTAMP(timezone=True), nullable=True)) + op.add_column("projects", sa.Column("deleted_at", sa.TIMESTAMP(timezone=True), nullable=True)) # Backfill NULL keys with auto-generated values # Format: PROJ-{first 8 chars of UUID} @@ -32,25 +33,25 @@ def upgrade() -> None: """) # Make key non-nullable now that all rows have values - op.alter_column('projects', 'key', nullable=False) + op.alter_column("projects", "key", nullable=False) # Add unique constraint on (org_id, key) to ensure uniqueness within org - op.create_unique_constraint('uq_projects_org_key', 'projects', ['org_id', 'key']) + op.create_unique_constraint("uq_projects_org_key", "projects", ["org_id", "key"]) # Add index on deleted_at for efficient filtering - op.create_index('ix_projects_deleted_at', 'projects', ['deleted_at']) + op.create_index("ix_projects_deleted_at", "projects", ["deleted_at"]) def downgrade() -> None: """Downgrade schema.""" # Drop index on deleted_at - op.drop_index('ix_projects_deleted_at', 'projects') + op.drop_index("ix_projects_deleted_at", "projects") # Drop unique constraint - op.drop_constraint('uq_projects_org_key', 'projects', type_='unique') + op.drop_constraint("uq_projects_org_key", "projects", type_="unique") # Make key nullable again - op.alter_column('projects', 'key', nullable=True) + op.alter_column("projects", "key", nullable=True) # Drop deleted_at column - op.drop_column('projects', 'deleted_at') + op.drop_column("projects", "deleted_at") diff --git a/backend/alembic/versions/8bba664c9d7b_set_freemium_max_users_default.py b/backend/alembic/versions/8bba664c9d7b_set_freemium_max_users_default.py index 9c6f051..b480c77 100644 --- a/backend/alembic/versions/8bba664c9d7b_set_freemium_max_users_default.py +++ b/backend/alembic/versions/8bba664c9d7b_set_freemium_max_users_default.py @@ -5,15 +5,14 @@ Create Date: 2025-12-19 18:33:47.267904 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '8bba664c9d7b' -down_revision: Union[str, Sequence[str], None] = 'ie4f5g6h7i8j' +revision: str = "8bba664c9d7b" +down_revision: Union[str, Sequence[str], None] = "ie4f5g6h7i8j" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/8e970e133ccd_add_spec_versions_table.py b/backend/alembic/versions/8e970e133ccd_add_spec_versions_table.py index 8733cfe..4396675 100644 --- a/backend/alembic/versions/8e970e133ccd_add_spec_versions_table.py +++ b/backend/alembic/versions/8e970e133ccd_add_spec_versions_table.py @@ -5,15 +5,14 @@ Create Date: 2025-11-20 11:27:52.760625 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = '8e970e133ccd' -down_revision: Union[str, Sequence[str], None] = '58b5bf6d73fc' +revision: str = "8e970e133ccd" +down_revision: Union[str, Sequence[str], None] = "58b5bf6d73fc" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -43,18 +42,18 @@ def upgrade() -> None: """) # Create indexes - op.create_index('ix_spec_versions_project_id', 'spec_versions', ['project_id']) - op.create_index('ix_spec_versions_active', 'spec_versions', ['project_id', 'spec_type', 'is_active']) + op.create_index("ix_spec_versions_project_id", "spec_versions", ["project_id"]) + op.create_index("ix_spec_versions_active", "spec_versions", ["project_id", "spec_type", "is_active"]) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index('ix_spec_versions_active', table_name='spec_versions') - op.drop_index('ix_spec_versions_project_id', table_name='spec_versions') + op.drop_index("ix_spec_versions_active", table_name="spec_versions") + op.drop_index("ix_spec_versions_project_id", table_name="spec_versions") # Drop table - op.drop_table('spec_versions') + op.drop_table("spec_versions") # Drop ENUM op.execute("DROP TYPE spec_type") diff --git a/backend/alembic/versions/8eebe3a100d3_add_mcp_oauth_tables.py b/backend/alembic/versions/8eebe3a100d3_add_mcp_oauth_tables.py index bd4726b..4e486a4 100644 --- a/backend/alembic/versions/8eebe3a100d3_add_mcp_oauth_tables.py +++ b/backend/alembic/versions/8eebe3a100d3_add_mcp_oauth_tables.py @@ -5,15 +5,16 @@ Create Date: 2026-01-09 07:46:45.073505 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '8eebe3a100d3' -down_revision: Union[str, Sequence[str], None] = '58b3923aa347' +revision: str = "8eebe3a100d3" +down_revision: Union[str, Sequence[str], None] = "58b3923aa347" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/9b373c88f9ac_add_email_verification_fields.py b/backend/alembic/versions/9b373c88f9ac_add_email_verification_fields.py index c5c70d9..88f7f72 100644 --- a/backend/alembic/versions/9b373c88f9ac_add_email_verification_fields.py +++ b/backend/alembic/versions/9b373c88f9ac_add_email_verification_fields.py @@ -5,15 +5,16 @@ Create Date: 2025-12-16 16:30:15.630889 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = '9b373c88f9ac' -down_revision: Union[str, Sequence[str], None] = 'b7c8d9e0f1a2' +revision: str = "9b373c88f9ac" +down_revision: Union[str, Sequence[str], None] = "b7c8d9e0f1a2" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/a049ffa5b22b_add_suggested_implementation_name_to_.py b/backend/alembic/versions/a049ffa5b22b_add_suggested_implementation_name_to_.py index 257399a..b78c2a8 100644 --- a/backend/alembic/versions/a049ffa5b22b_add_suggested_implementation_name_to_.py +++ b/backend/alembic/versions/a049ffa5b22b_add_suggested_implementation_name_to_.py @@ -5,15 +5,16 @@ Create Date: 2026-01-03 12:32:14.162163 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'a049ffa5b22b' -down_revision: Union[str, Sequence[str], None] = 'impl02' +revision: str = "a049ffa5b22b" +down_revision: Union[str, Sequence[str], None] = "impl02" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/a04c6c0f1117_add_implementation_notes_tables.py b/backend/alembic/versions/a04c6c0f1117_add_implementation_notes_tables.py index c415ca8..0c5fa56 100644 --- a/backend/alembic/versions/a04c6c0f1117_add_implementation_notes_tables.py +++ b/backend/alembic/versions/a04c6c0f1117_add_implementation_notes_tables.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 14:59:14.842098 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'a04c6c0f1117' -down_revision: Union[str, Sequence[str], None] = '264f9632081a' +revision: str = "a04c6c0f1117" +down_revision: Union[str, Sequence[str], None] = "264f9632081a" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,66 +23,66 @@ def upgrade() -> None: """Upgrade schema.""" # Create project_implementation_notes table op.create_table( - 'project_implementation_notes', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('title', sa.Text(), nullable=False), - sa.Column('content_markdown', sa.Text(), nullable=False), - sa.Column('created_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('promoted_from_phase_id', sa.UUID(), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=False, server_default=sa.true()), - sa.PrimaryKeyConstraint('id'), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['created_by'], ['users.id']), - sa.ForeignKeyConstraint(['promoted_from_phase_id'], ['implementation_phases.id'], ondelete='SET NULL'), + "project_implementation_notes", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("title", sa.Text(), nullable=False), + sa.Column("content_markdown", sa.Text(), nullable=False), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("promoted_from_phase_id", sa.UUID(), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.true()), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["created_by"], ["users.id"]), + sa.ForeignKeyConstraint(["promoted_from_phase_id"], ["implementation_phases.id"], ondelete="SET NULL"), ) - op.create_index('ix_project_implementation_notes_project_id', 'project_implementation_notes', ['project_id']) + op.create_index("ix_project_implementation_notes_project_id", "project_implementation_notes", ["project_id"]) # Create phase_notes table op.create_table( - 'phase_notes', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('implementation_phase_id', sa.UUID(), nullable=False), - sa.Column('content_markdown', sa.Text(), nullable=False), - sa.Column('created_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['implementation_phase_id'], ['implementation_phases.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['created_by'], ['users.id']), + "phase_notes", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("implementation_phase_id", sa.UUID(), nullable=False), + sa.Column("content_markdown", sa.Text(), nullable=False), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["implementation_phase_id"], ["implementation_phases.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["created_by"], ["users.id"]), ) - op.create_index('ix_phase_notes_project_id', 'phase_notes', ['project_id']) - op.create_index('ix_phase_notes_implementation_phase_id', 'phase_notes', ['implementation_phase_id']) + op.create_index("ix_phase_notes_project_id", "phase_notes", ["project_id"]) + op.create_index("ix_phase_notes_implementation_phase_id", "phase_notes", ["implementation_phase_id"]) # Create thread_notes table op.create_table( - 'thread_notes', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('thread_id', sa.UUID(), nullable=False), - sa.Column('content_markdown', sa.Text(), nullable=False), - sa.Column('created_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.ForeignKeyConstraint(['thread_id'], ['threads.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['created_by'], ['users.id']), + "thread_notes", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("thread_id", sa.UUID(), nullable=False), + sa.Column("content_markdown", sa.Text(), nullable=False), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["thread_id"], ["threads.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["created_by"], ["users.id"]), ) - op.create_index('ix_thread_notes_thread_id', 'thread_notes', ['thread_id']) + op.create_index("ix_thread_notes_thread_id", "thread_notes", ["thread_id"]) def downgrade() -> None: """Downgrade schema.""" # Drop thread_notes table - op.drop_index('ix_thread_notes_thread_id', 'thread_notes') - op.drop_table('thread_notes') + op.drop_index("ix_thread_notes_thread_id", "thread_notes") + op.drop_table("thread_notes") # Drop phase_notes table - op.drop_index('ix_phase_notes_implementation_phase_id', 'phase_notes') - op.drop_index('ix_phase_notes_project_id', 'phase_notes') - op.drop_table('phase_notes') + op.drop_index("ix_phase_notes_implementation_phase_id", "phase_notes") + op.drop_index("ix_phase_notes_project_id", "phase_notes") + op.drop_table("phase_notes") # Drop project_implementation_notes table - op.drop_index('ix_project_implementation_notes_project_id', 'project_implementation_notes') - op.drop_table('project_implementation_notes') + op.drop_index("ix_project_implementation_notes_project_id", "project_implementation_notes") + op.drop_table("project_implementation_notes") diff --git a/backend/alembic/versions/a1b2c3d4e5f6_add_brainstorming_phases_modules_features.py b/backend/alembic/versions/a1b2c3d4e5f6_add_brainstorming_phases_modules_features.py index 10ce70e..891d152 100644 --- a/backend/alembic/versions/a1b2c3d4e5f6_add_brainstorming_phases_modules_features.py +++ b/backend/alembic/versions/a1b2c3d4e5f6_add_brainstorming_phases_modules_features.py @@ -5,15 +5,14 @@ Create Date: 2025-12-03 10:00:00.000000 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = 'a1b2c3d4e5f6' -down_revision: Union[str, Sequence[str], None] = 'd5f8b2c9e1a3' +revision: str = "a1b2c3d4e5f6" +down_revision: Union[str, Sequence[str], None] = "d5f8b2c9e1a3" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -42,7 +41,7 @@ def upgrade() -> None: FOREIGN KEY (created_by) REFERENCES users (id) ) """) - op.create_index('ix_brainstorming_phases_project_id', 'brainstorming_phases', ['project_id']) + op.create_index("ix_brainstorming_phases_project_id", "brainstorming_phases", ["project_id"]) # Create modules table op.execute(""" @@ -63,8 +62,8 @@ def upgrade() -> None: FOREIGN KEY (created_by) REFERENCES users (id) ) """) - op.create_index('ix_modules_project_id', 'modules', ['project_id']) - op.create_index('ix_modules_brainstorming_phase_id', 'modules', ['brainstorming_phase_id']) + op.create_index("ix_modules_project_id", "modules", ["project_id"]) + op.create_index("ix_modules_brainstorming_phase_id", "modules", ["brainstorming_phase_id"]) # Create features table op.execute(""" @@ -86,23 +85,23 @@ def upgrade() -> None: FOREIGN KEY (created_by) REFERENCES users (id) ) """) - op.create_index('ix_features_module_id', 'features', ['module_id']) - op.create_index('ix_features_feature_key', 'features', ['feature_key']) + op.create_index("ix_features_module_id", "features", ["module_id"]) + op.create_index("ix_features_feature_key", "features", ["feature_key"]) def downgrade() -> None: """Drop brainstorming_phases, modules, and features tables.""" # Drop indexes - op.drop_index('ix_features_feature_key', table_name='features') - op.drop_index('ix_features_module_id', table_name='features') - op.drop_index('ix_modules_brainstorming_phase_id', table_name='modules') - op.drop_index('ix_modules_project_id', table_name='modules') - op.drop_index('ix_brainstorming_phases_project_id', table_name='brainstorming_phases') + op.drop_index("ix_features_feature_key", table_name="features") + op.drop_index("ix_features_module_id", table_name="features") + op.drop_index("ix_modules_brainstorming_phase_id", table_name="modules") + op.drop_index("ix_modules_project_id", table_name="modules") + op.drop_index("ix_brainstorming_phases_project_id", table_name="brainstorming_phases") # Drop tables (in correct order due to foreign key dependencies) - op.drop_table('features') - op.drop_table('modules') - op.drop_table('brainstorming_phases') + op.drop_table("features") + op.drop_table("modules") + op.drop_table("brainstorming_phases") # Drop ENUMs op.execute("DROP TYPE feature_status") diff --git a/backend/alembic/versions/a1b2c3d4e5f7_add_module_key.py b/backend/alembic/versions/a1b2c3d4e5f7_add_module_key.py index a3d90e6..c9a32c7 100644 --- a/backend/alembic/versions/a1b2c3d4e5f7_add_module_key.py +++ b/backend/alembic/versions/a1b2c3d4e5f7_add_module_key.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy import text +from alembic import op # revision identifiers, used by Alembic. revision: str = "a1b2c3d4e5f7" @@ -39,7 +39,8 @@ def upgrade() -> None: if dialect == "postgresql": # PostgreSQL version with LPAD - conn.execute(text(""" + conn.execute( + text(""" WITH numbered_modules AS ( SELECT m.id, @@ -57,10 +58,12 @@ def upgrade() -> None: module_key = 'M' || nm.project_key || '-' || LPAD(CAST(nm.row_num AS TEXT), 3, '0') FROM numbered_modules nm WHERE modules.id = nm.id - """)) + """) + ) else: # SQLite version with printf - conn.execute(text(""" + conn.execute( + text(""" WITH numbered_modules AS ( SELECT m.id, @@ -76,7 +79,8 @@ def upgrade() -> None: SET module_key_number = (SELECT nm.row_num FROM numbered_modules nm WHERE nm.id = modules.id), module_key = 'M' || (SELECT nm.project_key FROM numbered_modules nm WHERE nm.id = modules.id) || '-' || printf('%03d', (SELECT nm.row_num FROM numbered_modules nm WHERE nm.id = modules.id)) - """)) + """) + ) # Make columns non-nullable op.alter_column("modules", "module_key", nullable=False) diff --git a/backend/alembic/versions/a60117da6409_add_projects_table.py b/backend/alembic/versions/a60117da6409_add_projects_table.py index c224c9c..1045c4d 100644 --- a/backend/alembic/versions/a60117da6409_add_projects_table.py +++ b/backend/alembic/versions/a60117da6409_add_projects_table.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 10:10:50.652293 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'a60117da6409' -down_revision: Union[str, Sequence[str], None] = '4631a4a3270c' +revision: str = "a60117da6409" +down_revision: Union[str, Sequence[str], None] = "4631a4a3270c" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,41 +23,41 @@ def upgrade() -> None: """Upgrade schema.""" # Create projects table op.create_table( - 'projects', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('org_id', sa.UUID(), nullable=False), - sa.Column('parent_project_id', sa.UUID(), nullable=True), - sa.Column('type', sa.String(20), nullable=False), - sa.Column('key', sa.String(100), nullable=True), - sa.Column('name', sa.String(255), nullable=False), - sa.Column('short_description', sa.Text(), nullable=True), - sa.Column('idea_text', sa.Text(), nullable=True), - sa.Column('status', sa.String(30), nullable=False, server_default='draft'), - sa.Column('external_ticket_id', sa.String(100), nullable=True), - sa.Column('external_system', sa.String(50), nullable=True), - sa.Column('created_by', sa.UUID(), nullable=False), - sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.PrimaryKeyConstraint('id'), - sa.ForeignKeyConstraint(['org_id'], ['organizations.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['parent_project_id'], ['projects.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['created_by'], ['users.id']), + "projects", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("org_id", sa.UUID(), nullable=False), + sa.Column("parent_project_id", sa.UUID(), nullable=True), + sa.Column("type", sa.String(20), nullable=False), + sa.Column("key", sa.String(100), nullable=True), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("short_description", sa.Text(), nullable=True), + sa.Column("idea_text", sa.Text(), nullable=True), + sa.Column("status", sa.String(30), nullable=False, server_default="draft"), + sa.Column("external_ticket_id", sa.String(100), nullable=True), + sa.Column("external_system", sa.String(50), nullable=True), + sa.Column("created_by", sa.UUID(), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["org_id"], ["organizations.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["parent_project_id"], ["projects.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["created_by"], ["users.id"]), ) # Create indexes - op.create_index('ix_projects_org_id', 'projects', ['org_id']) - op.create_index('ix_projects_type', 'projects', ['type']) - op.create_index('ix_projects_status', 'projects', ['status']) - op.create_index('ix_projects_parent_project_id', 'projects', ['parent_project_id']) + op.create_index("ix_projects_org_id", "projects", ["org_id"]) + op.create_index("ix_projects_type", "projects", ["type"]) + op.create_index("ix_projects_status", "projects", ["status"]) + op.create_index("ix_projects_parent_project_id", "projects", ["parent_project_id"]) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index('ix_projects_parent_project_id', 'projects') - op.drop_index('ix_projects_status', 'projects') - op.drop_index('ix_projects_type', 'projects') - op.drop_index('ix_projects_org_id', 'projects') + op.drop_index("ix_projects_parent_project_id", "projects") + op.drop_index("ix_projects_status", "projects") + op.drop_index("ix_projects_type", "projects") + op.drop_index("ix_projects_org_id", "projects") # Drop table - op.drop_table('projects') + op.drop_table("projects") diff --git a/backend/alembic/versions/a6b7c8d9e0f1_user_level_api_keys.py b/backend/alembic/versions/a6b7c8d9e0f1_user_level_api_keys.py index 4b46277..1fd23f2 100644 --- a/backend/alembic/versions/a6b7c8d9e0f1_user_level_api_keys.py +++ b/backend/alembic/versions/a6b7c8d9e0f1_user_level_api_keys.py @@ -7,9 +7,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "a6b7c8d9e0f1" diff --git a/backend/alembic/versions/a8845f795cf1_add_jobs_table.py b/backend/alembic/versions/a8845f795cf1_add_jobs_table.py index 002d3a5..fb4d5f3 100644 --- a/backend/alembic/versions/a8845f795cf1_add_jobs_table.py +++ b/backend/alembic/versions/a8845f795cf1_add_jobs_table.py @@ -5,15 +5,16 @@ Create Date: 2025-11-20 08:11:53.926713 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'a8845f795cf1' -down_revision: Union[str, Sequence[str], None] = 'f9a38e4b733d' +revision: str = "a8845f795cf1" +down_revision: Union[str, Sequence[str], None] = "f9a38e4b733d" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,40 +23,40 @@ def upgrade() -> None: """Upgrade schema.""" # Create jobs table op.create_table( - 'jobs', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('org_id', sa.UUID(), nullable=True), - sa.Column('project_id', sa.UUID(), nullable=True), - sa.Column('job_type', sa.String(length=50), nullable=False), - sa.Column('status', sa.String(length=20), nullable=False), - sa.Column('payload', sa.JSON, nullable=False), - sa.Column('result', sa.JSON, nullable=True), - sa.Column('error_message', sa.Text(), nullable=True), - sa.Column('attempts', sa.Integer(), nullable=False, server_default='0'), - sa.Column('max_attempts', sa.Integer(), nullable=False, server_default='3'), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('started_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('finished_at', sa.DateTime(timezone=True), nullable=True), - sa.PrimaryKeyConstraint('id', name=op.f('pk_jobs')) + "jobs", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("org_id", sa.UUID(), nullable=True), + sa.Column("project_id", sa.UUID(), nullable=True), + sa.Column("job_type", sa.String(length=50), nullable=False), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("payload", sa.JSON, nullable=False), + sa.Column("result", sa.JSON, nullable=True), + sa.Column("error_message", sa.Text(), nullable=True), + sa.Column("attempts", sa.Integer(), nullable=False, server_default="0"), + sa.Column("max_attempts", sa.Integer(), nullable=False, server_default="3"), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("pk_jobs")), ) # Create indexes - op.create_index(op.f('ix_jobs_org_id'), 'jobs', ['org_id'], unique=False) - op.create_index(op.f('ix_jobs_project_id'), 'jobs', ['project_id'], unique=False) - op.create_index(op.f('ix_jobs_job_type'), 'jobs', ['job_type'], unique=False) - op.create_index(op.f('ix_jobs_status'), 'jobs', ['status'], unique=False) - op.create_index(op.f('ix_jobs_created_at'), 'jobs', ['created_at'], unique=False) + op.create_index(op.f("ix_jobs_org_id"), "jobs", ["org_id"], unique=False) + op.create_index(op.f("ix_jobs_project_id"), "jobs", ["project_id"], unique=False) + op.create_index(op.f("ix_jobs_job_type"), "jobs", ["job_type"], unique=False) + op.create_index(op.f("ix_jobs_status"), "jobs", ["status"], unique=False) + op.create_index(op.f("ix_jobs_created_at"), "jobs", ["created_at"], unique=False) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index(op.f('ix_jobs_created_at'), table_name='jobs') - op.drop_index(op.f('ix_jobs_status'), table_name='jobs') - op.drop_index(op.f('ix_jobs_job_type'), table_name='jobs') - op.drop_index(op.f('ix_jobs_project_id'), table_name='jobs') - op.drop_index(op.f('ix_jobs_org_id'), table_name='jobs') + op.drop_index(op.f("ix_jobs_created_at"), table_name="jobs") + op.drop_index(op.f("ix_jobs_status"), table_name="jobs") + op.drop_index(op.f("ix_jobs_job_type"), table_name="jobs") + op.drop_index(op.f("ix_jobs_project_id"), table_name="jobs") + op.drop_index(op.f("ix_jobs_org_id"), table_name="jobs") # Drop table - op.drop_table('jobs') + op.drop_table("jobs") diff --git a/backend/alembic/versions/b2c3d4e5f6g7_add_draft_final_version_tables.py b/backend/alembic/versions/b2c3d4e5f6g7_add_draft_final_version_tables.py index 3535452..1805e22 100644 --- a/backend/alembic/versions/b2c3d4e5f6g7_add_draft_final_version_tables.py +++ b/backend/alembic/versions/b2c3d4e5f6g7_add_draft_final_version_tables.py @@ -5,15 +5,16 @@ Create Date: 2025-12-03 11:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'b2c3d4e5f6g7' -down_revision: Union[str, Sequence[str], None] = 'a1b2c3d4e5f6' +revision: str = "b2c3d4e5f6g7" +down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,24 +22,24 @@ def upgrade() -> None: """Add brainstorming_phase_id and blocks to spec_versions, create final tables.""" # Add columns to spec_versions - op.add_column('spec_versions', sa.Column('brainstorming_phase_id', sa.UUID(), nullable=True)) - op.add_column('spec_versions', sa.Column('blocks', sa.JSON(), nullable=True)) + op.add_column("spec_versions", sa.Column("brainstorming_phase_id", sa.UUID(), nullable=True)) + op.add_column("spec_versions", sa.Column("blocks", sa.JSON(), nullable=True)) # Create foreign key constraint op.create_foreign_key( - 'fk_spec_versions_brainstorming_phase_id', - 'spec_versions', - 'brainstorming_phases', - ['brainstorming_phase_id'], - ['id'], - ondelete='CASCADE' + "fk_spec_versions_brainstorming_phase_id", + "spec_versions", + "brainstorming_phases", + ["brainstorming_phase_id"], + ["id"], + ondelete="CASCADE", ) # Create index on brainstorming_phase_id - op.create_index('ix_spec_versions_brainstorming_phase_id', 'spec_versions', ['brainstorming_phase_id']) + op.create_index("ix_spec_versions_brainstorming_phase_id", "spec_versions", ["brainstorming_phase_id"]) # Make project_id nullable (for brainstorming phase-based specs) - op.alter_column('spec_versions', 'project_id', nullable=True) + op.alter_column("spec_versions", "project_id", nullable=True) # Create final_specs table op.execute(""" @@ -56,7 +57,7 @@ def upgrade() -> None: FOREIGN KEY (created_by) REFERENCES users (id) ) """) - op.create_index('ix_final_specs_brainstorming_phase_id', 'final_specs', ['brainstorming_phase_id']) + op.create_index("ix_final_specs_brainstorming_phase_id", "final_specs", ["brainstorming_phase_id"]) # Create final_prompt_plans table op.execute(""" @@ -74,23 +75,23 @@ def upgrade() -> None: FOREIGN KEY (created_by) REFERENCES users (id) ) """) - op.create_index('ix_final_prompt_plans_brainstorming_phase_id', 'final_prompt_plans', ['brainstorming_phase_id']) + op.create_index("ix_final_prompt_plans_brainstorming_phase_id", "final_prompt_plans", ["brainstorming_phase_id"]) def downgrade() -> None: """Remove final tables and spec_versions modifications.""" # Drop final tables - op.drop_index('ix_final_prompt_plans_brainstorming_phase_id', table_name='final_prompt_plans') - op.drop_table('final_prompt_plans') + op.drop_index("ix_final_prompt_plans_brainstorming_phase_id", table_name="final_prompt_plans") + op.drop_table("final_prompt_plans") - op.drop_index('ix_final_specs_brainstorming_phase_id', table_name='final_specs') - op.drop_table('final_specs') + op.drop_index("ix_final_specs_brainstorming_phase_id", table_name="final_specs") + op.drop_table("final_specs") # Make project_id not nullable again - op.alter_column('spec_versions', 'project_id', nullable=False) + op.alter_column("spec_versions", "project_id", nullable=False) # Drop new columns from spec_versions - op.drop_index('ix_spec_versions_brainstorming_phase_id', table_name='spec_versions') - op.drop_constraint('fk_spec_versions_brainstorming_phase_id', 'spec_versions', type_='foreignkey') - op.drop_column('spec_versions', 'blocks') - op.drop_column('spec_versions', 'brainstorming_phase_id') + op.drop_index("ix_spec_versions_brainstorming_phase_id", table_name="spec_versions") + op.drop_constraint("fk_spec_versions_brainstorming_phase_id", "spec_versions", type_="foreignkey") + op.drop_column("spec_versions", "blocks") + op.drop_column("spec_versions", "brainstorming_phase_id") diff --git a/backend/alembic/versions/b7c8d9e0f1a2_add_summary_to_grounding_files.py b/backend/alembic/versions/b7c8d9e0f1a2_add_summary_to_grounding_files.py index 07a5a9c..186a560 100644 --- a/backend/alembic/versions/b7c8d9e0f1a2_add_summary_to_grounding_files.py +++ b/backend/alembic/versions/b7c8d9e0f1a2_add_summary_to_grounding_files.py @@ -8,11 +8,12 @@ grounding files (primarily agents.md). The summary describes what's in the file, not what changed. """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "b7c8d9e0f1a2" diff --git a/backend/alembic/versions/b7c8d9e0f1g2_add_coding_agent_name_to_mcp_logs.py b/backend/alembic/versions/b7c8d9e0f1g2_add_coding_agent_name_to_mcp_logs.py index 52ae154..2356029 100644 --- a/backend/alembic/versions/b7c8d9e0f1g2_add_coding_agent_name_to_mcp_logs.py +++ b/backend/alembic/versions/b7c8d9e0f1g2_add_coding_agent_name_to_mcp_logs.py @@ -7,11 +7,12 @@ Adds coding_agent_name column to track which coding agent (e.g., Claude Code, Cursor, Cline) made each MCP tool call. This enables analytics on agent usage. """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "b7c8d9e0f1g2" diff --git a/backend/alembic/versions/b894089ee371_add_api_key_encrypted_column.py b/backend/alembic/versions/b894089ee371_add_api_key_encrypted_column.py index c543c71..795f8fa 100644 --- a/backend/alembic/versions/b894089ee371_add_api_key_encrypted_column.py +++ b/backend/alembic/versions/b894089ee371_add_api_key_encrypted_column.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "b894089ee371" diff --git a/backend/alembic/versions/ba27eeb7dc05_merge_cexp08_and_icshare01.py b/backend/alembic/versions/ba27eeb7dc05_merge_cexp08_and_icshare01.py index 6b33dce..9d01a25 100644 --- a/backend/alembic/versions/ba27eeb7dc05_merge_cexp08_and_icshare01.py +++ b/backend/alembic/versions/ba27eeb7dc05_merge_cexp08_and_icshare01.py @@ -5,15 +5,12 @@ Create Date: 2026-01-15 08:18:33.406116 """ -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa +from typing import Sequence, Union # revision identifiers, used by Alembic. -revision: str = 'ba27eeb7dc05' -down_revision: Union[str, Sequence[str], None] = ('cexp08', 'icshare01') +revision: str = "ba27eeb7dc05" +down_revision: Union[str, Sequence[str], None] = ("cexp08", "icshare01") branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/bc6ddbf8b5b7_add_api_keys_table_for_mcp_http_.py b/backend/alembic/versions/bc6ddbf8b5b7_add_api_keys_table_for_mcp_http_.py index 0e64e3a..21706c8 100644 --- a/backend/alembic/versions/bc6ddbf8b5b7_add_api_keys_table_for_mcp_http_.py +++ b/backend/alembic/versions/bc6ddbf8b5b7_add_api_keys_table_for_mcp_http_.py @@ -5,15 +5,16 @@ Create Date: 2025-11-24 07:58:36.774569 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'bc6ddbf8b5b7' -down_revision: Union[str, Sequence[str], None] = '8058002151ba' +revision: str = "bc6ddbf8b5b7" +down_revision: Union[str, Sequence[str], None] = "8058002151ba" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,33 +23,33 @@ def upgrade() -> None: """Upgrade schema.""" # Create api_keys table op.create_table( - 'api_keys', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('user_id', sa.UUID(), nullable=False), - sa.Column('project_id', sa.UUID(), nullable=False), - sa.Column('name', sa.String(length=200), nullable=False), - sa.Column('key_hash', sa.String(length=255), nullable=False), - sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False), - sa.Column('last_used_at', sa.TIMESTAMP(timezone=True), nullable=True), - sa.Column('revoked', sa.Boolean(), nullable=False, server_default='false'), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('key_hash') + "api_keys", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=False), + sa.Column("project_id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(length=200), nullable=False), + sa.Column("key_hash", sa.String(length=255), nullable=False), + sa.Column("created_at", sa.TIMESTAMP(timezone=True), nullable=False), + sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.Column("revoked", sa.Boolean(), nullable=False, server_default="false"), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("key_hash"), ) # Create indexes - op.create_index('idx_api_keys_key_hash', 'api_keys', ['key_hash']) - op.create_index('idx_api_keys_project_id', 'api_keys', ['project_id']) - op.create_index('idx_api_keys_user_id', 'api_keys', ['user_id']) + op.create_index("idx_api_keys_key_hash", "api_keys", ["key_hash"]) + op.create_index("idx_api_keys_project_id", "api_keys", ["project_id"]) + op.create_index("idx_api_keys_user_id", "api_keys", ["user_id"]) def downgrade() -> None: """Downgrade schema.""" # Drop indexes - op.drop_index('idx_api_keys_user_id', table_name='api_keys') - op.drop_index('idx_api_keys_project_id', table_name='api_keys') - op.drop_index('idx_api_keys_key_hash', table_name='api_keys') + op.drop_index("idx_api_keys_user_id", table_name="api_keys") + op.drop_index("idx_api_keys_project_id", table_name="api_keys") + op.drop_index("idx_api_keys_key_hash", table_name="api_keys") # Drop table - op.drop_table('api_keys') + op.drop_table("api_keys") diff --git a/backend/alembic/versions/c3d4e5f6g7h8_add_thread_version_anchoring.py b/backend/alembic/versions/c3d4e5f6g7h8_add_thread_version_anchoring.py index 0668a80..7b3cf8a 100644 --- a/backend/alembic/versions/c3d4e5f6g7h8_add_thread_version_anchoring.py +++ b/backend/alembic/versions/c3d4e5f6g7h8_add_thread_version_anchoring.py @@ -5,15 +5,16 @@ Create Date: 2025-12-03 12:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'c3d4e5f6g7h8' -down_revision: Union[str, Sequence[str], None] = 'b2c3d4e5f6g7' +revision: str = "c3d4e5f6g7h8" +down_revision: Union[str, Sequence[str], None] = "b2c3d4e5f6g7" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,14 +22,14 @@ def upgrade() -> None: """Add version_id and block_id columns to threads, update context_type enum.""" # Add new columns to threads - op.add_column('threads', sa.Column('version_id', sa.String(), nullable=True)) - op.add_column('threads', sa.Column('block_id', sa.String(), nullable=True)) + op.add_column("threads", sa.Column("version_id", sa.String(), nullable=True)) + op.add_column("threads", sa.Column("block_id", sa.String(), nullable=True)) # Create index on version_id - op.create_index('ix_threads_version_id', 'threads', ['version_id']) + op.create_index("ix_threads_version_id", "threads", ["version_id"]) # Create composite index on (version_id, block_id) for efficient lookups - op.create_index('ix_threads_version_block', 'threads', ['version_id', 'block_id']) + op.create_index("ix_threads_version_block", "threads", ["version_id", "block_id"]) # Note: context_type is stored as VARCHAR, not a PostgreSQL enum type, # so no ALTER TYPE is needed. New values ('spec_draft', 'prompt_plan_draft') @@ -38,12 +39,12 @@ def upgrade() -> None: def downgrade() -> None: """Remove version_id and block_id columns from threads.""" # Drop indexes - op.drop_index('ix_threads_version_block', table_name='threads') - op.drop_index('ix_threads_version_id', table_name='threads') + op.drop_index("ix_threads_version_block", table_name="threads") + op.drop_index("ix_threads_version_id", table_name="threads") # Drop columns - op.drop_column('threads', 'block_id') - op.drop_column('threads', 'version_id') + op.drop_column("threads", "block_id") + op.drop_column("threads", "version_id") # Note: PostgreSQL doesn't support removing enum values easily. # The new enum values will remain but won't be used after downgrade. diff --git a/backend/alembic/versions/c4a5e7f8d123_add_archived_at_to_brainstorming_phases.py b/backend/alembic/versions/c4a5e7f8d123_add_archived_at_to_brainstorming_phases.py index 68504b5..45402c8 100644 --- a/backend/alembic/versions/c4a5e7f8d123_add_archived_at_to_brainstorming_phases.py +++ b/backend/alembic/versions/c4a5e7f8d123_add_archived_at_to_brainstorming_phases.py @@ -7,9 +7,10 @@ """ from typing import Sequence, Union -from alembic import op + import sqlalchemy as sa +from alembic import op revision: str = "c4a5e7f8d123" down_revision: Union[str, None] = "b894089ee371" diff --git a/backend/alembic/versions/c8d9e0f1g2h3_add_feature_updated_at.py b/backend/alembic/versions/c8d9e0f1g2h3_add_feature_updated_at.py index 9543f59..b3cb267 100644 --- a/backend/alembic/versions/c8d9e0f1g2h3_add_feature_updated_at.py +++ b/backend/alembic/versions/c8d9e0f1g2h3_add_feature_updated_at.py @@ -8,11 +8,12 @@ are modified. This enables sorting features by "recently changed". """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "c8d9e0f1g2h3" @@ -24,22 +25,16 @@ def upgrade() -> None: # Add updated_at column with server default of now() op.add_column( - 'features', - sa.Column( - 'updated_at', - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.func.now() - ) + "features", sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()) ) # Backfill existing rows with created_at value op.execute("UPDATE features SET updated_at = created_at") # Add index for efficient sorting by updated_at - op.create_index('ix_features_updated_at', 'features', ['updated_at']) + op.create_index("ix_features_updated_at", "features", ["updated_at"]) def downgrade() -> None: - op.drop_index('ix_features_updated_at', table_name='features') - op.drop_column('features', 'updated_at') + op.drop_index("ix_features_updated_at", table_name="features") + op.drop_column("features", "updated_at") diff --git a/backend/alembic/versions/cec01_add_code_exploration_cache.py b/backend/alembic/versions/cec01_add_code_exploration_cache.py index 248df9f..a838b55 100644 --- a/backend/alembic/versions/cec01_add_code_exploration_cache.py +++ b/backend/alembic/versions/cec01_add_code_exploration_cache.py @@ -14,8 +14,8 @@ from typing import Sequence, Union import sqlalchemy as sa -from alembic import op +from alembic import op # revision identifiers, used by Alembic. revision: str = "cec01" diff --git a/backend/alembic/versions/cexp01_add_code_explorer.py b/backend/alembic/versions/cexp01_add_code_explorer.py index 7bddc2e..d3fbe98 100644 --- a/backend/alembic/versions/cexp01_add_code_explorer.py +++ b/backend/alembic/versions/cexp01_add_code_explorer.py @@ -12,10 +12,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "cexp01" diff --git a/backend/alembic/versions/cexp06_add_raw_output_to_code_exploration.py b/backend/alembic/versions/cexp06_add_raw_output_to_code_exploration.py index 2e7bc23..e05088c 100644 --- a/backend/alembic/versions/cexp06_add_raw_output_to_code_exploration.py +++ b/backend/alembic/versions/cexp06_add_raw_output_to_code_exploration.py @@ -5,11 +5,12 @@ Create Date: 2026-01-11 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "cexp06" diff --git a/backend/alembic/versions/cexp07_add_thread_code_exploration.py b/backend/alembic/versions/cexp07_add_thread_code_exploration.py index 714857b..98c3227 100644 --- a/backend/alembic/versions/cexp07_add_thread_code_exploration.py +++ b/backend/alembic/versions/cexp07_add_thread_code_exploration.py @@ -8,12 +8,13 @@ Revises: cexp06, ppd07 Create Date: 2025-01-12 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import UUID +from alembic import op # revision identifiers, used by Alembic. revision: str = "cexp07" diff --git a/backend/alembic/versions/cexp08_move_code_explorer_to_project.py b/backend/alembic/versions/cexp08_move_code_explorer_to_project.py index 9fd586b..d316789 100644 --- a/backend/alembic/versions/cexp08_move_code_explorer_to_project.py +++ b/backend/alembic/versions/cexp08_move_code_explorer_to_project.py @@ -8,12 +8,13 @@ Revises: cexp07 Create Date: 2025-01-14 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import UUID +from alembic import op # revision identifiers, used by Alembic. revision: str = "cexp08" diff --git a/backend/alembic/versions/d1cf77c4c1fa_add_llm_usage_logs_table.py b/backend/alembic/versions/d1cf77c4c1fa_add_llm_usage_logs_table.py index 893edc2..f077ef8 100644 --- a/backend/alembic/versions/d1cf77c4c1fa_add_llm_usage_logs_table.py +++ b/backend/alembic/versions/d1cf77c4c1fa_add_llm_usage_logs_table.py @@ -5,16 +5,17 @@ Create Date: 2025-12-16 17:04:42.006885 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'd1cf77c4c1fa' -down_revision: Union[str, Sequence[str], None] = '9b373c88f9ac' +revision: str = "d1cf77c4c1fa" +down_revision: Union[str, Sequence[str], None] = "9b373c88f9ac" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/d4e5f6g7h8i9_add_activity_logs.py b/backend/alembic/versions/d4e5f6g7h8i9_add_activity_logs.py index dacde96..d034597 100644 --- a/backend/alembic/versions/d4e5f6g7h8i9_add_activity_logs.py +++ b/backend/alembic/versions/d4e5f6g7h8i9_add_activity_logs.py @@ -5,15 +5,14 @@ Create Date: 2025-12-03 13:00:00.000000 """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. -revision: str = 'd4e5f6g7h8i9' -down_revision: Union[str, Sequence[str], None] = 'c3d4e5f6g7h8' +revision: str = "d4e5f6g7h8i9" +down_revision: Union[str, Sequence[str], None] = "c3d4e5f6g7h8" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -34,14 +33,14 @@ def upgrade() -> None: """) # Create indexes for efficient querying - op.create_index('ix_activity_logs_entity', 'activity_logs', ['entity_type', 'entity_id']) - op.create_index('ix_activity_logs_event_type', 'activity_logs', ['event_type']) - op.create_index('ix_activity_logs_created_at', 'activity_logs', ['created_at']) + op.create_index("ix_activity_logs_entity", "activity_logs", ["entity_type", "entity_id"]) + op.create_index("ix_activity_logs_event_type", "activity_logs", ["event_type"]) + op.create_index("ix_activity_logs_created_at", "activity_logs", ["created_at"]) def downgrade() -> None: """Drop activity_logs table.""" - op.drop_index('ix_activity_logs_created_at', table_name='activity_logs') - op.drop_index('ix_activity_logs_event_type', table_name='activity_logs') - op.drop_index('ix_activity_logs_entity', table_name='activity_logs') - op.drop_table('activity_logs') + op.drop_index("ix_activity_logs_created_at", table_name="activity_logs") + op.drop_index("ix_activity_logs_event_type", table_name="activity_logs") + op.drop_index("ix_activity_logs_entity", table_name="activity_logs") + op.drop_table("activity_logs") diff --git a/backend/alembic/versions/d5b6e7f8g123_add_decision_summary_short.py b/backend/alembic/versions/d5b6e7f8g123_add_decision_summary_short.py index 3770f41..5d634ce 100644 --- a/backend/alembic/versions/d5b6e7f8g123_add_decision_summary_short.py +++ b/backend/alembic/versions/d5b6e7f8g123_add_decision_summary_short.py @@ -7,9 +7,10 @@ """ from typing import Sequence, Union -from alembic import op + import sqlalchemy as sa +from alembic import op revision: str = "d5b6e7f8g123" down_revision: Union[str, None] = "c4a5e7f8d123" diff --git a/backend/alembic/versions/d5f8b2c9e1a3_add_non_goals_violations_column.py b/backend/alembic/versions/d5f8b2c9e1a3_add_non_goals_violations_column.py index 5ff0745..c5fc88e 100644 --- a/backend/alembic/versions/d5f8b2c9e1a3_add_non_goals_violations_column.py +++ b/backend/alembic/versions/d5f8b2c9e1a3_add_non_goals_violations_column.py @@ -5,15 +5,16 @@ Create Date: 2025-11-25 16:35:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects import postgresql + +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'd5f8b2c9e1a3' -down_revision: Union[str, Sequence[str], None] = '3e35a2b90829' +revision: str = "d5f8b2c9e1a3" +down_revision: Union[str, Sequence[str], None] = "3e35a2b90829" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,11 +22,11 @@ def upgrade() -> None: """Add non_goals_violations JSON column to prompt_plan_coverage_reports table.""" op.add_column( - 'prompt_plan_coverage_reports', - sa.Column('non_goals_violations', sa.JSON(), nullable=False, server_default='[]') + "prompt_plan_coverage_reports", + sa.Column("non_goals_violations", sa.JSON(), nullable=False, server_default="[]"), ) def downgrade() -> None: """Remove non_goals_violations column from prompt_plan_coverage_reports table.""" - op.drop_column('prompt_plan_coverage_reports', 'non_goals_violations') + op.drop_column("prompt_plan_coverage_reports", "non_goals_violations") diff --git a/backend/alembic/versions/df4d2a8eedf1_add_show_create_implementation_button_.py b/backend/alembic/versions/df4d2a8eedf1_add_show_create_implementation_button_.py index c8e3804..88da725 100644 --- a/backend/alembic/versions/df4d2a8eedf1_add_show_create_implementation_button_.py +++ b/backend/alembic/versions/df4d2a8eedf1_add_show_create_implementation_button_.py @@ -5,15 +5,16 @@ Create Date: 2026-01-03 16:53:37.917190 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'df4d2a8eedf1' -down_revision: Union[str, Sequence[str], None] = '616549379b06' +revision: str = "df4d2a8eedf1" +down_revision: Union[str, Sequence[str], None] = "616549379b06" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,11 +22,10 @@ def upgrade() -> None: """Add show_create_implementation_button flag to threads table.""" op.add_column( - 'threads', - sa.Column('show_create_implementation_button', sa.Boolean(), nullable=False, server_default='false') + "threads", sa.Column("show_create_implementation_button", sa.Boolean(), nullable=False, server_default="false") ) def downgrade() -> None: """Remove show_create_implementation_button flag from threads table.""" - op.drop_column('threads', 'show_create_implementation_button') + op.drop_column("threads", "show_create_implementation_button") diff --git a/backend/alembic/versions/dus01_daily_usage_summary_table.py b/backend/alembic/versions/dus01_daily_usage_summary_table.py index d93bcca..cc678cf 100644 --- a/backend/alembic/versions/dus01_daily_usage_summary_table.py +++ b/backend/alembic/versions/dus01_daily_usage_summary_table.py @@ -12,10 +12,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "dus01" diff --git a/backend/alembic/versions/dus02_realtime_aggregation_trigger.py b/backend/alembic/versions/dus02_realtime_aggregation_trigger.py index dac99ad..a039905 100644 --- a/backend/alembic/versions/dus02_realtime_aggregation_trigger.py +++ b/backend/alembic/versions/dus02_realtime_aggregation_trigger.py @@ -13,8 +13,6 @@ from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "dus02" diff --git a/backend/alembic/versions/dus03_fix_trigger_on_conflict.py b/backend/alembic/versions/dus03_fix_trigger_on_conflict.py index 678d553..372a748 100644 --- a/backend/alembic/versions/dus03_fix_trigger_on_conflict.py +++ b/backend/alembic/versions/dus03_fix_trigger_on_conflict.py @@ -14,7 +14,6 @@ from alembic import op - # revision identifiers, used by Alembic. revision: str = "dus03" down_revision: Union[str, None] = "rec01" diff --git a/backend/alembic/versions/e4f5g6h7i8j9_add_feature_priority_category.py b/backend/alembic/versions/e4f5g6h7i8j9_add_feature_priority_category.py index ea194e8..c7812b9 100644 --- a/backend/alembic/versions/e4f5g6h7i8j9_add_feature_priority_category.py +++ b/backend/alembic/versions/e4f5g6h7i8j9_add_feature_priority_category.py @@ -5,15 +5,16 @@ Create Date: 2025-12-04 07:20:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'e4f5g6h7i8j9' -down_revision: Union[str, Sequence[str], None] = 'd4e5f6g7h8i9' +revision: str = "e4f5g6h7i8j9" +down_revision: Union[str, Sequence[str], None] = "d4e5f6g7h8i9" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,34 +22,20 @@ def upgrade() -> None: """Add priority and category columns to features table.""" # Create the enum type for feature priority - feature_priority_enum = sa.Enum( - 'must_have', 'important', 'optional', - name='feature_priority' - ) + feature_priority_enum = sa.Enum("must_have", "important", "optional", name="feature_priority") feature_priority_enum.create(op.get_bind(), checkfirst=True) # Add priority column with default 'important' - op.add_column( - 'features', - sa.Column( - 'priority', - feature_priority_enum, - nullable=False, - server_default='important' - ) - ) + op.add_column("features", sa.Column("priority", feature_priority_enum, nullable=False, server_default="important")) # Add category column (nullable string) - op.add_column( - 'features', - sa.Column('category', sa.String(100), nullable=True) - ) + op.add_column("features", sa.Column("category", sa.String(100), nullable=True)) def downgrade() -> None: """Remove priority and category columns from features table.""" - op.drop_column('features', 'category') - op.drop_column('features', 'priority') + op.drop_column("features", "category") + op.drop_column("features", "priority") # Drop the enum type - sa.Enum(name='feature_priority').drop(op.get_bind(), checkfirst=True) + sa.Enum(name="feature_priority").drop(op.get_bind(), checkfirst=True) diff --git a/backend/alembic/versions/e6f7g8h9i0j1_add_email_templates_table.py b/backend/alembic/versions/e6f7g8h9i0j1_add_email_templates_table.py index c67719f..06dc14a 100644 --- a/backend/alembic/versions/e6f7g8h9i0j1_add_email_templates_table.py +++ b/backend/alembic/versions/e6f7g8h9i0j1_add_email_templates_table.py @@ -7,10 +7,11 @@ """ from typing import Sequence, Union -from alembic import op + import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op revision: str = "e6f7g8h9i0j1" down_revision: Union[str, None] = "d5b6e7f8g123" diff --git a/backend/alembic/versions/ea0f87fc2305_add_is_sample_to_projects.py b/backend/alembic/versions/ea0f87fc2305_add_is_sample_to_projects.py index 365554e..8c5101d 100644 --- a/backend/alembic/versions/ea0f87fc2305_add_is_sample_to_projects.py +++ b/backend/alembic/versions/ea0f87fc2305_add_is_sample_to_projects.py @@ -7,11 +7,12 @@ Adds a boolean flag to identify sample onboarding projects created for new users during signup. """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "ea0f87fc2305" diff --git a/backend/alembic/versions/ed7322775e46_add_pending_approval_to_threads.py b/backend/alembic/versions/ed7322775e46_add_pending_approval_to_threads.py index 4edd369..f57684a 100644 --- a/backend/alembic/versions/ed7322775e46_add_pending_approval_to_threads.py +++ b/backend/alembic/versions/ed7322775e46_add_pending_approval_to_threads.py @@ -5,24 +5,25 @@ Create Date: 2025-11-20 13:01:24.678760 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'ed7322775e46' -down_revision: Union[str, Sequence[str], None] = '63256c2c0d52' +revision: str = "ed7322775e46" +down_revision: Union[str, Sequence[str], None] = "63256c2c0d52" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Upgrade schema.""" - op.add_column('threads', sa.Column('pending_approval', sa.Boolean(), nullable=False, server_default=sa.false())) + op.add_column("threads", sa.Column("pending_approval", sa.Boolean(), nullable=False, server_default=sa.false())) def downgrade() -> None: """Downgrade schema.""" - op.drop_column('threads', 'pending_approval') + op.drop_column("threads", "pending_approval") diff --git a/backend/alembic/versions/f5g6h7i8j9k0_add_llm_call_logs_table.py b/backend/alembic/versions/f5g6h7i8j9k0_add_llm_call_logs_table.py index ac73ee8..4fa9ac5 100644 --- a/backend/alembic/versions/f5g6h7i8j9k0_add_llm_call_logs_table.py +++ b/backend/alembic/versions/f5g6h7i8j9k0_add_llm_call_logs_table.py @@ -5,16 +5,17 @@ Create Date: 2025-12-05 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'f5g6h7i8j9k0' -down_revision: Union[str, Sequence[str], None] = 'e4f5g6h7i8j9' +revision: str = "f5g6h7i8j9k0" +down_revision: Union[str, Sequence[str], None] = "e4f5g6h7i8j9" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,44 +23,46 @@ def upgrade() -> None: """Create llm_call_logs table.""" op.create_table( - 'llm_call_logs', - sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), - sa.Column('job_id', postgresql.UUID(as_uuid=True), sa.ForeignKey('jobs.id', ondelete='CASCADE'), nullable=False), + "llm_call_logs", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), + sa.Column( + "job_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False + ), # Agent identification - sa.Column('agent_name', sa.String(100), nullable=False), - sa.Column('agent_display_name', sa.String(255), nullable=True), + sa.Column("agent_name", sa.String(100), nullable=False), + sa.Column("agent_display_name", sa.String(255), nullable=True), # Request details - sa.Column('request_messages', postgresql.JSONB, nullable=False), - sa.Column('request_model', sa.String(100), nullable=False), - sa.Column('request_temperature', sa.Numeric(3, 2), nullable=True), - sa.Column('request_max_tokens', sa.Integer(), nullable=True), + sa.Column("request_messages", postgresql.JSONB, nullable=False), + sa.Column("request_model", sa.String(100), nullable=False), + sa.Column("request_temperature", sa.Numeric(3, 2), nullable=True), + sa.Column("request_max_tokens", sa.Integer(), nullable=True), # Response details - sa.Column('response_content', sa.Text(), nullable=True), - sa.Column('response_finish_reason', sa.String(50), nullable=True), - sa.Column('response_tool_calls', postgresql.JSONB, nullable=True), + sa.Column("response_content", sa.Text(), nullable=True), + sa.Column("response_finish_reason", sa.String(50), nullable=True), + sa.Column("response_tool_calls", postgresql.JSONB, nullable=True), # Usage metrics - sa.Column('prompt_tokens', sa.Integer(), nullable=False), - sa.Column('completion_tokens', sa.Integer(), nullable=False), - sa.Column('cost_usd', sa.Numeric(10, 6), nullable=True), + sa.Column("prompt_tokens", sa.Integer(), nullable=False), + sa.Column("completion_tokens", sa.Integer(), nullable=False), + sa.Column("cost_usd", sa.Numeric(10, 6), nullable=True), # Timing - sa.Column('started_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('finished_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('duration_ms', sa.Integer(), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("duration_ms", sa.Integer(), nullable=False), # Metadata - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), ) # Create indexes - op.create_index('ix_llm_call_logs_job_id', 'llm_call_logs', ['job_id']) - op.create_index('ix_llm_call_logs_agent_name', 'llm_call_logs', ['agent_name']) - op.create_index('ix_llm_call_logs_created_at', 'llm_call_logs', ['created_at']) - op.create_index('ix_llm_call_logs_job_created', 'llm_call_logs', ['job_id', 'created_at']) + op.create_index("ix_llm_call_logs_job_id", "llm_call_logs", ["job_id"]) + op.create_index("ix_llm_call_logs_agent_name", "llm_call_logs", ["agent_name"]) + op.create_index("ix_llm_call_logs_created_at", "llm_call_logs", ["created_at"]) + op.create_index("ix_llm_call_logs_job_created", "llm_call_logs", ["job_id", "created_at"]) def downgrade() -> None: """Drop llm_call_logs table.""" - op.drop_index('ix_llm_call_logs_job_created') - op.drop_index('ix_llm_call_logs_created_at') - op.drop_index('ix_llm_call_logs_agent_name') - op.drop_index('ix_llm_call_logs_job_id') - op.drop_table('llm_call_logs') + op.drop_index("ix_llm_call_logs_job_created") + op.drop_index("ix_llm_call_logs_created_at") + op.drop_index("ix_llm_call_logs_agent_name") + op.drop_index("ix_llm_call_logs_job_id") + op.drop_table("llm_call_logs") diff --git a/backend/alembic/versions/f7g8h9i0j1k2_seed_email_templates.py b/backend/alembic/versions/f7g8h9i0j1k2_seed_email_templates.py index 06e591e..d5a3714 100644 --- a/backend/alembic/versions/f7g8h9i0j1k2_seed_email_templates.py +++ b/backend/alembic/versions/f7g8h9i0j1k2_seed_email_templates.py @@ -5,11 +5,13 @@ Create Date: 2025-01-21 """ -from alembic import op -import sqlalchemy as sa -from datetime import datetime, timezone -import uuid import json +import uuid +from datetime import datetime, timezone + +import sqlalchemy as sa + +from alembic import op revision = "f7g8h9i0j1k2" down_revision = "e6f7g8h9i0j1" @@ -147,10 +149,9 @@ def upgrade() -> None: id=str(uuid.uuid4()), body=MENTION_NOTIFICATION_BODY, mandatory_vars=json.dumps(["mentioned_by_name", "context_title", "view_url"]), - available_vars=json.dumps([ - "mentioned_by_name", "context_title", "context_description", - "recent_messages", "view_url" - ]), + available_vars=json.dumps( + ["mentioned_by_name", "context_title", "context_description", "recent_messages", "view_url"] + ), now=now, ) ) diff --git a/backend/alembic/versions/f9a38e4b733d_initial_migration_create_users_table.py b/backend/alembic/versions/f9a38e4b733d_initial_migration_create_users_table.py index 8f9b215..32252a6 100644 --- a/backend/alembic/versions/f9a38e4b733d_initial_migration_create_users_table.py +++ b/backend/alembic/versions/f9a38e4b733d_initial_migration_create_users_table.py @@ -1,18 +1,19 @@ """Initial migration - create users table Revision ID: f9a38e4b733d -Revises: +Revises: Create Date: 2025-11-20 07:58:16.443565 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'f9a38e4b733d' +revision: str = "f9a38e4b733d" down_revision: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,18 +23,18 @@ def upgrade() -> None: """Upgrade schema.""" # Create users table op.create_table( - 'users', - sa.Column('id', sa.UUID(), nullable=False), - sa.Column('email', sa.String(length=255), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.PrimaryKeyConstraint('id', name=op.f('pk_users')), - sa.UniqueConstraint('email', name=op.f('uq_users_email')) + "users", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("email", sa.String(length=255), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_users")), + sa.UniqueConstraint("email", name=op.f("uq_users_email")), ) - op.create_index(op.f('ix_email'), 'users', ['email'], unique=False) + op.create_index(op.f("ix_email"), "users", ["email"], unique=False) def downgrade() -> None: """Downgrade schema.""" # Drop users table - op.drop_index(op.f('ix_email'), table_name='users') - op.drop_table('users') + op.drop_index(op.f("ix_email"), table_name="users") + op.drop_table("users") diff --git a/backend/alembic/versions/fb1c2d3e4f5g_add_notes_updated_fields.py b/backend/alembic/versions/fb1c2d3e4f5g_add_notes_updated_fields.py index 5ba7090..ad1b5ae 100644 --- a/backend/alembic/versions/fb1c2d3e4f5g_add_notes_updated_fields.py +++ b/backend/alembic/versions/fb1c2d3e4f5g_add_notes_updated_fields.py @@ -5,11 +5,12 @@ Create Date: 2024-12-18 12:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "fb1c2d3e4f5g" diff --git a/backend/alembic/versions/fd30aac0413b_add_created_by_to_pre_phase_message.py b/backend/alembic/versions/fd30aac0413b_add_created_by_to_pre_phase_message.py index 4173927..f0c524b 100644 --- a/backend/alembic/versions/fd30aac0413b_add_created_by_to_pre_phase_message.py +++ b/backend/alembic/versions/fd30aac0413b_add_created_by_to_pre_phase_message.py @@ -5,36 +5,28 @@ Create Date: 2026-01-14 17:13:02.981564 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'fd30aac0413b' -down_revision: Union[str, Sequence[str], None] = 'sid01' +revision: str = "fd30aac0413b" +down_revision: Union[str, Sequence[str], None] = "sid01" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Add created_by column to pre_phase_messages for user tracking.""" - op.add_column( - 'pre_phase_messages', - sa.Column('created_by', postgresql.UUID(as_uuid=True), nullable=True) - ) - op.create_foreign_key( - 'fk_pre_phase_messages_created_by', - 'pre_phase_messages', - 'users', - ['created_by'], - ['id'] - ) + op.add_column("pre_phase_messages", sa.Column("created_by", postgresql.UUID(as_uuid=True), nullable=True)) + op.create_foreign_key("fk_pre_phase_messages_created_by", "pre_phase_messages", "users", ["created_by"], ["id"]) def downgrade() -> None: """Remove created_by column from pre_phase_messages.""" - op.drop_constraint('fk_pre_phase_messages_created_by', 'pre_phase_messages', type_='foreignkey') - op.drop_column('pre_phase_messages', 'created_by') + op.drop_constraint("fk_pre_phase_messages_created_by", "pre_phase_messages", type_="foreignkey") + op.drop_column("pre_phase_messages", "created_by") diff --git a/backend/alembic/versions/freemium01_increase_max_users_to_5.py b/backend/alembic/versions/freemium01_increase_max_users_to_5.py index cdfdf3d..1eb674d 100644 --- a/backend/alembic/versions/freemium01_increase_max_users_to_5.py +++ b/backend/alembic/versions/freemium01_increase_max_users_to_5.py @@ -5,14 +5,14 @@ Create Date: 2025-02-02 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. -revision: str = 'freemium01' -down_revision: Union[str, None] = 'grnotes01' +revision: str = "freemium01" +down_revision: Union[str, None] = "grnotes01" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/backend/alembic/versions/g6h7i8j9k0l1_add_mcp_call_logs_table.py b/backend/alembic/versions/g6h7i8j9k0l1_add_mcp_call_logs_table.py index 44efb6c..b1ad141 100644 --- a/backend/alembic/versions/g6h7i8j9k0l1_add_mcp_call_logs_table.py +++ b/backend/alembic/versions/g6h7i8j9k0l1_add_mcp_call_logs_table.py @@ -5,16 +5,17 @@ Create Date: 2025-12-06 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'g6h7i8j9k0l1' -down_revision: Union[str, Sequence[str], None] = '2f8c2d246f32' +revision: str = "g6h7i8j9k0l1" +down_revision: Union[str, Sequence[str], None] = "2f8c2d246f32" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -22,58 +23,70 @@ def upgrade() -> None: """Create mcp_call_logs table.""" op.create_table( - 'mcp_call_logs', - sa.Column('id', postgresql.UUID(as_uuid=True), primary_key=True), + "mcp_call_logs", + sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), # Who made the call - sa.Column('user_id', postgresql.UUID(as_uuid=True), - sa.ForeignKey('users.id', ondelete='CASCADE'), nullable=False), - sa.Column('api_key_id', postgresql.UUID(as_uuid=True), - sa.ForeignKey('api_keys.id', ondelete='SET NULL'), nullable=True), - sa.Column('api_key_name', sa.String(200), nullable=False), + sa.Column( + "user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ), + sa.Column( + "api_key_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("api_keys.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("api_key_name", sa.String(200), nullable=False), # Where (context) - sa.Column('org_id', postgresql.UUID(as_uuid=True), - sa.ForeignKey('organizations.id', ondelete='CASCADE'), nullable=False), - sa.Column('project_id', postgresql.UUID(as_uuid=True), - sa.ForeignKey('projects.id', ondelete='CASCADE'), nullable=False), + sa.Column( + "org_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "project_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, + ), # What was called - sa.Column('tool_name', sa.String(100), nullable=False), - sa.Column('jsonrpc_method', sa.String(100), nullable=False), + sa.Column("tool_name", sa.String(100), nullable=False), + sa.Column("jsonrpc_method", sa.String(100), nullable=False), # Request/Response - sa.Column('request_params', postgresql.JSONB, nullable=True), - sa.Column('response_result', postgresql.JSONB, nullable=True), - sa.Column('response_error', postgresql.JSONB, nullable=True), - sa.Column('is_error', sa.Boolean(), nullable=False, server_default='false'), + sa.Column("request_params", postgresql.JSONB, nullable=True), + sa.Column("response_result", postgresql.JSONB, nullable=True), + sa.Column("response_error", postgresql.JSONB, nullable=True), + sa.Column("is_error", sa.Boolean(), nullable=False, server_default="false"), # Timing - sa.Column('started_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('finished_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('duration_ms', sa.Integer(), nullable=False), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("duration_ms", sa.Integer(), nullable=False), # Metadata - sa.Column('created_at', sa.DateTime(timezone=True), - server_default=sa.func.now(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False), ) # Create indexes - op.create_index('ix_mcp_call_logs_user_id', 'mcp_call_logs', ['user_id']) - op.create_index('ix_mcp_call_logs_api_key_id', 'mcp_call_logs', ['api_key_id']) - op.create_index('ix_mcp_call_logs_org_id', 'mcp_call_logs', ['org_id']) - op.create_index('ix_mcp_call_logs_project_id', 'mcp_call_logs', ['project_id']) - op.create_index('ix_mcp_call_logs_tool_name', 'mcp_call_logs', ['tool_name']) - op.create_index('ix_mcp_call_logs_created_at', 'mcp_call_logs', ['created_at']) + op.create_index("ix_mcp_call_logs_user_id", "mcp_call_logs", ["user_id"]) + op.create_index("ix_mcp_call_logs_api_key_id", "mcp_call_logs", ["api_key_id"]) + op.create_index("ix_mcp_call_logs_org_id", "mcp_call_logs", ["org_id"]) + op.create_index("ix_mcp_call_logs_project_id", "mcp_call_logs", ["project_id"]) + op.create_index("ix_mcp_call_logs_tool_name", "mcp_call_logs", ["tool_name"]) + op.create_index("ix_mcp_call_logs_created_at", "mcp_call_logs", ["created_at"]) # Composite indexes - op.create_index('ix_mcp_call_logs_org_created', 'mcp_call_logs', ['org_id', 'created_at']) - op.create_index('ix_mcp_call_logs_project_created', 'mcp_call_logs', ['project_id', 'created_at']) - op.create_index('ix_mcp_call_logs_user_created', 'mcp_call_logs', ['user_id', 'created_at']) + op.create_index("ix_mcp_call_logs_org_created", "mcp_call_logs", ["org_id", "created_at"]) + op.create_index("ix_mcp_call_logs_project_created", "mcp_call_logs", ["project_id", "created_at"]) + op.create_index("ix_mcp_call_logs_user_created", "mcp_call_logs", ["user_id", "created_at"]) def downgrade() -> None: """Drop mcp_call_logs table.""" - op.drop_index('ix_mcp_call_logs_user_created') - op.drop_index('ix_mcp_call_logs_project_created') - op.drop_index('ix_mcp_call_logs_org_created') - op.drop_index('ix_mcp_call_logs_created_at') - op.drop_index('ix_mcp_call_logs_tool_name') - op.drop_index('ix_mcp_call_logs_project_id') - op.drop_index('ix_mcp_call_logs_org_id') - op.drop_index('ix_mcp_call_logs_api_key_id') - op.drop_index('ix_mcp_call_logs_user_id') - op.drop_table('mcp_call_logs') + op.drop_index("ix_mcp_call_logs_user_created") + op.drop_index("ix_mcp_call_logs_project_created") + op.drop_index("ix_mcp_call_logs_org_created") + op.drop_index("ix_mcp_call_logs_created_at") + op.drop_index("ix_mcp_call_logs_tool_name") + op.drop_index("ix_mcp_call_logs_project_id") + op.drop_index("ix_mcp_call_logs_org_id") + op.drop_index("ix_mcp_call_logs_api_key_id") + op.drop_index("ix_mcp_call_logs_user_id") + op.drop_table("mcp_call_logs") diff --git a/backend/alembic/versions/g8h9i0j1k2l3_add_triggered_by_and_duration.py b/backend/alembic/versions/g8h9i0j1k2l3_add_triggered_by_and_duration.py index 870f10a..5c0fdd4 100644 --- a/backend/alembic/versions/g8h9i0j1k2l3_add_triggered_by_and_duration.py +++ b/backend/alembic/versions/g8h9i0j1k2l3_add_triggered_by_and_duration.py @@ -5,12 +5,14 @@ Create Date: 2025-12-21 14:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op + # revision identifiers, used by Alembic. revision: str = "g8h9i0j1k2l3" down_revision: Union[str, None] = "f7g8h9i0j1k2" @@ -24,9 +26,7 @@ def upgrade() -> None: "jobs", sa.Column("triggered_by_user_id", postgresql.UUID(as_uuid=True), nullable=True), ) - op.create_index( - "ix_jobs_triggered_by_user_id", "jobs", ["triggered_by_user_id"], unique=False - ) + op.create_index("ix_jobs_triggered_by_user_id", "jobs", ["triggered_by_user_id"], unique=False) # Add triggered_by_user_id and duration_ms to llm_usage_logs table op.add_column( @@ -55,9 +55,7 @@ def upgrade() -> None: def downgrade() -> None: # Remove from llm_usage_logs - op.drop_constraint( - "fk_llm_usage_logs_triggered_by_user_id", "llm_usage_logs", type_="foreignkey" - ) + op.drop_constraint("fk_llm_usage_logs_triggered_by_user_id", "llm_usage_logs", type_="foreignkey") op.drop_index("ix_llm_usage_logs_triggered_by_user_id", table_name="llm_usage_logs") op.drop_column("llm_usage_logs", "duration_ms") op.drop_column("llm_usage_logs", "triggered_by_user_id") diff --git a/backend/alembic/versions/gc2d3e4f5g6h_add_thread_ai_error_fields.py b/backend/alembic/versions/gc2d3e4f5g6h_add_thread_ai_error_fields.py index a412316..0fd5af3 100644 --- a/backend/alembic/versions/gc2d3e4f5g6h_add_thread_ai_error_fields.py +++ b/backend/alembic/versions/gc2d3e4f5g6h_add_thread_ai_error_fields.py @@ -5,25 +5,26 @@ Create Date: 2025-01-01 00:00:00.000000 """ -from alembic import op + import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = 'gc2d3e4f5g6h' -down_revision = 'fb1c2d3e4f5g' +revision = "gc2d3e4f5g6h" +down_revision = "fb1c2d3e4f5g" branch_labels = None depends_on = None def upgrade() -> None: # Add AI error state columns for persisting error state across page refreshes - op.add_column('threads', sa.Column('ai_error_message', sa.Text(), nullable=True)) - op.add_column('threads', sa.Column('ai_error_job_id', sa.String(), nullable=True)) - op.add_column('threads', sa.Column('ai_error_user_message', sa.Text(), nullable=True)) + op.add_column("threads", sa.Column("ai_error_message", sa.Text(), nullable=True)) + op.add_column("threads", sa.Column("ai_error_job_id", sa.String(), nullable=True)) + op.add_column("threads", sa.Column("ai_error_user_message", sa.Text(), nullable=True)) def downgrade() -> None: - op.drop_column('threads', 'ai_error_user_message') - op.drop_column('threads', 'ai_error_job_id') - op.drop_column('threads', 'ai_error_message') + op.drop_column("threads", "ai_error_user_message") + op.drop_column("threads", "ai_error_job_id") + op.drop_column("threads", "ai_error_message") diff --git a/backend/alembic/versions/gen01flags_add_generation_status_flags.py b/backend/alembic/versions/gen01flags_add_generation_status_flags.py index 76f1d15..caf4a8d 100644 --- a/backend/alembic/versions/gen01flags_add_generation_status_flags.py +++ b/backend/alembic/versions/gen01flags_add_generation_status_flags.py @@ -11,9 +11,9 @@ when users switch tabs during generation. """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = "gen01flags" diff --git a/backend/alembic/versions/ghoauth01_add_github_oauth_states.py b/backend/alembic/versions/ghoauth01_add_github_oauth_states.py index 3570aa6..4db2d64 100644 --- a/backend/alembic/versions/ghoauth01_add_github_oauth_states.py +++ b/backend/alembic/versions/ghoauth01_add_github_oauth_states.py @@ -10,10 +10,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import UUID +from alembic import op revision: str = "ghoauth01" down_revision: Union[str, Sequence[str], None] = "ba27eeb7dc05" diff --git a/backend/alembic/versions/ghoauth02_add_github_oauth_to_platform_settings.py b/backend/alembic/versions/ghoauth02_add_github_oauth_to_platform_settings.py index 613668e..7914465 100644 --- a/backend/alembic/versions/ghoauth02_add_github_oauth_to_platform_settings.py +++ b/backend/alembic/versions/ghoauth02_add_github_oauth_to_platform_settings.py @@ -15,9 +15,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "ghoauth02" diff --git a/backend/alembic/versions/grdbr01_add_grounding_file_branches.py b/backend/alembic/versions/grdbr01_add_grounding_file_branches.py index 2b2beac..0d3739e 100644 --- a/backend/alembic/versions/grdbr01_add_grounding_file_branches.py +++ b/backend/alembic/versions/grdbr01_add_grounding_file_branches.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "grdbr01" diff --git a/backend/alembic/versions/grdbr02_add_is_merging_flag.py b/backend/alembic/versions/grdbr02_add_is_merging_flag.py index 7d93440..b9007a4 100644 --- a/backend/alembic/versions/grdbr02_add_is_merging_flag.py +++ b/backend/alembic/versions/grdbr02_add_is_merging_flag.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "grdbr02" diff --git a/backend/alembic/versions/grdbr03_add_content_updated_at.py b/backend/alembic/versions/grdbr03_add_content_updated_at.py index 10d93a3..a6ac1eb 100644 --- a/backend/alembic/versions/grdbr03_add_content_updated_at.py +++ b/backend/alembic/versions/grdbr03_add_content_updated_at.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "grdbr03" @@ -31,9 +31,7 @@ def upgrade() -> None: ) # Set existing rows to use their updated_at value - op.execute( - "UPDATE grounding_file_branches SET content_updated_at = updated_at WHERE content_updated_at IS NULL" - ) + op.execute("UPDATE grounding_file_branches SET content_updated_at = updated_at WHERE content_updated_at IS NULL") # Make the column non-nullable with a default for new rows op.alter_column( diff --git a/backend/alembic/versions/grdbr04_add_global_content_updated_at.py b/backend/alembic/versions/grdbr04_add_global_content_updated_at.py index 94cba4b..330efbb 100644 --- a/backend/alembic/versions/grdbr04_add_global_content_updated_at.py +++ b/backend/alembic/versions/grdbr04_add_global_content_updated_at.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "grdbr04" @@ -31,9 +31,7 @@ def upgrade() -> None: ) # Set existing rows to use their updated_at value - op.execute( - "UPDATE grounding_files SET content_updated_at = updated_at WHERE content_updated_at IS NULL" - ) + op.execute("UPDATE grounding_files SET content_updated_at = updated_at WHERE content_updated_at IS NULL") # Make the column non-nullable with a default for new rows op.alter_column( diff --git a/backend/alembic/versions/grdbr05_add_last_synced_with_global_at.py b/backend/alembic/versions/grdbr05_add_last_synced_with_global_at.py index 563800e..5ca1bce 100644 --- a/backend/alembic/versions/grdbr05_add_last_synced_with_global_at.py +++ b/backend/alembic/versions/grdbr05_add_last_synced_with_global_at.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "grdbr05" diff --git a/backend/alembic/versions/grnotes01_add_grounding_note_versions.py b/backend/alembic/versions/grnotes01_add_grounding_note_versions.py index 0cbf21b..f6b25d8 100644 --- a/backend/alembic/versions/grnotes01_add_grounding_note_versions.py +++ b/backend/alembic/versions/grnotes01_add_grounding_note_versions.py @@ -5,15 +5,17 @@ Create Date: 2025-01-28 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op + # revision identifiers, used by Alembic. -revision: str = 'grnotes01' -down_revision: Union[str, None] = 'phc04' +revision: str = "grnotes01" +down_revision: Union[str, None] = "phc04" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,27 +23,27 @@ def upgrade() -> None: # Create grounding_note_versions table op.create_table( - 'grounding_note_versions', - sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('project_id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('version', sa.Integer(), nullable=False), - sa.Column('content_markdown', sa.Text(), nullable=False), - sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), - sa.Column('edit_source', sa.String(length=50), nullable=True), - sa.Column('created_by', postgresql.UUID(as_uuid=True), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), - sa.ForeignKeyConstraint(['created_by'], ['users.id'], ondelete='SET NULL'), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('project_id', 'version', name='uq_grounding_note_version'), + "grounding_note_versions", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("version", sa.Integer(), nullable=False), + sa.Column("content_markdown", sa.Text(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default="true"), + sa.Column("edit_source", sa.String(length=50), nullable=True), + sa.Column("created_by", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), + sa.ForeignKeyConstraint(["created_by"], ["users.id"], ondelete="SET NULL"), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("project_id", "version", name="uq_grounding_note_version"), ) # Create indexes - op.create_index('ix_grounding_note_versions_project_id', 'grounding_note_versions', ['project_id']) - op.create_index('ix_grounding_note_active', 'grounding_note_versions', ['project_id', 'is_active']) + op.create_index("ix_grounding_note_versions_project_id", "grounding_note_versions", ["project_id"]) + op.create_index("ix_grounding_note_active", "grounding_note_versions", ["project_id", "is_active"]) def downgrade() -> None: - op.drop_index('ix_grounding_note_active', table_name='grounding_note_versions') - op.drop_index('ix_grounding_note_versions_project_id', table_name='grounding_note_versions') - op.drop_table('grounding_note_versions') + op.drop_index("ix_grounding_note_active", table_name="grounding_note_versions") + op.drop_index("ix_grounding_note_versions_project_id", table_name="grounding_note_versions") + op.drop_table("grounding_note_versions") diff --git a/backend/alembic/versions/h7i8j9k0l1m2_add_vfs_metadata_table.py b/backend/alembic/versions/h7i8j9k0l1m2_add_vfs_metadata_table.py index 89afe14..0fb505e 100644 --- a/backend/alembic/versions/h7i8j9k0l1m2_add_vfs_metadata_table.py +++ b/backend/alembic/versions/h7i8j9k0l1m2_add_vfs_metadata_table.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "h7i8j9k0l1m2" diff --git a/backend/alembic/versions/h9i0j1k2l3m4_add_thread_item_summary_snapshots.py b/backend/alembic/versions/h9i0j1k2l3m4_add_thread_item_summary_snapshots.py index fb24400..cbfd0fb 100644 --- a/backend/alembic/versions/h9i0j1k2l3m4_add_thread_item_summary_snapshots.py +++ b/backend/alembic/versions/h9i0j1k2l3m4_add_thread_item_summary_snapshots.py @@ -5,11 +5,12 @@ Create Date: 2025-12-21 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "h9i0j1k2l3m4" diff --git a/backend/alembic/versions/hd3e4f5g6h7i_add_session_ai_error_fields.py b/backend/alembic/versions/hd3e4f5g6h7i_add_session_ai_error_fields.py index 0801e7c..26f9afa 100644 --- a/backend/alembic/versions/hd3e4f5g6h7i_add_session_ai_error_fields.py +++ b/backend/alembic/versions/hd3e4f5g6h7i_add_session_ai_error_fields.py @@ -5,26 +5,27 @@ Create Date: 2025-01-01 00:00:00.000000 """ -from alembic import op + import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision = 'hd3e4f5g6h7i' -down_revision = 'gc2d3e4f5g6h' +revision = "hd3e4f5g6h7i" +down_revision = "gc2d3e4f5g6h" branch_labels = None depends_on = None def upgrade() -> None: # Add AI error state columns for persisting error state across page refreshes - op.add_column('user_question_sessions', sa.Column('ai_error_message', sa.Text(), nullable=True)) - op.add_column('user_question_sessions', sa.Column('ai_error_job_id', postgresql.UUID(as_uuid=True), nullable=True)) - op.add_column('user_question_sessions', sa.Column('ai_error_user_prompt', sa.Text(), nullable=True)) + op.add_column("user_question_sessions", sa.Column("ai_error_message", sa.Text(), nullable=True)) + op.add_column("user_question_sessions", sa.Column("ai_error_job_id", postgresql.UUID(as_uuid=True), nullable=True)) + op.add_column("user_question_sessions", sa.Column("ai_error_user_prompt", sa.Text(), nullable=True)) def downgrade() -> None: - op.drop_column('user_question_sessions', 'ai_error_user_prompt') - op.drop_column('user_question_sessions', 'ai_error_job_id') - op.drop_column('user_question_sessions', 'ai_error_message') + op.drop_column("user_question_sessions", "ai_error_user_prompt") + op.drop_column("user_question_sessions", "ai_error_job_id") + op.drop_column("user_question_sessions", "ai_error_message") diff --git a/backend/alembic/versions/i0j1k2l3m4n5_add_mcp_image_submissions.py b/backend/alembic/versions/i0j1k2l3m4n5_add_mcp_image_submissions.py index 51cfd5e..988c592 100644 --- a/backend/alembic/versions/i0j1k2l3m4n5_add_mcp_image_submissions.py +++ b/backend/alembic/versions/i0j1k2l3m4n5_add_mcp_image_submissions.py @@ -5,15 +5,17 @@ Create Date: 2025-12-22 12:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op + # revision identifiers, used by Alembic. -revision: str = 'i0j1k2l3m4n5' -down_revision: Union[str, None] = 'h9i0j1k2l3m4' +revision: str = "i0j1k2l3m4n5" +down_revision: Union[str, None] = "h9i0j1k2l3m4" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,38 +23,40 @@ def upgrade() -> None: """Create mcp_image_submissions table for staging VFS image uploads.""" op.create_table( - 'mcp_image_submissions', - sa.Column('id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('submission_id', sa.String(length=100), nullable=False), - sa.Column('project_id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('feature_id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False), - sa.Column('filename', sa.String(length=255), nullable=False), - sa.Column('content_type', sa.String(length=50), nullable=False), - sa.Column('image_data', sa.LargeBinary(), nullable=False), - sa.Column('size_bytes', sa.Integer(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('expires_at', sa.DateTime(timezone=True), nullable=False), - sa.ForeignKeyConstraint(['feature_id'], ['features.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ondelete='CASCADE'), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id'), + "mcp_image_submissions", + sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("submission_id", sa.String(length=100), nullable=False), + sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("feature_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), + sa.Column("filename", sa.String(length=255), nullable=False), + sa.Column("content_type", sa.String(length=50), nullable=False), + sa.Column("image_data", sa.LargeBinary(), nullable=False), + sa.Column("size_bytes", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["feature_id"], ["features.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["projects.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), ) # Indexes - op.create_index('ix_mcp_image_submissions_submission_id', 'mcp_image_submissions', ['submission_id'], unique=True) - op.create_index('ix_mcp_image_submissions_project_id', 'mcp_image_submissions', ['project_id'], unique=False) - op.create_index('ix_mcp_image_submissions_feature_id', 'mcp_image_submissions', ['feature_id'], unique=False) - op.create_index('ix_mcp_image_submissions_user_id', 'mcp_image_submissions', ['user_id'], unique=False) - op.create_index('ix_mcp_image_submissions_expires', 'mcp_image_submissions', ['expires_at'], unique=False) - op.create_index('ix_mcp_image_submissions_feature_user', 'mcp_image_submissions', ['feature_id', 'user_id'], unique=False) + op.create_index("ix_mcp_image_submissions_submission_id", "mcp_image_submissions", ["submission_id"], unique=True) + op.create_index("ix_mcp_image_submissions_project_id", "mcp_image_submissions", ["project_id"], unique=False) + op.create_index("ix_mcp_image_submissions_feature_id", "mcp_image_submissions", ["feature_id"], unique=False) + op.create_index("ix_mcp_image_submissions_user_id", "mcp_image_submissions", ["user_id"], unique=False) + op.create_index("ix_mcp_image_submissions_expires", "mcp_image_submissions", ["expires_at"], unique=False) + op.create_index( + "ix_mcp_image_submissions_feature_user", "mcp_image_submissions", ["feature_id", "user_id"], unique=False + ) def downgrade() -> None: """Drop mcp_image_submissions table.""" - op.drop_index('ix_mcp_image_submissions_feature_user', table_name='mcp_image_submissions') - op.drop_index('ix_mcp_image_submissions_expires', table_name='mcp_image_submissions') - op.drop_index('ix_mcp_image_submissions_user_id', table_name='mcp_image_submissions') - op.drop_index('ix_mcp_image_submissions_feature_id', table_name='mcp_image_submissions') - op.drop_index('ix_mcp_image_submissions_project_id', table_name='mcp_image_submissions') - op.drop_index('ix_mcp_image_submissions_submission_id', table_name='mcp_image_submissions') - op.drop_table('mcp_image_submissions') + op.drop_index("ix_mcp_image_submissions_feature_user", table_name="mcp_image_submissions") + op.drop_index("ix_mcp_image_submissions_expires", table_name="mcp_image_submissions") + op.drop_index("ix_mcp_image_submissions_user_id", table_name="mcp_image_submissions") + op.drop_index("ix_mcp_image_submissions_feature_id", table_name="mcp_image_submissions") + op.drop_index("ix_mcp_image_submissions_project_id", table_name="mcp_image_submissions") + op.drop_index("ix_mcp_image_submissions_submission_id", table_name="mcp_image_submissions") + op.drop_table("mcp_image_submissions") diff --git a/backend/alembic/versions/i8j9k0l1m2n3_add_team_roles_tables.py b/backend/alembic/versions/i8j9k0l1m2n3_add_team_roles_tables.py index d061f15..98b4636 100644 --- a/backend/alembic/versions/i8j9k0l1m2n3_add_team_roles_tables.py +++ b/backend/alembic/versions/i8j9k0l1m2n3_add_team_roles_tables.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "i8j9k0l1m2n3" diff --git a/backend/alembic/versions/icshare01_add_integration_config_sharing.py b/backend/alembic/versions/icshare01_add_integration_config_sharing.py index 36b8b7f..b826aef 100644 --- a/backend/alembic/versions/icshare01_add_integration_config_sharing.py +++ b/backend/alembic/versions/icshare01_add_integration_config_sharing.py @@ -12,10 +12,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import UUID +from alembic import op revision: str = "icshare01" down_revision: Union[str, Sequence[str], None] = "fd30aac0413b" diff --git a/backend/alembic/versions/ie4f5g6h7i8j_convert_trial_to_freemium.py b/backend/alembic/versions/ie4f5g6h7i8j_convert_trial_to_freemium.py index a6aef70..85814df 100644 --- a/backend/alembic/versions/ie4f5g6h7i8j_convert_trial_to_freemium.py +++ b/backend/alembic/versions/ie4f5g6h7i8j_convert_trial_to_freemium.py @@ -8,14 +8,14 @@ Create Date: 2025-01-01 00:00:00.000000 """ -from alembic import op + import sqlalchemy as sa -from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. -revision = 'ie4f5g6h7i8j' -down_revision = 'hd3e4f5g6h7i' +revision = "ie4f5g6h7i8j" +down_revision = "hd3e4f5g6h7i" branch_labels = None depends_on = None @@ -23,45 +23,45 @@ def upgrade() -> None: # 1. Add freemium configuration to platform_settings op.add_column( - 'platform_settings', + "platform_settings", sa.Column( - 'freemium_initial_tokens', + "freemium_initial_tokens", sa.Integer(), nullable=False, - server_default='5000000', - comment='Initial tokens granted to new users on signup' - ) + server_default="5000000", + comment="Initial tokens granted to new users on signup", + ), ) op.add_column( - 'platform_settings', + "platform_settings", sa.Column( - 'freemium_weekly_topup_tokens', + "freemium_weekly_topup_tokens", sa.Integer(), nullable=False, - server_default='10000000', - comment='Tokens added each Monday (additive, up to max)' - ) + server_default="10000000", + comment="Tokens added each Monday (additive, up to max)", + ), ) op.add_column( - 'platform_settings', + "platform_settings", sa.Column( - 'freemium_max_tokens', + "freemium_max_tokens", sa.Integer(), nullable=False, - server_default='10000000', - comment='Maximum token balance for freemium users' - ) + server_default="10000000", + comment="Maximum token balance for freemium users", + ), ) # 2. Add last top-up tracking to organizations op.add_column( - 'organizations', + "organizations", sa.Column( - 'last_freemium_topup_at', + "last_freemium_topup_at", sa.DateTime(timezone=True), nullable=True, - comment='When the org last received a weekly freemium token top-up' - ) + comment="When the org last received a weekly freemium token top-up", + ), ) # 3. Convert existing trial orgs to freemium @@ -87,7 +87,7 @@ def downgrade() -> None: """) # Remove columns - op.drop_column('organizations', 'last_freemium_topup_at') - op.drop_column('platform_settings', 'freemium_max_tokens') - op.drop_column('platform_settings', 'freemium_weekly_topup_tokens') - op.drop_column('platform_settings', 'freemium_initial_tokens') + op.drop_column("organizations", "last_freemium_topup_at") + op.drop_column("platform_settings", "freemium_max_tokens") + op.drop_column("platform_settings", "freemium_weekly_topup_tokens") + op.drop_column("platform_settings", "freemium_initial_tokens") diff --git a/backend/alembic/versions/impl01_add_implementations_table.py b/backend/alembic/versions/impl01_add_implementations_table.py index aaee021..2ce7ca5 100644 --- a/backend/alembic/versions/impl01_add_implementations_table.py +++ b/backend/alembic/versions/impl01_add_implementations_table.py @@ -6,14 +6,14 @@ """ -from typing import Sequence, Union import uuid from datetime import datetime, timezone +from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "impl01" @@ -134,7 +134,7 @@ def upgrade() -> None: "notes_updated_at": notes_updated_at, "created_by": str(created_by) if created_by else None, "created_at": created_at or datetime.now(timezone.utc), - } + }, ) diff --git a/backend/alembic/versions/impl02_add_implementation_created_enum.py b/backend/alembic/versions/impl02_add_implementation_created_enum.py index 0873c34..001ac65 100644 --- a/backend/alembic/versions/impl02_add_implementation_created_enum.py +++ b/backend/alembic/versions/impl02_add_implementation_created_enum.py @@ -5,11 +5,11 @@ Create Date: 2025-01-02 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. revision: str = "impl02" down_revision: Union[str, None] = "impl01" diff --git a/backend/alembic/versions/implcs01_add_completion_summary_to_implementations.py b/backend/alembic/versions/implcs01_add_completion_summary_to_implementations.py index 6b47566..485fc0b 100644 --- a/backend/alembic/versions/implcs01_add_completion_summary_to_implementations.py +++ b/backend/alembic/versions/implcs01_add_completion_summary_to_implementations.py @@ -5,23 +5,21 @@ Create Date: 2026-01-09 """ -from alembic import op + import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = 'implcs01' -down_revision = '8eebe3a100d3' +revision = "implcs01" +down_revision = "8eebe3a100d3" branch_labels = None depends_on = None def upgrade() -> None: - op.add_column( - 'implementations', - sa.Column('completion_summary', sa.Text(), nullable=True) - ) + op.add_column("implementations", sa.Column("completion_summary", sa.Text(), nullable=True)) def downgrade() -> None: - op.drop_column('implementations', 'completion_summary') + op.drop_column("implementations", "completion_summary") diff --git a/backend/alembic/versions/j1k2l3m4n5o6_add_form_drafts_table.py b/backend/alembic/versions/j1k2l3m4n5o6_add_form_drafts_table.py index ad2a16c..1ef51af 100644 --- a/backend/alembic/versions/j1k2l3m4n5o6_add_form_drafts_table.py +++ b/backend/alembic/versions/j1k2l3m4n5o6_add_form_drafts_table.py @@ -5,12 +5,14 @@ Create Date: 2025-12-23 12:00:00.000000 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op + # revision identifiers, used by Alembic. revision: str = "j1k2l3m4n5o6" down_revision: Union[str, None] = "i0j1k2l3m4n5" diff --git a/backend/alembic/versions/j9k0l1m2n3o4_make_team_roles_dynamic.py b/backend/alembic/versions/j9k0l1m2n3o4_make_team_roles_dynamic.py index 04c8c70..caef178 100644 --- a/backend/alembic/versions/j9k0l1m2n3o4_make_team_roles_dynamic.py +++ b/backend/alembic/versions/j9k0l1m2n3o4_make_team_roles_dynamic.py @@ -9,11 +9,10 @@ from typing import Sequence, Union from uuid import uuid4 -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql -from sqlalchemy.sql import table, column +from alembic import op # revision identifiers, used by Alembic. revision: str = "j9k0l1m2n3o4" @@ -59,18 +58,14 @@ def upgrade() -> None: for role_key, title, description, order_index in DEFAULT_ROLES: # Check if this role already exists for this org existing = conn.execute( - sa.text( - "SELECT id FROM team_role_definitions WHERE org_id = :org_id AND role_key = :role_key" - ), + sa.text("SELECT id FROM team_role_definitions WHERE org_id = :org_id AND role_key = :role_key"), {"org_id": org_id, "role_key": role_key}, ).fetchone() if existing: # Mark existing as default conn.execute( - sa.text( - "UPDATE team_role_definitions SET is_default = true WHERE id = :id" - ), + sa.text("UPDATE team_role_definitions SET is_default = true WHERE id = :id"), {"id": existing[0]}, ) else: diff --git a/backend/alembic/versions/k0l1m2n3o4p5_add_grounding_files_table.py b/backend/alembic/versions/k0l1m2n3o4p5_add_grounding_files_table.py index bc2acef..95e3fd7 100644 --- a/backend/alembic/versions/k0l1m2n3o4p5_add_grounding_files_table.py +++ b/backend/alembic/versions/k0l1m2n3o4p5_add_grounding_files_table.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "k0l1m2n3o4p5" diff --git a/backend/alembic/versions/l1m2n3o4p5q6_add_feature_content_versions.py b/backend/alembic/versions/l1m2n3o4p5q6_add_feature_content_versions.py index 75eeb0e..57bef10 100644 --- a/backend/alembic/versions/l1m2n3o4p5q6_add_feature_content_versions.py +++ b/backend/alembic/versions/l1m2n3o4p5q6_add_feature_content_versions.py @@ -6,14 +6,14 @@ """ -from typing import Sequence, Union import uuid from datetime import datetime, timezone +from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "l1m2n3o4p5q6" @@ -25,7 +25,9 @@ def upgrade() -> None: # Create the enum type content_type_enum = postgresql.ENUM( - "spec", "prompt_plan", "implementation_notes", + "spec", + "prompt_plan", + "implementation_notes", name="featurecontenttype", create_type=False, ) @@ -119,7 +121,7 @@ def upgrade() -> None: "content": spec_text, "created_by": str(created_by) if created_by else None, "created_at": created_at or datetime.now(timezone.utc), - } + }, ) if prompt_plan_text: @@ -135,7 +137,7 @@ def upgrade() -> None: "content": prompt_plan_text, "created_by": str(created_by) if created_by else None, "created_at": created_at or datetime.now(timezone.utc), - } + }, ) if implementation_notes: @@ -151,7 +153,7 @@ def upgrade() -> None: "content": implementation_notes, "created_by": str(created_by) if created_by else None, "created_at": created_at or datetime.now(timezone.utc), - } + }, ) diff --git a/backend/alembic/versions/m2n3o4p5q6r7_add_feature_import_fields.py b/backend/alembic/versions/m2n3o4p5q6r7_add_feature_import_fields.py index 90b00de..8cd2bc5 100644 --- a/backend/alembic/versions/m2n3o4p5q6r7_add_feature_import_fields.py +++ b/backend/alembic/versions/m2n3o4p5q6r7_add_feature_import_fields.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "m2n3o4p5q6r7" diff --git a/backend/alembic/versions/mcqans01_add_mcq_answer_item_type.py b/backend/alembic/versions/mcqans01_add_mcq_answer_item_type.py index 507ad51..b18a6fe 100644 --- a/backend/alembic/versions/mcqans01_add_mcq_answer_item_type.py +++ b/backend/alembic/versions/mcqans01_add_mcq_answer_item_type.py @@ -5,11 +5,11 @@ Create Date: 2024-12-29 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. revision: str = "mcqans01" down_revision: Union[str, None] = "ppd04" diff --git a/backend/alembic/versions/n3o4p5q6r7s8_add_identity_provider_tables.py b/backend/alembic/versions/n3o4p5q6r7s8_add_identity_provider_tables.py index 1f60a4c..78ddb87 100644 --- a/backend/alembic/versions/n3o4p5q6r7s8_add_identity_provider_tables.py +++ b/backend/alembic/versions/n3o4p5q6r7s8_add_identity_provider_tables.py @@ -8,10 +8,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "n3o4p5q6r7s8" diff --git a/backend/alembic/versions/o4p5q6r7s8t9_seed_identity_providers.py b/backend/alembic/versions/o4p5q6r7s8t9_seed_identity_providers.py index ca3e86e..cc2eaf1 100644 --- a/backend/alembic/versions/o4p5q6r7s8t9_seed_identity_providers.py +++ b/backend/alembic/versions/o4p5q6r7s8t9_seed_identity_providers.py @@ -5,10 +5,12 @@ Create Date: 2025-01-20 """ -from alembic import op -import sqlalchemy as sa -from datetime import datetime, timezone import uuid +from datetime import datetime, timezone + +import sqlalchemy as sa + +from alembic import op revision = "o4p5q6r7s8t9" down_revision = "n3o4p5q6r7s8" diff --git a/backend/alembic/versions/p5q6r7s8t9u0_add_platform_settings_tables.py b/backend/alembic/versions/p5q6r7s8t9u0_add_platform_settings_tables.py index 2289775..3838dbb 100644 --- a/backend/alembic/versions/p5q6r7s8t9u0_add_platform_settings_tables.py +++ b/backend/alembic/versions/p5q6r7s8t9u0_add_platform_settings_tables.py @@ -5,11 +5,13 @@ Create Date: 2025-01-20 """ -from alembic import op +import uuid +from datetime import datetime, timezone + import sqlalchemy as sa from sqlalchemy.dialects import postgresql -from datetime import datetime, timezone -import uuid + +from alembic import op revision = "p5q6r7s8t9u0" down_revision = "o4p5q6r7s8t9" @@ -40,9 +42,7 @@ def upgrade() -> None: nullable=False, default=lambda: datetime.now(timezone.utc), ), - sa.UniqueConstraint( - "connector_type", "provider", "display_name", name="uq_platform_connector_name" - ), + sa.UniqueConstraint("connector_type", "provider", "display_name", name="uq_platform_connector_name"), ) # Create platform_settings table (singleton) diff --git a/backend/alembic/versions/pc01_rename_pre_phase_to_project_chat.py b/backend/alembic/versions/pc01_rename_pre_phase_to_project_chat.py index d30a8e4..79a2b1e 100644 --- a/backend/alembic/versions/pc01_rename_pre_phase_to_project_chat.py +++ b/backend/alembic/versions/pc01_rename_pre_phase_to_project_chat.py @@ -15,7 +15,6 @@ from alembic import op - # revision identifiers, used by Alembic. revision: str = "pc01" down_revision: Union[str, None] = "ghoauth02" @@ -39,31 +38,20 @@ def upgrade() -> None: ) # Rename indexes on project_chats (formerly pre_phase_discussions) + op.execute("ALTER INDEX IF EXISTS ix_pre_phase_discussions_org_id RENAME TO ix_project_chats_org_id") + op.execute("ALTER INDEX IF EXISTS ix_pre_phase_discussions_project_id RENAME TO ix_project_chats_project_id") op.execute( - "ALTER INDEX IF EXISTS ix_pre_phase_discussions_org_id " - "RENAME TO ix_project_chats_org_id" - ) - op.execute( - "ALTER INDEX IF EXISTS ix_pre_phase_discussions_project_id " - "RENAME TO ix_project_chats_project_id" - ) - op.execute( - "ALTER INDEX IF EXISTS ix_pre_phase_discussions_created_phase_id " - "RENAME TO ix_project_chats_created_phase_id" + "ALTER INDEX IF EXISTS ix_pre_phase_discussions_created_phase_id RENAME TO ix_project_chats_created_phase_id" ) op.execute( "ALTER INDEX IF EXISTS ix_pre_phase_discussions_created_project_id " "RENAME TO ix_project_chats_created_project_id" ) - op.execute( - "ALTER INDEX IF EXISTS ix_pre_phase_discussions_short_id " - "RENAME TO ix_project_chats_short_id" - ) + op.execute("ALTER INDEX IF EXISTS ix_pre_phase_discussions_short_id RENAME TO ix_project_chats_short_id") # Rename indexes on project_chat_messages (formerly pre_phase_messages) op.execute( - "ALTER INDEX IF EXISTS ix_pre_phase_messages_discussion_id " - "RENAME TO ix_project_chat_messages_project_chat_id" + "ALTER INDEX IF EXISTS ix_pre_phase_messages_discussion_id RENAME TO ix_project_chat_messages_project_chat_id" ) # Rename index on code_exploration_results @@ -96,31 +84,20 @@ def downgrade() -> None: # Rename indexes on project_chat_messages back op.execute( - "ALTER INDEX IF EXISTS ix_project_chat_messages_project_chat_id " - "RENAME TO ix_pre_phase_messages_discussion_id" + "ALTER INDEX IF EXISTS ix_project_chat_messages_project_chat_id RENAME TO ix_pre_phase_messages_discussion_id" ) # Rename indexes on project_chats back - op.execute( - "ALTER INDEX IF EXISTS ix_project_chats_short_id " - "RENAME TO ix_pre_phase_discussions_short_id" - ) + op.execute("ALTER INDEX IF EXISTS ix_project_chats_short_id RENAME TO ix_pre_phase_discussions_short_id") op.execute( "ALTER INDEX IF EXISTS ix_project_chats_created_project_id " "RENAME TO ix_pre_phase_discussions_created_project_id" ) op.execute( - "ALTER INDEX IF EXISTS ix_project_chats_created_phase_id " - "RENAME TO ix_pre_phase_discussions_created_phase_id" - ) - op.execute( - "ALTER INDEX IF EXISTS ix_project_chats_project_id " - "RENAME TO ix_pre_phase_discussions_project_id" - ) - op.execute( - "ALTER INDEX IF EXISTS ix_project_chats_org_id " - "RENAME TO ix_pre_phase_discussions_org_id" + "ALTER INDEX IF EXISTS ix_project_chats_created_phase_id RENAME TO ix_pre_phase_discussions_created_phase_id" ) + op.execute("ALTER INDEX IF EXISTS ix_project_chats_project_id RENAME TO ix_pre_phase_discussions_project_id") + op.execute("ALTER INDEX IF EXISTS ix_project_chats_org_id RENAME TO ix_pre_phase_discussions_org_id") # Rename foreign key column in code_exploration_results back op.alter_column( diff --git a/backend/alembic/versions/pc02_add_retry_status_to_project_chat.py b/backend/alembic/versions/pc02_add_retry_status_to_project_chat.py index bcaf091..3ea9d25 100644 --- a/backend/alembic/versions/pc02_add_retry_status_to_project_chat.py +++ b/backend/alembic/versions/pc02_add_retry_status_to_project_chat.py @@ -11,9 +11,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "pc02" @@ -30,7 +30,7 @@ def upgrade() -> None: "retry_status", sa.String(100), nullable=True, - comment="Current retry status message for UI display (e.g., 'Invalid response. Retrying 2/3')" + comment="Current retry status message for UI display (e.g., 'Invalid response. Retrying 2/3')", ), ) diff --git a/backend/alembic/versions/phc01_add_phase_containers_table.py b/backend/alembic/versions/phc01_add_phase_containers_table.py index 06fd0d7..dae84e4 100644 --- a/backend/alembic/versions/phc01_add_phase_containers_table.py +++ b/backend/alembic/versions/phc01_add_phase_containers_table.py @@ -11,9 +11,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "phc01" diff --git a/backend/alembic/versions/phc02_add_container_fields_to_phases.py b/backend/alembic/versions/phc02_add_container_fields_to_phases.py index e2080ff..e05e89d 100644 --- a/backend/alembic/versions/phc02_add_container_fields_to_phases.py +++ b/backend/alembic/versions/phc02_add_container_fields_to_phases.py @@ -11,9 +11,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "phc02" diff --git a/backend/alembic/versions/phc03_migrate_phases_to_containers.py b/backend/alembic/versions/phc03_migrate_phases_to_containers.py index c974136..fc9bfc4 100644 --- a/backend/alembic/versions/phc03_migrate_phases_to_containers.py +++ b/backend/alembic/versions/phc03_migrate_phases_to_containers.py @@ -9,13 +9,13 @@ """ +from datetime import datetime, timezone from typing import Sequence, Union from uuid import uuid4 -from datetime import datetime, timezone -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "phc03" @@ -28,8 +28,9 @@ def generate_short_id(length: int = 11) -> str: """Generate a short ID for URL-friendly identifiers.""" import random import string + chars = string.ascii_letters + string.digits - return ''.join(random.choice(chars) for _ in range(length)) + return "".join(random.choice(chars) for _ in range(length)) def upgrade() -> None: diff --git a/backend/alembic/versions/phc04_add_target_container_to_project_chats.py b/backend/alembic/versions/phc04_add_target_container_to_project_chats.py index b1701ea..df7dd30 100644 --- a/backend/alembic/versions/phc04_add_target_container_to_project_chats.py +++ b/backend/alembic/versions/phc04_add_target_container_to_project_chats.py @@ -11,9 +11,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "phc04" diff --git a/backend/alembic/versions/ppc01_add_visibility_to_project_chats.py b/backend/alembic/versions/ppc01_add_visibility_to_project_chats.py index caa4b19..7bcc9be 100644 --- a/backend/alembic/versions/ppc01_add_visibility_to_project_chats.py +++ b/backend/alembic/versions/ppc01_add_visibility_to_project_chats.py @@ -14,10 +14,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppc01" @@ -29,7 +29,8 @@ def upgrade() -> None: # Create the visibility enum type visibility_enum = postgresql.ENUM( - "private", "team", + "private", + "team", name="projectchatvisibility", create_type=False, ) diff --git a/backend/alembic/versions/ppd01_add_pre_phase_discussions.py b/backend/alembic/versions/ppd01_add_pre_phase_discussions.py index 06716bb..2982685 100644 --- a/backend/alembic/versions/ppd01_add_pre_phase_discussions.py +++ b/backend/alembic/versions/ppd01_add_pre_phase_discussions.py @@ -9,12 +9,13 @@ Create Date: 2024-12-27 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppd01" @@ -26,7 +27,8 @@ def upgrade() -> None: # Create message_type enum for pre_phase_messages message_type_enum = postgresql.ENUM( - "user", "bot", + "user", + "bot", name="prephasemessagetype", create_type=False, ) diff --git a/backend/alembic/versions/ppd02_add_discussion_phase_link.py b/backend/alembic/versions/ppd02_add_discussion_phase_link.py index f321450..a3b8254 100644 --- a/backend/alembic/versions/ppd02_add_discussion_phase_link.py +++ b/backend/alembic/versions/ppd02_add_discussion_phase_link.py @@ -10,12 +10,13 @@ Create Date: 2024-12-28 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppd02" diff --git a/backend/alembic/versions/ppd03_add_pre_phase_feature_fields.py b/backend/alembic/versions/ppd03_add_pre_phase_feature_fields.py index ef75203..b8f4e03 100644 --- a/backend/alembic/versions/ppd03_add_pre_phase_feature_fields.py +++ b/backend/alembic/versions/ppd03_add_pre_phase_feature_fields.py @@ -11,12 +11,13 @@ Create Date: 2024-12-28 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppd03" diff --git a/backend/alembic/versions/ppd04_add_org_scoped_discussions.py b/backend/alembic/versions/ppd04_add_org_scoped_discussions.py index 9f7f90b..e1621f1 100644 --- a/backend/alembic/versions/ppd04_add_org_scoped_discussions.py +++ b/backend/alembic/versions/ppd04_add_org_scoped_discussions.py @@ -13,12 +13,13 @@ Create Date: 2024-12-28 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppd04" diff --git a/backend/alembic/versions/ppd05_add_image_attachments.py b/backend/alembic/versions/ppd05_add_image_attachments.py index 0310156..51aab23 100644 --- a/backend/alembic/versions/ppd05_add_image_attachments.py +++ b/backend/alembic/versions/ppd05_add_image_attachments.py @@ -13,12 +13,13 @@ Create Date: 2024-12-30 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppd05" diff --git a/backend/alembic/versions/ppd06_add_chat_title.py b/backend/alembic/versions/ppd06_add_chat_title.py index e6882f7..d22845b 100644 --- a/backend/alembic/versions/ppd06_add_chat_title.py +++ b/backend/alembic/versions/ppd06_add_chat_title.py @@ -9,11 +9,12 @@ Create Date: 2025-01-04 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "ppd06" diff --git a/backend/alembic/versions/ppd08_add_summary_snapshot.py b/backend/alembic/versions/ppd08_add_summary_snapshot.py index 221858b..fb7f62d 100644 --- a/backend/alembic/versions/ppd08_add_summary_snapshot.py +++ b/backend/alembic/versions/ppd08_add_summary_snapshot.py @@ -6,9 +6,9 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = "ppd08" diff --git a/backend/alembic/versions/pts01_add_project_tech_stack.py b/backend/alembic/versions/pts01_add_project_tech_stack.py index 201f72e..6d00f5f 100644 --- a/backend/alembic/versions/pts01_add_project_tech_stack.py +++ b/backend/alembic/versions/pts01_add_project_tech_stack.py @@ -10,9 +10,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op revision: str = "pts01" down_revision: Union[str, Sequence[str], None] = "ghoauth01" diff --git a/backend/alembic/versions/q6r7s8t9u0v1_add_base_url_to_platform_settings.py b/backend/alembic/versions/q6r7s8t9u0v1_add_base_url_to_platform_settings.py index dde1c7e..4242ad7 100644 --- a/backend/alembic/versions/q6r7s8t9u0v1_add_base_url_to_platform_settings.py +++ b/backend/alembic/versions/q6r7s8t9u0v1_add_base_url_to_platform_settings.py @@ -5,9 +5,10 @@ Create Date: 2025-01-20 """ -from alembic import op import sqlalchemy as sa +from alembic import op + revision = "q6r7s8t9u0v1" down_revision = "p5q6r7s8t9u0" branch_labels = None diff --git a/backend/alembic/versions/qrs02_add_grounding_file_is_generating.py b/backend/alembic/versions/qrs02_add_grounding_file_is_generating.py index ba4c6a5..21921c4 100644 --- a/backend/alembic/versions/qrs02_add_grounding_file_is_generating.py +++ b/backend/alembic/versions/qrs02_add_grounding_file_is_generating.py @@ -6,9 +6,9 @@ """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision = "qrs02" diff --git a/backend/alembic/versions/qst01phase_add_phase_question_stats.py b/backend/alembic/versions/qst01phase_add_phase_question_stats.py index 850d083..3c4c319 100644 --- a/backend/alembic/versions/qst01phase_add_phase_question_stats.py +++ b/backend/alembic/versions/qst01phase_add_phase_question_stats.py @@ -5,11 +5,12 @@ Create Date: 2024-12-26 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "qst01phase" @@ -53,14 +54,13 @@ def upgrade() -> None: connection = op.get_bind() # Get all phases - phases = connection.execute( - sa.text("SELECT id FROM brainstorming_phases") - ).fetchall() + phases = connection.execute(sa.text("SELECT id FROM brainstorming_phases")).fetchall() for (phase_id,) in phases: # Count active questions using a subquery for module_ids - active_total = connection.execute( - sa.text(""" + active_total = ( + connection.execute( + sa.text(""" SELECT COUNT(*) FROM features f WHERE f.module_id IN ( SELECT m.id FROM modules m @@ -71,12 +71,15 @@ def upgrade() -> None: AND f.visibility_status = 'active' AND f.archived_at IS NULL """), - {"phase_id": phase_id} - ).scalar() or 0 + {"phase_id": phase_id}, + ).scalar() + or 0 + ) # Count pending questions - pending_total = connection.execute( - sa.text(""" + pending_total = ( + connection.execute( + sa.text(""" SELECT COUNT(*) FROM features f WHERE f.module_id IN ( SELECT m.id FROM modules m @@ -87,12 +90,15 @@ def upgrade() -> None: AND f.visibility_status = 'pending' AND f.archived_at IS NULL """), - {"phase_id": phase_id} - ).scalar() or 0 + {"phase_id": phase_id}, + ).scalar() + or 0 + ) # Count answered questions (have MCQ with selected_option_id set) - active_answered = connection.execute( - sa.text(""" + active_answered = ( + connection.execute( + sa.text(""" SELECT COUNT(DISTINCT t.context_id) FROM threads t JOIN thread_items ti ON ti.thread_id = t.id @@ -111,8 +117,10 @@ def upgrade() -> None: AND ti.item_type = 'mcq_followup' AND ti.content_data->>'selected_option_id' IS NOT NULL """), - {"phase_id": phase_id} - ).scalar() or 0 + {"phase_id": phase_id}, + ).scalar() + or 0 + ) # Update the phase connection.execute( @@ -128,7 +136,7 @@ def upgrade() -> None: "active_answered": active_answered, "active_total": active_total, "pending_total": pending_total, - } + }, ) diff --git a/backend/alembic/versions/r7s8t9u0v1w2_add_invitations_and_groups.py b/backend/alembic/versions/r7s8t9u0v1w2_add_invitations_and_groups.py index a653526..b4bd35f 100644 --- a/backend/alembic/versions/r7s8t9u0v1w2_add_invitations_and_groups.py +++ b/backend/alembic/versions/r7s8t9u0v1w2_add_invitations_and_groups.py @@ -12,10 +12,11 @@ Create Date: 2025-01-20 """ -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSON, UUID +from alembic import op + revision = "r7s8t9u0v1w2" down_revision = "q6r7s8t9u0v1" branch_labels = None @@ -158,9 +159,7 @@ def upgrade() -> None: nullable=False, server_default=sa.func.now(), ), - sa.UniqueConstraint( - "project_id", "subject_type", "subject_id", name="uq_project_subject" - ), + sa.UniqueConstraint("project_id", "subject_type", "subject_id", name="uq_project_subject"), ) # 5. Migrate data from project_memberships to project_shares @@ -202,9 +201,7 @@ def upgrade() -> None: sa.Column("role", sa.String(20), nullable=False), sa.Column("token", sa.String(64), nullable=False, unique=True, index=True), sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), - sa.Column( - "status", sa.String(20), nullable=False, server_default="pending" - ), + sa.Column("status", sa.String(20), nullable=False, server_default="pending"), sa.Column( "accepted_by_user_id", UUID(as_uuid=True), diff --git a/backend/alembic/versions/rec01_plan_recommendations.py b/backend/alembic/versions/rec01_plan_recommendations.py index a6fd76a..7a3bb41 100644 --- a/backend/alembic/versions/rec01_plan_recommendations.py +++ b/backend/alembic/versions/rec01_plan_recommendations.py @@ -11,10 +11,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "rec01" diff --git a/backend/alembic/versions/repo01_add_project_repositories.py b/backend/alembic/versions/repo01_add_project_repositories.py index 56c7051..1c4f507 100644 --- a/backend/alembic/versions/repo01_add_project_repositories.py +++ b/backend/alembic/versions/repo01_add_project_repositories.py @@ -11,13 +11,14 @@ Create Date: 2026-01-16 """ + import json from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "repo01" @@ -98,6 +99,7 @@ def upgrade() -> None: slug = url_path.split("/")[-1].lower() if "/" in url_path else "repo" # Sanitize: replace non-alphanumeric with hyphens import re + slug = re.sub(r"[^a-zA-Z0-9_-]", "-", slug) slug = re.sub(r"-+", "-", slug).strip("-") or "repo" diff --git a/backend/alembic/versions/s8t9u0v1w2x3_add_current_org_id_to_users.py b/backend/alembic/versions/s8t9u0v1w2x3_add_current_org_id_to_users.py index a4053c3..b045517 100644 --- a/backend/alembic/versions/s8t9u0v1w2x3_add_current_org_id_to_users.py +++ b/backend/alembic/versions/s8t9u0v1w2x3_add_current_org_id_to_users.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "s8t9u0v1w2x3" diff --git a/backend/alembic/versions/sid01_add_short_ids.py b/backend/alembic/versions/sid01_add_short_ids.py index 131cba4..f3359f4 100644 --- a/backend/alembic/versions/sid01_add_short_ids.py +++ b/backend/alembic/versions/sid01_add_short_ids.py @@ -18,10 +18,9 @@ import string from typing import Sequence, Union -from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "sid01" diff --git a/backend/alembic/versions/slack01_add_slack_tables.py b/backend/alembic/versions/slack01_add_slack_tables.py index 4e5cfec..ae63d6d 100644 --- a/backend/alembic/versions/slack01_add_slack_tables.py +++ b/backend/alembic/versions/slack01_add_slack_tables.py @@ -5,12 +5,13 @@ Create Date: 2026-02-05 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "slack01" diff --git a/backend/alembic/versions/slack02_add_oauth_tables.py b/backend/alembic/versions/slack02_add_oauth_tables.py index a0f376a..5cc449c 100644 --- a/backend/alembic/versions/slack02_add_oauth_tables.py +++ b/backend/alembic/versions/slack02_add_oauth_tables.py @@ -8,11 +8,12 @@ Create Date: 2026-02-05 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "slack02" diff --git a/backend/alembic/versions/sys01_add_system_thread_item_type.py b/backend/alembic/versions/sys01_add_system_thread_item_type.py index f9d6bc6..6a33e5d 100644 --- a/backend/alembic/versions/sys01_add_system_thread_item_type.py +++ b/backend/alembic/versions/sys01_add_system_thread_item_type.py @@ -5,11 +5,11 @@ Create Date: 2025-01-13 """ + from typing import Sequence, Union from alembic import op - # revision identifiers, used by Alembic. revision: str = "sys01" down_revision: Union[str, None] = "cexp07" diff --git a/backend/alembic/versions/t9u0v1w2x3y4_add_thread_decision_summary.py b/backend/alembic/versions/t9u0v1w2x3y4_add_thread_decision_summary.py index 7d5d147..d297894 100644 --- a/backend/alembic/versions/t9u0v1w2x3y4_add_thread_decision_summary.py +++ b/backend/alembic/versions/t9u0v1w2x3y4_add_thread_decision_summary.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "t9u0v1w2x3y4" diff --git a/backend/alembic/versions/tep01_add_exploration_prompt_search_query.py b/backend/alembic/versions/tep01_add_exploration_prompt_search_query.py index d23420a..157a290 100644 --- a/backend/alembic/versions/tep01_add_exploration_prompt_search_query.py +++ b/backend/alembic/versions/tep01_add_exploration_prompt_search_query.py @@ -11,9 +11,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "tep01" diff --git a/backend/alembic/versions/th01_add_retry_status_to_threads.py b/backend/alembic/versions/th01_add_retry_status_to_threads.py index aa4d9ec..568b862 100644 --- a/backend/alembic/versions/th01_add_retry_status_to_threads.py +++ b/backend/alembic/versions/th01_add_retry_status_to_threads.py @@ -11,9 +11,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "th01" diff --git a/backend/alembic/versions/u0v1w2x3y4z5_add_proactive_conversation_features.py b/backend/alembic/versions/u0v1w2x3y4z5_add_proactive_conversation_features.py index bf23a1a..7d8d774 100644 --- a/backend/alembic/versions/u0v1w2x3y4z5_add_proactive_conversation_features.py +++ b/backend/alembic/versions/u0v1w2x3y4z5_add_proactive_conversation_features.py @@ -12,10 +12,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "u0v1w2x3y4z5" @@ -27,7 +27,8 @@ def upgrade() -> None: # Create feature_visibility_status enum visibility_status_enum = postgresql.ENUM( - "pending", "active", + "pending", + "active", name="featurevisibilitystatus", create_type=False, ) @@ -46,7 +47,9 @@ def upgrade() -> None: # Create trigger_type enum trigger_type_enum = postgresql.ENUM( - "phase_created", "mcq_answered", "user_comment", + "phase_created", + "mcq_answered", + "user_comment", name="triggertype", create_type=False, ) diff --git a/backend/alembic/versions/uchat01_add_project_chat_fields_to_threads.py b/backend/alembic/versions/uchat01_add_project_chat_fields_to_threads.py index 473fd70..58af7f0 100644 --- a/backend/alembic/versions/uchat01_add_project_chat_fields_to_threads.py +++ b/backend/alembic/versions/uchat01_add_project_chat_fields_to_threads.py @@ -12,9 +12,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "uchat01" diff --git a/backend/alembic/versions/v1w2x3y4z5a6_add_user_question_sessions.py b/backend/alembic/versions/v1w2x3y4z5a6_add_user_question_sessions.py index 230bc92..20619cd 100644 --- a/backend/alembic/versions/v1w2x3y4z5a6_add_user_question_sessions.py +++ b/backend/alembic/versions/v1w2x3y4z5a6_add_user_question_sessions.py @@ -12,10 +12,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "v1w2x3y4z5a6" @@ -27,7 +27,8 @@ def upgrade() -> None: # Create session_status enum session_status_enum = postgresql.ENUM( - "active", "archived", + "active", + "archived", name="userquestionsessionstatus", create_type=False, ) @@ -35,7 +36,8 @@ def upgrade() -> None: # Create message_role enum message_role_enum = postgresql.ENUM( - "user", "assistant", + "user", + "assistant", name="messagerole", create_type=False, ) diff --git a/backend/alembic/versions/w2x3y4z5a6b7_remove_thread_followup_columns.py b/backend/alembic/versions/w2x3y4z5a6b7_remove_thread_followup_columns.py index 0954b3a..d8adb18 100644 --- a/backend/alembic/versions/w2x3y4z5a6b7_remove_thread_followup_columns.py +++ b/backend/alembic/versions/w2x3y4z5a6b7_remove_thread_followup_columns.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "w2x3y4z5a6b7" diff --git a/backend/alembic/versions/wsrch01_add_web_search_support.py b/backend/alembic/versions/wsrch01_add_web_search_support.py index 83a86ae..55de82a 100644 --- a/backend/alembic/versions/wsrch01_add_web_search_support.py +++ b/backend/alembic/versions/wsrch01_add_web_search_support.py @@ -14,10 +14,10 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql +from alembic import op # revision identifiers, used by Alembic. revision: str = "wsrch01" diff --git a/backend/alembic/versions/x3y4z5a6b7c8_add_feature_description.py b/backend/alembic/versions/x3y4z5a6b7c8_add_feature_description.py index a392fc4..2673ec2 100644 --- a/backend/alembic/versions/x3y4z5a6b7c8_add_feature_description.py +++ b/backend/alembic/versions/x3y4z5a6b7c8_add_feature_description.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "x3y4z5a6b7c8" diff --git a/backend/alembic/versions/y4z5a6b7c8d9_add_conversation_rerun_flag.py b/backend/alembic/versions/y4z5a6b7c8d9_add_conversation_rerun_flag.py index fa7195b..78e1051 100644 --- a/backend/alembic/versions/y4z5a6b7c8d9_add_conversation_rerun_flag.py +++ b/backend/alembic/versions/y4z5a6b7c8d9_add_conversation_rerun_flag.py @@ -5,11 +5,12 @@ Create Date: 2024-12-13 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "y4z5a6b7c8d9" diff --git a/backend/alembic/versions/z5a6b7c8d9e0_convert_project_key_to_prefix.py b/backend/alembic/versions/z5a6b7c8d9e0_convert_project_key_to_prefix.py index 17c18e0..20177f7 100644 --- a/backend/alembic/versions/z5a6b7c8d9e0_convert_project_key_to_prefix.py +++ b/backend/alembic/versions/z5a6b7c8d9e0_convert_project_key_to_prefix.py @@ -9,11 +9,10 @@ now store ticket prefixes that are used for feature key generation. """ + from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "z5a6b7c8d9e0" diff --git a/backend/app/agents/brainstorm/__init__.py b/backend/app/agents/brainstorm/__init__.py index ee0e67d..da9d377 100644 --- a/backend/app/agents/brainstorm/__init__.py +++ b/backend/app/agents/brainstorm/__init__.py @@ -22,17 +22,17 @@ - Initial MCQs for each question to start the conversation thread """ -from app.agents.brainstorm.types import ( - BrainstormContext, - GeneratedMCQ, - GeneratedClarificationQuestion, - GeneratedAspect, - BrainstormResult, -) from app.agents.brainstorm.orchestrator import ( BrainstormOrchestrator, create_orchestrator, ) +from app.agents.brainstorm.types import ( + BrainstormContext, + BrainstormResult, + GeneratedAspect, + GeneratedClarificationQuestion, + GeneratedMCQ, +) __all__ = [ "BrainstormContext", diff --git a/backend/app/agents/brainstorm/generator.py b/backend/app/agents/brainstorm/generator.py index 8e569fa..06baaa3 100644 --- a/backend/app/agents/brainstorm/generator.py +++ b/backend/app/agents/brainstorm/generator.py @@ -25,24 +25,23 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +# Import from common module +from app.agents.response_parser import strip_markdown_json + from .types import ( + MAX_ASPECTS, + MAX_MCQ_CHOICES, + MAX_QUESTIONS_PER_ASPECT, + MIN_ASPECTS, + MIN_MCQ_CHOICES, + MIN_QUESTIONS_PER_ASPECT, BrainstormContext, BrainstormResult, GeneratedAspect, GeneratedClarificationQuestion, GeneratedMCQ, - MIN_ASPECTS, - MAX_ASPECTS, - MIN_QUESTIONS_PER_ASPECT, - MAX_QUESTIONS_PER_ASPECT, - MIN_MCQ_CHOICES, - MAX_MCQ_CHOICES, ) -# Import from common module -from app.agents.response_parser import strip_markdown_json - - logger = logging.getLogger(__name__) @@ -147,23 +146,16 @@ async def generate(self, context: BrainstormContext) -> BrainstormResult: Raises: ValueError: If generation fails or returns invalid JSON """ - logger.info( - f"[brainstorm.generator] Starting generation for phase {context.brainstorming_phase_id}" - ) + logger.info(f"[brainstorm.generator] Starting generation for phase {context.brainstorming_phase_id}") try: # Build the prompt prompt = self._build_prompt(context) - logger.debug( - f"[brainstorm.generator] Calling LLM with prompt length: {len(prompt)}" - ) + logger.debug(f"[brainstorm.generator] Calling LLM with prompt length: {len(prompt)}") # Call the agent - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -176,9 +168,7 @@ async def generate(self, context: BrainstormContext) -> BrainstormResult: try: generated_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - logger.error( - f"[brainstorm.generator] Failed to parse response as JSON: {e}" - ) + logger.error(f"[brainstorm.generator] Failed to parse response as JSON: {e}") logger.error(f"[brainstorm.generator] Raw response: {response_text[:500]}") raise ValueError(f"Failed to parse generator response as JSON: {e}") @@ -284,8 +274,7 @@ def _build_result(self, generated_data: dict) -> BrainstormResult: async def create_generator_agent( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> BrainstormGeneratorAgent: """ Factory function to create a Brainstorm Generator Agent. diff --git a/backend/app/agents/brainstorm/orchestrator.py b/backend/app/agents/brainstorm/orchestrator.py index 5f5de76..cba203f 100644 --- a/backend/app/agents/brainstorm/orchestrator.py +++ b/backend/app/agents/brainstorm/orchestrator.py @@ -13,17 +13,16 @@ """ import logging -from typing import Optional, Callable, Dict, Any +from typing import Any, Callable, Dict, Optional from autogen_core.models import ChatCompletionClient +from .generator import BrainstormGeneratorAgent from .types import ( BrainstormContext, BrainstormResult, validate_brainstorm_result, ) -from .generator import BrainstormGeneratorAgent - logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ def __init__( self, model_client: ChatCompletionClient, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - project_id: Optional[str] = None + project_id: Optional[str] = None, ): """ Initialize the Brainstorm Orchestrator. @@ -73,9 +72,7 @@ async def generate_brainstorm(self, context: BrainstormContext) -> BrainstormRes Raises: Exception: If generation fails """ - logger.info( - f"[brainstorm.orchestrator] Starting generation for phase {context.brainstorming_phase_id}" - ) + logger.info(f"[brainstorm.orchestrator] Starting generation for phase {context.brainstorming_phase_id}") try: # Step 1: Analyzing (0-20%) @@ -99,9 +96,7 @@ async def generate_brainstorm(self, context: BrainstormContext) -> BrainstormRes issues = validate_brainstorm_result(result) if issues: - logger.warning( - f"[brainstorm.orchestrator] Validation issues: {issues}" - ) + logger.warning(f"[brainstorm.orchestrator] Validation issues: {issues}") # Add issues to generation notes result.generation_notes.extend([f"Validation: {issue}" for issue in issues]) @@ -152,7 +147,7 @@ async def create_orchestrator( api_key: str, config: dict, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> BrainstormOrchestrator: """ Factory function to create a Brainstorm Orchestrator with LLM client. diff --git a/backend/app/agents/brainstorm/types.py b/backend/app/agents/brainstorm/types.py index 3ca1ed5..5b917f6 100644 --- a/backend/app/agents/brainstorm/types.py +++ b/backend/app/agents/brainstorm/types.py @@ -14,10 +14,9 @@ """ from dataclasses import dataclass, field -from typing import List, Optional, Dict, Any +from typing import Dict, List, Optional from uuid import UUID - # Configuration constants MIN_ASPECTS = 2 MAX_ASPECTS = 8 @@ -37,6 +36,7 @@ class BrainstormContext: The agent uses the phase description as its primary input to generate aspects and clarification questions. """ + project_id: UUID brainstorming_phase_id: UUID phase_title: str @@ -54,6 +54,7 @@ class GeneratedMCQ: Each clarification question gets a thread with an initial MCQ that allows users to provide structured input. """ + question_text: str choices: List[Dict[str, str]] # [{"id": "a", "label": "..."}] explanation: Optional[str] = None # Why this question is important @@ -67,6 +68,7 @@ class GeneratedClarificationQuestion: Clarification questions represent specific topics to discuss within an aspect. Each one gets its own conversation thread. """ + title: str description: str initial_mcq: GeneratedMCQ # The MCQ to start the thread @@ -80,6 +82,7 @@ class GeneratedAspect: Aspects represent logical areas to explore during brainstorming. Each aspect contains multiple clarification questions. """ + title: str description: str clarification_questions: List[GeneratedClarificationQuestion] @@ -93,6 +96,7 @@ class BrainstormResult: Contains all generated aspects with their clarification questions. """ + aspects: List[GeneratedAspect] total_clarification_questions: int = 0 generation_notes: List[str] = field(default_factory=list) @@ -100,9 +104,7 @@ class BrainstormResult: def __post_init__(self): """Calculate total clarification questions if not provided.""" if self.total_clarification_questions == 0: - self.total_clarification_questions = sum( - len(aspect.clarification_questions) for aspect in self.aspects - ) + self.total_clarification_questions = sum(len(aspect.clarification_questions) for aspect in self.aspects) def validate_brainstorm_result(result: BrainstormResult) -> List[str]: @@ -121,13 +123,9 @@ def validate_brainstorm_result(result: BrainstormResult) -> List[str]: # Check total questions if result.total_clarification_questions < MIN_TOTAL_QUESTIONS: - issues.append( - f"Too few questions: {result.total_clarification_questions} < {MIN_TOTAL_QUESTIONS}" - ) + issues.append(f"Too few questions: {result.total_clarification_questions} < {MIN_TOTAL_QUESTIONS}") if result.total_clarification_questions > MAX_TOTAL_QUESTIONS: - issues.append( - f"Too many questions: {result.total_clarification_questions} > {MAX_TOTAL_QUESTIONS}" - ) + issues.append(f"Too many questions: {result.total_clarification_questions} > {MAX_TOTAL_QUESTIONS}") # Check each aspect for i, aspect in enumerate(result.aspects): @@ -136,13 +134,9 @@ def validate_brainstorm_result(result: BrainstormResult) -> List[str]: q_count = len(aspect.clarification_questions) if q_count < MIN_QUESTIONS_PER_ASPECT: - issues.append( - f"Aspect '{aspect.title}' has too few questions: {q_count} < {MIN_QUESTIONS_PER_ASPECT}" - ) + issues.append(f"Aspect '{aspect.title}' has too few questions: {q_count} < {MIN_QUESTIONS_PER_ASPECT}") if q_count > MAX_QUESTIONS_PER_ASPECT: - issues.append( - f"Aspect '{aspect.title}' has too many questions: {q_count} > {MAX_QUESTIONS_PER_ASPECT}" - ) + issues.append(f"Aspect '{aspect.title}' has too many questions: {q_count} > {MAX_QUESTIONS_PER_ASPECT}") # Check each question for j, question in enumerate(aspect.clarification_questions): @@ -152,13 +146,9 @@ def validate_brainstorm_result(result: BrainstormResult) -> List[str]: # Check MCQ choices choice_count = len(question.initial_mcq.choices) if choice_count < MIN_MCQ_CHOICES: - issues.append( - f"Question '{question.title}' MCQ has too few choices: {choice_count}" - ) + issues.append(f"Question '{question.title}' MCQ has too few choices: {choice_count}") if choice_count > MAX_MCQ_CHOICES: - issues.append( - f"Question '{question.title}' MCQ has too many choices: {choice_count}" - ) + issues.append(f"Question '{question.title}' MCQ has too many choices: {choice_count}") # Check for duplicate aspect titles aspect_titles = [a.title.lower().strip() for a in result.aspects] diff --git a/backend/app/agents/brainstorm_conversation/__init__.py b/backend/app/agents/brainstorm_conversation/__init__.py index cab4f21..74db348 100644 --- a/backend/app/agents/brainstorm_conversation/__init__.py +++ b/backend/app/agents/brainstorm_conversation/__init__.py @@ -34,49 +34,47 @@ result = await orchestrator.generate_brainstorm_conversations(context) """ +from .aspect_generator import AspectGeneratorAgent, create_aspect_generator +from .classifier import ComplexityClassifierAgent, create_complexity_classifier +from .critic_pruner import CriticPrunerAgent, create_critic_pruner +from .input_validator import InputValidatorAgent, ValidationResult, create_input_validator +from .orchestrator import ( + BrainstormConversationOrchestrator, + create_orchestrator, +) +from .question_generator import QuestionGeneratorAgent, create_question_generator +from .summarizer import SummarizerAgent, create_summarizer_agent from .types import ( - # Enums - PhaseType, - PhaseComplexity, - AspectCategory, - QuestionPriority, - # Dataclasses - MCQChoice, - GeneratedMCQ, - GeneratedClarificationQuestion, - GeneratedAspect, - BrainstormConversationContext, - SummarizedPhaseContext, - ClassificationResult, - ComplexityCaps, - BrainstormConversationResult, - AgentInfo, - CodeExplorationContext, # Constants AGENT_METADATA, - WORKFLOW_STEPS, COMPLEXITY_CAPS_CONFIG, - STANDARD_MCQ_CHOICES, - MIN_MCQ_CHOICES, MAX_MCQ_CHOICES, + MIN_MCQ_CHOICES, + STANDARD_MCQ_CHOICES, + WORKFLOW_STEPS, + AgentInfo, + AspectCategory, + BrainstormConversationContext, + BrainstormConversationResult, + ClassificationResult, + CodeExplorationContext, + ComplexityCaps, + GeneratedAspect, + GeneratedClarificationQuestion, + GeneratedMCQ, + # Dataclasses + MCQChoice, + PhaseComplexity, + # Enums + PhaseType, + QuestionPriority, + SummarizedPhaseContext, + create_mcq_choices, # Helper functions get_caps_for_complexity, - create_mcq_choices, validate_brainstorm_result, ) -from .orchestrator import ( - BrainstormConversationOrchestrator, - create_orchestrator, -) - -from .summarizer import SummarizerAgent, create_summarizer_agent -from .classifier import ComplexityClassifierAgent, create_complexity_classifier -from .aspect_generator import AspectGeneratorAgent, create_aspect_generator -from .question_generator import QuestionGeneratorAgent, create_question_generator -from .critic_pruner import CriticPrunerAgent, create_critic_pruner -from .input_validator import InputValidatorAgent, ValidationResult, create_input_validator - __all__ = [ # Main orchestrator "BrainstormConversationOrchestrator", diff --git a/backend/app/agents/brainstorm_conversation/aspect_generator.py b/backend/app/agents/brainstorm_conversation/aspect_generator.py index 65f6950..4cf2f1a 100644 --- a/backend/app/agents/brainstorm_conversation/aspect_generator.py +++ b/backend/app/agents/brainstorm_conversation/aspect_generator.py @@ -13,21 +13,21 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( - GeneratedAspect, AspectCategory, - SummarizedPhaseContext, ClassificationResult, + CodeExplorationContext, ComplexityCaps, - ExistingConversationContext, - UserInitiatedContext, CrossProjectContext, - TechStackContext, - CodeExplorationContext, + ExistingConversationContext, + GeneratedAspect, SiblingPhasesContext, + SummarizedPhaseContext, + TechStackContext, + UserInitiatedContext, calculate_aspect_coverage_level, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -138,7 +138,7 @@ async def generate_aspects( tech_stack_context: Optional[TechStackContext] = None, code_exploration_context: Optional[CodeExplorationContext] = None, sibling_phases_context: Optional[SiblingPhasesContext] = None, - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> List[GeneratedAspect]: """ Generate aspects (exploration areas) for brainstorming. @@ -168,7 +168,7 @@ async def generate_aspects( complexity=classification.complexity.value, target_range=f"{caps.min_aspects}-{caps.max_aspects}", existing_count=existing_count, - num_focus_areas=len(classification.suggested_focus_areas) + num_focus_areas=len(classification.suggested_focus_areas), ) try: @@ -183,20 +183,13 @@ async def generate_aspects( cross_project_context, tech_stack_context, code_exploration_context, - sibling_phases_context + sibling_phases_context, ) # Call the agent - self.logger.log_llm_call( - prompt=prompt, - model=str(self.model_client), - operation="generate_aspects" - ) + self.logger.log_llm_call(prompt=prompt, model=str(self.model_client), operation="generate_aspects") - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -209,10 +202,9 @@ async def generate_aspects( try: aspects_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) raise ValueError(f"Failed to parse aspect generator response as JSON: {e}") # Convert to GeneratedAspect objects @@ -228,7 +220,7 @@ async def generate_aspects( self.logger.log_agent_complete( generated_count=len(aspects), categories=list(set(a.category.value for a in aspects)), - aspect_titles=[a.title for a in aspects] + aspect_titles=[a.title for a in aspects], ) return aspects @@ -237,10 +229,7 @@ async def generate_aspects( self.logger.log_error(e, {"project_id": project_id}) raise - def _build_existing_context_section( - self, - existing_context: ExistingConversationContext - ) -> str: + def _build_existing_context_section(self, existing_context: ExistingConversationContext) -> str: """ Build the existing aspects context section for the prompt. @@ -260,20 +249,30 @@ def _build_existing_context_section( # Section 1: Existing Aspects Overview sections.append("### EXISTING ASPECTS - DO NOT DUPLICATE:") - sections.append("The following aspects have already been created. DO NOT create new aspects with the same or similar names.") - sections.append("If you want to add more questions to an existing topic, you MUST skip it - questions will be added separately.") + sections.append( + "The following aspects have already been created. DO NOT create new aspects with the same or similar names." + ) + sections.append( + "If you want to add more questions to an existing topic, you MUST skip it - questions will be added separately." + ) sections.append("") for aspect in existing_context.aspects: coverage = calculate_aspect_coverage_level(aspect.total_questions) - sections.append(f"**{aspect.title}** ({aspect.total_questions} questions, {aspect.answered_questions} answered) - {coverage.upper()} coverage") - sections.append(f" Description: {aspect.description[:150]}..." if len(aspect.description) > 150 else f" Description: {aspect.description}") + sections.append( + f"**{aspect.title}** ({aspect.total_questions} questions, {aspect.answered_questions} answered) - {coverage.upper()} coverage" + ) + sections.append( + f" Description: {aspect.description[:150]}..." + if len(aspect.description) > 150 + else f" Description: {aspect.description}" + ) # Show questions with decision summaries for q in aspect.questions: if q.status == "answered" and q.decision_summary: sections.append(f" - Q: {q.question_title}") - sections.append(f" Decision: \"{q.decision_summary}\"") + sections.append(f' Decision: "{q.decision_summary}"') # Show unresolved points if any (limit to 2 for brevity) if q.unresolved_points: sections.append(f" Open questions: {', '.join(q.unresolved_points[:2])}") @@ -285,17 +284,14 @@ def _build_existing_context_section( # Section 2: Decisions that expand valid scope decisions_with_content = [ - (aspect, q) - for aspect in existing_context.aspects - for q in aspect.questions - if q.decision_summary + (aspect, q) for aspect in existing_context.aspects for q in aspect.questions if q.decision_summary ] if decisions_with_content: sections.append("### DECISIONS THAT EXPAND VALID SCOPE:") sections.append("These decisions enable follow-up aspects on related topics:") for aspect, q in decisions_with_content[:10]: # Cap at 10 - sections.append(f"- \"{q.decision_summary}\" → follow-up on this topic is VALID") + sections.append(f'- "{q.decision_summary}" → follow-up on this topic is VALID') sections.append("") # Section 3: Generation Guidelines @@ -303,16 +299,15 @@ def _build_existing_context_section( sections.append("- DO NOT create aspects with names similar to the ones listed above") sections.append("- Focus on NEW areas of exploration not yet covered") sections.append("- You MAY create aspects related to the decisions listed above") - sections.append("- Think like a human reviewer who already has these aspects - what NEW areas need exploration?") + sections.append( + "- Think like a human reviewer who already has these aspects - what NEW areas need exploration?" + ) sections.append("- If all major areas are covered, generate fewer aspects or focus on niche areas") sections.append("") return "\n".join(sections) - def _build_user_initiated_section( - self, - user_initiated_context: UserInitiatedContext - ) -> str: + def _build_user_initiated_section(self, user_initiated_context: UserInitiatedContext) -> str: """ Build the user-initiated context section for the prompt. @@ -394,11 +389,15 @@ def _build_tech_stack_section(self, tech_stack_context: TechStackContext) -> str sections = [] sections.append("### TECHNOLOGY STACK CONTEXT") sections.append("") - sections.append(f"The user indicated this tech stack during initial discussion: **{tech_stack_context.proposed_stack}**") + sections.append( + f"The user indicated this tech stack during initial discussion: **{tech_stack_context.proposed_stack}**" + ) sections.append("") if tech_stack_context.has_grounding: - sections.append("NOTE: This project has grounding documentation (agents.md), which may already document tech decisions.") + sections.append( + "NOTE: This project has grounding documentation (agents.md), which may already document tech decisions." + ) sections.append("Consider whether tech stack exploration is still needed.") sections.append("") @@ -408,10 +407,7 @@ def _build_tech_stack_section(self, tech_stack_context: TechStackContext) -> str return "\n".join(sections) - def _build_code_exploration_section( - self, - code_exploration_context: CodeExplorationContext - ) -> str: + def _build_code_exploration_section(self, code_exploration_context: CodeExplorationContext) -> str: """ Build the code exploration context section for the prompt. @@ -436,10 +432,7 @@ def _build_code_exploration_section( return "\n".join(sections) - def _build_cross_project_section( - self, - cross_project_context: CrossProjectContext - ) -> str: + def _build_cross_project_section(self, cross_project_context: CrossProjectContext) -> str: """ Build the cross-phase and project-level context section for the prompt. @@ -475,7 +468,9 @@ def _build_cross_project_section( sections.append(f" {phase.phase_description}") for decision in phase.decisions[:10]: # Cap at 10 per phase for display - sections.append(f" - [{decision.aspect_title}] {decision.question_title}: \"{decision.decision_summary_short}\"") + sections.append( + f' - [{decision.aspect_title}] {decision.question_title}: "{decision.decision_summary_short}"' + ) sections.append("") # Section 2: Project-level features @@ -485,7 +480,7 @@ def _build_cross_project_section( sections.append("") for feat in cross_project_context.project_features[:15]: # Cap at 15 - sections.append(f" - [{feat.module_title}] {feat.feature_title}: \"{feat.decision_summary_short}\"") + sections.append(f' - [{feat.module_title}] {feat.feature_title}: "{feat.decision_summary_short}"') sections.append("") # Guidance for the LLM @@ -497,10 +492,7 @@ def _build_cross_project_section( return "\n".join(sections) - def _build_sibling_phases_section( - self, - sibling_phases_context: SiblingPhasesContext - ) -> str: + def _build_sibling_phases_section(self, sibling_phases_context: SiblingPhasesContext) -> str: """ Build the sibling phases context section for the prompt. @@ -518,7 +510,7 @@ def _build_sibling_phases_section( return "" sections = [] - sections.append(f"### SIBLING PHASES (same container: \"{sibling_phases_context.container_title}\"):") + sections.append(f'### SIBLING PHASES (same container: "{sibling_phases_context.container_title}"):') sections.append("These phases are part of the same feature group and share related context:") sections.append("") @@ -530,7 +522,9 @@ def _build_sibling_phases_section( if phase.decisions: for decision in phase.decisions[:10]: - sections.append(f" - [{decision.aspect_title}] {decision.question_title}: \"{decision.decision_summary_short}\"") + sections.append( + f' - [{decision.aspect_title}] {decision.question_title}: "{decision.decision_summary_short}"' + ) if phase.implementation_analysis: truncated = phase.implementation_analysis[:300] @@ -558,7 +552,7 @@ def _build_prompt( cross_project_context: Optional[CrossProjectContext] = None, tech_stack_context: Optional[TechStackContext] = None, code_exploration_context: Optional[CodeExplorationContext] = None, - sibling_phases_context: Optional[SiblingPhasesContext] = None + sibling_phases_context: Optional[SiblingPhasesContext] = None, ) -> str: """ Build the aspect generation prompt with existing context awareness. @@ -685,13 +679,12 @@ def _parse_aspect(self, a_data: dict, order_index: int) -> GeneratedAspect: category=category, order_index=order_index, clarification_questions=[], # Will be filled by question generator - internal_agent_notes=internal_notes + internal_agent_notes=internal_notes, ) async def create_aspect_generator( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> AspectGeneratorAgent: """ Factory function to create an Aspect Generator Agent. diff --git a/backend/app/agents/brainstorm_conversation/classifier.py b/backend/app/agents/brainstorm_conversation/classifier.py index 9a7c33e..c9e3c80 100644 --- a/backend/app/agents/brainstorm_conversation/classifier.py +++ b/backend/app/agents/brainstorm_conversation/classifier.py @@ -6,15 +6,15 @@ """ import json -from typing import Optional, List +from typing import Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient -from .types import PhaseComplexity, ClassificationResult, SummarizedPhaseContext, PhaseType from .logging_config import get_agent_logger +from .types import ClassificationResult, PhaseComplexity, PhaseType, SummarizedPhaseContext from .utils import strip_markdown_json @@ -111,10 +111,7 @@ def _get_system_message(self) -> str: The suggested_focus_areas should be 3-5 key areas that deserve exploration based on the phase description.""" async def classify( - self, - summarized_context: SummarizedPhaseContext, - phase_type: PhaseType, - project_id: Optional[str] = None + self, summarized_context: SummarizedPhaseContext, phase_type: PhaseType, project_id: Optional[str] = None ) -> ClassificationResult: """ Classify the brainstorming phase's complexity. @@ -134,35 +131,25 @@ async def classify( project_id=project_id, phase_type=phase_type.value, num_objectives=len(summarized_context.key_objectives), - num_constraints=len(summarized_context.constraints) + num_constraints=len(summarized_context.constraints), ) try: # Apply heuristics first - heuristic_complexity = self._apply_heuristics( - phase_type, - summarized_context - ) + heuristic_complexity = self._apply_heuristics(phase_type, summarized_context) # Build the prompt - prompt = self._build_prompt( - summarized_context, - phase_type, - heuristic_complexity - ) + prompt = self._build_prompt(summarized_context, phase_type, heuristic_complexity) # Call the agent self.logger.log_llm_call( prompt=prompt, model=str(self.model_client), operation="classify_complexity", - heuristic_complexity=heuristic_complexity + heuristic_complexity=heuristic_complexity, ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -190,22 +177,17 @@ async def classify( raise ValueError(f"Invalid complexity value: {e}") result = ClassificationResult( - complexity=complexity, - rationale=rationale, - suggested_focus_areas=suggested_focus_areas + complexity=complexity, rationale=rationale, suggested_focus_areas=suggested_focus_areas ) self.logger.log_decision( decision=f"complexity={complexity.value}", rationale=rationale, project_id=project_id, - focus_areas=suggested_focus_areas + focus_areas=suggested_focus_areas, ) - self.logger.log_agent_complete( - complexity=complexity.value, - num_focus_areas=len(suggested_focus_areas) - ) + self.logger.log_agent_complete(complexity=complexity.value, num_focus_areas=len(suggested_focus_areas)) return result @@ -213,11 +195,7 @@ async def classify( self.logger.log_error(e, {"project_id": project_id}) raise - def _apply_heuristics( - self, - phase_type: PhaseType, - summarized_context: SummarizedPhaseContext - ) -> Optional[str]: + def _apply_heuristics(self, phase_type: PhaseType, summarized_context: SummarizedPhaseContext) -> Optional[str]: """ Apply simple heuristics to guess complexity. @@ -262,28 +240,30 @@ def _apply_heuristics( # High complexity indicators high_keywords = [ - "architecture", "multiple systems", "integration", "complex", - "greenfield", "from scratch", "new platform", "many users", - "enterprise", "scalability", "migration" + "architecture", + "multiple systems", + "integration", + "complex", + "greenfield", + "from scratch", + "new platform", + "many users", + "enterprise", + "scalability", + "migration", ] if any(keyword in summary_lower for keyword in high_keywords): return PhaseComplexity.HIGH.value # Low complexity indicators - low_keywords = [ - "simple", "single", "small", "quick", "minor", - "tweak", "adjustment", "one feature" - ] + low_keywords = ["simple", "single", "small", "quick", "minor", "tweak", "adjustment", "one feature"] if any(keyword in summary_lower for keyword in low_keywords): return PhaseComplexity.LOW.value return base_complexity.value def _build_prompt( - self, - summarized_context: SummarizedPhaseContext, - phase_type: PhaseType, - heuristic_complexity: Optional[str] + self, summarized_context: SummarizedPhaseContext, phase_type: PhaseType, heuristic_complexity: Optional[str] ) -> str: """ Build the classification prompt. @@ -340,8 +320,7 @@ def _build_prompt( async def create_complexity_classifier( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> ComplexityClassifierAgent: """ Factory function to create a Complexity Classifier Agent. diff --git a/backend/app/agents/brainstorm_conversation/code_explorer_stage.py b/backend/app/agents/brainstorm_conversation/code_explorer_stage.py index 9388a80..9aa837d 100644 --- a/backend/app/agents/brainstorm_conversation/code_explorer_stage.py +++ b/backend/app/agents/brainstorm_conversation/code_explorer_stage.py @@ -15,22 +15,22 @@ import hashlib import logging from datetime import datetime, timezone -from typing import Optional, List, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from sqlalchemy.orm import Session -from app.models.project import Project from app.models.platform_settings import PlatformSettings +from app.models.project import Project from app.services.code_explorer_client import code_explorer_client if TYPE_CHECKING: from app.models.brainstorming_phase import BrainstormingPhase from .types import ( + ClassificationResult, CodeExplorationContext, SummarizedPhaseContext, - ClassificationResult, ) logger = logging.getLogger(__name__) @@ -135,22 +135,20 @@ async def _build_repos_list( github_token = None if repo.github_integration_config_id: try: - github_token = await get_github_token_for_org( - db, project.org_id, repo.github_integration_config_id - ) + github_token = await get_github_token_for_org(db, project.org_id, repo.github_integration_config_id) except Exception as e: # Continue without token (works for public repos) - logger.warning( - f"Failed to get GitHub token for repo {repo.slug}: {e}" - ) - - repos.append({ - "slug": repo.slug, - "repo_url": repo.repo_url, - "branch": repo.default_branch or "main", - "github_token": github_token, - "user_remarks": repo.user_remarks, - }) + logger.warning(f"Failed to get GitHub token for repo {repo.slug}: {e}") + + repos.append( + { + "slug": repo.slug, + "repo_url": repo.repo_url, + "branch": repo.default_branch or "main", + "github_token": github_token, + "user_remarks": repo.user_remarks, + } + ) return repos @@ -181,10 +179,10 @@ def _log_code_exploration_usage( response_content: The exploration output (for response log) """ try: + from app.models.job import Job from app.services.llm_call_log_service import LLMCallLogService from app.services.llm_usage_log_service import LLMUsageLogService - from app.models.job import Job - from workers.handlers.code_explorer import calculate_claude_cost, CODE_EXPLORER_MODEL + from workers.handlers.code_explorer import CODE_EXPLORER_MODEL, calculate_claude_cost # Get org_id, project_id, triggered_by_user_id from job job = db.query(Job).filter(Job.id == job_id).first() @@ -243,10 +241,7 @@ def _log_code_exploration_usage( duration_ms=duration_ms, ) - logger.debug( - f"Logged code exploration usage: {prompt_tokens} prompt + " - f"{completion_tokens} completion tokens" - ) + logger.debug(f"Logged code exploration usage: {prompt_tokens} prompt + {completion_tokens} completion tokens") except Exception as e: # Log error but don't disrupt agent execution logger.warning(f"Failed to log code exploration usage: {e}") @@ -288,29 +283,21 @@ async def get_or_run_code_exploration( # 1. Check if Code Explorer is enabled in platform settings settings = db.query(PlatformSettings).first() if not settings or not settings.code_explorer_enabled: - logger.debug( - f"Code Explorer disabled for project {project.id}, skipping exploration" - ) + logger.debug(f"Code Explorer disabled for project {project.id}, skipping exploration") return None # 2. Build repos list repos = await _build_repos_list(db, project) if not repos: - logger.debug( - f"No repositories configured for project {project.id}, skipping exploration" - ) + logger.debug(f"No repositories configured for project {project.id}, skipping exploration") return None repos_hash = _compute_repos_hash(repos) # 3. Check cache validity - if ( - phase.code_exploration_output - and phase.code_exploration_repos_hash == repos_hash - ): + if phase.code_exploration_output and phase.code_exploration_repos_hash == repos_hash: logger.info( - f"Using cached code exploration for phase {phase.id} " - f"(cached at {phase.code_exploration_cached_at})" + f"Using cached code exploration for phase {phase.id} (cached at {phase.code_exploration_cached_at})" ) return CodeExplorationContext( output=phase.code_exploration_output, @@ -323,9 +310,7 @@ async def get_or_run_code_exploration( ) # 4. Run new exploration - logger.info( - f"Running new code exploration for phase {phase.id} with {len(repos)} repos" - ) + logger.info(f"Running new code exploration for phase {phase.id} with {len(repos)} repos") started_at = datetime.now(timezone.utc) result = await run_code_exploration( db=db, @@ -343,10 +328,7 @@ async def get_or_run_code_exploration( phase.code_exploration_cached_at = datetime.now(timezone.utc) phase.code_exploration_repos_hash = repos_hash db.commit() - logger.info( - f"Cached code exploration for phase {phase.id} " - f"({len(result.output)} chars)" - ) + logger.info(f"Cached code exploration for phase {phase.id} ({len(result.output)} chars)") # Log LLM usage for the code exploration call if job_id and result.prompt_tokens and result.completion_tokens: @@ -400,17 +382,13 @@ async def run_code_exploration( # 1. Check if Code Explorer is enabled in platform settings settings = db.query(PlatformSettings).first() if not settings or not settings.code_explorer_enabled: - logger.debug( - f"Code Explorer disabled for project {project.id}, skipping exploration" - ) + logger.debug(f"Code Explorer disabled for project {project.id}, skipping exploration") return None # 2. Check if project has repositories repositories = project.repositories if not repositories: - logger.debug( - f"No repositories configured for project {project.id}, skipping exploration" - ) + logger.debug(f"No repositories configured for project {project.id}, skipping exploration") return None # 3. Get Anthropic API key from platform settings @@ -418,9 +396,7 @@ async def run_code_exploration( anthropic_key = get_code_explorer_api_key(db) if not anthropic_key: - logger.warning( - f"Code Explorer API key not configured, skipping exploration for project {project.id}" - ) + logger.warning(f"Code Explorer API key not configured, skipping exploration for project {project.id}") return None # 4. Build repos list with GitHub tokens @@ -431,22 +407,20 @@ async def run_code_exploration( github_token = None if repo.github_integration_config_id: try: - github_token = await get_github_token_for_org( - db, project.org_id, repo.github_integration_config_id - ) + github_token = await get_github_token_for_org(db, project.org_id, repo.github_integration_config_id) except Exception as e: # Continue without token (works for public repos) - logger.warning( - f"Failed to get GitHub token for repo {repo.slug}: {e}" - ) - - repos.append({ - "slug": repo.slug, - "repo_url": repo.repo_url, - "branch": repo.default_branch or "main", - "github_token": github_token, - "user_remarks": repo.user_remarks, - }) + logger.warning(f"Failed to get GitHub token for repo {repo.slug}: {e}") + + repos.append( + { + "slug": repo.slug, + "repo_url": repo.repo_url, + "branch": repo.default_branch or "main", + "github_token": github_token, + "user_remarks": repo.user_remarks, + } + ) # 5. Build exploration prompt exploration_prompt = _build_exploration_prompt( @@ -456,9 +430,7 @@ async def run_code_exploration( classification=classification, ) - logger.info( - f"Running code exploration for project {project.id} with {len(repos)} repos" - ) + logger.info(f"Running code exploration for project {project.id} with {len(repos)} repos") # 6. Call code explorer service try: @@ -478,16 +450,13 @@ async def run_code_exploration( # Check for success and actual output if not result["success"]: logger.warning( - f"Code exploration failed for project {project.id}: " - f"{result.get('error') or result.get('error_code')}" + f"Code exploration failed for project {project.id}: {result.get('error') or result.get('error_code')}" ) return None output = result.get("output") if not output or not output.strip(): - logger.warning( - f"Code exploration returned empty output for project {project.id}" - ) + logger.warning(f"Code exploration returned empty output for project {project.id}") return None logger.info( @@ -508,7 +477,5 @@ async def run_code_exploration( except Exception as e: # Log but don't fail - exploration is optional context - logger.warning( - f"Code exploration error for project {project.id}: {e}" - ) + logger.warning(f"Code exploration error for project {project.id}: {e}") return None diff --git a/backend/app/agents/brainstorm_conversation/critic_pruner.py b/backend/app/agents/brainstorm_conversation/critic_pruner.py index 2ca2660..89c0172 100644 --- a/backend/app/agents/brainstorm_conversation/critic_pruner.py +++ b/backend/app/agents/brainstorm_conversation/critic_pruner.py @@ -20,18 +20,17 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( + AspectCategory, + ComplexityCaps, + ExistingConversationContext, GeneratedAspect, GeneratedClarificationQuestion, GeneratedMCQ, MCQChoice, QuestionPriority, - AspectCategory, - ComplexityCaps, - PhaseComplexity, - ExistingConversationContext, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -159,7 +158,7 @@ async def prune_and_refine( project_id: Optional[str] = None, phase_summary: Optional[str] = None, phase_description: Optional[str] = None, - existing_context: Optional[ExistingConversationContext] = None + existing_context: Optional[ExistingConversationContext] = None, ) -> List[GeneratedAspect]: """ Refine the generated aspects and questions for quality. @@ -188,7 +187,7 @@ async def prune_and_refine( aspect_count=len(aspects), question_count=total_questions, max_aspects=caps.max_aspects, - max_questions=caps.total_max_questions + max_questions=caps.total_max_questions, ) try: @@ -199,15 +198,10 @@ async def prune_and_refine( # Call the agent self.logger.log_llm_call( - prompt=prompt, - model=str(self.model_client), - operation="prune_aspects_and_questions" + prompt=prompt, model=str(self.model_client), operation="prune_aspects_and_questions" ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -220,10 +214,9 @@ async def prune_and_refine( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) raise ValueError(f"Failed to parse pruner response as JSON: {e}") # Extract refined aspects @@ -253,16 +246,13 @@ async def prune_and_refine( initial_questions=total_questions, final_questions=final_questions, pruned_items=[], # Could track specific items if needed - pruning_summary=pruning_summary + pruning_summary=pruning_summary, ) self.logger.log_agent_complete( kept_aspects=len(refined_aspects), kept_questions=final_questions, - within_caps=( - len(refined_aspects) <= caps.max_aspects and - final_questions <= caps.total_max_questions - ) + within_caps=(len(refined_aspects) <= caps.max_aspects and final_questions <= caps.total_max_questions), ) return refined_aspects @@ -278,7 +268,7 @@ def _build_prompt( phase_summary: Optional[str] = None, phase_description: Optional[str] = None, existing_context: Optional[ExistingConversationContext] = None, - is_user_initiated: bool = False + is_user_initiated: bool = False, ) -> str: """ Build the quality review prompt. @@ -326,7 +316,7 @@ def _build_prompt( prompt += "\n" # Add caps information - prompt += f"### Limits:\n" + prompt += "### Limits:\n" prompt += f"- Complexity level: {caps.complexity.value}\n" prompt += f"- Max aspects: {caps.max_aspects}\n" prompt += f"- Max questions per aspect: {caps.max_questions_per_aspect}\n" @@ -409,7 +399,7 @@ def _parse_aspect(self, a_data: dict, order_index: int) -> GeneratedAspect: category=category, order_index=order_index, clarification_questions=questions, - internal_agent_notes=a_data.get("pruning_note") + internal_agent_notes=a_data.get("pruning_note"), ) def _parse_question(self, q_data: dict) -> GeneratedClarificationQuestion: @@ -441,7 +431,7 @@ def _parse_question(self, q_data: dict) -> GeneratedClarificationQuestion: description=description, priority=priority, initial_mcq=mcq, - internal_agent_notes=q_data.get("pruning_note") + internal_agent_notes=q_data.get("pruning_note"), ) def _parse_mcq(self, mcq_data: dict) -> GeneratedMCQ: @@ -497,8 +487,7 @@ def _parse_mcq(self, mcq_data: dict) -> GeneratedMCQ: ) def _sort_questions_by_priority( - self, - questions: List[GeneratedClarificationQuestion] + self, questions: List[GeneratedClarificationQuestion] ) -> List[GeneratedClarificationQuestion]: """ Sort questions by priority (must_have first, then important, then optional). @@ -517,10 +506,7 @@ def _sort_questions_by_priority( return sorted(questions, key=lambda q: priority_order.get(q.priority, 1)) def _enforce_hard_caps( - self, - aspects: List[GeneratedAspect], - caps: ComplexityCaps, - is_user_initiated: bool = False + self, aspects: List[GeneratedAspect], caps: ComplexityCaps, is_user_initiated: bool = False ) -> List[GeneratedAspect]: """ Enforce hard caps if LLM didn't follow instructions. @@ -567,7 +553,7 @@ def _enforce_hard_caps( self.logger.logger.warning( f"LLM returned {len(aspects)} aspects but max is {caps.max_aspects}. Truncating." ) - aspects = aspects[:caps.max_aspects] + aspects = aspects[: caps.max_aspects] # Limit questions per aspect and total questions (priority-aware) total_questions = 0 @@ -580,7 +566,7 @@ def _enforce_hard_caps( self.logger.logger.info( f"Aspect '{aspect.title}': removing {removed_count} lowest-priority questions to meet per-aspect cap" ) - sorted_questions = sorted_questions[:caps.max_questions_per_aspect] + sorted_questions = sorted_questions[: caps.max_questions_per_aspect] # Check total cap remaining_budget = caps.total_max_questions - total_questions @@ -598,8 +584,7 @@ def _enforce_hard_caps( async def create_critic_pruner( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> CriticPrunerAgent: """ Factory function to create a Critic/Pruner Agent. diff --git a/backend/app/agents/brainstorm_conversation/logging_config.py b/backend/app/agents/brainstorm_conversation/logging_config.py index dd081c0..8d91cea 100644 --- a/backend/app/agents/brainstorm_conversation/logging_config.py +++ b/backend/app/agents/brainstorm_conversation/logging_config.py @@ -4,10 +4,10 @@ Provides structured logging for all agent decisions, LLM calls, and workflow steps. """ -import logging import json -from typing import Any, Dict, Optional +import logging from datetime import datetime, timezone +from typing import Any, Dict, Optional class BrainstormAgentLogger: @@ -33,19 +33,12 @@ def __init__(self, agent_name: str, project_id: Optional[str] = None): # Ensure structured output if not self.logger.handlers: handler = logging.StreamHandler() - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) - def _structured_log( - self, - level: str, - event: str, - extra_data: Optional[Dict[str, Any]] = None - ) -> None: + def _structured_log(self, level: str, event: str, extra_data: Optional[Dict[str, Any]] = None) -> None: """ Log a structured event. @@ -78,12 +71,7 @@ def log_agent_complete(self, **kwargs: Any) -> None: self._structured_log("info", f"{self.agent_name}_complete", kwargs) def log_llm_call( - self, - prompt: str, - model: str, - response: Optional[str] = None, - tokens_used: Optional[int] = None, - **kwargs: Any + self, prompt: str, model: str, response: Optional[str] = None, tokens_used: Optional[int] = None, **kwargs: Any ) -> None: """ Log an LLM API call. @@ -99,7 +87,7 @@ def log_llm_call( "model": model, "prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt, "prompt_length": len(prompt), - **kwargs + **kwargs, } if response: @@ -120,11 +108,7 @@ def log_decision(self, decision: str, rationale: str, **kwargs: Any) -> None: rationale: Explanation of why **kwargs: Additional context """ - data = { - "decision": decision, - "rationale": rationale, - **kwargs - } + data = {"decision": decision, "rationale": rationale, **kwargs} self._structured_log("info", "agent_decision", data) def log_pruning_stats( @@ -134,7 +118,7 @@ def log_pruning_stats( initial_questions: int, final_questions: int, pruned_items: list, - **kwargs: Any + **kwargs: Any, ) -> None: """ Log aspect and question pruning statistics. @@ -155,7 +139,7 @@ def log_pruning_stats( "pruned_aspects": initial_aspects - final_aspects, "pruned_questions": initial_questions - final_questions, "pruned_items": pruned_items, - **kwargs + **kwargs, } self._structured_log("info", "pruning_stats", data) @@ -186,11 +170,7 @@ def log_workflow_transition(self, from_state: str, to_state: str, **kwargs: Any) to_state: New state **kwargs: Additional context """ - data = { - "from_state": from_state, - "to_state": to_state, - **kwargs - } + data = {"from_state": from_state, "to_state": to_state, **kwargs} self._structured_log("info", "workflow_transition", data) diff --git a/backend/app/agents/brainstorm_conversation/orchestrator.py b/backend/app/agents/brainstorm_conversation/orchestrator.py index 7ca0bad..6da657a 100644 --- a/backend/app/agents/brainstorm_conversation/orchestrator.py +++ b/backend/app/agents/brainstorm_conversation/orchestrator.py @@ -11,36 +11,37 @@ """ import asyncio -from typing import Optional, Dict, Any, List, Callable, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from uuid import UUID from autogen_core.models import ChatCompletionClient if TYPE_CHECKING: from sqlalchemy.orm import Session - from app.models.project import Project + from app.models.brainstorming_phase import BrainstormingPhase + from app.models.project import Project -from app.agents.llm_client import create_litellm_client, LLMCallLogger +from app.agents.llm_client import LLMCallLogger, create_litellm_client +from .aspect_generator import AspectGeneratorAgent +from .classifier import ComplexityClassifierAgent +from .critic_pruner import CriticPrunerAgent +from .logging_config import get_agent_logger +from .question_generator import QuestionGeneratorAgent +from .summarizer import SummarizerAgent from .types import ( + AGENT_METADATA, + USER_INITIATED_CAPS, + WORKFLOW_STEP_DISPLAY_NAMES, + WORKFLOW_STEPS, BrainstormConversationContext, BrainstormConversationResult, + ClassificationResult, GeneratedAspect, PhaseComplexity, get_caps_for_complexity, - AGENT_METADATA, - WORKFLOW_STEPS, - WORKFLOW_STEP_DISPLAY_NAMES, - USER_INITIATED_CAPS, - ClassificationResult, ) -from .summarizer import SummarizerAgent -from .classifier import ComplexityClassifierAgent -from .aspect_generator import AspectGeneratorAgent -from .question_generator import QuestionGeneratorAgent -from .critic_pruner import CriticPrunerAgent -from .logging_config import get_agent_logger class BrainstormConversationOrchestrator: @@ -105,6 +106,7 @@ def __init__( self.call_logger = None if job_id and not mock_mode_enabled: from app.database import SessionLocal + self.call_logger = LLMCallLogger( db_session_factory=SessionLocal, job_id=job_id, @@ -134,16 +136,10 @@ def __init__( self.critic_pruner = None self.logger.log_agent_start( - model=str(model_client) if model_client else "mock", - provider=provider if not mock_mode_enabled else "mock" + model=str(model_client) if model_client else "mock", provider=provider if not mock_mode_enabled else "mock" ) - def _create_model_client( - self, - provider: str, - api_key: str, - config: Dict[str, Any] - ) -> ChatCompletionClient: + def _create_model_client(self, provider: str, api_key: str, config: Dict[str, Any]) -> ChatCompletionClient: """ Create a model client for the specified provider using LiteLLM. @@ -251,10 +247,7 @@ async def generate_brainstorm_conversations( # **NORMAL MODE**: Proceed with standard LLM workflow self.logger.log_workflow_transition( - from_state="start", - to_state="summarizing", - project_id=project_id_str, - phase_id=phase_id_str + from_state="start", to_state="summarizing", project_id=project_id_str, phase_id=phase_id_str ) try: @@ -273,7 +266,7 @@ async def generate_brainstorm_conversations( from_state="summarizing", to_state="generating_aspects", project_id=project_id_str, - user_initiated=True + user_initiated=True, ) complexity = PhaseComplexity.LOW @@ -283,21 +276,19 @@ async def generate_brainstorm_conversations( classification = ClassificationResult( complexity=complexity, rationale="User-initiated mode - conservative generation", - suggested_focus_areas=[context.user_initiated_context.user_prompt[:100]] + suggested_focus_areas=[context.user_initiated_context.user_prompt[:100]], ) self.logger.log_decision( decision="User-initiated mode: using conservative caps", rationale=f"User request: {context.user_initiated_context.user_prompt[:100]}", project_id=project_id_str, - user_initiated=True + user_initiated=True, ) else: # Normal mode: Run classifier self.logger.log_workflow_transition( - from_state="summarizing", - to_state="classifying", - project_id=project_id_str + from_state="summarizing", to_state="classifying", project_id=project_id_str ) # Step 2: Classify complexity @@ -305,9 +296,7 @@ async def generate_brainstorm_conversations( if self.call_logger: self.call_logger.set_agent("classifier", "Classifier") classification = await self.classifier.classify( - summarized_context, - phase_type=context.phase_type, - project_id=project_id_str + summarized_context, phase_type=context.phase_type, project_id=project_id_str ) complexity = classification.complexity @@ -317,14 +306,14 @@ async def generate_brainstorm_conversations( decision=f"Classified as {complexity.value} complexity", rationale=classification.rationale, project_id=project_id_str, - suggested_focus_areas=classification.suggested_focus_areas + suggested_focus_areas=classification.suggested_focus_areas, ) self.logger.log_workflow_transition( from_state="classifying", to_state="exploring_code", project_id=project_id_str, - complexity=complexity.value + complexity=complexity.value, ) # Step 2.5: Code Exploration (optional - provides codebase context) @@ -354,15 +343,11 @@ async def generate_brainstorm_conversations( project_id=project_id_str, ) else: - self.logger.logger.debug( - f"Code exploration returned no results for project {project_id_str}" - ) + self.logger.logger.debug(f"Code exploration returned no results for project {project_id_str}") except Exception as e: # Log but continue without code exploration context self.logger.log_error(e, {"context": "code_exploration_stage"}) - self.logger.logger.warning( - f"Code exploration failed, continuing without: {e}" - ) + self.logger.logger.warning(f"Code exploration failed, continuing without: {e}") self.logger.log_workflow_transition( from_state="exploring_code", @@ -390,14 +375,14 @@ async def generate_brainstorm_conversations( complexity=complexity.value, total_aspects=0, total_questions=0, - note="Phase fully explored - cumulative cap reached" + note="Phase fully explored - cumulative cap reached", ) return BrainstormConversationResult( aspects=[], complexity=complexity, total_aspects=0, total_questions=0, - generation_notes=["Phase is fully explored. No additional questions needed."] + generation_notes=["Phase is fully explored. No additional questions needed."], ) # Low engagement early stop: don't generate more if user isn't answering existing questions @@ -416,7 +401,7 @@ async def generate_brainstorm_conversations( complexity=complexity.value, total_aspects=0, total_questions=0, - note=f"Low engagement - only {answered_ratio:.1%} answered" + note=f"Low engagement - only {answered_ratio:.1%} answered", ) return BrainstormConversationResult( aspects=[], @@ -426,13 +411,14 @@ async def generate_brainstorm_conversations( generation_notes=[ f"Only {answered_count}/{existing_questions} questions answered. " "Please answer existing questions before generating more." - ] + ], ) # Create adjusted caps for this generation based on remaining budget # For user-initiated, keep USER_INITIATED_CAPS as-is (already conservative) if not is_user_initiated: from .types import ComplexityCaps + caps = ComplexityCaps( complexity=caps.complexity, min_aspects=min(caps.min_aspects, remaining_aspect_budget), @@ -457,14 +443,14 @@ async def generate_brainstorm_conversations( tech_stack_context=context.tech_stack_context, # Pass tech stack for 2nd-order questions code_exploration_context=code_exploration_context, # Pass codebase analysis sibling_phases_context=context.sibling_phases_context, # Pass sibling phase context - project_id=project_id_str + project_id=project_id_str, ) self.logger.log_workflow_transition( from_state="generating_aspects", to_state="generating_questions", project_id=project_id_str, - aspect_count=len(aspects) + aspect_count=len(aspects), ) # Step 4: Generate questions for each aspect @@ -482,7 +468,7 @@ async def generate_brainstorm_conversations( tech_stack_context=context.tech_stack_context, # Pass tech stack for 2nd-order questions code_exploration_context=code_exploration_context, # Pass codebase analysis sibling_phases_context=context.sibling_phases_context, # Pass sibling phase context - project_id=project_id_str + project_id=project_id_str, ) total_questions = sum(len(a.clarification_questions) for a in aspects_with_questions) @@ -491,7 +477,7 @@ async def generate_brainstorm_conversations( from_state="generating_questions", to_state="pruning", project_id=project_id_str, - total_questions=total_questions + total_questions=total_questions, ) # Step 5: Refine for quality @@ -505,7 +491,7 @@ async def generate_brainstorm_conversations( project_id=project_id_str, phase_summary=summarized_context.phase_summary, phase_description=context.phase_description, # Pass original for full context - existing_context=context.existing_context + existing_context=context.existing_context, ) final_question_count = sum(len(a.clarification_questions) for a in refined_aspects) @@ -515,7 +501,7 @@ async def generate_brainstorm_conversations( to_state="complete", project_id=project_id_str, final_aspects=len(refined_aspects), - final_questions=final_question_count + final_questions=final_question_count, ) # Step 6: Complete @@ -529,15 +515,15 @@ async def generate_brainstorm_conversations( total_questions=final_question_count, generation_notes=[ f"Complexity: {complexity.value}", - f"Suggested focus areas: {', '.join(classification.suggested_focus_areas)}" - ] + f"Suggested focus areas: {', '.join(classification.suggested_focus_areas)}", + ], ) self.logger.log_agent_complete( project_id=project_id_str, complexity=complexity.value, total_aspects=len(refined_aspects), - total_questions=final_question_count + total_questions=final_question_count, ) return result @@ -545,17 +531,11 @@ async def generate_brainstorm_conversations( except Exception as e: self.logger.log_error(e, {"project_id": project_id_str}) self.logger.log_workflow_transition( - from_state="error", - to_state="failed", - project_id=project_id_str, - error=str(e) + from_state="error", to_state="failed", project_id=project_id_str, error=str(e) ) raise - async def _run_mock_workflow( - self, - context: BrainstormConversationContext - ) -> BrainstormConversationResult: + async def _run_mock_workflow(self, context: BrainstormConversationContext) -> BrainstormConversationResult: """ Run mock workflow for testing without LLM calls. @@ -565,9 +545,7 @@ async def _run_mock_workflow( Returns: Mock BrainstormConversationResult """ - self.logger.logger.info( - f"Mock mode enabled: simulating workflow with {self.mock_delay_seconds}s delays" - ) + self.logger.logger.info(f"Mock mode enabled: simulating workflow with {self.mock_delay_seconds}s delays") # Step through workflow with delays to simulate real processing workflow_steps_with_agents = [ @@ -586,11 +564,10 @@ async def _run_mock_workflow( # Return mock result from .types import ( - GeneratedAspect, + AspectCategory, GeneratedClarificationQuestion, GeneratedMCQ, MCQChoice, - AspectCategory, QuestionPriority, ) @@ -613,10 +590,10 @@ async def _run_mock_workflow( MCQChoice(id="c", label="Both internal and external"), MCQChoice(id="d", label="Something else"), ], - explanation="Understanding the target users is critical for design" - ) + explanation="Understanding the target users is critical for design", + ), ) - ] + ], ), GeneratedAspect( title="Technical Architecture", @@ -635,11 +612,11 @@ async def _run_mock_workflow( MCQChoice(id="b", label="Evaluate new technologies"), MCQChoice(id="c", label="Something else"), ], - explanation="Technology choices impact development timeline" - ) + explanation="Technology choices impact development timeline", + ), ) - ] - ) + ], + ), ] return BrainstormConversationResult( @@ -647,7 +624,7 @@ async def _run_mock_workflow( complexity=PhaseComplexity.MEDIUM, total_aspects=len(mock_aspects), total_questions=sum(len(a.clarification_questions) for a in mock_aspects), - generation_notes=["Mock mode - using test data"] + generation_notes=["Mock mode - using test data"], ) async def close(self): diff --git a/backend/app/agents/brainstorm_conversation/question_generator.py b/backend/app/agents/brainstorm_conversation/question_generator.py index 7075d34..bb2667c 100644 --- a/backend/app/agents/brainstorm_conversation/question_generator.py +++ b/backend/app/agents/brainstorm_conversation/question_generator.py @@ -14,23 +14,23 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( + STANDARD_MCQ_CHOICES, + CodeExplorationContext, + ComplexityCaps, + CrossProjectContext, + ExistingConversationContext, GeneratedAspect, GeneratedClarificationQuestion, GeneratedMCQ, MCQChoice, QuestionPriority, + SiblingPhasesContext, SummarizedPhaseContext, - ComplexityCaps, - STANDARD_MCQ_CHOICES, - ExistingConversationContext, - UserInitiatedContext, - CrossProjectContext, TechStackContext, - CodeExplorationContext, - SiblingPhasesContext, + UserInitiatedContext, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -158,7 +158,7 @@ async def generate_questions_for_aspect( tech_stack_context: Optional[TechStackContext] = None, code_exploration_context: Optional[CodeExplorationContext] = None, sibling_phases_context: Optional[SiblingPhasesContext] = None, - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> List[GeneratedClarificationQuestion]: """ Generate clarification questions for a specific aspect. @@ -185,15 +185,22 @@ async def generate_questions_for_aspect( project_id=project_id, aspect_title=aspect.title, complexity=caps.complexity.value, - target_range=f"{caps.min_questions_per_aspect}-{caps.max_questions_per_aspect}" + target_range=f"{caps.min_questions_per_aspect}-{caps.max_questions_per_aspect}", ) try: # Build the prompt prompt = self._build_prompt( - aspect, summarized_context, caps, existing_context, user_initiated_context, - grounding_context, cross_project_context, tech_stack_context, - code_exploration_context, sibling_phases_context + aspect, + summarized_context, + caps, + existing_context, + user_initiated_context, + grounding_context, + cross_project_context, + tech_stack_context, + code_exploration_context, + sibling_phases_context, ) # Call the agent @@ -201,7 +208,7 @@ async def generate_questions_for_aspect( prompt=prompt, model=str(self.model_client), operation="generate_questions_for_aspect", - aspect=aspect.title + aspect=aspect.title, ) # Create fresh agent to avoid conversation history accumulation @@ -212,10 +219,7 @@ async def generate_questions_for_aspect( system_message=self._get_system_message(), model_client=self.model_client, ) - response = await fresh_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await fresh_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -228,10 +232,9 @@ async def generate_questions_for_aspect( try: questions_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) raise ValueError(f"Failed to parse question generator response as JSON: {e}") # Convert to GeneratedClarificationQuestion objects @@ -245,9 +248,7 @@ async def generate_questions_for_aspect( # Skip invalid questions but continue processing self.logger.log_agent_complete( - aspect=aspect.title, - generated_count=len(questions), - priorities=[q.priority.value for q in questions] + aspect=aspect.title, generated_count=len(questions), priorities=[q.priority.value for q in questions] ) return questions @@ -268,7 +269,7 @@ async def generate_questions_for_all_aspects( tech_stack_context: Optional[TechStackContext] = None, code_exploration_context: Optional[CodeExplorationContext] = None, sibling_phases_context: Optional[SiblingPhasesContext] = None, - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> List[GeneratedAspect]: """ Generate clarification questions for all aspects. @@ -304,7 +305,7 @@ async def generate_questions_for_all_aspects( tech_stack_context=tech_stack_context, code_exploration_context=code_exploration_context, sibling_phases_context=sibling_phases_context, - project_id=project_id + project_id=project_id, ) for aspect in aspects ] @@ -323,15 +324,13 @@ async def generate_questions_for_all_aspects( self.logger.log_decision( decision=f"Generated {total_questions} questions for {len(aspects)} aspects", rationale=f"Complexity: {caps.complexity.value}", - project_id=project_id + project_id=project_id, ) return aspects def _build_existing_questions_section( - self, - existing_context: ExistingConversationContext, - aspect_title: str + self, existing_context: ExistingConversationContext, aspect_title: str ) -> str: """ Build the existing questions context section for the prompt. @@ -365,13 +364,15 @@ def _build_existing_questions_section( if related_questions: sections.append("### EXISTING QUESTIONS - DO NOT DUPLICATE:") - sections.append("The following questions have already been asked. DO NOT create questions with the same or similar meaning.") + sections.append( + "The following questions have already been asked. DO NOT create questions with the same or similar meaning." + ) sections.append("") for q in related_questions: if q.status == "answered" and q.decision_summary: sections.append(f"- Q: {q.question_title}") - sections.append(f" Decision: \"{q.decision_summary}\"") + sections.append(f' Decision: "{q.decision_summary}"') # Show unresolved points if any (limit to 2 for brevity) if q.unresolved_points: sections.append(f" Open: {', '.join(q.unresolved_points[:2])}") @@ -389,10 +390,7 @@ def _build_existing_questions_section( return "\n".join(sections) - def _build_user_initiated_section( - self, - user_initiated_context: UserInitiatedContext - ) -> str: + def _build_user_initiated_section(self, user_initiated_context: UserInitiatedContext) -> str: """ Build the user-initiated context section for the prompt. @@ -468,11 +466,15 @@ def _build_tech_stack_section(self, tech_stack_context: TechStackContext) -> str sections = [] sections.append("### TECHNOLOGY STACK CONTEXT") sections.append("") - sections.append(f"The user indicated this tech stack during initial discussion: **{tech_stack_context.proposed_stack}**") + sections.append( + f"The user indicated this tech stack during initial discussion: **{tech_stack_context.proposed_stack}**" + ) sections.append("") if tech_stack_context.has_grounding: - sections.append("NOTE: This project has grounding documentation (agents.md), which may already document tech decisions.") + sections.append( + "NOTE: This project has grounding documentation (agents.md), which may already document tech decisions." + ) sections.append("Be conservative - only ask stack questions if they seem genuinely unresolved.") sections.append("") @@ -493,11 +495,7 @@ def _build_tech_stack_section(self, tech_stack_context: TechStackContext) -> str return "\n".join(sections) - def _build_cross_project_section( - self, - cross_project_context: CrossProjectContext, - aspect_title: str - ) -> str: + def _build_cross_project_section(self, cross_project_context: CrossProjectContext, aspect_title: str) -> str: """ Build the cross-phase and project-level context section for the prompt. @@ -531,7 +529,8 @@ def _build_cross_project_section( for phase in cross_project_context.other_phases: # Filter to show decisions most relevant to this aspect relevant_decisions = [ - d for d in phase.decisions + d + for d in phase.decisions if any(word in d.aspect_title.lower() for word in aspect_title.lower().split()) or any(word in d.question_title.lower() for word in aspect_title.lower().split()) ][:5] # Cap at 5 relevant decisions @@ -543,7 +542,7 @@ def _build_cross_project_section( if relevant_decisions: sections.append(f"**{phase.phase_title}:**") for decision in relevant_decisions: - sections.append(f" - {decision.question_title}: \"{decision.decision_summary_short}\"") + sections.append(f' - {decision.question_title}: "{decision.decision_summary_short}"') sections.append("") # Section 2: Project-level features @@ -552,7 +551,7 @@ def _build_cross_project_section( sections.append("") for feat in cross_project_context.project_features[:10]: # Cap at 10 - sections.append(f" - [{feat.module_title}] {feat.feature_title}: \"{feat.decision_summary_short}\"") + sections.append(f' - [{feat.module_title}] {feat.feature_title}: "{feat.decision_summary_short}"') sections.append("") # Guidance @@ -563,11 +562,7 @@ def _build_cross_project_section( return "\n".join(sections) - def _build_sibling_phases_section( - self, - sibling_phases_context: SiblingPhasesContext, - aspect_title: str - ) -> str: + def _build_sibling_phases_section(self, sibling_phases_context: SiblingPhasesContext, aspect_title: str) -> str: """ Build the sibling phases context section for the question generation prompt. @@ -586,7 +581,7 @@ def _build_sibling_phases_section( return "" sections = [] - sections.append(f"### SIBLING PHASE DECISIONS (container: \"{sibling_phases_context.container_title}\"):") + sections.append(f'### SIBLING PHASE DECISIONS (container: "{sibling_phases_context.container_title}"):') sections.append("Decisions from related phases in the same container. Avoid duplicating these questions:") sections.append("") @@ -595,7 +590,8 @@ def _build_sibling_phases_section( for phase in sibling_phases_context.sibling_phases: # Filter decisions by keyword overlap with current aspect relevant_decisions = [ - d for d in phase.decisions + d + for d in phase.decisions if any(word in d.aspect_title.lower() for word in aspect_words) or any(word in d.question_title.lower() for word in aspect_words) ][:5] @@ -607,7 +603,7 @@ def _build_sibling_phases_section( if relevant_decisions: sections.append(f"**{phase.phase_title}:**") for decision in relevant_decisions: - sections.append(f" - {decision.question_title}: \"{decision.decision_summary_short}\"") + sections.append(f' - {decision.question_title}: "{decision.decision_summary_short}"') sections.append("") sections.append("Avoid asking questions already answered in sibling phases above.") @@ -615,10 +611,7 @@ def _build_sibling_phases_section( return "\n".join(sections) - def _build_code_exploration_section( - self, - code_exploration_context: CodeExplorationContext - ) -> str: + def _build_code_exploration_section(self, code_exploration_context: CodeExplorationContext) -> str: """ Build the code exploration context section for the prompt. @@ -654,7 +647,7 @@ def _build_prompt( cross_project_context: Optional[CrossProjectContext] = None, tech_stack_context: Optional[TechStackContext] = None, code_exploration_context: Optional[CodeExplorationContext] = None, - sibling_phases_context: Optional[SiblingPhasesContext] = None + sibling_phases_context: Optional[SiblingPhasesContext] = None, ) -> str: """ Build the question generation prompt for a specific aspect. @@ -677,7 +670,7 @@ def _build_prompt( prompt += f"clarification questions for the aspect: **{aspect.title}**\n\n" # Add aspect details - prompt += f"### Aspect Details:\n" + prompt += "### Aspect Details:\n" prompt += f"- Title: {aspect.title}\n" prompt += f"- Category: {aspect.category.value}\n" prompt += f"- Description: {aspect.description}\n\n" @@ -780,7 +773,7 @@ def _parse_question(self, q_data: dict) -> GeneratedClarificationQuestion: description=description, priority=priority, initial_mcq=mcq, - internal_agent_notes=internal_notes + internal_agent_notes=internal_notes, ) def _parse_mcq(self, mcq_data: dict) -> GeneratedMCQ: @@ -801,10 +794,10 @@ def _parse_mcq(self, mcq_data: dict) -> GeneratedMCQ: for i, choice in enumerate(choices_data): if isinstance(choice, dict): - choice_id = choice.get("id", chr(ord('a') + i)) + choice_id = choice.get("id", chr(ord("a") + i)) choice_label = choice.get("label", f"Option {i + 1}") else: - choice_id = chr(ord('a') + i) + choice_id = chr(ord("a") + i) choice_label = str(choice) choices.append(MCQChoice(id=choice_id, label=choice_label)) @@ -850,8 +843,7 @@ def _parse_mcq(self, mcq_data: dict) -> GeneratedMCQ: async def create_question_generator( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> QuestionGeneratorAgent: """ Factory function to create a Question Generator Agent. diff --git a/backend/app/agents/brainstorm_conversation/summarizer.py b/backend/app/agents/brainstorm_conversation/summarizer.py index 4393f3a..34db4d8 100644 --- a/backend/app/agents/brainstorm_conversation/summarizer.py +++ b/backend/app/agents/brainstorm_conversation/summarizer.py @@ -13,8 +13,8 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient -from .types import BrainstormConversationContext, SummarizedPhaseContext from .logging_config import get_agent_logger +from .types import BrainstormConversationContext, SummarizedPhaseContext from .utils import strip_markdown_json @@ -87,7 +87,7 @@ async def summarize(self, context: BrainstormConversationContext) -> SummarizedP project_id=str(context.project_id), phase_id=str(context.brainstorming_phase_id), phase_type=context.phase_type.value, - description_length=len(context.phase_description) + description_length=len(context.phase_description), ) try: @@ -95,16 +95,9 @@ async def summarize(self, context: BrainstormConversationContext) -> SummarizedP prompt = self._build_prompt(context) # Call the agent - self.logger.log_llm_call( - prompt=prompt, - model=str(self.model_client), - operation="summarize_phase_context" - ) + self.logger.log_llm_call(prompt=prompt, model=str(self.model_client), operation="summarize_phase_context") - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -126,7 +119,7 @@ async def summarize(self, context: BrainstormConversationContext) -> SummarizedP key_objectives=summary_data.get("key_objectives", []), constraints=summary_data.get("constraints", []), target_users=summary_data.get("target_users"), - technical_context=summary_data.get("technical_context") + technical_context=summary_data.get("technical_context"), ) self.logger.log_agent_complete( @@ -134,7 +127,7 @@ async def summarize(self, context: BrainstormConversationContext) -> SummarizedP num_objectives=len(result.key_objectives), num_constraints=len(result.constraints), has_target_users=bool(result.target_users), - has_technical_context=bool(result.technical_context) + has_technical_context=bool(result.technical_context), ) return result @@ -183,8 +176,7 @@ def _build_prompt(self, context: BrainstormConversationContext) -> str: async def create_summarizer_agent( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> SummarizerAgent: """ Factory function to create a Summarizer Agent. diff --git a/backend/app/agents/brainstorm_conversation/types.py b/backend/app/agents/brainstorm_conversation/types.py index 16b6c2f..af1b234 100644 --- a/backend/app/agents/brainstorm_conversation/types.py +++ b/backend/app/agents/brainstorm_conversation/types.py @@ -15,26 +15,28 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Dict, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID if TYPE_CHECKING: - from sqlalchemy.orm import Session - from app.models.project import Project + pass # ============================ # Enums # ============================ + class PhaseType(str, Enum): """Type of brainstorming phase.""" + INITIAL = "initial" # Greenfield project brainstorming FEATURE_SPECIFIC = "feature_specific" # Adding features to existing project class PhaseComplexity(str, Enum): """Brainstorming phase complexity level determined by the classifier.""" + LOW = "low" MEDIUM = "medium" HIGH = "high" @@ -42,6 +44,7 @@ class PhaseComplexity(str, Enum): class AspectCategory(str, Enum): """Categories for aspects (exploration areas).""" + USER_EXPERIENCE = "User_Experience" TECHNICAL_ARCHITECTURE = "Technical_Architecture" DATA_MANAGEMENT = "Data_Management" @@ -54,18 +57,21 @@ class AspectCategory(str, Enum): class QuestionPriority(str, Enum): """3-level priority model for clarification questions.""" + MUST_HAVE = "must_have" # Critical; should not be pruned - IMPORTANT = "important" # High value; can be pruned only in low-complexity - OPTIONAL = "optional" # Helpful but non-essential; first candidate for pruning + IMPORTANT = "important" # High value; can be pruned only in low-complexity + OPTIONAL = "optional" # Helpful but non-essential; first candidate for pruning # ============================ # Dataclasses # ============================ + @dataclass class MCQChoice: """A single choice option for an MCQ question.""" + id: str label: str @@ -76,6 +82,7 @@ class GeneratedMCQ: An MCQ (Multiple Choice Question) to start a conversation thread. Each clarification question gets a thread with an initial MCQ. """ + question_text: str choices: List[MCQChoice] explanation: Optional[str] = None @@ -89,6 +96,7 @@ class GeneratedClarificationQuestion: A clarification question generated by the Question Generator. This will be stored as a Feature record. """ + title: str description: str priority: QuestionPriority @@ -102,6 +110,7 @@ class GeneratedAspect: An aspect (exploration area) generated by the Aspect Generator. This will be stored as a Module record. """ + title: str description: str category: AspectCategory @@ -118,6 +127,7 @@ class UserInitiatedContext: When present, the pipeline uses conservative caps and incorporates the user's specific prompt into aspect/question generation. """ + user_prompt: str # User's focus area / request num_questions: int = 3 # 1-5, default 3 session_history: List[Dict[str, Any]] = field(default_factory=list) # Previous messages @@ -168,6 +178,7 @@ class BrainstormConversationContext: Raw context passed to the Orchestrator for generating brainstorm conversations. This is the full input data available for the multi-agent workflow. """ + project_id: UUID brainstorming_phase_id: UUID phase_title: str @@ -205,6 +216,7 @@ class SummarizedPhaseContext: Output from the Summarizer Agent. Condensed summary of the phase description and objectives (300-600 tokens). """ + phase_summary: str key_objectives: List[str] constraints: List[str] @@ -217,6 +229,7 @@ class ClassificationResult: """ Output from the Complexity Classification Agent. """ + complexity: PhaseComplexity rationale: str suggested_focus_areas: List[str] @@ -228,6 +241,7 @@ class ComplexityCaps: Configuration caps based on complexity level. Defines limits for aspect and question generation. """ + complexity: PhaseComplexity # Aspects @@ -248,6 +262,7 @@ class BrainstormConversationResult: Final result from the Brainstorm Conversation Orchestrator. Contains all generated aspects with their clarification questions. """ + aspects: List[GeneratedAspect] complexity: PhaseComplexity total_aspects: int @@ -259,12 +274,14 @@ class BrainstormConversationResult: # Agent Metadata for UI # ============================ + @dataclass class AgentInfo: """ UI metadata for an agent in the Brainstorm Conversation workflow. Used for progress tracking and visual representation. """ + name: str description: str color: str # Hex color for UI tag @@ -275,37 +292,37 @@ class AgentInfo: "orchestrator": AgentInfo( name="Orchestrator", description="Coordinating the brainstorming workflow", - color="#8B5CF6" # Purple + color="#8B5CF6", # Purple ), "summarizer": AgentInfo( name="Summarizer", description="Analyzing phase description and objectives", - color="#3B82F6" # Blue + color="#3B82F6", # Blue ), "classifier": AgentInfo( name="Classifier", description="Determining phase complexity and focus", - color="#10B981" # Green + color="#10B981", # Green ), "aspect_generator": AgentInfo( name="Aspect Generator", description="Identifying key areas to explore", - color="#F59E0B" # Amber + color="#F59E0B", # Amber ), "question_generator": AgentInfo( name="Question Generator", description="Creating clarification questions with MCQs", - color="#6366F1" # Indigo + color="#6366F1", # Indigo ), "critic_pruner": AgentInfo( name="Critic", description="Refining and optimizing question quality", - color="#EC4899" # Pink + color="#EC4899", # Pink ), "code_explorer": AgentInfo( name="Code Explorer", description="Analyzing codebase for relevant patterns", - color="#14B8A6" # Teal + color="#14B8A6", # Teal ), } @@ -319,7 +336,7 @@ class AgentInfo: "generating_aspects", "generating_questions", "pruning", - "complete" + "complete", ] # User-friendly display names for workflow steps @@ -395,6 +412,7 @@ class AgentInfo: # Helper Functions # ============================ + def get_caps_for_complexity(complexity: PhaseComplexity) -> ComplexityCaps: """Get the configuration caps for a given complexity level.""" return COMPLEXITY_CAPS_CONFIG[complexity] @@ -414,10 +432,7 @@ def create_mcq_choices(options: List[str]) -> List[MCQChoice]: if not options or len(options) < 1 or len(options) > 3: raise ValueError("Must provide 1-3 specific options for MCQ question") - choices = [ - MCQChoice(id=f"option_{i}", label=label) - for i, label in enumerate(options, start=1) - ] + choices = [MCQChoice(id=f"option_{i}", label=label) for i, label in enumerate(options, start=1)] # Append standard choices choices.extend(STANDARD_MCQ_CHOICES) @@ -429,12 +444,14 @@ def create_mcq_choices(options: List[str]) -> List[MCQChoice]: # Existing Context Types (for Generate Additional) # ============================ + @dataclass class ExistingQuestionWithAnswer: """ A single existing clarification question with its decision context. Uses thread decision summaries for richer, more scalable context. """ + question_id: str question_title: str question_description: str @@ -453,6 +470,7 @@ class ExistingAspect: An existing aspect (module) with its questions. Used to provide context to generators about what's already been explored. """ + aspect_id: str title: str description: str @@ -468,6 +486,7 @@ class ExistingConversationContext: Complete context about existing aspects and questions for generators. Enables aspect-aware generation that builds on previous conversations. """ + aspects: List[ExistingAspect] total_aspects: int total_questions: int @@ -479,12 +498,14 @@ class ExistingConversationContext: # Cross-Phase Context Types # ============================ + @dataclass class CrossPhaseDecision: """ A decision from another phase's thread. Uses decision_summary_short for compact cross-phase context. """ + question_title: str decision_summary_short: str # Single-sentence summary (~100-150 chars) aspect_title: str @@ -496,6 +517,7 @@ class CrossPhaseContext: Context from another brainstorming phase in the same project. Contains the phase info and its ACTIVE thread decisions. """ + phase_id: str phase_title: str phase_description: str # Truncated to ~200 chars for compactness @@ -508,6 +530,7 @@ class ProjectFeatureDecision: A decision from a project-level feature thread. Project features are IMPLEMENTATION features with module.brainstorming_phase_id IS NULL. """ + feature_title: str module_title: str decision_summary_short: str @@ -521,6 +544,7 @@ class CrossProjectContext: 1. Decisions from OTHER brainstorming phases (not the current one) 2. Decisions from project-level implementation features """ + other_phases: List[CrossPhaseContext] project_features: List[ProjectFeatureDecision] @@ -529,6 +553,7 @@ class CrossProjectContext: # Sibling Phase Context Types (for phases in same container) # ============================ + @dataclass class SiblingPhaseContext: """ @@ -537,6 +562,7 @@ class SiblingPhaseContext: Sibling phases share a container and have container_sequence ordering, allowing agents to understand the progression of related phases. """ + phase_id: str phase_title: str phase_subtype: str # "INITIAL_SPEC" or "EXTENSION" @@ -554,6 +580,7 @@ class SiblingPhasesContext: This provides agents with rich context about related phases when generating content for extension phases within a container. """ + container_id: str container_title: str sibling_phases: List[SiblingPhaseContext] @@ -605,20 +632,16 @@ def validate_brainstorm_result(result: BrainstormConversationResult) -> List[str # Check total questions if result.total_questions > caps.total_max_questions: - issues.append( - f"Too many questions: {result.total_questions} > {caps.total_max_questions}" - ) + issues.append(f"Too many questions: {result.total_questions} > {caps.total_max_questions}") # Check each aspect for aspect in result.aspects: if not aspect.title.strip(): - issues.append(f"Aspect has empty title") + issues.append("Aspect has empty title") q_count = len(aspect.clarification_questions) if q_count < caps.min_questions_per_aspect: - issues.append( - f"Aspect '{aspect.title}' has too few questions: {q_count} < {caps.min_questions_per_aspect}" - ) + issues.append(f"Aspect '{aspect.title}' has too few questions: {q_count} < {caps.min_questions_per_aspect}") if q_count > caps.max_questions_per_aspect: issues.append( f"Aspect '{aspect.title}' has too many questions: {q_count} > {caps.max_questions_per_aspect}" @@ -632,13 +655,9 @@ def validate_brainstorm_result(result: BrainstormConversationResult) -> List[str # Check MCQ choices choice_count = len(question.initial_mcq.choices) if choice_count < MIN_MCQ_CHOICES: - issues.append( - f"Question '{question.title}' MCQ has too few choices: {choice_count}" - ) + issues.append(f"Question '{question.title}' MCQ has too few choices: {choice_count}") if choice_count > MAX_MCQ_CHOICES: - issues.append( - f"Question '{question.title}' MCQ has too many choices: {choice_count}" - ) + issues.append(f"Question '{question.title}' MCQ has too many choices: {choice_count}") # Check for duplicate aspect titles aspect_titles = [a.title.lower().strip() for a in result.aspects] diff --git a/backend/app/agents/brainstorm_conversation/utils.py b/backend/app/agents/brainstorm_conversation/utils.py index 7962f3d..469fc1d 100644 --- a/backend/app/agents/brainstorm_conversation/utils.py +++ b/backend/app/agents/brainstorm_conversation/utils.py @@ -5,9 +5,9 @@ import re # Import from common module and re-export for backwards compatibility -from app.agents.response_parser import strip_markdown_json, normalize_response_content +from app.agents.response_parser import normalize_response_content, strip_markdown_json -__all__ = ['strip_markdown_json', 'normalize_response_content', 'truncate_text', 'normalize_whitespace'] +__all__ = ["strip_markdown_json", "normalize_response_content", "truncate_text", "normalize_whitespace"] def truncate_text(text: str, max_length: int = 500) -> str: @@ -25,7 +25,7 @@ def truncate_text(text: str, max_length: int = 500) -> str: return text # Truncate at word boundary - truncated = text[:max_length].rsplit(' ', 1)[0] + truncated = text[:max_length].rsplit(" ", 1)[0] return truncated + "..." @@ -40,5 +40,5 @@ def normalize_whitespace(text: str) -> str: Normalized text """ # Replace multiple whitespace with single space - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) return text.strip() diff --git a/backend/app/agents/brainstorm_prompt_plan/__init__.py b/backend/app/agents/brainstorm_prompt_plan/__init__.py index 497eebd..24a4cc4 100644 --- a/backend/app/agents/brainstorm_prompt_plan/__init__.py +++ b/backend/app/agents/brainstorm_prompt_plan/__init__.py @@ -11,54 +11,49 @@ 4. QA - Validates completeness and actionability """ -from .types import ( - # Enums - PromptPlanSectionId, - # Dataclasses - BrainstormPromptPlanContext, - AnalyzedContext, - PromptPlanOutlineSection, - BrainstormPromptPlanOutline, - PromptPlanSectionContent, - BrainstormPromptPlan, - PromptPlanValidationReport, - BrainstormPromptPlanResult, - AgentInfo, - # Constants - AGENT_METADATA, - WORKFLOW_STEPS, - FIXED_SECTIONS, - ASPECT_CATEGORY_TO_SECTION_MAPPING, - # Helper functions - get_section_by_id, - get_outline_section_by_id, - build_section_tree, +from .analyzer import ( + AnalyzerAgent, + create_analyzer, ) - from .orchestrator import ( BrainstormPromptPlanOrchestrator, create_orchestrator, ) - -from .analyzer import ( - AnalyzerAgent, - create_analyzer, -) - from .planner import ( PlannerAgent, create_planner, ) - -from .writer import ( - WriterAgent, - create_writer, -) - from .qa import ( QAAgent, create_qa, ) +from .types import ( + # Constants + AGENT_METADATA, + ASPECT_CATEGORY_TO_SECTION_MAPPING, + FIXED_SECTIONS, + WORKFLOW_STEPS, + AgentInfo, + AnalyzedContext, + BrainstormPromptPlan, + # Dataclasses + BrainstormPromptPlanContext, + BrainstormPromptPlanOutline, + BrainstormPromptPlanResult, + PromptPlanOutlineSection, + PromptPlanSectionContent, + # Enums + PromptPlanSectionId, + PromptPlanValidationReport, + build_section_tree, + get_outline_section_by_id, + # Helper functions + get_section_by_id, +) +from .writer import ( + WriterAgent, + create_writer, +) __all__ = [ # Enums diff --git a/backend/app/agents/brainstorm_prompt_plan/analyzer.py b/backend/app/agents/brainstorm_prompt_plan/analyzer.py index 4b38f76..7ce9fda 100644 --- a/backend/app/agents/brainstorm_prompt_plan/analyzer.py +++ b/backend/app/agents/brainstorm_prompt_plan/analyzer.py @@ -13,11 +13,11 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( - BrainstormPromptPlanContext, AnalyzedContext, + BrainstormPromptPlanContext, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -102,11 +102,7 @@ def _get_system_message(self) -> str: Return ONLY the JSON object, no additional text.""" - async def analyze( - self, - context: BrainstormPromptPlanContext, - project_id: Optional[str] = None - ) -> AnalyzedContext: + async def analyze(self, context: BrainstormPromptPlanContext, project_id: Optional[str] = None) -> AnalyzedContext: """ Analyze brainstorming data and extract implementation components. @@ -125,7 +121,7 @@ async def analyze( phase_title=context.phase_title, aspects_count=len(context.aspects), questions_count=len(context.clarification_questions), - threads_count=len(context.thread_discussions) + threads_count=len(context.thread_discussions), ) try: @@ -136,13 +132,10 @@ async def analyze( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="analyze_brainstorm_context" + operation="analyze_brainstorm_context", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -155,10 +148,9 @@ async def analyze( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) raise ValueError(f"Failed to parse analyzer response as JSON: {e}") # Convert to AnalyzedContext @@ -202,7 +194,7 @@ def _build_prompt(self, context: BrainstormPromptPlanContext) -> str: Returns: Formatted prompt string """ - prompt = f"Analyze the following brainstorming data to extract implementation components.\n\n" + prompt = "Analyze the following brainstorming data to extract implementation components.\n\n" # Add phase info prompt += f"### Phase Title: {context.phase_title}\n" @@ -225,11 +217,11 @@ def _build_prompt(self, context: BrainstormPromptPlanContext) -> str: prompt += f" Question: {q.get('spec_text', q.get('description', 'No description'))}\n" # Include ALL answered MCQs for this feature - answered_mcqs = q.get('answered_mcqs', []) + answered_mcqs = q.get("answered_mcqs", []) for mcq in answered_mcqs: - question_text = mcq.get('question_text', '') - selected_label = mcq.get('selected_label', '') - free_text = mcq.get('free_text') + question_text = mcq.get("question_text", "") + selected_label = mcq.get("selected_label", "") + free_text = mcq.get("free_text") prompt += f" - MCQ: {question_text}\n" prompt += f" Answer: {selected_label}\n" @@ -243,22 +235,22 @@ def _build_prompt(self, context: BrainstormPromptPlanContext) -> str: for thread in context.thread_discussions: prompt += f"- **{thread.get('title', 'Untitled Thread')}**\n" # Prefer decision summary if available - if thread.get('decision_summary'): + if thread.get("decision_summary"): prompt += f" {thread['decision_summary']}\n" - if thread.get('unresolved_points'): + if thread.get("unresolved_points"): prompt += " **Unresolved:**\n" - for point in thread['unresolved_points']: - question = point.get('question', '') if isinstance(point, dict) else str(point) + for point in thread["unresolved_points"]: + question = point.get("question", "") if isinstance(point, dict) else str(point) prompt += f" - {question}\n" else: # Fallback to raw comments if no decision summary - comments = thread.get('comments', []) + comments = thread.get("comments", []) for comment in comments[:5]: # Limit to first 5 comments per thread prompt += f" - {comment.get('content', '')[:200]}\n" # Note any images attached to this thread - if thread.get('images'): + if thread.get("images"): prompt += " **Attached Images:**\n" - for img in thread['images'][:5]: # Limit to first 5 images + for img in thread["images"][:5]: # Limit to first 5 images prompt += f" - {img.get('filename', 'unnamed')} (ID: {img.get('id', 'unknown')})\n" prompt += "\n" @@ -307,10 +299,7 @@ def _build_sibling_implementation_section(sibling_phases_context) -> str: if not sibling_phases_context: return "" - phases_with_analysis = [ - p for p in sibling_phases_context.sibling_phases - if p.implementation_analysis - ] + phases_with_analysis = [p for p in sibling_phases_context.sibling_phases if p.implementation_analysis] if not phases_with_analysis: return "" @@ -327,10 +316,7 @@ def _build_sibling_implementation_section(sibling_phases_context) -> str: return "\n".join(sections) -async def create_analyzer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> AnalyzerAgent: +async def create_analyzer(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> AnalyzerAgent: """ Factory function to create an Analyzer Agent. diff --git a/backend/app/agents/brainstorm_prompt_plan/logging_config.py b/backend/app/agents/brainstorm_prompt_plan/logging_config.py index dd081c0..8d91cea 100644 --- a/backend/app/agents/brainstorm_prompt_plan/logging_config.py +++ b/backend/app/agents/brainstorm_prompt_plan/logging_config.py @@ -4,10 +4,10 @@ Provides structured logging for all agent decisions, LLM calls, and workflow steps. """ -import logging import json -from typing import Any, Dict, Optional +import logging from datetime import datetime, timezone +from typing import Any, Dict, Optional class BrainstormAgentLogger: @@ -33,19 +33,12 @@ def __init__(self, agent_name: str, project_id: Optional[str] = None): # Ensure structured output if not self.logger.handlers: handler = logging.StreamHandler() - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) - def _structured_log( - self, - level: str, - event: str, - extra_data: Optional[Dict[str, Any]] = None - ) -> None: + def _structured_log(self, level: str, event: str, extra_data: Optional[Dict[str, Any]] = None) -> None: """ Log a structured event. @@ -78,12 +71,7 @@ def log_agent_complete(self, **kwargs: Any) -> None: self._structured_log("info", f"{self.agent_name}_complete", kwargs) def log_llm_call( - self, - prompt: str, - model: str, - response: Optional[str] = None, - tokens_used: Optional[int] = None, - **kwargs: Any + self, prompt: str, model: str, response: Optional[str] = None, tokens_used: Optional[int] = None, **kwargs: Any ) -> None: """ Log an LLM API call. @@ -99,7 +87,7 @@ def log_llm_call( "model": model, "prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt, "prompt_length": len(prompt), - **kwargs + **kwargs, } if response: @@ -120,11 +108,7 @@ def log_decision(self, decision: str, rationale: str, **kwargs: Any) -> None: rationale: Explanation of why **kwargs: Additional context """ - data = { - "decision": decision, - "rationale": rationale, - **kwargs - } + data = {"decision": decision, "rationale": rationale, **kwargs} self._structured_log("info", "agent_decision", data) def log_pruning_stats( @@ -134,7 +118,7 @@ def log_pruning_stats( initial_questions: int, final_questions: int, pruned_items: list, - **kwargs: Any + **kwargs: Any, ) -> None: """ Log aspect and question pruning statistics. @@ -155,7 +139,7 @@ def log_pruning_stats( "pruned_aspects": initial_aspects - final_aspects, "pruned_questions": initial_questions - final_questions, "pruned_items": pruned_items, - **kwargs + **kwargs, } self._structured_log("info", "pruning_stats", data) @@ -186,11 +170,7 @@ def log_workflow_transition(self, from_state: str, to_state: str, **kwargs: Any) to_state: New state **kwargs: Additional context """ - data = { - "from_state": from_state, - "to_state": to_state, - **kwargs - } + data = {"from_state": from_state, "to_state": to_state, **kwargs} self._structured_log("info", "workflow_transition", data) diff --git a/backend/app/agents/brainstorm_prompt_plan/orchestrator.py b/backend/app/agents/brainstorm_prompt_plan/orchestrator.py index 4f84e99..0878e6e 100644 --- a/backend/app/agents/brainstorm_prompt_plan/orchestrator.py +++ b/backend/app/agents/brainstorm_prompt_plan/orchestrator.py @@ -11,23 +11,22 @@ """ import logging -from typing import Callable, Dict, Any, List, Optional +from typing import Any, Callable, Dict, Optional from uuid import UUID +# Re-use the same exception from brainstorm_spec +from app.agents.brainstorm_spec import JobCancelledException + +from .analyzer import create_analyzer +from .logging_config import get_agent_logger +from .planner import create_planner +from .qa import create_qa from .types import ( + AGENT_METADATA, BrainstormPromptPlanContext, BrainstormPromptPlanResult, - AGENT_METADATA, - WORKFLOW_STEPS, ) -from .analyzer import create_analyzer -from .planner import create_planner from .writer import create_writer -from .qa import create_qa -from .logging_config import get_agent_logger - -# Re-use the same exception from brainstorm_spec -from app.agents.brainstorm_spec import JobCancelledException logger = logging.getLogger(__name__) @@ -76,8 +75,8 @@ def _check_cancelled(self) -> None: if not self.job_id: return - from app.services.job_service import JobService from app.database import SessionLocal + from app.services.job_service import JobService db = SessionLocal() try: @@ -88,10 +87,7 @@ def _check_cancelled(self) -> None: db.close() def _report_progress( - self, - workflow_step: str, - agent_key: Optional[str] = None, - extra_data: Optional[Dict[str, Any]] = None + self, workflow_step: str, agent_key: Optional[str] = None, extra_data: Optional[Dict[str, Any]] = None ) -> None: """ Report progress to the callback. @@ -139,10 +135,7 @@ def _report_progress( except Exception as e: self.logger.log_error(e, {"context": "progress_callback"}) - async def generate_brainstorm_prompt_plan( - self, - context: BrainstormPromptPlanContext - ) -> BrainstormPromptPlanResult: + async def generate_brainstorm_prompt_plan(self, context: BrainstormPromptPlanContext) -> BrainstormPromptPlanResult: """ Generate a complete brainstorm prompt plan. @@ -188,11 +181,7 @@ async def generate_brainstorm_prompt_plan( if self.call_logger: self.call_logger.set_agent("planner", "Planner") planner = await create_planner(self.model_client, project_id) - outline = await planner.create_outline( - analyzed_context, - context.clarification_questions, - project_id - ) + outline = await planner.create_outline(analyzed_context, context.clarification_questions, project_id) # Check for cancellation before next step self._check_cancelled() @@ -205,10 +194,7 @@ async def generate_brainstorm_prompt_plan( self.call_logger.set_agent("writer", "Writer") writer = await create_writer(self.model_client, project_id) prompt_plan = await writer.write_prompt_plan( - outline, - analyzed_context, - context.clarification_questions, - project_id + outline, analyzed_context, context.clarification_questions, project_id ) # Check for cancellation before next step @@ -221,12 +207,7 @@ async def generate_brainstorm_prompt_plan( if self.call_logger: self.call_logger.set_agent("qa", "QA Validator") qa = await create_qa(self.model_client, project_id) - validation_report = await qa.validate( - prompt_plan, - outline, - context.clarification_questions, - project_id - ) + validation_report = await qa.validate(prompt_plan, outline, context.clarification_questions, project_id) # Complete self._report_progress("complete", "orchestrator") @@ -276,7 +257,7 @@ async def create_orchestrator( Returns: Initialized BrainstormPromptPlanOrchestrator instance """ - from app.agents.llm_client import create_litellm_client, LLMCallLogger + from app.agents.llm_client import LLMCallLogger, create_litellm_client config = config or {} model = config.get("model") @@ -287,6 +268,7 @@ async def create_orchestrator( call_logger = None if job_id: from app.database import SessionLocal + call_logger = LLMCallLogger( db_session_factory=SessionLocal, job_id=job_id, diff --git a/backend/app/agents/brainstorm_prompt_plan/planner.py b/backend/app/agents/brainstorm_prompt_plan/planner.py index 73d5ae1..263be3e 100644 --- a/backend/app/agents/brainstorm_prompt_plan/planner.py +++ b/backend/app/agents/brainstorm_prompt_plan/planner.py @@ -13,15 +13,15 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( + ASPECT_CATEGORY_TO_SECTION_MAPPING, + FIXED_SECTIONS, AnalyzedContext, BrainstormPromptPlanOutline, - PromptPlanOutlineSection, PhaseDomainMapping, - FIXED_SECTIONS, - ASPECT_CATEGORY_TO_SECTION_MAPPING, + PromptPlanOutlineSection, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -161,10 +161,7 @@ def _get_system_message(self) -> str: Return ONLY the JSON object.""" async def create_outline( - self, - analyzed_context: AnalyzedContext, - clarification_questions: List[dict], - project_id: Optional[str] = None + self, analyzed_context: AnalyzedContext, clarification_questions: List[dict], project_id: Optional[str] = None ) -> BrainstormPromptPlanOutline: """ Create the prompt plan outline from analyzed context. @@ -180,10 +177,7 @@ async def create_outline( Raises: ValueError: If planning fails or returns invalid JSON """ - self.logger.log_agent_start( - project_id=project_id, - questions_count=len(clarification_questions) - ) + self.logger.log_agent_start(project_id=project_id, questions_count=len(clarification_questions)) try: # Build the prompt @@ -193,13 +187,10 @@ async def create_outline( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="create_prompt_plan_outline" + operation="create_prompt_plan_outline", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -212,10 +203,9 @@ async def create_outline( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) # Fall back to default outline return self._create_default_outline(analyzed_context, clarification_questions) @@ -230,12 +220,14 @@ async def create_outline( # Parse phase_domain_mapping phase_domain_mapping = [] for m_data in result_data.get("phase_domain_mapping", []): - phase_domain_mapping.append(PhaseDomainMapping( - phase_title=m_data.get("phase_title", ""), - phase_index=m_data.get("phase_index", 0), - domains=m_data.get("domains", []), - keywords=m_data.get("keywords", []), - )) + phase_domain_mapping.append( + PhaseDomainMapping( + phase_title=m_data.get("phase_title", ""), + phase_index=m_data.get("phase_index", 0), + domains=m_data.get("domains", []), + keywords=m_data.get("keywords", []), + ) + ) # Ensure all fixed sections are present outline = BrainstormPromptPlanOutline( @@ -246,7 +238,7 @@ async def create_outline( self.logger.log_agent_complete( sections_count=len(outline.sections), - total_subsections=sum(len(s.subsections) for s in outline.sections) + total_subsections=sum(len(s.subsections) for s in outline.sections), ) return outline @@ -256,11 +248,7 @@ async def create_outline( # Return default outline on error return self._create_default_outline(analyzed_context, clarification_questions) - def _build_prompt( - self, - analyzed_context: AnalyzedContext, - clarification_questions: List[dict] - ) -> str: + def _build_prompt(self, analyzed_context: AnalyzedContext, clarification_questions: List[dict]) -> str: """ Build the planning prompt. @@ -315,9 +303,9 @@ def _build_prompt( if clarification_questions: prompt += "### Clarification Questions (link these to sections):\n" for q in clarification_questions: - q_id = q.get('id', 'unknown') - q_title = q.get('title', 'Untitled') - q_category = q.get('category', 'General') + q_id = q.get("id", "unknown") + q_title = q.get("title", "Untitled") + q_category = q.get("category", "General") prompt += f"- [{q_id}] {q_title} (Category: {q_category})\n" prompt += "\n" @@ -356,13 +344,15 @@ def _ensure_all_sections(self, outline: BrainstormPromptPlanOutline) -> Brainsto for fixed in FIXED_SECTIONS: if fixed["id"] not in existing_ids: - outline.sections.append(PromptPlanOutlineSection( - id=fixed["id"], - title=fixed["title"], - description="", - subsections=[], - linked_questions=[], - )) + outline.sections.append( + PromptPlanOutlineSection( + id=fixed["id"], + title=fixed["title"], + description="", + subsections=[], + linked_questions=[], + ) + ) # Sort by fixed order section_order = {s["id"]: i for i, s in enumerate(FIXED_SECTIONS)} @@ -371,9 +361,7 @@ def _ensure_all_sections(self, outline: BrainstormPromptPlanOutline) -> Brainsto return outline def _create_default_outline( - self, - analyzed_context: AnalyzedContext, - clarification_questions: List[dict] + self, analyzed_context: AnalyzedContext, clarification_questions: List[dict] ) -> BrainstormPromptPlanOutline: """Create a default outline when LLM fails.""" sections = [] @@ -381,8 +369,8 @@ def _create_default_outline( # Create question-to-section mapping based on categories question_section_links = {} for q in clarification_questions: - q_id = q.get('id', '') - category = q.get('category', 'Business_Logic') + q_id = q.get("id", "") + category = q.get("category", "Business_Logic") target_sections = ASPECT_CATEGORY_TO_SECTION_MAPPING.get(category, []) for section_id in target_sections: if section_id.value not in question_section_links: @@ -390,21 +378,20 @@ def _create_default_outline( question_section_links[section_id.value].append(q_id) for fixed in FIXED_SECTIONS: - sections.append(PromptPlanOutlineSection( - id=fixed["id"], - title=fixed["title"], - description="", - subsections=[], - linked_questions=question_section_links.get(fixed["id"], []), - )) + sections.append( + PromptPlanOutlineSection( + id=fixed["id"], + title=fixed["title"], + description="", + subsections=[], + linked_questions=question_section_links.get(fixed["id"], []), + ) + ) return BrainstormPromptPlanOutline(sections=sections, phase_domain_mapping=[]) -async def create_planner( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> PlannerAgent: +async def create_planner(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> PlannerAgent: """ Factory function to create a Planner Agent. diff --git a/backend/app/agents/brainstorm_prompt_plan/qa.py b/backend/app/agents/brainstorm_prompt_plan/qa.py index 6d1c1e3..47ecf31 100644 --- a/backend/app/agents/brainstorm_prompt_plan/qa.py +++ b/backend/app/agents/brainstorm_prompt_plan/qa.py @@ -13,12 +13,12 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( BrainstormPromptPlan, BrainstormPromptPlanOutline, PromptPlanValidationReport, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -92,7 +92,7 @@ async def validate( prompt_plan: BrainstormPromptPlan, outline: BrainstormPromptPlanOutline, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> PromptPlanValidationReport: """ Validate the prompt plan for completeness and quality. @@ -112,7 +112,7 @@ async def validate( self.logger.log_agent_start( project_id=project_id, sections_count=len(prompt_plan.sections), - questions_count=len(clarification_questions) + questions_count=len(clarification_questions), ) try: @@ -123,13 +123,10 @@ async def validate( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="validate_prompt_plan" + operation="validate_prompt_plan", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -142,10 +139,9 @@ async def validate( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) # Return a default report on parse failure return self._create_fallback_report(prompt_plan, outline) @@ -176,7 +172,7 @@ def _build_prompt( self, prompt_plan: BrainstormPromptPlan, outline: BrainstormPromptPlanOutline, - clarification_questions: List[dict] + clarification_questions: List[dict], ) -> str: """ Build the validation prompt. @@ -195,14 +191,18 @@ def _build_prompt( prompt += "### Generated Prompt Plan:\n" for section in prompt_plan.sections: prompt += f"\n**{section.title}** (ID: {section.id})\n" - prompt += f"{section.body_markdown[:500]}...\n" if len(section.body_markdown) > 500 else f"{section.body_markdown}\n" + prompt += ( + f"{section.body_markdown[:500]}...\n" + if len(section.body_markdown) > 500 + else f"{section.body_markdown}\n" + ) if section.linked_questions: prompt += f"Linked Questions: {', '.join(section.linked_questions)}\n" prompt += "\n### Clarification Questions:\n" for q in clarification_questions: - q_id = q.get('id', 'unknown') - q_title = q.get('title', 'Untitled') + q_id = q.get("id", "unknown") + q_title = q.get("title", "Untitled") prompt += f"- [{q_id}] {q_title}\n" prompt += "\nValidate completeness, clarity, and dependencies. Return ONLY the JSON object." @@ -210,9 +210,7 @@ def _build_prompt( return prompt def _create_fallback_report( - self, - prompt_plan: BrainstormPromptPlan, - outline: BrainstormPromptPlanOutline + self, prompt_plan: BrainstormPromptPlan, outline: BrainstormPromptPlanOutline ) -> PromptPlanValidationReport: """ Create a fallback validation report using simple heuristics. @@ -235,16 +233,13 @@ def _create_fallback_report( missing_components=[], unclear_instructions=empty_sections, dependency_issues=[], - suggested_improvements=[ - "Fallback validation: Consider reviewing prompt plan for completeness" - ] if empty_sections else [], + suggested_improvements=["Fallback validation: Consider reviewing prompt plan for completeness"] + if empty_sections + else [], ) -async def create_qa( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> QAAgent: +async def create_qa(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> QAAgent: """ Factory function to create a QA Agent. diff --git a/backend/app/agents/brainstorm_prompt_plan/types.py b/backend/app/agents/brainstorm_prompt_plan/types.py index d94e9fc..51da56a 100644 --- a/backend/app/agents/brainstorm_prompt_plan/types.py +++ b/backend/app/agents/brainstorm_prompt_plan/types.py @@ -14,16 +14,17 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional from uuid import UUID - # ============================ # Enums # ============================ + class PromptPlanSectionId(str, Enum): """Fixed section IDs for the brainstorm prompt plan outline.""" + INTRODUCTION = "introduction" PROJECT_OVERVIEW = "project_overview" IMPLEMENTATION_PHASES = "implementation_phases" @@ -40,12 +41,14 @@ class PromptPlanSectionId(str, Enum): # Dataclasses # ============================ + @dataclass class BrainstormPromptPlanContext: """ Raw context passed to the Orchestrator for generating Brainstorm Prompt Plan. This is the full input data from a brainstorming phase. """ + project_id: UUID brainstorming_phase_id: UUID phase_title: str @@ -90,6 +93,7 @@ class AnalyzedContext: Output from the Analyzer Agent. Extracted implementation components from brainstorming data. """ + phase_summary: str implementation_goals: List[str] = field(default_factory=list) components_to_build: List[str] = field(default_factory=list) @@ -115,10 +119,11 @@ class PromptPlanOutlineSection: A single section in the prompt plan outline. Can have nested subsections and linked questions. """ + id: str title: str description: str = "" - subsections: List['PromptPlanOutlineSection'] = field(default_factory=list) + subsections: List["PromptPlanOutlineSection"] = field(default_factory=list) linked_questions: List[str] = field(default_factory=list) # Q&A-aware generation: tracks whether section has sufficient answered questions has_qa_backing: bool = True # Default True for backwards compatibility @@ -131,6 +136,7 @@ class PhaseDomainMapping: Mapping of a prompt plan phase to its relevant domains and keywords. Used by module_feature extraction to match requirements to phases. """ + phase_title: str phase_index: int domains: List[str] = field(default_factory=list) # e.g., ["Authentication", "Security"] @@ -143,6 +149,7 @@ class BrainstormPromptPlanOutline: Full prompt plan outline generated by the Planner Agent. Contains hierarchical section tree and phase-domain mapping for downstream extraction. """ + sections: List[PromptPlanOutlineSection] = field(default_factory=list) # Mapping of phases to domains/keywords for module_feature extraction phase_domain_mapping: List[PhaseDomainMapping] = field(default_factory=list) @@ -154,6 +161,7 @@ class PromptPlanSectionContent: A single prompt plan section with rendered markdown content. Output from the Writer Agent. """ + id: str title: str body_markdown: str @@ -166,6 +174,7 @@ class BrainstormPromptPlan: Complete prompt plan document generated by the Writer Agent. Contains all sections with rendered markdown. """ + sections: List[PromptPlanSectionContent] = field(default_factory=list) # Phase-domain mapping for downstream module_feature extraction phase_domain_mapping: List[PhaseDomainMapping] = field(default_factory=list) @@ -218,6 +227,7 @@ class PromptPlanValidationReport: Quality assurance report generated by the QA Agent. Validates completeness and actionability of the prompt plan. """ + ok: bool missing_components: List[str] = field(default_factory=list) unclear_instructions: List[str] = field(default_factory=list) @@ -231,6 +241,7 @@ class BrainstormPromptPlanResult: Final result returned by the Orchestrator. Includes the complete prompt plan, outline, and validation report. """ + prompt_plan: BrainstormPromptPlan outline: BrainstormPromptPlanOutline analyzed_context: AnalyzedContext @@ -241,12 +252,14 @@ class BrainstormPromptPlanResult: # Agent Metadata for UI # ============================ + @dataclass class AgentInfo: """ UI metadata for an agent in the Brainstorm Prompt Plan generation workflow. Used for progress tracking and visual representation. """ + name: str description: str color: str # Hex color for UI tag @@ -257,40 +270,33 @@ class AgentInfo: "orchestrator": AgentInfo( name="Orchestrator", description="Coordinating prompt plan generation workflow", - color="#8B5CF6" # Purple + color="#8B5CF6", # Purple ), "analyzer": AgentInfo( name="Analyzer", description="Extracting implementation components from brainstorming", - color="#3B82F6" # Blue + color="#3B82F6", # Blue ), "planner": AgentInfo( name="Planner", description="Creating prompt plan outline with implementation phases", - color="#10B981" # Green + color="#10B981", # Green ), "writer": AgentInfo( name="Writer", description="Generating instructional content for each section", - color="#F59E0B" # Amber + color="#F59E0B", # Amber ), "qa": AgentInfo( name="QA", description="Validating completeness and actionability", - color="#EC4899" # Pink + color="#EC4899", # Pink ), } # Workflow step definitions for progress tracking -WORKFLOW_STEPS = [ - "start", - "analyzing", - "planning", - "writing", - "validating", - "complete" -] +WORKFLOW_STEPS = ["start", "analyzing", "planning", "writing", "validating", "complete"] # ============================ @@ -346,6 +352,7 @@ class AgentInfo: # Helper Functions # ============================ + def get_section_by_id(sections: List[PromptPlanSectionContent], section_id: str) -> Optional[PromptPlanSectionContent]: """Get a section by its ID.""" for section in sections: @@ -354,7 +361,9 @@ def get_section_by_id(sections: List[PromptPlanSectionContent], section_id: str) return None -def get_outline_section_by_id(sections: List[PromptPlanOutlineSection], section_id: str) -> Optional[PromptPlanOutlineSection]: +def get_outline_section_by_id( + sections: List[PromptPlanOutlineSection], section_id: str +) -> Optional[PromptPlanOutlineSection]: """Get an outline section by its ID (supports nested subsections).""" for section in sections: if section.id == section_id: diff --git a/backend/app/agents/brainstorm_prompt_plan/utils.py b/backend/app/agents/brainstorm_prompt_plan/utils.py index 9818f3e..4ce4b13 100644 --- a/backend/app/agents/brainstorm_prompt_plan/utils.py +++ b/backend/app/agents/brainstorm_prompt_plan/utils.py @@ -5,9 +5,9 @@ import re # Import from common module and re-export for backwards compatibility -from app.agents.response_parser import strip_markdown_json, normalize_response_content +from app.agents.response_parser import normalize_response_content, strip_markdown_json -__all__ = ['strip_markdown_json', 'normalize_response_content', 'truncate_text', 'normalize_whitespace'] +__all__ = ["strip_markdown_json", "normalize_response_content", "truncate_text", "normalize_whitespace"] def truncate_text(text: str, max_length: int = 500) -> str: @@ -25,7 +25,7 @@ def truncate_text(text: str, max_length: int = 500) -> str: return text # Truncate at word boundary - truncated = text[:max_length].rsplit(' ', 1)[0] + truncated = text[:max_length].rsplit(" ", 1)[0] return truncated + "..." @@ -40,5 +40,5 @@ def normalize_whitespace(text: str) -> str: Normalized text """ # Replace multiple whitespace with single space - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) return text.strip() diff --git a/backend/app/agents/brainstorm_prompt_plan/writer.py b/backend/app/agents/brainstorm_prompt_plan/writer.py index 2028504..16ecf9a 100644 --- a/backend/app/agents/brainstorm_prompt_plan/writer.py +++ b/backend/app/agents/brainstorm_prompt_plan/writer.py @@ -6,7 +6,6 @@ """ import asyncio -import json from typing import List, Optional from autogen_agentchat.agents import AssistantAgent @@ -14,15 +13,14 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( AnalyzedContext, - BrainstormPromptPlanOutline, BrainstormPromptPlan, - PromptPlanSectionContent, + BrainstormPromptPlanOutline, PromptPlanOutlineSection, + PromptPlanSectionContent, ) -from .logging_config import get_agent_logger -from .utils import strip_markdown_json class WriterAgent: @@ -126,7 +124,7 @@ async def write_prompt_plan( outline: BrainstormPromptPlanOutline, analyzed_context: AnalyzedContext, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> BrainstormPromptPlan: """ Write the full prompt plan from outline and context. @@ -143,19 +141,11 @@ async def write_prompt_plan( Raises: Exception: If writing fails for all sections """ - self.logger.log_agent_start( - project_id=project_id, - sections_count=len(outline.sections) - ) + self.logger.log_agent_start(project_id=project_id, sections_count=len(outline.sections)) # Create tasks for ALL sections to run in parallel tasks = [ - self._write_section( - section, - analyzed_context, - clarification_questions, - project_id - ) + self._write_section(section, analyzed_context, clarification_questions, project_id) for section in outline.sections ] @@ -169,12 +159,14 @@ async def write_prompt_plan( if isinstance(result, Exception): self.logger.log_error(result, {"section_id": section.id}) # Create a placeholder for failed sections - sections.append(PromptPlanSectionContent( - id=section.id, - title=section.title, - body_markdown=f"*Content generation failed: {str(result)}*", - linked_questions=section.linked_questions, - )) + sections.append( + PromptPlanSectionContent( + id=section.id, + title=section.title, + body_markdown=f"*Content generation failed: {str(result)}*", + linked_questions=section.linked_questions, + ) + ) else: sections.append(result) @@ -185,8 +177,7 @@ async def write_prompt_plan( ) self.logger.log_agent_complete( - sections_written=len(sections), - total_markdown_length=sum(len(s.body_markdown) for s in sections) + sections_written=len(sections), total_markdown_length=sum(len(s.body_markdown) for s in sections) ) return prompt_plan @@ -196,7 +187,7 @@ async def _write_section( section: PromptPlanOutlineSection, analyzed_context: AnalyzedContext, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> PromptPlanSectionContent: """ Write content for a single section. @@ -218,7 +209,7 @@ async def _write_section( prompt=prompt[:300] + "..." if len(prompt) > 300 else prompt, model=str(self.model_client), operation="write_section", - section_id=section.id + section_id=section.id, ) # Create a FRESH agent for each section to avoid conversation history accumulation @@ -230,10 +221,7 @@ async def _write_section( model_client=self.model_client, ) - response = await section_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await section_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -257,10 +245,7 @@ async def _write_section( ) def _build_section_prompt( - self, - section: PromptPlanOutlineSection, - analyzed_context: AnalyzedContext, - clarification_questions: List[dict] + self, section: PromptPlanOutlineSection, analyzed_context: AnalyzedContext, clarification_questions: List[dict] ) -> str: """ Build the prompt for writing a section. @@ -279,10 +264,10 @@ def _build_section_prompt( prompt += f"Section Description: {section.description}\n\n" # Check if section has Q&A backing - if not, generate placeholder - if hasattr(section, 'has_qa_backing') and not section.has_qa_backing: + if hasattr(section, "has_qa_backing") and not section.has_qa_backing: prompt += "### BLOCKED ON PENDING TOPICS:\n" prompt += "This section lacks sufficient Q&A backing. The following topics need user decisions:\n" - if hasattr(section, 'blocked_on') and section.blocked_on: + if hasattr(section, "blocked_on") and section.blocked_on: for topic in section.blocked_on: prompt += f"- {topic}\n" else: @@ -376,20 +361,20 @@ def _build_section_prompt( for q_id in section.linked_questions: # Find the question for q in clarification_questions: - if q.get('id') == q_id: + if q.get("id") == q_id: prompt += f"- **{q.get('title', 'Untitled')}**\n" prompt += f" {q.get('spec_text', q.get('description', ''))}\n" # Include answer if available - mcq_data = q.get('mcq_data', {}) + mcq_data = q.get("mcq_data", {}) if mcq_data: - selected = mcq_data.get('selected_option_id') + selected = mcq_data.get("selected_option_id") if selected: - choices = mcq_data.get('choices', []) + choices = mcq_data.get("choices", []) for choice in choices: - if choice.get('id') == selected: + if choice.get("id") == selected: prompt += f" Answer: {choice.get('label', selected)}\n" break - free_text = mcq_data.get('free_text') + free_text = mcq_data.get("free_text") if free_text: prompt += f" Additional: {free_text}\n" break @@ -420,10 +405,7 @@ def _build_section_prompt( return prompt -async def create_writer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> WriterAgent: +async def create_writer(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> WriterAgent: """ Factory function to create a Writer Agent. diff --git a/backend/app/agents/brainstorm_spec/__init__.py b/backend/app/agents/brainstorm_spec/__init__.py index 2bcb537..35bdd15 100644 --- a/backend/app/agents/brainstorm_spec/__init__.py +++ b/backend/app/agents/brainstorm_spec/__init__.py @@ -11,56 +11,51 @@ 4. QA/Coverage - Validates completeness and coverage """ -from .types import ( - # Enums - BrainstormSpecSectionId, - # Dataclasses - BrainstormSpecContext, - NormalizedBrainstormContext, - SpecOutlineSection, - BrainstormSpecOutline, - SpecSectionContent, - BrainstormSpecification, - CoverageReport, - BrainstormSpecResult, - AgentInfo, - # Constants - AGENT_METADATA, - WORKFLOW_STEPS, - FIXED_SECTIONS, - ASPECT_CATEGORY_TO_SECTION_MAPPING, - # Helper functions - get_section_by_id, - get_outline_section_by_id, - build_section_tree, -) - from .orchestrator import ( BrainstormSpecOrchestrator, - create_orchestrator, JobCancelledException, + create_orchestrator, +) +from .planner import ( + PlannerAgent, + create_planner, +) +from .qa_coverage import ( + QACoverageAgent, + create_qa_coverage, ) - from .summarizer import ( SummarizerAgent, create_summarizer, ) - -from .planner import ( - PlannerAgent, - create_planner, +from .types import ( + # Constants + AGENT_METADATA, + ASPECT_CATEGORY_TO_SECTION_MAPPING, + FIXED_SECTIONS, + WORKFLOW_STEPS, + AgentInfo, + # Dataclasses + BrainstormSpecContext, + BrainstormSpecification, + BrainstormSpecOutline, + BrainstormSpecResult, + # Enums + BrainstormSpecSectionId, + CoverageReport, + NormalizedBrainstormContext, + SpecOutlineSection, + SpecSectionContent, + build_section_tree, + get_outline_section_by_id, + # Helper functions + get_section_by_id, ) - from .writer import ( WriterAgent, create_writer, ) -from .qa_coverage import ( - QACoverageAgent, - create_qa_coverage, -) - __all__ = [ # Enums "BrainstormSpecSectionId", diff --git a/backend/app/agents/brainstorm_spec/logging_config.py b/backend/app/agents/brainstorm_spec/logging_config.py index dd081c0..8d91cea 100644 --- a/backend/app/agents/brainstorm_spec/logging_config.py +++ b/backend/app/agents/brainstorm_spec/logging_config.py @@ -4,10 +4,10 @@ Provides structured logging for all agent decisions, LLM calls, and workflow steps. """ -import logging import json -from typing import Any, Dict, Optional +import logging from datetime import datetime, timezone +from typing import Any, Dict, Optional class BrainstormAgentLogger: @@ -33,19 +33,12 @@ def __init__(self, agent_name: str, project_id: Optional[str] = None): # Ensure structured output if not self.logger.handlers: handler = logging.StreamHandler() - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) - def _structured_log( - self, - level: str, - event: str, - extra_data: Optional[Dict[str, Any]] = None - ) -> None: + def _structured_log(self, level: str, event: str, extra_data: Optional[Dict[str, Any]] = None) -> None: """ Log a structured event. @@ -78,12 +71,7 @@ def log_agent_complete(self, **kwargs: Any) -> None: self._structured_log("info", f"{self.agent_name}_complete", kwargs) def log_llm_call( - self, - prompt: str, - model: str, - response: Optional[str] = None, - tokens_used: Optional[int] = None, - **kwargs: Any + self, prompt: str, model: str, response: Optional[str] = None, tokens_used: Optional[int] = None, **kwargs: Any ) -> None: """ Log an LLM API call. @@ -99,7 +87,7 @@ def log_llm_call( "model": model, "prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt, "prompt_length": len(prompt), - **kwargs + **kwargs, } if response: @@ -120,11 +108,7 @@ def log_decision(self, decision: str, rationale: str, **kwargs: Any) -> None: rationale: Explanation of why **kwargs: Additional context """ - data = { - "decision": decision, - "rationale": rationale, - **kwargs - } + data = {"decision": decision, "rationale": rationale, **kwargs} self._structured_log("info", "agent_decision", data) def log_pruning_stats( @@ -134,7 +118,7 @@ def log_pruning_stats( initial_questions: int, final_questions: int, pruned_items: list, - **kwargs: Any + **kwargs: Any, ) -> None: """ Log aspect and question pruning statistics. @@ -155,7 +139,7 @@ def log_pruning_stats( "pruned_aspects": initial_aspects - final_aspects, "pruned_questions": initial_questions - final_questions, "pruned_items": pruned_items, - **kwargs + **kwargs, } self._structured_log("info", "pruning_stats", data) @@ -186,11 +170,7 @@ def log_workflow_transition(self, from_state: str, to_state: str, **kwargs: Any) to_state: New state **kwargs: Additional context """ - data = { - "from_state": from_state, - "to_state": to_state, - **kwargs - } + data = {"from_state": from_state, "to_state": to_state, **kwargs} self._structured_log("info", "workflow_transition", data) diff --git a/backend/app/agents/brainstorm_spec/orchestrator.py b/backend/app/agents/brainstorm_spec/orchestrator.py index 0b83fc4..5697e41 100644 --- a/backend/app/agents/brainstorm_spec/orchestrator.py +++ b/backend/app/agents/brainstorm_spec/orchestrator.py @@ -11,26 +11,26 @@ """ import logging -from typing import Callable, Dict, Any, List, Optional +from typing import Any, Callable, Dict, Optional from uuid import UUID +from .logging_config import get_agent_logger +from .planner import create_planner +from .qa_coverage import create_qa_coverage +from .summarizer import create_summarizer from .types import ( + AGENT_METADATA, BrainstormSpecContext, BrainstormSpecResult, - AGENT_METADATA, - WORKFLOW_STEPS, ) -from .summarizer import create_summarizer -from .planner import create_planner from .writer import create_writer -from .qa_coverage import create_qa_coverage -from .logging_config import get_agent_logger logger = logging.getLogger(__name__) class JobCancelledException(Exception): """Raised when a job is cancelled during pipeline execution.""" + pass @@ -78,8 +78,8 @@ def _check_cancelled(self) -> None: if not self.job_id: return - from app.services.job_service import JobService from app.database import SessionLocal + from app.services.job_service import JobService db = SessionLocal() try: @@ -90,10 +90,7 @@ def _check_cancelled(self) -> None: db.close() def _report_progress( - self, - workflow_step: str, - agent_key: Optional[str] = None, - extra_data: Optional[Dict[str, Any]] = None + self, workflow_step: str, agent_key: Optional[str] = None, extra_data: Optional[Dict[str, Any]] = None ) -> None: """ Report progress to the callback. @@ -141,10 +138,7 @@ def _report_progress( except Exception as e: self.logger.log_error(e, {"context": "progress_callback"}) - async def generate_brainstorm_spec( - self, - context: BrainstormSpecContext - ) -> BrainstormSpecResult: + async def generate_brainstorm_spec(self, context: BrainstormSpecContext) -> BrainstormSpecResult: """ Generate a complete brainstorm specification. @@ -190,11 +184,7 @@ async def generate_brainstorm_spec( if self.call_logger: self.call_logger.set_agent("planner", "Planner") planner = await create_planner(self.model_client, project_id) - outline = await planner.create_outline( - normalized_context, - context.clarification_questions, - project_id - ) + outline = await planner.create_outline(normalized_context, context.clarification_questions, project_id) # Check for cancellation before next step self._check_cancelled() @@ -207,10 +197,7 @@ async def generate_brainstorm_spec( self.call_logger.set_agent("writer", "Writer") writer = await create_writer(self.model_client, project_id) specification = await writer.write_specification( - outline, - normalized_context, - context.clarification_questions, - project_id + outline, normalized_context, context.clarification_questions, project_id ) # Check for cancellation before next step @@ -223,12 +210,7 @@ async def generate_brainstorm_spec( if self.call_logger: self.call_logger.set_agent("qa_coverage", "QA Coverage") qa = await create_qa_coverage(self.model_client, project_id) - coverage_report = await qa.validate( - specification, - outline, - context.clarification_questions, - project_id - ) + coverage_report = await qa.validate(specification, outline, context.clarification_questions, project_id) # Complete self._report_progress("complete", "orchestrator") @@ -278,7 +260,7 @@ async def create_orchestrator( Returns: Initialized BrainstormSpecOrchestrator instance """ - from app.agents.llm_client import create_litellm_client, LLMCallLogger + from app.agents.llm_client import LLMCallLogger, create_litellm_client config = config or {} model = config.get("model") @@ -289,6 +271,7 @@ async def create_orchestrator( call_logger = None if job_id: from app.database import SessionLocal + call_logger = LLMCallLogger( db_session_factory=SessionLocal, job_id=job_id, diff --git a/backend/app/agents/brainstorm_spec/planner.py b/backend/app/agents/brainstorm_spec/planner.py index 8182cb3..d21e251 100644 --- a/backend/app/agents/brainstorm_spec/planner.py +++ b/backend/app/agents/brainstorm_spec/planner.py @@ -13,14 +13,14 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( - NormalizedBrainstormContext, + ASPECT_CATEGORY_TO_SECTION_MAPPING, + FIXED_SECTIONS, BrainstormSpecOutline, + NormalizedBrainstormContext, SpecOutlineSection, - FIXED_SECTIONS, - ASPECT_CATEGORY_TO_SECTION_MAPPING, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -133,7 +133,7 @@ async def create_outline( self, normalized_context: NormalizedBrainstormContext, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> BrainstormSpecOutline: """ Create the specification outline from normalized context. @@ -149,10 +149,7 @@ async def create_outline( Raises: ValueError: If planning fails or returns invalid JSON """ - self.logger.log_agent_start( - project_id=project_id, - questions_count=len(clarification_questions) - ) + self.logger.log_agent_start(project_id=project_id, questions_count=len(clarification_questions)) try: # Build the prompt @@ -162,13 +159,10 @@ async def create_outline( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="create_spec_outline" + operation="create_spec_outline", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -181,10 +175,9 @@ async def create_outline( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) # Fall back to default outline return self._create_default_outline(clarification_questions) @@ -202,7 +195,7 @@ async def create_outline( self.logger.log_agent_complete( sections_count=len(outline.sections), - total_subsections=sum(len(s.subsections) for s in outline.sections) + total_subsections=sum(len(s.subsections) for s in outline.sections), ) return outline @@ -213,9 +206,7 @@ async def create_outline( return self._create_default_outline(clarification_questions) def _build_prompt( - self, - normalized_context: NormalizedBrainstormContext, - clarification_questions: List[dict] + self, normalized_context: NormalizedBrainstormContext, clarification_questions: List[dict] ) -> str: """ Build the planning prompt. @@ -257,9 +248,9 @@ def _build_prompt( if clarification_questions: prompt += "### Clarification Questions (link these to sections):\n" for q in clarification_questions: - q_id = q.get('id', 'unknown') - q_title = q.get('title', 'Untitled') - q_category = q.get('category', 'General') + q_id = q.get("id", "unknown") + q_title = q.get("title", "Untitled") + q_category = q.get("category", "General") prompt += f"- [{q_id}] {q_title} (Category: {q_category})\n" prompt += "\n" @@ -271,7 +262,9 @@ def _build_prompt( prompt += "\n" prompt += "Create the outline with all 11 sections. " - prompt += "For sections affected by unanswered topics, set has_qa_backing: false and list pending_clarifications. " + prompt += ( + "For sections affected by unanswered topics, set has_qa_backing: false and list pending_clarifications. " + ) prompt += "Return ONLY the JSON object." return prompt @@ -298,13 +291,15 @@ def _ensure_all_sections(self, outline: BrainstormSpecOutline) -> BrainstormSpec for fixed in FIXED_SECTIONS: if fixed["id"] not in existing_ids: - outline.sections.append(SpecOutlineSection( - id=fixed["id"], - title=fixed["title"], - description="", - subsections=[], - linked_questions=[], - )) + outline.sections.append( + SpecOutlineSection( + id=fixed["id"], + title=fixed["title"], + description="", + subsections=[], + linked_questions=[], + ) + ) # Sort by fixed order section_order = {s["id"]: i for i, s in enumerate(FIXED_SECTIONS)} @@ -319,8 +314,8 @@ def _create_default_outline(self, clarification_questions: List[dict]) -> Brains # Create question-to-section mapping based on categories question_section_links = {} for q in clarification_questions: - q_id = q.get('id', '') - category = q.get('category', 'Business_Logic') + q_id = q.get("id", "") + category = q.get("category", "Business_Logic") target_sections = ASPECT_CATEGORY_TO_SECTION_MAPPING.get(category, []) for section_id in target_sections: if section_id.value not in question_section_links: @@ -328,21 +323,20 @@ def _create_default_outline(self, clarification_questions: List[dict]) -> Brains question_section_links[section_id.value].append(q_id) for fixed in FIXED_SECTIONS: - sections.append(SpecOutlineSection( - id=fixed["id"], - title=fixed["title"], - description="", - subsections=[], - linked_questions=question_section_links.get(fixed["id"], []), - )) + sections.append( + SpecOutlineSection( + id=fixed["id"], + title=fixed["title"], + description="", + subsections=[], + linked_questions=question_section_links.get(fixed["id"], []), + ) + ) return BrainstormSpecOutline(sections=sections) -async def create_planner( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> PlannerAgent: +async def create_planner(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> PlannerAgent: """ Factory function to create a Planner Agent. diff --git a/backend/app/agents/brainstorm_spec/qa_coverage.py b/backend/app/agents/brainstorm_spec/qa_coverage.py index e5bc4eb..ec2beb7 100644 --- a/backend/app/agents/brainstorm_spec/qa_coverage.py +++ b/backend/app/agents/brainstorm_spec/qa_coverage.py @@ -13,12 +13,12 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( BrainstormSpecification, BrainstormSpecOutline, CoverageReport, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -92,7 +92,7 @@ async def validate( specification: BrainstormSpecification, outline: BrainstormSpecOutline, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> CoverageReport: """ Validate the specification for coverage and quality. @@ -112,7 +112,7 @@ async def validate( self.logger.log_agent_start( project_id=project_id, sections_count=len(specification.sections), - questions_count=len(clarification_questions) + questions_count=len(clarification_questions), ) try: @@ -123,13 +123,10 @@ async def validate( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="validate_coverage" + operation="validate_coverage", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -142,14 +139,11 @@ async def validate( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) - # Return a default report on parse failure - return self._create_fallback_report( - specification, outline, clarification_questions + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} ) + # Return a default report on parse failure + return self._create_fallback_report(specification, outline, clarification_questions) # Convert to CoverageReport report = CoverageReport( @@ -172,15 +166,13 @@ async def validate( except Exception as e: self.logger.log_error(e, {"project_id": project_id}) # Return fallback report on error - return self._create_fallback_report( - specification, outline, clarification_questions - ) + return self._create_fallback_report(specification, outline, clarification_questions) def _build_prompt( self, specification: BrainstormSpecification, outline: BrainstormSpecOutline, - clarification_questions: List[dict] + clarification_questions: List[dict], ) -> str: """ Build the validation prompt. @@ -199,18 +191,22 @@ def _build_prompt( prompt += "### Generated Specification:\n" for section in specification.sections: prompt += f"\n**{section.title}** (ID: {section.id})\n" - prompt += f"{section.body_markdown[:500]}...\n" if len(section.body_markdown) > 500 else f"{section.body_markdown}\n" + prompt += ( + f"{section.body_markdown[:500]}...\n" + if len(section.body_markdown) > 500 + else f"{section.body_markdown}\n" + ) if section.linked_questions: prompt += f"Linked Questions: {', '.join(section.linked_questions)}\n" prompt += "\n### Clarification Questions:\n" must_have_questions = [] for q in clarification_questions: - q_id = q.get('id', 'unknown') - q_title = q.get('title', 'Untitled') - q_priority = q.get('priority', 'optional') + q_id = q.get("id", "unknown") + q_title = q.get("title", "Untitled") + q_priority = q.get("priority", "optional") prompt += f"- [{q_id}] {q_title} (Priority: {q_priority})\n" - if q_priority == 'must_have': + if q_priority == "must_have": must_have_questions.append(q_id) if must_have_questions: @@ -224,7 +220,7 @@ def _create_fallback_report( self, specification: BrainstormSpecification, outline: BrainstormSpecOutline, - clarification_questions: List[dict] + clarification_questions: List[dict], ) -> CoverageReport: """ Create a fallback coverage report using simple heuristics. @@ -247,9 +243,9 @@ def _create_fallback_report( # Find uncovered must-have questions uncovered_must_have = [] for q in clarification_questions: - q_id = q.get('id', '') - priority = q.get('priority', 'optional') - if priority == 'must_have' and q_id not in covered_questions: + q_id = q.get("id", "") + priority = q.get("priority", "optional") + if priority == "must_have" and q_id not in covered_questions: uncovered_must_have.append(q_id) return CoverageReport( @@ -257,16 +253,13 @@ def _create_fallback_report( uncovered_must_have_questions=uncovered_must_have, weak_coverage_warnings=[], contradictions_found=[], - suggested_improvements=[ - "Fallback validation: Consider reviewing specification for completeness" - ] if uncovered_must_have else [], + suggested_improvements=["Fallback validation: Consider reviewing specification for completeness"] + if uncovered_must_have + else [], ) -async def create_qa_coverage( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> QACoverageAgent: +async def create_qa_coverage(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> QACoverageAgent: """ Factory function to create a QA/Coverage Agent. diff --git a/backend/app/agents/brainstorm_spec/summarizer.py b/backend/app/agents/brainstorm_spec/summarizer.py index 18a7bf2..1304c8d 100644 --- a/backend/app/agents/brainstorm_spec/summarizer.py +++ b/backend/app/agents/brainstorm_spec/summarizer.py @@ -14,11 +14,11 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( BrainstormSpecContext, NormalizedBrainstormContext, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -100,9 +100,7 @@ def _get_system_message(self) -> str: Return ONLY the JSON object, no additional text.""" async def normalize( - self, - context: BrainstormSpecContext, - project_id: Optional[str] = None + self, context: BrainstormSpecContext, project_id: Optional[str] = None ) -> NormalizedBrainstormContext: """ Normalize brainstorming data into structured context. @@ -122,7 +120,7 @@ async def normalize( phase_title=context.phase_title, aspects_count=len(context.aspects), questions_count=len(context.clarification_questions), - threads_count=len(context.thread_discussions) + threads_count=len(context.thread_discussions), ) try: @@ -133,13 +131,10 @@ async def normalize( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="normalize_brainstorm_context" + operation="normalize_brainstorm_context", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -152,10 +147,9 @@ async def normalize( try: result_data = json.loads(cleaned_response) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned_response[:500] - }) + self.logger.log_error( + e, {"raw_response": response_text[:500], "cleaned_response": cleaned_response[:500]} + ) raise ValueError(f"Failed to parse summarizer response as JSON: {e}") # Convert to NormalizedBrainstormContext @@ -199,7 +193,7 @@ def _build_prompt(self, context: BrainstormSpecContext) -> str: Returns: Formatted prompt string """ - prompt = f"Normalize the following brainstorming data into structured context.\n\n" + prompt = "Normalize the following brainstorming data into structured context.\n\n" # Add phase info prompt += f"### Phase Title: {context.phase_title}\n" @@ -222,11 +216,11 @@ def _build_prompt(self, context: BrainstormSpecContext) -> str: prompt += f" Question: {q.get('spec_text', q.get('description', 'No description'))}\n" # Include ALL answered MCQs for this feature - answered_mcqs = q.get('answered_mcqs', []) + answered_mcqs = q.get("answered_mcqs", []) for mcq in answered_mcqs: - question_text = mcq.get('question_text', '') - selected_label = mcq.get('selected_label', '') - free_text = mcq.get('free_text') + question_text = mcq.get("question_text", "") + selected_label = mcq.get("selected_label", "") + free_text = mcq.get("free_text") prompt += f" - MCQ: {question_text}\n" prompt += f" Answer: {selected_label}\n" @@ -240,22 +234,22 @@ def _build_prompt(self, context: BrainstormSpecContext) -> str: for thread in context.thread_discussions: prompt += f"- **{thread.get('title', 'Untitled Thread')}**\n" # Prefer decision summary if available - if thread.get('decision_summary'): + if thread.get("decision_summary"): prompt += f" {thread['decision_summary']}\n" - if thread.get('unresolved_points'): + if thread.get("unresolved_points"): prompt += " **Unresolved:**\n" - for point in thread['unresolved_points']: - question = point.get('question', '') if isinstance(point, dict) else str(point) + for point in thread["unresolved_points"]: + question = point.get("question", "") if isinstance(point, dict) else str(point) prompt += f" - {question}\n" else: # Fallback to raw comments if no decision summary - comments = thread.get('comments', []) + comments = thread.get("comments", []) for comment in comments[:5]: # Limit to first 5 comments per thread prompt += f" - {comment.get('content', '')[:200]}\n" # Note any images attached to this thread - if thread.get('images'): + if thread.get("images"): prompt += " **Attached Images:**\n" - for img in thread['images'][:5]: # Limit to first 5 images + for img in thread["images"][:5]: # Limit to first 5 images prompt += f" - {img.get('filename', 'unnamed')} (ID: {img.get('id', 'unknown')})\n" prompt += "\n" @@ -292,10 +286,7 @@ def _build_sibling_implementation_section(sibling_phases_context) -> str: if not sibling_phases_context: return "" - phases_with_analysis = [ - p for p in sibling_phases_context.sibling_phases - if p.implementation_analysis - ] + phases_with_analysis = [p for p in sibling_phases_context.sibling_phases if p.implementation_analysis] if not phases_with_analysis: return "" @@ -312,10 +303,7 @@ def _build_sibling_implementation_section(sibling_phases_context) -> str: return "\n".join(sections) -async def create_summarizer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> SummarizerAgent: +async def create_summarizer(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> SummarizerAgent: """ Factory function to create a Summarizer Agent. diff --git a/backend/app/agents/brainstorm_spec/types.py b/backend/app/agents/brainstorm_spec/types.py index 77faba6..34f3e30 100644 --- a/backend/app/agents/brainstorm_spec/types.py +++ b/backend/app/agents/brainstorm_spec/types.py @@ -14,16 +14,17 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional from uuid import UUID - # ============================ # Enums # ============================ + class BrainstormSpecSectionId(str, Enum): """Fixed section IDs for the brainstorm specification outline.""" + EXECUTIVE_SUMMARY = "executive_summary" PROBLEM_STATEMENT = "problem_statement" GOALS_AND_NON_GOALS = "goals_and_non_goals" @@ -41,12 +42,14 @@ class BrainstormSpecSectionId(str, Enum): # Dataclasses # ============================ + @dataclass class BrainstormSpecContext: """ Raw context passed to the Orchestrator for generating Brainstorm Specification. This is the full input data from a brainstorming phase. """ + project_id: UUID brainstorming_phase_id: UUID phase_title: str @@ -84,6 +87,7 @@ class NormalizedBrainstormContext: Output from the Summarizer/Normalizer Agent. Condensed summaries of all brainstorming inputs (300-600 tokens per category). """ + phase_summary: str key_objectives: List[str] = field(default_factory=list) user_personas: List[str] = field(default_factory=list) @@ -111,10 +115,11 @@ class SpecOutlineSection: A single section in the specification outline. Can have nested subsections and linked clarification questions. """ + id: str title: str description: str = "" - subsections: List['SpecOutlineSection'] = field(default_factory=list) + subsections: List["SpecOutlineSection"] = field(default_factory=list) linked_questions: List[str] = field(default_factory=list) # Question IDs # Q&A-aware generation: tracks whether section has sufficient answered questions has_qa_backing: bool = True # Default True for backwards compatibility @@ -127,6 +132,7 @@ class BrainstormSpecOutline: Full specification outline generated by the Planner Agent. Contains hierarchical section tree with linked questions. """ + sections: List[SpecOutlineSection] = field(default_factory=list) @@ -136,6 +142,7 @@ class SpecSectionContent: A single specification section with rendered markdown content. Output from the Writer Agent. """ + id: str title: str body_markdown: str @@ -149,6 +156,7 @@ class BrainstormSpecification: Complete specification document generated by the Writer Agent. Contains all sections with rendered markdown. """ + sections: List[SpecSectionContent] = field(default_factory=list) def to_markdown(self) -> str: @@ -192,7 +200,7 @@ def to_summary_markdown(self) -> str: # Fallback: use first 300 chars of body if no summary truncated = section.body_markdown[:300] if len(section.body_markdown) > 300: - truncated = truncated.rsplit(' ', 1)[0] + "..." + truncated = truncated.rsplit(" ", 1)[0] + "..." lines.append(truncated) lines.append("") return "\n".join(lines) @@ -204,6 +212,7 @@ class CoverageReport: Quality assurance and coverage report generated by the QA Agent. Validates completeness, consistency, and quality of the specification. """ + ok: bool uncovered_must_have_questions: List[str] = field(default_factory=list) weak_coverage_warnings: List[str] = field(default_factory=list) @@ -217,6 +226,7 @@ class BrainstormSpecResult: Final result returned by the Orchestrator. Includes the complete specification, outline, and coverage report. """ + specification: BrainstormSpecification outline: BrainstormSpecOutline normalized_context: NormalizedBrainstormContext @@ -227,12 +237,14 @@ class BrainstormSpecResult: # Agent Metadata for UI # ============================ + @dataclass class AgentInfo: """ UI metadata for an agent in the Brainstorm Spec generation workflow. Used for progress tracking and visual representation. """ + name: str description: str color: str # Hex color for UI tag @@ -243,40 +255,33 @@ class AgentInfo: "orchestrator": AgentInfo( name="Orchestrator", description="Coordinating specification generation workflow", - color="#8B5CF6" # Purple + color="#8B5CF6", # Purple ), "summarizer": AgentInfo( name="Summarizer", description="Normalizing brainstorm discussions into structured context", - color="#3B82F6" # Blue + color="#3B82F6", # Blue ), "planner": AgentInfo( name="Planner", description="Creating specification outline with linked questions", - color="#10B981" # Green + color="#10B981", # Green ), "writer": AgentInfo( name="Writer", description="Generating specification content for each section", - color="#F59E0B" # Amber + color="#F59E0B", # Amber ), "qa_coverage": AgentInfo( name="QA/Coverage", description="Validating completeness and coverage of specification", - color="#EC4899" # Pink + color="#EC4899", # Pink ), } # Workflow step definitions for progress tracking -WORKFLOW_STEPS = [ - "start", - "normalizing", - "planning", - "writing", - "validating", - "complete" -] +WORKFLOW_STEPS = ["start", "normalizing", "planning", "writing", "validating", "complete"] # ============================ @@ -335,6 +340,7 @@ class AgentInfo: # Helper Functions # ============================ + def get_section_by_id(sections: List[SpecSectionContent], section_id: str) -> Optional[SpecSectionContent]: """Get a section by its ID.""" for section in sections: diff --git a/backend/app/agents/brainstorm_spec/utils.py b/backend/app/agents/brainstorm_spec/utils.py index 37b931a..c1b8642 100644 --- a/backend/app/agents/brainstorm_spec/utils.py +++ b/backend/app/agents/brainstorm_spec/utils.py @@ -5,9 +5,9 @@ import re # Import from common module and re-export for backwards compatibility -from app.agents.response_parser import strip_markdown_json, normalize_response_content +from app.agents.response_parser import normalize_response_content, strip_markdown_json -__all__ = ['strip_markdown_json', 'normalize_response_content', 'truncate_text', 'normalize_whitespace'] +__all__ = ["strip_markdown_json", "normalize_response_content", "truncate_text", "normalize_whitespace"] def truncate_text(text: str, max_length: int = 500) -> str: @@ -25,7 +25,7 @@ def truncate_text(text: str, max_length: int = 500) -> str: return text # Truncate at word boundary - truncated = text[:max_length].rsplit(' ', 1)[0] + truncated = text[:max_length].rsplit(" ", 1)[0] return truncated + "..." @@ -40,5 +40,5 @@ def normalize_whitespace(text: str) -> str: Normalized text """ # Replace multiple whitespace with single space - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) return text.strip() diff --git a/backend/app/agents/brainstorm_spec/writer.py b/backend/app/agents/brainstorm_spec/writer.py index 1724af5..db20d3b 100644 --- a/backend/app/agents/brainstorm_spec/writer.py +++ b/backend/app/agents/brainstorm_spec/writer.py @@ -14,14 +14,14 @@ from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( - NormalizedBrainstormContext, - BrainstormSpecOutline, BrainstormSpecification, - SpecSectionContent, + BrainstormSpecOutline, + NormalizedBrainstormContext, SpecOutlineSection, + SpecSectionContent, ) -from .logging_config import get_agent_logger from .utils import strip_markdown_json @@ -142,7 +142,7 @@ async def write_specification( outline: BrainstormSpecOutline, normalized_context: NormalizedBrainstormContext, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> BrainstormSpecification: """ Write the full specification from outline and context. @@ -159,19 +159,11 @@ async def write_specification( Raises: Exception: If writing fails for all sections """ - self.logger.log_agent_start( - project_id=project_id, - sections_count=len(outline.sections) - ) + self.logger.log_agent_start(project_id=project_id, sections_count=len(outline.sections)) # Create tasks for all sections to execute in parallel tasks = [ - self._write_section( - section, - normalized_context, - clarification_questions, - project_id - ) + self._write_section(section, normalized_context, clarification_questions, project_id) for section in outline.sections ] @@ -185,20 +177,21 @@ async def write_specification( if isinstance(result, Exception): self.logger.log_error(result, {"section_id": section.id}) # Create a placeholder for failed sections - sections.append(SpecSectionContent( - id=section.id, - title=section.title, - body_markdown=f"*Content generation failed: {str(result)}*", - linked_questions=section.linked_questions, - )) + sections.append( + SpecSectionContent( + id=section.id, + title=section.title, + body_markdown=f"*Content generation failed: {str(result)}*", + linked_questions=section.linked_questions, + ) + ) else: sections.append(result) specification = BrainstormSpecification(sections=sections) self.logger.log_agent_complete( - sections_written=len(sections), - total_markdown_length=sum(len(s.body_markdown) for s in sections) + sections_written=len(sections), total_markdown_length=sum(len(s.body_markdown) for s in sections) ) return specification @@ -208,7 +201,7 @@ async def _write_section( section: SpecOutlineSection, normalized_context: NormalizedBrainstormContext, clarification_questions: List[dict], - project_id: Optional[str] = None + project_id: Optional[str] = None, ) -> SpecSectionContent: """ Write content for a single section. @@ -230,7 +223,7 @@ async def _write_section( prompt=prompt[:300] + "..." if len(prompt) > 300 else prompt, model=str(self.model_client), operation="write_section", - section_id=section.id + section_id=section.id, ) # Create a FRESH agent for each section to avoid conversation history accumulation @@ -242,10 +235,7 @@ async def _write_section( model_client=self.model_client, ) - response = await section_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await section_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -291,8 +281,7 @@ def _parse_writer_response(self, response_text: str, section_id: str) -> tuple: # Validate content exists if not content: self.logger.log_error( - ValueError("Empty content in JSON response"), - {"section_id": section_id, "response": cleaned[:500]} + ValueError("Empty content in JSON response"), {"section_id": section_id, "response": cleaned[:500]} ) # Fallback: treat entire response as content return self._clean_markdown(cleaned), "" @@ -307,7 +296,7 @@ def _parse_writer_response(self, response_text: str, section_id: str) -> tuple: # Fallback: treat response as raw markdown (legacy behavior) self.logger.log_error( ValueError("Failed to parse JSON, using raw markdown"), - {"section_id": section_id, "response": cleaned[:200]} + {"section_id": section_id, "response": cleaned[:200]}, ) body_markdown = self._clean_markdown(cleaned) summary = self._generate_fallback_summary(body_markdown) @@ -329,14 +318,14 @@ def _generate_fallback_summary(self, content: str, max_chars: int = 300) -> str: if len(content) <= max_chars: return content # Take first N characters, break at word boundary - truncated = content[:max_chars].rsplit(' ', 1)[0] + truncated = content[:max_chars].rsplit(" ", 1)[0] return truncated + "..." def _build_section_prompt( self, section: SpecOutlineSection, normalized_context: NormalizedBrainstormContext, - clarification_questions: List[dict] + clarification_questions: List[dict], ) -> str: """ Build the prompt for writing a section. @@ -355,10 +344,10 @@ def _build_section_prompt( prompt += f"Section Description: {section.description}\n\n" # Check if section has Q&A backing - if not, generate placeholder - if hasattr(section, 'has_qa_backing') and not section.has_qa_backing: + if hasattr(section, "has_qa_backing") and not section.has_qa_backing: prompt += "### PENDING CLARIFICATIONS:\n" prompt += "This section lacks sufficient Q&A backing. The following topics need clarification:\n" - if hasattr(section, 'pending_clarifications') and section.pending_clarifications: + if hasattr(section, "pending_clarifications") and section.pending_clarifications: for clarification in section.pending_clarifications: prompt += f"- {clarification}\n" else: @@ -418,7 +407,9 @@ def _build_section_prompt( if normalized_context.system_context_summary: prompt += "\n### System Context (for coherence, not direct requirements):\n" prompt += normalized_context.system_context_summary - prompt += "\n\nUse this to ensure consistency with other phases. Do NOT add requirements from other phases.\n" + prompt += ( + "\n\nUse this to ensure consistency with other phases. Do NOT add requirements from other phases.\n" + ) elif "data" in section_id: if normalized_context.data_requirements: @@ -441,7 +432,9 @@ def _build_section_prompt( if normalized_context.system_context_summary: prompt += "\n### System Context (for coherence, not direct requirements):\n" prompt += normalized_context.system_context_summary - prompt += "\n\nUse this to ensure consistency with other phases. Do NOT add requirements from other phases.\n" + prompt += ( + "\n\nUse this to ensure consistency with other phases. Do NOT add requirements from other phases.\n" + ) elif "constraint" in section_id or "assumption" in section_id: if normalized_context.constraints: @@ -463,20 +456,20 @@ def _build_section_prompt( for q_id in section.linked_questions: # Find the question for q in clarification_questions: - if q.get('id') == q_id: + if q.get("id") == q_id: prompt += f"- **{q.get('title', 'Untitled')}**\n" prompt += f" {q.get('spec_text', q.get('description', ''))}\n" # Include answer if available - mcq_data = q.get('mcq_data', {}) + mcq_data = q.get("mcq_data", {}) if mcq_data: - selected = mcq_data.get('selected_option_id') + selected = mcq_data.get("selected_option_id") if selected: - choices = mcq_data.get('choices', []) + choices = mcq_data.get("choices", []) for choice in choices: - if choice.get('id') == selected: + if choice.get("id") == selected: prompt += f" Answer: {choice.get('label', selected)}\n" break - free_text = mcq_data.get('free_text') + free_text = mcq_data.get("free_text") if free_text: prompt += f" Additional: {free_text}\n" break @@ -507,10 +500,7 @@ def _build_section_prompt( return prompt -async def create_writer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> WriterAgent: +async def create_writer(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> WriterAgent: """ Factory function to create a Writer Agent. diff --git a/backend/app/agents/collab_thread_assistant/__init__.py b/backend/app/agents/collab_thread_assistant/__init__.py index 5d3d14a..55fe50b 100644 --- a/backend/app/agents/collab_thread_assistant/__init__.py +++ b/backend/app/agents/collab_thread_assistant/__init__.py @@ -16,30 +16,30 @@ - Phase 8: Production Hardening & Observability """ -from .types import CollabThreadContext, AssistantResponse -from .validators import ResponseValidator, ValidationResult +from .assistant import SYSTEM_PROMPT, CollabThreadAssistant from .config import ( - TOKEN_THRESHOLD, - SUMMARY_MAX_TOKENS, - RECENT_MESSAGES_COUNT, + ENABLE_DEBUG_LOGGING, MAX_RETRIES, + RECENT_MESSAGES_COUNT, RETRY_BACKOFF_MS, - ENABLE_DEBUG_LOGGING, + SUMMARY_MAX_TOKENS, + TOKEN_THRESHOLD, ) -from .context_loader import load_thread, load_files, token_count, load_spec_draft_context -from .summarizer import SummarizerAgent -from .assistant import CollabThreadAssistant, SYSTEM_PROMPT -from .orchestrator import build_context, call_assistant, handle_ai_mention -from .spec_draft_handler import handle_spec_draft_ai_mention -from .spec_draft_assistant import SpecDraftAssistant -from .retry import with_retry, with_retry_sync, RetryError +from .context_loader import load_files, load_spec_draft_context, load_thread, token_count from .instrumentation import ( CollabThreadAssistantLogger, - get_assistant_logger, DebugInfo, - SummarizationEvent, RetryEvent, + SummarizationEvent, + get_assistant_logger, ) +from .orchestrator import build_context, call_assistant, handle_ai_mention +from .retry import RetryError, with_retry, with_retry_sync +from .spec_draft_assistant import SpecDraftAssistant +from .spec_draft_handler import handle_spec_draft_ai_mention +from .summarizer import SummarizerAgent +from .types import AssistantResponse, CollabThreadContext +from .validators import ResponseValidator, ValidationResult __all__ = [ # Types diff --git a/backend/app/agents/collab_thread_assistant/assistant.py b/backend/app/agents/collab_thread_assistant/assistant.py index 4406302..81ce5a3 100644 --- a/backend/app/agents/collab_thread_assistant/assistant.py +++ b/backend/app/agents/collab_thread_assistant/assistant.py @@ -8,20 +8,21 @@ """ import logging -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from app.agents.brainstorm_conversation.types import CrossProjectContext + from .types import ( - CollabThreadContext, - ThreadMessage, BrainstormingPhaseContext, + CollabThreadContext, CurrentThreadContext, + ThreadMessage, ) -from app.agents.brainstorm_conversation.types import CrossProjectContext if TYPE_CHECKING: from app.agents.llm_client import LLMCallLogger @@ -169,9 +170,7 @@ def _create_agent(self) -> AssistantAgent: model_client=self.model_client, ) - def _format_brainstorming_context( - self, brainstorming_context: BrainstormingPhaseContext - ) -> str: + def _format_brainstorming_context(self, brainstorming_context: BrainstormingPhaseContext) -> str: """ Format brainstorming phase context for the prompt. @@ -222,9 +221,7 @@ def _format_brainstorming_context( return "\n".join(lines) - def _format_cross_project_context( - self, cross_project_context: CrossProjectContext - ) -> str: + def _format_cross_project_context(self, cross_project_context: CrossProjectContext) -> str: """ Format cross-project context (decisions from other phases + project features). @@ -256,8 +253,7 @@ def _format_cross_project_context( for decision in phase_ctx.decisions: lines.append( - f"- **{decision.question_title}** ({decision.aspect_title}): " - f"{decision.decision_summary_short}" + f"- **{decision.question_title}** ({decision.aspect_title}): {decision.decision_summary_short}" ) lines.append("") @@ -267,8 +263,7 @@ def _format_cross_project_context( lines.append("") for feature in cross_project_context.project_features: lines.append( - f"- **{feature.feature_title}** ({feature.module_title}): " - f"{feature.decision_summary_short}" + f"- **{feature.feature_title}** ({feature.module_title}): {feature.decision_summary_short}" ) lines.append("") @@ -305,8 +300,7 @@ def _format_context(self, context: CollabThreadContext) -> str: # Cross-project context (decisions from OTHER phases + project features) # Placed early to give the assistant broad project awareness if context.cross_project_context and ( - context.cross_project_context.other_phases - or context.cross_project_context.project_features + context.cross_project_context.other_phases or context.cross_project_context.project_features ): sections.append(self._format_cross_project_context(context.cross_project_context)) sections.append("") @@ -399,9 +393,7 @@ def _format_recent_messages(self, messages: List[ThreadMessage]) -> str: lines.append("") return "\n".join(lines) - def _format_current_thread_context( - self, current_thread_context: CurrentThreadContext - ) -> str: + def _format_current_thread_context(self, current_thread_context: CurrentThreadContext) -> str: """ Format the current thread context for the User's Question section. @@ -499,9 +491,7 @@ def _build_prompt( # Format current thread context for User's Question section current_thread_section = "" if context.current_thread_context: - current_thread_section = self._format_current_thread_context( - context.current_thread_context - ) + current_thread_section = self._format_current_thread_context(context.current_thread_context) current_thread_section += "\n" # Add additional context if provided (e.g., MCQ answer context) @@ -687,10 +677,7 @@ async def respond( """ # Set agent context for LLM call logging if self.llm_call_logger: - self.llm_call_logger.set_agent( - "collab_thread_assistant", - "Collab Thread Assistant (@MFBTAI)" - ) + self.llm_call_logger.set_agent("collab_thread_assistant", "Collab Thread Assistant (@MFBTAI)") # Create a fresh agent for this call agent = self._create_agent() @@ -743,7 +730,7 @@ def _generate_fallback_response(self, user_message: str, error: str) -> str: ## Key Points -- Your question: {user_message[:200]}{'...' if len(user_message) > 200 else ''} +- Your question: {user_message[:200]}{"..." if len(user_message) > 200 else ""} - I was unable to generate a complete response due to a technical issue. ## Next Steps diff --git a/backend/app/agents/collab_thread_assistant/context_loader.py b/backend/app/agents/collab_thread_assistant/context_loader.py index cc37649..947f86e 100644 --- a/backend/app/agents/collab_thread_assistant/context_loader.py +++ b/backend/app/agents/collab_thread_assistant/context_loader.py @@ -15,31 +15,29 @@ import litellm from sqlalchemy.orm import Session, joinedload -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.feature import Feature, FeatureVisibilityStatus, FeatureType -from app.models.module import Module, ModuleType -from app.models.brainstorming_phase import BrainstormingPhase -from app.services.grounding_service import GroundingService -from app.services.agent_utils import AGENT_EMAIL - from app.agents.brainstorm_conversation.types import ( - CrossProjectContext, CrossPhaseContext, CrossPhaseDecision, + CrossProjectContext, ProjectFeatureDecision, ) +from app.models.brainstorming_phase import BrainstormingPhase +from app.models.feature import Feature, FeatureType, FeatureVisibilityStatus +from app.models.module import Module, ModuleType +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType +from app.services.agent_utils import AGENT_EMAIL +from app.services.grounding_service import GroundingService +from .config import RECENT_MESSAGES_COUNT, TOKEN_COUNT_MODEL from .types import ( - ThreadMessage, - BrainstormingPhaseContext, AnsweredQuestion, - ThreadDiscussionSummary, - MCQChoice, - FeatureContext, + BrainstormingPhaseContext, CurrentThreadContext, + FeatureContext, + MCQChoice, + ThreadMessage, ) -from .config import RECENT_MESSAGES_COUNT, TOKEN_COUNT_MODEL logger = logging.getLogger(__name__) @@ -130,9 +128,7 @@ def load_thread( thread = ( db.query(Thread) .filter(Thread.id == thread_id) - .options( - joinedload(Thread.items).joinedload(ThreadItem.author) - ) + .options(joinedload(Thread.items).joinedload(ThreadItem.author)) .first() ) @@ -145,10 +141,7 @@ def load_thread( ThreadItemType.CODE_EXPLORATION.value, ThreadItemType.WEB_SEARCH.value, ] - relevant_items = [ - item for item in thread.items - if item.item_type in allowed_types - ] + relevant_items = [item for item in thread.items if item.item_type in allowed_types] # Convert to ThreadMessage objects messages: List[ThreadMessage] = [] @@ -370,12 +363,7 @@ def load_files( # Infer project_id from feature if not provided if project_id is None: - feature = ( - db.query(Feature) - .options(joinedload(Feature.module)) - .filter(Feature.id == feature_id) - .first() - ) + feature = db.query(Feature).options(joinedload(Feature.module)).filter(Feature.id == feature_id).first() if feature and feature.module: project_id = str(feature.module.project_id) @@ -416,11 +404,13 @@ def _extract_mcq_choices(choices: List[dict], selected_option_id: Optional[str]) """Extract MCQ choices with selection status.""" mcq_choices = [] for choice in choices: - mcq_choices.append(MCQChoice( - id=choice.get("id", ""), - label=choice.get("label", ""), - is_selected=(choice.get("id") == selected_option_id) if selected_option_id else False, - )) + mcq_choices.append( + MCQChoice( + id=choice.get("id", ""), + label=choice.get("label", ""), + is_selected=(choice.get("id") == selected_option_id) if selected_option_id else False, + ) + ) return mcq_choices @@ -429,7 +419,7 @@ def _extract_unresolved_points(thread: Thread) -> List[str]: unresolved = [] if thread.unresolved_points: for point in thread.unresolved_points: - question = point.get('question', '') if isinstance(point, dict) else str(point) + question = point.get("question", "") if isinstance(point, dict) else str(point) if question: unresolved.append(question) return unresolved @@ -437,10 +427,7 @@ def _extract_unresolved_points(thread: Thread) -> List[str]: def _extract_key_points_fallback(thread: Thread) -> List[str]: """Extract key points from comments as fallback when no decision summary.""" - comment_items = [ - item for item in thread.items - if item.item_type == ThreadItemType.COMMENT.value - ] + comment_items = [item for item in thread.items if item.item_type == ThreadItemType.COMMENT.value] if not comment_items: return [] @@ -453,11 +440,7 @@ def _extract_key_points_fallback(thread: Thread) -> List[str]: if body: # Truncate long comments truncated = body[:200] + "..." if len(body) > 200 else body - author_name = ( - item.author.display_name - if item.author and item.author.display_name - else "User" - ) + author_name = item.author.display_name if item.author and item.author.display_name else "User" key_points.append(f"{author_name}: {truncated}") return key_points @@ -489,9 +472,7 @@ def load_brainstorming_phase_context( # Load feature with module and brainstorming phase feature = ( db.query(Feature) - .options( - joinedload(Feature.module).joinedload(Module.brainstorming_phase) - ) + .options(joinedload(Feature.module).joinedload(Module.brainstorming_phase)) .filter(Feature.id == feature_id) .first() ) @@ -586,14 +567,16 @@ def load_brainstorming_phase_context( # Create FeatureContext combining questions + summary if feature_questions or decision_summary or key_points: - feature_contexts.append(FeatureContext( - feature_id=str(conv_feature.id), - feature_title=conv_feature.title, - answered_questions=feature_questions, - decision_summary=decision_summary, - unresolved_points=unresolved, - key_points=key_points, - )) + feature_contexts.append( + FeatureContext( + feature_id=str(conv_feature.id), + feature_title=conv_feature.title, + answered_questions=feature_questions, + decision_summary=decision_summary, + unresolved_points=unresolved, + key_points=key_points, + ) + ) # Apply limit feature_contexts = feature_contexts[:MAX_FEATURE_CONTEXTS] @@ -645,12 +628,7 @@ def load_current_thread_context( return None # Load thread with items - thread = ( - db.query(Thread) - .filter(Thread.id == thread_id) - .options(joinedload(Thread.items)) - .first() - ) + thread = db.query(Thread).filter(Thread.id == thread_id).options(joinedload(Thread.items)).first() if not thread: logger.debug(f"Thread {thread_id} not found for current thread context") return None @@ -675,13 +653,15 @@ def load_current_thread_context( mcq_choices = _extract_mcq_choices(choices, selected_option_id) - mcq_questions.append(AnsweredQuestion( - question_text=content_data.get("question_text", ""), - selected_label=selected_label, - choices=mcq_choices, - free_text=content_data.get("free_text"), - feature_title=feature.title, - )) + mcq_questions.append( + AnsweredQuestion( + question_text=content_data.get("question_text", ""), + selected_label=selected_label, + choices=mcq_choices, + free_text=content_data.get("free_text"), + feature_title=feature.title, + ) + ) # Extract decision summary and unresolved points unresolved = _extract_unresolved_points(thread) @@ -740,41 +720,40 @@ def load_cross_project_context( # 1. Query all brainstorming phases (not archived) phase_query = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project_id, - BrainstormingPhase.archived_at.is_(None) + BrainstormingPhase.project_id == project_id, BrainstormingPhase.archived_at.is_(None) ) # Exclude current phase if specified if exclude_phase_id: phase_query = phase_query.filter(BrainstormingPhase.id != exclude_phase_id) - all_phases = phase_query.order_by( - BrainstormingPhase.created_at - ).limit(MAX_PHASES_FOR_CROSS_CONTEXT).all() + all_phases = phase_query.order_by(BrainstormingPhase.created_at).limit(MAX_PHASES_FOR_CROSS_CONTEXT).all() for phase in all_phases: decisions: List[CrossPhaseDecision] = [] # Get modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.archived_at.is_(None) - ).all() + modules = db.query(Module).filter(Module.brainstorming_phase_id == phase.id, Module.archived_at.is_(None)).all() for module in modules: # Get ACTIVE features (questions) with threads that have decisions - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id) - ).first() + thread = ( + db.query(Thread) + .filter(Thread.context_type == ContextType.BRAINSTORM_FEATURE, Thread.context_id == str(feature.id)) + .first() + ) # Only include if thread has decision_summary_short or decision_summary if thread and (thread.decision_summary_short or thread.decision_summary): @@ -785,11 +764,13 @@ def load_cross_project_context( else thread.decision_summary ) if summary: - decisions.append(CrossPhaseDecision( - question_title=feature.title, - decision_summary_short=summary, - aspect_title=module.title, - )) + decisions.append( + CrossPhaseDecision( + question_title=feature.title, + decision_summary_short=summary, + aspect_title=module.title, + ) + ) # Cap decisions per phase if len(decisions) >= MAX_DECISIONS_PER_PHASE: @@ -805,34 +786,38 @@ def load_cross_project_context( if len(description) > 200: description = description[:200] + "..." - phases_context.append(CrossPhaseContext( - phase_id=str(phase.id), - phase_title=phase.title, - phase_description=description, - decisions=decisions, - )) + phases_context.append( + CrossPhaseContext( + phase_id=str(phase.id), + phase_title=phase.title, + phase_description=description, + decisions=decisions, + ) + ) # 2. Query project-level features (module.brainstorming_phase_id IS NULL) - project_modules = db.query(Module).filter( - Module.project_id == project_id, - Module.brainstorming_phase_id.is_(None), - Module.archived_at.is_(None) - ).all() + project_modules = ( + db.query(Module) + .filter(Module.project_id == project_id, Module.brainstorming_phase_id.is_(None), Module.archived_at.is_(None)) + .all() + ) for module in project_modules: # Get IMPLEMENTATION features (not CONVERSATION) - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature (could be SPEC or GENERAL context type) - thread = db.query(Thread).filter( - Thread.context_id == str(feature.id) - ).first() + thread = db.query(Thread).filter(Thread.context_id == str(feature.id)).first() # Only include if thread has decision summary if thread and (thread.decision_summary_short or thread.decision_summary): @@ -842,11 +827,13 @@ def load_cross_project_context( else thread.decision_summary ) if summary: - project_features_context.append(ProjectFeatureDecision( - feature_title=feature.title, - module_title=module.title, - decision_summary_short=summary, - )) + project_features_context.append( + ProjectFeatureDecision( + feature_title=feature.title, + module_title=module.title, + decision_summary_short=summary, + ) + ) # Cap project features if len(project_features_context) >= MAX_PROJECT_FEATURES_FOR_CROSS_CONTEXT: @@ -908,8 +895,8 @@ def load_spec_draft_context( Raises: ValueError: If draft version or phase not found. """ - from app.services.draft_version_service import DraftVersionService from app.services.brainstorming_phase_service import BrainstormingPhaseService + from app.services.draft_version_service import DraftVersionService # Load draft version draft = DraftVersionService.get_draft(db, UUID(version_id)) @@ -931,10 +918,7 @@ def load_spec_draft_context( # Load grounding files for project context grounding_data = load_grounding_files(db, str(phase.project_id)) - logger.info( - f"Loaded spec draft context for version {version_id}, " - f"block {block_id}, phase {phase.title}" - ) + logger.info(f"Loaded spec draft context for version {version_id}, block {block_id}, phase {phase.title}") return { "full_document": draft.content_markdown or "", diff --git a/backend/app/agents/collab_thread_assistant/exploration_parser.py b/backend/app/agents/collab_thread_assistant/exploration_parser.py index bbbdcf2..8317a79 100644 --- a/backend/app/agents/collab_thread_assistant/exploration_parser.py +++ b/backend/app/agents/collab_thread_assistant/exploration_parser.py @@ -60,8 +60,7 @@ def parse_exploration_request(response_text: str) -> Optional[CodeExplorationReq for fallback in fallback_patterns: if re.search(fallback, response_text.lower()): logger.warning( - f"Detected exploration intent without proper block format. " - f"Response contains: '{fallback}'" + f"Detected exploration intent without proper block format. Response contains: '{fallback}'" ) return None diff --git a/backend/app/agents/collab_thread_assistant/instrumentation.py b/backend/app/agents/collab_thread_assistant/instrumentation.py index 383d378..8a3a322 100644 --- a/backend/app/agents/collab_thread_assistant/instrumentation.py +++ b/backend/app/agents/collab_thread_assistant/instrumentation.py @@ -7,11 +7,11 @@ retry attempts, and debug information collection. """ -import logging import json +import logging from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional from datetime import datetime, timezone +from typing import Any, Dict, List, Optional @dataclass @@ -110,9 +110,7 @@ def __init__( # Ensure structured output if not self.logger.handlers: handler = logging.StreamHandler() - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) @@ -158,11 +156,7 @@ def log_request_start(self, message_preview: str) -> None: Args: message_preview: Truncated preview of the user message. """ - preview = ( - message_preview[:200] + "..." - if len(message_preview) > 200 - else message_preview - ) + preview = message_preview[:200] + "..." if len(message_preview) > 200 else message_preview self._structured_log( "info", "request_start", @@ -274,11 +268,7 @@ def log_summarization_triggered( ) ) - reduction_pct = ( - round((1 - summarized_tokens / original_tokens) * 100, 1) - if original_tokens > 0 - else 0 - ) + reduction_pct = round((1 - summarized_tokens / original_tokens) * 100, 1) if original_tokens > 0 else 0 self._structured_log( "info", diff --git a/backend/app/agents/collab_thread_assistant/mcq_parser.py b/backend/app/agents/collab_thread_assistant/mcq_parser.py index f0bea3f..69d1d6d 100644 --- a/backend/app/agents/collab_thread_assistant/mcq_parser.py +++ b/backend/app/agents/collab_thread_assistant/mcq_parser.py @@ -17,15 +17,13 @@ MAX_MCQS_PER_RESPONSE = 3 # Regex pattern to find MCQ blocks -MCQ_BLOCK_PATTERN = re.compile( - r'\[MFBT_MCQ\](.*?)\[/MFBT_MCQ\]', - re.DOTALL -) +MCQ_BLOCK_PATTERN = re.compile(r"\[MFBT_MCQ\](.*?)\[/MFBT_MCQ\]", re.DOTALL) @dataclass class ParsedMCQ: """A single parsed MCQ from an MFBTAI response.""" + question_text: str choices: list[dict] # [{"id": "option_1", "label": "..."}] explanation: Optional[str] = None @@ -36,6 +34,7 @@ class ParsedMCQ: @dataclass class ParsedResponse: """Result of parsing an MFBTAI response for MCQ blocks.""" + preamble_text: Optional[str] = None # Text before MCQ block mcqs: list[ParsedMCQ] = field(default_factory=list) has_mcq_block: bool = False @@ -107,9 +106,7 @@ def _parse_mcq_json(json_str: str) -> tuple[list[ParsedMCQ], Optional[str]]: return [], "'questions' must be an array" if len(questions) > MAX_MCQS_PER_RESPONSE: - logger.warning( - f"MCQ block has {len(questions)} questions, limiting to {MAX_MCQS_PER_RESPONSE}" - ) + logger.warning(f"MCQ block has {len(questions)} questions, limiting to {MAX_MCQS_PER_RESPONSE}") questions = questions[:MAX_MCQS_PER_RESPONSE] parsed_mcqs = [] @@ -133,13 +130,15 @@ def _parse_mcq_json(json_str: str) -> tuple[list[ParsedMCQ], Optional[str]]: recommended_option_id = None recommended_reason = None - parsed_mcqs.append(ParsedMCQ( - question_text=mcq_data["question_text"], - choices=mcq_data["choices"], - explanation=mcq_data.get("explanation"), - recommended_option_id=recommended_option_id, - recommended_reason=recommended_reason, - )) + parsed_mcqs.append( + ParsedMCQ( + question_text=mcq_data["question_text"], + choices=mcq_data["choices"], + explanation=mcq_data.get("explanation"), + recommended_option_id=recommended_option_id, + recommended_reason=recommended_reason, + ) + ) return parsed_mcqs, None @@ -179,7 +178,7 @@ def parse_mfbtai_response(response_text: str) -> ParsedResponse: ) # Extract preamble (text before MCQ block) - preamble = response_text[:match.start()].strip() + preamble = response_text[: match.start()].strip() preamble_text = preamble if preamble else None # Parse MCQ JSON diff --git a/backend/app/agents/collab_thread_assistant/orchestrator.py b/backend/app/agents/collab_thread_assistant/orchestrator.py index 844e556..e3ef9a9 100644 --- a/backend/app/agents/collab_thread_assistant/orchestrator.py +++ b/backend/app/agents/collab_thread_assistant/orchestrator.py @@ -14,41 +14,41 @@ import asyncio import logging import time -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple from uuid import UUID from autogen_core.models import ChatCompletionClient from sqlalchemy.orm import Session, joinedload from app.models.feature import Feature -from app.models.thread import Thread, ContextType -from app.models.project import Project from app.models.module import Module +from app.models.project import Project +from app.models.thread import ContextType, Thread -from .types import AssistantResponse, CollabThreadContext -from .context_loader import ( - load_thread, - load_feature_files, - load_grounding_files, - load_brainstorming_phase_context, - load_current_thread_context, - load_cross_project_context, - token_count, -) -from .summarizer import SummarizerAgent from .assistant import CollabThreadAssistant -from .retry import with_retry, RetryError from .config import ( - TOKEN_THRESHOLD, - SUMMARY_MAX_TOKENS, - RECENT_MESSAGES_COUNT, - ENABLE_SUMMARIZATION, ENABLE_DEBUG_LOGGING, - MIN_MESSAGES_FOR_SUMMARY, + ENABLE_SUMMARIZATION, MAX_RETRIES, + MIN_MESSAGES_FOR_SUMMARY, + RECENT_MESSAGES_COUNT, RETRY_BACKOFF_MS, + SUMMARY_MAX_TOKENS, + TOKEN_THRESHOLD, +) +from .context_loader import ( + load_brainstorming_phase_context, + load_cross_project_context, + load_current_thread_context, + load_feature_files, + load_grounding_files, + load_thread, + token_count, ) from .instrumentation import CollabThreadAssistantLogger, get_assistant_logger +from .retry import RetryError, with_retry +from .summarizer import SummarizerAgent +from .types import AssistantResponse, CollabThreadContext if TYPE_CHECKING: from app.agents.llm_client import LLMCallLogger @@ -102,12 +102,7 @@ async def build_context( # Infer project_id from feature if not provided if project_id is None: - feature = ( - db.query(Feature) - .options(joinedload(Feature.module)) - .filter(Feature.id == feature_id) - .first() - ) + feature = db.query(Feature).options(joinedload(Feature.module)).filter(Feature.id == feature_id).first() if feature and feature.module: project_id = str(feature.module.project_id) @@ -134,6 +129,7 @@ async def build_context( # Check if web search is available from app.services.platform_settings_service import is_web_search_available_sync + web_search_enabled = is_web_search_available_sync(db) # Load brainstorming context if this is a BRAINSTORM_FEATURE thread @@ -187,9 +183,9 @@ async def build_context( exclude_phase_id=current_phase_id, # Exclude current phase to avoid duplication ) if cross_project_context: - total_decisions = sum( - len(phase.decisions) for phase in cross_project_context.other_phases - ) + len(cross_project_context.project_features) + total_decisions = sum(len(phase.decisions) for phase in cross_project_context.other_phases) + len( + cross_project_context.project_features + ) logger.info( f"Loaded cross-project context: {len(cross_project_context.other_phases)} phases, " f"{len(cross_project_context.project_features)} project features, " @@ -233,57 +229,53 @@ async def build_context( messages_to_summarize = None # Check if thread needs summarization - if ( - token_counts["thread"] > TOKEN_THRESHOLD - and len(all_messages) >= MIN_MESSAGES_FOR_SUMMARY - ): + if token_counts["thread"] > TOKEN_THRESHOLD and len(all_messages) >= MIN_MESSAGES_FOR_SUMMARY: # Get messages excluding recent ones (they'll be included verbatim) messages_to_summarize = ( - all_messages[:-RECENT_MESSAGES_COUNT] - if len(all_messages) > RECENT_MESSAGES_COUNT - else all_messages + all_messages[:-RECENT_MESSAGES_COUNT] if len(all_messages) > RECENT_MESSAGES_COUNT else all_messages ) if messages_to_summarize: logger.info( f"Summarizing thread {thread_id}: {len(messages_to_summarize)} messages " f"({token_counts['thread']} tokens)" ) - summarization_tasks.append(( - "thread", - summarizer.summarize_thread(messages_to_summarize, max_tokens=SUMMARY_MAX_TOKENS) - )) + summarization_tasks.append( + ("thread", summarizer.summarize_thread(messages_to_summarize, max_tokens=SUMMARY_MAX_TOKENS)) + ) # Check if spec needs summarization if token_counts["spec"] > TOKEN_THRESHOLD: logger.info(f"Summarizing spec for feature {feature_id}: {token_counts['spec']} tokens") - summarization_tasks.append(( - "spec", - summarizer.summarize(spec, max_tokens=SUMMARY_MAX_TOKENS, context_type="spec") - )) + summarization_tasks.append( + ("spec", summarizer.summarize(spec, max_tokens=SUMMARY_MAX_TOKENS, context_type="spec")) + ) # Check if prompt_plan needs summarization if token_counts["prompt_plan"] > TOKEN_THRESHOLD: logger.info(f"Summarizing prompt_plan for feature {feature_id}: {token_counts['prompt_plan']} tokens") - summarization_tasks.append(( - "prompt_plan", - summarizer.summarize(prompt_plan, max_tokens=SUMMARY_MAX_TOKENS, context_type="prompt_plan") - )) + summarization_tasks.append( + ( + "prompt_plan", + summarizer.summarize(prompt_plan, max_tokens=SUMMARY_MAX_TOKENS, context_type="prompt_plan"), + ) + ) # Check if notes needs summarization if token_counts["notes"] > TOKEN_THRESHOLD: logger.info(f"Summarizing notes for feature {feature_id}: {token_counts['notes']} tokens") - summarization_tasks.append(( - "notes", - summarizer.summarize(notes, max_tokens=SUMMARY_MAX_TOKENS, context_type="notes") - )) + summarization_tasks.append( + ("notes", summarizer.summarize(notes, max_tokens=SUMMARY_MAX_TOKENS, context_type="notes")) + ) # Check if grounding needs summarization if token_counts["grounding"] > TOKEN_THRESHOLD: logger.info(f"Summarizing grounding for project {project_id}: {token_counts['grounding']} tokens") - summarization_tasks.append(( - "grounding", - summarizer.summarize(grounding_combined, max_tokens=SUMMARY_MAX_TOKENS, context_type="grounding") - )) + summarization_tasks.append( + ( + "grounding", + summarizer.summarize(grounding_combined, max_tokens=SUMMARY_MAX_TOKENS, context_type="grounding"), + ) + ) # Execute all summarization tasks in parallel if summarization_tasks: @@ -460,7 +452,7 @@ def _generate_graceful_fallback( ## Key Points -- Your question: {message[:200]}{'...' if len(message) > 200 else ''} +- Your question: {message[:200]}{"..." if len(message) > 200 else ""} - Thread ID: {thread_id} - Feature ID: {feature_id} - Multiple retry attempts were made before this fallback response. @@ -681,10 +673,7 @@ def on_assistant_retry(attempt: int, error: Exception) -> None: tokens_used = sum(context.token_counts.values()) assistant_logger.log_request_complete(latency_ms, success=False, tokens_used=tokens_used) - summarized_parts = [ - key for key, was_summarized in context.summarization_applied.items() - if was_summarized - ] + summarized_parts = [key for key, was_summarized in context.summarization_applied.items() if was_summarized] metadata = { "implemented": True, @@ -720,10 +709,7 @@ def on_assistant_retry(attempt: int, error: Exception) -> None: latency_ms = (time.time() - start_time) * 1000 # Determine which parts were summarized - summarized_parts = [ - key for key, was_summarized in context.summarization_applied.items() - if was_summarized - ] + summarized_parts = [key for key, was_summarized in context.summarization_applied.items() if was_summarized] # Log successful request completion tokens_used = sum(context.token_counts.values()) @@ -800,14 +786,14 @@ async def handle_ai_mention( Raises: ValueError: If thread not found, feature not found, or no LLM configured. """ - from app.models import Thread, Project from app.agents.llm_client import ( + DEFAULT_LLM_REQUEST_TIMEOUT_SECONDS, LiteLLMChatCompletionClient, LLMCallLogger, - DEFAULT_LLM_REQUEST_TIMEOUT_SECONDS, ) - from app.services.platform_settings_service import require_llm_config_sync from app.database import SessionLocal + from app.models import Project, Thread + from app.services.platform_settings_service import require_llm_config_sync # Get thread to find project and org thread = db.query(Thread).filter(Thread.id == thread_id).first() diff --git a/backend/app/agents/collab_thread_assistant/retry.py b/backend/app/agents/collab_thread_assistant/retry.py index 8367456..c734e9c 100644 --- a/backend/app/agents/collab_thread_assistant/retry.py +++ b/backend/app/agents/collab_thread_assistant/retry.py @@ -11,14 +11,14 @@ # Re-export everything from the shared retry module from app.agents.retry import ( + # Constants (re-export from config for compatibility) + LEGACY_BACKOFF_MS, # Exception RetryError, + calculate_backoff_delay, # Functions with_retry, with_retry_sync, - calculate_backoff_delay, - # Constants (re-export from config for compatibility) - LEGACY_BACKOFF_MS, ) # Re-export config values for code that imports them from here diff --git a/backend/app/agents/collab_thread_assistant/spec_draft_assistant.py b/backend/app/agents/collab_thread_assistant/spec_draft_assistant.py index 74c29f9..e9c08ac 100644 --- a/backend/app/agents/collab_thread_assistant/spec_draft_assistant.py +++ b/backend/app/agents/collab_thread_assistant/spec_draft_assistant.py @@ -6,7 +6,7 @@ """ import logging -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from autogen_agentchat.agents import AssistantAgent from autogen_core.models import ChatCompletionClient @@ -190,6 +190,7 @@ async def respond( elif isinstance(content, dict): # JSON response - convert to string import json + return json.dumps(content) return str(content) diff --git a/backend/app/agents/collab_thread_assistant/spec_draft_handler.py b/backend/app/agents/collab_thread_assistant/spec_draft_handler.py index 4be1499..0cb6ba5 100644 --- a/backend/app/agents/collab_thread_assistant/spec_draft_handler.py +++ b/backend/app/agents/collab_thread_assistant/spec_draft_handler.py @@ -54,15 +54,16 @@ async def handle_spec_draft_ai_mention( Raises: ValueError: If thread not found, version not found, or no LLM configured. """ - from app.models import Thread, Project from app.agents.llm_client import ( + DEFAULT_LLM_REQUEST_TIMEOUT_SECONDS, LiteLLMChatCompletionClient, LLMCallLogger, - DEFAULT_LLM_REQUEST_TIMEOUT_SECONDS, ) - from app.services.platform_settings_service import require_llm_config_sync from app.database import SessionLocal - from .context_loader import load_spec_draft_context, load_thread, load_grounding_files + from app.models import Project, Thread + from app.services.platform_settings_service import require_llm_config_sync + + from .context_loader import load_spec_draft_context, load_thread from .spec_draft_assistant import SpecDraftAssistant # Get thread to find project diff --git a/backend/app/agents/collab_thread_assistant/summarizer.py b/backend/app/agents/collab_thread_assistant/summarizer.py index 1a43ef9..36cd9c3 100644 --- a/backend/app/agents/collab_thread_assistant/summarizer.py +++ b/backend/app/agents/collab_thread_assistant/summarizer.py @@ -8,15 +8,15 @@ """ import logging -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient -from .types import ThreadMessage from .config import SUMMARY_MAX_TOKENS +from .types import ThreadMessage if TYPE_CHECKING: from app.agents.llm_client import LLMCallLogger @@ -167,7 +167,7 @@ async def summarize( agent = self._create_agent(system_prompt, "file_summarizer") # Build the prompt - prompt = f"""Please summarize the following {context_type or 'document'}: + prompt = f"""Please summarize the following {context_type or "document"}: --- {text} diff --git a/backend/app/agents/collab_thread_assistant/types.py b/backend/app/agents/collab_thread_assistant/types.py index 454a077..550c79e 100644 --- a/backend/app/agents/collab_thread_assistant/types.py +++ b/backend/app/agents/collab_thread_assistant/types.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID if TYPE_CHECKING: diff --git a/backend/app/agents/collab_thread_assistant/web_search_parser.py b/backend/app/agents/collab_thread_assistant/web_search_parser.py index d54ac52..f91495d 100644 --- a/backend/app/agents/collab_thread_assistant/web_search_parser.py +++ b/backend/app/agents/collab_thread_assistant/web_search_parser.py @@ -59,8 +59,7 @@ def parse_web_search_request(response_text: str) -> Optional[WebSearchRequest]: for fallback in fallback_patterns: if re.search(fallback, response_text.lower()): logger.warning( - f"Detected web search intent without proper block format. " - f"Response contains: '{fallback}'" + f"Detected web search intent without proper block format. Response contains: '{fallback}'" ) return None diff --git a/backend/app/agents/collab_thread_decision_summarizer/orchestrator.py b/backend/app/agents/collab_thread_decision_summarizer/orchestrator.py index ae7745c..61723c7 100644 --- a/backend/app/agents/collab_thread_decision_summarizer/orchestrator.py +++ b/backend/app/agents/collab_thread_decision_summarizer/orchestrator.py @@ -17,10 +17,9 @@ create_litellm_client, ) from app.database import SessionLocal -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType from app.models.feature import Feature -from app.models.user import User +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType from app.services.agent_utils import AGENT_EMAIL from .config import ( @@ -137,9 +136,7 @@ async def summarize_thread( all_items = sorted(thread.items, key=lambda x: x.created_at) # Find unprocessed items (after last_summarized_item_id) - unprocessed_items = self._get_unprocessed_items( - all_items, thread.last_summarized_item_id - ) + unprocessed_items = self._get_unprocessed_items(all_items, thread.last_summarized_item_id) if not unprocessed_items: logger.info(f"Thread {thread_id}: No unprocessed items") @@ -153,9 +150,7 @@ async def summarize_thread( final_suggested_implementation_name=thread.suggested_implementation_name, ) - logger.info( - f"Thread {thread_id}: Processing {len(unprocessed_items)} unprocessed items" - ) + logger.info(f"Thread {thread_id}: Processing {len(unprocessed_items)} unprocessed items") # Load existing state current_summary = thread.decision_summary @@ -172,9 +167,7 @@ async def summarize_thread( { "current_item": i + 1, "total_items": len(unprocessed_items), - "progress_percentage": 10 + int( - (i / len(unprocessed_items)) * 70 - ), + "progress_percentage": 10 + int((i / len(unprocessed_items)) * 70), }, ) @@ -420,9 +413,7 @@ def _update_thread_summary( db.commit() db.refresh(thread) - logger.debug( - f"Updated thread {thread.id}: last_summarized_item_id={last_processed_id}" - ) + logger.debug(f"Updated thread {thread.id}: last_summarized_item_id={last_processed_id}") async def create_orchestrator( diff --git a/backend/app/agents/collab_thread_decision_summarizer/summarizer.py b/backend/app/agents/collab_thread_decision_summarizer/summarizer.py index c63e4f0..284197b 100644 --- a/backend/app/agents/collab_thread_decision_summarizer/summarizer.py +++ b/backend/app/agents/collab_thread_decision_summarizer/summarizer.py @@ -6,20 +6,19 @@ import json import logging -from typing import Awaitable, Callable, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Tuple from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .config import SUMMARY_MAX_TOKENS from .types import ( DecisionSummaryContext, DecisionSummaryResult, UnresolvedPoint, - UnresolvedStatus, ) -from .config import SUMMARY_MAX_TOKENS if TYPE_CHECKING: from app.agents.llm_client import LLMCallLogger @@ -195,10 +194,7 @@ async def process_item( # Set agent context for LLM call logging if self.llm_call_logger: retry_suffix = f" (retry {attempt})" if attempt > 0 else "" - self.llm_call_logger.set_agent( - "decision_summarizer", - f"Decision Summarizer{retry_suffix}" - ) + self.llm_call_logger.set_agent("decision_summarizer", f"Decision Summarizer{retry_suffix}") # Create fresh agent for each attempt (AutoGen agents accumulate history) agent = self._create_agent() @@ -225,10 +221,7 @@ async def process_item( if json_parsed: # Successfully parsed JSON - logger.debug( - f"Processed item {context.new_item.item_id}: " - f"summary_changed={result.summary_changed}" - ) + logger.debug(f"Processed item {context.new_item.item_id}: summary_changed={result.summary_changed}") return result # JSON parsing failed - retry if not last attempt @@ -242,9 +235,7 @@ async def process_item( continue else: # Last attempt - use fallback - logger.warning( - f"LLM did not return valid JSON after {MAX_JSON_RETRIES} attempts, using fallback" - ) + logger.warning(f"LLM did not return valid JSON after {MAX_JSON_RETRIES} attempts, using fallback") return result except Exception as e: @@ -303,7 +294,7 @@ def _build_prompt(self, context: DecisionSummaryContext) -> str: item_content = f"""Type: MCQ Answer (EXPLICIT DECISION) Question: {item.mcq_question} Selected Answer: {item.mcq_selected_answer} -Additional Notes: {item.mcq_free_text or 'None provided'}""" +Additional Notes: {item.mcq_free_text or "None provided"}""" else: item_content = f"""Type: Comment Content: {item.content}""" @@ -317,7 +308,7 @@ def _build_prompt(self, context: DecisionSummaryContext) -> str: {unresolved_json} ## NEW MESSAGE: -Author: {item.author}{' (AI Assistant)' if item.is_ai else ''} +Author: {item.author}{" (AI Assistant)" if item.is_ai else ""} Timestamp: {item.created_at.isoformat()} {item_content} diff --git a/backend/app/agents/collab_thread_decision_summarizer/types.py b/backend/app/agents/collab_thread_decision_summarizer/types.py index 6a7eb96..a723638 100644 --- a/backend/app/agents/collab_thread_decision_summarizer/types.py +++ b/backend/app/agents/collab_thread_decision_summarizer/types.py @@ -1,6 +1,6 @@ """Type definitions for the CollabThreadDecisionSummarizer agent.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional diff --git a/backend/app/agents/feature_content/__init__.py b/backend/app/agents/feature_content/__init__.py index 1d5c79e..906d173 100644 --- a/backend/app/agents/feature_content/__init__.py +++ b/backend/app/agents/feature_content/__init__.py @@ -29,26 +29,24 @@ print(result.content_markdown) """ -from .types import ( - ContentType, - FeatureContentContext, - ContentGenerationResult, - FeatureInfo, - ModuleInfo, - ThreadItem, - AGENT_METADATA, - WORKFLOW_STEPS, -) - from .context_loader import ( - load_feature_context, format_thread_items, + load_feature_context, ) - from .orchestrator import ( FeatureContentOrchestrator, create_orchestrator, ) +from .types import ( + AGENT_METADATA, + WORKFLOW_STEPS, + ContentGenerationResult, + ContentType, + FeatureContentContext, + FeatureInfo, + ModuleInfo, + ThreadItem, +) __all__ = [ # Types diff --git a/backend/app/agents/feature_content/context_loader.py b/backend/app/agents/feature_content/context_loader.py index 396623e..6048dc1 100644 --- a/backend/app/agents/feature_content/context_loader.py +++ b/backend/app/agents/feature_content/context_loader.py @@ -12,15 +12,14 @@ from sqlalchemy.orm import Session, joinedload -from app.models.feature import Feature -from app.models.module import Module -from app.models.project import Project -from app.models.implementation import Implementation -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem as ThreadItemModel, ThreadItemType from app.agents.collab_thread_assistant.context_loader import ( load_grounding_files, ) +from app.models.feature import Feature +from app.models.implementation import Implementation +from app.models.project import Project +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem as ThreadItemModel from .types import ( FeatureContentContext, @@ -52,7 +51,7 @@ def _find_implementation_segment( # Find all IMPLEMENTATION_CREATED markers ordered by created_at markers = [] for item in thread_items_db: - item_type_str = item.item_type.value if hasattr(item.item_type, 'value') else item.item_type + item_type_str = item.item_type.value if hasattr(item.item_type, "value") else item.item_type if item_type_str == "implementation_created": markers.append(item) @@ -76,7 +75,7 @@ def _find_implementation_segment( ) return (None, None, True) - is_first = (target_index == 0) + is_first = target_index == 0 # Segment end is this marker's timestamp segment_end = target_marker.created_at @@ -111,12 +110,7 @@ def load_feature_context( ValueError: If feature, module, project, or thread not found """ # Load feature with module - feature = ( - db.query(Feature) - .options(joinedload(Feature.module)) - .filter(Feature.id == feature_id) - .first() - ) + feature = db.query(Feature).options(joinedload(Feature.module)).filter(Feature.id == feature_id).first() if not feature: raise ValueError(f"Feature {feature_id} not found") @@ -242,7 +236,7 @@ def _convert_thread_items( items = [] for item in thread_items_db: # item_type may be an enum or string depending on how it's loaded - item_type_str = item.item_type.value if hasattr(item.item_type, 'value') else item.item_type + item_type_str = item.item_type.value if hasattr(item.item_type, "value") else item.item_type content_data = item.content_data or {} # Determine if this item is in the focus segment @@ -255,12 +249,14 @@ def _convert_thread_items( if item_type_str == "comment": author = item.author.display_name if item.author else "Unknown" - items.append(ThreadItem( - item_type="comment", - author_name=author, - body=content_data.get("body_markdown", ""), - is_focus_segment=is_focus, - )) + items.append( + ThreadItem( + item_type="comment", + author_name=author, + body=content_data.get("body_markdown", ""), + is_focus_segment=is_focus, + ) + ) elif item_type_str == "mcq_followup": question = content_data.get("question_text", "") selected_option_id = content_data.get("selected_option_id") @@ -273,31 +269,37 @@ def _convert_thread_items( selected_label = choice.get("label", "") break - items.append(ThreadItem( - item_type="mcq", - author_name="System", - mcq_question=question, - mcq_selected_label=selected_label, - mcq_free_text=content_data.get("free_text"), - is_focus_segment=is_focus, - )) + items.append( + ThreadItem( + item_type="mcq", + author_name="System", + mcq_question=question, + mcq_selected_label=selected_label, + mcq_free_text=content_data.get("free_text"), + is_focus_segment=is_focus, + ) + ) elif item_type_str == "code_exploration": # Include full code exploration results (prompt + output) - items.append(ThreadItem( - item_type="code_exploration", - author_name="Code Explorer", - exploration_prompt=content_data.get("prompt", ""), - exploration_output=content_data.get("output", ""), - is_focus_segment=is_focus, - )) + items.append( + ThreadItem( + item_type="code_exploration", + author_name="Code Explorer", + exploration_prompt=content_data.get("prompt", ""), + exploration_output=content_data.get("output", ""), + is_focus_segment=is_focus, + ) + ) elif item_type_str == "web_search": # Include web search query only - the conclusion comes as MFBTAI comment - items.append(ThreadItem( - item_type="web_search", - author_name="Web Search", - exploration_prompt=content_data.get("query", ""), - is_focus_segment=is_focus, - )) + items.append( + ThreadItem( + item_type="web_search", + author_name="Web Search", + exploration_prompt=content_data.get("query", ""), + is_focus_segment=is_focus, + ) + ) return items diff --git a/backend/app/agents/feature_content/orchestrator.py b/backend/app/agents/feature_content/orchestrator.py index 44b3e8d..acdffd8 100644 --- a/backend/app/agents/feature_content/orchestrator.py +++ b/backend/app/agents/feature_content/orchestrator.py @@ -14,17 +14,16 @@ from autogen_agentchat.messages import TextMessage from autogen_core.models import ChatCompletionClient -from app.agents.llm_client import create_litellm_client, LLMCallLogger from app.agents.brainstorm_spec import JobCancelledException +from app.agents.llm_client import LLMCallLogger, create_litellm_client from .types import ( AGENT_METADATA, - WORKFLOW_STEPS, + PROMPT_PLAN_SYSTEM_MESSAGE, + SPEC_SYSTEM_MESSAGE, + ContentGenerationResult, ContentType, FeatureContentContext, - ContentGenerationResult, - SPEC_SYSTEM_MESSAGE, - PROMPT_PLAN_SYSTEM_MESSAGE, ) logger = logging.getLogger(__name__) @@ -128,8 +127,8 @@ def _check_cancelled(self) -> None: if not self.job_id: return - from app.services.job_service import JobService from app.database import SessionLocal + from app.services.job_service import JobService db = SessionLocal() try: @@ -154,9 +153,9 @@ def _build_spec_prompt(self, context: FeatureContentContext) -> str: - **Feature Key**: {context.feature.feature_key} - **Title**: {context.feature.title} - **Module**: {context.module.title} -- **Module Description**: {context.module.description or 'N/A'} -- **Category**: {context.feature.category or 'N/A'} -- **Priority**: {context.feature.priority or 'N/A'} +- **Module Description**: {context.module.description or "N/A"} +- **Category**: {context.feature.category or "N/A"} +- **Priority**: {context.feature.priority or "N/A"} """ # Add feature description if available @@ -170,7 +169,7 @@ def _build_spec_prompt(self, context: FeatureContentContext) -> str: if context.implementation_id: prompt += f""" ## Implementation Context -- **Implementation Name**: {context.implementation_name or 'N/A'} +- **Implementation Name**: {context.implementation_name or "N/A"} - **Type**: {"First implementation" if context.is_first_implementation else "Subsequent implementation"} """ if not context.is_first_implementation: @@ -226,9 +225,9 @@ def _build_prompt_plan_prompt(self, context: FeatureContentContext) -> str: - **Feature Key**: {context.feature.feature_key} - **Title**: {context.feature.title} - **Module**: {context.module.title} -- **Module Description**: {context.module.description or 'N/A'} -- **Category**: {context.feature.category or 'N/A'} -- **Priority**: {context.feature.priority or 'N/A'} +- **Module Description**: {context.module.description or "N/A"} +- **Category**: {context.feature.category or "N/A"} +- **Priority**: {context.feature.priority or "N/A"} """ # Add feature description if available @@ -242,7 +241,7 @@ def _build_prompt_plan_prompt(self, context: FeatureContentContext) -> str: if context.implementation_id: prompt += f""" ## Implementation Context -- **Implementation Name**: {context.implementation_name or 'N/A'} +- **Implementation Name**: {context.implementation_name or "N/A"} - **Type**: {"First implementation" if context.is_first_implementation else "Subsequent implementation"} """ if not context.is_first_implementation: @@ -325,9 +324,7 @@ async def generate( Raises: ValueError: If generation fails """ - logger.info( - f"Generating {content_type.value} for feature {context.feature.feature_key}" - ) + logger.info(f"Generating {content_type.value} for feature {context.feature.feature_key}") # Determine agent key based on content type agent_key = "spec_generator" if content_type == ContentType.SPEC else "prompt_plan_generator" @@ -363,10 +360,7 @@ async def generate( self._check_cancelled() # Run the agent - response = await agent.on_messages( - [TextMessage(content=prompt, source="user")], - cancellation_token=None - ) + response = await agent.on_messages([TextMessage(content=prompt, source="user")], cancellation_token=None) # Extract response text response_text = response.chat_message.content @@ -380,19 +374,15 @@ async def generate( self._report_progress("complete", 100, agent_key) logger.info( - f"Generated {content_type.value} for feature {context.feature.feature_key}: " - f"{len(content_markdown)} chars" + f"Generated {content_type.value} for feature {context.feature.feature_key}: {len(content_markdown)} chars" ) # Build metadata metadata = {} - if hasattr(self.model_client, 'get_usage_stats'): + if hasattr(self.model_client, "get_usage_stats"): usage_stats = self.model_client.get_usage_stats() metadata["model"] = usage_stats.get("model") - metadata["tokens_used"] = ( - usage_stats.get("prompt_tokens", 0) + - usage_stats.get("completion_tokens", 0) - ) + metadata["tokens_used"] = usage_stats.get("prompt_tokens", 0) + usage_stats.get("completion_tokens", 0) metadata["cost_usd"] = usage_stats.get("cost_usd") return ContentGenerationResult( @@ -403,7 +393,7 @@ async def generate( def get_usage_stats(self) -> Dict[str, Any]: """Get LLM usage statistics from the model client.""" - if hasattr(self.model_client, 'get_usage_stats'): + if hasattr(self.model_client, "get_usage_stats"): return self.model_client.get_usage_stats() return {} diff --git a/backend/app/agents/feature_content/types.py b/backend/app/agents/feature_content/types.py index ae1a67f..ae9211f 100644 --- a/backend/app/agents/feature_content/types.py +++ b/backend/app/agents/feature_content/types.py @@ -14,6 +14,7 @@ class ContentType(str, Enum): """Type of content to generate.""" + SPEC = "spec" PROMPT_PLAN = "prompt_plan" @@ -21,6 +22,7 @@ class ContentType(str, Enum): @dataclass class ThreadItem: """A single item from a feature's discussion thread.""" + item_type: str # "comment", "mcq", "code_exploration", or "web_search" author_name: str body: Optional[str] = None # For comments @@ -36,6 +38,7 @@ class ThreadItem: @dataclass class FeatureInfo: """Basic feature information for context.""" + feature_id: UUID feature_key: str # e.g., "USER-001" title: str @@ -48,6 +51,7 @@ class FeatureInfo: @dataclass class ModuleInfo: """Module information for context.""" + module_id: UUID title: str description: Optional[str] = None @@ -61,6 +65,7 @@ class FeatureContentContext: Contains all information needed to generate a specification or prompt plan from a feature's discussion thread. """ + # Identifiers feature_id: UUID project_id: UUID @@ -92,6 +97,7 @@ class ContentGenerationResult: """ Result of content generation by the agent. """ + content_markdown: str # The generated spec or prompt plan content_type: ContentType metadata: Dict[str, Any] = field(default_factory=dict) @@ -105,6 +111,7 @@ class ContentGenerationResult: @dataclass class AgentInfo: """Agent display information for UI progress tracking.""" + name: str description: str color: str @@ -115,23 +122,17 @@ class AgentInfo: "spec_generator": AgentInfo( name="Spec Writer", description="Generating feature specification from conversation", - color="#3B82F6" # Blue + color="#3B82F6", # Blue ), "prompt_plan_generator": AgentInfo( name="Prompt Plan Writer", description="Creating implementation prompt plan from specification", - color="#10B981" # Green + color="#10B981", # Green ), } # Workflow step definitions for progress tracking -WORKFLOW_STEPS = [ - "start", - "gathering_context", - "generating_content", - "saving_version", - "complete" -] +WORKFLOW_STEPS = ["start", "gathering_context", "generating_content", "saving_version", "complete"] # System messages for different content types diff --git a/backend/app/agents/grounding/__init__.py b/backend/app/agents/grounding/__init__.py index b8933dd..8d805f4 100644 --- a/backend/app/agents/grounding/__init__.py +++ b/backend/app/agents/grounding/__init__.py @@ -1,11 +1,11 @@ """Grounding update agent for maintaining agents.md grounding files.""" +from .orchestrator import GroundingUpdateOrchestrator, create_orchestrator from .types import ( + GroundingChanges, GroundingUpdateContext, GroundingUpdateResult, - GroundingChanges, ) -from .orchestrator import create_orchestrator, GroundingUpdateOrchestrator __all__ = [ "GroundingUpdateContext", diff --git a/backend/app/agents/grounding/merge_orchestrator.py b/backend/app/agents/grounding/merge_orchestrator.py index 5a43aa3..3ad4b1a 100644 --- a/backend/app/agents/grounding/merge_orchestrator.py +++ b/backend/app/agents/grounding/merge_orchestrator.py @@ -8,19 +8,20 @@ import json from dataclasses import dataclass, field -from typing import Optional, Dict, Any, Callable, List +from typing import Any, Callable, Dict, List, Optional from uuid import UUID from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core.models import ChatCompletionClient -from app.agents.llm_client import create_litellm_client, LLMCallLogger +from app.agents.llm_client import LLMCallLogger, create_litellm_client @dataclass class GroundingMergeChanges: """Summary of changes made during merge.""" + added: List[str] = field(default_factory=list) updated: List[str] = field(default_factory=list) kept: List[str] = field(default_factory=list) @@ -29,6 +30,7 @@ class GroundingMergeChanges: @dataclass class GroundingMergeResult: """Result of grounding merge by the agent.""" + merged_content: str changes: GroundingMergeChanges summary: str @@ -184,12 +186,7 @@ def __init__( system_message=MERGE_SYSTEM_PROMPT, ) - def _create_model_client( - self, - provider: str, - api_key: str, - config: Dict[str, Any] - ) -> ChatCompletionClient: + def _create_model_client(self, provider: str, api_key: str, config: Dict[str, Any]) -> ChatCompletionClient: """Create a model client for the specified provider.""" model = config.get("model") if not model: @@ -246,15 +243,9 @@ def _parse_response(self, response_text: str) -> GroundingMergeResult: try: data = json.loads(response_text) except json.JSONDecodeError as e: - is_truncated = ( - "Unterminated string" in str(e) or - not response_text.rstrip().endswith("}") - ) + is_truncated = "Unterminated string" in str(e) or not response_text.rstrip().endswith("}") if is_truncated: - raise ValueError( - f"LLM response was truncated. The files may be too large. " - f"Parse error: {e}" - ) + raise ValueError(f"LLM response was truncated. The files may be too large. Parse error: {e}") raise ValueError(f"Failed to parse JSON response: {e}") changes_data = data.get("changes", {}) @@ -304,10 +295,7 @@ async def merge_grounding( self.llm_call_logger.set_agent("grounding_merger", "Grounding Merger") # Run the agent - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - cancellation_token=None - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], cancellation_token=None) if progress_callback: progress_callback("Processing response...", 80) @@ -387,10 +375,7 @@ async def pull_from_global( self.llm_call_logger.set_agent("grounding_puller", "Grounding Puller") # Run the agent - response = await pull_agent.on_messages( - [TextMessage(content=prompt, source="user")], - cancellation_token=None - ) + response = await pull_agent.on_messages([TextMessage(content=prompt, source="user")], cancellation_token=None) if progress_callback: progress_callback("Processing response...", 80) diff --git a/backend/app/agents/grounding/orchestrator.py b/backend/app/agents/grounding/orchestrator.py index 7425be6..65fa4f4 100644 --- a/backend/app/agents/grounding/orchestrator.py +++ b/backend/app/agents/grounding/orchestrator.py @@ -7,22 +7,21 @@ """ import json -from typing import Optional, Dict, Any, Callable +from typing import Any, Callable, Dict, Optional from uuid import UUID from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core.models import ChatCompletionClient -from app.agents.llm_client import create_litellm_client, LLMCallLogger +from app.agents.llm_client import LLMCallLogger, create_litellm_client from .types import ( + GroundingChanges, GroundingUpdateContext, GroundingUpdateResult, - GroundingChanges, ) - # System prompt for the grounding updater agent SYSTEM_PROMPT = """You are an expert grounding file updater for a software project management platform. @@ -149,12 +148,7 @@ def __init__( system_message=SYSTEM_PROMPT, ) - def _create_model_client( - self, - provider: str, - api_key: str, - config: Dict[str, Any] - ) -> ChatCompletionClient: + def _create_model_client(self, provider: str, api_key: str, config: Dict[str, Any]) -> ChatCompletionClient: """ Create a model client for the specified provider using LiteLLM. @@ -248,10 +242,7 @@ def _parse_response(self, response_text: str) -> GroundingUpdateResult: data = json.loads(response_text) except json.JSONDecodeError as e: # Check if this looks like a truncated response - is_truncated = ( - "Unterminated string" in str(e) or - not response_text.rstrip().endswith("}") - ) + is_truncated = "Unterminated string" in str(e) or not response_text.rstrip().endswith("}") if is_truncated: raise ValueError( f"LLM response was truncated (likely exceeded max_tokens). " @@ -276,9 +267,7 @@ def _parse_response(self, response_text: str) -> GroundingUpdateResult: ) async def update_grounding( - self, - context: GroundingUpdateContext, - progress_callback: Optional[Callable[[str, int], None]] = None + self, context: GroundingUpdateContext, progress_callback: Optional[Callable[[str, int], None]] = None ) -> GroundingUpdateResult: """ Analyze feature notes and update the grounding file. @@ -307,10 +296,7 @@ async def update_grounding( self.llm_call_logger.set_agent("grounding_updater", "Grounding Updater") # Run the agent - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - cancellation_token=None - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], cancellation_token=None) if progress_callback: progress_callback("Processing response...", 80) @@ -329,9 +315,7 @@ async def close(self): pass async def summarize_content( - self, - content: str, - progress_callback: Optional[Callable[[str, int], None]] = None + self, content: str, progress_callback: Optional[Callable[[str, int], None]] = None ) -> str: """ Generate a summary of agents.md content without modifying it. @@ -374,8 +358,7 @@ async def summarize_content( # Run the agent response = await summarize_agent.on_messages( - [TextMessage(content=prompt, source="user")], - cancellation_token=None + [TextMessage(content=prompt, source="user")], cancellation_token=None ) if progress_callback: diff --git a/backend/app/agents/grounding/types.py b/backend/app/agents/grounding/types.py index 6f75c2a..1b8e4ea 100644 --- a/backend/app/agents/grounding/types.py +++ b/backend/app/agents/grounding/types.py @@ -15,6 +15,7 @@ class GroundingUpdateContext: """ Context passed to the grounding agent for analysis. """ + project_id: UUID feature_id: UUID feature_key: str # e.g., "USER-001" @@ -29,6 +30,7 @@ class GroundingChanges: """ Summary of changes made to agents.md. """ + added: List[str] = field(default_factory=list) # New items added updated: List[str] = field(default_factory=list) # Items that were modified removed: List[str] = field(default_factory=list) # Items that were removed @@ -39,6 +41,7 @@ class GroundingUpdateResult: """ Result of grounding update by the agent. """ + updated_content: str # The new agents.md content changes: GroundingChanges summary: str # Brief summary of what was changed diff --git a/backend/app/agents/image_annotator/annotator.py b/backend/app/agents/image_annotator/annotator.py index e8d1dfe..e079efc 100644 --- a/backend/app/agents/image_annotator/annotator.py +++ b/backend/app/agents/image_annotator/annotator.py @@ -6,11 +6,12 @@ """ import logging -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import MultiModalMessage -from autogen_core import CancellationToken, Image as AGImage +from autogen_core import CancellationToken +from autogen_core import Image as AGImage from autogen_core.models import ChatCompletionClient from .types import ImageAnnotationContext, ImageAnnotationResult @@ -139,10 +140,7 @@ async def annotate( # Set agent name for call logging if self.llm_call_logger: - self.llm_call_logger.set_agent( - "image_annotator", - "Image Annotator" - ) + self.llm_call_logger.set_agent("image_annotator", "Image Annotator") # Build multimodal message with text and image # Parse data URI to extract base64 data and create AutoGen Image @@ -184,10 +182,7 @@ async def annotate( if annotation.startswith('"') and annotation.endswith('"'): annotation = annotation[1:-1] - logger.info( - f"Generated annotation for image {context.image_id}: " - f"{annotation[:100]}..." - ) + logger.info(f"Generated annotation for image {context.image_id}: {annotation[:100]}...") return ImageAnnotationResult(annotation=annotation) diff --git a/backend/app/agents/image_annotator/orchestrator.py b/backend/app/agents/image_annotator/orchestrator.py index cce0009..1f365bd 100644 --- a/backend/app/agents/image_annotator/orchestrator.py +++ b/backend/app/agents/image_annotator/orchestrator.py @@ -5,22 +5,20 @@ in pre-phase discussions. """ -import base64 import logging -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional from uuid import UUID -from sqlalchemy.orm import Session - from autogen_core.models import ChatCompletionClient +from sqlalchemy.orm import Session -from app.models.project_chat import ProjectChat -from app.models.project import Project -from app.models.organization import Organization from app.models.grounding_file import GroundingFile +from app.models.organization import Organization +from app.models.project import Project +from app.models.project_chat import ProjectChat -from .types import ImageAnnotationContext, ImageAnnotationResult from .annotator import ImageAnnotatorAgent +from .types import ImageAnnotationContext, ImageAnnotationResult if TYPE_CHECKING: from app.agents.llm_client import LLMCallLogger @@ -54,17 +52,13 @@ def load_context( ValueError: If discussion not found. """ # Load discussion - discussion = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not discussion: raise ValueError(f"Discussion {project_chat_id} not found") # Load organization - organization = db.query(Organization).filter( - Organization.id == discussion.org_id - ).first() + organization = db.query(Organization).filter(Organization.id == discussion.org_id).first() if not organization: raise ValueError(f"Organization {discussion.org_id} not found") @@ -88,9 +82,7 @@ def load_context( ) # Project-scoped discussion - load project context - project = db.query(Project).filter( - Project.id == discussion.project_id - ).first() + project = db.query(Project).filter(Project.id == discussion.project_id).first() if not project: raise ValueError(f"Project {discussion.project_id} not found") @@ -99,10 +91,11 @@ def load_context( has_grounding = False grounding_summary = None - grounding_file = db.query(GroundingFile).filter( - GroundingFile.project_id == project.id, - GroundingFile.filename == "agents.md" - ).first() + grounding_file = ( + db.query(GroundingFile) + .filter(GroundingFile.project_id == project.id, GroundingFile.filename == "agents.md") + .first() + ) if grounding_file and grounding_file.summary: has_grounding = True @@ -223,7 +216,7 @@ async def handle_image_annotation( # Get usage stats from model client usage_stats = {} - if hasattr(model_client, 'get_usage_stats'): + if hasattr(model_client, "get_usage_stats"): usage_stats = model_client.get_usage_stats() return { diff --git a/backend/app/agents/image_annotator/types.py b/backend/app/agents/image_annotator/types.py index 856ee85..9a95fd2 100644 --- a/backend/app/agents/image_annotator/types.py +++ b/backend/app/agents/image_annotator/types.py @@ -18,6 +18,7 @@ class ImageAnnotationContext: Contains all information the agent needs to generate a meaningful annotation for an uploaded image. """ + # Image info image_id: str image_filename: str @@ -51,6 +52,7 @@ class ImageAnnotationResult: Contains the generated annotation for the image. """ + # The annotation text describing the image annotation: str diff --git a/backend/app/agents/llm_client.py b/backend/app/agents/llm_client.py index 8579b0a..a6f8203 100644 --- a/backend/app/agents/llm_client.py +++ b/backend/app/agents/llm_client.py @@ -25,13 +25,10 @@ import litellm from autogen_core import CancellationToken, FunctionCall - -from app.agents.retry import llm_retry from autogen_core.models import ( AssistantMessage, ChatCompletionClient, CreateResult, - FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, ModelCapabilities, @@ -44,6 +41,8 @@ from autogen_core.tools import Tool, ToolSchema from pydantic import BaseModel +from app.agents.retry import llm_retry + logger = logging.getLogger(__name__) # Suppress LiteLLM's verbose logging @@ -189,9 +188,9 @@ def log_call( try: # Import here to avoid circular imports + from app.models.job import Job from app.services.llm_call_log_service import LLMCallLogService from app.services.llm_usage_log_service import LLMUsageLogService - from app.models.job import Job db = self.db_session_factory() try: @@ -324,11 +323,17 @@ def __init__( def _supports_vision(self) -> bool: """Check if model supports vision/images.""" model_lower = self._model.lower() - return any(x in model_lower for x in [ - "gpt-4o", "gpt-4-turbo", "gpt-4-vision", - "claude-3", "claude-3.5", - "gemini", - ]) + return any( + x in model_lower + for x in [ + "gpt-4o", + "gpt-4-turbo", + "gpt-4-vision", + "claude-3", + "claude-3.5", + "gemini", + ] + ) def _convert_messages(self, messages: Sequence[LLMMessage]) -> List[Dict[str, Any]]: """ @@ -344,17 +349,21 @@ def _convert_messages(self, messages: Sequence[LLMMessage]) -> List[Dict[str, An for msg in messages: if isinstance(msg, SystemMessage): - result.append({ - "role": "system", - "content": msg.content, - }) + result.append( + { + "role": "system", + "content": msg.content, + } + ) elif isinstance(msg, UserMessage): # Handle string or list content (for images) if isinstance(msg.content, str): - result.append({ - "role": "user", - "content": msg.content, - }) + result.append( + { + "role": "user", + "content": msg.content, + } + ) else: # Handle multimodal content (text + images) content_parts = [] @@ -363,45 +372,52 @@ def _convert_messages(self, messages: Sequence[LLMMessage]) -> List[Dict[str, An content_parts.append({"type": "text", "text": part}) else: # Image object - convert to base64 URL - content_parts.append({ - "type": "image_url", - "image_url": {"url": part.data_uri} - }) - result.append({ - "role": "user", - "content": content_parts, - }) + content_parts.append({"type": "image_url", "image_url": {"url": part.data_uri}}) + result.append( + { + "role": "user", + "content": content_parts, + } + ) elif isinstance(msg, AssistantMessage): if isinstance(msg.content, str): - result.append({ - "role": "assistant", - "content": msg.content, - }) + result.append( + { + "role": "assistant", + "content": msg.content, + } + ) else: # Function calls tool_calls = [] for fc in msg.content: - tool_calls.append({ - "id": fc.id, - "type": "function", - "function": { - "name": fc.name, - "arguments": fc.arguments, + tool_calls.append( + { + "id": fc.id, + "type": "function", + "function": { + "name": fc.name, + "arguments": fc.arguments, + }, } - }) - result.append({ - "role": "assistant", - "content": None, - "tool_calls": tool_calls, - }) + ) + result.append( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + } + ) elif isinstance(msg, FunctionExecutionResultMessage): # Add tool results for fr in msg.content: - result.append({ - "role": "tool", - "tool_call_id": fr.call_id, - "content": fr.content, - }) + result.append( + { + "role": "tool", + "tool_call_id": fr.call_id, + "content": fr.content, + } + ) return result @@ -414,14 +430,16 @@ def _convert_tools(self, tools: Sequence[Tool | ToolSchema]) -> List[Dict[str, A else: schema = tool - result.append({ - "type": "function", - "function": { - "name": schema["name"], - "description": schema.get("description", ""), - "parameters": schema.get("parameters", {"type": "object", "properties": {}}), + result.append( + { + "type": "function", + "function": { + "name": schema["name"], + "description": schema.get("description", ""), + "parameters": schema.get("parameters", {"type": "object", "properties": {}}), + }, } - }) + ) return result def _parse_response(self, response: Any, cached: bool = False) -> CreateResult: @@ -467,7 +485,9 @@ def _parse_response(self, response: Any, cached: bool = False) -> CreateResult: thought=getattr(message, "reasoning_content", None), ) - def _normalize_finish_reason(self, reason: Optional[str]) -> Literal["stop", "length", "function_calls", "content_filter", "unknown"]: + def _normalize_finish_reason( + self, reason: Optional[str] + ) -> Literal["stop", "length", "function_calls", "content_filter", "unknown"]: """Normalize finish reason to AutoGen's expected values.""" if reason is None: return "unknown" @@ -585,10 +605,7 @@ async def create( if tools: kwargs["tools"] = self._convert_tools(tools) if isinstance(tool_choice, Tool): - kwargs["tool_choice"] = { - "type": "function", - "function": {"name": tool_choice.schema["name"]} - } + kwargs["tool_choice"] = {"type": "function", "function": {"name": tool_choice.schema["name"]}} elif tool_choice != "auto": kwargs["tool_choice"] = tool_choice @@ -675,10 +692,7 @@ async def create_stream( if tools: kwargs["tools"] = self._convert_tools(tools) if isinstance(tool_choice, Tool): - kwargs["tool_choice"] = { - "type": "function", - "function": {"name": tool_choice.schema["name"]} - } + kwargs["tool_choice"] = {"type": "function", "function": {"name": tool_choice.schema["name"]}} elif tool_choice != "auto": kwargs["tool_choice"] = tool_choice @@ -702,16 +716,16 @@ async def create_stream( try: async for chunk in response: - if hasattr(chunk, 'choices') and chunk.choices: + if hasattr(chunk, "choices") and chunk.choices: delta = chunk.choices[0].delta # Yield content chunks - if hasattr(delta, 'content') and delta.content: + if hasattr(delta, "content") and delta.content: full_content += delta.content yield delta.content # Accumulate tool calls - if hasattr(delta, 'tool_calls') and delta.tool_calls: + if hasattr(delta, "tool_calls") and delta.tool_calls: for tc in delta.tool_calls: idx = tc.index while len(tool_calls) <= idx: @@ -729,7 +743,7 @@ async def create_stream( finish_reason = chunk.choices[0].finish_reason # Get usage from final chunk if available - if hasattr(chunk, 'usage') and chunk.usage: + if hasattr(chunk, "usage") and chunk.usage: prompt_tokens = chunk.usage.prompt_tokens completion_tokens = chunk.usage.completion_tokens except Exception as e: @@ -818,10 +832,7 @@ def count_tokens( return count except Exception: # Fallback: rough estimate - total_chars = sum( - len(str(m.get("content", ""))) - for m in litellm_messages - ) + total_chars = sum(len(str(m.get("content", ""))) for m in litellm_messages) return total_chars // 4 # Rough estimate: 4 chars per token def remaining_tokens( diff --git a/backend/app/agents/module_feature/__init__.py b/backend/app/agents/module_feature/__init__.py index c71e6bf..015ea36 100644 --- a/backend/app/agents/module_feature/__init__.py +++ b/backend/app/agents/module_feature/__init__.py @@ -15,41 +15,41 @@ - prompt_plan_text: HOW to build (step-by-step instructions) """ +from .merger import MergerAgent +from .orchestrator import ModuleFeatureOrchestrator, create_orchestrator +from .plan_structurer import PlanStructurerAgent +from .spec_analyzer import SpecAnalyzerAgent from .types import ( - # Input/Output types - ModuleFeatureContext, - ExtractedModule, + AGENT_METADATA, + # UI metadata + WORKFLOW_STEPS, + AgentInfo, + CoverageReport, ExtractedFeature, + ExtractedModule, ExtractionResult, - # Agent intermediate types - SpecAnalysis, - SpecRequirement, - PlanStructure, + FeatureCategoryType, + FeatureContent, + FeatureMapping, + FeaturePriorityLevel, ImplementationPhase, ImplementationStep, MergedMapping, - ModuleMapping, - FeatureMapping, - FeatureContent, - WriterOutput, - CoverageReport, # Enums ModuleCategory, - FeaturePriorityLevel, - FeatureCategoryType, - # UI metadata - WORKFLOW_STEPS, - AGENT_METADATA, - AgentInfo, + # Input/Output types + ModuleFeatureContext, + ModuleMapping, + PlanStructure, + # Agent intermediate types + SpecAnalysis, + SpecRequirement, + WriterOutput, # Helpers validate_extraction_result, ) -from .orchestrator import ModuleFeatureOrchestrator, create_orchestrator -from .spec_analyzer import SpecAnalyzerAgent -from .plan_structurer import PlanStructurerAgent -from .merger import MergerAgent -from .writer import WriterAgent from .validator import ValidatorAgent +from .writer import WriterAgent __all__ = [ # Input/Output Types diff --git a/backend/app/agents/module_feature/logging_config.py b/backend/app/agents/module_feature/logging_config.py index 37722d0..76460aa 100644 --- a/backend/app/agents/module_feature/logging_config.py +++ b/backend/app/agents/module_feature/logging_config.py @@ -12,10 +12,10 @@ - Coverage Validator (Agent 5) """ -import logging import json -from typing import Any, Dict, List, Optional +import logging from datetime import datetime, timezone +from typing import Any, Dict, List, Optional class ModuleFeatureAgentLogger: @@ -41,19 +41,12 @@ def __init__(self, agent_name: str, project_id: Optional[str] = None): # Ensure structured output if not self.logger.handlers: handler = logging.StreamHandler() - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO) - def _structured_log( - self, - level: str, - event: str, - extra_data: Optional[Dict[str, Any]] = None - ) -> None: + def _structured_log(self, level: str, event: str, extra_data: Optional[Dict[str, Any]] = None) -> None: """ Log a structured event. @@ -86,12 +79,7 @@ def log_agent_complete(self, **kwargs: Any) -> None: self._structured_log("info", f"{self.agent_name}_complete", kwargs) def log_llm_call( - self, - prompt: str, - model: str, - response: Optional[str] = None, - tokens_used: Optional[int] = None, - **kwargs: Any + self, prompt: str, model: str, response: Optional[str] = None, tokens_used: Optional[int] = None, **kwargs: Any ) -> None: """ Log an LLM API call. @@ -107,7 +95,7 @@ def log_llm_call( "model": model, "prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt, "prompt_length": len(prompt), - **kwargs + **kwargs, } if response: @@ -119,13 +107,7 @@ def log_llm_call( self._structured_log("info", "llm_call", data) - def log_extraction_stats( - self, - modules_count: int, - total_features: int, - spec_length: int, - **kwargs: Any - ) -> None: + def log_extraction_stats(self, modules_count: int, total_features: int, spec_length: int, **kwargs: Any) -> None: """ Log extraction statistics. @@ -140,7 +122,7 @@ def log_extraction_stats( "total_features": total_features, "spec_length": spec_length, "avg_features_per_module": round(total_features / modules_count, 2) if modules_count > 0 else 0, - **kwargs + **kwargs, } self._structured_log("info", "extraction_stats", data) @@ -171,11 +153,7 @@ def log_workflow_transition(self, from_state: str, to_state: str, **kwargs: Any) to_state: New state **kwargs: Additional context """ - data = { - "from_state": from_state, - "to_state": to_state, - **kwargs - } + data = {"from_state": from_state, "to_state": to_state, **kwargs} self._structured_log("info", "workflow_transition", data) def log_validation_issues(self, issues: list, **kwargs: Any) -> None: @@ -190,7 +168,7 @@ def log_validation_issues(self, issues: list, **kwargs: Any) -> None: "issues_count": len(issues), "issues": issues[:10], # Limit to first 10 issues "has_issues": len(issues) > 0, - **kwargs + **kwargs, } self._structured_log("warning" if issues else "info", "validation_issues", data) @@ -200,7 +178,7 @@ def log_spec_analysis( domain_areas: List[str], data_models_count: int, api_endpoints_count: int, - **kwargs: Any + **kwargs: Any, ) -> None: """ Log spec analysis results from Agent 1. @@ -218,16 +196,12 @@ def log_spec_analysis( "domain_areas_count": len(domain_areas), "data_models_count": data_models_count, "api_endpoints_count": api_endpoints_count, - **kwargs + **kwargs, } self._structured_log("info", "spec_analysis_complete", data) def log_plan_structure( - self, - phases_count: int, - total_steps: int, - cross_cutting_concerns: List[str], - **kwargs: Any + self, phases_count: int, total_steps: int, cross_cutting_concerns: List[str], **kwargs: Any ) -> None: """ Log plan structure results from Agent 2. @@ -243,17 +217,12 @@ def log_plan_structure( "total_steps": total_steps, "cross_cutting_concerns": cross_cutting_concerns, "avg_steps_per_phase": round(total_steps / phases_count, 2) if phases_count > 0 else 0, - **kwargs + **kwargs, } self._structured_log("info", "plan_structure_complete", data) def log_merge_result( - self, - modules_count: int, - features_count: int, - unmapped_requirements: int, - unmapped_steps: int, - **kwargs: Any + self, modules_count: int, features_count: int, unmapped_requirements: int, unmapped_steps: int, **kwargs: Any ) -> None: """ Log merge results from Agent 3. @@ -272,17 +241,15 @@ def log_merge_result( "unmapped_steps": unmapped_steps, "mapping_success_rate": round( (features_count / (features_count + unmapped_requirements + unmapped_steps)) * 100, 2 - ) if (features_count + unmapped_requirements + unmapped_steps) > 0 else 100, - **kwargs + ) + if (features_count + unmapped_requirements + unmapped_steps) > 0 + else 100, + **kwargs, } self._structured_log("info", "merge_complete", data) def log_content_written( - self, - features_processed: int, - avg_spec_text_length: int, - avg_prompt_plan_text_length: int, - **kwargs: Any + self, features_processed: int, avg_spec_text_length: int, avg_prompt_plan_text_length: int, **kwargs: Any ) -> None: """ Log content writing results from Agent 4. @@ -297,7 +264,7 @@ def log_content_written( "features_processed": features_processed, "avg_spec_text_length": avg_spec_text_length, "avg_prompt_plan_text_length": avg_prompt_plan_text_length, - **kwargs + **kwargs, } self._structured_log("info", "content_written", data) @@ -308,7 +275,7 @@ def log_coverage_report( uncovered_requirements: int, uncovered_steps: int, content_issues: int, - **kwargs: Any + **kwargs: Any, ) -> None: """ Log coverage validation results from Agent 5. @@ -327,7 +294,7 @@ def log_coverage_report( "uncovered_requirements": uncovered_requirements, "uncovered_steps": uncovered_steps, "content_issues": content_issues, - **kwargs + **kwargs, } level = "info" if ok else "warning" self._structured_log(level, "coverage_report", data) diff --git a/backend/app/agents/module_feature/merger.py b/backend/app/agents/module_feature/merger.py index 794bf96..21205ac 100644 --- a/backend/app/agents/module_feature/merger.py +++ b/backend/app/agents/module_feature/merger.py @@ -10,26 +10,25 @@ import asyncio import json from collections import defaultdict -from typing import Optional, List, Dict, Any, Set +from typing import Any, Dict, List, Optional, Set from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( - ModuleFeatureContext, - SpecAnalysis, - SpecRequirement, - PlanStructure, + FeatureMapping, ImplementationPhase, MergedMapping, + ModuleFeatureContext, ModuleMapping, - FeatureMapping, + PlanStructure, + SpecAnalysis, + SpecRequirement, ) -from .logging_config import get_agent_logger -from .utils import strip_markdown_json, generate_unique_id, generate_semantic_id, parse_json_with_repair - +from .utils import generate_semantic_id, parse_json_with_repair, strip_markdown_json # Default domain to phase keyword mapping (fallback when phase_domain_mapping not available) DEFAULT_DOMAIN_KEYWORDS = { @@ -207,7 +206,7 @@ async def merge( project_id=str(context.project_id), requirements_count=len(spec_analysis.requirements), phases_count=len(plan_structure.phases), - total_steps=plan_structure.total_steps + total_steps=plan_structure.total_steps, ) try: @@ -244,17 +243,13 @@ async def _merge_parallel( Returns: Combined MergedMapping from all phases """ - self.logger.logger.info( - f"Using parallel phase processing for {len(plan_structure.phases)} phases" - ) + self.logger.logger.info(f"Using parallel phase processing for {len(plan_structure.phases)} phases") # Extract phase_domain_mapping from prompt plan phase_mappings = self._extract_phase_domain_mapping(context) # Create mapping dict for quick lookup - mapping_by_index = { - m.get("phase_index"): m for m in phase_mappings - } + mapping_by_index = {m.get("phase_index"): m for m in phase_mappings} # SINGLE-BEST MATCHING: Pre-compute requirement assignments sequentially # This ensures each requirement is assigned to only ONE phase @@ -272,9 +267,7 @@ async def _merge_parallel( break # Match requirements to this phase - matching_reqs = self._match_requirements_to_phase( - phase, spec_analysis.requirements, phase_mapping - ) + matching_reqs = self._match_requirements_to_phase(phase, spec_analysis.requirements, phase_mapping) # Filter out requirements already matched to earlier phases matching_reqs = [r for r in matching_reqs if r.id not in matched_requirement_ids] @@ -289,9 +282,7 @@ async def _merge_parallel( # Create parallel tasks using pre-computed assignments tasks = [] for phase in plan_structure.phases: - tasks.append( - self._merge_phase(phase, phase_requirements[phase.phase_index], spec_analysis, context) - ) + tasks.append(self._merge_phase(phase, phase_requirements[phase.phase_index], spec_analysis, context)) # Run all phases in parallel results = await asyncio.gather(*tasks, return_exceptions=True) @@ -300,32 +291,39 @@ async def _merge_parallel( phase_results = [] for i, result in enumerate(results): if isinstance(result, Exception): - self.logger.log_error(result, { - "phase_index": plan_structure.phases[i].phase_index, - "phase_title": plan_structure.phases[i].title, - }) + self.logger.log_error( + result, + { + "phase_index": plan_structure.phases[i].phase_index, + "phase_title": plan_structure.phases[i].title, + }, + ) # Use fallback for failed phases phase = plan_structure.phases[i] - phase_results.append({ - "phase_index": phase.phase_index, - "module": { - "title": phase.title, - "description": phase.objective, - "order_index": phase.phase_index, - "category": "phase", - "phase_reference": phase.phase_index, - }, - "features": [{ - "title": f"Implement {phase.title}", + phase_results.append( + { "phase_index": phase.phase_index, - "step_index": 1, - "spec_requirement_ids": [], - "plan_step_id": f"P{phase.phase_index}-S1", - "priority": "important", - "category": "Other", - }], - "unmapped_requirements": [], - }) + "module": { + "title": phase.title, + "description": phase.objective, + "order_index": phase.phase_index, + "category": "phase", + "phase_reference": phase.phase_index, + }, + "features": [ + { + "title": f"Implement {phase.title}", + "phase_index": phase.phase_index, + "step_index": 1, + "spec_requirement_ids": [], + "plan_step_id": f"P{phase.phase_index}-S1", + "priority": "important", + "category": "Other", + } + ], + "unmapped_requirements": [], + } + ) else: phase_results.append(result) @@ -385,13 +383,10 @@ async def _merge_single_call( self.logger.log_llm_call( prompt=prompt[:500] + "..." if len(prompt) > 500 else prompt, model=str(self.model_client), - operation="merge_spec_and_plan" + operation="merge_spec_and_plan", ) - response = await self.agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await self.agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -402,12 +397,15 @@ async def _merge_single_call( result_data = parse_json_with_repair(response_text) except json.JSONDecodeError as e: cleaned = strip_markdown_json(response_text) - self.logger.log_error(e, { - "raw_response": response_text[:500], - "cleaned_response": cleaned[:500], - "response_length": len(response_text), - "error_position": e.pos if hasattr(e, 'pos') else None - }) + self.logger.log_error( + e, + { + "raw_response": response_text[:500], + "cleaned_response": cleaned[:500], + "response_length": len(response_text), + "error_position": e.pos if hasattr(e, "pos") else None, + }, + ) raise ValueError(f"Failed to parse merger response as JSON: {e}") # Convert to MergedMapping @@ -415,32 +413,36 @@ async def _merge_single_call( raw_modules = result_data.get("modules", []) for idx, mod_data in enumerate(raw_modules, start=1): mod_title = mod_data.get("title", f"Module {idx}") - modules.append(ModuleMapping( - module_id=generate_semantic_id("MOD", mod_title, idx), - title=mod_title, - description=mod_data.get("description", ""), - order_index=mod_data.get("order_index", idx), - category=mod_data.get("category", "phase"), - phase_reference=mod_data.get("phase_reference"), - feature_ids=mod_data.get("feature_ids", []), - )) + modules.append( + ModuleMapping( + module_id=generate_semantic_id("MOD", mod_title, idx), + title=mod_title, + description=mod_data.get("description", ""), + order_index=mod_data.get("order_index", idx), + category=mod_data.get("category", "phase"), + phase_reference=mod_data.get("phase_reference"), + feature_ids=mod_data.get("feature_ids", []), + ) + ) features = [] raw_features = result_data.get("features", []) for idx, feat_data in enumerate(raw_features, start=1): feat_title = feat_data.get("title", f"Feature {idx}") - features.append(FeatureMapping( - feature_id=generate_semantic_id("FEAT", feat_title, idx), - title=feat_title, - module_id=feat_data.get("module_id", "MOD-001"), - phase_index=feat_data.get("phase_index", 0), - step_index=feat_data.get("step_index", idx), - global_order=feat_data.get("global_order", idx), - spec_requirement_ids=feat_data.get("spec_requirement_ids", []), - plan_step_id=feat_data.get("plan_step_id"), - priority=feat_data.get("priority", "important"), - category=feat_data.get("category", "Other"), - )) + features.append( + FeatureMapping( + feature_id=generate_semantic_id("FEAT", feat_title, idx), + title=feat_title, + module_id=feat_data.get("module_id", "MOD-001"), + phase_index=feat_data.get("phase_index", 0), + step_index=feat_data.get("step_index", idx), + global_order=feat_data.get("global_order", idx), + spec_requirement_ids=feat_data.get("spec_requirement_ids", []), + plan_step_id=feat_data.get("plan_step_id"), + priority=feat_data.get("priority", "important"), + category=feat_data.get("category", "Other"), + ) + ) merged = MergedMapping( modules=modules, @@ -464,10 +466,7 @@ async def _merge_single_call( return merged def _build_prompt( - self, - spec_analysis: SpecAnalysis, - plan_structure: PlanStructure, - context: ModuleFeatureContext + self, spec_analysis: SpecAnalysis, plan_structure: PlanStructure, context: ModuleFeatureContext ) -> str: """ Build the merge prompt from spec analysis and plan structure. @@ -534,9 +533,9 @@ def _build_prompt( prompt += "\n" # Add cross-phase context if available (for awareness only) - if hasattr(context, 'cross_project_context') and context.cross_project_context: + if hasattr(context, "cross_project_context") and context.cross_project_context: cross_ctx = context.cross_project_context - if hasattr(cross_ctx, 'other_phases') and cross_ctx.other_phases: + if hasattr(cross_ctx, "other_phases") and cross_ctx.other_phases: prompt += "## DECISIONS FROM OTHER PHASES (for awareness only):\n\n" prompt += "Use these for consistency. Do NOT add their requirements to this phase.\n\n" for phase_ctx in cross_ctx.other_phases[:3]: # Limit to top 3 phases @@ -720,10 +719,7 @@ async def _merge_phase( model_client=self.model_client, ) - response = await phase_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await phase_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -734,10 +730,7 @@ async def _merge_phase( result["phase_index"] = phase.phase_index return result except json.JSONDecodeError as e: - self.logger.log_error(e, { - "phase": phase.title, - "response_preview": response_text[:300] - }) + self.logger.log_error(e, {"phase": phase.title, "response_preview": response_text[:300]}) # Return minimal fallback return { "phase_index": phase.phase_index, @@ -809,18 +802,20 @@ def _combine_phase_results( feature_id = generate_semantic_id("FEAT", feat_title, feat_counter) module.feature_ids.append(feature_id) - features.append(FeatureMapping( - feature_id=feature_id, - title=feat_title, - module_id=module.module_id, - phase_index=feat_data.get("phase_index", result.get("phase_index", 0)), - step_index=feat_data.get("step_index", feat_counter), - global_order=feat_counter, - spec_requirement_ids=feat_data.get("spec_requirement_ids", []), - plan_step_id=feat_data.get("plan_step_id"), - priority=feat_data.get("priority", "important"), - category=feat_data.get("category", "Other"), - )) + features.append( + FeatureMapping( + feature_id=feature_id, + title=feat_title, + module_id=module.module_id, + phase_index=feat_data.get("phase_index", result.get("phase_index", 0)), + step_index=feat_data.get("step_index", feat_counter), + global_order=feat_counter, + spec_requirement_ids=feat_data.get("spec_requirement_ids", []), + plan_step_id=feat_data.get("plan_step_id"), + priority=feat_data.get("priority", "important"), + category=feat_data.get("category", "Other"), + ) + ) modules.append(module) @@ -902,10 +897,7 @@ def _deduplicate_features(self, merged: MergedMapping) -> MergedMapping: return merged -async def create_merger( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> MergerAgent: +async def create_merger(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> MergerAgent: """ Factory function to create a Merger Agent. diff --git a/backend/app/agents/module_feature/orchestrator.py b/backend/app/agents/module_feature/orchestrator.py index df56140..76d61ef 100644 --- a/backend/app/agents/module_feature/orchestrator.py +++ b/backend/app/agents/module_feature/orchestrator.py @@ -11,29 +11,29 @@ import asyncio import logging -from typing import Optional, Callable, Dict, Any +from typing import Any, Callable, Dict, Optional from uuid import UUID from autogen_core.models import ChatCompletionClient +# Re-use the same exception from brainstorm_spec +from app.agents.brainstorm_spec import JobCancelledException + +from .logging_config import get_agent_logger +from .merger import MergerAgent +from .plan_structurer import PlanStructurerAgent +from .spec_analyzer import SpecAnalyzerAgent from .types import ( - ModuleFeatureContext, - ExtractionResult, - ExtractedModule, - ExtractedFeature, - WORKFLOW_STEPS, AGENT_METADATA, + WORKFLOW_STEPS, + ExtractedFeature, + ExtractedModule, + ExtractionResult, + ModuleFeatureContext, validate_extraction_result, ) -from .logging_config import get_agent_logger -from .spec_analyzer import SpecAnalyzerAgent -from .plan_structurer import PlanStructurerAgent -from .merger import MergerAgent -from .writer import WriterAgent from .validator import ValidatorAgent - -# Re-use the same exception from brainstorm_spec -from app.agents.brainstorm_spec import JobCancelledException +from .writer import WriterAgent logger = logging.getLogger(__name__) @@ -87,8 +87,8 @@ def _check_cancelled(self) -> None: if not self.job_id: return - from app.services.job_service import JobService from app.database import SessionLocal + from app.services.job_service import JobService db = SessionLocal() try: @@ -121,7 +121,7 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra self.logger.log_workflow_transition( from_state=WORKFLOW_STEPS[0], # "start" - to_state=WORKFLOW_STEPS[1] # "analyzing_spec" + to_state=WORKFLOW_STEPS[1], # "analyzing_spec" ) try: @@ -142,12 +142,12 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra self._report_progress( f"Extracted {len(spec_analysis.requirements)} requirements and {len(plan_structure.phases)} phases", 30, - "plan_structurer" + "plan_structurer", ) self.logger.log_workflow_transition( from_state=WORKFLOW_STEPS[1], # "analyzing_spec" - to_state=WORKFLOW_STEPS[3] # "merging" (skip structuring_plan since parallel) + to_state=WORKFLOW_STEPS[3], # "merging" (skip structuring_plan since parallel) ) # Step 3: Merge Spec and Plan (30-50%) @@ -158,7 +158,7 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra self._report_progress( f"Merger: Created {len(merged_mapping.modules)} modules with {len(merged_mapping.features)} features", 50, - "merger" + "merger", ) # Check for cancellation after merger @@ -166,20 +166,16 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra self.logger.log_workflow_transition( from_state=WORKFLOW_STEPS[3], # "merging" - to_state=WORKFLOW_STEPS[4] # "writing_content" + to_state=WORKFLOW_STEPS[4], # "writing_content" ) # Step 4: Write Feature Content (50-85%) self._report_progress("Content Writer: Generating spec_text and prompt_plan_text", 50, "writer") if self.call_logger: self.call_logger.set_agent("writer", "Content Writer") - writer_output = await self.writer.write_all( - merged_mapping, spec_analysis, plan_structure, context - ) + writer_output = await self.writer.write_all(merged_mapping, spec_analysis, plan_structure, context) self._report_progress( - f"Content Writer: Generated content for {len(writer_output.feature_contents)} features", - 85, - "writer" + f"Content Writer: Generated content for {len(writer_output.feature_contents)} features", 85, "writer" ) # Check for cancellation after writer @@ -187,7 +183,7 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra self.logger.log_workflow_transition( from_state=WORKFLOW_STEPS[4], # "writing_content" - to_state=WORKFLOW_STEPS[5] # "validating" + to_state=WORKFLOW_STEPS[5], # "validating" ) # Step 5: Validate Coverage (85-100%) @@ -195,11 +191,7 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra coverage_report = await self.validator.validate( spec_analysis, plan_structure, merged_mapping, writer_output ) - self._report_progress( - f"Validator: Coverage {coverage_report.coverage_percentage}%", - 95, - "validator" - ) + self._report_progress(f"Validator: Coverage {coverage_report.coverage_percentage}%", 95, "validator") # Build final result (pass context for image resolution) result = self._build_final_result(merged_mapping, writer_output, coverage_report, context) @@ -214,7 +206,7 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra self.logger.log_workflow_transition( from_state=WORKFLOW_STEPS[5], # "validating" - to_state=WORKFLOW_STEPS[6] # "complete" + to_state=WORKFLOW_STEPS[6], # "complete" ) self.logger.log_agent_complete( @@ -227,10 +219,7 @@ async def extract_modules_features(self, context: ModuleFeatureContext) -> Extra return result except Exception as e: - self.logger.log_error(e, { - "project_id": str(context.project_id), - "workflow_step": "unknown" - }) + self.logger.log_error(e, {"project_id": str(context.project_id), "workflow_step": "unknown"}) raise def _build_final_result( @@ -253,16 +242,11 @@ def _build_final_result( Final ExtractionResult ready for persistence """ # Build content lookup - content_map = { - fc.feature_id: fc - for fc in writer_output.feature_contents - } + content_map = {fc.feature_id: fc for fc in writer_output.feature_contents} # Build image attachment lookup by ID for quick resolution image_attachment_map = { - img.get("id"): img - for img in context.phase_description_image_attachments - if img.get("id") + img.get("id"): img for img in context.phase_description_image_attachments if img.get("id") } # Build module-to-features lookup @@ -282,9 +266,15 @@ def _build_final_result( features = [] for feat in module_features.get(mod.module_id, []): content = content_map.get(feat.feature_id) - description = content.description if content else f"This feature enables {feat.title.lower()} functionality for users." + description = ( + content.description + if content + else f"This feature enables {feat.title.lower()} functionality for users." + ) spec_text = content.spec_text if content else f"## {feat.title}\n\nImplement as specified." - prompt_plan_text = content.prompt_plan_text if content else f"## Implementation\n\n1. Implement {feat.title}" + prompt_plan_text = ( + content.prompt_plan_text if content else f"## Implementation\n\n1. Implement {feat.title}" + ) # Resolve relevant_image_ids to full image attachment dicts description_image_attachments = [] @@ -293,26 +283,30 @@ def _build_final_result( if img_id in image_attachment_map: description_image_attachments.append(image_attachment_map[img_id]) - features.append(ExtractedFeature( - title=feat.title, - description=description, - spec_text=spec_text, - prompt_plan_text=prompt_plan_text, - priority=feat.priority, - category=feat.category, - order_index=feat.step_index, - spec_requirement_refs=feat.spec_requirement_ids, - description_image_attachments=description_image_attachments, - )) - - modules.append(ExtractedModule( - title=mod.title, - description=mod.description, - order_index=mod.order_index, - module_category=mod.category, - phase_reference=mod.phase_reference, - features=features, - )) + features.append( + ExtractedFeature( + title=feat.title, + description=description, + spec_text=spec_text, + prompt_plan_text=prompt_plan_text, + priority=feat.priority, + category=feat.category, + order_index=feat.step_index, + spec_requirement_refs=feat.spec_requirement_ids, + description_image_attachments=description_image_attachments, + ) + ) + + modules.append( + ExtractedModule( + title=mod.title, + description=mod.description, + order_index=mod.order_index, + module_category=mod.category, + phase_reference=mod.phase_reference, + features=features, + ) + ) # Sort modules by order_index modules.sort(key=lambda m: m.order_index) @@ -394,7 +388,7 @@ async def create_orchestrator( Raises: ValueError: If provider is unsupported """ - from app.agents.llm_client import create_litellm_client, LLMCallLogger + from app.agents.llm_client import LLMCallLogger, create_litellm_client model = config.get("model") if not model: @@ -415,6 +409,7 @@ async def create_orchestrator( call_logger = None if job_id: from app.database import SessionLocal + call_logger = LLMCallLogger( db_session_factory=SessionLocal, job_id=job_id, diff --git a/backend/app/agents/module_feature/plan_structurer.py b/backend/app/agents/module_feature/plan_structurer.py index 3a19601..2529898 100644 --- a/backend/app/agents/module_feature/plan_structurer.py +++ b/backend/app/agents/module_feature/plan_structurer.py @@ -9,21 +9,21 @@ import asyncio import json -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( - ModuleFeatureContext, - PlanStructure, ImplementationPhase, ImplementationStep, + ModuleFeatureContext, + PlanStructure, ) -from .logging_config import get_agent_logger -from .utils import strip_markdown_json, extract_markdown_sections +from .utils import extract_markdown_sections, strip_markdown_json class PlanStructurerAgent: @@ -112,7 +112,7 @@ async def structure( self.logger.log_agent_start( project_id=str(context.project_id), plan_length=len(context.prompt_plan_markdown), - has_json_plan=context.prompt_plan_json is not None + has_json_plan=context.prompt_plan_json is not None, ) try: @@ -123,9 +123,7 @@ async def structure( if phase_mapping: # Use phase_domain_mapping to guide phase extraction - self.logger.logger.info( - f"Using phase_domain_mapping with {len(phase_mapping)} defined phases" - ) + self.logger.logger.info(f"Using phase_domain_mapping with {len(phase_mapping)} defined phases") return await self._structure_with_phase_mapping(context, phase_mapping) # Fallback: Split plan into sections (legacy behavior) @@ -133,19 +131,12 @@ async def structure( if not sections: # Fallback: treat entire plan as one section - sections = [{ - 'id': 'full_plan', - 'title': 'Full Prompt Plan', - 'content': context.prompt_plan_markdown - }] + sections = [{"id": "full_plan", "title": "Full Prompt Plan", "content": context.prompt_plan_markdown}] self.logger.logger.info(f"Processing {len(sections)} sections in parallel (no phase_domain_mapping)") # Process all sections in parallel - tasks = [ - self._structure_section(section, context) - for section in sections - ] + tasks = [self._structure_section(section, context) for section in sections] section_results = await asyncio.gather(*tasks, return_exceptions=True) # Merge results from all sections @@ -154,16 +145,14 @@ async def structure( for idx, (section, result) in enumerate(zip(sections, section_results)): if isinstance(result, Exception): - self.logger.log_error(result, { - "section_id": section['id'], - "section_title": section['title'], - "error": str(result) - }) + self.logger.log_error( + result, {"section_id": section["id"], "section_title": section["title"], "error": str(result)} + ) continue # Collect phases from this section for phase in result.get("phases", []): - phase["source_section"] = section['title'] + phase["source_section"] = section["title"] all_phases.append(phase) all_cross_cutting.extend(result.get("cross_cutting_concerns", [])) @@ -178,7 +167,9 @@ def dedupe_list(items): if key not in seen: seen.add(key) if isinstance(item, dict): - result.append(item.get("name") or item.get("concern") or item.get("description") or str(item)) + result.append( + item.get("name") or item.get("concern") or item.get("description") or str(item) + ) else: result.append(item) return result @@ -195,21 +186,25 @@ def dedupe_list(items): raw_steps = phase_data.get("steps", []) for step_idx, step_data in enumerate(raw_steps, start=1): - steps.append(ImplementationStep( - step_id=f"P{new_idx}-S{step_idx}", - title=step_data.get("title", "Untitled Step"), - description=step_data.get("description", ""), - expected_artifacts=step_data.get("expected_artifacts", []), - completion_criteria=step_data.get("completion_criteria", []), - )) - - phases.append(ImplementationPhase( - phase_index=new_idx, - title=phase_data.get("title", f"Phase {new_idx}"), - objective=phase_data.get("objective", ""), - steps=steps, - dependencies=phase_data.get("dependencies", []), - )) + steps.append( + ImplementationStep( + step_id=f"P{new_idx}-S{step_idx}", + title=step_data.get("title", "Untitled Step"), + description=step_data.get("description", ""), + expected_artifacts=step_data.get("expected_artifacts", []), + completion_criteria=step_data.get("completion_criteria", []), + ) + ) + + phases.append( + ImplementationPhase( + phase_index=new_idx, + title=phase_data.get("title", f"Phase {new_idx}"), + objective=phase_data.get("objective", ""), + steps=steps, + dependencies=phase_data.get("dependencies", []), + ) + ) structure = PlanStructure( phases=phases, @@ -253,10 +248,7 @@ async def _structure_with_phase_mapping( PlanStructure with exactly the number of phases in phase_mapping """ # Process each phase in parallel to extract steps - tasks = [ - self._extract_phase_steps(context, phase_def, idx + 1) - for idx, phase_def in enumerate(phase_mapping) - ] + tasks = [self._extract_phase_steps(context, phase_def, idx + 1) for idx, phase_def in enumerate(phase_mapping)] phase_results = await asyncio.gather(*tasks, return_exceptions=True) # Build phases from results @@ -267,38 +259,41 @@ async def _structure_with_phase_mapping( phase_index = idx + 1 if isinstance(result, Exception): - self.logger.log_error(result, { - "phase_title": phase_def.get("phase_title"), - "error": str(result) - }) + self.logger.log_error(result, {"phase_title": phase_def.get("phase_title"), "error": str(result)}) # Create minimal phase on error - phases.append(ImplementationPhase( - phase_index=phase_index, - title=phase_def.get("phase_title", f"Phase {phase_index}"), - objective=f"Implement {phase_def.get('phase_title', 'phase')}", - steps=[], - dependencies=[], - )) + phases.append( + ImplementationPhase( + phase_index=phase_index, + title=phase_def.get("phase_title", f"Phase {phase_index}"), + objective=f"Implement {phase_def.get('phase_title', 'phase')}", + steps=[], + dependencies=[], + ) + ) continue # Build steps from result steps = [] for step_idx, step_data in enumerate(result.get("steps", []), start=1): - steps.append(ImplementationStep( - step_id=f"P{phase_index}-S{step_idx}", - title=step_data.get("title", "Untitled Step"), - description=step_data.get("description", ""), - expected_artifacts=step_data.get("expected_artifacts", []), - completion_criteria=step_data.get("completion_criteria", []), - )) - - phases.append(ImplementationPhase( - phase_index=phase_index, - title=phase_def.get("phase_title", f"Phase {phase_index}"), - objective=result.get("objective", f"Implement {phase_def.get('phase_title', 'phase')}"), - steps=steps, - dependencies=result.get("dependencies", []), - )) + steps.append( + ImplementationStep( + step_id=f"P{phase_index}-S{step_idx}", + title=step_data.get("title", "Untitled Step"), + description=step_data.get("description", ""), + expected_artifacts=step_data.get("expected_artifacts", []), + completion_criteria=step_data.get("completion_criteria", []), + ) + ) + + phases.append( + ImplementationPhase( + phase_index=phase_index, + title=phase_def.get("phase_title", f"Phase {phase_index}"), + objective=result.get("objective", f"Implement {phase_def.get('phase_title', 'phase')}"), + steps=steps, + dependencies=result.get("dependencies", []), + ) + ) all_cross_cutting.extend(result.get("cross_cutting_concerns", [])) @@ -354,8 +349,8 @@ async def _extract_phase_steps( # Build prompt to extract steps for this specific phase prompt = f"""Extract implementation steps for Phase {phase_index}: {phase_title} -**Phase Keywords:** {', '.join(keywords)} -**Domains:** {', '.join(domains)} +**Phase Keywords:** {", ".join(keywords)} +**Domains:** {", ".join(domains)} **Prompt Plan Content:** {context.prompt_plan_markdown[:15000]} @@ -385,9 +380,7 @@ async def _extract_phase_steps( Return ONLY the JSON object.""" self.logger.log_llm_call( - prompt=f"Phase {phase_index}: {phase_title}", - model=str(self.model_client), - operation="extract_phase_steps" + prompt=f"Phase {phase_index}: {phase_title}", model=str(self.model_client), operation="extract_phase_steps" ) # Create fresh agent for this phase @@ -397,10 +390,7 @@ async def _extract_phase_steps( model_client=self.model_client, ) - response = await phase_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await phase_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -411,21 +401,20 @@ async def _extract_phase_steps( try: result = json.loads(cleaned) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "phase_title": phase_title, - "response_preview": cleaned[:200] - }) + self.logger.log_error(e, {"phase_title": phase_title, "response_preview": cleaned[:200]}) # Return minimal result return { "objective": f"Implement {phase_title}", - "steps": [{ - "title": f"Implement {phase_title}", - "description": f"Complete implementation for {phase_title}", - "expected_artifacts": [], - "completion_criteria": [f"{phase_title} complete"] - }], + "steps": [ + { + "title": f"Implement {phase_title}", + "description": f"Complete implementation for {phase_title}", + "expected_artifacts": [], + "completion_criteria": [f"{phase_title} complete"], + } + ], "dependencies": [], - "cross_cutting_concerns": [] + "cross_cutting_concerns": [], } return result @@ -446,18 +435,15 @@ async def _structure_section( Dict with phases and cross_cutting_concerns """ # Skip empty sections - if not section['content'].strip(): - return { - "phases": [], - "cross_cutting_concerns": [] - } + if not section["content"].strip(): + return {"phases": [], "cross_cutting_concerns": []} prompt = self._build_section_prompt(section, context) self.logger.log_llm_call( prompt=f"Section: {section['title']} ({len(section['content'])} chars)", model=str(self.model_client), - operation="structure_section" + operation="structure_section", ) # Create a FRESH agent for each section to avoid conversation history accumulation @@ -469,10 +455,7 @@ async def _structure_section( model_client=self.model_client, ) - response = await section_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await section_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -483,25 +466,21 @@ async def _structure_section( try: result = json.loads(cleaned) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "section_id": section['id'], - "section_title": section['title'], - "response_length": len(response_text), - "response_preview": cleaned[:200] - }) + self.logger.log_error( + e, + { + "section_id": section["id"], + "section_title": section["title"], + "response_length": len(response_text), + "response_preview": cleaned[:200], + }, + ) # Return empty result rather than failing entirely - return { - "phases": [], - "cross_cutting_concerns": [] - } + return {"phases": [], "cross_cutting_concerns": []} return result - def _build_section_prompt( - self, - section: Dict[str, str], - context: ModuleFeatureContext - ) -> str: + def _build_section_prompt(self, section: Dict[str, str], context: ModuleFeatureContext) -> str: """ Build the structuring prompt for a single section. @@ -512,7 +491,7 @@ def _build_section_prompt( Returns: Formatted prompt string """ - prompt = f"Extract implementation phases and steps from this prompt plan section.\n\n" + prompt = "Extract implementation phases and steps from this prompt plan section.\n\n" # Add project context if context.project_name: @@ -520,7 +499,7 @@ def _build_section_prompt( # Add section prompt += f"## {section['title']}\n\n" - prompt += section['content'] + prompt += section["content"] prompt += "\n\n" prompt += "Extract phases and steps from THIS SECTION. " @@ -530,8 +509,7 @@ def _build_section_prompt( async def create_plan_structurer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> PlanStructurerAgent: """ Factory function to create a Plan Structurer Agent. diff --git a/backend/app/agents/module_feature/spec_analyzer.py b/backend/app/agents/module_feature/spec_analyzer.py index 61a881d..3c4f4b2 100644 --- a/backend/app/agents/module_feature/spec_analyzer.py +++ b/backend/app/agents/module_feature/spec_analyzer.py @@ -9,23 +9,23 @@ import asyncio import json -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( ModuleFeatureContext, SpecAnalysis, SpecRequirement, ) -from .logging_config import get_agent_logger from .utils import ( - strip_markdown_json, - generate_unique_id, extract_markdown_sections, + generate_unique_id, + strip_markdown_json, ) @@ -111,7 +111,7 @@ async def analyze( self.logger.log_agent_start( project_id=str(context.project_id), spec_length=len(context.final_spec_markdown), - has_json_spec=context.final_spec_json is not None + has_json_spec=context.final_spec_json is not None, ) try: @@ -120,19 +120,12 @@ async def analyze( if not sections: # Fallback: treat entire spec as one section - sections = [{ - 'id': 'full_spec', - 'title': 'Full Specification', - 'content': context.final_spec_markdown - }] + sections = [{"id": "full_spec", "title": "Full Specification", "content": context.final_spec_markdown}] self.logger.logger.info(f"Processing {len(sections)} sections in parallel") # Process all sections in parallel - tasks = [ - self._analyze_section(section, context) - for section in sections - ] + tasks = [self._analyze_section(section, context) for section in sections] section_results = await asyncio.gather(*tasks, return_exceptions=True) # Merge results from all sections @@ -145,17 +138,15 @@ async def analyze( for idx, (section, result) in enumerate(zip(sections, section_results)): if isinstance(result, Exception): - self.logger.log_error(result, { - "section_id": section['id'], - "section_title": section['title'], - "error": str(result) - }) + self.logger.log_error( + result, {"section_id": section["id"], "section_title": section["title"], "error": str(result)} + ) continue # Add section info to each requirement for req in result.get("requirements", []): - req["section_id"] = section['id'] - req["section_title"] = section['title'] + req["section_id"] = section["id"] + req["section_title"] = section["title"] all_requirements.append(req) if req.get("domain_area"): all_domain_areas.add(req["domain_area"]) @@ -177,7 +168,13 @@ def dedupe_list(items): seen.add(key) # For dicts, extract just the name/text if possible if isinstance(item, dict): - result.append(item.get("component_name") or item.get("name") or item.get("constraint_type") or item.get("description") or str(item)) + result.append( + item.get("component_name") + or item.get("name") + or item.get("constraint_type") + or item.get("description") + or str(item) + ) else: result.append(item) return result @@ -190,14 +187,16 @@ def dedupe_list(items): # Convert to SpecRequirement objects with unique IDs requirements = [] for idx, req_data in enumerate(all_requirements, start=1): - requirements.append(SpecRequirement( - id=generate_unique_id("REQ", idx), - section_id=req_data.get("section_id", "unknown"), - section_title=req_data.get("section_title", "Unknown Section"), - requirement_text=req_data.get("requirement_text", ""), - domain_area=req_data.get("domain_area", "Other"), - implied_components=req_data.get("implied_components", []), - )) + requirements.append( + SpecRequirement( + id=generate_unique_id("REQ", idx), + section_id=req_data.get("section_id", "unknown"), + section_title=req_data.get("section_title", "Unknown Section"), + requirement_text=req_data.get("requirement_text", ""), + domain_area=req_data.get("domain_area", "Other"), + implied_components=req_data.get("implied_components", []), + ) + ) analysis = SpecAnalysis( requirements=requirements, @@ -243,21 +242,15 @@ async def _analyze_section( Dict with requirements, data_models, api_endpoints, etc. """ # Skip empty sections - if not section['content'].strip(): - return { - "requirements": [], - "data_models": [], - "api_endpoints": [], - "ui_components": [], - "constraints": [] - } + if not section["content"].strip(): + return {"requirements": [], "data_models": [], "api_endpoints": [], "ui_components": [], "constraints": []} prompt = self._build_section_prompt(section, context) self.logger.log_llm_call( prompt=f"Section: {section['title']} ({len(section['content'])} chars)", model=str(self.model_client), - operation="analyze_section" + operation="analyze_section", ) # Create a FRESH agent for each section to avoid conversation history accumulation @@ -269,10 +262,7 @@ async def _analyze_section( model_client=self.model_client, ) - response = await section_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await section_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -284,28 +274,21 @@ async def _analyze_section( try: result = json.loads(cleaned) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "section_id": section['id'], - "section_title": section['title'], - "response_length": len(response_text), - "response_preview": cleaned[:200] - }) + self.logger.log_error( + e, + { + "section_id": section["id"], + "section_title": section["title"], + "response_length": len(response_text), + "response_preview": cleaned[:200], + }, + ) # Return empty result rather than failing entirely - return { - "requirements": [], - "data_models": [], - "api_endpoints": [], - "ui_components": [], - "constraints": [] - } + return {"requirements": [], "data_models": [], "api_endpoints": [], "ui_components": [], "constraints": []} return result - def _build_section_prompt( - self, - section: Dict[str, str], - context: ModuleFeatureContext - ) -> str: + def _build_section_prompt(self, section: Dict[str, str], context: ModuleFeatureContext) -> str: """ Build the analysis prompt for a single section. @@ -316,7 +299,7 @@ def _build_section_prompt( Returns: Formatted prompt string """ - prompt = f"Extract requirements from this specification section.\n\n" + prompt = "Extract requirements from this specification section.\n\n" # Add project context if context.project_name: @@ -324,7 +307,7 @@ def _build_section_prompt( # Add section prompt += f"## {section['title']}\n\n" - prompt += section['content'] + prompt += section["content"] prompt += "\n\n" prompt += "Extract all implementable requirements from THIS SECTION. " @@ -334,8 +317,7 @@ def _build_section_prompt( async def create_spec_analyzer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None + model_client: ChatCompletionClient, project_id: Optional[str] = None ) -> SpecAnalyzerAgent: """ Factory function to create a Spec Analyzer Agent. diff --git a/backend/app/agents/module_feature/types.py b/backend/app/agents/module_feature/types.py index 84054e1..d4921c9 100644 --- a/backend/app/agents/module_feature/types.py +++ b/backend/app/agents/module_feature/types.py @@ -16,22 +16,24 @@ from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Dict, Any +from typing import Any, Dict, List, Optional from uuid import UUID - # ============================ # Enums # ============================ + class ModuleCategory(str, Enum): """Category of module based on how it was derived.""" + PHASE = "phase" # Maps to an implementation phase from the prompt plan CROSS_CUTTING = "cross_cutting" # Cross-cutting concern (config, testing, etc.) class FeaturePriorityLevel(str, Enum): """Priority level for features.""" + MUST_HAVE = "must_have" IMPORTANT = "important" OPTIONAL = "optional" @@ -39,6 +41,7 @@ class FeaturePriorityLevel(str, Enum): class FeatureCategoryType(str, Enum): """Category of feature based on technical domain.""" + DATA_MODEL = "Data Model" API = "API" UI = "UI" @@ -56,11 +59,13 @@ class FeatureCategoryType(str, Enum): # Input Context # ============================ + @dataclass class ImageAttachmentInfo: """ Lightweight image info for the LLM to understand available images. """ + id: str # UUID of the image filename: str description: str = "" # Optional description if available @@ -74,6 +79,7 @@ class ModuleFeatureContext: Contains all the information needed to extract modules and features from a finalized specification AND prompt plan. """ + project_id: UUID brainstorming_phase_id: UUID @@ -106,11 +112,13 @@ class ModuleFeatureContext: # Agent 1: Spec Analyzer Types # ============================ + @dataclass class SpecRequirement: """ A single requirement extracted from the specification. """ + id: str # Unique ID for traceability (e.g., "REQ-001") section_id: str # Source section ID (e.g., "functional_requirements") section_title: str @@ -125,6 +133,7 @@ class SpecAnalysis: Output from the Spec Analyzer Agent. Parsed specification with extracted requirements by domain. """ + requirements: List[SpecRequirement] = field(default_factory=list) domain_areas: List[str] = field(default_factory=list) data_models: List[str] = field(default_factory=list) @@ -142,11 +151,13 @@ def __post_init__(self): # Agent 2: Plan Structurer Types # ============================ + @dataclass class ImplementationStep: """ A single step within an implementation phase. """ + step_id: str # Unique ID (e.g., "P1-S1" for Phase 1, Step 1) title: str description: str @@ -159,6 +170,7 @@ class ImplementationPhase: """ An implementation phase from the prompt plan. """ + phase_index: int # 1-based index title: str objective: str @@ -172,6 +184,7 @@ class PlanStructure: Output from the Plan Structurer Agent. Parsed prompt plan with implementation phases and steps. """ + phases: List[ImplementationPhase] = field(default_factory=list) total_steps: int = 0 cross_cutting_concerns: List[str] = field(default_factory=list) # Config, testing, etc. @@ -185,11 +198,13 @@ def __post_init__(self): # Agent 3: Merger Types # ============================ + @dataclass class FeatureMapping: """ Mapping of a feature to its sources in spec and plan. """ + feature_id: str # Unique ID (e.g., "FEAT-001") title: str module_id: str # Which module this belongs to @@ -213,6 +228,7 @@ class ModuleMapping: """ Mapping of a module to its source phase. """ + module_id: str # Unique ID (e.g., "MOD-001") title: str description: str @@ -228,6 +244,7 @@ class MergedMapping: Output from the Merger Agent. Aligned modules and features from spec and plan. """ + modules: List[ModuleMapping] = field(default_factory=list) features: List[FeatureMapping] = field(default_factory=list) @@ -240,12 +257,14 @@ class MergedMapping: # Agent 4: Content Writer Types # ============================ + @dataclass class FeatureContent: """ Rich content for a single feature. Generated by the Content Writer Agent. """ + feature_id: str description: str # User-friendly description of what the feature does spec_text: str # WHAT to build - requirements, acceptance criteria @@ -260,6 +279,7 @@ class WriterOutput: Output from the Content Writer Agent. All features with their content populated. """ + feature_contents: List[FeatureContent] = field(default_factory=list) @@ -267,12 +287,14 @@ class WriterOutput: # Agent 5: Coverage Validator Types # ============================ + @dataclass class CoverageReport: """ Quality assurance report from the Coverage Validator Agent. Validates completeness and ordering. """ + ok: bool # Coverage issues @@ -300,12 +322,14 @@ class CoverageReport: # Final Output Types # ============================ + @dataclass class ExtractedFeature: """ A fully extracted feature with all content. Final output for persistence. """ + title: str description: str # User-friendly description of what the feature does spec_text: str # WHAT: requirements, acceptance criteria @@ -324,6 +348,7 @@ class ExtractedModule: A fully extracted module with all features. Final output for persistence. """ + title: str description: str order_index: int @@ -338,6 +363,7 @@ class ExtractionResult: Final result from the module/feature extraction process. Includes all modules, features, and validation report. """ + modules: List[ExtractedModule] = field(default_factory=list) total_features: int = 0 coverage_report: Optional[CoverageReport] = None @@ -352,12 +378,14 @@ def __post_init__(self): # Agent Metadata for UI # ============================ + @dataclass class AgentInfo: """ UI metadata for an agent in the extraction workflow. Used for progress tracking and visual representation. """ + name: str description: str color: str # Hex color for UI tag @@ -368,46 +396,38 @@ class AgentInfo: "orchestrator": AgentInfo( name="Orchestrator", description="Coordinating the module/feature extraction workflow", - color="#8B5CF6" # Purple + color="#8B5CF6", # Purple ), "spec_analyzer": AgentInfo( name="Spec Analyzer", description="Parsing specification to extract requirements", - color="#3B82F6" # Blue + color="#3B82F6", # Blue ), "plan_structurer": AgentInfo( name="Plan Structurer", description="Parsing prompt plan to extract implementation phases", - color="#10B981" # Green + color="#10B981", # Green ), "merger": AgentInfo( name="Merger", description="Aligning spec requirements with plan phases", - color="#F59E0B" # Amber + color="#F59E0B", # Amber ), "writer": AgentInfo( name="Content Writer", description="Generating spec_text and prompt_plan_text for features", - color="#EF4444" # Red + color="#EF4444", # Red ), "validator": AgentInfo( name="Validator", description="Validating coverage and ordering", - color="#EC4899" # Pink + color="#EC4899", # Pink ), } # Workflow step definitions for progress tracking -WORKFLOW_STEPS = [ - "start", - "analyzing_spec", - "structuring_plan", - "merging", - "writing_content", - "validating", - "complete" -] +WORKFLOW_STEPS = ["start", "analyzing_spec", "structuring_plan", "merging", "writing_content", "validating", "complete"] # ============================ @@ -437,6 +457,7 @@ class AgentInfo: # Helper Functions # ============================ + def get_module_by_id(modules: List[ModuleMapping], module_id: str) -> Optional[ModuleMapping]: """Get a module mapping by its ID.""" for module in modules: @@ -508,7 +529,9 @@ def validate_extraction_result(result: ExtractionResult) -> List[str]: issues.append(f"Feature '{feature.title}' has spec_text too short ({len(feature.spec_text)} chars)") if len(feature.prompt_plan_text) < MIN_PROMPT_PLAN_TEXT_LENGTH: - issues.append(f"Feature '{feature.title}' has prompt_plan_text too short ({len(feature.prompt_plan_text)} chars)") + issues.append( + f"Feature '{feature.title}' has prompt_plan_text too short ({len(feature.prompt_plan_text)} chars)" + ) # Check for duplicate module titles titles = [m.title.lower() for m in result.modules] diff --git a/backend/app/agents/module_feature/utils.py b/backend/app/agents/module_feature/utils.py index cfbcd00..9615167 100644 --- a/backend/app/agents/module_feature/utils.py +++ b/backend/app/agents/module_feature/utils.py @@ -4,16 +4,16 @@ Provides JSON parsing, text processing, and markdown extraction utilities. """ -import re import json import logging +import re import traceback from typing import Any, Dict, List, Optional from json_repair import repair_json # Import from common module and re-export for backwards compatibility -from app.agents.response_parser import strip_markdown_json, normalize_response_content +from app.agents.response_parser import strip_markdown_json # Get logger for JSON repair warnings _repair_logger = logging.getLogger("module_feature.json_repair") @@ -87,7 +87,7 @@ def repair_truncated_json(text: str) -> str: escape_next = False continue - if char == '\\': + if char == "\\": escape_next = True continue @@ -98,15 +98,15 @@ def repair_truncated_json(text: str) -> str: if in_string: continue - if char == '{': - stack.append('}') - elif char == '[': - stack.append(']') - elif char == '}': - if stack and stack[-1] == '}': + if char == "{": + stack.append("}") + elif char == "[": + stack.append("]") + elif char == "}": + if stack and stack[-1] == "}": stack.pop() - elif char == ']': - if stack and stack[-1] == ']': + elif char == "]": + if stack and stack[-1] == "]": stack.pop() # If we're inside a string, close it first @@ -138,7 +138,7 @@ def fix_common_json_errors(text: str) -> str: """ # Remove trailing commas before closing brackets/braces # Pattern: comma followed by optional whitespace then ] or } - text = re.sub(r',(\s*[\]}])', r'\1', text) + text = re.sub(r",(\s*[\]}])", r"\1", text) # Fix missing colons after object keys # Pattern: "key" followed by whitespace then another "value" without a colon @@ -164,7 +164,7 @@ def fix_common_json_errors(text: str) -> str: i += 1 continue - if char == '\\': + if char == "\\": escape_next = True result.append(char) i += 1 @@ -178,15 +178,15 @@ def fix_common_json_errors(text: str) -> str: if in_string: # Check for control characters that need escaping - if char == '\n': - result.append('\\n') - elif char == '\r': - result.append('\\r') - elif char == '\t': - result.append('\\t') + if char == "\n": + result.append("\\n") + elif char == "\r": + result.append("\\r") + elif char == "\t": + result.append("\\t") elif ord(char) < 32: # Other control characters - escape as unicode - result.append(f'\\u{ord(char):04x}') + result.append(f"\\u{ord(char):04x}") else: result.append(char) else: @@ -194,7 +194,7 @@ def fix_common_json_errors(text: str) -> str: i += 1 - return ''.join(result) + return "".join(result) def _log_json_repair_banner(strategy: str, text_preview: str, text_length: int, extra_info: str = ""): @@ -210,10 +210,10 @@ def _log_json_repair_banner(strategy: str, text_preview: str, text_length: int, ║ Strategy: {strategy:<67}║ ║ Text length: {text_length:<64}║ ║ Preview: {text_preview[:60]:<68}║ -{f'║ Info: {extra_info:<71}║' if extra_info else ''} +{f"║ Info: {extra_info:<71}║" if extra_info else ""} ╠══════════════════════════════════════════════════════════════════════════════╣ ║ CALL STACK: ║ -{chr(10).join(f'║ {line:<77}║' for line in caller_info.split(chr(10)))} +{chr(10).join(f"║ {line:<77}║" for line in caller_info.split(chr(10)))} ╚══════════════════════════════════════════════════════════════════════════════╝ """ _repair_logger.warning(banner) @@ -255,18 +255,18 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: if isinstance(repaired, dict): _log_json_repair_banner( "Strategy 2: json-repair library", - cleaned[:100].replace('\n', '\\n'), + cleaned[:100].replace("\n", "\\n"), len(cleaned), - f"Initial error: {initial_error[:50]}" + f"Initial error: {initial_error[:50]}", ) return repaired elif isinstance(repaired, str): result = json.loads(repaired) _log_json_repair_banner( "Strategy 2: json-repair library (string result)", - cleaned[:100].replace('\n', '\\n'), + cleaned[:100].replace("\n", "\\n"), len(cleaned), - f"Initial error: {initial_error[:50]}" + f"Initial error: {initial_error[:50]}", ) return result except Exception: @@ -278,9 +278,9 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: result = json.loads(fixed_and_repaired) _log_json_repair_banner( "Strategy 3: fix_common_errors + repair_truncated", - cleaned[:100].replace('\n', '\\n'), + cleaned[:100].replace("\n", "\\n"), len(cleaned), - f"Initial error: {initial_error[:50]}" + f"Initial error: {initial_error[:50]}", ) return result except json.JSONDecodeError: @@ -290,7 +290,7 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: # Sometimes LLMs add extra text before or after the JSON try: # Find the first { and match to its closing } - start = cleaned.find('{') + start = cleaned.find("{") if start != -1: # Track braces to find matching close depth = 0 @@ -302,16 +302,16 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: if escape_next: escape_next = False continue - if char == '\\': + if char == "\\": escape_next = True continue if char == '"' and not escape_next: in_string = not in_string continue if not in_string: - if char == '{': + if char == "{": depth += 1 - elif char == '}': + elif char == "}": depth -= 1 if depth == 0: end = i + 1 @@ -325,9 +325,9 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: if isinstance(repaired, dict): _log_json_repair_banner( "Strategy 4: Extract JSON + json-repair", - cleaned[:100].replace('\n', '\\n'), + cleaned[:100].replace("\n", "\\n"), len(cleaned), - f"Extracted {end - start} chars from position {start}" + f"Extracted {end - start} chars from position {start}", ) return repaired except Exception: @@ -335,9 +335,9 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: result = json.loads(extracted) _log_json_repair_banner( "Strategy 4: Extract JSON object", - cleaned[:100].replace('\n', '\\n'), + cleaned[:100].replace("\n", "\\n"), len(cleaned), - f"Extracted {end - start} chars from position {start}" + f"Extracted {end - start} chars from position {start}", ) return result except json.JSONDecodeError: @@ -346,9 +346,9 @@ def parse_json_with_repair(text: str) -> Dict[str, Any]: # All strategies failed - raise with the original cleaned text _log_json_repair_banner( "ALL STRATEGIES FAILED - raising exception", - cleaned[:100].replace('\n', '\\n'), + cleaned[:100].replace("\n", "\\n"), len(cleaned), - f"Initial error: {initial_error[:50]}" + f"Initial error: {initial_error[:50]}", ) return json.loads(cleaned) @@ -368,7 +368,7 @@ def truncate_text(text: str, max_length: int = 500) -> str: return text # Truncate at word boundary - truncated = text[:max_length].rsplit(' ', 1)[0] + truncated = text[:max_length].rsplit(" ", 1)[0] return truncated + "..." @@ -383,7 +383,7 @@ def normalize_whitespace(text: str) -> str: Normalized text """ # Replace multiple whitespace with single space - text = re.sub(r'\s+', ' ', text) + text = re.sub(r"\s+", " ", text) return text.strip() @@ -401,18 +401,20 @@ def extract_markdown_sections(markdown: str) -> List[Dict[str, str]]: current_section = None current_content = [] - for line in markdown.split('\n'): + for line in markdown.split("\n"): # Check for ## header - header_match = re.match(r'^##\s+(.+)$', line) + header_match = re.match(r"^##\s+(.+)$", line) if header_match: # Save previous section if exists if current_section is not None: - sections.append({ - 'id': slugify(current_section), - 'title': current_section, - 'content': '\n'.join(current_content).strip() - }) + sections.append( + { + "id": slugify(current_section), + "title": current_section, + "content": "\n".join(current_content).strip(), + } + ) current_section = header_match.group(1).strip() current_content = [] @@ -421,11 +423,9 @@ def extract_markdown_sections(markdown: str) -> List[Dict[str, str]]: # Save final section if current_section is not None: - sections.append({ - 'id': slugify(current_section), - 'title': current_section, - 'content': '\n'.join(current_content).strip() - }) + sections.append( + {"id": slugify(current_section), "title": current_section, "content": "\n".join(current_content).strip()} + ) return sections @@ -451,13 +451,13 @@ def slugify(text: str) -> str: # Convert to lowercase text = text.lower() # Replace spaces with underscores - text = re.sub(r'\s+', '_', text) + text = re.sub(r"\s+", "_", text) # Remove non-alphanumeric characters (except underscores) - text = re.sub(r'[^a-z0-9_]', '', text) + text = re.sub(r"[^a-z0-9_]", "", text) # Remove consecutive underscores - text = re.sub(r'_+', '_', text) + text = re.sub(r"_+", "_", text) # Remove leading/trailing underscores - text = text.strip('_') + text = text.strip("_") return text @@ -480,10 +480,7 @@ def extract_phase_steps_from_markdown(markdown: str) -> List[Dict[str, Any]]: current_steps = [] # Patterns for phase headers - phase_pattern = re.compile( - r'^(?:#{1,4}\s*)?(?:Phase\s*)?(\d+)[.:]\s*(.+?)$', - re.IGNORECASE | re.MULTILINE - ) + phase_pattern = re.compile(r"^(?:#{1,4}\s*)?(?:Phase\s*)?(\d+)[.:]\s*(.+?)$", re.IGNORECASE | re.MULTILINE) # Split by phase headers parts = phase_pattern.split(markdown) @@ -500,11 +497,7 @@ def extract_phase_steps_from_markdown(markdown: str) -> List[Dict[str, Any]]: # Extract steps from content steps = extract_steps_from_content(phase_content) - phases.append({ - 'index': phase_index, - 'title': phase_title, - 'steps': steps - }) + phases.append({"index": phase_index, "title": phase_title, "steps": steps}) i += 3 @@ -526,17 +519,16 @@ def extract_steps_from_content(content: str) -> List[Dict[str, str]]: steps = [] # Pattern for list items (numbered or bulleted) - item_pattern = re.compile( - r'^(?:\d+[.)]|\*|-)\s+(.+?)$', - re.MULTILINE - ) + item_pattern = re.compile(r"^(?:\d+[.)]|\*|-)\s+(.+?)$", re.MULTILINE) for match in item_pattern.finditer(content): step_text = match.group(1).strip() - steps.append({ - 'title': step_text, - 'description': '' # Could be enhanced to capture following lines - }) + steps.append( + { + "title": step_text, + "description": "", # Could be enhanced to capture following lines + } + ) return steps @@ -583,14 +575,14 @@ def generate_semantic_id(prefix: str, title: str, index: int, max_len: int = 20) return f"{prefix}-{index:03d}" # Convert to slug: uppercase, alphanumeric + hyphens only - slug = re.sub(r'[^a-zA-Z0-9]+', '-', title.strip()) - slug = slug.strip('-').upper() + slug = re.sub(r"[^a-zA-Z0-9]+", "-", title.strip()) + slug = slug.strip("-").upper() # Truncate if too long if len(slug) > max_len: # Try to break at a hyphen truncated = slug[:max_len] - last_hyphen = truncated.rfind('-') + last_hyphen = truncated.rfind("-") if last_hyphen > max_len // 2: slug = truncated[:last_hyphen] else: @@ -613,7 +605,7 @@ def chunk_list(items: List[Any], chunk_size: int) -> List[List[Any]]: Returns: List of chunks """ - return [items[i:i + chunk_size] for i in range(0, len(items), chunk_size)] + return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] def merge_markdown_sections(sections: List[Dict[str, str]]) -> str: @@ -630,6 +622,6 @@ def merge_markdown_sections(sections: List[Dict[str, str]]) -> str: for section in sections: lines.append(f"## {section['title']}") lines.append("") - lines.append(section['content']) + lines.append(section["content"]) lines.append("") return "\n".join(lines) diff --git a/backend/app/agents/module_feature/validator.py b/backend/app/agents/module_feature/validator.py index 4b9b020..ad9355b 100644 --- a/backend/app/agents/module_feature/validator.py +++ b/backend/app/agents/module_feature/validator.py @@ -7,16 +7,16 @@ from typing import List, Optional, Set +from .logging_config import get_agent_logger from .types import ( - SpecAnalysis, - PlanStructure, + MIN_PROMPT_PLAN_TEXT_LENGTH, + MIN_SPEC_TEXT_LENGTH, + CoverageReport, MergedMapping, + PlanStructure, + SpecAnalysis, WriterOutput, - CoverageReport, - MIN_SPEC_TEXT_LENGTH, - MIN_PROMPT_PLAN_TEXT_LENGTH, ) -from .logging_config import get_agent_logger class ValidatorAgent: @@ -68,23 +68,14 @@ async def validate( try: # Build content lookup - content_map = { - fc.feature_id: fc - for fc in writer_output.feature_contents - } + content_map = {fc.feature_id: fc for fc in writer_output.feature_contents} # Validate coverage - uncovered_requirements = self._check_requirement_coverage( - spec_analysis, merged_mapping - ) - uncovered_steps = self._check_step_coverage( - plan_structure, merged_mapping - ) + uncovered_requirements = self._check_requirement_coverage(spec_analysis, merged_mapping) + uncovered_steps = self._check_step_coverage(plan_structure, merged_mapping) # Validate ordering - ordering_issues = self._check_ordering( - plan_structure, merged_mapping - ) + ordering_issues = self._check_ordering(plan_structure, merged_mapping) # Content quality checks removed - relying on prompt-based prevention # in Writer agent (TBD markers) rather than post-hoc pattern matching @@ -115,10 +106,7 @@ async def validate( must_have_without_content.append(feature.feature_id) # Calculate coverage percentage - total_items = ( - len(spec_analysis.requirements) + - plan_structure.total_steps - ) + total_items = len(spec_analysis.requirements) + plan_structure.total_steps covered_items = total_items - len(uncovered_requirements) - len(uncovered_steps) coverage_percentage = (covered_items / total_items * 100) if total_items > 0 else 100 @@ -134,8 +122,7 @@ async def validate( # Determine if validation passed # OK if: no must_have requirements uncovered, no must_have features missing content must_have_reqs_uncovered = [ - req_id for req_id in uncovered_requirements - if self._is_must_have_requirement(spec_analysis, req_id) + req_id for req_id in uncovered_requirements if self._is_must_have_requirement(spec_analysis, req_id) ] ok = len(must_have_reqs_uncovered) == 0 and len(must_have_without_content) == 0 @@ -357,26 +344,20 @@ def _generate_suggestions( if uncovered_requirements: suggestions.append( f"Consider adding features for {len(uncovered_requirements)} uncovered requirements: " - f"{', '.join(uncovered_requirements[:5])}" - + ("..." if len(uncovered_requirements) > 5 else "") + f"{', '.join(uncovered_requirements[:5])}" + ("..." if len(uncovered_requirements) > 5 else "") ) if uncovered_steps: suggestions.append( f"Consider adding features for {len(uncovered_steps)} uncovered plan steps: " - f"{', '.join(uncovered_steps[:5])}" - + ("..." if len(uncovered_steps) > 5 else "") + f"{', '.join(uncovered_steps[:5])}" + ("..." if len(uncovered_steps) > 5 else "") ) if ordering_issues: - suggestions.append( - f"Review feature ordering to ensure phase dependencies are respected" - ) + suggestions.append("Review feature ordering to ensure phase dependencies are respected") if empty_spec_text: - suggestions.append( - f"Enhance spec_text for {len(empty_spec_text)} features with insufficient content" - ) + suggestions.append(f"Enhance spec_text for {len(empty_spec_text)} features with insufficient content") if empty_prompt_plan_text: suggestions.append( @@ -389,9 +370,7 @@ def _generate_suggestions( return suggestions -async def create_validator( - project_id: Optional[str] = None -) -> ValidatorAgent: +async def create_validator(project_id: Optional[str] = None) -> ValidatorAgent: """ Factory function to create a Validator Agent. diff --git a/backend/app/agents/module_feature/writer.py b/backend/app/agents/module_feature/writer.py index 3e46d2e..f65b86d 100644 --- a/backend/app/agents/module_feature/writer.py +++ b/backend/app/agents/module_feature/writer.py @@ -7,24 +7,24 @@ import asyncio import json -from typing import List, Optional, Tuple +from typing import List, Optional from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken from autogen_core.models import ChatCompletionClient +from .logging_config import get_agent_logger from .types import ( + FeatureContent, + FeatureMapping, + MergedMapping, ModuleFeatureContext, - SpecAnalysis, PlanStructure, - MergedMapping, - FeatureMapping, - FeatureContent, + SpecAnalysis, WriterOutput, ) -from .logging_config import get_agent_logger -from .utils import strip_markdown_json, chunk_list +from .utils import chunk_list, strip_markdown_json class WriterAgent: @@ -131,10 +131,7 @@ async def write_all( Raises: ValueError: If writing fails """ - self.logger.log_agent_start( - project_id=str(context.project_id), - features_count=len(merged_mapping.features) - ) + self.logger.log_agent_start(project_id=str(context.project_id), features_count=len(merged_mapping.features)) try: # Process features individually (1 per call) in PARALLEL @@ -149,12 +146,7 @@ async def write_all( # Create tasks for parallel processing tasks = [ self._write_chunk_with_index( - chunk_idx, - feature_chunk, - spec_analysis, - plan_structure, - context, - len(feature_chunks) + chunk_idx, feature_chunk, spec_analysis, plan_structure, context, len(feature_chunks) ) for chunk_idx, feature_chunk in enumerate(feature_chunks) ] @@ -166,18 +158,17 @@ async def write_all( all_contents: List[FeatureContent] = [] for chunk_idx, result in enumerate(chunk_results): if isinstance(result, Exception): - self.logger.log_error(result, { - "chunk_index": chunk_idx, - "error": str(result) - }) + self.logger.log_error(result, {"chunk_index": chunk_idx, "error": str(result)}) # Generate fallback content for failed chunks for feature in feature_chunks[chunk_idx]: - all_contents.append(FeatureContent( - feature_id=feature.feature_id, - description=f"This feature enables {feature.title.lower()} functionality for users.", - spec_text=f"## {feature.title}\n\nImplement this feature as specified.", - prompt_plan_text=f"## Implementation\n\n1. Implement {feature.title}\n2. Write tests\n3. Verify functionality", - )) + all_contents.append( + FeatureContent( + feature_id=feature.feature_id, + description=f"This feature enables {feature.title.lower()} functionality for users.", + spec_text=f"## {feature.title}\n\nImplement this feature as specified.", + prompt_plan_text=f"## Implementation\n\n1. Implement {feature.title}\n2. Write tests\n3. Verify functionality", + ) + ) else: all_contents.extend(result) @@ -194,9 +185,7 @@ async def write_all( avg_prompt_plan_text_length=avg_plan_len, ) - self.logger.log_agent_complete( - features_written=len(all_contents) - ) + self.logger.log_agent_complete(features_written=len(all_contents)) return WriterOutput(feature_contents=all_contents) @@ -231,12 +220,10 @@ async def _write_chunk_with_index( prompt=f"Processing chunk {chunk_idx + 1}/{total_chunks} ({len(features)} features)", model=str(self.model_client), operation="write_feature_content", - chunk_size=len(features) + chunk_size=len(features), ) - return await self._write_chunk( - features, spec_analysis, plan_structure, context - ) + return await self._write_chunk(features, spec_analysis, plan_structure, context) async def _write_chunk( self, @@ -268,10 +255,7 @@ async def _write_chunk( model_client=self.model_client, ) - response = await chunk_agent.on_messages( - [TextMessage(content=prompt, source="user")], - CancellationToken() - ) + response = await chunk_agent.on_messages([TextMessage(content=prompt, source="user")], CancellationToken()) response_text = response.chat_message.content if isinstance(response_text, list): @@ -283,11 +267,14 @@ async def _write_chunk( try: result_data = json.loads(cleaned) except json.JSONDecodeError as e: - self.logger.log_error(e, { - "response_length": len(response_text), - "response_preview": cleaned[:300], - "error_position": e.pos if hasattr(e, 'pos') else None - }) + self.logger.log_error( + e, + { + "response_length": len(response_text), + "response_preview": cleaned[:300], + "error_position": e.pos if hasattr(e, "pos") else None, + }, + ) # Return fallback content on parse failure return [ FeatureContent( @@ -307,7 +294,7 @@ async def _write_chunk( "result_keys": list(result_data.keys()) if isinstance(result_data, dict) else None, "response_preview": response_text[:500], "response_length": len(response_text), - } + }, ) # Convert to FeatureContent objects @@ -322,21 +309,18 @@ async def _write_chunk( else: self.logger.log_error( ValueError(f"Invalid feature entry: expected dict with feature_id, got {type(f).__name__}"), - {"raw_value": str(f)[:200] if f else None} + {"raw_value": str(f)[:200] if f else None}, ) # Log if we had to filter out invalid entries if len(valid_features) < len(raw_features): self.logger.log_error( ValueError(f"Filtered out {len(raw_features) - len(valid_features)} invalid feature entries"), - {"total_raw": len(raw_features), "valid_count": len(valid_features)} + {"total_raw": len(raw_features), "valid_count": len(valid_features)}, ) # Create a map for easy lookup - content_map = { - f.get("feature_id", ""): f - for f in valid_features - } + content_map = {f.get("feature_id", ""): f for f in valid_features} for feature in features: if feature.feature_id in content_map: @@ -347,22 +331,30 @@ async def _write_chunk( raw_image_ids = [] relevant_image_ids = [str(img_id) for img_id in raw_image_ids if img_id] - contents.append(FeatureContent( - feature_id=feature.feature_id, - description=data.get("description", f"This feature enables {feature.title.lower()} functionality for users."), - spec_text=data.get("spec_text", f"## {feature.title}\n\nImplement as specified."), - prompt_plan_text=data.get("prompt_plan_text", f"## Implementation\n\n1. Implement {feature.title}"), - relevant_image_ids=relevant_image_ids, - )) + contents.append( + FeatureContent( + feature_id=feature.feature_id, + description=data.get( + "description", f"This feature enables {feature.title.lower()} functionality for users." + ), + spec_text=data.get("spec_text", f"## {feature.title}\n\nImplement as specified."), + prompt_plan_text=data.get( + "prompt_plan_text", f"## Implementation\n\n1. Implement {feature.title}" + ), + relevant_image_ids=relevant_image_ids, + ) + ) else: # Fallback for missing features - contents.append(FeatureContent( - feature_id=feature.feature_id, - description=f"This feature enables {feature.title.lower()} functionality for users.", - spec_text=f"## {feature.title}\n\nImplement this feature according to the specification.", - prompt_plan_text=f"## Implementation\n\n1. Implement {feature.title}\n2. Write unit tests\n3. Verify functionality", - relevant_image_ids=[], - )) + contents.append( + FeatureContent( + feature_id=feature.feature_id, + description=f"This feature enables {feature.title.lower()} functionality for users.", + spec_text=f"## {feature.title}\n\nImplement this feature according to the specification.", + prompt_plan_text=f"## Implementation\n\n1. Implement {feature.title}\n2. Write unit tests\n3. Verify functionality", + relevant_image_ids=[], + ) + ) return contents @@ -400,7 +392,7 @@ def _build_prompt( prompt += "\n" # Add pending clarification topics if available - if hasattr(context, 'topics_pending_clarification') and context.topics_pending_clarification: + if hasattr(context, "topics_pending_clarification") and context.topics_pending_clarification: prompt += "## PENDING CLARIFICATION TOPICS:\n\n" prompt += "The following topics were not fully decided in brainstorming. Do NOT invent solutions:\n" for topic in context.topics_pending_clarification: @@ -470,10 +462,7 @@ def _build_prompt( return prompt -async def create_writer( - model_client: ChatCompletionClient, - project_id: Optional[str] = None -) -> WriterAgent: +async def create_writer(model_client: ChatCompletionClient, project_id: Optional[str] = None) -> WriterAgent: """ Factory function to create a Writer Agent. diff --git a/backend/app/agents/project_chat_assistant/__init__.py b/backend/app/agents/project_chat_assistant/__init__.py index 5857d20..4c109b0 100644 --- a/backend/app/agents/project_chat_assistant/__init__.py +++ b/backend/app/agents/project_chat_assistant/__init__.py @@ -7,17 +7,17 @@ description when ready. """ -from .types import ( - ProjectChatContext, - ProjectChatAssistantResponse, - MCQOption, - ExistingPhaseInfo, -) -from .assistant import ProjectChatAssistant, SYSTEM_PROMPT +from .assistant import SYSTEM_PROMPT, ProjectChatAssistant from .orchestrator import ( - load_context, generate_response, handle_user_message, + load_context, +) +from .types import ( + ExistingPhaseInfo, + MCQOption, + ProjectChatAssistantResponse, + ProjectChatContext, ) __all__ = [ diff --git a/backend/app/agents/project_chat_assistant/assistant.py b/backend/app/agents/project_chat_assistant/assistant.py index 2e2a0a3..8e9bdf3 100644 --- a/backend/app/agents/project_chat_assistant/assistant.py +++ b/backend/app/agents/project_chat_assistant/assistant.py @@ -8,28 +8,24 @@ import json import logging import re -from typing import Awaitable, Callable, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional from autogen_agentchat.agents import AssistantAgent -from autogen_agentchat.messages import StructuredMessage, TextMessage +from autogen_agentchat.messages import TextMessage from autogen_core import CancellationToken -from pydantic import ValidationError -from autogen_core.models import ChatCompletionClient, UserMessage, AssistantMessage, LLMMessage from autogen_core.model_context import TokenLimitedChatCompletionContext +from autogen_core.models import AssistantMessage, ChatCompletionClient, LLMMessage, UserMessage -from .types import ( - ProjectChatContext, - ProjectChatAssistantResponse, - ContainerContext, - MCQOption, - ExistingPhaseInfo, - ExistingModuleInfo, -) from app.agents.collab_thread_assistant.web_search_parser import has_web_search_block from app.agents.response_parser import ( - strip_markdown_json, - normalize_response_content, extract_json_from_text, + strip_markdown_json, +) + +from .types import ( + ContainerContext, + ProjectChatAssistantResponse, + ProjectChatContext, ) if TYPE_CHECKING: @@ -350,7 +346,9 @@ def _build_grounding_section(grounding_context: str) -> str: sections.append("") sections.append(grounding_context) sections.append("") - sections.append("Use this context to understand the project's technical foundation when helping users define what to build.") + sections.append( + "Use this context to understand the project's technical foundation when helping users define what to build." + ) return "\n".join(sections) @@ -360,7 +358,7 @@ def _build_extension_context_section(container_ctx: ContainerContext) -> str: lines = [ "### EXTENSION MODE", "", - f"You are helping create an **extension** for the container **\"{container_ctx.container_title}\"**.", + f'You are helping create an **extension** for the container **"{container_ctx.container_title}"**.', "", "An extension explores a new aspect that builds upon or complements the initial spec.", "Focus on what NEW aspect extends the initial spec - avoid duplicating what's already covered.", @@ -447,9 +445,7 @@ def _build_context_section(context: ProjectChatContext) -> str: ) sections.append(exploration_info) else: - sections.append( - "This is a GREENFIELD project - starting fresh with no existing codebase." - ) + sections.append("This is a GREENFIELD project - starting fresh with no existing codebase.") # Existing phases if context.existing_phases: @@ -472,7 +468,7 @@ def _build_context_section(context: ProjectChatContext) -> str: for decision in phase.decisions[:5]: # Limit to 5 decisions per phase decision_lines.append( f"- [{decision.aspect_title}] {decision.question_title}: " - f"\"{decision.decision_summary_short}\"" + f'"{decision.decision_summary_short}"' ) if len(decision_lines) > 1: # Has actual decisions sections.append("\n".join(decision_lines)) @@ -480,10 +476,7 @@ def _build_context_section(context: ProjectChatContext) -> str: if cross_ctx.project_features: feature_lines = ["Project-level feature decisions:"] for feat in cross_ctx.project_features[:5]: # Limit to 5 - feature_lines.append( - f"- [{feat.module_title}] {feat.feature_title}: " - f"\"{feat.decision_summary_short}\"" - ) + feature_lines.append(f'- [{feat.module_title}] {feat.feature_title}: "{feat.decision_summary_short}"') sections.append("\n".join(feature_lines)) # Extension mode context @@ -728,9 +721,7 @@ def _parse_response(self, response_text: str) -> tuple[ProjectChatAssistantRespo try: inner_data = json.loads(inner_json_str) if isinstance(inner_data, dict) and "reply_text" in inner_data: - logger.warning( - "Detected JSON in reply_text field - extracting inner reply_text" - ) + logger.warning("Detected JSON in reply_text field - extracting inner reply_text") # Use the inner data instead data = inner_data except (json.JSONDecodeError, TypeError): @@ -739,7 +730,11 @@ def _parse_response(self, response_text: str) -> tuple[ProjectChatAssistantRespo # Fallback: Detect exploration intent when JSON parsing failed or was incomplete # This catches cases where the agent says "Let me explore..." without proper JSON # Skip if web search is requested (via JSON field or block) - "let me search" should NOT trigger code exploration - if not data.get("wants_code_exploration") and not data.get("wants_web_search") and not has_web_search_block(response_text): + if ( + not data.get("wants_code_exploration") + and not data.get("wants_web_search") + and not has_web_search_block(response_text) + ): reply_text = data.get("reply_text", response_text).lower() original_reply = data.get("reply_text", response_text) @@ -870,9 +865,7 @@ async def generate_response( # While adding: smart token-aware trimming from the middle initial_messages = None if context.recent_messages: - initial_messages = self._convert_messages_to_llm_format( - context.recent_messages - ) + initial_messages = self._convert_messages_to_llm_format(context.recent_messages) # Only send the current user message - history is in the model_context user_message = context.user_message @@ -894,10 +887,7 @@ async def generate_response( # Set agent name for call logging if self.llm_call_logger: retry_suffix = f" (retry {attempt})" if attempt > 0 else "" - self.llm_call_logger.set_agent( - "project_chat", - f"Pre-Phase Discussion Assistant{retry_suffix}" - ) + self.llm_call_logger.set_agent("project_chat", f"Pre-Phase Discussion Assistant{retry_suffix}") try: # Always append schema as suffix - it's the last thing the LLM sees diff --git a/backend/app/agents/project_chat_assistant/orchestrator.py b/backend/app/agents/project_chat_assistant/orchestrator.py index fcd625b..799f241 100644 --- a/backend/app/agents/project_chat_assistant/orchestrator.py +++ b/backend/app/agents/project_chat_assistant/orchestrator.py @@ -6,33 +6,37 @@ """ import logging -from typing import Any, Awaitable, Callable, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional from uuid import UUID +from autogen_core.models import ChatCompletionClient from sqlalchemy import func from sqlalchemy.orm import Session -from autogen_core.models import ChatCompletionClient - -from app.models.project_chat import ProjectChat -from app.models.project import Project -from app.models.organization import Organization -from app.models.brainstorming_phase import BrainstormingPhase -from app.models.grounding_file import GroundingFile -from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureVisibilityStatus, FeatureType -from app.models.thread import Thread, ContextType -from app.models.platform_settings import PlatformSettings - from app.agents.brainstorm_conversation.types import ( - CrossProjectContext, CrossPhaseContext, CrossPhaseDecision, + CrossProjectContext, ProjectFeatureDecision, ) +from app.models.brainstorming_phase import BrainstormingPhase +from app.models.feature import Feature, FeatureType, FeatureVisibilityStatus +from app.models.grounding_file import GroundingFile +from app.models.module import Module, ModuleType +from app.models.organization import Organization +from app.models.platform_settings import PlatformSettings +from app.models.project import Project +from app.models.project_chat import ProjectChat +from app.models.thread import ContextType, Thread -from .types import ProjectChatContext, ProjectChatAssistantResponse, ExistingPhaseInfo, ExistingModuleInfo, ContainerContext from .assistant import ProjectChatAssistant +from .types import ( + ContainerContext, + ExistingModuleInfo, + ExistingPhaseInfo, + ProjectChatAssistantResponse, + ProjectChatContext, +) if TYPE_CHECKING: from app.agents.llm_client import LLMCallLogger @@ -67,47 +71,55 @@ def _build_cross_project_context_for_project_chat( project_features_context = [] # 1. Query ALL brainstorming phases (not archived) - all_phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project_id, - BrainstormingPhase.archived_at.is_(None) - ).order_by(BrainstormingPhase.created_at).limit(max_phases).all() + all_phases = ( + db.query(BrainstormingPhase) + .filter(BrainstormingPhase.project_id == project_id, BrainstormingPhase.archived_at.is_(None)) + .order_by(BrainstormingPhase.created_at) + .limit(max_phases) + .all() + ) for phase in all_phases: decisions = [] # Get modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.archived_at.is_(None) - ).all() + modules = db.query(Module).filter(Module.brainstorming_phase_id == phase.id, Module.archived_at.is_(None)).all() for module in modules: # Get ACTIVE features (questions) with threads that have decisions - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id) - ).first() + thread = ( + db.query(Thread) + .filter(Thread.context_type == ContextType.BRAINSTORM_FEATURE, Thread.context_id == str(feature.id)) + .first() + ) # Only include if thread has decision_summary_short or decision_summary if thread and (thread.decision_summary_short or thread.decision_summary): summary = thread.decision_summary_short or ( - thread.decision_summary[:100] + "..." if len(thread.decision_summary or "") > 100 + thread.decision_summary[:100] + "..." + if len(thread.decision_summary or "") > 100 else thread.decision_summary ) if summary: - decisions.append(CrossPhaseDecision( - question_title=feature.title, - decision_summary_short=summary, - aspect_title=module.title, - )) + decisions.append( + CrossPhaseDecision( + question_title=feature.title, + decision_summary_short=summary, + aspect_title=module.title, + ) + ) # Cap decisions per phase if len(decisions) >= max_decisions_per_phase: @@ -123,47 +135,54 @@ def _build_cross_project_context_for_project_chat( if len(description) > 200: description = description[:200] + "..." - phases_context.append(CrossPhaseContext( - phase_id=str(phase.id), - phase_title=phase.title, - phase_description=description, - decisions=decisions, - )) + phases_context.append( + CrossPhaseContext( + phase_id=str(phase.id), + phase_title=phase.title, + phase_description=description, + decisions=decisions, + ) + ) # 2. Query project-level features (module.brainstorming_phase_id IS NULL) - project_modules = db.query(Module).filter( - Module.project_id == project_id, - Module.brainstorming_phase_id.is_(None), - Module.archived_at.is_(None) - ).all() + project_modules = ( + db.query(Module) + .filter(Module.project_id == project_id, Module.brainstorming_phase_id.is_(None), Module.archived_at.is_(None)) + .all() + ) for module in project_modules: # Get IMPLEMENTATION features (not CONVERSATION) - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature (could be SPEC or GENERAL context type) - thread = db.query(Thread).filter( - Thread.context_id == str(feature.id) - ).first() + thread = db.query(Thread).filter(Thread.context_id == str(feature.id)).first() # Only include if thread has decision summary if thread and (thread.decision_summary_short or thread.decision_summary): summary = thread.decision_summary_short or ( - thread.decision_summary[:100] + "..." if len(thread.decision_summary or "") > 100 + thread.decision_summary[:100] + "..." + if len(thread.decision_summary or "") > 100 else thread.decision_summary ) if summary: - project_features_context.append(ProjectFeatureDecision( - feature_title=feature.title, - module_title=module.title, - decision_summary_short=summary, - )) + project_features_context.append( + ProjectFeatureDecision( + feature_title=feature.title, + module_title=module.title, + decision_summary_short=summary, + ) + ) # Cap project features if len(project_features_context) >= max_project_features: @@ -205,9 +224,7 @@ def load_context( ValueError: If discussion not found or not associated with a project. """ # Load discussion - discussion = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not discussion: raise ValueError(f"Discussion {project_chat_id} not found") @@ -220,9 +237,7 @@ def load_context( ) # Load organization (always required) - organization = db.query(Organization).filter( - Organization.id == discussion.org_id - ).first() + organization = db.query(Organization).filter(Organization.id == discussion.org_id).first() if not organization: raise ValueError(f"Organization {discussion.org_id} not found") @@ -243,9 +258,7 @@ def load_context( recent_messages.append(msg_data) # Load project context - project = db.query(Project).filter( - Project.id == discussion.project_id - ).first() + project = db.query(Project).filter(Project.id == discussion.project_id).first() if not project: raise ValueError(f"Project {discussion.project_id} not found") @@ -255,10 +268,11 @@ def load_context( grounding_summary = None grounding_context = None - grounding_file = db.query(GroundingFile).filter( - GroundingFile.project_id == project.id, - GroundingFile.filename == "agents.md" - ).first() + grounding_file = ( + db.query(GroundingFile) + .filter(GroundingFile.project_id == project.id, GroundingFile.filename == "agents.md") + .first() + ) if grounding_file: # Prefer summary when available (saves tokens), fall back to full content @@ -274,8 +288,7 @@ def load_context( has_grounding = True grounding_context = grounding_file.content logger.info( - f"Using full agents.md for project chat {project_chat_id} " - f"({len(grounding_file.content)} chars)" + f"Using full agents.md for project chat {project_chat_id} ({len(grounding_file.content)} chars)" ) # Check for repositories (brownfield indicator) @@ -294,14 +307,16 @@ def load_context( if sum(languages.values()) > 0: has_repositories = True - repositories.append({ - "slug": repo.slug, - "display_name": repo.display_name, - "repo_url": repo.repo_url, - "default_branch": repo.default_branch, - "user_remarks": repo.user_remarks, - "primary_language": primary_language, - }) + repositories.append( + { + "slug": repo.slug, + "display_name": repo.display_name, + "repo_url": repo.repo_url, + "default_branch": repo.default_branch, + "user_remarks": repo.user_remarks, + "primary_language": primary_language, + } + ) # Check if code explorer is enabled code_explorer_enabled = False @@ -311,6 +326,7 @@ def load_context( # Check if web search is enabled from app.services.platform_settings_service import is_web_search_available_sync + web_search_enabled = is_web_search_available_sync(db) # Note: We no longer load last_exploration_output/prompt here because @@ -318,10 +334,13 @@ def load_context( # messages, providing proper chronological context to the agent. # Load existing phases - existing_phases_db = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project.id, - BrainstormingPhase.archived_at.is_(None) - ).order_by(BrainstormingPhase.created_at.desc()).limit(10).all() + existing_phases_db = ( + db.query(BrainstormingPhase) + .filter(BrainstormingPhase.project_id == project.id, BrainstormingPhase.archived_at.is_(None)) + .order_by(BrainstormingPhase.created_at.desc()) + .limit(10) + .all() + ) existing_phases = [ ExistingPhaseInfo( @@ -333,27 +352,40 @@ def load_context( ] # Load project-level modules (for feature placement proposals) - existing_modules_db = db.query(Module).filter( - Module.project_id == project.id, - Module.brainstorming_phase_id.is_(None), # Project-level modules only - Module.archived_at.is_(None), - Module.module_type == ModuleType.IMPLEMENTATION, - ).order_by(Module.created_at.desc()).limit(20).all() + existing_modules_db = ( + db.query(Module) + .filter( + Module.project_id == project.id, + Module.brainstorming_phase_id.is_(None), # Project-level modules only + Module.archived_at.is_(None), + Module.module_type == ModuleType.IMPLEMENTATION, + ) + .order_by(Module.created_at.desc()) + .limit(20) + .all() + ) existing_modules = [] for module in existing_modules_db: # Count features in this module - feature_count = db.query(func.count(Feature.id)).filter( - Feature.module_id == module.id, - Feature.archived_at.is_(None), - ).scalar() or 0 - - existing_modules.append(ExistingModuleInfo( - module_id=str(module.id), - title=module.title, - description=module.description[:200] if module.description else None, - feature_count=feature_count, - )) + feature_count = ( + db.query(func.count(Feature.id)) + .filter( + Feature.module_id == module.id, + Feature.archived_at.is_(None), + ) + .scalar() + or 0 + ) + + existing_modules.append( + ExistingModuleInfo( + module_id=str(module.id), + title=module.title, + description=module.description[:200] if module.description else None, + feature_count=feature_count, + ) + ) # Build cross-project context (decisions from existing phases) cross_project_context = _build_cross_project_context_for_project_chat(db, project.id) @@ -363,9 +395,7 @@ def load_context( if discussion.target_container_id: from app.services.phase_container_service import PhaseContainerService - preview = PhaseContainerService.get_extension_preview( - db, discussion.target_container_id - ) + preview = PhaseContainerService.get_extension_preview(db, discussion.target_container_id) if preview: target_container = ContainerContext( container_id=preview["container_id"], @@ -436,7 +466,9 @@ async def generate_response( # Load context logger.info(f"Loading context for discussion {project_chat_id}") context = load_context( - db, project_chat_id, user_message, + db, + project_chat_id, + user_message, is_exploration_followup=is_exploration_followup, is_web_search_followup=is_web_search_followup, ) @@ -497,7 +529,7 @@ async def handle_user_message( # Get usage stats from model client usage_stats = {} - if hasattr(model_client, 'get_usage_stats'): + if hasattr(model_client, "get_usage_stats"): usage_stats = model_client.get_usage_stats() return { diff --git a/backend/app/agents/project_chat_assistant/types.py b/backend/app/agents/project_chat_assistant/types.py index ccf0280..ad8c1b1 100644 --- a/backend/app/agents/project_chat_assistant/types.py +++ b/backend/app/agents/project_chat_assistant/types.py @@ -6,10 +6,11 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID -from pydantic import BaseModel, Field as PydanticField +from pydantic import BaseModel +from pydantic import Field as PydanticField if TYPE_CHECKING: from app.agents.brainstorm_conversation.types import CrossProjectContext @@ -18,6 +19,7 @@ @dataclass class ExistingPhaseInfo: """Summary info about an existing brainstorming phase.""" + phase_id: str title: str description: Optional[str] = None @@ -26,6 +28,7 @@ class ExistingPhaseInfo: @dataclass class ExistingModuleInfo: """Summary info about an existing project-level module.""" + module_id: str title: str description: Optional[str] = None @@ -34,6 +37,7 @@ class ExistingModuleInfo: class MCQOption(BaseModel): """A single option for an MCQ question.""" + id: str text: str @@ -41,6 +45,7 @@ class MCQOption(BaseModel): @dataclass class ContainerContext: """Context about a target container for extension creation.""" + container_id: str container_title: str initial_spec_summary: Optional[str] = None @@ -60,6 +65,7 @@ class ProjectChatContext: Note: Project creation is handled by the ProjectWizard, not by project-chat discussions. Pre-phase discussions always operate within an existing project. """ + # Identifiers project_chat_id: UUID org_id: UUID @@ -78,7 +84,9 @@ class ProjectChatContext: grounding_context: Optional[str] = None # Full agents.md content or summary (for LLM context) # Repository info (brownfield indicator) - supports multiple repos - repositories: List[Dict[str, Any]] = field(default_factory=list) # List of repos with slug, display_name, repo_url, default_branch, user_remarks, primary_language + repositories: List[Dict[str, Any]] = field( + default_factory=list + ) # List of repos with slug, display_name, repo_url, default_branch, user_remarks, primary_language has_repositories: bool = False # True if project has any repos with code # Code exploration capability @@ -126,6 +134,7 @@ class ProjectChatAssistantResponse(BaseModel): This is a Pydantic model to enable AutoGen's output_content_type feature for structured JSON output from the LLM. """ + # The reply text to show the user reply_text: str @@ -175,10 +184,7 @@ def to_response_data(self) -> Dict[str, Any]: } if self.mcq_options: - data["mcq_options"] = [ - {"id": opt.id, "text": opt.text} - for opt in self.mcq_options - ] + data["mcq_options"] = [{"id": opt.id, "text": opt.text} for opt in self.mcq_options] # Phase fields if self.proposed_title: diff --git a/backend/app/agents/project_chat_gating/agent.py b/backend/app/agents/project_chat_gating/agent.py index 02cfe66..e378fc8 100644 --- a/backend/app/agents/project_chat_gating/agent.py +++ b/backend/app/agents/project_chat_gating/agent.py @@ -15,7 +15,7 @@ from .types import GatingResponse if TYPE_CHECKING: - from app.agents.llm_client import LLMCallLogger, LiteLLMChatCompletionClient + from app.agents.llm_client import LiteLLMChatCompletionClient, LLMCallLogger logger = logging.getLogger(__name__) diff --git a/backend/app/agents/response_parser.py b/backend/app/agents/response_parser.py index cbec62f..7539d61 100644 --- a/backend/app/agents/response_parser.py +++ b/backend/app/agents/response_parser.py @@ -12,7 +12,6 @@ import json import logging -import re from typing import Any, Dict, List, Optional, Union logger = logging.getLogger(__name__) @@ -55,7 +54,7 @@ def strip_markdown_json(text: str) -> str: if not text: return text - lines = text.split('\n') + lines = text.split("\n") # Need at least 2 lines for opening and closing fences if len(lines) < 2: @@ -66,18 +65,18 @@ def strip_markdown_json(text: str) -> str: # Check for opening fence (```, ```json, ```JSON, etc.) # and closing fence (must be exactly ```) - if first_line.startswith('```') and last_line == '```': + if first_line.startswith("```") and last_line == "```": # Remove first and last lines only, preserving everything in between - inner_content = '\n'.join(lines[1:-1]) + inner_content = "\n".join(lines[1:-1]) return inner_content.strip() # Handle case where there's trailing content after closing fence # e.g., "```json\n{...}\n```\n\nSome explanation" # Find the last line that is exactly "```" for i in range(len(lines) - 1, 0, -1): - if lines[i].strip() == '```': - if first_line.startswith('```'): - inner_content = '\n'.join(lines[1:i]) + if lines[i].strip() == "```": + if first_line.startswith("```"): + inner_content = "\n".join(lines[1:i]) return inner_content.strip() break @@ -101,7 +100,7 @@ def strip_markdown_content(text: str) -> str: if not text: return text - lines = text.split('\n') + lines = text.split("\n") if len(lines) < 2: return text @@ -110,18 +109,18 @@ def strip_markdown_content(text: str) -> str: last_line = lines[-1].strip() # Check for opening fence and closing fence - opens_with_fence = first_line.startswith('```') - closes_with_fence = last_line == '```' + opens_with_fence = first_line.startswith("```") + closes_with_fence = last_line == "```" if opens_with_fence and closes_with_fence: - inner_content = '\n'.join(lines[1:-1]) + inner_content = "\n".join(lines[1:-1]) return inner_content.strip() # Handle trailing content after fence for i in range(len(lines) - 1, 0, -1): - if lines[i].strip() == '```': + if lines[i].strip() == "```": if opens_with_fence: - inner_content = '\n'.join(lines[1:i]) + inner_content = "\n".join(lines[1:i]) return inner_content.strip() break @@ -169,8 +168,8 @@ def extract_json_from_text(text: str) -> Optional[str]: text = text.strip() # Find the start of JSON (either { or [) - obj_start = text.find('{') - arr_start = text.find('[') + obj_start = text.find("{") + arr_start = text.find("[") if obj_start == -1 and arr_start == -1: return None @@ -178,17 +177,17 @@ def extract_json_from_text(text: str) -> Optional[str]: # Determine which comes first if obj_start == -1: start = arr_start - open_char, close_char = '[', ']' + open_char, close_char = "[", "]" elif arr_start == -1: start = obj_start - open_char, close_char = '{', '}' + open_char, close_char = "{", "}" else: if obj_start < arr_start: start = obj_start - open_char, close_char = '{', '}' + open_char, close_char = "{", "}" else: start = arr_start - open_char, close_char = '[', ']' + open_char, close_char = "[", "]" # Track brackets to find matching close depth = 0 @@ -200,7 +199,7 @@ def extract_json_from_text(text: str) -> Optional[str]: escape_next = False continue - if char == '\\': + if char == "\\": escape_next = True continue @@ -214,15 +213,12 @@ def extract_json_from_text(text: str) -> Optional[str]: elif char == close_char: depth -= 1 if depth == 0: - return text[start:i + 1] + return text[start : i + 1] return None -def parse_json_response( - text: str, - fallback_to_raw: bool = False -) -> Union[Dict[str, Any], List[Any], str]: +def parse_json_response(text: str, fallback_to_raw: bool = False) -> Union[Dict[str, Any], List[Any], str]: """ Parse a JSON response from an LLM with multiple fallback strategies. @@ -268,23 +264,15 @@ def parse_json_response( # Strategy 4: Fallback to raw text if fallback_to_raw: - logger.warning( - f"Failed to parse JSON from response, returning raw text. " - f"Preview: {cleaned[:100]}..." - ) + logger.warning(f"Failed to parse JSON from response, returning raw text. Preview: {cleaned[:100]}...") return cleaned # All strategies failed - raise json.JSONDecodeError( - f"Failed to parse JSON from response", - cleaned, - 0 - ) + raise json.JSONDecodeError("Failed to parse JSON from response", cleaned, 0) def safe_parse_json( - text: str, - default: Optional[Union[Dict[str, Any], List[Any]]] = None + text: str, default: Optional[Union[Dict[str, Any], List[Any]]] = None ) -> Union[Dict[str, Any], List[Any]]: """ Safely parse JSON, returning a default value on failure. diff --git a/backend/app/agents/retry.py b/backend/app/agents/retry.py index a850762..ec2e831 100644 --- a/backend/app/agents/retry.py +++ b/backend/app/agents/retry.py @@ -18,7 +18,7 @@ async def create(self, messages): import asyncio import logging -from typing import Any, Awaitable, Callable, List, Optional, Tuple, Type, TypeVar +from typing import Awaitable, Callable, List, Optional, Tuple, Type, TypeVar import litellm.exceptions as litellm_exc from tenacity import ( @@ -242,9 +242,7 @@ def before_sleep_callback(retry_state: RetryCallState) -> None: with attempt: result = await func() if attempt.retry_state.attempt_number > 1: - logger.info( - f"Retry succeeded on attempt {attempt.retry_state.attempt_number}" - ) + logger.info(f"Retry succeeded on attempt {attempt.retry_state.attempt_number}") return result except Exception as e: # Check if it's a retryable exception that exhausted all attempts @@ -281,9 +279,7 @@ async def _with_retry_legacy( return result except Exception as e: last_exception = e - logger.warning( - f"Attempt {attempt + 1}/{max_attempts} failed: {type(e).__name__}: {e}" - ) + logger.warning(f"Attempt {attempt + 1}/{max_attempts} failed: {type(e).__name__}: {e}") # Call the on_retry callback if provided if on_retry is not None: @@ -306,9 +302,7 @@ async def _with_retry_legacy( ) -def _calculate_legacy_backoff( - attempt: int, backoff_ms: Optional[List[int]] = None -) -> float: +def _calculate_legacy_backoff(attempt: int, backoff_ms: Optional[List[int]] = None) -> float: """Calculate backoff delay for legacy mode.""" if backoff_ms is None: backoff_ms = LEGACY_BACKOFF_MS @@ -318,9 +312,7 @@ def _calculate_legacy_backoff( # Backward-compatible alias -def calculate_backoff_delay( - attempt: int, backoff_ms: Optional[List[int]] = None -) -> float: +def calculate_backoff_delay(attempt: int, backoff_ms: Optional[List[int]] = None) -> float: """ Calculate the backoff delay for a given attempt. diff --git a/backend/app/auth/__init__.py b/backend/app/auth/__init__.py index 32ae8fa..da826d1 100644 --- a/backend/app/auth/__init__.py +++ b/backend/app/auth/__init__.py @@ -1,6 +1,7 @@ """ Authentication utilities and dependencies. """ + from app.auth.providers import ( KNOWN_PROVIDERS, get_configured_providers, diff --git a/backend/app/auth/api_key_utils.py b/backend/app/auth/api_key_utils.py index 015c930..ff1b6e7 100644 --- a/backend/app/auth/api_key_utils.py +++ b/backend/app/auth/api_key_utils.py @@ -2,6 +2,7 @@ import hashlib import uuid + import bcrypt diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index 1081e25..54f0e5e 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -4,28 +4,27 @@ Provides dependency functions for extracting and validating the current user from JWT tokens, session cookies, or API keys in requests. """ + from datetime import datetime, timezone from typing import Annotated -from uuid import UUID -from fastapi import Cookie, Depends, HTTPException, Request, status, Header -from fastapi.security import OAuth2PasswordBearer, HTTPBearer +from fastapi import Cookie, Depends, Header, HTTPException, status +from fastapi.security import HTTPBearer, OAuth2PasswordBearer from jose import JWTError from sqlalchemy.orm import Session -from app.auth.utils import decode_access_token -from app.auth.api_key_utils import verify_api_key, hash_api_key_sha256 +from app.auth.api_key_utils import hash_api_key_sha256, verify_api_key from app.auth.encryption_utils import verify_api_key_encrypted +from app.auth.utils import decode_access_token from app.config import settings from app.database import get_db -from app.models.user import User from app.models.api_key import ApiKey from app.models.project import Project -from app.services.user_service import UserService +from app.models.user import User +from app.services.mcp_oauth_service import MCPOAuthService from app.services.project_service import ProjectService from app.services.project_share_service import ProjectShareService -from app.services.mcp_oauth_service import MCPOAuthService - +from app.services.user_service import UserService # OAuth2 scheme for token extraction (makes it optional so we can fall back to cookie) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False) @@ -207,11 +206,7 @@ async def __call__( # Fast O(1) lookup via SHA-256 index lookup_hash = hash_api_key_sha256(api_key) - matched_key = ( - db.query(ApiKey) - .filter(ApiKey.key_lookup_hash == lookup_hash, ApiKey.revoked == False) - .first() - ) + matched_key = db.query(ApiKey).filter(ApiKey.key_lookup_hash == lookup_hash, ApiKey.revoked == False).first() if matched_key: # Verify the key actually matches (SHA-256 collision safety) @@ -223,11 +218,7 @@ async def __call__( # Fallback: scan legacy keys that don't have a lookup hash yet if not matched_key: - legacy_keys = ( - db.query(ApiKey) - .filter(ApiKey.key_lookup_hash.is_(None), ApiKey.revoked == False) - .all() - ) + legacy_keys = db.query(ApiKey).filter(ApiKey.key_lookup_hash.is_(None), ApiKey.revoked == False).all() for key in legacy_keys: if key.key_encrypted: if verify_api_key_encrypted(api_key, key.key_encrypted): @@ -262,9 +253,7 @@ async def __call__( ) # Verify user has access to this project (direct, via group, or via org) - has_access = ProjectShareService.user_has_project_access( - db, project.id, user.id - ) + has_access = ProjectShareService.user_has_project_access(db, project.id, user.id) if not has_access: # Return 404 for privacy (don't reveal project exists) raise HTTPException( @@ -329,9 +318,7 @@ async def __call__( credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing credentials", - headers={ - "WWW-Authenticate": f'Bearer resource="{resource_url}"' - }, + headers={"WWW-Authenticate": f'Bearer resource="{resource_url}"'}, ) if not authorization: @@ -351,6 +338,7 @@ async def __call__( # Try to decode as JWT import logging + logger = logging.getLogger(__name__) try: @@ -392,9 +380,7 @@ async def __call__( project = ProjectService.get_by_identifier(db, self.project_id) if project: # Verify user has access to this project (direct or via group) - has_access = ProjectShareService.user_has_project_access( - db, project.id, user.id - ) + has_access = ProjectShareService.user_has_project_access(db, project.id, user.id) if has_access: return user, project, None diff --git a/backend/app/auth/domain_validation.py b/backend/app/auth/domain_validation.py index 8f5007f..1048e59 100644 --- a/backend/app/auth/domain_validation.py +++ b/backend/app/auth/domain_validation.py @@ -1,4 +1,5 @@ """Email domain validation for signup.""" + from app.config import settings diff --git a/backend/app/auth/platform_admin.py b/backend/app/auth/platform_admin.py index ac4b3a1..63ea31f 100644 --- a/backend/app/auth/platform_admin.py +++ b/backend/app/auth/platform_admin.py @@ -1,4 +1,5 @@ """Platform admin authorization helpers.""" + from typing import Annotated from fastapi import Depends, HTTPException, status diff --git a/backend/app/auth/providers.py b/backend/app/auth/providers.py index 1891161..a7e0bba 100644 --- a/backend/app/auth/providers.py +++ b/backend/app/auth/providers.py @@ -7,12 +7,12 @@ Enterprise auth providers (e.g. Scalekit SSO) are registered via the plugin registry — see app/plugin_registry.py. """ + from authlib.integrations.starlette_client import OAuth from app.config import settings from app.schemas.oauth import NormalizedUserInfo - # Global OAuth registry instance oauth = OAuth() diff --git a/backend/app/auth/service.py b/backend/app/auth/service.py index c157ba8..326f970 100644 --- a/backend/app/auth/service.py +++ b/backend/app/auth/service.py @@ -32,6 +32,7 @@ - Check expires_at and refresh tokens proactively - Consistent with existing OrgBugTracker token encryption pattern """ + from datetime import datetime, timezone from sqlalchemy.orm import Session @@ -42,7 +43,6 @@ from app.models.user_identity import UserIdentity from app.schemas.oauth import NormalizedUserInfo - # Hardcoded provider mapping for Phase 1 # In Phase 2, this could be read from the database for external IdPs PROVIDER_CONFIG = { @@ -103,16 +103,11 @@ def get_or_create_identity_provider( # Validate slug if provider_slug not in PROVIDER_CONFIG: raise ValueError( - f"Unknown provider: '{provider_slug}'. " - f"Valid providers are: {', '.join(PROVIDER_CONFIG.keys())}" + f"Unknown provider: '{provider_slug}'. Valid providers are: {', '.join(PROVIDER_CONFIG.keys())}" ) # Try to find existing provider - provider = ( - db.query(IdentityProvider) - .filter(IdentityProvider.slug == provider_slug) - .first() - ) + provider = db.query(IdentityProvider).filter(IdentityProvider.slug == provider_slug).first() if provider: return provider @@ -248,8 +243,4 @@ def get_user_identities( Returns: List of UserIdentity instances """ - return ( - db.query(UserIdentity) - .filter(UserIdentity.user_id == user_id) - .all() - ) + return db.query(UserIdentity).filter(UserIdentity.user_id == user_id).all() diff --git a/backend/app/auth/trial.py b/backend/app/auth/trial.py index 442351d..2a2d8f5 100644 --- a/backend/app/auth/trial.py +++ b/backend/app/auth/trial.py @@ -15,7 +15,8 @@ - No token limits enforced - No trial expiration checks """ -from datetime import datetime, timezone, timedelta + +from datetime import datetime, timedelta, timezone from typing import Annotated, Optional from fastapi import Depends, HTTPException, status @@ -33,6 +34,7 @@ def _has_plan_plugin() -> bool: """Check if the plan enforcement plugin is registered.""" from app.plugin_registry import get_plugin_registry + return get_plugin_registry().plan_plugin is not None diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py index b0651d5..1505d46 100644 --- a/backend/app/auth/utils.py +++ b/backend/app/auth/utils.py @@ -3,8 +3,9 @@ Provides password hashing and JWT token management utilities. """ -from datetime import datetime, timedelta, UTC -from typing import Dict, Any + +from datetime import UTC, datetime, timedelta +from typing import Any, Dict import bcrypt from jose import JWTError, jwt @@ -23,14 +24,14 @@ def hash_password(password: str) -> str: Bcrypt hashed password string """ # Convert password to bytes (bcrypt requires bytes) - password_bytes = password.encode('utf-8') + password_bytes = password.encode("utf-8") # Generate salt and hash salt = bcrypt.gensalt() hashed = bcrypt.hashpw(password_bytes, salt) # Return as string for database storage - return hashed.decode('utf-8') + return hashed.decode("utf-8") def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -45,8 +46,8 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: True if the password matches, False otherwise """ # Convert both to bytes - password_bytes = plain_password.encode('utf-8') - hashed_bytes = hashed_password.encode('utf-8') + password_bytes = plain_password.encode("utf-8") + hashed_bytes = hashed_password.encode("utf-8") # Verify return bcrypt.checkpw(password_bytes, hashed_bytes) diff --git a/backend/app/config.py b/backend/app/config.py index cb6cd06..171834f 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -44,13 +44,9 @@ class Settings(BaseSettings): default="postgresql://mfbt:iammfbt@localhost:5432/mfbt_dev", description="PostgreSQL connection URL", ) - database_echo: bool = Field( - default=False, description="Echo SQL queries (for debugging)" - ) + database_echo: bool = Field(default=False, description="Echo SQL queries (for debugging)") database_pool_size: int = Field(default=10, description="Database connection pool size") - database_max_overflow: int = Field( - default=20, description="Max overflow connections" - ) + database_max_overflow: int = Field(default=20, description="Max overflow connections") # Testing test_database_url: PostgresDsn = Field( @@ -86,9 +82,7 @@ def kafka_bootstrap_servers(self) -> str: default="", description="Encryption key for API keys and integration tokens (used with PBKDF2)", ) - access_token_expire_minutes: int = Field( - default=129600, description="Access token expiration time in minutes" - ) + access_token_expire_minutes: int = Field(default=129600, description="Access token expiration time in minutes") # OAuth Providers (Phase 9) google_client_id: str | None = Field( @@ -235,6 +229,7 @@ def _default_slack_command_name(self) -> "Settings": if not self.slack_command_name: self.slack_command_name = "/mfbt" return self + slack_oauth_redirect_url: str | None = Field( default=None, description="Override for Slack OAuth redirect URL (e.g. ngrok tunnel). If empty, constructed from BASE_URL.", @@ -253,33 +248,21 @@ def platform_admin_emails(self) -> set[str]: """Parse platform admin emails into a set for fast lookup.""" if not self.platform_admins: return set() - return { - email.strip().lower() - for email in self.platform_admins.split(",") - if email.strip() - } + return {email.strip().lower() for email in self.platform_admins.split(",") if email.strip()} @property def permitted_domains(self) -> set[str]: """Parse permitted signup domains into a set for fast lookup.""" if not self.permitted_signup_domains: return set() - return { - domain.strip().lower() - for domain in self.permitted_signup_domains.split(",") - if domain.strip() - } + return {domain.strip().lower() for domain in self.permitted_signup_domains.split(",") if domain.strip()} @property def trial_exempted_emails(self) -> set[str]: """Parse trial-exempted emails into a set for fast lookup.""" if not self.trial_mode_exempted_user_emails: return set() - return { - email.strip().lower() - for email in self.trial_mode_exempted_user_emails.split(",") - if email.strip() - } + return {email.strip().lower() for email in self.trial_mode_exempted_user_emails.split(",") if email.strip()} @property def trial_exempted_domains(self) -> set[str]: @@ -287,9 +270,7 @@ def trial_exempted_domains(self) -> set[str]: if not self.trial_mode_exempted_email_domains: return set() return { - domain.strip().lower() - for domain in self.trial_mode_exempted_email_domains.split(",") - if domain.strip() + domain.strip().lower() for domain in self.trial_mode_exempted_email_domains.split(",") if domain.strip() } @property @@ -331,10 +312,7 @@ def has_slack_oauth_config(self) -> bool: @property def has_github_oauth_env_config(self) -> bool: """Check if GitHub OAuth env vars are configured with non-empty values.""" - return bool( - self.github_integration_oauth_client_id - and self.github_integration_oauth_client_secret - ) + return bool(self.github_integration_oauth_client_id and self.github_integration_oauth_client_secret) @lru_cache diff --git a/backend/app/database.py b/backend/app/database.py index d249ad3..2069737 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -3,17 +3,16 @@ Provides SQLAlchemy engine, session factory, and base model class. """ + from contextlib import asynccontextmanager from typing import AsyncGenerator, Generator -from sqlalchemy import create_engine, MetaData +from sqlalchemy import MetaData, create_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session -from sqlalchemy.pool import NullPool +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import settings - # Naming convention for constraints (helpful for Alembic migrations) NAMING_CONVENTION = { "ix": "ix_%(column_0_label)s", diff --git a/backend/app/integrations/__init__.py b/backend/app/integrations/__init__.py index 3458ede..1e68add 100644 --- a/backend/app/integrations/__init__.py +++ b/backend/app/integrations/__init__.py @@ -1,4 +1,5 @@ """Bug tracker integrations.""" + from app.integrations.base import BugTrackerAdapter, TicketData __all__ = ["BugTrackerAdapter", "TicketData"] diff --git a/backend/app/integrations/base.py b/backend/app/integrations/base.py index 53f1373..646a20d 100644 --- a/backend/app/integrations/base.py +++ b/backend/app/integrations/base.py @@ -1,4 +1,5 @@ """Base interface for bug tracker adapters.""" + from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any diff --git a/backend/app/integrations/factory.py b/backend/app/integrations/factory.py index 4e8c022..1e4c3d4 100644 --- a/backend/app/integrations/factory.py +++ b/backend/app/integrations/factory.py @@ -1,4 +1,5 @@ """Factory for creating bug tracker adapters.""" + from typing import Any from app.integrations.base import BugTrackerAdapter @@ -7,9 +8,7 @@ from app.integrations.jira import JiraAdapter -def get_adapter( - provider: str, token: str, config: dict[str, Any] | None = None -) -> BugTrackerAdapter: +def get_adapter(provider: str, token: str, config: dict[str, Any] | None = None) -> BugTrackerAdapter: """Get the appropriate bug tracker adapter for the given provider. Args: @@ -31,9 +30,6 @@ def get_adapter( adapter_class = adapters.get(provider.lower()) if not adapter_class: - raise ValueError( - f"Unsupported provider: {provider}. " - f"Supported providers: {', '.join(adapters.keys())}" - ) + raise ValueError(f"Unsupported provider: {provider}. Supported providers: {', '.join(adapters.keys())}") return adapter_class(token=token, config=config) diff --git a/backend/app/integrations/github.py b/backend/app/integrations/github.py index 0100c26..ce78b6d 100644 --- a/backend/app/integrations/github.py +++ b/backend/app/integrations/github.py @@ -1,4 +1,5 @@ """GitHub Issues adapter with PAT, GitHub App, and OAuth authentication support.""" + import re import time from typing import Any @@ -261,9 +262,7 @@ async def fetch_ticket(self, external_id: str) -> TicketData: # Parse external_id (format: "owner/repo#123") match = re.match(r"^([^/]+)/([^#]+)#(\d+)$", external_id) if not match: - raise ValueError( - f"Invalid GitHub issue format: {external_id}. Expected 'owner/repo#123'" - ) + raise ValueError(f"Invalid GitHub issue format: {external_id}. Expected 'owner/repo#123'") owner, repo, issue_number = match.groups() @@ -281,9 +280,7 @@ async def fetch_ticket(self, external_id: str) -> TicketData: comments_url = issue_data.get("comments_url") comments_data = [] if comments_url and issue_data.get("comments", 0) > 0: - comments_response = await client.get( - comments_url, headers=headers, timeout=30.0 - ) + comments_response = await client.get(comments_url, headers=headers, timeout=30.0) comments_response.raise_for_status() raw_comments = comments_response.json() comments_data = [ @@ -310,13 +307,8 @@ async def fetch_ticket(self, external_id: str) -> TicketData: "updated_at": issue_data.get("updated_at"), "closed_at": issue_data.get("closed_at"), "author": issue_data.get("user", {}).get("login", "unknown"), - "assignees": [ - assignee.get("login") - for assignee in issue_data.get("assignees", []) - ], - "milestone": issue_data.get("milestone", {}).get("title") - if issue_data.get("milestone") - else None, + "assignees": [assignee.get("login") for assignee in issue_data.get("assignees", [])], + "milestone": issue_data.get("milestone", {}).get("title") if issue_data.get("milestone") else None, } return TicketData( @@ -349,9 +341,7 @@ async def list_repositories(self, page: int = 1, per_page: int = 100) -> list[di # For GitHub App, use installation repositories endpoint url = f"{self.BASE_URL}/installation/repositories" params = {"per_page": per_page, "page": page} - response = await client.get( - url, headers=headers, params=params, timeout=30.0 - ) + response = await client.get(url, headers=headers, params=params, timeout=30.0) response.raise_for_status() data = response.json() return data.get("repositories", []) @@ -370,9 +360,7 @@ async def list_repositories(self, page: int = 1, per_page: int = 100) -> list[di "sort": "updated", "affiliation": "owner,collaborator,organization_member", } - response = await client.get( - url, headers=headers, params=params, timeout=30.0 - ) + response = await client.get(url, headers=headers, params=params, timeout=30.0) response.raise_for_status() return response.json() @@ -400,9 +388,7 @@ async def search_issues( """ # Validate repo format if "/" not in repo_or_project: - raise ValueError( - f"Invalid repository format: {repo_or_project}. Expected 'owner/repo'" - ) + raise ValueError(f"Invalid repository format: {repo_or_project}. Expected 'owner/repo'") # Build search query # GitHub search syntax: repo:owner/repo is:issue query state:open diff --git a/backend/app/integrations/gitlab.py b/backend/app/integrations/gitlab.py index 4d9dc89..163d271 100644 --- a/backend/app/integrations/gitlab.py +++ b/backend/app/integrations/gitlab.py @@ -1,4 +1,5 @@ """GitLab adapter (stub implementation).""" + from typing import Any from app.integrations.base import BugTrackerAdapter, IssueSearchResult, TicketData diff --git a/backend/app/integrations/jira.py b/backend/app/integrations/jira.py index 47f098f..cda77d3 100644 --- a/backend/app/integrations/jira.py +++ b/backend/app/integrations/jira.py @@ -1,4 +1,5 @@ """Jira adapter.""" + import base64 import logging import re @@ -121,18 +122,14 @@ async def test_connection(self) -> dict[str, Any]: # Log response body for debugging non-200 responses try: response_text = response.text[:500] - logger.warning( - f"Jira test_connection: non-200 response body: {response_text}" - ) + logger.warning(f"Jira test_connection: non-200 response body: {response_text}") except Exception: pass response.raise_for_status() user_data = response.json() - display_name = user_data.get( - "displayName", user_data.get("emailAddress", "unknown") - ) + display_name = user_data.get("displayName", user_data.get("emailAddress", "unknown")) logger.info(f"Jira test_connection: success, connected as {display_name}") return { "success": True, @@ -146,9 +143,7 @@ async def test_connection(self) -> dict[str, Any]: error_detail = "" try: error_body = e.response.text[:500] - logger.error( - f"Jira test_connection: HTTP {status_code} error, body: {error_body}" - ) + logger.error(f"Jira test_connection: HTTP {status_code} error, body: {error_body}") error_detail = f" - {error_body}" if error_body else "" except Exception: logger.error(f"Jira test_connection: HTTP {status_code} error") @@ -222,9 +217,7 @@ async def fetch_ticket(self, external_id: str) -> TicketData: else: # Validate issue key format (e.g., "PROJ-123") if not re.match(r"^[A-Z]+-\d+$", external_id): - raise ValueError( - f"Invalid Jira issue key: {external_id}. Expected 'PROJECTKEY-123'" - ) + raise ValueError(f"Invalid Jira issue key: {external_id}. Expected 'PROJECTKEY-123'") issue_key = external_id # Fetch issue data @@ -240,9 +233,7 @@ async def fetch_ticket(self, external_id: str) -> TicketData: "fields": "summary,description,status,labels,comment,created,updated,resolution,assignee,reporter,issuetype" } - response = await client.get( - api_url, headers=headers, params=params, timeout=30.0 - ) + response = await client.get(api_url, headers=headers, params=params, timeout=30.0) response.raise_for_status() issue_data = response.json() @@ -273,11 +264,13 @@ async def fetch_ticket(self, external_id: str) -> TicketData: comment_body = self._extract_text_from_adf(comment_body) elif comment_body is None: comment_body = "" - comments_data.append({ - "author": author_info.get("displayName") or author_info.get("emailAddress", "unknown"), - "body": comment_body, - "created_at": comment.get("created"), - }) + comments_data.append( + { + "author": author_info.get("displayName") or author_info.get("emailAddress", "unknown"), + "body": comment_body, + "created_at": comment.get("created"), + } + ) # Build metadata (use site_url for human-readable browse links) metadata = { @@ -467,9 +460,7 @@ async def search_issues( "maxResults": 30, # Limit for import UI } - response = await client.post( - api_url, headers=headers, json=payload, timeout=30.0 - ) + response = await client.post(api_url, headers=headers, json=payload, timeout=30.0) response.raise_for_status() data = response.json() @@ -534,9 +525,7 @@ async def list_projects(self) -> list[dict]: api_url = f"{api_base_url}/rest/api/3/project" params = {"maxResults": 100, "orderBy": "name"} - response = await client.get( - api_url, headers=headers, params=params, timeout=30.0 - ) + response = await client.get(api_url, headers=headers, params=params, timeout=30.0) response.raise_for_status() projects = response.json() diff --git a/backend/app/main.py b/backend/app/main.py index ae63291..038f909 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -6,21 +6,64 @@ To run the server: uv run uvicorn app.main:app --reload """ + import asyncio import logging from contextlib import asynccontextmanager from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse from starlette.middleware.sessions import SessionMiddleware from app.auth.trial import require_active_trial, require_tokens_available from app.config import settings -from app.routers import auth, orgs, projects, project_repositories, threads, thread_items, integrations, llm_preferences, jobs, websocket, testing, api_keys, mcp_http -from app.routers import brainstorming_phases, modules, features, drafts, activity, agent_api, llm_call_logs, mcp_call_logs, team_roles, grounding, feature_content_versions, platform_settings, email_templates, implementations, phase_containers, grounding_notes -from app.routers import invitations, invite_acceptance, project_shares, user_groups, user_question_sessions, thread_images, dashboard, images, mcp_images, form_drafts, project_chats, org_chats, project_chat_images, conversations -from app.routers import analytics, plan_recommendations +from app.routers import ( + activity, + agent_api, + analytics, + api_keys, + auth, + brainstorming_phases, + conversations, + dashboard, + drafts, + email_templates, + feature_content_versions, + features, + form_drafts, + grounding, + grounding_notes, + images, + implementations, + integrations, + invitations, + invite_acceptance, + jobs, + llm_call_logs, + llm_preferences, + mcp_call_logs, + mcp_http, + mcp_images, + modules, + org_chats, + orgs, + phase_containers, + plan_recommendations, + platform_settings, + project_chat_images, + project_chats, + project_repositories, + project_shares, + projects, + team_roles, + testing, + thread_images, + thread_items, + threads, + user_groups, + user_question_sessions, + websocket, +) logger = logging.getLogger(__name__) @@ -33,7 +76,7 @@ async def lifespan(app: FastAPI): Starts the WebSocket broadcast consumer on startup and stops it on shutdown. """ # Startup - with open('/tmp/mfbt_lifespan.log', 'a') as f: + with open("/tmp/mfbt_lifespan.log", "a") as f: f.write("=" * 80 + "\n") f.write("LIFESPAN STARTUP CALLED\n") f.write("=" * 80 + "\n") @@ -44,8 +87,8 @@ async def lifespan(app: FastAPI): logger.info("Starting WebSocket broadcast consumer...") try: - from app.websocket.broadcast_consumer import get_broadcast_consumer from app.services.kafka_producer import get_sync_kafka_producer + from app.websocket.broadcast_consumer import get_broadcast_consumer # Start sync Kafka producer sync_producer = get_sync_kafka_producer() @@ -64,6 +107,7 @@ async def lifespan(app: FastAPI): except Exception as e: print(f"ERROR creating broadcast consumer: {e}") import traceback + traceback.print_exc() raise diff --git a/backend/app/mcp/server.py b/backend/app/mcp/server.py index ef41aa1..f8a86d2 100644 --- a/backend/app/mcp/server.py +++ b/backend/app/mcp/server.py @@ -5,7 +5,7 @@ from mcp.server import FastMCP from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import sessionmaker from app.mcp.tools.append_feature_note import append_feature_note from app.mcp.tools.create_clarification_question import create_clarification_question @@ -132,9 +132,7 @@ async def handle_get_toc( ) -> dict: """Get table of contents for spec or prompt plan.""" with get_db() as db: - return get_toc( - db, project_id=project_id, project_key=project_key, target=target - ) + return get_toc(db, project_id=project_id, project_key=project_key, target=target) # Register getSection tool @mcp.tool( diff --git a/backend/app/mcp/tools/append_feature_note.py b/backend/app/mcp/tools/append_feature_note.py index f44dffc..f151fb7 100644 --- a/backend/app/mcp/tools/append_feature_note.py +++ b/backend/app/mcp/tools/append_feature_note.py @@ -1,7 +1,8 @@ """appendFeatureNote MCP tool - append a note to a feature's implementation notes.""" -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from uuid import UUID + from sqlalchemy.orm import Session from app.mcp.utils.project_resolver import resolve_project diff --git a/backend/app/mcp/tools/create_clarification_question.py b/backend/app/mcp/tools/create_clarification_question.py index 9ede8c0..ce96486 100644 --- a/backend/app/mcp/tools/create_clarification_question.py +++ b/backend/app/mcp/tools/create_clarification_question.py @@ -1,11 +1,12 @@ """createClarificationQuestion MCP tool - create a clarification question thread.""" -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional + from sqlalchemy.orm import Session from app.mcp.utils.project_resolver import resolve_project -from app.services.thread_service import ThreadService from app.models import ContextType +from app.services.thread_service import ThreadService def create_clarification_question( @@ -88,7 +89,7 @@ def create_clarification_question( "project_key": project.key, "title": thread.title, "pending_approval": thread.pending_approval, - "context_type": thread.context_type.value if hasattr(thread.context_type, 'value') else thread.context_type, + "context_type": thread.context_type.value if hasattr(thread.context_type, "value") else thread.context_type, "context_id": thread.context_id, "created_at": thread.created_at.isoformat(), "comment_preview": body[:200] + "..." if len(body) > 200 else body, diff --git a/backend/app/mcp/tools/get_context.py b/backend/app/mcp/tools/get_context.py index adfb1ec..74cbfe4 100644 --- a/backend/app/mcp/tools/get_context.py +++ b/backend/app/mcp/tools/get_context.py @@ -1,14 +1,15 @@ """getContext MCP tool - retrieve full project context.""" -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional + from sqlalchemy.orm import Session from app.mcp.utils.project_resolver import resolve_project from app.models.brainstorming_phase import BrainstormingPhase -from app.models.final_spec import FinalSpec +from app.models.feature import Feature, FeatureStatus, FeatureType from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureType, FeatureStatus def get_context( @@ -45,7 +46,7 @@ def get_context( # Build project metadata project_info = { "id": str(project.id), - "type": project.type.value if hasattr(project.type, 'value') else project.type, + "type": project.type.value if hasattr(project.type, "value") else project.type, "name": project.name, "key": project.key, "parent_application_key": None, @@ -68,18 +69,10 @@ def get_context( phases_info: List[Dict[str, Any]] = [] for phase in phases: # Get final spec for this phase - final_spec = ( - db.query(FinalSpec) - .filter(FinalSpec.brainstorming_phase_id == phase.id) - .first() - ) + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == phase.id).first() # Get final prompt plan for this phase - final_plan = ( - db.query(FinalPromptPlan) - .filter(FinalPromptPlan.brainstorming_phase_id == phase.id) - .first() - ) + final_plan = db.query(FinalPromptPlan).filter(FinalPromptPlan.brainstorming_phase_id == phase.id).first() # Get implementation modules and features for this phase modules = ( @@ -109,45 +102,47 @@ def get_context( features_info: List[Dict[str, Any]] = [] for feature in features: - features_info.append({ - "id": str(feature.id), - "feature_key": feature.feature_key, - "title": feature.title, - "priority": feature.priority.value if hasattr(feature.priority, 'value') else feature.priority, - "category": feature.category, - "completion_status": feature.completion_status.value if hasattr(feature.completion_status, 'value') else feature.completion_status, - "spec_text": feature.spec_text, - "prompt_plan_text": feature.prompt_plan_text, - "has_implementation_notes": bool(feature.implementation_notes), - }) - - modules_info.append({ - "id": str(module.id), - "title": module.title, - "description": module.description, - "order_index": module.order_index, - "features": features_info, - "feature_count": len(features_info), - }) + features_info.append( + { + "id": str(feature.id), + "feature_key": feature.feature_key, + "title": feature.title, + "priority": feature.priority.value if hasattr(feature.priority, "value") else feature.priority, + "category": feature.category, + "completion_status": feature.completion_status.value + if hasattr(feature.completion_status, "value") + else feature.completion_status, + "spec_text": feature.spec_text, + "prompt_plan_text": feature.prompt_plan_text, + "has_implementation_notes": bool(feature.implementation_notes), + } + ) + + modules_info.append( + { + "id": str(module.id), + "title": module.title, + "description": module.description, + "order_index": module.order_index, + "features": features_info, + "feature_count": len(features_info), + } + ) # Calculate completion stats total_features = sum(len(m["features"]) for m in modules_info) completed_features = sum( - 1 for m in modules_info - for f in m["features"] - if f["completion_status"] == "completed" + 1 for m in modules_info for f in m["features"] if f["completion_status"] == "completed" ) in_progress_features = sum( - 1 for m in modules_info - for f in m["features"] - if f["completion_status"] == "in_progress" + 1 for m in modules_info for f in m["features"] if f["completion_status"] == "in_progress" ) phase_info = { "id": str(phase.id), "title": phase.title, "description": phase.description, - "phase_type": phase.phase_type.value if hasattr(phase.phase_type, 'value') else phase.phase_type, + "phase_type": phase.phase_type.value if hasattr(phase.phase_type, "value") else phase.phase_type, "final_spec": { "available": final_spec is not None, "content_markdown": final_spec.content_markdown if final_spec else None, diff --git a/backend/app/mcp/tools/get_feature_notes.py b/backend/app/mcp/tools/get_feature_notes.py index 8903eef..1b56ce3 100644 --- a/backend/app/mcp/tools/get_feature_notes.py +++ b/backend/app/mcp/tools/get_feature_notes.py @@ -1,7 +1,8 @@ """getFeatureNotes MCP tool - retrieve implementation notes for a feature.""" -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from uuid import UUID + from sqlalchemy.orm import Session from app.mcp.utils.project_resolver import resolve_project diff --git a/backend/app/mcp/tools/get_section.py b/backend/app/mcp/tools/get_section.py index aa1238c..5d03ad6 100644 --- a/backend/app/mcp/tools/get_section.py +++ b/backend/app/mcp/tools/get_section.py @@ -1,11 +1,12 @@ """getSection MCP tool - retrieve specific section from spec or prompt plan.""" -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional + from sqlalchemy.orm import Session -from app.mcp.utils.project_resolver import resolve_project from app.mcp.utils.markdown_parser import extract_section -from app.models import SpecType, BrainstormingPhase +from app.mcp.utils.project_resolver import resolve_project +from app.models import BrainstormingPhase, SpecType from app.services.spec_service import SpecService @@ -66,17 +67,13 @@ def get_section( # Fallback: Legacy project-level specs (SpecVersion with project_id) if not content_markdown: spec_type = SpecType.SPECIFICATION if target == "spec" else SpecType.PROMPT_PLAN - active_spec = SpecService.get_active_spec( - db, project_id=project.id, spec_type=spec_type - ) + active_spec = SpecService.get_active_spec(db, project_id=project.id, spec_type=spec_type) if active_spec: content_markdown = active_spec.content_markdown content_json = active_spec.content_json if not content_markdown: - raise ValueError( - f"No active {target} found for project {project.id}" - ) + raise ValueError(f"No active {target} found for project {project.id}") # Extract section section_markdown = None @@ -92,8 +89,7 @@ def get_section( # List available section IDs to help with debugging available_ids = [s["id"] for s in content_json["sections"]] raise ValueError( - f"Section '{section_id}' not found in {target}. " - f"Available section IDs: {', '.join(available_ids)}" + f"Section '{section_id}' not found in {target}. Available section IDs: {', '.join(available_ids)}" ) else: # Fallback: extract from markdown using header matching diff --git a/backend/app/mcp/tools/get_toc.py b/backend/app/mcp/tools/get_toc.py index 2630f27..5a96cb1 100644 --- a/backend/app/mcp/tools/get_toc.py +++ b/backend/app/mcp/tools/get_toc.py @@ -1,11 +1,12 @@ """getToc MCP tool - retrieve table of contents for spec or prompt plan.""" -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional + from sqlalchemy.orm import Session +from app.mcp.utils.markdown_parser import TocEntry, extract_toc from app.mcp.utils.project_resolver import resolve_project -from app.mcp.utils.markdown_parser import extract_toc, TocEntry -from app.models import SpecType, FinalSpec, FinalPromptPlan, BrainstormingPhase +from app.models import BrainstormingPhase, SpecType from app.services.spec_service import SpecService @@ -67,9 +68,7 @@ def get_toc( # Fallback: Legacy project-level specs (SpecVersion with project_id) if not content_markdown: spec_type = SpecType.SPECIFICATION if target == "spec" else SpecType.PROMPT_PLAN - active_spec = SpecService.get_active_spec( - db, project_id=project.id, spec_type=spec_type - ) + active_spec = SpecService.get_active_spec(db, project_id=project.id, spec_type=spec_type) if active_spec: content_markdown = active_spec.content_markdown content_json = active_spec.content_json @@ -79,12 +78,14 @@ def get_toc( # If content_json exists with sections, use it for structured TOC if content_json and "sections" in content_json: for section in content_json["sections"]: - toc.append({ - "id": section["id"], - "title": section["title"], - "has_content": bool(section.get("body_markdown")), - "linked_questions_count": len(section.get("linked_questions", [])) - }) + toc.append( + { + "id": section["id"], + "title": section["title"], + "has_content": bool(section.get("body_markdown")), + "linked_questions_count": len(section.get("linked_questions", [])), + } + ) else: # Fallback: parse markdown headers toc = extract_toc(content_markdown) diff --git a/backend/app/mcp/tools/vfs_cat.py b/backend/app/mcp/tools/vfs_cat.py index 9422f2b..99b215b 100644 --- a/backend/app/mcp/tools/vfs_cat.py +++ b/backend/app/mcp/tools/vfs_cat.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError +from app.mcp.vfs import PathNotFoundError, resolve_path from app.mcp.vfs.content import get_file_content @@ -58,11 +58,7 @@ def vfs_cat( # This ensures users see their own version, not global from app.mcp.vfs import NodeType - if ( - resolved.node_type == NodeType.GROUNDING_FILE - and resolved.grounding_filename == "agents.md" - and user_id - ): + if resolved.node_type == NodeType.GROUNDING_FILE and resolved.grounding_filename == "agents.md" and user_id: from app.services.grounding_service import GroundingService user_uuid = UUID(user_id) @@ -98,10 +94,7 @@ def vfs_cat( return {"error": str(e)} # Add note for agents.md global fallback (when no branch version exists) - if ( - resolved.node_type == NodeType.GROUNDING_FILE - and resolved.grounding_filename == "agents.md" - ): + if resolved.node_type == NodeType.GROUNDING_FILE and resolved.grounding_filename == "agents.md": result["note"] = ( "Always include branch_name when reading or writing this file. " "This is the global starting-point version. Your branch-specific " diff --git a/backend/app/mcp/tools/vfs_find.py b/backend/app/mcp/tools/vfs_find.py index 7906037..475b35b 100644 --- a/backend/app/mcp/tools/vfs_find.py +++ b/backend/app/mcp/tools/vfs_find.py @@ -6,12 +6,12 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError, NodeType, slugify, feature_dir_name, module_dir_name +from app.mcp.vfs import NodeType, PathNotFoundError, feature_dir_name, module_dir_name, resolve_path, slugify from app.mcp.vfs.content import list_directory from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureProvenance +from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType from app.models.implementation import Implementation +from app.models.module import Module, ModuleType from app.services.grounding_service import GroundingService @@ -113,9 +113,7 @@ def _collect_all_paths( elif resolved.node_type == NodeType.SYSTEM_GENERATED_DIR: # /phases/system-generated/ lists all phases - phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project_id - ).all() + phases = db.query(BrainstormingPhase).filter(BrainstormingPhase.project_id == project_id).all() for phase in phases: phase_slug = slugify(phase.title) paths.extend(_collect_all_paths(db, project_id, f"/phases/system-generated/{phase_slug}")) @@ -143,22 +141,30 @@ def _collect_all_paths( elif resolved.node_type == NodeType.FEATURES_DIR: # System-generated features directory - modules = db.query(Module).filter( - Module.brainstorming_phase_id == resolved.phase_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == resolved.phase_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .all() + ) for module in modules: module_slug = module_dir_name(module.module_key, module.title) paths.extend(_collect_all_paths(db, project_id, f"{resolved.path}/{module_slug}")) elif resolved.node_type == NodeType.MODULE_DIR: # System-generated module directory (only IMPLEMENTATION features) - features = db.query(Feature).filter( - Feature.module_id == resolved.module_id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == resolved.module_id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .all() + ) for feature in features: feature_slug = feature_dir_name(feature.feature_key, feature.title) paths.extend(_collect_all_paths(db, project_id, f"{resolved.path}/{feature_slug}")) @@ -174,35 +180,43 @@ def _collect_all_paths( elif resolved.node_type == NodeType.USER_DEFINED_FEATURES_DIR: # /phases/user-defined/features/ lists modules with user-defined features - modules = db.query(Module).filter( - Module.project_id == project_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.project_id == project_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .all() + ) for module in modules: # Check if this module has any user-defined features - has_user_features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ).first() + has_user_features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) + .first() + ) if has_user_features: module_slug = module_dir_name(module.module_key, module.title) paths.extend(_collect_all_paths(db, project_id, f"/phases/user-defined/features/{module_slug}")) elif resolved.node_type == NodeType.USER_DEFINED_MODULE_DIR: # /phases/user-defined/features/{module}/ lists user-defined features - features = db.query(Feature).filter( - Feature.module_id == resolved.module_id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == resolved.module_id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) + .all() + ) for feature in features: feature_slug = feature_dir_name(feature.feature_key, feature.title) paths.extend(_collect_all_paths(db, project_id, f"{resolved.path}/{feature_slug}")) diff --git a/backend/app/mcp/tools/vfs_grep.py b/backend/app/mcp/tools/vfs_grep.py index e39530d..84f479a 100644 --- a/backend/app/mcp/tools/vfs_grep.py +++ b/backend/app/mcp/tools/vfs_grep.py @@ -6,12 +6,12 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError, NodeType, slugify, feature_dir_name, module_dir_name +from app.mcp.vfs import NodeType, PathNotFoundError, feature_dir_name, module_dir_name, resolve_path, slugify from app.mcp.vfs.content import get_file_content, list_directory from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureProvenance +from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType from app.models.implementation import Implementation +from app.models.module import Module, ModuleType from app.services.grounding_service import GroundingService @@ -116,9 +116,7 @@ def _collect_files(db: Session, project_id: UUID, path: str) -> List[str]: # /phases/system-generated/ lists all phases elif resolved.node_type == NodeType.SYSTEM_GENERATED_DIR: - phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project_id - ).all() + phases = db.query(BrainstormingPhase).filter(BrainstormingPhase.project_id == project_id).all() for phase in phases: phase_slug = slugify(phase.title) files.extend(_collect_files(db, project_id, f"/phases/system-generated/{phase_slug}")) @@ -149,22 +147,30 @@ def _collect_files(db: Session, project_id: UUID, path: str) -> List[str]: # Get features directory contents elif resolved.node_type == NodeType.FEATURES_DIR: - modules = db.query(Module).filter( - Module.brainstorming_phase_id == resolved.phase_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == resolved.phase_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .all() + ) for module in modules: module_slug = module_dir_name(module.module_key, module.title) files.extend(_collect_files(db, project_id, f"{resolved.path}/{module_slug}")) # Get module contents elif resolved.node_type == NodeType.MODULE_DIR: - features = db.query(Feature).filter( - Feature.module_id == resolved.module_id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == resolved.module_id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .all() + ) for feature in features: feature_slug = feature_dir_name(feature.feature_key, feature.title) files.extend(_collect_files(db, project_id, f"{resolved.path}/{feature_slug}")) @@ -180,35 +186,43 @@ def _collect_files(db: Session, project_id: UUID, path: str) -> List[str]: # /phases/user-defined/features/ lists modules with user-defined features elif resolved.node_type == NodeType.USER_DEFINED_FEATURES_DIR: - modules = db.query(Module).filter( - Module.project_id == project_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.project_id == project_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .all() + ) for module in modules: # Check if this module has any user-defined features - has_user_features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ).first() + has_user_features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) + .first() + ) if has_user_features: module_slug = module_dir_name(module.module_key, module.title) files.extend(_collect_files(db, project_id, f"/phases/user-defined/features/{module_slug}")) # /phases/user-defined/features/{module}/ lists user-defined features elif resolved.node_type == NodeType.USER_DEFINED_MODULE_DIR: - features = db.query(Feature).filter( - Feature.module_id == resolved.module_id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == resolved.module_id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) + .all() + ) for feature in features: feature_slug = feature_dir_name(feature.feature_key, feature.title) files.extend(_collect_files(db, project_id, f"{resolved.path}/{feature_slug}")) @@ -281,7 +295,7 @@ def _search_content( start = max(0, i - context_lines) end = min(len(lines), i + context_lines + 1) match["context_before"] = lines[start:i] - match["context_after"] = lines[i + 1:end] + match["context_after"] = lines[i + 1 : end] matches.append(match) diff --git a/backend/app/mcp/tools/vfs_head.py b/backend/app/mcp/tools/vfs_head.py index 83a8873..2ec257c 100644 --- a/backend/app/mcp/tools/vfs_head.py +++ b/backend/app/mcp/tools/vfs_head.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError +from app.mcp.vfs import PathNotFoundError, resolve_path from app.mcp.vfs.content import get_file_content diff --git a/backend/app/mcp/tools/vfs_ls.py b/backend/app/mcp/tools/vfs_ls.py index be9c153..7ec97e1 100644 --- a/backend/app/mcp/tools/vfs_ls.py +++ b/backend/app/mcp/tools/vfs_ls.py @@ -1,11 +1,11 @@ """VFS ls tool - list directory contents.""" -from typing import Any, Dict, Optional +from typing import Any, Dict from uuid import UUID from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError, NotADirectoryError +from app.mcp.vfs import PathNotFoundError, resolve_path from app.mcp.vfs.content import list_directory @@ -64,22 +64,13 @@ def vfs_ls( # If not long format, simplify entries if not long: - result["entries"] = [ - {"name": e["name"], "type": e["type"]} - for e in result["entries"] - ] + result["entries"] = [{"name": e["name"], "type": e["type"]} for e in result["entries"]] # If not all, filter hidden files if not all: - result["entries"] = [ - e for e in result["entries"] - if not e["name"].startswith(".") - ] + result["entries"] = [e for e in result["entries"] if not e["name"].startswith(".")] # Also update text output lines = result.get("text", "").split("\n") - result["text"] = "\n".join( - line for line in lines - if not any(part.startswith(".") for part in line.split()) - ) + result["text"] = "\n".join(line for line in lines if not any(part.startswith(".") for part in line.split())) return result diff --git a/backend/app/mcp/tools/vfs_sed.py b/backend/app/mcp/tools/vfs_sed.py index 2bdeacc..6c82e51 100644 --- a/backend/app/mcp/tools/vfs_sed.py +++ b/backend/app/mcp/tools/vfs_sed.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError, NodeType +from app.mcp.vfs import NodeType, PathNotFoundError, resolve_path from app.mcp.vfs.content import get_file_content from app.models.feature import Feature @@ -120,8 +120,8 @@ def _sed_feature_notes( db.refresh(feature) # Broadcast and trigger grounding update - from app.services.feature_service import FeatureService from app.mcp.tools.vfs_write import _trigger_grounding_update + from app.services.feature_service import FeatureService FeatureService._broadcast_feature_update(db, feature, "notes") _trigger_grounding_update(db, project_id, feature.id) @@ -147,8 +147,8 @@ def _sed_grounding_file( matches: int, ) -> Dict[str, Any]: """Apply sed to grounding file.""" - from app.services.grounding_service import GroundingService from app.mcp.tools.vfs_write import _trigger_grounding_summarize + from app.services.grounding_service import GroundingService filename = resolved.grounding_filename grounding_file = GroundingService.get_file(db, project_id, filename) diff --git a/backend/app/mcp/tools/vfs_set_metadata.py b/backend/app/mcp/tools/vfs_set_metadata.py index e1fc2ef..ecdfdb9 100644 --- a/backend/app/mcp/tools/vfs_set_metadata.py +++ b/backend/app/mcp/tools/vfs_set_metadata.py @@ -5,9 +5,9 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError -from app.mcp.vfs.metadata import set_metadata, COMPUTED_KEYS, SIDE_EFFECT_KEYS -from app.mcp.vfs.errors import PermissionDeniedError, InvalidPathError +from app.mcp.vfs import PathNotFoundError, resolve_path +from app.mcp.vfs.errors import InvalidPathError, PermissionDeniedError +from app.mcp.vfs.metadata import COMPUTED_KEYS, set_metadata def vfs_set_metadata( diff --git a/backend/app/mcp/tools/vfs_tail.py b/backend/app/mcp/tools/vfs_tail.py index 13c1ee7..7fa1100 100644 --- a/backend/app/mcp/tools/vfs_tail.py +++ b/backend/app/mcp/tools/vfs_tail.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError +from app.mcp.vfs import PathNotFoundError, resolve_path from app.mcp.vfs.content import get_file_content diff --git a/backend/app/mcp/tools/vfs_tree.py b/backend/app/mcp/tools/vfs_tree.py index a5419d7..0a0d71a 100644 --- a/backend/app/mcp/tools/vfs_tree.py +++ b/backend/app/mcp/tools/vfs_tree.py @@ -5,12 +5,20 @@ from sqlalchemy.orm import Session -from app.mcp.vfs import resolve_path, PathNotFoundError, NodeType, slugify, feature_dir_name, module_dir_name, ResolvedPath +from app.mcp.vfs import ( + NodeType, + PathNotFoundError, + ResolvedPath, + feature_dir_name, + module_dir_name, + resolve_path, + slugify, +) from app.mcp.vfs.content import list_directory from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureProvenance +from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType from app.models.implementation import Implementation +from app.models.module import Module, ModuleType from app.services.grounding_service import GroundingService @@ -106,9 +114,7 @@ def _build_tree( for subdir in ["phases", "project-info", "system-info", "for-coding-agents"]: try: child_resolved = resolve_path(db, project_id, f"/{subdir}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -120,9 +126,7 @@ def _build_tree( for subdir in ["system-generated", "user-defined"]: try: child_resolved = resolve_path(db, project_id, f"/phases/{subdir}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -131,16 +135,17 @@ def _build_tree( elif resolved.node_type == NodeType.SYSTEM_GENERATED_DIR: # /phases/system-generated/ lists all phases - phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project_id - ).order_by(BrainstormingPhase.created_at).all() + phases = ( + db.query(BrainstormingPhase) + .filter(BrainstormingPhase.project_id == project_id) + .order_by(BrainstormingPhase.created_at) + .all() + ) for phase in phases: phase_slug = slugify(phase.title) try: child_resolved = resolve_path(db, project_id, f"/phases/system-generated/{phase_slug}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -151,9 +156,7 @@ def _build_tree( for subdir in ["phase-spec", "phase-prompt-plan", "features"]: try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{subdir}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -175,9 +178,7 @@ def _build_tree( if current_depth + 1 < max_depth: try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/by-section") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) # Replace the placeholder with actual tree children[0] = child_node stats["directories"] += child_stats["directories"] - 1 @@ -189,28 +190,33 @@ def _build_tree( try: result = list_directory(db, project_id, resolved) for entry in result.get("entries", []): - children.append({ - "name": entry["name"], - "type": "file", - }) + children.append( + { + "name": entry["name"], + "type": "file", + } + ) stats["files"] += 1 except Exception: pass elif resolved.node_type == NodeType.FEATURES_DIR: - modules = db.query(Module).filter( - Module.brainstorming_phase_id == resolved.phase_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).order_by(Module.order_index).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == resolved.phase_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .order_by(Module.order_index) + .all() + ) for module in modules: module_slug = module_dir_name(module.module_key, module.title) try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{module_slug}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -218,19 +224,22 @@ def _build_tree( pass elif resolved.node_type == NodeType.MODULE_DIR: - features = db.query(Feature).filter( - Feature.module_id == resolved.module_id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).order_by(Feature.created_at).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == resolved.module_id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .order_by(Feature.created_at) + .all() + ) for feature in features: feature_slug = feature_dir_name(feature.feature_key, feature.title) try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{feature_slug}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -248,9 +257,7 @@ def _build_tree( for dirname in ["implementations", "conversations"]: try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{dirname}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) # Replace placeholder with actual tree for i, c in enumerate(children): if c["name"] == dirname: @@ -265,9 +272,7 @@ def _build_tree( # /phases/user-defined/ has: features/ try: child_resolved = resolve_path(db, project_id, "/phases/user-defined/features") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -276,30 +281,35 @@ def _build_tree( elif resolved.node_type == NodeType.USER_DEFINED_FEATURES_DIR: # /phases/user-defined/features/ lists modules with user-defined features - modules = db.query(Module).filter( - Module.project_id == project_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).order_by(Module.order_index).all() + modules = ( + db.query(Module) + .filter( + Module.project_id == project_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .order_by(Module.order_index) + .all() + ) for module in modules: # Check if this module has any user-defined features - has_user_features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ).first() + has_user_features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) + .first() + ) if has_user_features: module_slug = module_dir_name(module.module_key, module.title) try: child_resolved = resolve_path(db, project_id, f"/phases/user-defined/features/{module_slug}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -308,22 +318,23 @@ def _build_tree( elif resolved.node_type == NodeType.USER_DEFINED_MODULE_DIR: # /phases/user-defined/features/{module}/ lists user-defined features - features = db.query(Feature).filter( - Feature.module_id == resolved.module_id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ).order_by(Feature.created_at).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == resolved.module_id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) + .order_by(Feature.created_at) + .all() + ) for feature in features: feature_slug = feature_dir_name(feature.feature_key, feature.title) try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{feature_slug}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -341,9 +352,7 @@ def _build_tree( for dirname in ["implementations", "conversations"]: try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{dirname}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) # Replace placeholder with actual tree for i, c in enumerate(children): if c["name"] == dirname: @@ -363,9 +372,7 @@ def _build_tree( # /system-info/ has: users/ try: child_resolved = resolve_path(db, project_id, "/system-info/users") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] @@ -396,9 +403,7 @@ def _build_tree( impl_slug = slugify(impl.name) try: child_resolved = resolve_path(db, project_id, f"{resolved.path}/{impl_slug}") - child_node, child_stats = _build_tree( - db, project_id, child_resolved, max_depth, current_depth + 1 - ) + child_node, child_stats = _build_tree(db, project_id, child_resolved, max_depth, current_depth + 1) children.append(child_node) stats["directories"] += child_stats["directories"] stats["files"] += child_stats["files"] diff --git a/backend/app/mcp/tools/vfs_write.py b/backend/app/mcp/tools/vfs_write.py index 820a8af..f80c25e 100644 --- a/backend/app/mcp/tools/vfs_write.py +++ b/backend/app/mcp/tools/vfs_write.py @@ -5,9 +5,7 @@ from sqlalchemy.orm import Session -from app.config import settings -from app.mcp.vfs import resolve_path, PathNotFoundError, NodeType -from app.mcp.vfs.errors import ReadOnlyError +from app.mcp.vfs import NodeType, PathNotFoundError, resolve_path from app.models.feature import Feature from app.models.implementation import Implementation @@ -77,29 +75,43 @@ def vfs_write( # Handle grounding files (/for-coding-agents/*) if resolved.node_type == NodeType.GROUNDING_FILE: return _write_grounding_file( - db, project_uuid, user_uuid, resolved, content, append, - branch_name=branch_name, repo_path=repo_path + db, project_uuid, user_uuid, resolved, content, append, branch_name=branch_name, repo_path=repo_path ) # Handle feature notes.md files (both system-generated and user-defined) - if resolved.node_type in (NodeType.FEATURE_FILE, NodeType.USER_DEFINED_FEATURE_FILE) and resolved.file_name == "notes.md": + if ( + resolved.node_type in (NodeType.FEATURE_FILE, NodeType.USER_DEFINED_FEATURE_FILE) + and resolved.file_name == "notes.md" + ): return _write_feature_notes( - db, project_uuid, user_uuid, resolved, content, append, - coding_agent_name, branch_name=branch_name, repo_path=repo_path + db, + project_uuid, + user_uuid, + resolved, + content, + append, + coding_agent_name, + branch_name=branch_name, + repo_path=repo_path, ) # Handle implementation notes.md files if resolved.node_type == NodeType.IMPLEMENTATION_NOTES_FILE: return _write_implementation_notes( - db, project_uuid, user_uuid, resolved, content, append, - coding_agent_name, branch_name=branch_name, repo_path=repo_path + db, + project_uuid, + user_uuid, + resolved, + content, + append, + coding_agent_name, + branch_name=branch_name, + repo_path=repo_path, ) # Handle conversations.md writes (create thread comments) if resolved.node_type == NodeType.CONVERSATIONS_FILE: - return _write_conversation_comment( - db, project_uuid, user_uuid, resolved, content, coding_agent_name - ) + return _write_conversation_comment(db, project_uuid, user_uuid, resolved, content, coding_agent_name) # All other files are read-only return { @@ -154,14 +166,14 @@ def _write_feature_notes( # Broadcast feature update to WebSocket clients from app.services.feature_service import FeatureService + FeatureService._broadcast_feature_update(db, feature, "notes") # Trigger grounding update job targeting user's branch-specific agents.md # Default to "main" branch to ensure user isolation effective_branch = branch_name or "main" _trigger_grounding_update( - db, project_id, feature.id, - user_id=user_id, branch_name=effective_branch, repo_path=repo_path + db, project_id, feature.id, user_id=user_id, branch_name=effective_branch, repo_path=repo_path ) return { @@ -194,9 +206,7 @@ def _write_implementation_notes( if not resolved.implementation_id: return {"error": "Implementation not found"} - impl = db.query(Implementation).filter( - Implementation.id == resolved.implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == resolved.implementation_id).first() if not impl: return {"error": "Implementation not found"} @@ -226,14 +236,19 @@ def _write_implementation_notes( # Broadcast implementation update to WebSocket clients from app.services.implementation_service import ImplementationService + ImplementationService.broadcast_implementation_updated(db, impl, "notes") # Trigger grounding update job targeting user's branch-specific agents.md # Default to "main" branch to ensure user isolation effective_branch = branch_name or "main" _trigger_grounding_update( - db, project_id, feature.id, - user_id=user_id, branch_name=effective_branch, repo_path=repo_path, + db, + project_id, + feature.id, + user_id=user_id, + branch_name=effective_branch, + repo_path=repo_path, implementation_id=impl.id, ) @@ -270,9 +285,9 @@ def _trigger_grounding_update( """ import logging + from app.models.job import JobType from app.services.job_service import JobService from app.services.project_service import ProjectService - from app.models.job import JobType from workers.core.helpers import publish_job_to_kafka logger = logging.getLogger(__name__) @@ -318,9 +333,7 @@ def _trigger_grounding_update( if success: branch_info = f", branch={branch_name}" if branch_name else "" - logger.info( - f"Triggered grounding update job {job.id} for feature {feature_id}{branch_info}" - ) + logger.info(f"Triggered grounding update job {job.id} for feature {feature_id}{branch_info}") else: logger.warning(f"Failed to publish grounding update job {job.id} to Kafka") @@ -340,9 +353,9 @@ def _trigger_grounding_summarize(db: Session, project_id: UUID) -> None: """ import logging + from app.models.job import JobType from app.services.job_service import JobService from app.services.project_service import ProjectService - from app.models.job import JobType from workers.core.helpers import publish_job_to_kafka logger = logging.getLogger(__name__) @@ -382,9 +395,7 @@ def _trigger_grounding_summarize(db: Session, project_id: UUID) -> None: logger.error(f"Failed to trigger grounding summarize: {e}") -def _trigger_grounding_branch_summarize( - db: Session, project_id: UUID, user_id: UUID, branch_name: str -) -> None: +def _trigger_grounding_branch_summarize(db: Session, project_id: UUID, user_id: UUID, branch_name: str) -> None: """ Trigger an async job to regenerate the summary for a branch-specific agents.md. @@ -396,9 +407,9 @@ def _trigger_grounding_branch_summarize( """ import logging + from app.models.job import JobType from app.services.job_service import JobService from app.services.project_service import ProjectService - from app.models.job import JobType from workers.core.helpers import publish_job_to_kafka logger = logging.getLogger(__name__) @@ -407,9 +418,7 @@ def _trigger_grounding_branch_summarize( # Get project to find org_id project = ProjectService.get_project_by_id(db, project_id) if not project: - logger.warning( - f"Could not trigger branch grounding summarize: project {project_id} not found" - ) + logger.warning(f"Could not trigger branch grounding summarize: project {project_id} not found") return # Create the job @@ -435,13 +444,10 @@ def _trigger_grounding_branch_summarize( if success: logger.info( - f"Triggered branch grounding summarize job {job.id} for " - f"project {project_id}, branch {branch_name}" + f"Triggered branch grounding summarize job {job.id} for project {project_id}, branch {branch_name}" ) else: - logger.warning( - f"Failed to publish branch grounding summarize job {job.id} to Kafka" - ) + logger.warning(f"Failed to publish branch grounding summarize job {job.id} to Kafka") except Exception as e: logger.error(f"Failed to trigger branch grounding summarize: {e}") @@ -474,15 +480,12 @@ def _write_grounding_file( if filename == "agents.md": effective_branch = branch_name or "main" branch_file = GroundingService.update_branch_file( - db, project_id, user_id, effective_branch, content, - append=append, repo_path=repo_path, filename=filename + db, project_id, user_id, effective_branch, content, append=append, repo_path=repo_path, filename=filename ) action = "appended" if append else "written" # Broadcast branch file update to WebSocket clients - GroundingService._broadcast_branch_grounding_update( - db, project_id, branch_file, action - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, branch_file, action) # Trigger summarization for branch file _trigger_grounding_branch_summarize(db, project_id, user_id, effective_branch) @@ -500,9 +503,7 @@ def _write_grounding_file( if grounding_file: # Update existing file - grounding_file = GroundingService.update_file( - db, project_id, filename, content, append=append - ) + grounding_file = GroundingService.update_file(db, project_id, filename, content, append=append) action = "appended" if append else "written" else: # Create new file (only for valid extensions) @@ -511,9 +512,7 @@ def _write_grounding_file( "error": f"Invalid file extension for: {filename}", "hint": "Allowed: .md, .txt, .json, .yaml, .yml", } - grounding_file = GroundingService.create_file( - db, project_id, filename, content, user_id - ) + grounding_file = GroundingService.create_file(db, project_id, filename, content, user_id) action = "created" # Broadcast grounding file update to WebSocket clients @@ -551,12 +550,11 @@ def _write_conversation_comment( from sqlalchemy.orm.attributes import flag_modified - from app.services.thread_service import ThreadService - from app.services.project_share_service import ProjectShareService - from app.services.mention_utils import extract_user_mentions - from app.models.thread import Thread, ContextType + from app.models.thread import Thread from app.models.user import User - from app.models.project import Project + from app.services.mention_utils import extract_user_mentions + from app.services.project_share_service import ProjectShareService + from app.services.thread_service import ThreadService logger = logging.getLogger(__name__) @@ -677,8 +675,8 @@ def _get_or_create_feature_thread( user_id: UUID, ): """Get existing thread or create new one for a feature.""" - from app.models.thread import Thread, ContextType from app.models.feature import Feature + from app.models.thread import ContextType, Thread # Check if thread already exists thread = ( @@ -761,9 +759,13 @@ def run_async_safely(coro): for ref in image_refs: # Find staged image - submission = db.query(MCPImageSubmission).filter( - MCPImageSubmission.submission_id == ref, - ).first() + submission = ( + db.query(MCPImageSubmission) + .filter( + MCPImageSubmission.submission_id == ref, + ) + .first() + ) if not submission: return { @@ -786,12 +788,14 @@ def run_async_safely(coro): # Upload to S3 using async service try: + async def upload_to_s3(): # Create a fresh async engine for this event loop to avoid # "Future attached to a different loop" errors when running # in a thread pool with a new event loop via asyncio.run() - from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession + from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker + from app.config import settings db_url = str(settings.database_url) @@ -826,10 +830,7 @@ async def upload_to_s3(): processed_images.append(result.metadata.to_dict()) submissions_to_delete.append(submission) - logger.info( - f"Uploaded staged image to S3: submission_id={ref}, " - f"s3_key={result.metadata.s3_key}" - ) + logger.info(f"Uploaded staged image to S3: submission_id={ref}, s3_key={result.metadata.s3_key}") except Exception as e: logger.error(f"Failed to upload staged image {ref}: {e}", exc_info=True) diff --git a/backend/app/mcp/utils/markdown_parser.py b/backend/app/mcp/utils/markdown_parser.py index 6aea5d9..190ecdb 100644 --- a/backend/app/mcp/utils/markdown_parser.py +++ b/backend/app/mcp/utils/markdown_parser.py @@ -1,7 +1,7 @@ """Markdown parsing utilities for TOC extraction and section extraction.""" import re -from typing import List, Dict, TypedDict +from typing import List, TypedDict class TocEntry(TypedDict): @@ -27,10 +27,10 @@ def _heading_to_id(heading_text: str) -> str: text = heading_text.replace("'", "") # Replace special characters and spaces with hyphens - text = re.sub(r'[^a-zA-Z0-9\s-]', '', text) + text = re.sub(r"[^a-zA-Z0-9\s-]", "", text) # Convert to lowercase and replace spaces/multiple hyphens with single hyphen - text = re.sub(r'[\s-]+', '-', text.strip().lower()) + text = re.sub(r"[\s-]+", "-", text.strip().lower()) # Prefix with "sec-" return f"sec-{text}" @@ -59,7 +59,7 @@ def extract_toc(markdown: str) -> List[TocEntry]: return [] # Regex to match markdown headings: ^(#{1,6})\s+(.+)$ - heading_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE) + heading_pattern = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE) matches = heading_pattern.findall(markdown) if not matches: @@ -138,10 +138,10 @@ def extract_section(markdown: str, section_id: str) -> str: # We extract from the target heading until: # - The next heading at any level (for "direct content only" semantics) # - OR end of document - lines = markdown.split('\n') + lines = markdown.split("\n") section_lines: List[str] = [] in_section = False - heading_marker = '#' * target_entry["level"] + heading_marker = "#" * target_entry["level"] heading_line = f"{heading_marker} {target_entry['title']}" for line in lines: @@ -154,7 +154,7 @@ def extract_section(markdown: str, section_id: str) -> str: # If in section, collect lines until ANY next heading if in_section: # Check if this line is ANY heading - if line.strip().startswith('#') and re.match(r'^#{1,6}\s+.+$', line.strip()): + if line.strip().startswith("#") and re.match(r"^#{1,6}\s+.+$", line.strip()): # This is a heading - stop collecting break @@ -163,4 +163,4 @@ def extract_section(markdown: str, section_id: str) -> str: if not section_lines: raise ValueError(f"Section '{section_id}' not found in markdown") - return '\n'.join(section_lines) + return "\n".join(section_lines) diff --git a/backend/app/mcp/utils/project_resolver.py b/backend/app/mcp/utils/project_resolver.py index 938a565..6dc459b 100644 --- a/backend/app/mcp/utils/project_resolver.py +++ b/backend/app/mcp/utils/project_resolver.py @@ -2,6 +2,7 @@ from typing import Optional from uuid import UUID + from sqlalchemy.orm import Session from app.models import Project diff --git a/backend/app/mcp/vfs/__init__.py b/backend/app/mcp/vfs/__init__.py index 1033f75..60cee13 100644 --- a/backend/app/mcp/vfs/__init__.py +++ b/backend/app/mcp/vfs/__init__.py @@ -35,21 +35,21 @@ """ from app.mcp.vfs.errors import ( - VFSError, - PathNotFoundError, + InvalidPathError, NotADirectoryError, NotAFileError, + PathNotFoundError, PermissionDeniedError, - InvalidPathError, + VFSError, ) from app.mcp.vfs.path_resolver import ( NodeType, ResolvedPath, - resolve_path, - slugify, + build_path, feature_dir_name, module_dir_name, - build_path, + resolve_path, + slugify, ) __all__ = [ diff --git a/backend/app/mcp/vfs/content.py b/backend/app/mcp/vfs/content.py index 19950a9..b2306ff 100644 --- a/backend/app/mcp/vfs/content.py +++ b/backend/app/mcp/vfs/content.py @@ -7,23 +7,23 @@ from sqlalchemy.orm import Session -from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureProvenance -from app.models.implementation import Implementation -from app.models.project import Project -from app.models.org_membership import OrgMembership -from app.models.spec_version import SpecVersion, SpecType -from app.models.user import User +from app.mcp.utils.markdown_parser import extract_section, extract_toc +from app.mcp.vfs.errors import PathNotFoundError from app.mcp.vfs.path_resolver import ( NodeType, ResolvedPath, - slugify, feature_dir_name, module_dir_name, + slugify, ) -from app.mcp.vfs.errors import PathNotFoundError -from app.mcp.utils.markdown_parser import extract_toc, extract_section +from app.models.brainstorming_phase import BrainstormingPhase +from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType +from app.models.implementation import Implementation +from app.models.module import Module, ModuleType +from app.models.org_membership import OrgMembership +from app.models.project import Project +from app.models.spec_version import SpecType, SpecVersion +from app.models.user import User from app.services.team_role_service import TeamRoleService @@ -111,7 +111,7 @@ def _resolve_image_references(markdown: str) -> str: from app.services.image_service import ImageService # Pattern matches IMAGE_REF:uuid-format - pattern = r'IMAGE_REF:([a-f0-9-]+)' + pattern = r"IMAGE_REF:([a-f0-9-]+)" def replacer(match): image_id = match.group(1) @@ -331,9 +331,7 @@ def replacer(match): } -def _wrap_prompt_plan_with_instructions( - content: str, feature: Feature, resolved: ResolvedPath -) -> str: +def _wrap_prompt_plan_with_instructions(content: str, feature: Feature, resolved: ResolvedPath) -> str: """Wrap prompt plan with preamble/postscript instructions for coding agents.""" feature_dir = resolved.path.rsplit("/", 1)[0] + "/" @@ -466,13 +464,9 @@ def get_file_content( raise ValueError(f"Cannot get file content for node type: {resolved.node_type}") -def _get_full_document_content( - db: Session, resolved: ResolvedPath -) -> Dict[str, Any]: +def _get_full_document_content(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for full.md files (spec or prompt plan).""" - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == resolved.phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == resolved.phase_id).first() if not phase: raise PathNotFoundError(resolved.path) @@ -483,9 +477,7 @@ def _get_full_document_content( # Resolve any IMAGE_REF:xxx references to signed URLs content = _resolve_image_references(content_markdown) # Append any phase description images that weren't embedded by the LLM - content = _append_description_images( - content, phase.description_image_attachments, "phase description" - ) + content = _append_description_images(content, phase.description_image_attachments, "phase description") elif resolved.document_type == "spec": content = "# No specification available\n\nThis phase does not have a specification yet." else: @@ -502,15 +494,11 @@ def _get_full_document_content( } -def _get_summary_document_content( - db: Session, resolved: ResolvedPath -) -> Dict[str, Any]: +def _get_summary_document_content(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for summary.md files (spec only).""" from app.services.brainstorming_phase_service import _build_spec_summary_from_json - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == resolved.phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == resolved.phase_id).first() if not phase: raise PathNotFoundError(resolved.path) @@ -523,6 +511,7 @@ def _get_summary_document_content( # Create a simple object with content_json attribute to use existing function class ContentHolder: pass + holder = ContentHolder() holder.content_json = content_json content = _build_spec_summary_from_json(holder) @@ -543,9 +532,7 @@ class ContentHolder: def _get_section_content(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for section files (by-section/*.md).""" - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == resolved.phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == resolved.phase_id).first() if not phase: raise PathNotFoundError(resolved.path) @@ -601,17 +588,13 @@ def _get_feature_file_content(db: Session, resolved: ResolvedPath) -> Dict[str, # Resolve any IMAGE_REF:xxx references to signed URLs content = _resolve_image_references(content) # Append any feature description images that weren't embedded - content = _append_description_images( - content, feature.description_image_attachments, "feature" - ) + content = _append_description_images(content, feature.description_image_attachments, "feature") elif resolved.file_name == "prompt_plan.md": if feature.prompt_plan_text: # Resolve image references before wrapping with instructions resolved_text = _resolve_image_references(feature.prompt_plan_text) # Append any feature description images that weren't embedded - resolved_text = _append_description_images( - resolved_text, feature.description_image_attachments, "feature" - ) + resolved_text = _append_description_images(resolved_text, feature.description_image_attachments, "feature") content = _wrap_prompt_plan_with_instructions(resolved_text, feature, resolved) else: content = "# No prompt plan\n\nThis feature does not have a prompt plan." @@ -633,9 +616,7 @@ def _get_feature_file_content(db: Session, resolved: ResolvedPath) -> Dict[str, } -def list_directory( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def list_directory(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """ List contents of a virtual directory. @@ -691,9 +672,7 @@ def list_directory( raise ValueError(f"Cannot list directory for node type: {resolved.node_type}") -def _list_root( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_root(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """List root directory (top-level directories).""" entries = [ {"name": "phases/", "type": "d", "description": "System-generated and user-defined features"}, @@ -717,16 +696,10 @@ def _list_root( } -def _list_phases_dir( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_phases_dir(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """List /phases/ directory (system-generated and user-defined).""" # Count system-generated phases - phase_count = ( - db.query(BrainstormingPhase) - .filter(BrainstormingPhase.project_id == project_id) - .count() - ) + phase_count = db.query(BrainstormingPhase).filter(BrainstormingPhase.project_id == project_id).count() # Count user-defined features user_feature_count = ( @@ -739,10 +712,7 @@ def _list_phases_dir( Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .count() ) @@ -777,9 +747,7 @@ def _list_phases_dir( } -def _list_system_generated_dir( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_system_generated_dir(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """List /phases/system-generated/ directory (all phases).""" phases = ( db.query(BrainstormingPhase) @@ -798,14 +766,16 @@ def _list_system_generated_dir( progress = PhaseProgressService.get_phase_progress(db, phase.id) - entries.append({ - "name": f"{phase_slug}/", - "type": "d", - "title": phase.title, - "total_features": progress.total_features, - "completed_features": progress.completed_features, - "progress": progress.progress_percent, - }) + entries.append( + { + "name": f"{phase_slug}/", + "type": "d", + "title": phase.title, + "total_features": progress.total_features, + "completed_features": progress.completed_features, + "progress": progress.progress_percent, + } + ) text_lines.append(f"drwxr-xr-x {phase_slug}/") return { @@ -818,9 +788,7 @@ def _list_system_generated_dir( } -def _list_user_defined_dir( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_user_defined_dir(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """List /phases/user-defined/ directory.""" # Count user-defined features user_features = ( @@ -833,10 +801,7 @@ def _list_user_defined_dir( Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .all() ) @@ -872,9 +837,7 @@ def _list_user_defined_dir( } -def _list_user_defined_features_dir( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_user_defined_features_dir(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """List /phases/user-defined/features/ directory (modules with user-defined features).""" # Get all modules in project modules = ( @@ -904,10 +867,7 @@ def _list_user_defined_features_dir( Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .all() ) @@ -915,22 +875,22 @@ def _list_user_defined_features_dir( continue # Skip modules with no user-defined features feature_count = len(features) - completed = sum( - 1 for f in features if f.completion_status and f.completion_status.value == "completed" - ) + completed = sum(1 for f in features if f.completion_status and f.completion_status.value == "completed") progress = (completed / feature_count * 100) if feature_count > 0 else 0 total_features += feature_count completed_features += completed - entries.append({ - "name": f"{module_slug}/", - "type": "d", - "title": module.title, - "feature_count": feature_count, - "completed": completed, - "progress": round(progress, 1), - }) + entries.append( + { + "name": f"{module_slug}/", + "type": "d", + "title": module.title, + "feature_count": feature_count, + "completed": completed, + "progress": round(progress, 1), + } + ) text_lines.append(f"drwxr-xr-x {module_slug}/") overall_progress = (completed_features / total_features * 100) if total_features > 0 else 0 @@ -959,10 +919,7 @@ def _list_user_defined_module(db: Session, resolved: ResolvedPath) -> Dict[str, Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .order_by(Feature.created_at) .all() ) @@ -980,16 +937,18 @@ def _list_user_defined_module(db: Session, resolved: ResolvedPath) -> Dict[str, if next_feature is None and status == "pending": next_feature = feature.feature_key - entries.append({ - "name": f"{feature_slug}/", - "type": "d", - "feature_key": feature.feature_key, - "title": feature.title, - "status": status, - "priority": priority, - "provenance": provenance, - "external_provider": feature.external_provider, - }) + entries.append( + { + "name": f"{feature_slug}/", + "type": "d", + "feature_key": feature.feature_key, + "title": feature.title, + "status": status, + "priority": priority, + "provenance": provenance, + "external_provider": feature.external_provider, + } + ) text_lines.append(f"drwxr-xr-x {feature_slug}/") completed = sum(1 for f in features if f.completion_status and f.completion_status.value == "completed") @@ -1011,7 +970,7 @@ def _list_user_defined_module(db: Session, resolved: ResolvedPath) -> Dict[str, def _list_user_defined_feature(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """List user-defined feature directory (implementations/, conversations/).""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem feature = db.query(Feature).filter(Feature.id == resolved.feature_id).first() @@ -1019,11 +978,7 @@ def _list_user_defined_feature(db: Session, resolved: ResolvedPath) -> Dict[str, # Count implementations implementation_count = 0 if feature: - implementation_count = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature.id) - .count() - ) + implementation_count = db.query(Implementation).filter(Implementation.feature_id == feature.id).count() # Check if feature has a conversation thread with items has_conversation = False @@ -1037,11 +992,7 @@ def _list_user_defined_feature(db: Session, resolved: ResolvedPath) -> Dict[str, .first() ) if thread: - item_count = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == thread.id) - .count() - ) + item_count = db.query(ThreadItem).filter(ThreadItem.thread_id == thread.id).count() has_conversation = item_count > 0 entries = [ @@ -1187,14 +1138,16 @@ def _get_team_info_json(db: Session, project_id: UUID) -> Dict[str, Any]: } for a in assignments ] - roles.append({ - "id": str(role_def.id), - "role_key": role_def.role_key, - "title": role_def.title, - "description": role_def.description, - "is_default": role_def.is_default, - "members": members, - }) + roles.append( + { + "id": str(role_def.id), + "role_key": role_def.role_key, + "title": role_def.title, + "description": role_def.description, + "is_default": role_def.is_default, + "members": members, + } + ) content = json.dumps({"roles": roles}, indent=2) @@ -1220,9 +1173,7 @@ def _list_phase(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: ] # Get phase metadata - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == resolved.phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == resolved.phase_id).first() from app.services.phase_progress_service import PhaseProgressService @@ -1273,9 +1224,7 @@ def _list_document_dir(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: def _list_sections(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """List by-section/ directory with all available sections.""" - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == resolved.phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == resolved.phase_id).first() sections = [] @@ -1290,29 +1239,35 @@ def _list_sections(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: # Try structured content first if content_json and isinstance(content_json, dict) and "sections" in content_json: for section in content_json["sections"]: - sections.append({ - "id": section.get("id", "unknown"), - "title": section.get("title", "Untitled"), - }) + sections.append( + { + "id": section.get("id", "unknown"), + "title": section.get("title", "Untitled"), + } + ) elif full_content: # Fall back to markdown parsing toc = extract_toc(full_content) for entry in toc: - sections.append({ - "id": entry["id"], - "title": entry["title"], - }) + sections.append( + { + "id": entry["id"], + "title": entry["title"], + } + ) entries = [] text_lines = [] for section in sections: file_name = f"{section['id']}.md" - entries.append({ - "name": file_name, - "type": "f", - "title": section["title"], - }) + entries.append( + { + "name": file_name, + "type": "f", + "title": section["title"], + } + ) text_lines.append(f"-rw-r--r-- {file_name}") return { @@ -1338,14 +1293,16 @@ def _list_features_dir(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: for mod in phase_progress.modules: module_slug = module_dir_name(mod.module_key, mod.title) - entries.append({ - "name": f"{module_slug}/", - "type": "d", - "title": mod.title, - "feature_count": mod.total_features, - "completed": mod.completed_features, - "progress": mod.progress_percent, - }) + entries.append( + { + "name": f"{module_slug}/", + "type": "d", + "title": mod.title, + "feature_count": mod.total_features, + "completed": mod.completed_features, + "progress": mod.progress_percent, + } + ) text_lines.append(f"drwxr-xr-x {module_slug}/") return { @@ -1378,9 +1335,7 @@ def _list_module(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: from app.services.phase_progress_service import PhaseProgressService - total, completed, pending, in_prog, pct, next_feature = ( - PhaseProgressService.compute_feature_stats(features) - ) + total, completed, pending, in_prog, pct, next_feature = PhaseProgressService.compute_feature_stats(features) entries = [] text_lines = [] @@ -1391,16 +1346,18 @@ def _list_module(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: priority = feature.priority.value if feature.priority else "important" provenance = feature.provenance.value if feature.provenance else "system" - entries.append({ - "name": f"{feature_slug}/", - "type": "d", - "feature_key": feature.feature_key, - "title": feature.title, - "status": f_status, - "priority": priority, - "provenance": provenance, - "external_provider": feature.external_provider, - }) + entries.append( + { + "name": f"{feature_slug}/", + "type": "d", + "feature_key": feature.feature_key, + "title": feature.title, + "status": f_status, + "priority": priority, + "provenance": provenance, + "external_provider": feature.external_provider, + } + ) text_lines.append(f"drwxr-xr-x {feature_slug}/") return { @@ -1419,7 +1376,7 @@ def _list_module(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: def _list_feature(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """List feature directory (implementations/, conversations/).""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem feature = db.query(Feature).filter(Feature.id == resolved.feature_id).first() @@ -1427,11 +1384,7 @@ def _list_feature(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: # Count implementations implementation_count = 0 if feature: - implementation_count = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature.id) - .count() - ) + implementation_count = db.query(Implementation).filter(Implementation.feature_id == feature.id).count() # Check if feature has a conversation thread with items has_conversation = False @@ -1445,11 +1398,7 @@ def _list_feature(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: .first() ) if thread: - item_count = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == thread.id) - .count() - ) + item_count = db.query(ThreadItem).filter(ThreadItem.thread_id == thread.id).count() has_conversation = item_count > 0 entries = [ @@ -1484,9 +1433,7 @@ def _list_feature(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: } -def _list_grounding_dir( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_grounding_dir(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """List /for-coding-agents/ directory.""" from app.services.grounding_service import GroundingService @@ -1497,22 +1444,26 @@ def _list_grounding_dir( for gf in files: permissions = "-rw-r--r--" if gf.is_protected else "-rw-rw-rw-" - entries.append({ - "name": gf.filename, - "type": "f", - "is_protected": gf.is_protected, - "updated_at": gf.updated_at.isoformat() if gf.updated_at else None, - "writable": True, - }) + entries.append( + { + "name": gf.filename, + "type": "f", + "is_protected": gf.is_protected, + "updated_at": gf.updated_at.isoformat() if gf.updated_at else None, + "writable": True, + } + ) text_lines.append(f"{permissions} {gf.filename}") # Add static mfbt-usage-guide directory - entries.append({ - "name": "mfbt-usage-guide/", - "type": "d", - "description": "MFBT usage guides for coding agents", - "is_static": True, - }) + entries.append( + { + "name": "mfbt-usage-guide/", + "type": "d", + "description": "MFBT usage guides for coding agents", + "is_static": True, + } + ) text_lines.append("drwxr-xr-x mfbt-usage-guide/") return { @@ -1526,9 +1477,7 @@ def _list_grounding_dir( } -def _get_grounding_file_content( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _get_grounding_file_content(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for grounding files (/for-coding-agents/*).""" from app.services.grounding_service import GroundingService @@ -1615,6 +1564,7 @@ def _get_extension_from_content_type(content_type: str) -> str: def _transform_mentions_for_vfs(body: str) -> str: """Transform @[Name](uuid) to @Name (user_id: uuid) for coding agents.""" import re + # Match @[Name](id) where id can be any alphanumeric+hyphen string pattern = re.compile(r"@\[([^\]]+)\]\(([a-zA-Z0-9-]+)\)") return pattern.sub(r"@\1 (user_id: \2)", body) @@ -1645,9 +1595,7 @@ def _get_thread_images(db: Session, thread_id) -> List[dict]: return images -def _list_conversations_dir( - db: Session, resolved: ResolvedPath -) -> Dict[str, Any]: +def _list_conversations_dir(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """List {feature}/conversations/ directory. Note: Images are now embedded directly in conversations.md as markdown @@ -1673,11 +1621,7 @@ def _list_conversations_dir( if resolved.thread_id: thread = db.query(Thread).filter(Thread.id == resolved.thread_id).first() if thread: - item_count = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == resolved.thread_id) - .count() - ) + item_count = db.query(ThreadItem).filter(ThreadItem.thread_id == resolved.thread_id).count() return { "path": resolved.path, @@ -1707,9 +1651,10 @@ def _get_conversations_file_content( The upload token is only generated when both project_id, user_id, and resolved.feature_id are available. """ + from sqlalchemy.orm import joinedload + from app.models.thread import Thread from app.models.thread_item import ThreadItem, ThreadItemType - from sqlalchemy.orm import joinedload if not resolved.thread_id: return { @@ -1718,11 +1663,7 @@ def _get_conversations_file_content( "content": "# No Conversation\n\nNo discussion has occurred for this feature yet.", } - thread = ( - db.query(Thread) - .filter(Thread.id == resolved.thread_id) - .first() - ) + thread = db.query(Thread).filter(Thread.id == resolved.thread_id).first() if not thread: return { @@ -1782,9 +1723,7 @@ def _get_conversations_file_content( width = img.get("width", 0) height = img.get("height", 0) if width and height: - markdown_lines.append( - f"![{original_filename} ({width}x{height})]({signed_url})\n\n" - ) + markdown_lines.append(f"![{original_filename} ({width}x{height})]({signed_url})\n\n") else: markdown_lines.append(f"![{original_filename}]({signed_url})\n\n") @@ -1823,8 +1762,8 @@ def _get_conversations_file_content( # Generate upload token if we have the required context if project_id and user_id and resolved.feature_id: - from app.services.image_service import ImageService from app.config import settings + from app.services.image_service import ImageService upload_token = ImageService.generate_upload_token( project_id=project_id, @@ -1834,10 +1773,10 @@ def _get_conversations_file_content( base_url = settings.base_url.rstrip("/") markdown_lines.append("1. (Optional) Upload images via API:\n") - markdown_lines.append(f" curl -X POST \"{base_url}/api/v1/mcp/images/upload\" \\\n") - markdown_lines.append(f" -H \"Authorization: Bearer {upload_token}\" \\\n") - markdown_lines.append(" -F \"file=@/path/to/screenshot.png\"\n\n") - markdown_lines.append(" Response: {\"image_id\": \"abc123\", \"expires_in_hours\": 1}\n\n") + markdown_lines.append(f' curl -X POST "{base_url}/api/v1/mcp/images/upload" \\\n') + markdown_lines.append(f' -H "Authorization: Bearer {upload_token}" \\\n') + markdown_lines.append(' -F "file=@/path/to/screenshot.png"\n\n') + markdown_lines.append(' Response: {"image_id": "abc123", "expires_in_hours": 1}\n\n') markdown_lines.append(" Supported formats: png, jpg, jpeg, gif, webp (max 10MB)\n") markdown_lines.append(" Token valid for 1 hour.\n\n") else: @@ -1845,7 +1784,9 @@ def _get_conversations_file_content( markdown_lines.append("2. Post comment (with optional image references):\n") markdown_lines.append("```json\n") - markdown_lines.append('{"action": "add_comment", "body_markdown": "Your text. Use @[Name](user_id) for mentions.", "images": [""], "coding_agent_name": "claude_code"}\n') + markdown_lines.append( + '{"action": "add_comment", "body_markdown": "Your text. Use @[Name](user_id) for mentions.", "images": [""], "coding_agent_name": "claude_code"}\n' + ) markdown_lines.append("```\n\n") markdown_lines.append("Notes:\n") markdown_lines.append("- Get user_id from /project-info/team-info.json for mentions\n") @@ -1886,16 +1827,18 @@ def _list_implementations_dir(db: Session, resolved: ResolvedPath) -> Dict[str, has_prompt_plan = bool(impl.prompt_plan_text) has_notes = bool(impl.implementation_notes) - entries.append({ - "name": f"{impl_slug}/", - "type": "d", - "title": impl.name, - "is_complete": impl.is_complete, - "is_primary": impl.is_primary, - "has_spec": has_spec, - "has_prompt_plan": has_prompt_plan, - "has_notes": has_notes, - }) + entries.append( + { + "name": f"{impl_slug}/", + "type": "d", + "title": impl.name, + "is_complete": impl.is_complete, + "is_primary": impl.is_primary, + "has_spec": has_spec, + "has_prompt_plan": has_prompt_plan, + "has_notes": has_notes, + } + ) text_lines.append(f"drwxr-xr-x {impl_slug}/") # Count completed implementations @@ -1915,9 +1858,7 @@ def _list_implementations_dir(db: Session, resolved: ResolvedPath) -> Dict[str, def _list_implementation(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """List {feature}/implementations/{name}/ directory.""" - impl = db.query(Implementation).filter( - Implementation.id == resolved.implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == resolved.implementation_id).first() has_spec = bool(impl.spec_text) if impl else False has_prompt_plan = bool(impl.prompt_plan_text) if impl else False @@ -1954,9 +1895,7 @@ def _list_implementation(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: def _get_implementation_spec_content(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for implementation spec.md files.""" - impl = db.query(Implementation).filter( - Implementation.id == resolved.implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == resolved.implementation_id).first() if not impl: raise PathNotFoundError(resolved.path) @@ -1968,9 +1907,7 @@ def _get_implementation_spec_content(db: Session, resolved: ResolvedPath) -> Dic content = _resolve_image_references(impl.spec_text) # Append any feature description images that weren't embedded if feature: - content = _append_description_images( - content, feature.description_image_attachments, "feature" - ) + content = _append_description_images(content, feature.description_image_attachments, "feature") else: content = "# No specification\n\nThis implementation does not have a specification yet." @@ -1988,9 +1925,7 @@ def _get_implementation_spec_content(db: Session, resolved: ResolvedPath) -> Dic def _get_implementation_prompt_plan_content(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for implementation prompt_plan.md files.""" - impl = db.query(Implementation).filter( - Implementation.id == resolved.implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == resolved.implementation_id).first() if not impl: raise PathNotFoundError(resolved.path) @@ -2003,9 +1938,7 @@ def _get_implementation_prompt_plan_content(db: Session, resolved: ResolvedPath) resolved_text = _resolve_image_references(impl.prompt_plan_text) # Append any feature description images that weren't embedded if feature: - resolved_text = _append_description_images( - resolved_text, feature.description_image_attachments, "feature" - ) + resolved_text = _append_description_images(resolved_text, feature.description_image_attachments, "feature") content = _wrap_prompt_plan_with_instructions(resolved_text, feature, resolved) else: content = "# No prompt plan\n\nThis implementation does not have a prompt plan yet." @@ -2024,9 +1957,7 @@ def _get_implementation_prompt_plan_content(db: Session, resolved: ResolvedPath) def _get_implementation_notes_content(db: Session, resolved: ResolvedPath) -> Dict[str, Any]: """Get content for implementation notes.md files.""" - impl = db.query(Implementation).filter( - Implementation.id == resolved.implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == resolved.implementation_id).first() if not impl: raise PathNotFoundError(resolved.path) diff --git a/backend/app/mcp/vfs/errors.py b/backend/app/mcp/vfs/errors.py index f23d2ec..5df72a9 100644 --- a/backend/app/mcp/vfs/errors.py +++ b/backend/app/mcp/vfs/errors.py @@ -17,8 +17,7 @@ def __init__(self, path: str, available: Optional[List[str]] = None): self.available = available or [] if available: super().__init__( - f"Path not found: {path}. Available: {', '.join(available[:5])}" - + ("..." if len(available) > 5 else "") + f"Path not found: {path}. Available: {', '.join(available[:5])}" + ("..." if len(available) > 5 else "") ) else: super().__init__(f"Path not found: {path}") diff --git a/backend/app/mcp/vfs/metadata.py b/backend/app/mcp/vfs/metadata.py index cfd6624..b798dc8 100644 --- a/backend/app/mcp/vfs/metadata.py +++ b/backend/app/mcp/vfs/metadata.py @@ -1,15 +1,14 @@ """VFS metadata management - computed and agent-settable metadata.""" from datetime import datetime, timezone -from typing import Any, Dict, List, Optional +from typing import Any, Dict from uuid import UUID from sqlalchemy.orm import Session -from app.models.feature import Feature, FeatureCompletionStatus -from app.mcp.vfs.path_resolver import NodeType, ResolvedPath from app.mcp.vfs.errors import InvalidPathError, PermissionDeniedError - +from app.mcp.vfs.path_resolver import NodeType, ResolvedPath +from app.models.feature import Feature, FeatureCompletionStatus # Reserved metadata keys that are computed and cannot be set COMPUTED_KEYS = { @@ -42,16 +41,14 @@ } -def get_metadata( - db: Session, project_id: UUID, resolved: ResolvedPath, user_id: UUID -) -> Dict[str, Any]: +def get_metadata(db: Session, project_id: UUID, resolved: ResolvedPath, user_id: UUID) -> Dict[str, Any]: """ Get all metadata for a path (computed + agent-set). This is already included in list_directory and get_file_content responses, but this function can be used for direct metadata access. """ - from app.mcp.vfs.content import list_directory, get_file_content + from app.mcp.vfs.content import get_file_content, list_directory if resolved.is_directory: result = list_directory(db, project_id, resolved) @@ -99,17 +96,11 @@ def set_metadata( """ # Check for computed keys if key in COMPUTED_KEYS: - raise PermissionDeniedError( - resolved.path, - f"set computed metadata key '{key}'" - ) + raise PermissionDeniedError(resolved.path, f"set computed metadata key '{key}'") # Only allow metadata on directories (phases, modules, features) if not resolved.is_directory: - raise InvalidPathError( - resolved.path, - "metadata can only be set on directories" - ) + raise InvalidPathError(resolved.path, "metadata can only be set on directories") # Check for side effect keys if key in SIDE_EFFECT_KEYS: @@ -130,10 +121,7 @@ def _set_completion_status( ) -> Dict[str, Any]: """Set completion_status on a feature directory.""" if resolved.node_type not in (NodeType.FEATURE_DIR, NodeType.USER_DEFINED_FEATURE_DIR): - raise InvalidPathError( - resolved.path, - "completion_status can only be set on feature directories" - ) + raise InvalidPathError(resolved.path, "completion_status can only be set on feature directories") if not resolved.feature_id: raise InvalidPathError(resolved.path, "feature_id not found") @@ -152,8 +140,7 @@ def _set_completion_status( status_str = str(value).lower() if status_str not in status_map: raise InvalidPathError( - resolved.path, - f"invalid completion_status '{value}'. Valid values: pending, in_progress, completed" + resolved.path, f"invalid completion_status '{value}'. Valid values: pending, in_progress, completed" ) new_status = status_map[status_str] @@ -161,14 +148,13 @@ def _set_completion_status( # If setting to completed, guide agent to use is_complete on implementations instead if new_status == FeatureCompletionStatus.COMPLETED: from app.models.implementation import Implementation - impl_count = db.query(Implementation).filter( - Implementation.feature_id == feature.id - ).count() + + impl_count = db.query(Implementation).filter(Implementation.feature_id == feature.id).count() if impl_count > 0: return { "error": "Do not set completion_status to 'completed' directly", "hint": "Use setMetadataValueForKey on the implementation directory with key 'is_complete' and value 'true'. Feature status auto-syncs from implementations.", - "example": f"setMetadataValueForKey .../implementations/{{impl_name}}/ is_complete true", + "example": "setMetadataValueForKey .../implementations/{impl_name}/ is_complete true", } # Update feature @@ -190,6 +176,7 @@ def _set_completion_status( # Broadcast feature update to WebSocket clients from app.services.feature_service import FeatureService + FeatureService._broadcast_feature_update(db, feature, "completion_status") return { @@ -214,13 +201,9 @@ def _set_in_progress( ) -> Dict[str, Any]: """Shortcut to set a feature to in_progress status.""" if value in (True, "true", "True", "1", 1): - return _set_completion_status( - db, project_id, resolved, user_id, "in_progress" - ) + return _set_completion_status(db, project_id, resolved, user_id, "in_progress") else: - return _set_completion_status( - db, project_id, resolved, user_id, "pending" - ) + return _set_completion_status(db, project_id, resolved, user_id, "pending") def _set_is_complete( @@ -232,21 +215,16 @@ def _set_is_complete( ) -> Dict[str, Any]: """Set is_complete on an implementation directory.""" from app.models.implementation import Implementation - from app.services.implementation_service import ImplementationService from app.services.feature_service import FeatureService + from app.services.implementation_service import ImplementationService if resolved.node_type != NodeType.IMPLEMENTATION_DIR: - raise InvalidPathError( - resolved.path, - "is_complete can only be set on implementation directories" - ) + raise InvalidPathError(resolved.path, "is_complete can only be set on implementation directories") if not resolved.implementation_id: raise InvalidPathError(resolved.path, "implementation_id not found") - impl = db.query(Implementation).filter( - Implementation.id == resolved.implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == resolved.implementation_id).first() if not impl: raise InvalidPathError(resolved.path, "implementation not found") @@ -283,9 +261,7 @@ def _set_is_complete( } -def _get_agent_metadata( - db: Session, project_id: UUID, resolved: ResolvedPath -) -> Dict[str, Any]: +def _get_agent_metadata(db: Session, project_id: UUID, resolved: ResolvedPath) -> Dict[str, Any]: """Get agent-set metadata from VFSMetadata table.""" from app.models.vfs_metadata import VFSMetadata @@ -322,9 +298,10 @@ def _store_metadata( value: Any, ) -> Dict[str, Any]: """Store metadata in VFSMetadata table.""" - from app.models.vfs_metadata import VFSMetadata import json + from app.models.vfs_metadata import VFSMetadata + # Determine which entity this metadata applies to phase_id = None module_id = None @@ -337,10 +314,7 @@ def _store_metadata( elif resolved.node_type in (NodeType.FEATURE_DIR, NodeType.USER_DEFINED_FEATURE_DIR): feature_id = resolved.feature_id else: - raise InvalidPathError( - resolved.path, - "metadata can only be set on phase, module, or feature directories" - ) + raise InvalidPathError(resolved.path, "metadata can only be set on phase, module, or feature directories") # Serialize value to JSON string value_str = json.dumps(value) if not isinstance(value, str) else value @@ -352,16 +326,20 @@ def _store_metadata( ] if phase_id: - filters.extend([ - VFSMetadata.brainstorming_phase_id == phase_id, - VFSMetadata.module_id.is_(None), - VFSMetadata.feature_id.is_(None), - ]) + filters.extend( + [ + VFSMetadata.brainstorming_phase_id == phase_id, + VFSMetadata.module_id.is_(None), + VFSMetadata.feature_id.is_(None), + ] + ) elif module_id: - filters.extend([ - VFSMetadata.module_id == module_id, - VFSMetadata.feature_id.is_(None), - ]) + filters.extend( + [ + VFSMetadata.module_id == module_id, + VFSMetadata.feature_id.is_(None), + ] + ) elif feature_id: filters.append(VFSMetadata.feature_id == feature_id) diff --git a/backend/app/mcp/vfs/path_resolver.py b/backend/app/mcp/vfs/path_resolver.py index ba333cd..f30fdf9 100644 --- a/backend/app/mcp/vfs/path_resolver.py +++ b/backend/app/mcp/vfs/path_resolver.py @@ -3,16 +3,16 @@ import re from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Tuple +from typing import List, Optional from uuid import UUID from sqlalchemy.orm import Session +from app.mcp.vfs.errors import InvalidPathError, PathNotFoundError from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module, ModuleType -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureProvenance +from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType from app.models.implementation import Implementation -from app.mcp.vfs.errors import PathNotFoundError, InvalidPathError +from app.models.module import Module, ModuleType class NodeType(Enum): @@ -61,7 +61,9 @@ class NodeType(Enum): IMPLEMENTATIONS_DIR = "implementations_dir" # {feature}/implementations/ IMPLEMENTATION_DIR = "implementation_dir" # {feature}/implementations/{name}/ IMPLEMENTATION_SPEC_FILE = "implementation_spec_file" # {feature}/implementations/{name}/spec.md - IMPLEMENTATION_PROMPT_PLAN_FILE = "implementation_prompt_plan_file" # {feature}/implementations/{name}/prompt_plan.md + IMPLEMENTATION_PROMPT_PLAN_FILE = ( + "implementation_prompt_plan_file" # {feature}/implementations/{name}/prompt_plan.md + ) IMPLEMENTATION_NOTES_FILE = "implementation_notes_file" # {feature}/implementations/{name}/notes.md @@ -217,9 +219,7 @@ def _normalize_path(path: str) -> str: return path -def _find_phase_by_slug( - db: Session, project_id: UUID, slug: str -) -> Optional[BrainstormingPhase]: +def _find_phase_by_slug(db: Session, project_id: UUID, slug: str) -> Optional[BrainstormingPhase]: """Find a brainstorming phase by its slugified title.""" phases = ( db.query(BrainstormingPhase) @@ -235,9 +235,7 @@ def _find_phase_by_slug( return None -def _find_module_by_slug( - db: Session, phase_id: UUID, slug: str -) -> Optional[Module]: +def _find_module_by_slug(db: Session, phase_id: UUID, slug: str) -> Optional[Module]: """Find a module by its slugified directory name within a phase. Directory format: {module-key}-{slugified-title} @@ -259,9 +257,7 @@ def _find_module_by_slug( return None -def _find_feature_by_slug( - db: Session, module_id: UUID, slug: str -) -> Optional[Feature]: +def _find_feature_by_slug(db: Session, module_id: UUID, slug: str) -> Optional[Feature]: """Find a feature by its slugified directory name within a module.""" features = ( db.query(Feature) @@ -321,15 +317,9 @@ def _get_available_features(db: Session, module_id: UUID) -> List[str]: return [feature_dir_name(f.feature_key, f.title) for f in features] -def _find_implementation_by_slug( - db: Session, feature_id: UUID, slug: str -) -> Optional[Implementation]: +def _find_implementation_by_slug(db: Session, feature_id: UUID, slug: str) -> Optional[Implementation]: """Find an implementation by slugified name within a feature.""" - implementations = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature_id) - .all() - ) + implementations = db.query(Implementation).filter(Implementation.feature_id == feature_id).all() for impl in implementations: if slugify(impl.name) == slug: return impl @@ -349,20 +339,14 @@ def _get_available_implementations(db: Session, feature_id: UUID) -> List[str]: def _is_user_defined_feature(feature: Feature) -> bool: """Check if a feature is user-defined (provenance=USER or has external_provider).""" - return ( - feature.provenance == FeatureProvenance.USER - or feature.external_provider is not None - ) + return feature.provenance == FeatureProvenance.USER or feature.external_provider is not None -def _find_user_defined_module_by_slug( - db: Session, project_id: UUID, slug: str -) -> Optional[Module]: +def _find_user_defined_module_by_slug(db: Session, project_id: UUID, slug: str) -> Optional[Module]: """Find a module by directory name that has user-defined features. Directory format: {module-key}-{slugified-title} """ - from sqlalchemy import exists, and_ # Get all modules in the project that have user-defined features modules = ( @@ -385,10 +369,7 @@ def _find_user_defined_module_by_slug( Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .first() ) if has_user_features: @@ -396,9 +377,7 @@ def _find_user_defined_module_by_slug( return None -def _find_user_defined_feature_by_slug( - db: Session, module_id: UUID, slug: str -) -> Optional[Feature]: +def _find_user_defined_feature_by_slug(db: Session, module_id: UUID, slug: str) -> Optional[Feature]: """Find a user-defined feature by its slugified directory name within a module.""" features = ( db.query(Feature) @@ -407,10 +386,7 @@ def _find_user_defined_feature_by_slug( Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .all() ) for feature in features: @@ -442,10 +418,7 @@ def _get_available_user_defined_modules(db: Session, project_id: UUID) -> List[s Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .first() ) if has_user_features: @@ -462,10 +435,7 @@ def _get_available_user_defined_features(db: Session, module_id: UUID) -> List[s Feature.feature_type == FeatureType.IMPLEMENTATION, Feature.status == FeatureStatus.ACTIVE, ) - .filter( - (Feature.provenance == FeatureProvenance.USER) - | (Feature.external_provider.isnot(None)) - ) + .filter((Feature.provenance == FeatureProvenance.USER) | (Feature.external_provider.isnot(None))) .all() ) return [feature_dir_name(f.feature_key, f.title) for f in features] @@ -516,9 +486,7 @@ def resolve_path(db: Session, project_id: UUID, path: str) -> ResolvedPath: raise PathNotFoundError(path, ["phases", "project-info", "system-info", "for-coding-agents"]) -def _resolve_phases_path( - db: Session, project_id: UUID, full_path: str, parts: List[str] -) -> ResolvedPath: +def _resolve_phases_path(db: Session, project_id: UUID, full_path: str, parts: List[str]) -> ResolvedPath: """Resolve paths under /phases/.""" # Just /phases/ if len(parts) == 1: @@ -541,9 +509,7 @@ def _resolve_phases_path( raise PathNotFoundError(full_path, ["system-generated", "user-defined"]) -def _resolve_system_generated_path( - db: Session, project_id: UUID, full_path: str, parts: List[str] -) -> ResolvedPath: +def _resolve_system_generated_path(db: Session, project_id: UUID, full_path: str, parts: List[str]) -> ResolvedPath: """Resolve paths under /phases/system-generated/.""" # Just /phases/system-generated/ if len(parts) == 2: @@ -575,29 +541,21 @@ def _resolve_system_generated_path( # Handle phase-spec directory if fourth == "phase-spec": - return _resolve_document_path( - db, full_path, parts, phase, phase_slug, "spec" - ) + return _resolve_document_path(db, full_path, parts, phase, phase_slug, "spec") # Handle phase-prompt-plan directory if fourth == "phase-prompt-plan": - return _resolve_document_path( - db, full_path, parts, phase, phase_slug, "prompt_plan" - ) + return _resolve_document_path(db, full_path, parts, phase, phase_slug, "prompt_plan") # Handle features directory if fourth == "features": return _resolve_features_path(db, full_path, parts, phase, phase_slug) # Invalid fourth-level path - raise PathNotFoundError( - full_path, ["phase-spec", "phase-prompt-plan", "features"] - ) + raise PathNotFoundError(full_path, ["phase-spec", "phase-prompt-plan", "features"]) -def _resolve_user_defined_path( - db: Session, project_id: UUID, full_path: str, parts: List[str] -) -> ResolvedPath: +def _resolve_user_defined_path(db: Session, project_id: UUID, full_path: str, parts: List[str]) -> ResolvedPath: """Resolve paths under /phases/user-defined/.""" # Just /phases/user-defined/ if len(parts) == 2: @@ -781,9 +739,7 @@ def _resolve_document_path( # Just the document directory: /phases/system-generated/{phase}/{doc-type} if len(parts) == 4: return ResolvedPath( - node_type=NodeType.PHASE_SPEC_DIR - if doc_type == "spec" - else NodeType.PHASE_PROMPT_PLAN_DIR, + node_type=NodeType.PHASE_SPEC_DIR if doc_type == "spec" else NodeType.PHASE_PROMPT_PLAN_DIR, is_directory=True, path=f"{base_path}/{dir_name}", phase_slug=phase_slug, @@ -839,9 +795,7 @@ def _resolve_document_path( if len(parts) == 6: section_file = parts[5] if not section_file.endswith(".md"): - raise InvalidPathError( - full_path, "section files must have .md extension" - ) + raise InvalidPathError(full_path, "section files must have .md extension") section_id = section_file[:-3] # Remove .md return ResolvedPath( node_type=NodeType.SECTION_FILE, @@ -970,7 +924,7 @@ def _resolve_features_path( def _get_feature_thread(db: Session, feature_id: UUID): """Get the BRAINSTORM_FEATURE thread for a feature, if one exists.""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread return ( db.query(Thread) diff --git a/backend/app/middleware/__init__.py b/backend/app/middleware/__init__.py index 17f3d3c..6624db5 100644 --- a/backend/app/middleware/__init__.py +++ b/backend/app/middleware/__init__.py @@ -1,4 +1,5 @@ """Middleware package for MFBT Backend.""" + from app.middleware.trial import TrialMiddleware __all__ = ["TrialMiddleware"] diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 5e2b1b4..cb8ccdf 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -3,81 +3,89 @@ This module exports all database models for easy importing. """ -from app.models.user import User + +# Import validators and events to register SQLAlchemy event listeners +from app.models import ( + events, # noqa: F401 + validators, # noqa: F401 +) +from app.models.activity_log import ActivityLog from app.models.api_key import ApiKey +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype +from app.models.bug_sync_history import BugSyncHistory +from app.models.code_exploration_result import CodeExplorationResult +from app.models.daily_usage_summary import SENTINEL_UUID, DailyUsageSummary +from app.models.email_template import EmailTemplate, EmailTemplateKey +from app.models.feature import ( + Feature, + FeatureCompletionStatus, + FeatureProvenance, + FeatureStatus, + FeatureVisibilityStatus, +) +from app.models.feature_content_version import FeatureContentType, FeatureContentVersion +from app.models.feature_import_comment import FeatureImportComment +from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec +from app.models.form_draft import FormDraft, FormDraftType +from app.models.github_oauth_state import GitHubOAuthState +from app.models.grounding_file import GroundingFile +from app.models.grounding_file_branch import GroundingFileBranch +from app.models.grounding_note_version import GroundingNoteVersion from app.models.identity_provider import IdentityProvider, IdentityProviderType -from app.models.user_identity import UserIdentity -from app.models.job import Job, JobType, JobStatus -from app.models.organization import Organization -from app.models.org_membership import OrgMembership, OrgRole -from app.models.project import Project, ProjectType, ProjectStatus -from app.models.project_repository import ProjectRepository -from app.models.project_membership import ProjectMembership, ProjectRole -from app.models.provisioning import ProvisioningSource -from app.models.user_group import UserGroup -from app.models.user_group_membership import UserGroupMembership -from app.models.project_share import ProjectShare, ShareSubjectType -from app.models.org_invitation import OrgInvitation, InvitationStatus -from app.models.org_invitation_group import OrgInvitationGroup -from app.models.spec_version import SpecVersion, SpecType -from app.models.spec_coverage import SpecCoverageReport -from app.models.prompt_plan_coverage import PromptPlanCoverageReport -from app.models.thread import Thread, Comment, ContextType, ProjectChatVisibility -from app.models.thread_item import ThreadItem +from app.models.implementation import Implementation from app.models.integration_config import IntegrationConfig, IntegrationVisibility from app.models.integration_config_share import IntegrationConfigShare, IntegrationShareSubjectType -from app.models.bug_sync_history import BugSyncHistory -from app.models.notification_preference import NotificationPreference, NotificationChannel -from app.models.notification_project_mute import NotificationProjectMute -from app.models.notification_thread_watch import NotificationThreadWatch -from app.models.llm_preference import LLMPreference -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype -from app.models.phase_container import PhaseContainer -from app.models.module import Module, ModuleProvenance -from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureCompletionStatus, FeatureVisibilityStatus -from app.models.final_spec import FinalSpec -from app.models.final_prompt_plan import FinalPromptPlan -from app.models.activity_log import ActivityLog +from app.models.job import Job, JobStatus, JobType from app.models.llm_call_log import LLMCallLog +from app.models.llm_preference import LLMPreference from app.models.llm_usage_log import LLMUsageLog -from app.models.daily_usage_summary import DailyUsageSummary, SENTINEL_UUID from app.models.mcp_call_log import MCPCallLog -from app.models.vfs_metadata import VFSMetadata -from app.models.team_role import TeamRoleDefinition, ProjectTeamAssignment, DEFAULT_TEAM_ROLES -from app.models.grounding_file import GroundingFile -from app.models.grounding_file_branch import GroundingFileBranch -from app.models.feature_content_version import FeatureContentVersion, FeatureContentType -from app.models.feature_import_comment import FeatureImportComment -from app.models.implementation import Implementation +from app.models.mcp_image_submission import MCP_IMAGE_SUBMISSION_EXPIRY_HOURS, MCPImageSubmission +from app.models.mcp_oauth_client import MCPOAuthClient +from app.models.mcp_oauth_code import MCPOAuthAuthorizationCode +from app.models.mcp_oauth_token import MCPOAuthToken +from app.models.module import Module, ModuleProvenance +from app.models.notification_preference import NotificationChannel, NotificationPreference +from app.models.notification_project_mute import NotificationProjectMute +from app.models.notification_thread_watch import NotificationThreadWatch +from app.models.org_invitation import InvitationStatus, OrgInvitation +from app.models.org_invitation_group import OrgInvitationGroup +from app.models.org_membership import OrgMembership, OrgRole +from app.models.organization import Organization +from app.models.phase_container import PhaseContainer +from app.models.plan_recommendation import PlanRecommendation, RecommendationAction, RecommendationStatus from app.models.platform_connector import PlatformConnector, PlatformConnectorType from app.models.platform_settings import PlatformSettings -from app.models.user_question_session import ( - UserQuestionSession, - UserQuestionSessionStatus, - UserQuestionMessage, - MessageRole, -) -from app.models.email_template import EmailTemplate, EmailTemplateKey -from app.models.mcp_image_submission import MCPImageSubmission, MCP_IMAGE_SUBMISSION_EXPIRY_HOURS -from app.models.form_draft import FormDraft, FormDraftType +from app.models.project import Project, ProjectStatus, ProjectType from app.models.project_chat import ( ProjectChat, ProjectChatMessage, ProjectChatMessageType, ) -from app.models.mcp_oauth_client import MCPOAuthClient -from app.models.mcp_oauth_code import MCPOAuthAuthorizationCode -from app.models.mcp_oauth_token import MCPOAuthToken -from app.models.code_exploration_result import CodeExplorationResult -from app.models.github_oauth_state import GitHubOAuthState -from app.models.plan_recommendation import PlanRecommendation, RecommendationAction, RecommendationStatus -from app.models.grounding_note_version import GroundingNoteVersion +from app.models.project_membership import ProjectMembership, ProjectRole +from app.models.project_repository import ProjectRepository +from app.models.project_share import ProjectShare, ShareSubjectType +from app.models.prompt_plan_coverage import PromptPlanCoverageReport +from app.models.provisioning import ProvisioningSource from app.models.slack_channel_link import SlackChannelProjectLink, SlackChannelScope from app.models.slack_user_mapping import SlackUserMapping - -# Import validators and events to register SQLAlchemy event listeners -from app.models import validators # noqa: F401 -from app.models import events # noqa: F401 +from app.models.spec_coverage import SpecCoverageReport +from app.models.spec_version import SpecType, SpecVersion +from app.models.team_role import DEFAULT_TEAM_ROLES, ProjectTeamAssignment, TeamRoleDefinition +from app.models.thread import Comment, ContextType, ProjectChatVisibility, Thread +from app.models.thread_item import ThreadItem +from app.models.user import User +from app.models.user_group import UserGroup +from app.models.user_group_membership import UserGroupMembership +from app.models.user_identity import UserIdentity +from app.models.user_question_session import ( + MessageRole, + UserQuestionMessage, + UserQuestionSession, + UserQuestionSessionStatus, +) +from app.models.vfs_metadata import VFSMetadata __all__ = [ "User", diff --git a/backend/app/models/activity_log.py b/backend/app/models/activity_log.py index 608ff6b..d2a87fa 100644 --- a/backend/app/models/activity_log.py +++ b/backend/app/models/activity_log.py @@ -1,8 +1,12 @@ """ActivityLog model for tracking entity events.""" + import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, DateTime, ForeignKey, Index -from sqlalchemy.dialects.postgresql import UUID as PostgresUUID, JSON + +from sqlalchemy import Column, DateTime, Index, String +from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID + from app.database import Base @@ -16,6 +20,7 @@ class ActivityLog(Base): - SPEC_DRAFT_CREATED, SPEC_FINAL_GENERATED - PROMPT_PLAN_DRAFT_CREATED, PROMPT_PLAN_FINAL_GENERATED """ + __tablename__ = "activity_logs" id = Column(PostgresUUID(as_uuid=True), primary_key=True, default=uuid.uuid4) diff --git a/backend/app/models/api_key.py b/backend/app/models/api_key.py index e77b3ee..b95ade4 100644 --- a/backend/app/models/api_key.py +++ b/backend/app/models/api_key.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from uuid import UUID, uuid4 -from sqlalchemy import Boolean, DateTime, ForeignKey, String, Index +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base diff --git a/backend/app/models/brainstorming_phase.py b/backend/app/models/brainstorming_phase.py index 4c612d8..d3cf393 100644 --- a/backend/app/models/brainstorming_phase.py +++ b/backend/app/models/brainstorming_phase.py @@ -3,11 +3,13 @@ import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Boolean, Column, String, Text, ForeignKey, Enum, DateTime, Integer, JSON + +from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base -from app.utils.short_id import generate_short_id, build_url_identifier +from app.utils.short_id import build_url_identifier, generate_short_id class BrainstormingPhaseType(str, enum.Enum): @@ -53,9 +55,7 @@ class BrainstormingPhase(Base): # "size_bytes": N, "width": N, "height": N, "thumbnail_s3_key": "..."}] description_image_attachments = Column(JSON, nullable=True) created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) updated_at = Column( DateTime(timezone=True), nullable=False, @@ -130,5 +130,9 @@ def created_by_name(self) -> str | None: # Modules are not cascade deleted - they become orphaned with NULL brainstorming_phase_id modules = relationship("Module", back_populates="brainstorming_phase") spec_versions = relationship("SpecVersion", back_populates="brainstorming_phase", cascade="all, delete-orphan") - final_spec = relationship("FinalSpec", back_populates="brainstorming_phase", uselist=False, cascade="all, delete-orphan") - final_prompt_plan = relationship("FinalPromptPlan", back_populates="brainstorming_phase", uselist=False, cascade="all, delete-orphan") + final_spec = relationship( + "FinalSpec", back_populates="brainstorming_phase", uselist=False, cascade="all, delete-orphan" + ) + final_prompt_plan = relationship( + "FinalPromptPlan", back_populates="brainstorming_phase", uselist=False, cascade="all, delete-orphan" + ) diff --git a/backend/app/models/bug_sync_history.py b/backend/app/models/bug_sync_history.py index 7045a9d..53acd6b 100644 --- a/backend/app/models/bug_sync_history.py +++ b/backend/app/models/bug_sync_history.py @@ -1,4 +1,5 @@ """Bug sync history model.""" + import uuid from datetime import datetime, timezone @@ -14,9 +15,7 @@ class BugSyncHistory(Base): __tablename__ = "bug_sync_history" - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) project_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), @@ -31,9 +30,7 @@ class BugSyncHistory(Base): status: Mapped[str] = mapped_column(String(20), nullable=False) # success/error imported_data_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) error_message: Mapped[str | None] = mapped_column(Text, nullable=True) - triggered_by: Mapped[str] = mapped_column( - String(20), nullable=False - ) # system/user/agent + triggered_by: Mapped[str] = mapped_column(String(20), nullable=False) # system/user/agent # Relationships project: Mapped["Project"] = relationship("Project", back_populates="sync_history") diff --git a/backend/app/models/code_exploration_result.py b/backend/app/models/code_exploration_result.py index d32e193..4bec43a 100644 --- a/backend/app/models/code_exploration_result.py +++ b/backend/app/models/code_exploration_result.py @@ -24,9 +24,7 @@ class CodeExplorationResult(Base): __tablename__ = "code_exploration_results" - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) project_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("projects.id", ondelete="CASCADE"), @@ -72,14 +70,10 @@ class CodeExplorationResult(Base): completion_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) # Execution time - execution_time_seconds: Mapped[Decimal | None] = mapped_column( - Numeric(10, 2), nullable=True - ) + execution_time_seconds: Mapped[Decimal | None] = mapped_column(Numeric(10, 2), nullable=True) # Timestamps - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) # Relationships project: Mapped["Project"] = relationship("Project", back_populates="code_explorations") diff --git a/backend/app/models/daily_usage_summary.py b/backend/app/models/daily_usage_summary.py index 2102e8a..f8c8952 100644 --- a/backend/app/models/daily_usage_summary.py +++ b/backend/app/models/daily_usage_summary.py @@ -5,18 +5,18 @@ combination, enabling efficient analytics queries without scanning the full llm_usage_logs table. """ -from datetime import datetime, date + +from datetime import date, datetime +from decimal import Decimal from typing import Optional from uuid import UUID, uuid4 -from decimal import Decimal -from sqlalchemy import String, Integer, BigInteger, Date, DateTime, ForeignKey, Numeric, Index +from sqlalchemy import BigInteger, Date, DateTime, ForeignKey, Integer, Numeric from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column from app.database import Base - # Sentinel UUID used in unique constraint to handle NULL user_id/project_id # PostgreSQL treats NULL as distinct, so we use this sentinel for COALESCE SENTINEL_UUID = UUID("00000000-0000-0000-0000-000000000000") diff --git a/backend/app/models/email_template.py b/backend/app/models/email_template.py index 6832135..c4cc322 100644 --- a/backend/app/models/email_template.py +++ b/backend/app/models/email_template.py @@ -1,4 +1,5 @@ """Email template model for platform-level email templates.""" + import uuid from datetime import datetime, timezone from enum import Enum @@ -34,9 +35,7 @@ class EmailTemplate(Base): __tablename__ = "email_templates" __table_args__ = (UniqueConstraint("key", name="uq_email_template_key"),) - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # Unique key for code reference (e.g., "verification", "invitation") key: Mapped[str] = mapped_column(String(50), nullable=False, unique=True, index=True) @@ -68,9 +67,7 @@ class EmailTemplate(Base): is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) # Timestamps - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), diff --git a/backend/app/models/events/__init__.py b/backend/app/models/events/__init__.py index d265fc0..39b9aaf 100644 --- a/backend/app/models/events/__init__.py +++ b/backend/app/models/events/__init__.py @@ -3,6 +3,7 @@ This module registers SQLAlchemy event listeners for model lifecycle events. Import event handlers here to ensure they are registered when models are loaded. """ + from app.models.events import phase_container_events # noqa: F401 __all__ = ["phase_container_events"] diff --git a/backend/app/models/events/phase_container_events.py b/backend/app/models/events/phase_container_events.py index d5fa350..02ac031 100644 --- a/backend/app/models/events/phase_container_events.py +++ b/backend/app/models/events/phase_container_events.py @@ -3,12 +3,13 @@ This module contains SQLAlchemy event listeners that handle cascade operations when phase containers are archived. """ + import logging + from sqlalchemy import event, inspect, text from app.models.phase_container import PhaseContainer - logger = logging.getLogger(__name__) @@ -41,9 +42,7 @@ def cascade_archive_on_container_archive(mapper, connection, target): if old_val is not None or target.archived_at is None: return - logger.info( - f"Cascading archive from container {target.id} to all contained phases" - ) + logger.info(f"Cascading archive from container {target.id} to all contained phases") # Use raw SQL to update phases - this avoids event loops # and ensures all phases get the exact same archived_at timestamp @@ -53,5 +52,5 @@ def cascade_archive_on_container_archive(mapper, connection, target): SET archived_at = :archived_at WHERE container_id = :container_id AND archived_at IS NULL """), - {"container_id": str(target.id), "archived_at": target.archived_at} + {"container_id": str(target.id), "archived_at": target.archived_at}, ) diff --git a/backend/app/models/feature.py b/backend/app/models/feature.py index 75a5b3c..b74bcdb 100644 --- a/backend/app/models/feature.py +++ b/backend/app/models/feature.py @@ -3,11 +3,13 @@ import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, Enum, DateTime, JSON, Integer + +from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base -from app.utils.short_id import generate_short_id, build_url_identifier +from app.utils.short_id import build_url_identifier, generate_short_id class FeatureProvenance(str, enum.Enum): @@ -116,9 +118,7 @@ class Feature(Base): default=FeatureVisibilityStatus.ACTIVE, ) created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) updated_at = Column( DateTime(timezone=True), nullable=False, diff --git a/backend/app/models/feature_content_version.py b/backend/app/models/feature_content_version.py index 5d1e1b4..7db9c70 100644 --- a/backend/app/models/feature_content_version.py +++ b/backend/app/models/feature_content_version.py @@ -3,9 +3,11 @@ import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, Integer, Boolean, ForeignKey, Enum, DateTime, UniqueConstraint, Index, JSON + +from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Index, Integer, String, Text, UniqueConstraint from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -54,9 +56,7 @@ class FeatureContentVersion(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) # Relationships feature = relationship("Feature", back_populates="content_versions") diff --git a/backend/app/models/feature_import_comment.py b/backend/app/models/feature_import_comment.py index 20ac209..eeefc99 100644 --- a/backend/app/models/feature_import_comment.py +++ b/backend/app/models/feature_import_comment.py @@ -1,10 +1,11 @@ """FeatureImportComment model for storing original comments from imported issues.""" import uuid -from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, Integer, Index + +from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -41,6 +42,4 @@ class FeatureImportComment(Base): # Relationships feature = relationship("Feature", back_populates="import_comments") - __table_args__ = ( - Index("ix_feature_import_comments_feature_order", "feature_id", "order_index"), - ) + __table_args__ = (Index("ix_feature_import_comments_feature_order", "feature_id", "order_index"),) diff --git a/backend/app/models/final_prompt_plan.py b/backend/app/models/final_prompt_plan.py index 09c871d..911ed53 100644 --- a/backend/app/models/final_prompt_plan.py +++ b/backend/app/models/final_prompt_plan.py @@ -2,9 +2,11 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, JSON + +from sqlalchemy import JSON, Column, DateTime, ForeignKey, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -41,9 +43,7 @@ class FinalPromptPlan(Base): ForeignKey("spec_versions.id", ondelete="SET NULL"), nullable=True, ) - generated_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + generated_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False) # Relationships diff --git a/backend/app/models/final_spec.py b/backend/app/models/final_spec.py index 3a9922f..b4db61c 100644 --- a/backend/app/models/final_spec.py +++ b/backend/app/models/final_spec.py @@ -2,9 +2,11 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, JSON + +from sqlalchemy import JSON, Column, DateTime, ForeignKey, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -41,9 +43,7 @@ class FinalSpec(Base): ForeignKey("spec_versions.id", ondelete="SET NULL"), nullable=True, ) - generated_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + generated_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False) # Relationships diff --git a/backend/app/models/form_draft.py b/backend/app/models/form_draft.py index 58ea426..d2e09f5 100644 --- a/backend/app/models/form_draft.py +++ b/backend/app/models/form_draft.py @@ -4,7 +4,7 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, ForeignKey, DateTime, Enum, Index, JSON +from sqlalchemy import JSON, Column, DateTime, Enum, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship diff --git a/backend/app/models/github_oauth_state.py b/backend/app/models/github_oauth_state.py index c74cccc..6296dc5 100644 --- a/backend/app/models/github_oauth_state.py +++ b/backend/app/models/github_oauth_state.py @@ -29,17 +29,11 @@ class GitHubOAuthState(Base): nullable=False, index=True, ) - state_token: Mapped[str] = mapped_column( - String(64), unique=True, nullable=False, index=True - ) + state_token: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) display_name: Mapped[str] = mapped_column(String(100), nullable=False) visibility: Mapped[str] = mapped_column(String(20), nullable=False, default="org") - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) - expires_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) + expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/grounding_file.py b/backend/app/models/grounding_file.py index ceb4248..06ce407 100644 --- a/backend/app/models/grounding_file.py +++ b/backend/app/models/grounding_file.py @@ -2,9 +2,11 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, Boolean, Index + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base diff --git a/backend/app/models/grounding_file_branch.py b/backend/app/models/grounding_file_branch.py index 6ef98ed..d389afe 100644 --- a/backend/app/models/grounding_file_branch.py +++ b/backend/app/models/grounding_file_branch.py @@ -2,9 +2,11 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, Boolean, Index + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -89,7 +91,4 @@ class GroundingFileBranch(Base): ) def __repr__(self) -> str: - return ( - f"" - ) + return f"" diff --git a/backend/app/models/grounding_note_version.py b/backend/app/models/grounding_note_version.py index 1662780..19587ea 100644 --- a/backend/app/models/grounding_note_version.py +++ b/backend/app/models/grounding_note_version.py @@ -2,9 +2,11 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, Integer, Boolean, ForeignKey, DateTime, UniqueConstraint, Index + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -37,9 +39,7 @@ class GroundingNoteVersion(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) # Relationships project = relationship("Project", back_populates="grounding_note_versions") diff --git a/backend/app/models/identity_provider.py b/backend/app/models/identity_provider.py index e43e14f..dc211a6 100644 --- a/backend/app/models/identity_provider.py +++ b/backend/app/models/identity_provider.py @@ -4,12 +4,14 @@ Represents authentication providers (Google, GitHub, future external IdPs). Designed to support Phase 2 external IdP integration without data migration. """ + from datetime import datetime -from typing import Optional, TYPE_CHECKING -from uuid import UUID, uuid4 from enum import Enum +from typing import TYPE_CHECKING, Optional +from uuid import UUID, uuid4 -from sqlalchemy import String, DateTime, func, Text, Enum as SQLEnum +from sqlalchemy import DateTime, String, func +from sqlalchemy import Enum as SQLEnum from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -29,6 +31,7 @@ class IdentityProviderType(str, Enum): - OAUTH_GITHUB: GitHub OAuth2 - OIDC_GENERIC: Future external IdPs (Auth0, WorkOS, Keycloak, etc.) """ + LOCAL = "local" OAUTH_GOOGLE = "oauth_google" OAUTH_GITHUB = "oauth_github" diff --git a/backend/app/models/implementation.py b/backend/app/models/implementation.py index 29124b4..8e0a41c 100644 --- a/backend/app/models/implementation.py +++ b/backend/app/models/implementation.py @@ -2,9 +2,11 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, Boolean, ForeignKey, DateTime, Integer + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -47,9 +49,7 @@ class Implementation(Base): # Completion tracking is_complete = Column(Boolean, nullable=False, default=False) completed_at = Column(DateTime(timezone=True), nullable=True) - completed_by_id = Column( - PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=True - ) + completed_by_id = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=True) completion_summary = Column(Text, nullable=True) # Summary when marked complete # Primary flag - first/default implementation for the feature @@ -63,12 +63,8 @@ class Implementation(Base): is_generating_prompt_plan = Column(Boolean, nullable=False, default=False) # Audit - created_by = Column( - PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False - ) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) updated_at = Column( DateTime(timezone=True), nullable=False, diff --git a/backend/app/models/inbox_follow.py b/backend/app/models/inbox_follow.py index 746e368..8e39985 100644 --- a/backend/app/models/inbox_follow.py +++ b/backend/app/models/inbox_follow.py @@ -2,11 +2,12 @@ Hierarchical follow system for projects and threads. """ + import enum -from datetime import datetime, timezone +from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import String, DateTime, Enum, ForeignKey, func, Index +from sqlalchemy import DateTime, Enum, ForeignKey, Index, String, func from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -15,12 +16,14 @@ class InboxFollowType(str, enum.Enum): """Type of follow relationship.""" + PROJECT = "project" THREAD = "thread" class InboxThreadType(str, enum.Enum): """Type of thread being followed.""" + FEATURE = "feature" PHASE = "phase" PROJECT_CHAT = "project_chat" diff --git a/backend/app/models/inbox_mention.py b/backend/app/models/inbox_mention.py index 2e3c21e..46f189d 100644 --- a/backend/app/models/inbox_mention.py +++ b/backend/app/models/inbox_mention.py @@ -2,11 +2,12 @@ Tracks mentions of users in project chats, features, and phase threads. """ + import enum -from datetime import datetime, timezone +from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import String, Boolean, DateTime, Enum, ForeignKey, func, Index, Integer +from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, Index, Integer, String, func from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -15,6 +16,7 @@ class InboxConversationType(str, enum.Enum): """Type of conversation where mention occurred.""" + PROJECT_CHAT = "project_chat" FEATURE = "feature" PHASE = "phase" diff --git a/backend/app/models/integration_config.py b/backend/app/models/integration_config.py index e513cc5..2be4359 100644 --- a/backend/app/models/integration_config.py +++ b/backend/app/models/integration_config.py @@ -1,4 +1,5 @@ """Integration configuration model.""" + import uuid from datetime import datetime, timezone from enum import Enum @@ -26,15 +27,9 @@ class IntegrationConfig(Base): """External integration configuration (per organization).""" __tablename__ = "integration_configs" - __table_args__ = ( - UniqueConstraint( - "organization_id", "provider", "display_name", name="uq_org_provider_name" - ), - ) + __table_args__ = (UniqueConstraint("organization_id", "provider", "display_name", name="uq_org_provider_name"),) - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) organization_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), @@ -60,9 +55,7 @@ class IntegrationConfig(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), @@ -70,7 +63,5 @@ class IntegrationConfig(Base): ) # Relationships - organization: Mapped["Organization"] = relationship( - "Organization", back_populates="integration_configs" - ) + organization: Mapped["Organization"] = relationship("Organization", back_populates="integration_configs") created_by: Mapped[Optional["User"]] = relationship("User", lazy="joined") diff --git a/backend/app/models/integration_config_share.py b/backend/app/models/integration_config_share.py index 6f4f292..0c6b49c 100644 --- a/backend/app/models/integration_config_share.py +++ b/backend/app/models/integration_config_share.py @@ -47,14 +47,10 @@ class IntegrationConfigShare(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) # Relationships - integration_config: Mapped["IntegrationConfig"] = relationship( - "IntegrationConfig", lazy="joined" - ) + integration_config: Mapped["IntegrationConfig"] = relationship("IntegrationConfig", lazy="joined") created_by: Mapped[Optional["User"]] = relationship("User", lazy="joined") __table_args__ = ( @@ -67,4 +63,6 @@ class IntegrationConfigShare(Base): ) def __repr__(self) -> str: - return f"" + return ( + f"" + ) diff --git a/backend/app/models/job.py b/backend/app/models/job.py index 624d095..95146ea 100644 --- a/backend/app/models/job.py +++ b/backend/app/models/job.py @@ -3,19 +3,21 @@ Jobs represent asynchronous tasks processed by workers via Kafka. """ -from datetime import datetime, UTC -from typing import Optional, TYPE_CHECKING, List -from uuid import UUID, uuid4 + +from datetime import datetime from enum import Enum +from typing import TYPE_CHECKING, List, Optional +from uuid import UUID, uuid4 -from sqlalchemy import String, Integer, DateTime, Enum as SQLEnum, func, Text, JSON, Numeric +from sqlalchemy import JSON, DateTime, Integer, Numeric, String, Text, func +from sqlalchemy import Enum as SQLEnum from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base if TYPE_CHECKING: - from app.models.llm_call_log import LLMCallLog from app.models.code_exploration_result import CodeExplorationResult + from app.models.llm_call_log import LLMCallLog class JobType(str, Enum): diff --git a/backend/app/models/llm_call_log.py b/backend/app/models/llm_call_log.py index 7dffe9d..4f96b9d 100644 --- a/backend/app/models/llm_call_log.py +++ b/backend/app/models/llm_call_log.py @@ -4,12 +4,13 @@ This model stores detailed information about each LLM call made during agent execution, including request/response data for debugging and analysis. """ + from datetime import datetime -from typing import Optional, List -from uuid import UUID, uuid4 from decimal import Decimal +from typing import List, Optional +from uuid import UUID, uuid4 -from sqlalchemy import String, Integer, DateTime, Text, ForeignKey, Numeric, Index, JSON +from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, Numeric, String, Text from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -148,13 +149,13 @@ class LLMCallLog(Base): job: Mapped["Job"] = relationship("Job", back_populates="llm_call_logs") # Composite index for efficient job + time queries - __table_args__ = ( - Index("ix_llm_call_logs_job_created", "job_id", "created_at"), - ) + __table_args__ = (Index("ix_llm_call_logs_job_created", "job_id", "created_at"),) def __repr__(self) -> str: """String representation of LLMCallLog.""" - return f"" + return ( + f"" + ) @property def total_tokens(self) -> int: diff --git a/backend/app/models/llm_preference.py b/backend/app/models/llm_preference.py index 978ff31..44dea3c 100644 --- a/backend/app/models/llm_preference.py +++ b/backend/app/models/llm_preference.py @@ -1,4 +1,5 @@ """LLM Preference model for organization-level LLM selections.""" + import uuid from datetime import datetime, timezone @@ -13,13 +14,9 @@ class LLMPreference(Base): """Organization-level LLM preferences for main and lightweight LLMs.""" __tablename__ = "llm_preferences" - __table_args__ = ( - UniqueConstraint("organization_id", name="uq_org_llm_preference"), - ) + __table_args__ = (UniqueConstraint("organization_id", name="uq_org_llm_preference"),) - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) organization_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), @@ -38,19 +35,11 @@ class LLMPreference(Base): ) # Testing & Debugging settings - mock_discovery_enabled: Mapped[bool] = mapped_column( - Boolean, default=False, nullable=False - ) - mock_discovery_question_limit: Mapped[int] = mapped_column( - Integer, default=10, nullable=False - ) - mock_discovery_delay_seconds: Mapped[int] = mapped_column( - Integer, default=5, nullable=False - ) + mock_discovery_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + mock_discovery_question_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False) + mock_discovery_delay_seconds: Mapped[int] = mapped_column(Integer, default=5, nullable=False) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), @@ -58,9 +47,7 @@ class LLMPreference(Base): ) # Relationships - organization: Mapped["Organization"] = relationship( - "Organization", back_populates="llm_preference" - ) + organization: Mapped["Organization"] = relationship("Organization", back_populates="llm_preference") main_llm_config: Mapped["IntegrationConfig"] = relationship( "IntegrationConfig", foreign_keys=[main_llm_config_id], diff --git a/backend/app/models/llm_usage_log.py b/backend/app/models/llm_usage_log.py index 310f480..1fd4866 100644 --- a/backend/app/models/llm_usage_log.py +++ b/backend/app/models/llm_usage_log.py @@ -4,15 +4,16 @@ This model stores essential usage metrics for all LLM calls (including standalone calls outside of job context), enabling per-agent analytics and cost tracking. """ + from datetime import datetime -from typing import Optional, TYPE_CHECKING -from uuid import UUID, uuid4 from decimal import Decimal +from typing import TYPE_CHECKING, Optional +from uuid import UUID, uuid4 if TYPE_CHECKING: from app.models.user import User -from sqlalchemy import String, Integer, DateTime, ForeignKey, Numeric, Index +from sqlalchemy import DateTime, ForeignKey, Index, Integer, Numeric, String from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column, relationship diff --git a/backend/app/models/mcp_call_log.py b/backend/app/models/mcp_call_log.py index 74f2aa7..851443f 100644 --- a/backend/app/models/mcp_call_log.py +++ b/backend/app/models/mcp_call_log.py @@ -5,11 +5,12 @@ via the HTTP transport, including request/response data for debugging and auditing. """ + from datetime import datetime from typing import Optional from uuid import UUID, uuid4 -from sqlalchemy import String, Integer, DateTime, ForeignKey, Index, JSON, Boolean +from sqlalchemy import JSON, Boolean, DateTime, ForeignKey, Index, Integer, String from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column, relationship diff --git a/backend/app/models/mcp_image_submission.py b/backend/app/models/mcp_image_submission.py index 2346b5a..2fcd15d 100644 --- a/backend/app/models/mcp_image_submission.py +++ b/backend/app/models/mcp_image_submission.py @@ -4,16 +4,16 @@ This model stores image data temporarily until a comment references it, at which point the image is uploaded to S3 and attached to the comment. """ -from datetime import datetime, timezone, timedelta + +from datetime import datetime, timedelta, timezone from uuid import UUID, uuid4 -from sqlalchemy import String, DateTime, ForeignKey, LargeBinary, Index +from sqlalchemy import DateTime, ForeignKey, Index, LargeBinary, String from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base - # Default expiry time for staged images (1 hour) MCP_IMAGE_SUBMISSION_EXPIRY_HOURS = 1 diff --git a/backend/app/models/mcp_oauth_client.py b/backend/app/models/mcp_oauth_client.py index 970c9d7..aa79d80 100644 --- a/backend/app/models/mcp_oauth_client.py +++ b/backend/app/models/mcp_oauth_client.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from uuid import UUID, uuid4 -from sqlalchemy import Boolean, DateTime, String, Index +from sqlalchemy import Boolean, DateTime, Index, String from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.orm import Mapped, mapped_column @@ -37,14 +37,10 @@ class MCPOAuthClient(Base): ) # Response types supported - response_types: Mapped[list] = mapped_column( - JSON, nullable=False, default=lambda: ["code"] - ) + response_types: Mapped[list] = mapped_column(JSON, nullable=False, default=lambda: ["code"]) # Token endpoint auth method (always "none" for public clients) - token_endpoint_auth_method: Mapped[str] = mapped_column( - String(50), nullable=False, default="none" - ) + token_endpoint_auth_method: Mapped[str] = mapped_column(String(50), nullable=False, default="none") # Registration metadata created_at: Mapped[datetime] = mapped_column( @@ -53,9 +49,7 @@ class MCPOAuthClient(Base): last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) revoked: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - __table_args__ = ( - Index("idx_mcp_oauth_clients_client_id", "client_id"), - ) + __table_args__ = (Index("idx_mcp_oauth_clients_client_id", "client_id"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/mcp_oauth_code.py b/backend/app/models/mcp_oauth_code.py index 53b2950..2e28481 100644 --- a/backend/app/models/mcp_oauth_code.py +++ b/backend/app/models/mcp_oauth_code.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from uuid import UUID, uuid4 -from sqlalchemy import Boolean, DateTime, ForeignKey, String, Index +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base @@ -28,29 +28,21 @@ class MCPOAuthAuthorizationCode(Base): code_hash: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) # Client that requested the code - client_id: Mapped[UUID] = mapped_column( - ForeignKey("mcp_oauth_clients.id", ondelete="CASCADE"), nullable=False - ) + client_id: Mapped[UUID] = mapped_column(ForeignKey("mcp_oauth_clients.id", ondelete="CASCADE"), nullable=False) # User who authorized the code - user_id: Mapped[UUID] = mapped_column( - ForeignKey("users.id", ondelete="CASCADE"), nullable=False - ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False) # Resource (audience) for token binding - the MCP endpoint URL # Format: https://app.mfbt.ai/api/v1/projects/{project_id}/mcp resource: Mapped[str] = mapped_column(String(512), nullable=False) # Project extracted from resource URL for efficient lookup - project_id: Mapped[UUID] = mapped_column( - ForeignKey("projects.id", ondelete="CASCADE"), nullable=False - ) + project_id: Mapped[UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) # PKCE parameters (required per MCP spec, only S256 supported) code_challenge: Mapped[str] = mapped_column(String(128), nullable=False) - code_challenge_method: Mapped[str] = mapped_column( - String(10), nullable=False, default="S256" - ) + code_challenge_method: Mapped[str] = mapped_column(String(10), nullable=False, default="S256") # OAuth parameters redirect_uri: Mapped[str] = mapped_column(String(512), nullable=False) diff --git a/backend/app/models/mcp_oauth_token.py b/backend/app/models/mcp_oauth_token.py index 33a1234..77ffb4b 100644 --- a/backend/app/models/mcp_oauth_token.py +++ b/backend/app/models/mcp_oauth_token.py @@ -3,7 +3,7 @@ from datetime import datetime, timezone from uuid import UUID, uuid4 -from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, Index +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String, Text from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base @@ -31,23 +31,17 @@ class MCPOAuthToken(Base): id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) # User who authorized the token - user_id: Mapped[UUID] = mapped_column( - ForeignKey("users.id", ondelete="CASCADE"), nullable=False - ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False) # Client that requested the token - client_id: Mapped[UUID] = mapped_column( - ForeignKey("mcp_oauth_clients.id", ondelete="CASCADE"), nullable=False - ) + client_id: Mapped[UUID] = mapped_column(ForeignKey("mcp_oauth_clients.id", ondelete="CASCADE"), nullable=False) # Resource (audience) this token is valid for # Format: https://app.mfbt.ai/api/v1/projects/{project_id}/mcp resource: Mapped[str] = mapped_column(String(512), nullable=False) # Project ID extracted from resource for efficient lookup - project_id: Mapped[UUID] = mapped_column( - ForeignKey("projects.id", ondelete="CASCADE"), nullable=False - ) + project_id: Mapped[UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) # Access token JTI (JWT ID) for validation and revocation checks access_token_jti: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) @@ -63,12 +57,8 @@ class MCPOAuthToken(Base): issued_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False ) - access_token_expires_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) - refresh_token_expires_at: Mapped[datetime | None] = mapped_column( - DateTime(timezone=True), nullable=True - ) + access_token_expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + refresh_token_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) # Status diff --git a/backend/app/models/module.py b/backend/app/models/module.py index 59af7e2..b5afd83 100644 --- a/backend/app/models/module.py +++ b/backend/app/models/module.py @@ -3,11 +3,13 @@ import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, Integer, ForeignKey, Enum, DateTime + +from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base -from app.utils.short_id import generate_short_id, build_url_identifier +from app.utils.short_id import build_url_identifier, generate_short_id class ModuleProvenance(str, enum.Enum): @@ -69,9 +71,7 @@ class Module(Base): module_key = Column(String(20), nullable=False, index=True) # e.g., "MPROJ-001" module_key_number = Column(Integer, nullable=False, index=True) # Numeric part for sorting created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=False) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) archived_at = Column(DateTime(timezone=True), nullable=True) # Short ID for URL-friendly identifiers (YouTube-style) diff --git a/backend/app/models/notification_preference.py b/backend/app/models/notification_preference.py index c1d973c..1f52cea 100644 --- a/backend/app/models/notification_preference.py +++ b/backend/app/models/notification_preference.py @@ -3,19 +3,22 @@ Stores user-level notification channel preferences. """ -from datetime import datetime, UTC + +from datetime import datetime from enum import Enum from typing import Optional from uuid import UUID, uuid4 -from sqlalchemy import String, DateTime, Enum as SAEnum, ForeignKey, func, Boolean -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy import Boolean, DateTime, ForeignKey, String, func +from sqlalchemy import Enum as SAEnum +from sqlalchemy.orm import Mapped, mapped_column from app.database import Base class NotificationChannel(str, Enum): """Notification channel types.""" + EMAIL = "email" SLACK = "slack" TEAMS = "teams" diff --git a/backend/app/models/notification_project_mute.py b/backend/app/models/notification_project_mute.py index 32e6de2..ee9166d 100644 --- a/backend/app/models/notification_project_mute.py +++ b/backend/app/models/notification_project_mute.py @@ -3,10 +3,11 @@ Allows users to mute notifications for specific projects. """ -from datetime import datetime, UTC + +from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import DateTime, ForeignKey, func, UniqueConstraint +from sqlalchemy import DateTime, ForeignKey, UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column from app.database import Base @@ -21,9 +22,7 @@ class NotificationProjectMute(Base): """ __tablename__ = "notification_project_mutes" - __table_args__ = ( - UniqueConstraint("user_id", "project_id", name="uq_user_project_mute"), - ) + __table_args__ = (UniqueConstraint("user_id", "project_id", name="uq_user_project_mute"),) id: Mapped[UUID] = mapped_column( primary_key=True, diff --git a/backend/app/models/notification_thread_watch.py b/backend/app/models/notification_thread_watch.py index 06bde8d..0c481b8 100644 --- a/backend/app/models/notification_thread_watch.py +++ b/backend/app/models/notification_thread_watch.py @@ -3,10 +3,11 @@ Tracks which users are watching specific threads for notifications. """ -from datetime import datetime, UTC + +from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import DateTime, ForeignKey, String, func, UniqueConstraint +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column from app.database import Base @@ -21,9 +22,7 @@ class NotificationThreadWatch(Base): """ __tablename__ = "notification_thread_watches" - __table_args__ = ( - UniqueConstraint("user_id", "thread_id", name="uq_user_thread_watch"), - ) + __table_args__ = (UniqueConstraint("user_id", "thread_id", name="uq_user_thread_watch"),) id: Mapped[UUID] = mapped_column( primary_key=True, diff --git a/backend/app/models/org_invitation.py b/backend/app/models/org_invitation.py index 77ce94a..e5b5a14 100644 --- a/backend/app/models/org_invitation.py +++ b/backend/app/models/org_invitation.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional from uuid import UUID, uuid4 -from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint, func +from sqlalchemy import DateTime, ForeignKey, String, func from sqlalchemy import Enum as SQLEnum from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -15,9 +15,9 @@ from app.models.provisioning import ProvisioningSource if TYPE_CHECKING: + from app.models.org_invitation_group import OrgInvitationGroup from app.models.organization import Organization from app.models.user import User - from app.models.org_invitation_group import OrgInvitationGroup class InvitationStatus(str, Enum): @@ -54,12 +54,8 @@ class OrgInvitation(Base): SQLEnum(OrgRole, native_enum=False, length=20), nullable=False, ) - token: Mapped[str] = mapped_column( - String(64), unique=True, nullable=False, index=True - ) - expires_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), nullable=False - ) + token: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) + expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) status: Mapped[InvitationStatus] = mapped_column( SQLEnum( InvitationStatus, @@ -74,9 +70,7 @@ class OrgInvitation(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - accepted_at: Mapped[Optional[datetime]] = mapped_column( - DateTime(timezone=True), nullable=True - ) + accepted_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) provisioning_source: Mapped[ProvisioningSource] = mapped_column( SQLEnum( ProvisioningSource, @@ -87,12 +81,8 @@ class OrgInvitation(Base): default=ProvisioningSource.MANUAL, nullable=False, ) - metadata_json: Mapped[Optional[dict]] = mapped_column( - JSON, nullable=True, default=None - ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + metadata_json: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -102,12 +92,8 @@ class OrgInvitation(Base): # Relationships organization: Mapped["Organization"] = relationship("Organization", lazy="joined") - invited_by: Mapped[Optional["User"]] = relationship( - "User", foreign_keys=[invited_by_user_id], lazy="joined" - ) - accepted_by: Mapped[Optional["User"]] = relationship( - "User", foreign_keys=[accepted_by_user_id] - ) + invited_by: Mapped[Optional["User"]] = relationship("User", foreign_keys=[invited_by_user_id], lazy="joined") + accepted_by: Mapped[Optional["User"]] = relationship("User", foreign_keys=[accepted_by_user_id]) group_assignments: Mapped[list["OrgInvitationGroup"]] = relationship( "OrgInvitationGroup", back_populates="invitation", diff --git a/backend/app/models/org_invitation_group.py b/backend/app/models/org_invitation_group.py index ae31c17..0571eff 100644 --- a/backend/app/models/org_invitation_group.py +++ b/backend/app/models/org_invitation_group.py @@ -35,14 +35,10 @@ class OrgInvitationGroup(Base): ) # Relationships - invitation: Mapped["OrgInvitation"] = relationship( - "OrgInvitation", back_populates="group_assignments" - ) + invitation: Mapped["OrgInvitation"] = relationship("OrgInvitation", back_populates="group_assignments") group: Mapped["UserGroup"] = relationship("UserGroup", lazy="joined") - __table_args__ = ( - UniqueConstraint("invitation_id", "group_id", name="uq_invitation_group"), - ) + __table_args__ = (UniqueConstraint("invitation_id", "group_id", name="uq_invitation_group"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/org_membership.py b/backend/app/models/org_membership.py index 495d23a..9205de0 100644 --- a/backend/app/models/org_membership.py +++ b/backend/app/models/org_membership.py @@ -38,12 +38,8 @@ class OrgMembership(Base): nullable=False, index=True, ) - user_id: Mapped[UUID] = mapped_column( - ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True - ) - role: Mapped[OrgRole] = mapped_column( - SQLEnum(OrgRole, native_enum=False, length=20), nullable=False - ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) + role: Mapped[OrgRole] = mapped_column(SQLEnum(OrgRole, native_enum=False, length=20), nullable=False) provisioning_source: Mapped[ProvisioningSource] = mapped_column( SQLEnum( ProvisioningSource, @@ -54,12 +50,8 @@ class OrgMembership(Base): default=ProvisioningSource.MANUAL, nullable=False, ) - metadata_json: Mapped[Optional[dict]] = mapped_column( - JSON, nullable=True, default=None - ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + metadata_json: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -71,9 +63,7 @@ class OrgMembership(Base): org: Mapped["Organization"] = relationship("Organization", lazy="joined") user: Mapped["User"] = relationship("User", lazy="joined") - __table_args__ = ( - UniqueConstraint("org_id", "user_id", name="uq_org_user"), - ) + __table_args__ = (UniqueConstraint("org_id", "user_id", name="uq_org_user"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/organization.py b/backend/app/models/organization.py index 990041c..96319a7 100644 --- a/backend/app/models/organization.py +++ b/backend/app/models/organization.py @@ -24,9 +24,7 @@ class Organization(Base): index=True, doc="External SSO provider organization ID (e.g., from Scalekit, Auth0, WorkOS)", ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -93,7 +91,9 @@ class Organization(Base): # Relationships projects = relationship("Project", back_populates="organization", cascade="all, delete-orphan") integration_configs = relationship("IntegrationConfig", back_populates="organization", cascade="all, delete-orphan") - llm_preference = relationship("LLMPreference", back_populates="organization", uselist=False, cascade="all, delete-orphan") + llm_preference = relationship( + "LLMPreference", back_populates="organization", uselist=False, cascade="all, delete-orphan" + ) project_chats = relationship("ProjectChat", back_populates="organization", cascade="all, delete-orphan") def __repr__(self) -> str: diff --git a/backend/app/models/phase_container.py b/backend/app/models/phase_container.py index b1053a8..43f92c6 100644 --- a/backend/app/models/phase_container.py +++ b/backend/app/models/phase_container.py @@ -2,11 +2,13 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, ForeignKey, DateTime, Integer + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base -from app.utils.short_id import generate_short_id, build_url_identifier +from app.utils.short_id import build_url_identifier, generate_short_id class PhaseContainer(Base): @@ -32,9 +34,7 @@ class PhaseContainer(Base): title = Column(String(500), nullable=False) # Short ID for URL-friendly identifiers (YouTube-style) - short_id = Column( - String(11), nullable=False, unique=True, index=True, default=generate_short_id - ) + short_id = Column(String(11), nullable=False, unique=True, index=True, default=generate_short_id) # Ordering within the project order_index = Column(Integer, nullable=False, default=0) @@ -48,9 +48,7 @@ class PhaseContainer(Base): ) # Timestamps - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) updated_at = Column( DateTime(timezone=True), nullable=False, diff --git a/backend/app/models/plan_recommendation.py b/backend/app/models/plan_recommendation.py index d68f427..3eb18bc 100644 --- a/backend/app/models/plan_recommendation.py +++ b/backend/app/models/plan_recommendation.py @@ -9,8 +9,9 @@ from typing import Optional from uuid import UUID, uuid4 -from sqlalchemy import BigInteger, Date, DateTime, Float, ForeignKey, Integer, String, Text, func -from sqlalchemy.dialects.postgresql import UUID as PGUUID, ENUM +from sqlalchemy import BigInteger, Date, DateTime, Float, ForeignKey, Integer, Text, func +from sqlalchemy.dialects.postgresql import ENUM +from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base diff --git a/backend/app/models/platform_connector.py b/backend/app/models/platform_connector.py index c5e5eb5..d6bcd9b 100644 --- a/backend/app/models/platform_connector.py +++ b/backend/app/models/platform_connector.py @@ -1,4 +1,5 @@ """Platform-level connector configuration model.""" + import uuid from datetime import datetime, timezone from enum import Enum @@ -30,27 +31,17 @@ class PlatformConnector(Base): __tablename__ = "platform_connectors" __table_args__ = ( - UniqueConstraint( - "connector_type", "provider", "display_name", name="uq_platform_connector_name" - ), + UniqueConstraint("connector_type", "provider", "display_name", name="uq_platform_connector_name"), ) - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) - connector_type: Mapped[str] = mapped_column( - String(50), nullable=False, index=True - ) # llm, email, object_storage - provider: Mapped[str] = mapped_column( - String(50), nullable=False - ) # anthropic, openai, sendgrid, aws-s3 + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + connector_type: Mapped[str] = mapped_column(String(50), nullable=False, index=True) # llm, email, object_storage + provider: Mapped[str] = mapped_column(String(50), nullable=False) # anthropic, openai, sendgrid, aws-s3 display_name: Mapped[str] = mapped_column(String(100), nullable=False) encrypted_credentials: Mapped[str] = mapped_column(Text, nullable=False) config_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), diff --git a/backend/app/models/platform_settings.py b/backend/app/models/platform_settings.py index b74f1a5..5345b36 100644 --- a/backend/app/models/platform_settings.py +++ b/backend/app/models/platform_settings.py @@ -1,4 +1,5 @@ """Platform-level settings model (singleton).""" + import uuid from datetime import datetime, timezone @@ -19,23 +20,13 @@ class PlatformSettings(Base): __tablename__ = "platform_settings" - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # Selected connectors for each service - main_llm_connector_id: Mapped[uuid.UUID | None] = mapped_column( - UUID(as_uuid=True), nullable=True - ) - lightweight_llm_connector_id: Mapped[uuid.UUID | None] = mapped_column( - UUID(as_uuid=True), nullable=True - ) - email_connector_id: Mapped[uuid.UUID | None] = mapped_column( - UUID(as_uuid=True), nullable=True - ) - object_storage_connector_id: Mapped[uuid.UUID | None] = mapped_column( - UUID(as_uuid=True), nullable=True - ) + main_llm_connector_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True) + lightweight_llm_connector_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True) + email_connector_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True) + object_storage_connector_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), nullable=True) # General platform settings base_url: Mapped[str | None] = mapped_column( @@ -43,65 +34,48 @@ class PlatformSettings(Base): ) # Testing/debug settings (migrated from org-level) - mock_discovery_enabled: Mapped[bool] = mapped_column( - Boolean, default=False, nullable=False - ) - mock_discovery_question_limit: Mapped[int] = mapped_column( - Integer, default=10, nullable=False - ) - mock_discovery_delay_seconds: Mapped[int] = mapped_column( - Integer, default=5, nullable=False - ) + mock_discovery_enabled: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + mock_discovery_question_limit: Mapped[int] = mapped_column(Integer, default=10, nullable=False) + mock_discovery_delay_seconds: Mapped[int] = mapped_column(Integer, default=5, nullable=False) # Freemium plan settings freemium_initial_tokens: Mapped[int] = mapped_column( - Integer, default=5_000_000, nullable=False, - comment="Initial tokens granted to new users on signup" + Integer, default=5_000_000, nullable=False, comment="Initial tokens granted to new users on signup" ) freemium_weekly_topup_tokens: Mapped[int] = mapped_column( - Integer, default=10_000_000, nullable=False, - comment="Tokens added each Monday (additive, up to max)" + Integer, default=10_000_000, nullable=False, comment="Tokens added each Monday (additive, up to max)" ) freemium_max_tokens: Mapped[int] = mapped_column( - Integer, default=10_000_000, nullable=False, - comment="Maximum token balance for freemium users" + Integer, default=10_000_000, nullable=False, comment="Maximum token balance for freemium users" ) # Code Explorer settings code_explorer_connector_id: Mapped[uuid.UUID | None] = mapped_column( - UUID(as_uuid=True), nullable=True, - comment="Platform connector for Code Explorer Anthropic API key" + UUID(as_uuid=True), nullable=True, comment="Platform connector for Code Explorer Anthropic API key" ) code_explorer_enabled: Mapped[bool] = mapped_column( - Boolean, default=False, nullable=False, - comment="Whether Code Explorer feature is enabled" + Boolean, default=False, nullable=False, comment="Whether Code Explorer feature is enabled" ) # Web Search settings web_search_connector_id: Mapped[uuid.UUID | None] = mapped_column( - UUID(as_uuid=True), nullable=True, - comment="Platform connector for Tavily web search API key" + UUID(as_uuid=True), nullable=True, comment="Platform connector for Tavily web search API key" ) web_search_enabled: Mapped[bool] = mapped_column( - Boolean, default=False, nullable=False, - comment="Whether Web Search feature is enabled" + Boolean, default=False, nullable=False, comment="Whether Web Search feature is enabled" ) # GitHub OAuth for Integration Connectors # These allow UI configuration of OAuth credentials, taking precedence over env vars github_oauth_client_id_encrypted: Mapped[str | None] = mapped_column( - nullable=True, - comment="Encrypted GitHub OAuth App Client ID for integration connectors" + nullable=True, comment="Encrypted GitHub OAuth App Client ID for integration connectors" ) github_oauth_client_secret_encrypted: Mapped[str | None] = mapped_column( - nullable=True, - comment="Encrypted GitHub OAuth App Client Secret for integration connectors" + nullable=True, comment="Encrypted GitHub OAuth App Client Secret for integration connectors" ) # Timestamps - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), diff --git a/backend/app/models/project.py b/backend/app/models/project.py index e915632..335e5dc 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,13 +1,14 @@ """Project model.""" + import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Boolean, Column, String, Text, ForeignKey, Enum as SAEnum, TIMESTAMP, UUID +from sqlalchemy import TIMESTAMP, UUID, Boolean, Column, ForeignKey, String, Text from sqlalchemy.orm import relationship from app.database import Base -from app.utils.short_id import generate_short_id, build_url_identifier +from app.utils.short_id import build_url_identifier, generate_short_id class ProjectType(str, enum.Enum): @@ -69,7 +70,12 @@ class Project(Base): created_by = Column(UUID, ForeignKey("users.id"), nullable=False) created_at = Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(TIMESTAMP(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + TIMESTAMP(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) deleted_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) # Sample project flag for onboarding diff --git a/backend/app/models/project_chat.py b/backend/app/models/project_chat.py index 89629ab..e7eaadd 100644 --- a/backend/app/models/project_chat.py +++ b/backend/app/models/project_chat.py @@ -7,29 +7,30 @@ """ import enum -from datetime import datetime, timezone -from typing import Optional, List, TYPE_CHECKING +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional from uuid import UUID, uuid4 -from sqlalchemy import String, Boolean, DateTime, Text, Enum, ForeignKey, JSON, func +from sqlalchemy import JSON, Boolean, DateTime, Enum, ForeignKey, String, Text, func from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base -from app.utils.short_id import generate_short_id, build_url_identifier +from app.utils.short_id import build_url_identifier, generate_short_id if TYPE_CHECKING: - from app.models.user import User - from app.models.project import Project - from app.models.job import Job from app.models.brainstorming_phase import BrainstormingPhase + from app.models.code_exploration_result import CodeExplorationResult + from app.models.job import Job from app.models.module import Module from app.models.organization import Organization - from app.models.code_exploration_result import CodeExplorationResult + from app.models.project import Project + from app.models.user import User class ProjectChatMessageType(str, enum.Enum): """Type of message in a project chat.""" + USER = "user" BOT = "bot" CODE_EXPLORATION = "code_exploration" @@ -39,6 +40,7 @@ class ProjectChatMessageType(str, enum.Enum): class ProjectChatVisibility(str, enum.Enum): """Visibility level for project chats.""" + PRIVATE = "private" TEAM = "team" @@ -527,7 +529,9 @@ class ProjectChatMessage(Base): ) def __repr__(self) -> str: - return f"" + return ( + f"" + ) @property def is_user_message(self) -> bool: diff --git a/backend/app/models/project_membership.py b/backend/app/models/project_membership.py index 6c44a3c..4ed5590 100644 --- a/backend/app/models/project_membership.py +++ b/backend/app/models/project_membership.py @@ -2,12 +2,12 @@ from datetime import datetime from enum import Enum +from typing import TYPE_CHECKING from uuid import UUID, uuid4 from sqlalchemy import DateTime, ForeignKey, UniqueConstraint, func from sqlalchemy import Enum as SQLEnum from sqlalchemy.orm import Mapped, mapped_column, relationship -from typing import TYPE_CHECKING from app.database import Base @@ -36,23 +36,15 @@ class ProjectMembership(Base): nullable=False, index=True, ) - user_id: Mapped[UUID] = mapped_column( - ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True - ) - role: Mapped[ProjectRole] = mapped_column( - SQLEnum(ProjectRole, native_enum=False, length=20), nullable=False - ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) + role: Mapped[ProjectRole] = mapped_column(SQLEnum(ProjectRole, native_enum=False, length=20), nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) # Relationships project: Mapped["Project"] = relationship("Project", lazy="joined") user: Mapped["User"] = relationship("User", lazy="joined") - __table_args__ = ( - UniqueConstraint("project_id", "user_id", name="uq_project_user"), - ) + __table_args__ = (UniqueConstraint("project_id", "user_id", name="uq_project_user"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/project_repository.py b/backend/app/models/project_repository.py index d0eab88..3b42d0e 100644 --- a/backend/app/models/project_repository.py +++ b/backend/app/models/project_repository.py @@ -1,10 +1,11 @@ """ProjectRepository model for multi-repository support.""" + import re import uuid from datetime import datetime, timezone from urllib.parse import urlparse -from sqlalchemy import Column, String, Text, Integer, ForeignKey, TIMESTAMP, UniqueConstraint, UUID +from sqlalchemy import TIMESTAMP, UUID, Column, ForeignKey, Integer, String, Text, UniqueConstraint from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.orm import relationship @@ -47,9 +48,7 @@ class ProjectRepository(Base): onupdate=lambda: datetime.now(timezone.utc), ) - __table_args__ = ( - UniqueConstraint("project_id", "slug", name="uq_project_repo_slug"), - ) + __table_args__ = (UniqueConstraint("project_id", "slug", name="uq_project_repo_slug"),) # Relationships project = relationship("Project", back_populates="repositories") diff --git a/backend/app/models/project_share.py b/backend/app/models/project_share.py index 94774f9..17bc433 100644 --- a/backend/app/models/project_share.py +++ b/backend/app/models/project_share.py @@ -53,9 +53,7 @@ class ProjectShare(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -67,11 +65,7 @@ class ProjectShare(Base): project: Mapped["Project"] = relationship("Project", lazy="joined") created_by: Mapped[Optional["User"]] = relationship("User", lazy="joined") - __table_args__ = ( - UniqueConstraint( - "project_id", "subject_type", "subject_id", name="uq_project_subject" - ), - ) + __table_args__ = (UniqueConstraint("project_id", "subject_type", "subject_id", name="uq_project_subject"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/prompt_plan_coverage.py b/backend/app/models/prompt_plan_coverage.py index 849826a..da63b86 100644 --- a/backend/app/models/prompt_plan_coverage.py +++ b/backend/app/models/prompt_plan_coverage.py @@ -2,13 +2,16 @@ PromptPlanCoverageReport model for storing prompt plan quality validation results. """ -from sqlalchemy import Column, Boolean, ForeignKey, DateTime -from sqlalchemy.dialects.postgresql import UUID as PostgresUUID, JSON -from sqlalchemy.orm import relationship -from app.database import Base import uuid from datetime import datetime, timezone +from sqlalchemy import Boolean, Column, DateTime, ForeignKey +from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID +from sqlalchemy.orm import relationship + +from app.database import Base + class PromptPlanCoverageReport(Base): """ @@ -17,6 +20,7 @@ class PromptPlanCoverageReport(Base): Generated by the QA Agent after prompt plan generation. Tracks completeness, correctness, and coverage of implementation phases and MCP methods. """ + __tablename__ = "prompt_plan_coverage_reports" id = Column(PostgresUUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -24,7 +28,7 @@ class PromptPlanCoverageReport(Base): PostgresUUID(as_uuid=True), ForeignKey("spec_versions.id"), nullable=False, - unique=True # One report per prompt plan version + unique=True, # One report per prompt plan version ) # Overall validation status diff --git a/backend/app/models/provisioning.py b/backend/app/models/provisioning.py index c0501c4..4bf30ed 100644 --- a/backend/app/models/provisioning.py +++ b/backend/app/models/provisioning.py @@ -10,7 +10,7 @@ class ProvisioningSource(str, Enum): SSO/SCIM integration in Phase 2 without schema changes. """ - MANUAL = "manual" # Created by admin via UI - INVITE = "invite" # Created via invitation acceptance - SSO_JIT = "sso_jit" # Just-in-time provisioning via SSO (Phase 2) + MANUAL = "manual" # Created by admin via UI + INVITE = "invite" # Created via invitation acceptance + SSO_JIT = "sso_jit" # Just-in-time provisioning via SSO (Phase 2) SSO_SCIM = "sso_scim" # Provisioned via SCIM directory sync (Phase 2) diff --git a/backend/app/models/slack_channel_link.py b/backend/app/models/slack_channel_link.py index 3224db2..73a9ddc 100644 --- a/backend/app/models/slack_channel_link.py +++ b/backend/app/models/slack_channel_link.py @@ -36,18 +36,10 @@ class SlackChannelProjectLink(Base): ), ) - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) - slack_team_id: Mapped[str] = mapped_column( - String(50), nullable=False, index=True - ) - slack_channel_id: Mapped[str] = mapped_column( - String(50), nullable=False, index=True - ) - slack_channel_name: Mapped[str | None] = mapped_column( - String(255), nullable=True - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + slack_team_id: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + slack_channel_id: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + slack_channel_name: Mapped[str | None] = mapped_column(String(255), nullable=True) organization_id: Mapped[uuid.UUID] = mapped_column( UUID(as_uuid=True), ForeignKey("organizations.id", ondelete="CASCADE"), @@ -70,12 +62,8 @@ class SlackChannelProjectLink(Base): ForeignKey("users.id", ondelete="SET NULL"), nullable=True, ) - linked_by_slack_user_id: Mapped[str | None] = mapped_column( - String(50), nullable=True - ) - is_active: Mapped[bool] = mapped_column( - Boolean, nullable=False, default=True - ) + linked_by_slack_user_id: Mapped[str | None] = mapped_column(String(50), nullable=True) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), diff --git a/backend/app/models/slack_user_mapping.py b/backend/app/models/slack_user_mapping.py index f27b23d..1e2cd25 100644 --- a/backend/app/models/slack_user_mapping.py +++ b/backend/app/models/slack_user_mapping.py @@ -27,21 +27,11 @@ class SlackUserMapping(Base): ), ) - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) - slack_team_id: Mapped[str] = mapped_column( - String(50), nullable=False, index=True - ) - slack_user_id: Mapped[str] = mapped_column( - String(50), nullable=False, index=True - ) - slack_display_name: Mapped[str | None] = mapped_column( - String(255), nullable=True - ) - slack_email: Mapped[str | None] = mapped_column( - String(255), nullable=True - ) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + slack_team_id: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + slack_user_id: Mapped[str] = mapped_column(String(50), nullable=False, index=True) + slack_display_name: Mapped[str | None] = mapped_column(String(255), nullable=True) + slack_email: Mapped[str | None] = mapped_column(String(255), nullable=True) user_id: Mapped[uuid.UUID | None] = mapped_column( UUID(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), @@ -69,8 +59,5 @@ class SlackUserMapping(Base): def __repr__(self) -> str: return ( - f"" + f"" ) diff --git a/backend/app/models/spec_coverage.py b/backend/app/models/spec_coverage.py index 70153b5..7922947 100644 --- a/backend/app/models/spec_coverage.py +++ b/backend/app/models/spec_coverage.py @@ -2,13 +2,16 @@ SpecCoverageReport model for storing specification quality validation results. """ -from sqlalchemy import Column, Boolean, ForeignKey, DateTime -from sqlalchemy.dialects.postgresql import UUID as PostgresUUID, JSON -from sqlalchemy.orm import relationship -from app.database import Base import uuid from datetime import datetime, timezone +from sqlalchemy import Boolean, Column, DateTime, ForeignKey +from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy.dialects.postgresql import UUID as PostgresUUID +from sqlalchemy.orm import relationship + +from app.database import Base + class SpecCoverageReport(Base): """ @@ -17,6 +20,7 @@ class SpecCoverageReport(Base): Generated by the QA/Coverage Agent after spec generation. Tracks completeness, consistency, and coverage of discovery requirements. """ + __tablename__ = "spec_coverage_reports" id = Column(PostgresUUID(as_uuid=True), primary_key=True, default=uuid.uuid4) @@ -24,7 +28,7 @@ class SpecCoverageReport(Base): PostgresUUID(as_uuid=True), ForeignKey("spec_versions.id"), nullable=False, - unique=True # One report per spec version + unique=True, # One report per spec version ) # Overall validation status diff --git a/backend/app/models/spec_version.py b/backend/app/models/spec_version.py index afc0abf..ba3aea8 100644 --- a/backend/app/models/spec_version.py +++ b/backend/app/models/spec_version.py @@ -3,9 +3,11 @@ import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, Integer, Boolean, ForeignKey, Enum, DateTime, JSON + +from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, ForeignKey, Integer, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship + from app.database import Base @@ -41,9 +43,7 @@ class SpecVersion(Base): content_json = Column(JSON, nullable=True) # Optional structured representation blocks = Column(JSON, nullable=True) # Structured block content for inline comments created_by = Column(PostgresUUID(as_uuid=True), ForeignKey("users.id"), nullable=True) - created_at = Column( - DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc) - ) + created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) # Relationships project = relationship("Project", back_populates="spec_versions") diff --git a/backend/app/models/team_role.py b/backend/app/models/team_role.py index 48d261e..65c0f15 100644 --- a/backend/app/models/team_role.py +++ b/backend/app/models/team_role.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Optional from uuid import UUID, uuid4 -from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text, UniqueConstraint, func, Integer +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base @@ -46,9 +46,7 @@ class TeamRoleDefinition(Base): description: Mapped[Optional[str]] = mapped_column(Text, nullable=True) order_index: Mapped[int] = mapped_column(Integer, default=0) is_default: Mapped[bool] = mapped_column(Boolean, default=False) # Track if seeded default - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False ) @@ -59,9 +57,7 @@ class TeamRoleDefinition(Base): "ProjectTeamAssignment", back_populates="role_definition", cascade="all, delete-orphan" ) - __table_args__ = ( - UniqueConstraint("org_id", "role_key", name="uq_org_role_key"), - ) + __table_args__ = (UniqueConstraint("org_id", "role_key", name="uq_org_role_key"),) def __repr__(self) -> str: return f"" @@ -92,9 +88,7 @@ class ProjectTeamAssignment(Base): ForeignKey("users.id"), nullable=False, ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) # Relationships project: Mapped["Project"] = relationship("Project", lazy="joined") diff --git a/backend/app/models/thread.py b/backend/app/models/thread.py index c265d8b..4e4d8ab 100644 --- a/backend/app/models/thread.py +++ b/backend/app/models/thread.py @@ -1,14 +1,19 @@ """Thread and Comment models for discussions.""" + import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum as SQLEnum, Boolean, JSON + +from sqlalchemy import JSON, Boolean, Column, DateTime, ForeignKey, String, Text +from sqlalchemy import Enum as SQLEnum from sqlalchemy.orm import relationship + from app.database import Base class ContextType(str, enum.Enum): """Context types for threads.""" + SPEC = "spec" GENERAL = "general" SPEC_DRAFT = "spec_draft" # Thread anchored to a spec draft version block @@ -19,12 +24,14 @@ class ContextType(str, enum.Enum): class ProjectChatVisibility(str, enum.Enum): """Visibility level for project chat threads.""" + PRIVATE = "private" TEAM = "team" class Thread(Base): """Thread model for project discussions.""" + __tablename__ = "threads" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) @@ -38,7 +45,12 @@ class Thread(Base): pending_approval = Column(Boolean, nullable=False, default=False) created_by = Column(String, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) - updated_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime(timezone=True), + nullable=False, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) # Decision summary tracking decision_summary = Column(Text, nullable=True) # Compact decision summary for downstream use @@ -110,7 +122,9 @@ class Thread(Base): proposed_project_key = Column(String(10), nullable=True) # Phase/Project created from this chat (if any) - created_phase_id = Column(String, ForeignKey("brainstorming_phases.id", ondelete="SET NULL"), nullable=True, index=True) + created_phase_id = Column( + String, ForeignKey("brainstorming_phases.id", ondelete="SET NULL"), nullable=True, index=True + ) created_project_id = Column(String, ForeignKey("projects.id", ondelete="SET NULL"), nullable=True, index=True) # Feature creation tracking (stored as JSON array of UUID strings for SQLite compatibility) @@ -119,8 +133,12 @@ class Thread(Base): # Relationships project = relationship("Project", back_populates="threads", foreign_keys=[project_id]) creator = relationship("User", foreign_keys=[created_by]) - comments = relationship("Comment", back_populates="thread", cascade="all, delete-orphan", order_by="Comment.created_at") - items = relationship("ThreadItem", back_populates="thread", cascade="all, delete-orphan", order_by="ThreadItem.created_at") + comments = relationship( + "Comment", back_populates="thread", cascade="all, delete-orphan", order_by="Comment.created_at" + ) + items = relationship( + "ThreadItem", back_populates="thread", cascade="all, delete-orphan", order_by="ThreadItem.created_at" + ) code_explorations = relationship( "CodeExplorationResult", back_populates="thread", @@ -152,6 +170,7 @@ def is_readonly(self): class Comment(Base): """Comment model for thread discussions.""" + __tablename__ = "comments" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) diff --git a/backend/app/models/thread_item.py b/backend/app/models/thread_item.py index e9facb1..1dcc669 100644 --- a/backend/app/models/thread_item.py +++ b/backend/app/models/thread_item.py @@ -1,14 +1,19 @@ """Thread Item model for mixed content in threads (comments + MCQ follow-ups).""" + import enum import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, DateTime, ForeignKey, Enum as SQLEnum, JSON + +from sqlalchemy import JSON, Column, DateTime, ForeignKey, String, Text +from sqlalchemy import Enum as SQLEnum from sqlalchemy.orm import relationship + from app.database import Base class ThreadItemType(str, enum.Enum): """Types of thread items.""" + COMMENT = "comment" MCQ_FOLLOWUP = "mcq_followup" NO_FOLLOWUP_MESSAGE = "no_followup_message" @@ -21,11 +26,27 @@ class ThreadItemType(str, enum.Enum): class ThreadItem(Base): """Thread item model for polymorphic content (comments, MCQs, system messages).""" + __tablename__ = "thread_items" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) thread_id = Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=False, index=True) - item_type = Column(SQLEnum("comment", "mcq_followup", "no_followup_message", "mcq_answer", "implementation_created", "code_exploration", "web_search", "system", name="threaditemtype", create_type=False), nullable=False, index=True) + item_type = Column( + SQLEnum( + "comment", + "mcq_followup", + "no_followup_message", + "mcq_answer", + "implementation_created", + "code_exploration", + "web_search", + "system", + name="threaditemtype", + create_type=False, + ), + nullable=False, + index=True, + ) author_id = Column(String, ForeignKey("users.id"), nullable=False) content_data = Column(JSON, nullable=False) created_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 902482b..4a3a1cc 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -3,18 +3,19 @@ Supports both password-based and OAuth authentication. """ -from datetime import datetime, UTC -from typing import Optional, TYPE_CHECKING + +from datetime import datetime +from typing import TYPE_CHECKING, Optional from uuid import UUID, uuid4 -from sqlalchemy import String, DateTime, func, Boolean, ForeignKey +from sqlalchemy import Boolean, DateTime, ForeignKey, String, func from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base if TYPE_CHECKING: - from app.models.user_identity import UserIdentity from app.models.organization import Organization + from app.models.user_identity import UserIdentity class User(Base): diff --git a/backend/app/models/user_conversation_status.py b/backend/app/models/user_conversation_status.py index ed54ed5..ac69d2f 100644 --- a/backend/app/models/user_conversation_status.py +++ b/backend/app/models/user_conversation_status.py @@ -2,10 +2,11 @@ Tracks the last read sequence number per user per conversation. """ -from datetime import datetime, timezone + +from datetime import datetime from uuid import UUID -from sqlalchemy import Boolean, DateTime, Integer, Enum, ForeignKey, func, Index, PrimaryKeyConstraint +from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, Index, Integer, PrimaryKeyConstraint, func from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import Mapped, mapped_column, relationship diff --git a/backend/app/models/user_group.py b/backend/app/models/user_group.py index cc69422..c439a40 100644 --- a/backend/app/models/user_group.py +++ b/backend/app/models/user_group.py @@ -48,9 +48,7 @@ class UserGroup(Base): default=ProvisioningSource.MANUAL, nullable=False, ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -68,9 +66,7 @@ class UserGroup(Base): lazy="selectin", ) - __table_args__ = ( - UniqueConstraint("org_id", "name", name="uq_org_group_name"), - ) + __table_args__ = (UniqueConstraint("org_id", "name", name="uq_org_group_name"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/user_group_membership.py b/backend/app/models/user_group_membership.py index 0910462..1499c5a 100644 --- a/backend/app/models/user_group_membership.py +++ b/backend/app/models/user_group_membership.py @@ -46,9 +46,7 @@ class UserGroupMembership(Base): default=ProvisioningSource.MANUAL, nullable=False, ) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -57,14 +55,10 @@ class UserGroupMembership(Base): ) # Relationships - group: Mapped["UserGroup"] = relationship( - "UserGroup", back_populates="memberships", lazy="joined" - ) + group: Mapped["UserGroup"] = relationship("UserGroup", back_populates="memberships", lazy="joined") user: Mapped["User"] = relationship("User", lazy="joined") - __table_args__ = ( - UniqueConstraint("group_id", "user_id", name="uq_group_user"), - ) + __table_args__ = (UniqueConstraint("group_id", "user_id", name="uq_group_user"),) def __repr__(self) -> str: return f"" diff --git a/backend/app/models/user_identity.py b/backend/app/models/user_identity.py index dc25887..09ac566 100644 --- a/backend/app/models/user_identity.py +++ b/backend/app/models/user_identity.py @@ -11,19 +11,21 @@ - Check expires_at and proactively refresh tokens before expiration - Store tokens only when provider API access is needed (not just authentication) """ + from datetime import datetime -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from uuid import UUID, uuid4 -from sqlalchemy import String, DateTime, func, ForeignKey, UniqueConstraint -from sqlalchemy.dialects.postgresql import JSON, UUID as PGUUID +from sqlalchemy import DateTime, ForeignKey, String, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy.dialects.postgresql import UUID as PGUUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base if TYPE_CHECKING: - from app.models.user import User from app.models.identity_provider import IdentityProvider + from app.models.user import User class UserIdentity(Base): diff --git a/backend/app/models/user_question_session.py b/backend/app/models/user_question_session.py index b3229ea..191484c 100644 --- a/backend/app/models/user_question_session.py +++ b/backend/app/models/user_question_session.py @@ -6,30 +6,32 @@ """ import enum -from datetime import datetime, timezone -from typing import Optional, List, TYPE_CHECKING +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional from uuid import UUID, uuid4 -from sqlalchemy import String, Integer, DateTime, Text, Enum, ForeignKey, JSON, func +from sqlalchemy import JSON, DateTime, Enum, ForeignKey, Integer, String, Text, func from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base if TYPE_CHECKING: - from app.models.user import User from app.models.brainstorming_phase import BrainstormingPhase from app.models.job import Job + from app.models.user import User class UserQuestionSessionStatus(str, enum.Enum): """Status of a user question session.""" + ACTIVE = "active" ARCHIVED = "archived" class MessageRole(str, enum.Enum): """Role of a message in the chat.""" + USER = "user" ASSISTANT = "assistant" @@ -256,7 +258,4 @@ def get_unadded_questions(self) -> list: return [] added_ids = set(self.added_question_ids or []) - return [ - q for q in self.generated_questions - if q.get("temp_id") not in added_ids - ] + return [q for q in self.generated_questions if q.get("temp_id") not in added_ids] diff --git a/backend/app/models/validators/__init__.py b/backend/app/models/validators/__init__.py index 901c1e7..44d4c4f 100644 --- a/backend/app/models/validators/__init__.py +++ b/backend/app/models/validators/__init__.py @@ -3,6 +3,7 @@ This module registers SQLAlchemy event listeners for model validation. Import validators here to ensure they are registered when models are loaded. """ + from app.models.validators import phase_validators # noqa: F401 __all__ = ["phase_validators"] diff --git a/backend/app/models/validators/phase_validators.py b/backend/app/models/validators/phase_validators.py index a2736b2..ced8277 100644 --- a/backend/app/models/validators/phase_validators.py +++ b/backend/app/models/validators/phase_validators.py @@ -3,18 +3,20 @@ This module contains SQLAlchemy event listeners that validate phase-container relationships on insert and update operations. """ + import logging + from sqlalchemy import event, inspect from sqlalchemy.orm import Session from app.models.brainstorming_phase import BrainstormingPhase - logger = logging.getLogger(__name__) class PhaseContainerValidationError(Exception): """Raised when phase-container validation fails.""" + pass @@ -33,9 +35,7 @@ def _validate_container_sequence(target: BrainstormingPhase) -> None: PhaseContainerValidationError: If sequence is invalid """ if target.container_sequence is not None and target.container_sequence < 1: - raise PhaseContainerValidationError( - f"container_sequence must be >= 1, got {target.container_sequence}" - ) + raise PhaseContainerValidationError(f"container_sequence must be >= 1, got {target.container_sequence}") def _validate_container_project_consistency(target: BrainstormingPhase) -> None: @@ -56,14 +56,11 @@ def _validate_container_project_consistency(target: BrainstormingPhase) -> None: return from app.models.phase_container import PhaseContainer - container = db.query(PhaseContainer).filter( - PhaseContainer.id == target.container_id - ).first() + + container = db.query(PhaseContainer).filter(PhaseContainer.id == target.container_id).first() if container is None: - raise PhaseContainerValidationError( - f"Container {target.container_id} does not exist" - ) + raise PhaseContainerValidationError(f"Container {target.container_id} does not exist") if container.project_id != target.project_id: raise PhaseContainerValidationError( @@ -89,14 +86,11 @@ def _validate_container_not_archived(target: BrainstormingPhase) -> None: return from app.models.phase_container import PhaseContainer - container = db.query(PhaseContainer).filter( - PhaseContainer.id == target.container_id - ).first() + + container = db.query(PhaseContainer).filter(PhaseContainer.id == target.container_id).first() if container and container.archived_at is not None: - raise PhaseContainerValidationError( - f"Cannot assign phase to archived container {target.container_id}" - ) + raise PhaseContainerValidationError(f"Cannot assign phase to archived container {target.container_id}") def _validate_sequence_uniqueness(target: BrainstormingPhase) -> None: @@ -115,11 +109,15 @@ def _validate_sequence_uniqueness(target: BrainstormingPhase) -> None: if db is None: return - existing = db.query(BrainstormingPhase).filter( - BrainstormingPhase.container_id == target.container_id, - BrainstormingPhase.container_sequence == target.container_sequence, - BrainstormingPhase.id != target.id, - ).first() + existing = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.container_id == target.container_id, + BrainstormingPhase.container_sequence == target.container_sequence, + BrainstormingPhase.id != target.id, + ) + .first() + ) if existing: raise PhaseContainerValidationError( @@ -160,8 +158,7 @@ def validate_phase_on_update(mapper, connection, target): # Only validate if container-related fields changed container_changed = ( - insp.attrs.container_id.history.has_changes() or - insp.attrs.container_sequence.history.has_changes() + insp.attrs.container_id.history.has_changes() or insp.attrs.container_sequence.history.has_changes() ) if container_changed and target.container_id is not None: diff --git a/backend/app/models/vfs_metadata.py b/backend/app/models/vfs_metadata.py index f7cd8ee..275be40 100644 --- a/backend/app/models/vfs_metadata.py +++ b/backend/app/models/vfs_metadata.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime, timezone -from sqlalchemy import Column, String, Text, ForeignKey, DateTime, Index +from sqlalchemy import Column, DateTime, ForeignKey, Index, String, Text from sqlalchemy.dialects.postgresql import UUID as PostgresUUID from sqlalchemy.orm import relationship diff --git a/backend/app/permissions/__init__.py b/backend/app/permissions/__init__.py index 054cc1e..88521db 100644 --- a/backend/app/permissions/__init__.py +++ b/backend/app/permissions/__init__.py @@ -3,18 +3,18 @@ This module provides organization and project-level permission checking and role-based access control. """ -from app.permissions.context import OrgContext, get_org_context, ProjectContext, get_project_context +from app.permissions.context import OrgContext, ProjectContext, get_org_context, get_project_context from app.permissions.helpers import ( is_org_admin_or_higher, is_org_member_or_higher, is_org_owner, - require_org_role, - role_rank, is_project_admin_or_higher, is_project_member_or_higher, is_project_owner, project_role_rank, + require_org_role, require_project_role, + role_rank, ) __all__ = [ diff --git a/backend/app/permissions/context.py b/backend/app/permissions/context.py index 6ad6f9f..35eb8cb 100644 --- a/backend/app/permissions/context.py +++ b/backend/app/permissions/context.py @@ -8,7 +8,7 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import Organization, OrgMembership, User, Project, ProjectShare +from app.models import Organization, OrgMembership, Project, ProjectShare, User from app.services.org_service import OrgService from app.services.project_service import ProjectService from app.services.project_share_service import ProjectShareService @@ -175,9 +175,7 @@ def get_settings( # Load user's effective project share (direct or via group) # Use the actual project.id (UUID) for share lookup - project_share = ProjectShareService.get_user_effective_share( - db, project.id, current_user.id - ) + project_share = ProjectShareService.get_user_effective_share(db, project.id, current_user.id) if not project_share: # Use 404 instead of 403 to avoid leaking project existence # This is a privacy best practice (e.g., GitHub does this for private repos) @@ -219,7 +217,8 @@ def get_org_context_from_project( if isinstance(project_id, str): try: from uuid import UUID as PyUUID - project_id = PyUUID(project_id) if '-' in project_id else project_id + + project_id = PyUUID(project_id) if "-" in project_id else project_id except (ValueError, AttributeError): pass @@ -233,6 +232,7 @@ def get_org_context_from_project( # Verify user is org member (ensure UUIDs are used) from uuid import UUID as PyUUID + org_id = project.org_id if isinstance(project.org_id, PyUUID) else PyUUID(str(project.org_id)) user_id = current_user.id if isinstance(current_user.id, PyUUID) else PyUUID(str(current_user.id)) org_membership = OrgService.get_org_membership(db, org_id, user_id) diff --git a/backend/app/plugin_registry.py b/backend/app/plugin_registry.py index ca22a39..1d681b0 100644 --- a/backend/app/plugin_registry.py +++ b/backend/app/plugin_registry.py @@ -10,8 +10,8 @@ """ import logging -from dataclasses import dataclass, field -from typing import Any, Callable +from dataclasses import dataclass +from typing import Callable from fastapi import APIRouter @@ -39,7 +39,9 @@ class AuthProviderPlugin: class InvitationPlugin: """Plugin for enterprise invitation management (e.g. Scalekit invitations).""" - on_create: Callable | None = None # fn(org, emails, role, invited_by_user_id, group_ids, db) -> InvitationCreateResponse + on_create: Callable | None = ( + None # fn(org, emails, role, invited_by_user_id, group_ids, db) -> InvitationCreateResponse + ) on_cancel: Callable | None = None # fn(org, invitation, db) -> None on_resend: Callable | None = None # fn(org, invitation, db) -> InvitationSendResult on_remove_member: Callable | None = None # fn(db, org, user_id) -> None diff --git a/backend/app/routers/__init__.py b/backend/app/routers/__init__.py index f2f3c9a..d174f9d 100644 --- a/backend/app/routers/__init__.py +++ b/backend/app/routers/__init__.py @@ -1,6 +1,7 @@ """ API routers for the MFBT backend. """ + from app.routers import auth, orgs, projects, threads __all__ = ["auth", "orgs", "projects", "threads"] diff --git a/backend/app/routers/activity.py b/backend/app/routers/activity.py index 1720bd8..205a5ce 100644 --- a/backend/app/routers/activity.py +++ b/backend/app/routers/activity.py @@ -1,16 +1,14 @@ """Router for activity log operations.""" -from typing import Optional -from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole -from app.permissions.context import get_project_context, ProjectContext +from app.models import ProjectRole, User +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role -from app.schemas.activity import ActivityLogResponse, ActivityListResponse +from app.schemas.activity import ActivityListResponse, ActivityLogResponse from app.services.activity_log_service import ActivityLogService router = APIRouter(tags=["activity"]) diff --git a/backend/app/routers/agent_api.py b/backend/app/routers/agent_api.py index c92afd2..ea4f594 100644 --- a/backend/app/routers/agent_api.py +++ b/backend/app/routers/agent_api.py @@ -6,29 +6,30 @@ from datetime import datetime from typing import Annotated, List, Optional -from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Header, status +from fastapi import APIRouter, Depends, Header, HTTPException, status from pydantic import BaseModel, ConfigDict from sqlalchemy.orm import Session from app.auth.dependencies import ApiKeyAuth from app.database import get_db -from app.models.user import User -from app.models.project import Project from app.models.brainstorming_phase import BrainstormingPhase -from app.models.final_spec import FinalSpec +from app.models.feature import Feature, FeatureStatus from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec from app.models.module import Module -from app.models.feature import Feature, FeatureStatus +from app.models.project import Project +from app.models.user import User router = APIRouter(prefix="/projects", tags=["Agent API"]) # ==================== Response Schemas ==================== + class FinalSpecResponse(BaseModel): """Response for final spec endpoint.""" + model_config = ConfigDict(from_attributes=True) content_markdown: str @@ -38,6 +39,7 @@ class FinalSpecResponse(BaseModel): class FinalPromptPlanResponse(BaseModel): """Response for final prompt plan endpoint.""" + model_config = ConfigDict(from_attributes=True) content_markdown: str @@ -47,6 +49,7 @@ class FinalPromptPlanResponse(BaseModel): class FeatureResponse(BaseModel): """Feature data for agent consumption.""" + model_config = ConfigDict(from_attributes=True) id: str @@ -59,6 +62,7 @@ class FeatureResponse(BaseModel): class ModuleResponse(BaseModel): """Module with nested features for agent consumption.""" + model_config = ConfigDict(from_attributes=True) id: str @@ -69,11 +73,13 @@ class ModuleResponse(BaseModel): class ModulesFeaturesResponse(BaseModel): """Response for modules-features endpoint.""" + modules: List[ModuleResponse] # ==================== Helper Functions ==================== + async def get_api_key_auth( project_id: str, db: Annotated[Session, Depends(get_db)], @@ -87,6 +93,7 @@ async def get_api_key_auth( # ==================== Endpoints ==================== + @router.get( "/{project_id}/final-spec", response_model=FinalSpecResponse, diff --git a/backend/app/routers/analytics.py b/backend/app/routers/analytics.py index 446c2f2..86f0f7d 100644 --- a/backend/app/routers/analytics.py +++ b/backend/app/routers/analytics.py @@ -13,8 +13,8 @@ from app.auth.platform_admin import require_platform_admin from app.database import get_db -from app.models.user import User from app.models.organization import Organization +from app.models.user import User from app.schemas.analytics import ( EfficiencyOverviewResponse, OrgEfficiencyResponse, @@ -23,8 +23,8 @@ TopProjectsResponse, TopUsersResponse, ) -from app.services.analytics_service import AnalyticsService from app.services.analytics_cache import invalidate_analytics_cache +from app.services.analytics_service import AnalyticsService router = APIRouter(prefix="/analytics", tags=["analytics"]) @@ -146,6 +146,4 @@ def get_org_efficiency_overview( Includes top users and projects per organization. Platform admins only. """ - return AnalyticsService.get_org_efficiency_overview( - db, time_range, org_id, limit_users, limit_projects - ) + return AnalyticsService.get_org_efficiency_overview(db, time_range, org_id, limit_users, limit_projects) diff --git a/backend/app/routers/api_keys.py b/backend/app/routers/api_keys.py index 675f463..c8d9596 100644 --- a/backend/app/routers/api_keys.py +++ b/backend/app/routers/api_keys.py @@ -7,18 +7,18 @@ from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user +from app.config import settings from app.database import get_db from app.models.user import User from app.permissions.context import ProjectContext, get_project_context from app.schemas.api_key import ( ApiKeyCreate, ApiKeyCreateResponse, - ApiKeyResponse, ApiKeyList, + ApiKeyResponse, MCPConnectionConfig, ) from app.services.api_key_service import ApiKeyService -from app.config import settings # User-level API key endpoints router = APIRouter(prefix="/users/me/api-keys", tags=["API Keys"]) @@ -37,9 +37,7 @@ async def create_api_key( API keys can be used to authenticate MCP requests to any project the user has access to. """ - api_key, raw_key = ApiKeyService.create_api_key( - db, user_id=current_user.id, data=data - ) + api_key, raw_key = ApiKeyService.create_api_key(db, user_id=current_user.id, data=data) return ApiKeyService.to_response(api_key, raw_key=raw_key) @@ -71,9 +69,7 @@ async def revoke_api_key( The key will no longer be usable for authentication. """ - api_key = ApiKeyService.revoke_api_key( - db, key_id=key_id, user_id=current_user.id - ) + api_key = ApiKeyService.revoke_api_key(db, key_id=key_id, user_id=current_user.id) return ApiKeyService.to_response(api_key) diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index cc7fa00..aa6a883 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -7,6 +7,7 @@ Enterprise auth providers (e.g. Scalekit SSO) are integrated via the plugin registry — see app/plugin_registry.py. """ + import json import logging from typing import Annotated, Any @@ -15,9 +16,12 @@ from fastapi import APIRouter, Cookie, Depends, HTTPException, Request, Response, status from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm +from jose import JWTError +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user +from app.auth.domain_validation import validate_signup_domain from app.auth.providers import ( get_configured_providers, get_provider_client, @@ -26,34 +30,26 @@ ) from app.auth.service import AuthService from app.auth.utils import create_access_token, decode_access_token -from jose import JWTError from app.config import settings -from app.database import get_db, get_async_db +from app.database import get_async_db, get_db from app.models.user import User -from app.models.organization import Organization -from sqlalchemy.ext.asyncio import AsyncSession +from app.plugin_registry import get_plugin_registry +from app.schemas.api_key import ApiKeyCreate from app.schemas.auth import ( - UserCreate, - UserLogin, - UserResponse, - TokenResponse, OrgMembershipResponse, - SwitchOrgRequest, RegistrationResponse, ResendVerificationRequest, ResendVerificationResponse, + SwitchOrgRequest, + TokenResponse, + UserCreate, + UserResponse, ) from app.schemas.oauth import OAuthProviderInfo -from app.services.user_service import UserService -from app.services.org_service import OrgService -from app.services.email_service import EmailService from app.services.api_key_service import ApiKeyService -from app.schemas.api_key import ApiKeyCreate -from app.services.invitation_service import InvitationService -from app.models.org_membership import OrgRole, ProvisioningSource -from app.auth.domain_validation import validate_signup_domain -from app.plugin_registry import get_plugin_registry - +from app.services.email_service import EmailService +from app.services.org_service import OrgService +from app.services.user_service import UserService logger = logging.getLogger(__name__) @@ -139,6 +135,7 @@ def _get_provider_info() -> dict[str, dict[str, str]]: def _get_known_provider_slugs() -> set[str]: """Get all known provider slugs (base + plugin).""" from app.auth.providers import _get_known_providers + return _get_known_providers() @@ -200,15 +197,14 @@ async def register( # Create sample onboarding project for new users try: from app.services.sample_project_service import SampleProjectService + SampleProjectService.create_sample_project(db, org.id, user.id) except Exception as e: _auth_log("warning", "sample_project_creation_failed", user_id=user.id, error=str(e)) # Auto-create default API key for MCP access try: - ApiKeyService.create_api_key( - db, user_id=user.id, data=ApiKeyCreate(name="Default MCP Key") - ) + ApiKeyService.create_api_key(db, user_id=user.id, data=ApiKeyCreate(name="Default MCP Key")) _auth_log("info", "default_api_key_created", user_id=user.id) except Exception as e: _auth_log("warning", "default_api_key_creation_failed", user_id=user.id, error=str(e)) @@ -595,8 +591,9 @@ def logout( except Exception as e: logger.warning(f"Plugin {slug} logout hook failed: {e}") - _auth_log("info", "logout", user_id=current_user.id, email=current_user.email, - plugin_logout=logout_result is not None) + _auth_log( + "info", "logout", user_id=current_user.id, email=current_user.email, plugin_logout=logout_result is not None + ) result = {"message": "Logged out successfully"} if logout_result and logout_result.get("redirect_url"): @@ -1165,15 +1162,14 @@ async def oauth_callback( if user.current_org_id: try: from app.services.sample_project_service import SampleProjectService + SampleProjectService.create_sample_project(db, user.current_org_id, user.id) except Exception as e: _auth_log("warning", "sample_project_creation_failed", user_id=user.id, error=str(e)) # Auto-create default API key for MCP access try: - ApiKeyService.create_api_key( - db, user_id=user.id, data=ApiKeyCreate(name="Default MCP Key") - ) + ApiKeyService.create_api_key(db, user_id=user.id, data=ApiKeyCreate(name="Default MCP Key")) _auth_log("info", "default_api_key_created", user_id=user.id) except Exception as e: _auth_log("warning", "default_api_key_creation_failed", user_id=user.id, error=str(e)) diff --git a/backend/app/routers/brainstorming_phases.py b/backend/app/routers/brainstorming_phases.py index d799d69..bf075fa 100644 --- a/backend/app/routers/brainstorming_phases.py +++ b/backend/app/routers/brainstorming_phases.py @@ -1,4 +1,5 @@ """Router for brainstorming phase operations.""" + import logging from typing import List from uuid import UUID @@ -8,24 +9,24 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, OrgRole, ProjectRole -from app.models.job import Job, JobType, JobStatus -from app.permissions.context import get_project_context, ProjectContext +from app.models import ProjectRole, User +from app.models.job import Job, JobStatus, JobType +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role from app.schemas.brainstorming_phase import ( + BrainstormFeatureGenerationStatusResponse, BrainstormingPhaseCreate, - BrainstormingPhaseUpdate, - BrainstormingPhaseResponse, BrainstormingPhaseListResponse, - BrainstormSpecGenerationStatusResponse, + BrainstormingPhaseResponse, + BrainstormingPhaseUpdate, BrainstormPromptPlanGenerationStatusResponse, - BrainstormFeatureGenerationStatusResponse, + BrainstormSpecGenerationStatusResponse, PhaseImplementationProgressResponse, ) +from app.services.activity_log_service import ActivityEventTypes, ActivityLogService from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.job_service import JobService from app.services.kafka_producer import get_kafka_producer -from app.services.activity_log_service import ActivityLogService, ActivityEventTypes logger = logging.getLogger(__name__) @@ -103,9 +104,7 @@ async def create_brainstorming_phase( if not success: # Log warning but don't fail the phase creation - logger.warning( - f"Failed to publish brainstorm generation job {job.id} for phase {phase.id}" - ) + logger.warning(f"Failed to publish brainstorm generation job {job.id} for phase {phase.id}") JobService.update_job_status( db=db, job_id=job.id, @@ -113,9 +112,7 @@ async def create_brainstorming_phase( error_message="Failed to publish job to Kafka", ) else: - logger.info( - f"Brainstorm generation job {job.id} queued for phase {phase.id}" - ) + logger.info(f"Brainstorm generation job {job.id} queued for phase {phase.id}") # Note: 2nd generation is triggered immediately after the first generation # completes successfully (PHASE_CREATED trigger handled in worker handler) @@ -528,10 +525,12 @@ async def generate_conversations( db.query(Job) .filter( Job.project_id == phase.project_id, - Job.job_type.in_([ - JobType.BRAINSTORM_CONVERSATION_GENERATE, - JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, - ]), + Job.job_type.in_( + [ + JobType.BRAINSTORM_CONVERSATION_GENERATE, + JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, + ] + ), Job.status.in_([JobStatus.QUEUED, JobStatus.RUNNING]), ) .first() @@ -669,13 +668,11 @@ async def generate_modules_from_spec( # Check if Final Spec exists, or auto-create from latest draft if no threads from app.models.final_spec import FinalSpec - from app.models.spec_version import SpecVersion, SpecType + from app.models.spec_version import SpecType, SpecVersion from app.models.thread import Thread from app.services.finalization_service import FinalizationService - final_spec = db.query(FinalSpec).filter( - FinalSpec.brainstorming_phase_id == phase.id - ).first() + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == phase.id).first() if not final_spec: # No FinalSpec exists - check if we can auto-create from latest draft @@ -698,9 +695,7 @@ async def generate_modules_from_spec( # Check if any draft has threads draft_ids = [str(d.id) for d in drafts] - has_threads = db.query(Thread).filter( - Thread.version_id.in_(draft_ids) - ).first() is not None + has_threads = db.query(Thread).filter(Thread.version_id.in_(draft_ids)).first() is not None if has_threads: # Threads exist - user must manually finalize to incorporate discussions @@ -855,7 +850,7 @@ def get_spec_generation_status( ) # Get spec drafts - from app.models.spec_version import SpecVersion, SpecType + from app.models.spec_version import SpecType, SpecVersion spec_drafts = ( db.query(SpecVersion) @@ -874,22 +869,17 @@ def get_spec_generation_status( # Check for final spec from app.models.final_spec import FinalSpec - final_spec = ( - db.query(FinalSpec) - .filter(FinalSpec.brainstorming_phase_id == phase.id) - .first() - ) + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == phase.id).first() has_final_spec = final_spec is not None # Get conversation threads and their answer status - from app.models.thread import Thread, ContextType - from app.models.thread_item import ThreadItem - # Get all features for this phase and count answered ones # Uses same logic as Conversations page: feature is "answered" when its thread # has an MCQ item with selected_option_id set from app.models.feature import Feature, FeatureType from app.models.module import Module + from app.models.thread import ContextType, Thread + from app.models.thread_item import ThreadItem # Find modules for this phase modules = ( @@ -907,15 +897,19 @@ def get_spec_generation_status( from app.models.feature import FeatureVisibilityStatus features = ( - db.query(Feature) - .filter( - Feature.module_id.in_(module_ids), - Feature.status == "active", - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + ( + db.query(Feature) + .filter( + Feature.module_id.in_(module_ids), + Feature.status == "active", + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + ) + .all() ) - .all() - ) if module_ids else [] + if module_ids + else [] + ) feature_ids = [str(f.id) for f in features] total_threads = len(features) # Count features, not threads @@ -942,23 +936,15 @@ def get_spec_generation_status( answered_feature_ids.add(context_id) answered_threads = len(answered_feature_ids) - thread_progress_percentage = ( - int((answered_threads / total_threads) * 100) - if total_threads > 0 - else 0 - ) + thread_progress_percentage = int((answered_threads / total_threads) * 100) if total_threads > 0 else 0 # Generate warnings warnings = [] if total_threads == 0: - warnings.append( - "No conversations found. Generate conversations first for better specification quality." - ) + warnings.append("No conversations found. Generate conversations first for better specification quality.") elif answered_threads == 0: - warnings.append( - "No conversations have been answered yet. The generated specification may be limited." - ) + warnings.append("No conversations have been answered yet. The generated specification may be limited.") elif thread_progress_percentage < 50: warnings.append( f"Only {thread_progress_percentage}% of conversations have responses. " @@ -967,8 +953,7 @@ def get_spec_generation_status( if has_existing_draft: warnings.append( - f"This will create a new draft version (v{latest_draft_version + 1}). " - f"Current: v{latest_draft_version}" + f"This will create a new draft version (v{latest_draft_version + 1}). Current: v{latest_draft_version}" ) return BrainstormSpecGenerationStatusResponse( @@ -1019,7 +1004,7 @@ def get_prompt_plan_generation_status( ) # Get prompt plan drafts - from app.models.spec_version import SpecVersion, SpecType + from app.models.spec_version import SpecType, SpecVersion plan_drafts = ( db.query(SpecVersion) @@ -1038,21 +1023,13 @@ def get_prompt_plan_generation_status( # Check for final prompt plan from app.models.final_prompt_plan import FinalPromptPlan - final_plan = ( - db.query(FinalPromptPlan) - .filter(FinalPromptPlan.brainstorming_phase_id == phase.id) - .first() - ) + final_plan = db.query(FinalPromptPlan).filter(FinalPromptPlan.brainstorming_phase_id == phase.id).first() has_final_plan = final_plan is not None # Check for spec (final or draft) from app.models.final_spec import FinalSpec - final_spec = ( - db.query(FinalSpec) - .filter(FinalSpec.brainstorming_phase_id == phase.id) - .first() - ) + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == phase.id).first() has_final_spec = final_spec is not None spec_drafts = ( @@ -1082,18 +1059,13 @@ def get_prompt_plan_generation_status( warnings = [] if not can_generate: - warnings.append( - "No specification exists. Generate a specification first." - ) + warnings.append("No specification exists. Generate a specification first.") elif spec_source == "draft": - warnings.append( - f"No Final Spec found. Will use the latest draft (v{spec_draft_version}) instead." - ) + warnings.append(f"No Final Spec found. Will use the latest draft (v{spec_draft_version}) instead.") if has_existing_draft: warnings.append( - f"This will create a new draft version (v{latest_draft_version + 1}). " - f"Current: v{latest_draft_version}" + f"This will create a new draft version (v{latest_draft_version + 1}). Current: v{latest_draft_version}" ) return BrainstormPromptPlanGenerationStatusResponse( @@ -1146,57 +1118,72 @@ def get_feature_generation_status( # Get module and feature counts for IMPLEMENTATION modules only # (not brainstorming conversation aspects which are also stored as modules) + from app.models.feature import Feature, FeatureStatus, FeatureType from app.models.module import Module, ModuleType - from app.models.feature import Feature, FeatureType, FeatureStatus # Count implementation modules (using module_type field) - module_count = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).count() + module_count = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == phase.id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .count() + ) has_existing_modules = module_count > 0 # Count implementation features (using feature_type field) - feature_count = db.query(Feature).join(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.status == FeatureStatus.ACTIVE, - ).count() + feature_count = ( + db.query(Feature) + .join(Module) + .filter( + Module.brainstorming_phase_id == phase.id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.status == FeatureStatus.ACTIVE, + ) + .count() + ) # Check spec drafts - from app.models.spec_version import SpecVersion, SpecType + from app.models.spec_version import SpecType, SpecVersion - spec_drafts = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == phase.id, - SpecVersion.spec_type == SpecType.SPECIFICATION, - ).order_by(SpecVersion.version.desc()).all() + spec_drafts = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == phase.id, + SpecVersion.spec_type == SpecType.SPECIFICATION, + ) + .order_by(SpecVersion.version.desc()) + .all() + ) has_spec_draft = len(spec_drafts) > 0 spec_draft_version = spec_drafts[0].version if spec_drafts else None # Check final spec from app.models.final_spec import FinalSpec - final_spec = db.query(FinalSpec).filter( - FinalSpec.brainstorming_phase_id == phase.id - ).first() + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == phase.id).first() has_final_spec = final_spec is not None # Check prompt plan drafts (also uses SpecVersion with different spec_type) - plan_drafts = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == phase.id, - SpecVersion.spec_type == SpecType.PROMPT_PLAN, - ).order_by(SpecVersion.version.desc()).all() + plan_drafts = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == phase.id, + SpecVersion.spec_type == SpecType.PROMPT_PLAN, + ) + .order_by(SpecVersion.version.desc()) + .all() + ) has_prompt_plan_draft = len(plan_drafts) > 0 # Check final prompt plan from app.models.final_prompt_plan import FinalPromptPlan - final_plan = db.query(FinalPromptPlan).filter( - FinalPromptPlan.brainstorming_phase_id == phase.id - ).first() + final_plan = db.query(FinalPromptPlan).filter(FinalPromptPlan.brainstorming_phase_id == phase.id).first() has_final_prompt_plan = final_plan is not None # Determine spec source and can_generate @@ -1215,22 +1202,14 @@ def get_feature_generation_status( warnings = [] if not has_spec_draft and not has_final_spec: - warnings.append( - "No specification exists. Generate a specification first." - ) + warnings.append("No specification exists. Generate a specification first.") elif not has_final_spec: - warnings.append( - f"Specification has not been finalized. Will use draft v{spec_draft_version}." - ) + warnings.append(f"Specification has not been finalized. Will use draft v{spec_draft_version}.") if not has_prompt_plan_draft and not has_final_prompt_plan: - warnings.append( - "No prompt plan exists. Consider generating one for better implementation guidance." - ) + warnings.append("No prompt plan exists. Consider generating one for better implementation guidance.") elif not has_final_prompt_plan: - warnings.append( - "Prompt plan has not been finalized. Will use the latest draft." - ) + warnings.append("Prompt plan has not been finalized. Will use the latest draft.") if has_existing_modules and can_generate: warnings.append( @@ -1454,14 +1433,10 @@ async def generate_brainstorm_prompt_plan( ) # Validate prerequisite: spec must exist (final or draft) - from app.models.spec_version import SpecVersion, SpecType from app.models.final_spec import FinalSpec + from app.models.spec_version import SpecType, SpecVersion - final_spec = ( - db.query(FinalSpec) - .filter(FinalSpec.brainstorming_phase_id == phase.id) - .first() - ) + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == phase.id).first() spec_draft = ( db.query(SpecVersion) @@ -1605,9 +1580,7 @@ def get_pending_questions_count( detail="Brainstorming phase not found", ) - counts = BrainstormingPhaseService.get_pending_questions_count( - db=db, brainstorming_phase_id=phase.id - ) + counts = BrainstormingPhaseService.get_pending_questions_count(db=db, brainstorming_phase_id=phase.id) return counts @@ -1647,9 +1620,7 @@ def get_pending_questions( detail="Brainstorming phase not found", ) - result = BrainstormingPhaseService.get_pending_questions( - db=db, brainstorming_phase_id=phase.id - ) + result = BrainstormingPhaseService.get_pending_questions(db=db, brainstorming_phase_id=phase.id) return result @@ -1850,7 +1821,6 @@ async def cancel_phase_generation( cancelled_jobs: Number of jobs cancelled cleared_flag: Name of the flag that was cleared """ - from typing import Literal from app.services.kafka_producer import get_sync_kafka_producer generation_type = request_data.get("generation_type") @@ -1893,9 +1863,7 @@ async def cancel_phase_generation( ) # Cancel any running jobs for this generation type - cancelled_count = JobService.cancel_jobs_for_phase_generation( - db, phase.id, generation_type - ) + cancelled_count = JobService.cancel_jobs_for_phase_generation(db, phase.id, generation_type) # Map generation type to flag name flag_map = { diff --git a/backend/app/routers/conversations.py b/backend/app/routers/conversations.py index 37c543b..cf6f079 100644 --- a/backend/app/routers/conversations.py +++ b/backend/app/routers/conversations.py @@ -5,33 +5,31 @@ These endpoints are the primary API for the unified chat frontend components. """ + import logging from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from sqlalchemy.orm import Session, joinedload from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ContextType, Thread, Comment -from app.models.thread import ProjectChatVisibility -from app.models.thread_item import ThreadItem, ThreadItemType +from app.models import ContextType, Thread, User from app.models.implementation import Implementation -from app.models.job import Job, JobType, JobStatus -from app.permissions.context import get_project_context, ProjectContext, get_org_context, OrgContext -from app.services.thread_service import ThreadService -from app.services.job_service import JobService -from app.services.kafka_producer import get_kafka_producer -from app.services.project_share_service import ProjectShareService -from app.services.project_service import ProjectService -from app.schemas.thread import ThreadResponse +from app.models.job import Job, JobStatus, JobType +from app.models.thread_item import ThreadItem, ThreadItemType +from app.permissions.context import ProjectContext, get_project_context from app.schemas.thread_item import ( - thread_item_to_response, ToggleReactionRequest, ToggleReactionResponse, + thread_item_to_response, ) +from app.services.job_service import JobService +from app.services.kafka_producer import get_kafka_producer +from app.services.project_share_service import ProjectShareService +from app.services.thread_service import ThreadService logger = logging.getLogger(__name__) @@ -42,14 +40,17 @@ # Pydantic schemas for unified conversations # ============================================= + class CreateProjectChatConversationRequest(BaseModel): """Request to create a new project chat conversation.""" + visibility: str = Field(default="private", description="Visibility: 'private' or 'team'") initial_message: Optional[str] = Field(None, description="Optional initial message to send") class ProjectChatConversationResponse(BaseModel): """Response for a project chat conversation.""" + id: str org_id: str project_id: Optional[str] @@ -110,12 +111,14 @@ class Config: class ProjectChatConversationWithItems(BaseModel): """Project chat conversation with its items.""" + conversation: ProjectChatConversationResponse items: List[dict] # ThreadItem responses class ProjectChatListItem(BaseModel): """List item for project chat sidebar.""" + id: str chat_title: Optional[str] visibility: str @@ -130,6 +133,7 @@ class ProjectChatListItem(BaseModel): class ProjectChatListResponse(BaseModel): """Response for listing project chats.""" + conversations: List[ProjectChatListItem] total: int has_more: bool @@ -137,17 +141,20 @@ class ProjectChatListResponse(BaseModel): class UpdateVisibilityRequest(BaseModel): """Request to update conversation visibility.""" + visibility: str = Field(..., description="'private' or 'team'") class SendMessageRequest(BaseModel): """Request to send a message in a conversation.""" + content: str = Field(..., min_length=1, description="Message content") images: Optional[List[dict]] = Field(None, description="Image attachments") class SendMessageResponse(BaseModel): """Response after sending a message.""" + item: dict # ThreadItem response job_id: Optional[str] = None # If AI response was triggered @@ -156,6 +163,7 @@ class SendMessageResponse(BaseModel): # Helper functions # ============================================= + def _thread_to_project_chat_response(thread: Thread) -> ProjectChatConversationResponse: """Convert a PROJECT_CHAT Thread to response schema.""" from app.utils.short_id import build_url_identifier @@ -175,7 +183,6 @@ def _thread_to_project_chat_response(thread: Thread) -> ProjectChatConversationR running_summary=thread.running_summary, short_id=thread.short_id, url_identifier=url_identifier, - # Generation state is_generating=thread.is_generating_ai_response, is_exploring_code=thread.is_exploring_code, @@ -183,40 +190,35 @@ def _thread_to_project_chat_response(thread: Thread) -> ProjectChatConversationR exploring_code_prompt=thread.exploring_code_prompt, searching_web_query=thread.searching_web_query, retry_status=thread.retry_status, - # Error state ai_error_message=thread.ai_error_message, ai_error_job_id=str(thread.ai_error_job_id) if thread.ai_error_job_id else None, ai_error_user_message=thread.ai_error_user_message, - # Phase proposal state proposed_title=thread.proposed_title, proposed_description=thread.proposed_description, ready_to_create_phase=thread.ready_to_create_phase, - # Feature proposal state ready_to_create_feature=thread.ready_to_create_feature, proposed_feature_title=thread.proposed_feature_title, proposed_feature_description=thread.proposed_feature_description, - proposed_feature_module_id=str(thread.proposed_feature_module_id) if thread.proposed_feature_module_id else None, + proposed_feature_module_id=str(thread.proposed_feature_module_id) + if thread.proposed_feature_module_id + else None, proposed_feature_module_title=thread.proposed_feature_module_title, proposed_feature_module_description=thread.proposed_feature_module_description, - # Project proposal state ready_to_create_project=thread.ready_to_create_project, proposed_project_name=thread.proposed_project_name, proposed_project_description=thread.proposed_project_description, proposed_project_tech_stack=thread.proposed_project_tech_stack, proposed_project_key=thread.proposed_project_key, - # Created entities created_phase_id=str(thread.created_phase_id) if thread.created_phase_id else None, created_project_id=str(thread.created_project_id) if thread.created_project_id else None, created_feature_ids=thread.created_feature_ids, - # Readonly flag is_readonly=thread.is_readonly, - created_at=thread.created_at.isoformat(), updated_at=thread.updated_at.isoformat(), ) @@ -226,6 +228,7 @@ def _thread_to_project_chat_response(thread: Thread) -> ProjectChatConversationR # PROJECT_CHAT specific endpoints # ============================================= + @router.post( "/projects/{project_id}/conversations", response_model=ProjectChatConversationResponse, @@ -356,23 +359,26 @@ async def list_project_conversations( last_message_preview = content[:100] + "..." if len(content) > 100 else content from app.utils.short_id import build_url_identifier + url_identifier = None if thread.short_id: title = thread.chat_title or "chat" url_identifier = build_url_identifier(title, thread.short_id) - items.append(ProjectChatListItem( - id=str(thread.id), - chat_title=thread.chat_title, - visibility=thread.visibility if thread.visibility else "private", - short_id=thread.short_id, - url_identifier=url_identifier, - created_at=thread.created_at.isoformat(), - updated_at=thread.updated_at.isoformat(), - is_readonly=thread.is_readonly, - message_count=message_count, - last_message_preview=last_message_preview, - )) + items.append( + ProjectChatListItem( + id=str(thread.id), + chat_title=thread.chat_title, + visibility=thread.visibility if thread.visibility else "private", + short_id=thread.short_id, + url_identifier=url_identifier, + created_at=thread.created_at.isoformat(), + updated_at=thread.updated_at.isoformat(), + is_readonly=thread.is_readonly, + message_count=message_count, + last_message_preview=last_message_preview, + ) + ) # Get total count total = ( @@ -418,6 +424,7 @@ async def get_conversation( else: # Org-scoped - verify org membership from app.services.org_membership_service import OrgMembershipService + membership = OrgMembershipService.get_membership(db, UUID(thread.org_id), current_user.id) if not membership: raise HTTPException(status_code=404, detail="Conversation not found") @@ -438,16 +445,11 @@ async def get_conversation( impl_ids = [ item.content_data.get("implementation_id") for item in thread_with_items.items - if item.item_type == ThreadItemType.IMPLEMENTATION_CREATED - and item.content_data.get("implementation_id") + if item.item_type == ThreadItemType.IMPLEMENTATION_CREATED and item.content_data.get("implementation_id") ] impl_map = {} if impl_ids: - implementations = ( - db.query(Implementation) - .filter(Implementation.id.in_(impl_ids)) - .all() - ) + implementations = db.query(Implementation).filter(Implementation.id.in_(impl_ids)).all() impl_map = {str(impl.id): impl for impl in implementations} items = [thread_item_to_response(item, impl_map) for item in thread_with_items.items] @@ -481,6 +483,7 @@ async def send_message( raise HTTPException(status_code=404, detail="Conversation not found") else: from app.services.org_membership_service import OrgMembershipService + membership = OrgMembershipService.get_membership(db, UUID(thread.org_id), current_user.id) if not membership: raise HTTPException(status_code=404, detail="Conversation not found") @@ -549,12 +552,7 @@ async def send_message( job_id = str(job.id) # Reload item with author - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item.id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item.id).options(joinedload(ThreadItem.author)).first() return SendMessageResponse( item=thread_item_to_response(item, {}), diff --git a/backend/app/routers/dashboard.py b/backend/app/routers/dashboard.py index 341cb9a..6e052a1 100644 --- a/backend/app/routers/dashboard.py +++ b/backend/app/routers/dashboard.py @@ -6,18 +6,17 @@ from sqlalchemy.orm import Session from app.database import get_db -from app.permissions.context import get_org_context, OrgContext +from app.permissions.context import OrgContext, get_org_context from app.schemas.dashboard import ( + AgentUsage, DashboardStatsResponse, LLMUsageDetailsResponse, MonthlyLLMUsage, - RecentLLMCall, - AgentUsage, PlanInfo, + RecentLLMCall, ) from app.services.dashboard_service import DashboardService - router = APIRouter(prefix="/dashboard", tags=["dashboard"]) @@ -50,9 +49,7 @@ def get_dashboard_stats( user_count=data["user_count"], project_count=data["project_count"], llm_usage_this_month=MonthlyLLMUsage(**data["llm_usage_this_month"]), - recent_llm_calls=[ - RecentLLMCall.model_validate(call) for call in data["recent_llm_calls"] - ], + recent_llm_calls=[RecentLLMCall.model_validate(call) for call in data["recent_llm_calls"]], plan_info=PlanInfo(**data["plan_info"]), ) diff --git a/backend/app/routers/drafts.py b/backend/app/routers/drafts.py index ec9fe72..2ef97d6 100644 --- a/backend/app/routers/drafts.py +++ b/backend/app/routers/drafts.py @@ -1,4 +1,5 @@ """Router for draft version operations (specs and prompt plans).""" + from typing import List from uuid import UUID @@ -7,22 +8,22 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole +from app.models import ProjectRole, User from app.models.spec_version import SpecType from app.schemas.draft_version import ( + DraftListResponse, DraftVersionCreate, DraftVersionResponse, - DraftListResponse, ) from app.schemas.final_version import ( - FinalSpecResponse, FinalPromptPlanResponse, + FinalSpecResponse, ) +from app.services.activity_log_service import ActivityEventTypes, ActivityLogService from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.draft_version_service import DraftVersionService from app.services.finalization_service import FinalizationService from app.services.project_service import ProjectService -from app.services.activity_log_service import ActivityLogService, ActivityEventTypes router = APIRouter(tags=["drafts"]) diff --git a/backend/app/routers/email_templates.py b/backend/app/routers/email_templates.py index e402bab..1e10ce0 100644 --- a/backend/app/routers/email_templates.py +++ b/backend/app/routers/email_templates.py @@ -2,6 +2,7 @@ All endpoints require platform admin privileges. """ + import logging from typing import Annotated from uuid import UUID diff --git a/backend/app/routers/feature_content_versions.py b/backend/app/routers/feature_content_versions.py index 3d3802b..7012ba4 100644 --- a/backend/app/routers/feature_content_versions.py +++ b/backend/app/routers/feature_content_versions.py @@ -1,4 +1,5 @@ """Router for feature content version operations.""" + from typing import List, Optional from uuid import UUID @@ -7,11 +8,11 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole +from app.models import ProjectRole, User from app.models.feature import Feature from app.models.feature_content_version import FeatureContentType from app.models.module import Module -from app.permissions.context import get_project_context, ProjectContext +from app.permissions.context import get_project_context from app.permissions.helpers import require_project_role from app.schemas.feature_content_version import ( ContentVersionCreate, @@ -235,7 +236,7 @@ async def generate_from_conversation( require_project_role(context, ProjectRole.MEMBER) # Import here to avoid circular imports - from app.models.job import Job, JobType, JobStatus + from app.models.job import JobStatus, JobType from app.services.job_service import JobService from app.services.kafka_producer import get_sync_kafka_producer diff --git a/backend/app/routers/features.py b/backend/app/routers/features.py index 7bbe501..a180407 100644 --- a/backend/app/routers/features.py +++ b/backend/app/routers/features.py @@ -1,66 +1,66 @@ """Router for feature operations (Phase 7 features within modules).""" + +import logging from typing import Dict, List, Optional, Set from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy.orm import Session +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import case, func from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import text, func, case, literal +from sqlalchemy.orm import Session +from sqlalchemy.orm.attributes import flag_modified from app.auth.dependencies import get_current_user -from app.database import get_db, get_async_db -from app.models import User, ProjectRole -from app.models.module import Module -from app.models.thread import Thread, ContextType +from app.database import get_async_db, get_db +from app.integrations.factory import get_adapter as get_bug_tracker_adapter +from app.models import ProjectRole, User +from app.models.feature import FeatureCompletionStatus, FeaturePriority, FeatureProvenance, FeatureType +from app.models.feature_content_version import FeatureContentType, FeatureContentVersion +from app.models.implementation import Implementation +from app.models.job import Job, JobStatus, JobType +from app.models.module import Module, ModuleProvenance, ModuleType +from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem -from app.permissions.context import get_project_context, ProjectContext +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role from app.schemas.feature import ( + ConnectorSourcesRequest, + CreateFeatureWithThreadRequest, + FeatureCompleteRequest, FeatureCreate, - FeatureUpdate, - FeatureResponse, + FeatureImportCommentResponse, + FeatureImportCommentsResponse, + FeatureImportRequest, FeatureListResponse, + FeatureResponse, FeatureRestoreRequest, - FeatureStartRequest, - FeatureCompleteRequest, - CreateFeatureWithThreadRequest, + FeatureSidebarData, + FeatureSortField, + FeatureUpdate, FeatureWithThreadResponse, + GitHubRepositoriesResponse, + GitHubRepository, IssueSearchRequest, IssueSearchResponse, - IssueSearchResult as IssueSearchResultSchema, - FeatureImportRequest, - FeatureImportCommentsResponse, - FeatureImportCommentResponse, - ConnectorSourcesRequest, - GitHubRepository, - GitHubRepositoriesResponse, JiraProject, JiraProjectsResponse, PaginatedFeaturesResponse, - FeatureSortField, SortOrder, - FeatureSidebarData, UnresolvedPointSchema, ) -from app.models.feature import FeatureType, FeatureProvenance, FeaturePriority, FeatureCompletionStatus -from app.models.feature_content_version import FeatureContentVersion, FeatureContentType -from app.models.implementation import Implementation -from app.models.module import ModuleProvenance, ModuleType +from app.schemas.feature import ( + IssueSearchResult as IssueSearchResultSchema, +) from app.schemas.implementation import ClearStatusNotesResponse -from app.services.feature_service import FeatureService -from app.services.feature_content_version_service import FeatureContentVersionService -from app.services.implementation_service import ImplementationService -from app.services.module_service import ModuleService +from app.services.activity_log_service import ActivityEventTypes, ActivityLogService from app.services.brainstorming_phase_service import BrainstormingPhaseService -from app.services.activity_log_service import ActivityLogService, ActivityEventTypes +from app.services.feature_content_version_service import FeatureContentVersionService from app.services.feature_import_service import FeatureImportService +from app.services.feature_service import FeatureService +from app.services.implementation_service import ImplementationService from app.services.integration_service import IntegrationService -from app.integrations.factory import get_adapter as get_bug_tracker_adapter -from app.models.job import Job, JobType, JobStatus from app.services.kafka_producer import get_sync_kafka_producer -from sqlalchemy.orm.attributes import flag_modified -import json -import logging +from app.services.module_service import ModuleService logger = logging.getLogger(__name__) @@ -233,12 +233,18 @@ def list_project_features( project_id: str, include_archived: bool = Query(False, description="Include archived features"), module_id: Optional[str] = Query(None, description="Filter by module ID or URL identifier"), - brainstorming_phase_id: Optional[str] = Query(None, description="Filter by brainstorming phase ID or URL identifier"), - feature_type: Optional[FeatureType] = Query(None, description="Filter by feature type (conversation or implementation)"), + brainstorming_phase_id: Optional[str] = Query( + None, description="Filter by brainstorming phase ID or URL identifier" + ), + feature_type: Optional[FeatureType] = Query( + None, description="Filter by feature type (conversation or implementation)" + ), # New filter parameters priority: Optional[List[FeaturePriority]] = Query(None, description="Filter by priority levels"), completion_status: Optional[List[FeatureCompletionStatus]] = Query(None, description="Filter by completion status"), - provenance: Optional[List[FeatureProvenance]] = Query(None, description="Filter by provenance (system, user, restored)"), + provenance: Optional[List[FeatureProvenance]] = Query( + None, description="Filter by provenance (system, user, restored)" + ), has_spec: Optional[bool] = Query(None, description="Filter by features with/without spec"), has_notes: Optional[bool] = Query(None, description="Filter by features with/without notes"), external_provider: Optional[str] = Query(None, description="Filter by external provider (github, jira, none)"), @@ -303,19 +309,10 @@ def list_project_features( impl_query = ( db.query( Implementation.feature_id, - func.count(Implementation.id).label('count'), - func.max(case( - (Implementation.spec_text.isnot(None), 1), - else_=0 - )).label('has_spec'), - func.max(case( - (Implementation.prompt_plan_text.isnot(None), 1), - else_=0 - )).label('has_prompt_plan'), - func.max(case( - (Implementation.implementation_notes.isnot(None), 1), - else_=0 - )).label('has_notes'), + func.count(Implementation.id).label("count"), + func.max(case((Implementation.spec_text.isnot(None), 1), else_=0)).label("has_spec"), + func.max(case((Implementation.prompt_plan_text.isnot(None), 1), else_=0)).label("has_prompt_plan"), + func.max(case((Implementation.implementation_notes.isnot(None), 1), else_=0)).label("has_notes"), ) .filter(Implementation.feature_id.in_(feature_uuids)) .group_by(Implementation.feature_id) @@ -323,10 +320,10 @@ def list_project_features( ) impl_data = { str(row.feature_id): { - 'count': row.count, - 'has_spec': bool(row.has_spec), - 'has_prompt_plan': bool(row.has_prompt_plan), - 'has_notes': bool(row.has_notes), + "count": row.count, + "has_spec": bool(row.has_spec), + "has_prompt_plan": bool(row.has_prompt_plan), + "has_notes": bool(row.has_notes), } for row in impl_query } @@ -353,9 +350,9 @@ def list_project_features( "completion_status": feature.completion_status, "has_description": bool(feature.description), # Include implementation content in has_* checks - "has_spec": bool(feature.spec_text) or impl_info.get('has_spec', False), - "has_prompt_plan": bool(feature.prompt_plan_text) or impl_info.get('has_prompt_plan', False), - "has_notes": bool(feature.implementation_notes) or impl_info.get('has_notes', False), + "has_spec": bool(feature.spec_text) or impl_info.get("has_spec", False), + "has_prompt_plan": bool(feature.prompt_plan_text) or impl_info.get("has_prompt_plan", False), + "has_notes": bool(feature.implementation_notes) or impl_info.get("has_notes", False), # External import fields "external_provider": feature.external_provider, "external_id": feature.external_id, @@ -363,7 +360,7 @@ def list_project_features( # Unresolved points count from thread "unresolved_count": unresolved_counts.get(feature_id_str, 0), # Implementation count for tooltip display - "implementation_count": impl_info.get('count', 0), + "implementation_count": impl_info.get("count", 0), # Short ID for URL-friendly identifiers "short_id": feature.short_id, "url_identifier": feature.url_identifier, @@ -568,10 +565,7 @@ def get_feature_sidebar_data( .all() ) mcq_total = len(mcq_items) - mcq_answered = sum( - 1 for item in mcq_items - if item.content_data and item.content_data.get("selected_option_id") - ) + mcq_answered = sum(1 for item in mcq_items if item.content_data and item.content_data.get("selected_option_id")) # Count comments comment_count = ( @@ -631,9 +625,7 @@ def update_feature( # Check project membership from app.services.project_service import ProjectService - project_membership = ProjectService.get_project_membership( - db, feature.module.project_id, current_user.id - ) + project_membership = ProjectService.get_project_membership(db, feature.module.project_id, current_user.id) if not project_membership or project_membership.role not in [ ProjectRole.OWNER, ProjectRole.ADMIN, @@ -688,9 +680,7 @@ def archive_feature( # Check project role (ADMIN or OWNER required for archiving) from app.services.project_service import ProjectService - project_membership = ProjectService.get_project_membership( - db, feature.module.project_id, current_user.id - ) + project_membership = ProjectService.get_project_membership(db, feature.module.project_id, current_user.id) if not project_membership or project_membership.role not in [ProjectRole.OWNER, ProjectRole.ADMIN]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -753,9 +743,7 @@ def restore_feature( # Check project role (ADMIN or OWNER required for restoring) from app.services.project_service import ProjectService - project_membership = ProjectService.get_project_membership( - db, feature.module.project_id, current_user.id - ) + project_membership = ProjectService.get_project_membership(db, feature.module.project_id, current_user.id) if not project_membership or project_membership.role not in [ProjectRole.OWNER, ProjectRole.ADMIN]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -824,9 +812,7 @@ def start_feature( # Check project membership from app.services.project_service import ProjectService - project_membership = ProjectService.get_project_membership( - db, feature.module.project_id, current_user.id - ) + project_membership = ProjectService.get_project_membership(db, feature.module.project_id, current_user.id) if not project_membership or project_membership.role not in [ ProjectRole.OWNER, ProjectRole.ADMIN, @@ -895,9 +881,7 @@ def complete_feature( # Check project membership from app.services.project_service import ProjectService - project_membership = ProjectService.get_project_membership( - db, feature.module.project_id, current_user.id - ) + project_membership = ProjectService.get_project_membership(db, feature.module.project_id, current_user.id) if not project_membership or project_membership.role not in [ ProjectRole.OWNER, ProjectRole.ADMIN, @@ -1177,9 +1161,7 @@ async def list_github_repositories_for_import( try: # Get the connector integration_service = IntegrationService(async_db) - config = await integration_service.get_config_by_id( - project_context.project.org_id, UUID(request.connector_id) - ) + config = await integration_service.get_config_by_id(project_context.project.org_id, UUID(request.connector_id)) if not config: raise HTTPException( @@ -1241,9 +1223,7 @@ async def list_jira_projects_for_import( try: # Get the connector integration_service = IntegrationService(async_db) - config = await integration_service.get_config_by_id( - project_context.project.org_id, UUID(request.connector_id) - ) + config = await integration_service.get_config_by_id(project_context.project.org_id, UUID(request.connector_id)) if not config: raise HTTPException( diff --git a/backend/app/routers/form_drafts.py b/backend/app/routers/form_drafts.py index d617f61..1363c3c 100644 --- a/backend/app/routers/form_drafts.py +++ b/backend/app/routers/form_drafts.py @@ -10,7 +10,7 @@ from app.database import get_db from app.models import User from app.models.form_draft import FormDraftType -from app.permissions.context import get_project_context, ProjectContext +from app.permissions.context import ProjectContext, get_project_context from app.schemas.form_draft import ( FormDraftListItem, FormDraftResponse, diff --git a/backend/app/routers/grounding.py b/backend/app/routers/grounding.py index 0952144..39ee8b8 100644 --- a/backend/app/routers/grounding.py +++ b/backend/app/routers/grounding.py @@ -1,9 +1,8 @@ """Grounding file management router for coding agent warm starts.""" import logging -from typing import Annotated, List, Optional -from uuid import UUID from datetime import datetime +from typing import Annotated, List, Optional from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, ConfigDict @@ -22,6 +21,7 @@ class GroundingFileResponse(BaseModel): """Response schema for grounding file.""" + model_config = ConfigDict(from_attributes=True) id: str @@ -102,13 +102,9 @@ async def get_grounding_file( """ # Ensure agents.md exists if filename == "agents.md": - grounding_file = GroundingService.get_or_create_agents_md( - db, project_context.project.id - ) + grounding_file = GroundingService.get_or_create_agents_md(db, project_context.project.id) else: - grounding_file = GroundingService.get_file( - db, project_context.project.id, filename - ) + grounding_file = GroundingService.get_file(db, project_context.project.id, filename) if not grounding_file: raise HTTPException( @@ -154,9 +150,7 @@ async def create_grounding_file( ) # Broadcast grounding file creation to WebSocket clients - GroundingService._broadcast_grounding_update( - db, project_context.project.id, grounding_file, "created" - ) + GroundingService._broadcast_grounding_update(db, project_context.project.id, grounding_file, "created") return _to_response(grounding_file) @@ -190,9 +184,7 @@ async def update_grounding_file( ) # Broadcast grounding file update to WebSocket clients - GroundingService._broadcast_grounding_update( - db, project_context.project.id, grounding_file, "written" - ) + GroundingService._broadcast_grounding_update(db, project_context.project.id, grounding_file, "written") return _to_response(grounding_file) @@ -223,9 +215,7 @@ async def delete_grounding_file( ) # Broadcast deletion BEFORE actually deleting (need the object for broadcast) - GroundingService._broadcast_grounding_update( - db, project_context.project.id, grounding_file, "deleted" - ) + GroundingService._broadcast_grounding_update(db, project_context.project.id, grounding_file, "deleted") GroundingService.delete_file(db, project_context.project.id, filename) @@ -237,6 +227,7 @@ async def delete_grounding_file( class GroundingFileBranchResponse(BaseModel): """Response schema for branch-specific grounding file.""" + model_config = ConfigDict(from_attributes=True) id: str @@ -299,9 +290,7 @@ async def list_user_branches( Returns branch-specific versions of agents.md for the logged-in user. Requires project membership. """ - branches = GroundingService.list_user_branches( - db, project_context.project.id, current_user.id - ) + branches = GroundingService.list_user_branches(db, project_context.project.id, current_user.id) return BranchListResponse( branches=[_to_branch_response(b) for b in branches], @@ -322,9 +311,7 @@ async def get_branch_file( Returns the branch-specific agents.md for the logged-in user. Requires project membership. """ - branch_file = GroundingService.get_branch_file( - db, project_context.project.id, current_user.id, branch_name - ) + branch_file = GroundingService.get_branch_file(db, project_context.project.id, current_user.id, branch_name) if not branch_file: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -348,14 +335,12 @@ async def merge_branch( into the global agents.md file. Requires project membership. """ - from app.services.job_service import JobService from app.models.job import JobType + from app.services.job_service import JobService from workers.core.helpers import publish_job_to_kafka # Verify branch exists - branch_file = GroundingService.get_branch_file( - db, project_context.project.id, current_user.id, branch_name - ) + branch_file = GroundingService.get_branch_file(db, project_context.project.id, current_user.id, branch_name) if not branch_file: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -374,9 +359,7 @@ async def merge_branch( GroundingService.set_branch_merging_flag(db, branch_file, True) # Broadcast branch update to notify UI of is_merging change - GroundingService._broadcast_branch_grounding_update( - db, project_context.project.id, branch_file, "merging_started" - ) + GroundingService._broadcast_branch_grounding_update(db, project_context.project.id, branch_file, "merging_started") # Create merge job job = JobService.create_job( @@ -437,14 +420,12 @@ async def pull_from_global( into the branch-specific agents.md file, preserving branch work. Requires project membership. """ - from app.services.job_service import JobService from app.models.job import JobType + from app.services.job_service import JobService from workers.core.helpers import publish_job_to_kafka # Verify branch exists - branch_file = GroundingService.get_branch_file( - db, project_context.project.id, current_user.id, branch_name - ) + branch_file = GroundingService.get_branch_file(db, project_context.project.id, current_user.id, branch_name) if not branch_file: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -459,9 +440,7 @@ async def pull_from_global( ) # Get global file to compare timestamps - global_file = GroundingService.get_or_create_agents_md( - db, project_context.project.id, current_user.id - ) + global_file = GroundingService.get_or_create_agents_md(db, project_context.project.id, current_user.id) # Check if pull is needed (global must be newer than branch) if global_file.content_updated_at <= branch_file.content_updated_at: @@ -474,9 +453,7 @@ async def pull_from_global( GroundingService.set_branch_merging_flag(db, branch_file, True) # Broadcast branch update to notify UI - GroundingService._broadcast_branch_grounding_update( - db, project_context.project.id, branch_file, "pulling_started" - ) + GroundingService._broadcast_branch_grounding_update(db, project_context.project.id, branch_file, "pulling_started") # Create pull job job = JobService.create_job( @@ -543,8 +520,8 @@ async def regenerate_grounding( Only available for brownfield projects (projects with GitHub repo configured). Requires project membership. """ - from app.services.job_service import JobService from app.models.job import JobType + from app.services.job_service import JobService from workers.core.helpers import publish_job_to_kafka project = project_context.project @@ -571,9 +548,7 @@ async def regenerate_grounding( GroundingService.set_generating_flag(db, agents_md, True) # Broadcast update to notify UI of is_generating change - GroundingService._broadcast_grounding_update( - db, project.id, agents_md, "generating_started" - ) + GroundingService._broadcast_grounding_update(db, project.id, agents_md, "generating_started") # Get default branch from primary repository default_branch = primary_repo.default_branch or "main" @@ -603,9 +578,7 @@ async def regenerate_grounding( logger.error(f"Failed to publish grounding regeneration job: {e}") # Clear is_generating flag if Kafka publish fails GroundingService.set_generating_flag(db, agents_md, False) - GroundingService._broadcast_grounding_update( - db, project.id, agents_md, "generating_failed" - ) + GroundingService._broadcast_grounding_update(db, project.id, agents_md, "generating_failed") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to queue grounding regeneration job", diff --git a/backend/app/routers/grounding_notes.py b/backend/app/routers/grounding_notes.py index ce74e2f..15b4a27 100644 --- a/backend/app/routers/grounding_notes.py +++ b/backend/app/routers/grounding_notes.py @@ -8,7 +8,7 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole +from app.models import ProjectRole, User from app.permissions.context import get_project_context from app.permissions.helpers import require_project_role from app.schemas.grounding_note import ( diff --git a/backend/app/routers/images.py b/backend/app/routers/images.py index 2ebf1ee..aaa7b77 100644 --- a/backend/app/routers/images.py +++ b/backend/app/routers/images.py @@ -11,16 +11,15 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.responses import Response from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.project_chat import ProjectChatMessage +from app.auth.dependencies import get_current_user +from app.database import SessionLocal, get_db +from app.models import User from app.models.brainstorming_phase import BrainstormingPhase from app.models.feature import Feature -from app.models import User -from app.database import SessionLocal, get_db, get_async_db -from app.auth.dependencies import get_current_user +from app.models.project_chat import ProjectChatMessage +from app.models.thread_item import ThreadItem, ThreadItemType from app.services.image_service import ImageService logger = logging.getLogger(__name__) @@ -82,11 +81,7 @@ async def serve_image( target_image = None # 1. Search thread item comment images - items = ( - db.query(ThreadItem) - .filter(ThreadItem.item_type == ThreadItemType.COMMENT) - .all() - ) + items = db.query(ThreadItem).filter(ThreadItem.item_type == ThreadItemType.COMMENT).all() for item in items: content_data = item.content_data or {} images = content_data.get("images", []) @@ -101,7 +96,7 @@ async def serve_image( if not target_image: messages = db.query(ProjectChatMessage).filter(ProjectChatMessage.images.isnot(None)).all() for msg in messages: - for img in (msg.images or []): + for img in msg.images or []: if img.get("id") == image_id: target_image = img break @@ -110,11 +105,11 @@ async def serve_image( # 3. Search brainstorming phase description images if not target_image: - phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.description_image_attachments.isnot(None) - ).all() + phases = ( + db.query(BrainstormingPhase).filter(BrainstormingPhase.description_image_attachments.isnot(None)).all() + ) for phase in phases: - for img in (phase.description_image_attachments or []): + for img in phase.description_image_attachments or []: if img.get("id") == image_id: target_image = img break @@ -123,11 +118,9 @@ async def serve_image( # 4. Search feature description images if not target_image: - features = db.query(Feature).filter( - Feature.description_image_attachments.isnot(None) - ).all() + features = db.query(Feature).filter(Feature.description_image_attachments.isnot(None)).all() for feature in features: - for img in (feature.description_image_attachments or []): + for img in feature.description_image_attachments or []: if img.get("id") == image_id: target_image = img break @@ -205,11 +198,7 @@ def _find_image_by_id(db: Session, image_id: str) -> dict | None: Image metadata dict or None if not found """ # 1. Search thread item comment images - items = ( - db.query(ThreadItem) - .filter(ThreadItem.item_type == ThreadItemType.COMMENT) - .all() - ) + items = db.query(ThreadItem).filter(ThreadItem.item_type == ThreadItemType.COMMENT).all() for item in items: content_data = item.content_data or {} images = content_data.get("images", []) @@ -220,25 +209,21 @@ def _find_image_by_id(db: Session, image_id: str) -> dict | None: # 2. Search pre-phase discussion message images messages = db.query(ProjectChatMessage).filter(ProjectChatMessage.images.isnot(None)).all() for msg in messages: - for img in (msg.images or []): + for img in msg.images or []: if img.get("id") == image_id: return img # 3. Search brainstorming phase description images - phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.description_image_attachments.isnot(None) - ).all() + phases = db.query(BrainstormingPhase).filter(BrainstormingPhase.description_image_attachments.isnot(None)).all() for phase in phases: - for img in (phase.description_image_attachments or []): + for img in phase.description_image_attachments or []: if img.get("id") == image_id: return img # 4. Search feature description images - features = db.query(Feature).filter( - Feature.description_image_attachments.isnot(None) - ).all() + features = db.query(Feature).filter(Feature.description_image_attachments.isnot(None)).all() for feature in features: - for img in (feature.description_image_attachments or []): + for img in feature.description_image_attachments or []: if img.get("id") == image_id: return img diff --git a/backend/app/routers/implementations.py b/backend/app/routers/implementations.py index d82ba27..f610b42 100644 --- a/backend/app/routers/implementations.py +++ b/backend/app/routers/implementations.py @@ -1,6 +1,7 @@ """Router for implementation operations.""" + import logging -from typing import List, Optional +from typing import List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status @@ -11,31 +12,33 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole +from app.models import ProjectRole, User from app.models.feature import Feature from app.models.module import Module from app.models.project import Project -from app.models.thread import Thread, ContextType -from app.services.org_service import OrgService -from app.permissions.context import get_project_context, ProjectContext +from app.models.thread import ContextType, Thread +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role from app.schemas.implementation import ( + ClearStatusNotesResponse, ImplementationCreate, - ImplementationUpdate, - ImplementationResponse, ImplementationListItem, + ImplementationResponse, ImplementationSidebarData, - ClearStatusNotesResponse, + ImplementationUpdate, ) from app.services.implementation_service import ImplementationService +from app.services.org_service import OrgService from app.services.thread_service import ThreadService class AutoGenerateResponse(BaseModel): """Response schema for auto-generate endpoint.""" + spec_job_id: str implementation_id: str + router = APIRouter(tags=["implementations"]) @@ -45,10 +48,7 @@ def _get_feature_or_404(db: Session, feature_id: str) -> Feature: feature = FeatureService.get_by_identifier(db, feature_id) if not feature: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Feature not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Feature not found") return feature @@ -315,10 +315,7 @@ async def update_implementation( ) if not impl: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Implementation not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Implementation not found") return _implementation_to_response(impl) @@ -345,10 +342,7 @@ async def mark_implementation_complete( ) if not impl: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Implementation not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Implementation not found") return _implementation_to_response(impl) @@ -374,10 +368,7 @@ async def reopen_implementation( ) if not impl: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Implementation not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Implementation not found") return _implementation_to_response(impl) @@ -405,10 +396,7 @@ async def clear_implementation_status_notes( ) if not impl: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Implementation not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Implementation not found") # Broadcast WebSocket update so frontend refreshes ImplementationService.broadcast_implementation_updated(db, impl, "clear_status_notes") @@ -469,10 +457,7 @@ async def set_implementation_primary( ) if not impl: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Implementation not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Implementation not found") return _implementation_to_response(impl) @@ -501,7 +486,7 @@ async def delete_implementation( if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot delete implementation. It may not exist or is the only one for this feature." + detail="Cannot delete implementation. It may not exist or is the only one for this feature.", ) return None @@ -523,7 +508,7 @@ async def auto_generate_implementation_content( Creates a spec generation job that will automatically chain to prompt plan generation upon completion. """ - from app.models.job import JobType, JobStatus + from app.models.job import JobStatus, JobType from app.services.job_service import JobService from app.services.kafka_producer import get_sync_kafka_producer @@ -620,6 +605,7 @@ async def auto_generate_implementation_content( class CancelGenerationResponse(BaseModel): """Response schema for cancel-generation endpoint.""" + cancelled_jobs: int cleared_flags: List[str] @@ -685,14 +671,16 @@ async def cancel_implementation_generation( # Find the ThreadItem that shows "Implementation Created" in conversation # We'll delete it AFTER successfully deleting the implementation - from app.models.thread_item import ThreadItem as ThreadItemModel, ThreadItemType from sqlalchemy.orm import joinedload + from app.models.thread_item import ThreadItem as ThreadItemModel + from app.models.thread_item import ThreadItemType + impl_thread_item = ( db.query(ThreadItemModel) .filter( ThreadItemModel.item_type == ThreadItemType.IMPLEMENTATION_CREATED, - ThreadItemModel.content_data.op('->>')('implementation_id') == deleted_impl_id_str, + ThreadItemModel.content_data.op("->>")("implementation_id") == deleted_impl_id_str, ) .options(joinedload(ThreadItemModel.author)) .first() @@ -708,9 +696,7 @@ async def cancel_implementation_generation( thread_item_deleted = False if impl_thread_item: # Broadcast thread item deletion BEFORE deleting (need item data for broadcast) - ThreadService._broadcast_thread_item_update( - db, impl_thread_item, "thread_item_deleted" - ) + ThreadService._broadcast_thread_item_update(db, impl_thread_item, "thread_item_deleted") db.delete(impl_thread_item) db.commit() diff --git a/backend/app/routers/inbox.py b/backend/app/routers/inbox.py index d13d67d..b9eef35 100644 --- a/backend/app/routers/inbox.py +++ b/backend/app/routers/inbox.py @@ -2,7 +2,6 @@ import logging from typing import Optional -from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel @@ -10,20 +9,20 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models.user import User from app.models.inbox_mention import InboxConversationType -from app.services.inbox_badge_service import InboxBadgeService -from app.services.inbox_conversation_service import InboxConversationService -from app.services.inbox_mention_service import InboxMentionService -from app.services.inbox_status_service import InboxStatusService -from app.services.org_service import OrgService +from app.models.user import User from app.schemas.inbox_badge import OrgBadgeCountsResponse from app.schemas.inbox_conversation import ( + ConversationSortField, InboxConversationsRequest, InboxConversationsResponse, - ConversationSortField, SortOrder, ) +from app.services.inbox_badge_service import InboxBadgeService +from app.services.inbox_conversation_service import InboxConversationService +from app.services.inbox_mention_service import InboxMentionService +from app.services.inbox_status_service import InboxStatusService +from app.services.org_service import OrgService class UpdateReadPositionRequest(BaseModel): @@ -33,6 +32,7 @@ class UpdateReadPositionRequest(BaseModel): conversation_id: str sequence_number: int + logger = logging.getLogger(__name__) router = APIRouter(tags=["inbox"]) diff --git a/backend/app/routers/inbox_deep_link.py b/backend/app/routers/inbox_deep_link.py index 7696987..ce317c0 100644 --- a/backend/app/routers/inbox_deep_link.py +++ b/backend/app/routers/inbox_deep_link.py @@ -1,7 +1,7 @@ """Router for inbox deep link resolution.""" + import logging from typing import Optional -from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session diff --git a/backend/app/routers/inbox_follows.py b/backend/app/routers/inbox_follows.py index e10295a..2e73930 100644 --- a/backend/app/routers/inbox_follows.py +++ b/backend/app/routers/inbox_follows.py @@ -8,29 +8,29 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models.user import User -from app.models.inbox_follow import InboxFollowType, InboxThreadType +from app.models.inbox_follow import InboxThreadType from app.models.inbox_mention import InboxConversationType -from app.services.inbox_follow_service import InboxFollowService -from app.services.inbox_badge_service import InboxBadgeService -from app.services.inbox_conversation_service import InboxConversationService -from app.services.project_service import ProjectService -from app.services.project_share_service import ProjectShareService +from app.models.user import User from app.schemas.inbox_badge import BadgeCountsResponse, ConversationBadgeCount from app.schemas.inbox_conversation import ( + ConversationSortField, InboxConversationsRequest, InboxConversationsResponse, - ConversationSortField, SortOrder, ) from app.schemas.inbox_follow import ( + EffectiveFollowsResponse, FollowThreadRequest, InboxFollowResponse, ProjectFollowStatusResponse, ThreadFollowStatusResponse, - EffectiveFollowsResponse, UnfollowResponse, ) +from app.services.inbox_badge_service import InboxBadgeService +from app.services.inbox_conversation_service import InboxConversationService +from app.services.inbox_follow_service import InboxFollowService +from app.services.project_service import ProjectService +from app.services.project_share_service import ProjectShareService logger = logging.getLogger(__name__) router = APIRouter(tags=["inbox-follows"]) @@ -144,9 +144,7 @@ def unfollow_project( def unfollow_thread( project_id: str, thread_id: str = Query(..., description="Thread ID to unfollow"), - thread_type: InboxThreadType = Query( - ..., description="Type: feature, phase, project_chat" - ), + thread_type: InboxThreadType = Query(..., description="Type: feature, phase, project_chat"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> UnfollowResponse: @@ -165,8 +163,7 @@ def unfollow_thread( if result: logger.info( - f"User {current_user.id} unfollowed thread {thread_id} " - f"(type={thread_type.value}) in project {project.id}" + f"User {current_user.id} unfollowed thread {thread_id} (type={thread_type.value}) in project {project.id}" ) return UnfollowResponse(success=True, message="Successfully unfollowed thread") else: @@ -203,14 +200,8 @@ def get_effective_follows( ) return EffectiveFollowsResponse( - project_follow=( - InboxFollowResponse.model_validate(project_follow) - if project_follow - else None - ), - thread_follows=[ - InboxFollowResponse.model_validate(f) for f in thread_follows - ], + project_follow=(InboxFollowResponse.model_validate(project_follow) if project_follow else None), + thread_follows=[InboxFollowResponse.model_validate(f) for f in thread_follows], is_following_project=project_follow is not None, ) @@ -246,9 +237,7 @@ def get_project_follow_status( def get_thread_follow_status( project_id: str, thread_id: str = Query(..., description="Thread ID to check"), - thread_type: InboxThreadType = Query( - ..., description="Type: feature, phase, project_chat" - ), + thread_type: InboxThreadType = Query(..., description="Type: feature, phase, project_chat"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> ThreadFollowStatusResponse: @@ -294,9 +283,7 @@ def get_badge_counts( return BadgeCountsResponse( unread_mentions=counts["unread_mentions"], - conversations=[ - ConversationBadgeCount(**c) for c in counts["conversations"] - ], + conversations=[ConversationBadgeCount(**c) for c in counts["conversations"]], total_unread=counts["total_unread"], ) diff --git a/backend/app/routers/integrations.py b/backend/app/routers/integrations.py index e3a8044..7bbb271 100644 --- a/backend/app/routers/integrations.py +++ b/backend/app/routers/integrations.py @@ -1,4 +1,5 @@ """Integration config router for managing external integrations (bug trackers and LLMs).""" + import json import logging from uuid import UUID @@ -13,22 +14,22 @@ logger = logging.getLogger(__name__) from app.database import get_async_db -from app.models import User, OrgRole +from app.integrations.factory import get_adapter as get_bug_tracker_adapter +from app.models import OrgRole, User from app.models.integration_config import IntegrationVisibility -from app.permissions.context import get_org_context, OrgContext -from app.permissions.helpers import require_org_role, is_org_admin_or_higher +from app.permissions.context import OrgContext, get_org_context +from app.permissions.helpers import is_org_admin_or_higher, require_org_role from app.schemas.integration_config import ( IntegrationConfigCreate, - IntegrationConfigUpdate, IntegrationConfigResponse, IntegrationConfigShareCreate, - IntegrationConfigShareResponse, IntegrationConfigShareListResponse, + IntegrationConfigShareResponse, + IntegrationConfigUpdate, ) -from app.services.integration_service import IntegrationService from app.services.integration_config_share_service import IntegrationConfigShareService +from app.services.integration_service import IntegrationService from app.services.llm_adapters import get_llm_adapter -from app.integrations.factory import get_adapter as get_bug_tracker_adapter router = APIRouter(tags=["integrations"]) @@ -332,6 +333,7 @@ async def delete_integration_config( # If this is a GitHub connector, clear connector reference from project repositories if config.provider == "github": from sqlalchemy import update + from app.models.project_repository import ProjectRepository await db.execute( @@ -497,8 +499,9 @@ async def list_org_github_repos( require_org_role(org_context, OrgRole.MEMBER) from sqlalchemy import select - from app.models.llm_preference import LLMPreference + from app.models.integration_config import IntegrationConfig + from app.models.llm_preference import LLMPreference # Get the Code Explorer GitHub connector config = None @@ -509,9 +512,7 @@ async def list_org_github_repos( pref = pref_result.scalar_one_or_none() if pref and pref.code_explorer_github_config_id: - config_stmt = select(IntegrationConfig).where( - IntegrationConfig.id == pref.code_explorer_github_config_id - ) + config_stmt = select(IntegrationConfig).where(IntegrationConfig.id == pref.code_explorer_github_config_id) config_result = await db.execute(config_stmt) config = config_result.scalar_one_or_none() @@ -896,10 +897,7 @@ async def github_oauth_initiate( ) # Build redirect URI for callback (static path - org_id is in the state parameter) - redirect_uri = ( - f"{settings.base_url}{settings.api_v1_prefix}" - f"/integrations/github/oauth/callback" - ) + redirect_uri = f"{settings.base_url}{settings.api_v1_prefix}/integrations/github/oauth/callback" # Get authorization URL with the client_id from UI/ENV auth_url = service.get_authorization_url(state_token, redirect_uri, client_id=client_id) diff --git a/backend/app/routers/invitations.py b/backend/app/routers/invitations.py index da53c9c..1468bb3 100644 --- a/backend/app/routers/invitations.py +++ b/backend/app/routers/invitations.py @@ -19,6 +19,7 @@ from app.models.org_invitation_group import OrgInvitationGroup from app.models.org_membership import ProvisioningSource from app.permissions import OrgContext, get_org_context, require_org_role +from app.plugin_registry import get_plugin_registry from app.schemas.invitation import ( InvitationCreateRequest, InvitationCreateResponse, @@ -26,11 +27,9 @@ InvitationResponse, InvitationSendResult, ) -from app.models.user import User from app.services.email_service import EmailService from app.services.invitation_service import InvitationService from app.services.user_group_service import UserGroupService -from app.plugin_registry import get_plugin_registry logger = logging.getLogger(__name__) @@ -93,9 +92,7 @@ async def create_invitations( # Check if enterprise invitation plugin is available AND org is SSO-linked registry = get_plugin_registry() use_enterprise = ( - registry.invitation_plugin - and registry.invitation_plugin.on_create - and org_context.org.organization_id + registry.invitation_plugin and registry.invitation_plugin.on_create and org_context.org.organization_id ) if use_enterprise: @@ -114,9 +111,7 @@ async def create_invitations( for email in request.emails: # Check for existing pending invitation - existing = InvitationService.get_pending_invitation_for_email( - db, org_id, email - ) + existing = InvitationService.get_pending_invitation_for_email(db, org_id, email) if existing: results.append( InvitationSendResult( @@ -140,9 +135,7 @@ async def create_invitations( ) for old_inv in old_invitations: # Delete associated group assignments first - db.query(OrgInvitationGroup).filter( - OrgInvitationGroup.invitation_id == old_inv.id - ).delete() + db.query(OrgInvitationGroup).filter(OrgInvitationGroup.invitation_id == old_inv.id).delete() db.delete(old_inv) if old_invitations: db.commit() @@ -177,13 +170,12 @@ async def create_invitations( success=True, invitation_id=invitation.id, email_sent=email_result.success, - error=f"Invitation created but email failed: {email_result.message}" if not email_result.success else None, + error=f"Invitation created but email failed: {email_result.message}" + if not email_result.success + else None, ) ) - logger.info( - f"Invitation created: {email} -> {org_context.org.name} " - f"(email_sent={email_result.success})" - ) + logger.info(f"Invitation created: {email} -> {org_context.org.name} (email_sent={email_result.success})") except Exception as e: logger.exception(f"Failed to create invitation for {email}: {e}") @@ -252,14 +244,10 @@ def list_invitations( ) # Get invitations - invitations = InvitationService.list_org_invitations( - db, org_id, invitation_status - ) + invitations = InvitationService.list_org_invitations(db, org_id, invitation_status) # Transform to response - response_items = [ - InvitationResponse.from_orm_with_groups(inv) for inv in invitations - ] + response_items = [InvitationResponse.from_orm_with_groups(inv) for inv in invitations] return InvitationListResponse( invitations=response_items, @@ -411,8 +399,7 @@ async def resend_invitation( ) logger.info( - f"Invitation resent: {invitation.email} -> {org_context.org.name} " - f"(email_sent={email_result.success})" + f"Invitation resent: {invitation.email} -> {org_context.org.name} (email_sent={email_result.success})" ) return InvitationSendResult( email=invitation.email, diff --git a/backend/app/routers/invite_acceptance.py b/backend/app/routers/invite_acceptance.py index f9aa246..d51b2d4 100644 --- a/backend/app/routers/invite_acceptance.py +++ b/backend/app/routers/invite_acceptance.py @@ -14,13 +14,13 @@ from app.database import get_db from app.models.org_membership import OrgRole from app.models.user import User +from app.plugin_registry import get_plugin_registry from app.schemas.invitation import ( InviteAcceptRequest, InviteAcceptResponse, InviteValidationResponse, OrgSummary, ) -from app.plugin_registry import get_plugin_registry from app.services.invitation_service import InvitationService from app.services.org_service import OrgService from app.services.sample_project_service import SampleProjectService @@ -154,9 +154,7 @@ def accept_invite( has_owned_org = any(role == OrgRole.OWNER for _, role in user_orgs) if not has_owned_org: - personal_org_name = ( - f"{current_user.display_name or current_user.email}'s Organization" - ) + personal_org_name = f"{current_user.display_name or current_user.email}'s Organization" personal_org, _ = OrgService.create_org_with_owner( db=db, name=personal_org_name, @@ -168,9 +166,7 @@ def accept_invite( registry.plan_plugin.on_org_created(db, personal_org, current_user) # Create sample onboarding project in the personal org try: - SampleProjectService.create_sample_project( - db, personal_org.id, current_user.id - ) + SampleProjectService.create_sample_project(db, personal_org.id, current_user.id) except Exception as e: logger.warning(f"Failed to create sample project for personal org: {e}") # Note: current_org_id stays as invitation.org_id (set above) diff --git a/backend/app/routers/jobs.py b/backend/app/routers/jobs.py index 48bbe8a..72e0c1b 100644 --- a/backend/app/routers/jobs.py +++ b/backend/app/routers/jobs.py @@ -1,15 +1,16 @@ """Jobs API endpoints.""" +from typing import Optional +from uuid import UUID + from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session -from uuid import UUID -from typing import Optional from app.auth.dependencies import get_current_user from app.database import get_db from app.models import User from app.models.job import JobStatus -from app.permissions.context import get_org_context, OrgContext +from app.permissions.context import OrgContext, get_org_context from app.schemas.job import JobListResponse from app.services.job_service import JobService from app.services.project_service import ProjectService diff --git a/backend/app/routers/llm_call_logs.py b/backend/app/routers/llm_call_logs.py index 81d0629..ed952a8 100644 --- a/backend/app/routers/llm_call_logs.py +++ b/backend/app/routers/llm_call_logs.py @@ -1,9 +1,10 @@ """LLM Call Logs API endpoints for Agent Log feature.""" +from typing import Dict, List, Optional +from uuid import UUID + from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session -from uuid import UUID -from typing import Optional, List, Dict from app.auth.dependencies import get_current_user from app.auth.platform_admin import is_platform_admin, require_platform_admin @@ -12,10 +13,10 @@ from app.models import User from app.models.job import JobType from app.models.project import Project -from app.permissions.context import get_org_context, OrgContext, get_project_context, ProjectContext -from app.schemas.llm_call_log import LLMCallLogDetail, LLMCallLogSummary, JobWithCallLogs, AgentLogListResponse -from app.services.llm_call_log_service import LLMCallLogService +from app.permissions.context import OrgContext, ProjectContext, get_org_context, get_project_context +from app.schemas.llm_call_log import AgentLogListResponse, JobWithCallLogs, LLMCallLogDetail, LLMCallLogSummary from app.services.job_service import JobService +from app.services.llm_call_log_service import LLMCallLogService from app.services.org_service import OrgService @@ -98,26 +99,28 @@ def list_platform_agent_logs( # Get user info job_user_info = user_info.get(job.triggered_by_user_id, {}) if job.triggered_by_user_id else {} - result.append(JobWithCallLogs( - id=job.id, - job_type=job.job_type.value, - status=job.status.value, - model_used=job.model_used, - total_prompt_tokens=job.total_prompt_tokens, - total_completion_tokens=job.total_completion_tokens, - total_cost_usd=float(job.total_cost_usd) if job.total_cost_usd else None, - created_at=job.created_at, - started_at=job.started_at, - finished_at=job.finished_at, - project_id=job.project_id, - project_name=project_names.get(job.project_id) if job.project_id else None, - triggered_by_user_id=job.triggered_by_user_id, - triggered_by_user_email=job_user_info.get("email"), - triggered_by_user_display_name=job_user_info.get("display_name"), - duration_seconds=duration_seconds, - call_logs=[LLMCallLogSummary.model_validate(log) for log in call_logs], - call_count=len(call_logs), - )) + result.append( + JobWithCallLogs( + id=job.id, + job_type=job.job_type.value, + status=job.status.value, + model_used=job.model_used, + total_prompt_tokens=job.total_prompt_tokens, + total_completion_tokens=job.total_completion_tokens, + total_cost_usd=float(job.total_cost_usd) if job.total_cost_usd else None, + created_at=job.created_at, + started_at=job.started_at, + finished_at=job.finished_at, + project_id=job.project_id, + project_name=project_names.get(job.project_id) if job.project_id else None, + triggered_by_user_id=job.triggered_by_user_id, + triggered_by_user_email=job_user_info.get("email"), + triggered_by_user_display_name=job_user_info.get("display_name"), + duration_seconds=duration_seconds, + call_logs=[LLMCallLogSummary.model_validate(log) for log in call_logs], + call_count=len(call_logs), + ) + ) return AgentLogListResponse(items=result, total=total, limit=limit, offset=offset) @@ -183,26 +186,28 @@ def list_agent_logs( # Get user info job_user_info = user_info.get(job.triggered_by_user_id, {}) if job.triggered_by_user_id else {} - result.append(JobWithCallLogs( - id=job.id, - job_type=job.job_type.value, - status=job.status.value, - model_used=job.model_used, - total_prompt_tokens=job.total_prompt_tokens, - total_completion_tokens=job.total_completion_tokens, - total_cost_usd=float(job.total_cost_usd) if job.total_cost_usd else None, - created_at=job.created_at, - started_at=job.started_at, - finished_at=job.finished_at, - project_id=job.project_id, - project_name=project_names.get(job.project_id) if job.project_id else None, - triggered_by_user_id=job.triggered_by_user_id, - triggered_by_user_email=job_user_info.get("email"), - triggered_by_user_display_name=job_user_info.get("display_name"), - duration_seconds=duration_seconds, - call_logs=[LLMCallLogSummary.model_validate(log) for log in call_logs], - call_count=len(call_logs), - )) + result.append( + JobWithCallLogs( + id=job.id, + job_type=job.job_type.value, + status=job.status.value, + model_used=job.model_used, + total_prompt_tokens=job.total_prompt_tokens, + total_completion_tokens=job.total_completion_tokens, + total_cost_usd=float(job.total_cost_usd) if job.total_cost_usd else None, + created_at=job.created_at, + started_at=job.started_at, + finished_at=job.finished_at, + project_id=job.project_id, + project_name=project_names.get(job.project_id) if job.project_id else None, + triggered_by_user_id=job.triggered_by_user_id, + triggered_by_user_email=job_user_info.get("email"), + triggered_by_user_display_name=job_user_info.get("display_name"), + duration_seconds=duration_seconds, + call_logs=[LLMCallLogSummary.model_validate(log) for log in call_logs], + call_count=len(call_logs), + ) + ) return AgentLogListResponse(items=result, total=total, limit=limit, offset=offset) @@ -278,26 +283,28 @@ def list_project_agent_logs( # Get user info job_user_info = user_info.get(job.triggered_by_user_id, {}) if job.triggered_by_user_id else {} - result.append(JobWithCallLogs( - id=job.id, - job_type=job.job_type.value, - status=job.status.value, - model_used=job.model_used, - total_prompt_tokens=job.total_prompt_tokens, - total_completion_tokens=job.total_completion_tokens, - total_cost_usd=float(job.total_cost_usd) if job.total_cost_usd else None, - created_at=job.created_at, - started_at=job.started_at, - finished_at=job.finished_at, - project_id=job.project_id, - project_name=project_name, - triggered_by_user_id=job.triggered_by_user_id, - triggered_by_user_email=job_user_info.get("email"), - triggered_by_user_display_name=job_user_info.get("display_name"), - duration_seconds=duration_seconds, - call_logs=[LLMCallLogSummary.model_validate(log) for log in call_logs], - call_count=len(call_logs), - )) + result.append( + JobWithCallLogs( + id=job.id, + job_type=job.job_type.value, + status=job.status.value, + model_used=job.model_used, + total_prompt_tokens=job.total_prompt_tokens, + total_completion_tokens=job.total_completion_tokens, + total_cost_usd=float(job.total_cost_usd) if job.total_cost_usd else None, + created_at=job.created_at, + started_at=job.started_at, + finished_at=job.finished_at, + project_id=job.project_id, + project_name=project_name, + triggered_by_user_id=job.triggered_by_user_id, + triggered_by_user_email=job_user_info.get("email"), + triggered_by_user_display_name=job_user_info.get("display_name"), + duration_seconds=duration_seconds, + call_logs=[LLMCallLogSummary.model_validate(log) for log in call_logs], + call_count=len(call_logs), + ) + ) return AgentLogListResponse(items=result, total=total, limit=limit, offset=offset) @@ -308,6 +315,7 @@ def _get_feature_name(db: Session, feature_id_str: Optional[str]) -> tuple[Optio return None, None try: from app.models.feature import Feature + feature_id = UUID(feature_id_str) feature = db.query(Feature.id, Feature.title).filter(Feature.id == feature_id).first() if feature: diff --git a/backend/app/routers/llm_preferences.py b/backend/app/routers/llm_preferences.py index e1b83f7..db8a307 100644 --- a/backend/app/routers/llm_preferences.py +++ b/backend/app/routers/llm_preferences.py @@ -1,4 +1,5 @@ """LLM Preference router for managing organization-level LLM selections.""" + from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status @@ -6,12 +7,12 @@ from app.auth.dependencies import get_current_user from app.database import get_async_db -from app.models import User, OrgRole -from app.permissions.context import get_org_context, OrgContext +from app.models import OrgRole, User +from app.permissions.context import OrgContext, get_org_context from app.permissions.helpers import require_org_role from app.schemas.llm_preference import ( - LLMPreferenceUpdate, LLMPreferenceResponse, + LLMPreferenceUpdate, ) from app.services.llm_preference_service import LLMPreferenceService diff --git a/backend/app/routers/mcp_call_logs.py b/backend/app/routers/mcp_call_logs.py index 5b6d537..6bcaee1 100644 --- a/backend/app/routers/mcp_call_logs.py +++ b/backend/app/routers/mcp_call_logs.py @@ -1,19 +1,20 @@ """MCP Call Logs API endpoints for MCP Log feature.""" +from typing import Optional +from uuid import UUID + from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session -from uuid import UUID -from typing import Optional from app.auth.dependencies import get_current_user from app.auth.trial import require_strictly_exempt from app.database import get_db from app.models import User -from app.permissions.context import get_org_context, OrgContext +from app.models.mcp_call_log import MCPCallLog +from app.permissions.context import OrgContext, get_org_context from app.schemas.mcp_call_log import MCPCallLogDetail, MCPCallLogSummary, MCPLogListResponse from app.services.mcp_call_log_service import MCPCallLogService from app.services.org_service import OrgService -from app.models.mcp_call_log import MCPCallLog def _to_summary(log: MCPCallLog) -> MCPCallLogSummary: @@ -56,6 +57,7 @@ def _to_detail(log: MCPCallLog) -> MCPCallLogDetail: created_at=log.created_at, ) + router = APIRouter(tags=["mcp-logs"]) diff --git a/backend/app/routers/mcp_http.py b/backend/app/routers/mcp_http.py index fd368b7..dbbf13f 100644 --- a/backend/app/routers/mcp_http.py +++ b/backend/app/routers/mcp_http.py @@ -4,40 +4,39 @@ import logging from datetime import datetime, timezone from typing import Annotated, Any -from urllib.parse import urlencode, urlparse, parse_qs +from urllib.parse import urlencode from uuid import UUID -from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status, Form +from fastapi import APIRouter, Depends, Form, Header, HTTPException, Request, Response from fastapi.responses import JSONResponse, RedirectResponse from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.auth.dependencies import DualAuth, get_current_user +from app.config import settings from app.database import get_db -from app.models.user import User -from app.models.project import Project from app.models.project_membership import ProjectRole -from app.services.project_service import ProjectService -from app.services.project_share_service import ProjectShareService +from app.models.user import User +from app.permissions.helpers import project_role_rank from app.services.mcp_call_log_service import MCPCallLogService from app.services.mcp_oauth_service import MCPOAuthService -from app.permissions.helpers import project_role_rank -from app.config import settings +from app.services.project_service import ProjectService +from app.services.project_share_service import ProjectShareService logger = logging.getLogger(__name__) # Import MCP VFS tools from app.mcp.tools.read_me_first import read_me_first -from app.mcp.tools.vfs_ls import vfs_ls from app.mcp.tools.vfs_cat import vfs_cat +from app.mcp.tools.vfs_find import vfs_find +from app.mcp.tools.vfs_grep import vfs_grep from app.mcp.tools.vfs_head import vfs_head +from app.mcp.tools.vfs_ls import vfs_ls +from app.mcp.tools.vfs_sed import vfs_sed +from app.mcp.tools.vfs_set_metadata import vfs_set_metadata from app.mcp.tools.vfs_tail import vfs_tail -from app.mcp.tools.vfs_grep import vfs_grep -from app.mcp.tools.vfs_find import vfs_find from app.mcp.tools.vfs_tree import vfs_tree from app.mcp.tools.vfs_write import vfs_write -from app.mcp.tools.vfs_sed import vfs_sed -from app.mcp.tools.vfs_set_metadata import vfs_set_metadata router = APIRouter(tags=["MCP HTTP Transport"]) oauth_router = APIRouter(tags=["OAuth Discovery"]) # Separate router for root-level OAuth endpoints @@ -155,6 +154,7 @@ def jsonrpc_response(request_id: int | str | None, result: Any = None, error: di # Pydantic models for OAuth class OAuthAuthorizeRequest(BaseModel): """Request body for consent approval.""" + client_id: str redirect_uri: str code_challenge: str @@ -167,6 +167,7 @@ class OAuthAuthorizeRequest(BaseModel): class OAuthTokenRequest(BaseModel): """Token request - can be authorization_code or refresh_token grant.""" + grant_type: str code: str | None = None redirect_uri: str | None = None @@ -207,7 +208,7 @@ async def oauth_protected_resource_metadata(): """ raise HTTPException( status_code=404, - detail="Use resource-level discovery at /api/v1/projects/{project_id}/mcp/.well-known/oauth-protected-resource" + detail="Use resource-level discovery at /api/v1/projects/{project_id}/mcp/.well-known/oauth-protected-resource", ) @@ -462,6 +463,7 @@ async def oauth_revoke( class OAuthConsentRequest(BaseModel): """Request body for user consent.""" + client_id: str redirect_uri: str code_challenge: str @@ -474,6 +476,7 @@ class OAuthConsentRequest(BaseModel): class OAuthConsentResponse(BaseModel): """Response for consent endpoint.""" + redirect_url: str @@ -846,9 +849,7 @@ async def mcp_endpoint( if not request.params: return jsonrpc_response( request.id, - error={ - "code": -32602, "message": "Invalid params", "data": "Missing params" - }, + error={"code": -32602, "message": "Invalid params", "data": "Missing params"}, ) tool_name = request.params.get("name") @@ -861,9 +862,7 @@ async def mcp_endpoint( if not tool_name: return jsonrpc_response( request.id, - error={ - "code": -32602, "message": "Invalid params", "data": "Missing tool name" - }, + error={"code": -32602, "message": "Invalid params", "data": "Missing tool name"}, ) # Check if tool exists @@ -975,13 +974,8 @@ async def mcp_endpoint( # Wrap result in MCP content format # MCP tools/call responses must have content array with text blocks mcp_result = { - "content": [ - { - "type": "text", - "text": json.dumps(result, indent=2, default=str) - } - ], - "isError": False + "content": [{"type": "text", "text": json.dumps(result, indent=2, default=str)}], + "isError": False, } # Log the successful call @@ -1032,9 +1026,7 @@ async def mcp_endpoint( return jsonrpc_response(request.id, error=error_obj) except ValueError as e: # Tool-specific error (e.g., project not found) - error_obj = { - "code": -32000, "message": "Tool execution error", "data": str(e) - } + error_obj = {"code": -32000, "message": "Tool execution error", "data": str(e)} finished_at = datetime.now(timezone.utc) _log_mcp_call( db=db, @@ -1055,9 +1047,7 @@ async def mcp_endpoint( return jsonrpc_response(request.id, error=error_obj) except Exception as e: # Unexpected error - error_obj = { - "code": -32603, "message": "Internal error", "data": str(e) - } + error_obj = {"code": -32603, "message": "Internal error", "data": str(e)} finished_at = datetime.now(timezone.utc) _log_mcp_call( db=db, @@ -1081,217 +1071,129 @@ async def mcp_endpoint( elif request.method == "tools/list": # Define JSON Schema for each VFS tool tool_schemas = { - "readMeFirst": { - "type": "object", - "properties": {}, - "additionalProperties": False - }, + "readMeFirst": {"type": "object", "properties": {}, "additionalProperties": False}, "ls": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "Directory path to list", - "default": "/" - }, - "long": { - "type": "boolean", - "description": "Include detailed metadata", - "default": False - }, - "all": { - "type": "boolean", - "description": "Include hidden files", - "default": False - } + "path": {"type": "string", "description": "Directory path to list", "default": "/"}, + "long": {"type": "boolean", "description": "Include detailed metadata", "default": False}, + "all": {"type": "boolean", "description": "Include hidden files", "default": False}, }, - "additionalProperties": False + "additionalProperties": False, }, "cat": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "File path to read" - }, + "path": {"type": "string", "description": "File path to read"}, "branch_name": { "type": "string", - "description": "Git branch name (use current git branch, or 'main' if not in a git repo)" - } + "description": "Git branch name (use current git branch, or 'main' if not in a git repo)", + }, }, "required": ["path"], - "additionalProperties": False + "additionalProperties": False, }, "head": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "File path to read" - }, - "lines": { - "type": "integer", - "description": "Number of lines to display", - "default": 10 - } + "path": {"type": "string", "description": "File path to read"}, + "lines": {"type": "integer", "description": "Number of lines to display", "default": 10}, }, "required": ["path"], - "additionalProperties": False + "additionalProperties": False, }, "tail": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "File path to read" - }, - "lines": { - "type": "integer", - "description": "Number of lines to display", - "default": 10 - } + "path": {"type": "string", "description": "File path to read"}, + "lines": {"type": "integer", "description": "Number of lines to display", "default": 10}, }, "required": ["path"], - "additionalProperties": False + "additionalProperties": False, }, "grep": { "type": "object", "properties": { - "pattern": { - "type": "string", - "description": "Regex pattern to search for" - }, - "path": { - "type": "string", - "description": "Starting path for search", - "default": "/" - }, - "ignore_case": { - "type": "boolean", - "description": "Case-insensitive search", - "default": False - }, - "context": { - "type": "integer", - "description": "Lines of context around matches", - "default": 0 - } + "pattern": {"type": "string", "description": "Regex pattern to search for"}, + "path": {"type": "string", "description": "Starting path for search", "default": "/"}, + "ignore_case": {"type": "boolean", "description": "Case-insensitive search", "default": False}, + "context": {"type": "integer", "description": "Lines of context around matches", "default": 0}, }, "required": ["pattern"], - "additionalProperties": False + "additionalProperties": False, }, "find": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "Starting path for search", - "default": "/" - }, - "name": { - "type": "string", - "description": "Glob pattern for name matching (e.g., '*.md')" - }, - "type": { - "type": "string", - "enum": ["f", "d"], - "description": "'f' for files, 'd' for directories" - } + "path": {"type": "string", "description": "Starting path for search", "default": "/"}, + "name": {"type": "string", "description": "Glob pattern for name matching (e.g., '*.md')"}, + "type": {"type": "string", "enum": ["f", "d"], "description": "'f' for files, 'd' for directories"}, }, - "additionalProperties": False + "additionalProperties": False, }, "tree": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "Root path for the tree", - "default": "/" - }, - "depth": { - "type": "integer", - "description": "Maximum depth to display", - "default": 3 - } + "path": {"type": "string", "description": "Root path for the tree", "default": "/"}, + "depth": {"type": "integer", "description": "Maximum depth to display", "default": 3}, }, - "additionalProperties": False + "additionalProperties": False, }, "write": { "type": "object", "properties": { "path": { "type": "string", - "description": "File path to write (only notes.md files or /for-coding-agents/*)" - }, - "content": { - "type": "string", - "description": "Content to write" - }, - "append": { - "type": "boolean", - "description": "Append to existing content", - "default": True + "description": "File path to write (only notes.md files or /for-coding-agents/*)", }, + "content": {"type": "string", "description": "Content to write"}, + "append": {"type": "boolean", "description": "Append to existing content", "default": True}, "branch_name": { "type": "string", - "description": "Git branch name (use current git branch, or 'main' if not in a git repo)" + "description": "Git branch name (use current git branch, or 'main' if not in a git repo)", }, "repo_path": { "type": "string", - "description": "Remote repository URL (e.g., 'https://github.com/org/repo'). Ignore if no remote is set." - } + "description": "Remote repository URL (e.g., 'https://github.com/org/repo'). Ignore if no remote is set.", + }, }, "required": ["path", "content"], - "additionalProperties": False + "additionalProperties": False, }, "sed": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "File path (only notes.md or /for-coding-agents/*)" - }, - "pattern": { - "type": "string", - "description": "Regex pattern to match" - }, - "replacement": { - "type": "string", - "description": "Replacement string" - }, + "path": {"type": "string", "description": "File path (only notes.md or /for-coding-agents/*)"}, + "pattern": {"type": "string", "description": "Regex pattern to match"}, + "replacement": {"type": "string", "description": "Replacement string"}, "flags": { "type": "string", "description": "Flags: g=global, i=case-insensitive, m=multiline", - "default": "" - } + "default": "", + }, }, "required": ["path", "pattern", "replacement"], - "additionalProperties": False + "additionalProperties": False, }, "setMetadataValueForKey": { "type": "object", "properties": { - "path": { - "type": "string", - "description": "Directory path" - }, + "path": {"type": "string", "description": "Directory path"}, "key": { "type": "string", - "description": "Metadata key (e.g., 'completion_status', 'is_complete', 'in_progress')" + "description": "Metadata key (e.g., 'completion_status', 'is_complete', 'in_progress')", }, - "value": { - "description": "Metadata value (string, number, boolean, or JSON)" - } + "value": {"description": "Metadata value (string, number, boolean, or JSON)"}, }, "required": ["path", "key", "value"], - "additionalProperties": False - } + "additionalProperties": False, + }, } # Add coding_agent_name to all tool schemas for analytics coding_agent_schema = { "type": "string", - "description": "Optional: Identify your coding agent (e.g., 'claude_code', 'cursor', 'cline'). Helps with analytics." + "description": "Optional: Identify your coding agent (e.g., 'claude_code', 'cursor', 'cline'). Helps with analytics.", } for schema in tool_schemas.values(): schema["properties"]["coding_agent_name"] = coding_agent_schema @@ -1300,7 +1202,9 @@ async def mcp_endpoint( { "name": name, "description": func.__doc__.strip() if func.__doc__ else "No description", - "inputSchema": tool_schemas.get(name, {"type": "object", "properties": {}, "additionalProperties": False}) + "inputSchema": tool_schemas.get( + name, {"type": "object", "properties": {}, "additionalProperties": False} + ), } for name, func in MCP_TOOLS.items() ] diff --git a/backend/app/routers/mcp_images.py b/backend/app/routers/mcp_images.py index a115ff5..dfaa2bf 100644 --- a/backend/app/routers/mcp_images.py +++ b/backend/app/routers/mcp_images.py @@ -15,8 +15,8 @@ from app.database import get_db from app.models.mcp_image_submission import ( - MCPImageSubmission, MCP_IMAGE_SUBMISSION_EXPIRY_HOURS, + MCPImageSubmission, ) from app.services.image_service import ImageService @@ -145,8 +145,7 @@ async def upload_mcp_image( content_type=content_type, image_data=content, size_bytes=len(content), - expires_at=datetime.now(timezone.utc) - + timedelta(hours=MCP_IMAGE_SUBMISSION_EXPIRY_HOURS), + expires_at=datetime.now(timezone.utc) + timedelta(hours=MCP_IMAGE_SUBMISSION_EXPIRY_HOURS), ) try: diff --git a/backend/app/routers/modules.py b/backend/app/routers/modules.py index 7152c9d..2d43636 100644 --- a/backend/app/routers/modules.py +++ b/backend/app/routers/modules.py @@ -1,29 +1,30 @@ """Router for module operations.""" + from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole -from app.permissions.context import get_project_context, ProjectContext +from app.models import ProjectRole, User +from app.models.module import ModuleType +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role +from app.schemas.brainstorming_phase import ModuleProgressResponse +from app.schemas.implementation import ClearStatusNotesResponse from app.schemas.module import ( + ModuleArchiveResponse, ModuleCreate, - ModuleUpdate, - ModuleResponse, ModuleListResponse, - ModuleArchiveResponse, + ModuleResponse, + ModuleUpdate, ) -from app.schemas.brainstorming_phase import ModuleProgressResponse -from app.schemas.implementation import ClearStatusNotesResponse -from app.models.module import ModuleType -from app.services.module_service import ModuleService -from app.services.implementation_service import ImplementationService -from app.services.activity_log_service import ActivityLogService, ActivityEventTypes +from app.services.activity_log_service import ActivityEventTypes, ActivityLogService from app.services.brainstorming_phase_service import BrainstormingPhaseService +from app.services.implementation_service import ImplementationService +from app.services.module_service import ModuleService router = APIRouter(tags=["modules"]) @@ -85,8 +86,12 @@ def create_module( def list_modules( project_id: str, include_archived: bool = Query(False, description="Include archived modules"), - brainstorming_phase_id: Optional[str] = Query(None, description="Filter by brainstorming phase ID or URL identifier"), - module_type: Optional[ModuleType] = Query(None, description="Filter by module type (conversation or implementation)"), + brainstorming_phase_id: Optional[str] = Query( + None, description="Filter by brainstorming phase ID or URL identifier" + ), + module_type: Optional[ModuleType] = Query( + None, description="Filter by module type (conversation or implementation)" + ), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), project_context: ProjectContext = Depends(get_project_context), diff --git a/backend/app/routers/org_chats.py b/backend/app/routers/org_chats.py index 9ee2298..b581cd2 100644 --- a/backend/app/routers/org_chats.py +++ b/backend/app/routers/org_chats.py @@ -1,31 +1,30 @@ """API endpoints for org-scoped pre-phase discussions (chats).""" + import logging from typing import Optional -from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user from app.database import get_db +from app.models.job import Job, JobStatus, JobType from app.models.user import User -from app.models.job import Job, JobType, JobStatus -from app.services.project_chat_service import ProjectChatService -from app.services.job_service import JobService -from app.services.kafka_producer import get_kafka_producer -from app.permissions.context import get_org_context, OrgContext +from app.permissions.context import OrgContext, get_org_context from app.schemas.project_chat import ( + CreateProjectFromProjectChatRequest, + CreateProjectFromProjectChatResponse, + OrgProjectChatListItem, + OrgProjectChatListResponse, + ProjectChatMessageResponse, ProjectChatResponse, ProjectChatWithMessages, - ProjectChatMessageResponse, SendMessageRequest, SendMessageResponse, - OrgProjectChatListItem, - OrgProjectChatListResponse, - CreateProjectFromProjectChatRequest, - CreateProjectFromProjectChatResponse, ) - +from app.services.job_service import JobService +from app.services.kafka_producer import get_kafka_producer +from app.services.project_chat_service import ProjectChatService logger = logging.getLogger(__name__) @@ -164,8 +163,7 @@ async def get_org_discussion( if discussion.project_id is not None: # This is a project-scoped discussion - redirect them to use the project endpoint raise HTTPException( - status_code=400, - detail="This discussion is project-scoped. Use the project endpoint instead." + status_code=400, detail="This discussion is project-scoped. Use the project endpoint instead." ) messages = sorted(discussion.messages, key=lambda m: m.created_at) @@ -197,10 +195,7 @@ async def delete_org_discussion( # Prevent deletion if a project was created from this discussion if discussion.created_project_id is not None: - raise HTTPException( - status_code=400, - detail="Cannot delete a discussion that has created a project" - ) + raise HTTPException(status_code=400, detail="Cannot delete a discussion that has created a project") ProjectChatService.delete_project_chat( db=db, @@ -233,8 +228,7 @@ async def send_message( # Check if discussion is readonly (project already created) if discussion.is_readonly: raise HTTPException( - status_code=400, - detail="Cannot send messages to a discussion that has already created a project" + status_code=400, detail="Cannot send messages to a discussion that has already created a project" ) # Add user message to discussion first @@ -256,10 +250,7 @@ async def send_message( # Check if already generating (only relevant when we're about to generate) if discussion.is_generating: - raise HTTPException( - status_code=400, - detail="Bot is already generating a response" - ) + raise HTTPException(status_code=400, detail="Bot is already generating a response") # Create job - note: project_id is None for org-scoped discussions job = Job( @@ -313,10 +304,7 @@ async def send_message( detail="Failed to queue response generation job", ) - logger.info( - f"Created org-scoped pre-phase discussion response job {job.id} " - f"for discussion {project_chat_id}" - ) + logger.info(f"Created org-scoped pre-phase discussion response job {job.id} for discussion {project_chat_id}") return SendMessageResponse( job_id=job.id, @@ -348,23 +336,16 @@ async def retry_message( # Check if discussion is readonly if discussion.is_readonly: raise HTTPException( - status_code=400, - detail="Cannot retry messages in a discussion that has already created a project" + status_code=400, detail="Cannot retry messages in a discussion that has already created a project" ) # Check if there's an error to retry if not discussion.ai_error_user_message: - raise HTTPException( - status_code=400, - detail="No failed message to retry" - ) + raise HTTPException(status_code=400, detail="No failed message to retry") # Check if already generating if discussion.is_generating: - raise HTTPException( - status_code=400, - detail="Bot is already generating a response" - ) + raise HTTPException(status_code=400, detail="Bot is already generating a response") # Get the user message that failed user_message = discussion.ai_error_user_message @@ -420,10 +401,7 @@ async def retry_message( detail="Failed to queue response generation job", ) - logger.info( - f"Created org-scoped pre-phase discussion retry job {job.id} " - f"for discussion {discussion.id}" - ) + logger.info(f"Created org-scoped pre-phase discussion retry job {job.id} for discussion {discussion.id}") return SendMessageResponse( job_id=job.id, diff --git a/backend/app/routers/orgs.py b/backend/app/routers/orgs.py index ad4f4d2..a50a2ff 100644 --- a/backend/app/routers/orgs.py +++ b/backend/app/routers/orgs.py @@ -6,6 +6,7 @@ Enterprise SSO operations (Scalekit sync) are handled via the plugin registry — see app/plugin_registry.py. """ + import logging from typing import Annotated, Optional from uuid import UUID @@ -17,9 +18,8 @@ from app.database import get_db from app.models import OrgMembership, OrgRole, User from app.models.user_group_membership import UserGroupMembership -from app.models.user_identity import UserIdentity -from app.models.identity_provider import IdentityProvider from app.permissions import OrgContext, get_org_context, require_org_role +from app.plugin_registry import get_plugin_registry from app.schemas.org import ( OrgMemberResponse, OrgMemberRoleUpdateRequest, @@ -29,12 +29,11 @@ ) from app.schemas.project_share import ( ShareableSubject, - ShareableSubjectType, ShareableSubjectsResponse, + ShareableSubjectType, ) from app.services.org_service import OrgService from app.services.user_group_service import UserGroupService -from app.plugin_registry import get_plugin_registry logger = logging.getLogger(__name__) @@ -173,10 +172,7 @@ def search_shareable_subjects( # Filter by search term if search_term: email_match = search_term in user.email.lower() - name_match = ( - user.display_name - and search_term in user.display_name.lower() - ) + name_match = user.display_name and search_term in user.display_name.lower() if not (email_match or name_match): continue @@ -198,10 +194,7 @@ def search_shareable_subjects( # Filter by search term if search_term: name_match = search_term in group.name.lower() - desc_match = ( - group.description - and search_term in group.description.lower() - ) + desc_match = group.description and search_term in group.description.lower() if not (name_match or desc_match): continue @@ -273,11 +266,7 @@ async def update_org( if updated_org: # If enterprise plugin provides on_rename_org and org is SSO-linked, sync registry = get_plugin_registry() - if ( - registry.invitation_plugin - and registry.invitation_plugin.on_rename_org - and updated_org.organization_id - ): + if registry.invitation_plugin and registry.invitation_plugin.on_rename_org and updated_org.organization_id: try: await registry.invitation_plugin.on_rename_org( org=updated_org, @@ -285,9 +274,7 @@ async def update_org( ) except Exception as e: # Log but don't fail the update if sync fails - logger.warning( - f"Failed to sync org name via plugin for org {updated_org.id}: {e}" - ) + logger.warning(f"Failed to sync org name via plugin for org {updated_org.id}: {e}") return OrgResponse.model_validate(updated_org) @@ -407,10 +394,7 @@ def update_org_member_role( db.commit() db.refresh(membership) - logger.info( - f"Updated role for user {user_id} in org {org_id} to {new_role.value} " - f"by user {org_context.user.id}" - ) + logger.info(f"Updated role for user {user_id} in org {org_id} to {new_role.value} by user {org_context.user.id}") return OrgMemberResponse( user_id=membership.user.id, @@ -482,11 +466,7 @@ async def remove_org_member( # If enterprise plugin provides on_remove_member and org is SSO-linked, sync registry = get_plugin_registry() - if ( - registry.invitation_plugin - and registry.invitation_plugin.on_remove_member - and org_context.org.organization_id - ): + if registry.invitation_plugin and registry.invitation_plugin.on_remove_member and org_context.org.organization_id: try: await registry.invitation_plugin.on_remove_member( db=db, @@ -494,9 +474,7 @@ async def remove_org_member( user_id=user_id, ) except Exception as e: - logger.warning( - f"Failed to remove member via plugin (continuing with local removal): {e}" - ) + logger.warning(f"Failed to remove member via plugin (continuing with local removal): {e}") # Remove from all groups in this org group_memberships = ( @@ -515,7 +493,4 @@ async def remove_org_member( db.delete(membership) db.commit() - logger.info( - f"Removed member {user_id} from org {org_id} " - f"(removed from {len(group_memberships)} groups)" - ) + logger.info(f"Removed member {user_id} from org {org_id} (removed from {len(group_memberships)} groups)") diff --git a/backend/app/routers/phase_containers.py b/backend/app/routers/phase_containers.py index 53840be..6c277ce 100644 --- a/backend/app/routers/phase_containers.py +++ b/backend/app/routers/phase_containers.py @@ -1,34 +1,34 @@ """Router for phase container operations.""" + import logging from typing import List -from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ProjectRole -from app.models.job import JobType, JobStatus -from app.permissions.context import get_project_context, ProjectContext +from app.models import ProjectRole, User +from app.models.job import JobStatus, JobType +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role -from app.schemas.phase_container import ( - PhaseContainerCreate, - PhaseContainerUpdate, - PhaseContainerResponse, - PhaseContainerListResponse, - ExtensionCreateRequest, - ExtensionPreviewResponse, -) from app.schemas.brainstorming_phase import ( BrainstormingPhaseListResponse, BrainstormingPhaseResponse, ) -from app.services.phase_container_service import PhaseContainerService +from app.schemas.phase_container import ( + ExtensionCreateRequest, + ExtensionPreviewResponse, + PhaseContainerCreate, + PhaseContainerListResponse, + PhaseContainerResponse, + PhaseContainerUpdate, +) +from app.services.activity_log_service import ActivityEventTypes, ActivityLogService from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.job_service import JobService from app.services.kafka_producer import get_kafka_producer -from app.services.activity_log_service import ActivityLogService, ActivityEventTypes +from app.services.phase_container_service import PhaseContainerService logger = logging.getLogger(__name__) @@ -656,9 +656,7 @@ def get_extension_preview( Requires MEMBER role or higher. """ - container, _ = _resolve_container_with_membership( - db, identifier, current_user, ProjectRole.MEMBER - ) + container, _ = _resolve_container_with_membership(db, identifier, current_user, ProjectRole.MEMBER) preview = PhaseContainerService.get_extension_preview(db, container.id) if preview is None: @@ -689,9 +687,7 @@ async def create_extension_phase( Requires MEMBER role or higher. """ - container, _ = _resolve_container_with_membership( - db, identifier, current_user, ProjectRole.MEMBER - ) + container, _ = _resolve_container_with_membership(db, identifier, current_user, ProjectRole.MEMBER) try: phase = PhaseContainerService.create_extension_phase( @@ -749,9 +745,7 @@ async def create_extension_phase( ) if not success: - logger.warning( - f"Failed to publish brainstorm generation job {job.id} for extension phase {phase.id}" - ) + logger.warning(f"Failed to publish brainstorm generation job {job.id} for extension phase {phase.id}") JobService.update_job_status( db=db, job_id=job.id, @@ -759,9 +753,7 @@ async def create_extension_phase( error_message="Failed to publish job to Kafka", ) else: - logger.info( - f"Brainstorm generation job {job.id} queued for extension phase {phase.id}" - ) + logger.info(f"Brainstorm generation job {job.id} queued for extension phase {phase.id}") db.commit() return phase diff --git a/backend/app/routers/plan_recommendations.py b/backend/app/routers/plan_recommendations.py index 4e90efa..9103322 100644 --- a/backend/app/routers/plan_recommendations.py +++ b/backend/app/routers/plan_recommendations.py @@ -10,7 +10,8 @@ from typing import Annotated from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status as http_status +from fastapi import APIRouter, Depends, HTTPException +from fastapi import status as http_status from sqlalchemy.orm import Session from app.auth.platform_admin import require_platform_admin @@ -128,9 +129,7 @@ def dismiss_recommendation( Platform admins only. Enterprise only. """ _require_plan_plugin() - recommendation = PlanRecommendationService.dismiss_recommendation( - db, recommendation_id, user.id - ) + recommendation = PlanRecommendationService.dismiss_recommendation(db, recommendation_id, user.id) if not recommendation: raise HTTPException(status_code=404, detail="Recommendation not found") diff --git a/backend/app/routers/platform_settings.py b/backend/app/routers/platform_settings.py index cd5567c..7a2b42a 100644 --- a/backend/app/routers/platform_settings.py +++ b/backend/app/routers/platform_settings.py @@ -2,6 +2,7 @@ All endpoints in this router require platform admin privileges. """ + import json import logging from typing import Annotated @@ -21,7 +22,6 @@ GitHubOAuthEnvConfigResponse, GitHubOAuthSettingsResponse, GitHubOAuthSettingsUpdate, - OrgPlanInfo, PlatformAdminCheckResponse, PlatformConnectorCreate, PlatformConnectorResponse, @@ -38,12 +38,10 @@ UserPlanSearchResult, UserPlanUpdateRequest, UserPlanUpdateResponse, - UserTrialInfo, WebSearchEnvConfigResponse, ) from app.services.platform_settings_service import PlatformSettingsService - logger = logging.getLogger(__name__) router = APIRouter(prefix="/platform", tags=["platform-settings"]) @@ -51,6 +49,7 @@ # ==================== Platform Admin Check ==================== + @router.get( "/is-admin", response_model=PlatformAdminCheckResponse, @@ -68,6 +67,7 @@ async def check_platform_admin( # ==================== Platform Connectors ==================== + @router.get( "/connectors", response_model=list[PlatformConnectorResponse], @@ -75,9 +75,7 @@ async def check_platform_admin( async def list_platform_connectors( _: Annotated[User, Depends(require_platform_admin)], db: Annotated[AsyncSession, Depends(get_async_db)], - connector_type: str | None = Query( - None, description="Filter by connector type (llm, email, object_storage)" - ), + connector_type: str | None = Query(None, description="Filter by connector type (llm, email, object_storage)"), ) -> list[PlatformConnectorResponse]: """List all platform connectors. @@ -143,10 +141,7 @@ async def create_platform_connector( detail=str(e), ) - logger.info( - f"Platform connector created: {connector.id} " - f"({connector.connector_type}/{connector.provider})" - ) + logger.info(f"Platform connector created: {connector.id} ({connector.connector_type}/{connector.provider})") return PlatformConnectorResponse.model_validate(connector) @@ -279,7 +274,7 @@ async def test_platform_connector( success=False, message=f"Model not found: {str(e)}", ) - except litellm.RateLimitError as e: + except litellm.RateLimitError: # Rate limit means auth worked, just hitting limits return PlatformConnectorTestResult( success=True, @@ -306,6 +301,7 @@ async def test_platform_connector( ) import httpx + try: async with httpx.AsyncClient() as client: response = await client.get( @@ -458,7 +454,7 @@ async def test_platform_connector( success=False, message=f"Authentication failed: {str(e)}", ) - except litellm.RateLimitError as e: + except litellm.RateLimitError: # Rate limit means auth worked, just hitting limits return PlatformConnectorTestResult( success=True, @@ -508,6 +504,7 @@ async def test_platform_connector( # ==================== Platform Settings ==================== + @router.get( "/settings", response_model=PlatformSettingsResponse, @@ -640,6 +637,7 @@ async def update_freemium_settings( # ==================== Test Email ==================== + @router.post( "/send-test-email", response_model=SendTestEmailResponse, @@ -693,9 +691,7 @@ async def send_test_email( "Content-Type": "application/json", }, json={ - "personalizations": [ - {"to": [{"email": request.to_email}]} - ], + "personalizations": [{"to": [{"email": request.to_email}]}], "from": {"email": from_email, "name": from_name}, "subject": "MFBT Platform - Test Email", "content": [ @@ -747,6 +743,7 @@ async def send_test_email( # ==================== Email Environment Config ==================== + @router.get( "/email-env-config", response_model=EmailEnvConfigResponse, @@ -772,6 +769,7 @@ async def check_email_env_config( # ==================== Web Search Environment Config ==================== + @router.get( "/web-search-env-config", response_model=WebSearchEnvConfigResponse, @@ -989,12 +987,13 @@ async def list_platform_projects( Returns: List of projects with org info """ - from sqlalchemy import select, func, union - from app.models.project import Project, ProjectStatus + from sqlalchemy import func, select, union + + from app.models.org_membership import OrgMembership from app.models.organization import Organization + from app.models.project import Project, ProjectStatus from app.models.project_share import ProjectShare, ShareSubjectType from app.models.user_group_membership import UserGroupMembership - from app.models.org_membership import OrgMembership # Build the query with join to get org name query = ( @@ -1020,50 +1019,33 @@ async def list_platform_projects( # Apply user access filter if user_id provided if user_id: # Get project IDs from direct user shares - direct_query = ( - select(ProjectShare.project_id) - .where( - ProjectShare.subject_type == ShareSubjectType.USER, - ProjectShare.subject_id == user_id, - ) + direct_query = select(ProjectShare.project_id).where( + ProjectShare.subject_type == ShareSubjectType.USER, + ProjectShare.subject_id == user_id, ) # Get group IDs the user belongs to user_group_ids = ( - select(UserGroupMembership.group_id) - .where(UserGroupMembership.user_id == user_id) - .scalar_subquery() + select(UserGroupMembership.group_id).where(UserGroupMembership.user_id == user_id).scalar_subquery() ) # Get project IDs from group shares - group_query = ( - select(ProjectShare.project_id) - .where( - ProjectShare.subject_type == ShareSubjectType.GROUP, - ProjectShare.subject_id.in_(user_group_ids), - ) + group_query = select(ProjectShare.project_id).where( + ProjectShare.subject_type == ShareSubjectType.GROUP, + ProjectShare.subject_id.in_(user_group_ids), ) # Get org IDs user belongs to - user_org_ids = ( - select(OrgMembership.org_id) - .where(OrgMembership.user_id == user_id) - .scalar_subquery() - ) + user_org_ids = select(OrgMembership.org_id).where(OrgMembership.user_id == user_id).scalar_subquery() # Get project IDs from org shares (where user is org member) - org_query = ( - select(ProjectShare.project_id) - .where( - ProjectShare.subject_type == ShareSubjectType.ORG, - ProjectShare.subject_id.in_(user_org_ids), - ) + org_query = select(ProjectShare.project_id).where( + ProjectShare.subject_type == ShareSubjectType.ORG, + ProjectShare.subject_id.in_(user_org_ids), ) # Combine with UNION and filter - accessible_project_ids = union( - direct_query, group_query, org_query - ).subquery() + accessible_project_ids = union(direct_query, group_query, org_query).subquery() query = query.where(Project.id.in_(select(accessible_project_ids.c.project_id))) # Order by name and limit @@ -1079,17 +1061,13 @@ async def list_platform_projects( name=row.name, org_id=row.org_id, org_name=row.org_name, - status=row.status.value if hasattr(row.status, 'value') else str(row.status), + status=row.status.value if hasattr(row.status, "value") else str(row.status), ) for row in rows ] # Get total count (without limit) - count_query = ( - select(func.count()) - .select_from(Project) - .where(Project.deleted_at.is_(None)) - ) + count_query = select(func.count()).select_from(Project).where(Project.deleted_at.is_(None)) if not include_archived: count_query = count_query.where(Project.status != ProjectStatus.ARCHIVED) if search: diff --git a/backend/app/routers/project_chat_images.py b/backend/app/routers/project_chat_images.py index ab2abde..52bfa53 100644 --- a/backend/app/routers/project_chat_images.py +++ b/backend/app/routers/project_chat_images.py @@ -1,6 +1,6 @@ """API routes for project-chat discussion image uploads.""" + import logging -from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status @@ -8,12 +8,12 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.database import get_db, get_async_db from app.auth.dependencies import get_current_user -from app.models import User, Project +from app.database import get_async_db, get_db +from app.models import User from app.models.job import JobType -from app.models.project_chat import ProjectChat, ProjectChatMessage -from app.permissions.context import get_org_context_from_project, get_org_context +from app.models.project_chat import ProjectChatMessage +from app.permissions.context import get_org_context, get_org_context_from_project from app.services.image_service import ImageService from app.services.job_service import JobService from app.services.kafka_producer import get_sync_kafka_producer @@ -76,8 +76,8 @@ async def upload_discussion_image_project_scoped( ) -> ImageUploadResponse: """Upload an image for a project-scoped project-chat discussion message.""" # Verify project access - from app.services.project_service import ProjectService from app.services.project_chat_service import ProjectChatService + from app.services.project_service import ProjectService project = ProjectService.get_by_identifier(db, project_id) if not project: @@ -192,7 +192,7 @@ def _find_discussion_image_id_by_s3_key(db: Session, s3_key: str) -> str | None: """Find image ID by S3 key in project-chat discussion messages.""" messages = db.query(ProjectChatMessage).filter(ProjectChatMessage.images.isnot(None)).all() for msg in messages: - for img in (msg.images or []): + for img in msg.images or []: if img.get("s3_key") == s3_key or img.get("thumbnail_s3_key") == s3_key: return img.get("id") return None @@ -247,7 +247,7 @@ def _find_image_metadata_by_id(db: Session, image_id: str) -> dict | None: """ messages = db.query(ProjectChatMessage).filter(ProjectChatMessage.images.isnot(None)).all() for msg in messages: - for img in (msg.images or []): + for img in msg.images or []: if img.get("id") == image_id: return img return None @@ -283,8 +283,8 @@ async def annotate_discussion_image_project_scoped( Job ID for tracking the annotation job """ # Verify project access - from app.services.project_service import ProjectService from app.services.project_chat_service import ProjectChatService + from app.services.project_service import ProjectService project = ProjectService.get_by_identifier(db, project_id) if not project: diff --git a/backend/app/routers/project_chats.py b/backend/app/routers/project_chats.py index 371baa1..b2340af 100644 --- a/backend/app/routers/project_chats.py +++ b/backend/app/routers/project_chats.py @@ -1,45 +1,44 @@ """API endpoints for pre-phase discussions.""" + import logging from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, status, Query +from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session +from starlette.responses import Response -from app.database import get_db from app.auth.dependencies import get_current_user -from app.models.user import User -from app.models.project import Project -from app.models.job import Job, JobType, JobStatus +from app.database import get_db +from app.models.job import Job, JobStatus, JobType from app.models.project_chat import ProjectChatMessage, ProjectChatMessageType, ProjectChatVisibility -from starlette.responses import Response -from app.services.project_chat_service import ProjectChatService -from app.services.job_service import JobService -from app.services.kafka_producer import get_kafka_producer -from app.services.project_share_service import ProjectShareService -from app.services.project_service import ProjectService +from app.models.user import User from app.schemas.project_chat import ( - CreateProjectChatRequest, - ProjectChatResponse, - ProjectChatWithMessages, - ProjectChatMessageResponse, - SendMessageRequest, - SendMessageResponse, - CreatePhaseFromProjectChatResponse, + CreatedFeatureInfo, CreateFeatureFromProjectChatRequest, CreateFeatureFromProjectChatResponse, - CreatedFeatureInfo, + CreatePhaseFromProjectChatResponse, + CreateProjectChatRequest, ProjectChatListItem, ProjectChatListResponse, + ProjectChatMessageResponse, + ProjectChatResponse, ProjectChatStartOverResponse, + ProjectChatWithMessages, + SendMessageRequest, + SendMessageResponse, UpdateVisibilityRequest, ) from app.schemas.thread_item import ( + Reaction, ToggleReactionRequest, ToggleReactionResponse, - Reaction, ) - +from app.services.job_service import JobService +from app.services.kafka_producer import get_kafka_producer +from app.services.project_chat_service import ProjectChatService +from app.services.project_service import ProjectService +from app.services.project_share_service import ProjectShareService logger = logging.getLogger(__name__) @@ -148,15 +147,11 @@ async def create_project_chat( if request and request.target_container_id: from app.services.phase_container_service import PhaseContainerService - container = PhaseContainerService.get_by_identifier( - db, request.target_container_id - ) + container = PhaseContainerService.get_by_identifier(db, request.target_container_id) if not container or container.project_id != project.id: raise HTTPException(status_code=404, detail="Container not found") if container.archived_at is not None: - raise HTTPException( - status_code=400, detail="Cannot target an archived container" - ) + raise HTTPException(status_code=400, detail="Cannot target an archived container") resolved_container_id = container.id discussion = ProjectChatService.create_project_chat( @@ -224,10 +219,7 @@ async def create_project_chat( detail="Failed to queue response generation job", ) - logger.info( - f"Created pre-phase discussion {discussion.id} with initial message, " - f"job {job.id} queued" - ) + logger.info(f"Created pre-phase discussion {discussion.id} with initial message, job {job.id} queued") # Update response with initial message info response.initial_message_id = message.id @@ -382,10 +374,7 @@ async def delete_project_chat( # Prevent deletion if a phase was created from this discussion if discussion.created_phase_id is not None: - raise HTTPException( - status_code=400, - detail="Cannot delete a discussion that has created a phase" - ) + raise HTTPException(status_code=400, detail="Cannot delete a discussion that has created a phase") ProjectChatService.delete_project_chat( db=db, @@ -425,8 +414,7 @@ async def send_message( # Check if discussion is readonly (phase already created) if discussion.is_readonly: raise HTTPException( - status_code=400, - detail="Cannot send messages to a discussion that has already created a phase" + status_code=400, detail="Cannot send messages to a discussion that has already created a phase" ) # Add user message to discussion first @@ -460,10 +448,7 @@ async def send_message( # Always queue a job - gating agent in worker decides if bot should respond # Check if already generating if discussion.is_generating: - raise HTTPException( - status_code=400, - detail="Bot is already generating a response" - ) + raise HTTPException(status_code=400, detail="Bot is already generating a response") # Create job job = Job( @@ -521,10 +506,7 @@ async def send_message( detail="Failed to queue response generation job", ) - logger.info( - f"Created pre-phase discussion response job {job.id} " - f"for discussion {discussion_id}" - ) + logger.info(f"Created pre-phase discussion response job {job.id} for discussion {discussion_id}") return SendMessageResponse( job_id=job.id, @@ -563,23 +545,16 @@ async def retry_message( # Check if discussion is readonly (phase already created) if discussion.is_readonly: raise HTTPException( - status_code=400, - detail="Cannot retry messages in a discussion that has already created a phase" + status_code=400, detail="Cannot retry messages in a discussion that has already created a phase" ) # Check if there's an error to retry if not discussion.ai_error_user_message: - raise HTTPException( - status_code=400, - detail="No failed message to retry" - ) + raise HTTPException(status_code=400, detail="No failed message to retry") # Check if already generating if discussion.is_generating: - raise HTTPException( - status_code=400, - detail="Bot is already generating a response" - ) + raise HTTPException(status_code=400, detail="Bot is already generating a response") # Get the user message that failed user_message = discussion.ai_error_user_message @@ -635,10 +610,7 @@ async def retry_message( detail="Failed to queue response generation job", ) - logger.info( - f"Created pre-phase discussion retry job {job.id} " - f"for discussion {discussion.id}" - ) + logger.info(f"Created pre-phase discussion retry job {job.id} for discussion {discussion.id}") # Note: We don't add a new user message for retry - we reuse the stored one # The message_id we return is a placeholder since we're not creating a new message @@ -684,10 +656,7 @@ async def cancel_discussion_response( # Check if discussion is readonly (phase already created) if discussion.is_readonly: - raise HTTPException( - status_code=400, - detail="Cannot cancel in a discussion that has already created a phase" - ) + raise HTTPException(status_code=400, detail="Cannot cancel in a discussion that has already created a phase") # Cancel any running jobs for this discussion cancelled = JobService.cancel_jobs_for_project_chat(db, discussion.id) @@ -740,8 +709,7 @@ async def cancel_discussion_response( ProjectChatService.broadcast_project_chat_updated(db, discussion) logger.info( - f"Cancelled discussion {discussion.id} response, " - f"jobs cancelled: {cancelled}, system message created" + f"Cancelled discussion {discussion.id} response, jobs cancelled: {cancelled}, system message created" ) return ProjectChatMessageResponse.from_message(system_message) @@ -787,7 +755,6 @@ async def create_phase_from_project_chat( # Trigger the initial MCQ generation job for the new phase # This follows the pattern from brainstorming_phases router - from app.services.brainstorming_phase_service import BrainstormingPhaseService job = Job( org_id=project.org_id, @@ -825,10 +792,7 @@ async def create_phase_from_project_chat( db.refresh(discussion) ProjectChatService.broadcast_project_chat_updated(db, discussion) - logger.info( - f"Created phase {phase.id} from discussion {discussion.id} " - f"and triggered MCQ generation" - ) + logger.info(f"Created phase {phase.id} from discussion {discussion.id} and triggered MCQ generation") return CreatePhaseFromProjectChatResponse( phase_id=phase.id, @@ -1024,6 +988,7 @@ async def delete_message( try: from uuid import UUID as PyUUID + ProjectChatService.delete_message( db=db, project_chat_id=discussion.id, @@ -1071,6 +1036,7 @@ async def start_over_from_message( try: from uuid import UUID as PyUUID + result = ProjectChatService.start_over_from_message( db=db, project_chat_id=discussion.id, @@ -1119,19 +1085,13 @@ async def update_visibility( # Only creator can change visibility if discussion.created_by != current_user.id: - raise HTTPException( - status_code=403, - detail="Only the creator can change visibility" - ) + raise HTTPException(status_code=403, detail="Only the creator can change visibility") # Parse visibility value try: visibility = ProjectChatVisibility(request.visibility) except ValueError: - raise HTTPException( - status_code=400, - detail="Invalid visibility value. Must be 'private' or 'team'" - ) + raise HTTPException(status_code=400, detail="Invalid visibility value. Must be 'private' or 'team'") # Update visibility try: @@ -1185,6 +1145,7 @@ async def toggle_reaction( try: from uuid import UUID as PyUUID + message, action = ProjectChatService.toggle_reaction( db=db, message_id=PyUUID(message_id), @@ -1216,6 +1177,7 @@ async def toggle_reaction( class TypingIndicatorRequest(BaseModel): """Schema for typing indicator request.""" + typing: bool = Field(..., description="True if user started typing, False if stopped") @@ -1238,7 +1200,9 @@ def send_typing_indicator( """ from app.services.typing_indicator_service import TypingIndicatorService - logger.info(f"[TypingIndicator] Project chat endpoint called: project={project_id}, discussion={discussion_id}, typing={request.typing}, user={current_user.id}") + logger.info( + f"[TypingIndicator] Project chat endpoint called: project={project_id}, discussion={discussion_id}, typing={request.typing}, user={current_user.id}" + ) # Look up project and verify user access project = ProjectService.get_by_identifier(db, project_id) diff --git a/backend/app/routers/project_repositories.py b/backend/app/routers/project_repositories.py index 7f4e349..4057e22 100644 --- a/backend/app/routers/project_repositories.py +++ b/backend/app/routers/project_repositories.py @@ -1,25 +1,25 @@ """Project repositories router for managing multi-repo support.""" + import logging -from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app.auth.dependencies import get_current_user -from app.database import get_db, get_async_db -from app.models import User, ProjectRole, JobType -from app.permissions.context import get_project_context, ProjectContext +from app.database import get_async_db, get_db +from app.models import JobType, ProjectRole, User +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role from app.schemas.project_repository import ( ProjectRepositoryCreate, - ProjectRepositoryUpdate, ProjectRepositoryResponse, + ProjectRepositoryUpdate, ReorderRepositoriesRequest, ) -from app.services.project_repository_service import ProjectRepositoryService from app.services.github_repo_service import GitHubRepoService from app.services.job_service import JobService +from app.services.project_repository_service import ProjectRepositoryService from workers.core.helpers import publish_job_to_kafka logger = logging.getLogger(__name__) @@ -49,15 +49,13 @@ async def create_project_repository( project = project_context.project # Validate the GitHub connector - from app.services.integration_service import IntegrationService - from app.permissions.helpers import role_rank from app.models.org_membership import OrgRole + from app.permissions.helpers import role_rank + from app.services.integration_service import IntegrationService from app.services.org_service import OrgService int_service = IntegrationService(async_db) - config = await int_service.get_config_by_id( - project.org_id, repo_data.github_integration_config_id - ) + config = await int_service.get_config_by_id(project.org_id, repo_data.github_integration_config_id) if not config: raise HTTPException( @@ -152,7 +150,7 @@ async def create_project_repository( except Exception as e: # Log but don't fail repo creation if grounding generation fails try: - if 'agents_md' in locals(): + if "agents_md" in locals(): GroundingService.set_generating_flag(db, agents_md, False) except Exception: pass @@ -198,9 +196,7 @@ def get_project_repository( """ require_project_role(project_context, ProjectRole.VIEWER) - repo = ProjectRepositoryService.get_repository_by_slug( - db, project_context.project.id, slug - ) + repo = ProjectRepositoryService.get_repository_by_slug(db, project_context.project.id, slug) if not repo: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -240,15 +236,13 @@ async def update_project_repository( # Validate new connector if provided if update_data.github_integration_config_id: - from app.services.integration_service import IntegrationService - from app.permissions.helpers import role_rank from app.models.org_membership import OrgRole + from app.permissions.helpers import role_rank + from app.services.integration_service import IntegrationService from app.services.org_service import OrgService int_service = IntegrationService(async_db) - config = await int_service.get_config_by_id( - project.org_id, update_data.github_integration_config_id - ) + config = await int_service.get_config_by_id(project.org_id, update_data.github_integration_config_id) if not config: raise HTTPException( @@ -308,9 +302,7 @@ def delete_project_repository( require_project_role(project_context, ProjectRole.ADMIN) # Get the repo first for broadcast - repo = ProjectRepositoryService.get_repository_by_slug( - db, project_context.project.id, slug - ) + repo = ProjectRepositoryService.get_repository_by_slug(db, project_context.project.id, slug) if not repo: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -320,9 +312,7 @@ def delete_project_repository( project = project_context.project # Delete the repository - deleted = ProjectRepositoryService.delete_repository( - db, project.id, slug - ) + deleted = ProjectRepositoryService.delete_repository(db, project.id, slug) if not deleted: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -359,7 +349,7 @@ def delete_project_repository( logger.info(f"Triggered grounding regeneration after repo deletion for project {project.id}") except Exception as e: try: - if 'agents_md' in locals(): + if "agents_md" in locals(): GroundingService.set_generating_flag(db, agents_md, False) except Exception: pass @@ -385,9 +375,7 @@ def reorder_project_repositories( require_project_role(project_context, ProjectRole.ADMIN) try: - repos = ProjectRepositoryService.reorder_repositories( - db, project_context.project.id, reorder_data.slug_order - ) + repos = ProjectRepositoryService.reorder_repositories(db, project_context.project.id, reorder_data.slug_order) except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/app/routers/project_shares.py b/backend/app/routers/project_shares.py index 280a89e..84320ec 100644 --- a/backend/app/routers/project_shares.py +++ b/backend/app/routers/project_shares.py @@ -14,7 +14,7 @@ from app.models.project_share import ProjectShare, ShareSubjectType from app.models.user_group import UserGroup from app.models.user_group_membership import UserGroupMembership -from app.permissions.context import get_project_context, ProjectContext +from app.permissions.context import ProjectContext, get_project_context from app.permissions.helpers import require_project_role from app.schemas.project_share import ( GroupSummary, @@ -66,9 +66,7 @@ def create_project_share( # Validate subject based on type if share_data.subject_type == ShareSubjectType.USER: # Verify user is a member of the org - org_membership = OrgService.get_org_membership( - db, org_id, share_data.subject_id - ) + org_membership = OrgService.get_org_membership(db, org_id, share_data.subject_id) if not org_membership: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -219,9 +217,7 @@ def delete_project_share( # Prevent removing the last owner if share.role == ProjectRole.OWNER: - all_shares = ProjectShareService.list_project_shares( - db, project_context.project.id - ) + all_shares = ProjectShareService.list_project_shares(db, project_context.project.id) owner_shares = [s for s in all_shares if s.role == ProjectRole.OWNER] if len(owner_shares) <= 1: raise HTTPException( @@ -232,16 +228,13 @@ def delete_project_share( # Remove the share ProjectShareService.remove_share(db, share_id) - logger.info( - f"Share {share_id} removed from project {project_id} " - f"by {current_user.email}" - ) + logger.info(f"Share {share_id} removed from project {project_id} by {current_user.email}") def _build_share_response(db: Session, share: ProjectShare) -> ProjectShareResponse: """Build a ProjectShareResponse with enriched subject data.""" - from app.models.organization import Organization from app.models.org_membership import OrgMembership + from app.models.organization import Organization user_summary = None group_summary = None diff --git a/backend/app/routers/projects.py b/backend/app/routers/projects.py index 6b76464..f83c38e 100644 --- a/backend/app/routers/projects.py +++ b/backend/app/routers/projects.py @@ -1,4 +1,6 @@ """Project router for CRUD operations on projects.""" + +import logging from typing import Optional from uuid import UUID @@ -7,35 +9,33 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, Project, ProjectType, ProjectStatus, OrgRole, ProjectRole, JobType -from app.permissions.context import get_org_context, OrgContext, get_project_context, ProjectContext +from app.models import JobType, OrgRole, ProjectRole, ProjectType, User +from app.permissions.context import OrgContext, ProjectContext, get_org_context, get_project_context from app.permissions.helpers import require_org_role, require_project_role +from app.schemas.bug_sync_history import BugSyncHistoryResponse from app.schemas.project import ( - ProjectCreate, - ProjectUpdate, - ProjectResponse, - FeatureCreate, BugfixCreate, - ProjectMemberCreate, - ProjectMembershipResponse, - ProjectMemberResponse, + CheckPrefixResponse, CloneProjectRequest, - LoadSampleProjectRequest, - OwnerInfo, - MyProjectRoleResponse, + FeatureCreate, GeneratePrefixRequest, GeneratePrefixResponse, - CheckPrefixResponse, + LoadSampleProjectRequest, + MyProjectRoleResponse, + OwnerInfo, ProjectAccessibleUserResponse, + ProjectCreate, + ProjectMemberCreate, + ProjectMemberResponse, + ProjectMembershipResponse, + ProjectResponse, + ProjectUpdate, ) -from app.schemas.bug_sync_history import BugSyncHistoryResponse from app.services.job_service import JobService from app.services.project_service import ProjectService from app.services.project_share_service import ProjectShareService from workers.core.helpers import publish_job_to_kafka -import logging - logger = logging.getLogger(__name__) router = APIRouter(tags=["projects"]) @@ -267,9 +267,7 @@ def list_all_accessible_projects( for project in projects: project_dict = ProjectResponse.model_validate(project).model_dump() project_dict["owner"] = OwnerInfo( - id=project.creator.id, - email=project.creator.email, - display_name=project.creator.display_name + id=project.creator.id, email=project.creator.email, display_name=project.creator.display_name ) result.append(ProjectResponse(**project_dict)) @@ -323,9 +321,7 @@ def list_org_projects( for project in projects: project_dict = ProjectResponse.model_validate(project).model_dump() project_dict["owner"] = OwnerInfo( - id=project.creator.id, - email=project.creator.email, - display_name=project.creator.display_name + id=project.creator.id, email=project.creator.email, display_name=project.creator.display_name ) result.append(ProjectResponse(**project_dict)) @@ -681,6 +677,7 @@ def create_bugfix( except Exception as e: # Log error but don't fail project creation import logging + logger = logging.getLogger(__name__) logger.error(f"Failed to auto-trigger bug sync for {bugfix.id}: {e}") @@ -773,9 +770,7 @@ def add_project_member( # Verify the user being added is a member of the org from app.services.org_service import OrgService - org_membership = OrgService.get_org_membership( - db, project_context.project.org_id, member_data.user_id - ) + org_membership = OrgService.get_org_membership(db, project_context.project.org_id, member_data.user_id) if not org_membership: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -987,9 +982,10 @@ def get_bug_sync_history( # Query sync history # Use project.id (UUID), not project_id (short URL identifier) - from app.models.bug_sync_history import BugSyncHistory from sqlalchemy import select + from app.models.bug_sync_history import BugSyncHistory + stmt = ( select(BugSyncHistory) .where(BugSyncHistory.project_id == project.id) diff --git a/backend/app/routers/team_roles.py b/backend/app/routers/team_roles.py index 17b8e56..c163423 100644 --- a/backend/app/routers/team_roles.py +++ b/backend/app/routers/team_roles.py @@ -7,20 +7,20 @@ from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, OrgRole, ProjectRole -from app.permissions.context import get_org_context, OrgContext, get_project_context, ProjectContext +from app.models import OrgRole, ProjectRole, User +from app.permissions.context import OrgContext, ProjectContext, get_org_context, get_project_context from app.permissions.helpers import require_org_role, require_project_role from app.schemas.team_role import ( - TeamRoleDefinitionCreate, - TeamRoleDefinitionUpdate, - TeamRoleDefinitionResponse, + AvailableUserResponse, ProjectTeamAssignmentCreate, ProjectTeamAssignmentResponse, ProjectTeamAssignmentWithUser, - ProjectTeamRoleGroup, ProjectTeamResponse, - AvailableUserResponse, + ProjectTeamRoleGroup, TeamMemberInfo, + TeamRoleDefinitionCreate, + TeamRoleDefinitionResponse, + TeamRoleDefinitionUpdate, ) from app.services.team_role_service import TeamRoleService @@ -272,9 +272,7 @@ def list_project_team_assignments( """ require_project_role(project_context, ProjectRole.VIEWER) - assignments = TeamRoleService.get_project_team_assignments( - db, project_context.project.id - ) + assignments = TeamRoleService.get_project_team_assignments(db, project_context.project.id) return [ ProjectTeamAssignmentWithUser( diff --git a/backend/app/routers/testing.py b/backend/app/routers/testing.py index 402a66c..bdca40d 100644 --- a/backend/app/routers/testing.py +++ b/backend/app/routers/testing.py @@ -3,7 +3,7 @@ Provides endpoints for loading test data and debugging features. """ -from typing import Annotated + from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status @@ -18,7 +18,6 @@ from app.schemas.testing import LoadFakeUsersResponse from app.services.user_service import UserService - router = APIRouter(prefix="/api/v1", tags=["testing"]) diff --git a/backend/app/routers/thread_images.py b/backend/app/routers/thread_images.py index df4ae2a..15092ab 100644 --- a/backend/app/routers/thread_images.py +++ b/backend/app/routers/thread_images.py @@ -1,16 +1,15 @@ """API routes for thread image uploads.""" + import logging -from typing import Optional from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status -from fastapi.responses import RedirectResponse from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.database import get_db, get_async_db from app.auth.dependencies import get_current_user -from app.models import User, Thread, Project +from app.database import get_async_db, get_db +from app.models import Project, Thread, User from app.models.thread_item import ThreadItem, ThreadItemType from app.permissions.context import get_org_context_from_project from app.services.image_service import ImageService @@ -107,11 +106,7 @@ async def upload_thread_image( def _find_image_id_by_s3_key(db: Session, s3_key: str) -> str | None: """Find image ID by S3 key in thread item images.""" - items = ( - db.query(ThreadItem) - .filter(ThreadItem.item_type == ThreadItemType.COMMENT) - .all() - ) + items = db.query(ThreadItem).filter(ThreadItem.item_type == ThreadItemType.COMMENT).all() for item in items: content_data = item.content_data or {} images = content_data.get("images", []) diff --git a/backend/app/routers/thread_items.py b/backend/app/routers/thread_items.py index e1259ac..ad97361 100644 --- a/backend/app/routers/thread_items.py +++ b/backend/app/routers/thread_items.py @@ -1,19 +1,20 @@ """API routes for thread items (comments, MCQ follow-ups, system messages).""" + import logging + from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from app.database import get_db from app.auth.dependencies import get_current_user -from app.models import User, Thread, OrgMembership +from app.database import get_db +from app.models import Thread, User from app.models.thread_item import ThreadItem from app.permissions.context import get_org_context_from_project from app.schemas.thread_item import ( CreateCommentItem, CreateMCQAnswer, - UpdateThreadItem, - StartOverResponse, DownstreamItemsResponse, + StartOverResponse, ToggleReactionRequest, ToggleReactionResponse, thread_item_to_response, @@ -201,10 +202,7 @@ def start_over_from_item( # Check if user is author (only authors can start over from their own items) is_author = str(item.author_id) == str(current_user.id) if not is_author: - raise HTTPException( - status_code=403, - detail="Only the author of this item can start over from here" - ) + raise HTTPException(status_code=403, detail="Only the author of this item can start over from here") # Start over from this item try: diff --git a/backend/app/routers/threads.py b/backend/app/routers/threads.py index 7acbed7..45d8583 100644 --- a/backend/app/routers/threads.py +++ b/backend/app/routers/threads.py @@ -1,25 +1,27 @@ """Thread and comment REST API endpoints.""" + import logging from typing import List, Optional from uuid import UUID + from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, Field from sqlalchemy.orm import Session, joinedload from app.auth.dependencies import get_current_user from app.database import get_db -from app.models import User, ContextType, Thread, Comment, ProjectRole -from app.models.thread_item import ThreadItem, ThreadItemType +from app.models import Comment, ContextType, ProjectRole, Thread, User from app.models.implementation import Implementation -from app.permissions.context import get_project_context, ProjectContext +from app.models.thread_item import ThreadItem, ThreadItemType +from app.permissions.context import ProjectContext, get_project_context from app.schemas.thread import ( - ThreadCreate, - ThreadUpdate, - ThreadResponse, - ThreadListResponse, CommentCreate, - CommentUpdate, CommentResponse, + CommentUpdate, + ThreadCreate, + ThreadListResponse, + ThreadResponse, + ThreadUpdate, ) from app.schemas.thread_item import thread_item_to_response from app.services.thread_service import ThreadService @@ -30,6 +32,7 @@ # Schema for creating version-anchored threads class VersionThreadCreate(BaseModel): """Schema for creating a thread on a draft version block.""" + block_id: str = Field(..., min_length=1, max_length=100) title: Optional[str] = Field(None, min_length=1, max_length=200) context_type: ContextType = ContextType.SPEC_DRAFT @@ -38,14 +41,18 @@ class VersionThreadCreate(BaseModel): # Schemas for @MFBTAI mention class AIMentionRequest(BaseModel): """Schema for @MFBTAI mention request.""" + message_text: str = Field(..., min_length=1, description="The message text containing @MFBTAI") feature_id: str = Field(..., description="Feature ID for context") additional_context: Optional[str] = Field(None, description="Optional additional context for the AI assistant") - trigger_type: Optional[str] = Field(None, description="Optional trigger type (e.g., 'feature_created', 'manual_start')") + trigger_type: Optional[str] = Field( + None, description="Optional trigger type (e.g., 'feature_created', 'manual_start')" + ) class SpecDraftAIMentionRequest(BaseModel): """Schema for @MFBTAI mention in spec/prompt plan draft discussions.""" + message_text: str = Field(..., min_length=1, description="The message text containing @MFBTAI") version_id: str = Field(..., description="Draft version ID for context") block_id: str = Field(..., description="Block ID being discussed") @@ -54,6 +61,7 @@ class SpecDraftAIMentionRequest(BaseModel): class AIMentionResponse(BaseModel): """Schema for @MFBTAI async response (job accepted).""" + job_id: str = Field(..., description="ID of the background job processing the AI mention") status: str = Field(default="accepted", description="Job acceptance status") message: str = Field(default="AI mention job queued successfully", description="Status message") @@ -62,6 +70,7 @@ class AIMentionResponse(BaseModel): # Schema for typing indicator class TypingIndicatorRequest(BaseModel): """Schema for typing indicator request.""" + typing: bool = Field(..., description="True if user started typing, False if stopped") @@ -114,7 +123,7 @@ def get_thread( .filter(Thread.id == thread_id) .options( joinedload(Thread.comments).joinedload(Comment.author), # Legacy - joinedload(Thread.items).joinedload(ThreadItem.author) + joinedload(Thread.items).joinedload(ThreadItem.author), ) .first() ) @@ -124,6 +133,7 @@ def get_thread( # Verify user has access to the project (check org membership) from app.permissions.context import get_org_context_from_project + get_org_context_from_project(db, str(thread.project_id), current_user) # Build implementation map for implementation_created items @@ -131,16 +141,11 @@ def get_thread( impl_ids = [ item.content_data.get("implementation_id") for item in thread.items - if item.item_type == ThreadItemType.IMPLEMENTATION_CREATED - and item.content_data.get("implementation_id") + if item.item_type == ThreadItemType.IMPLEMENTATION_CREATED and item.content_data.get("implementation_id") ] impl_map = {} if impl_ids: - implementations = ( - db.query(Implementation) - .filter(Implementation.id.in_(impl_ids)) - .all() - ) + implementations = db.query(Implementation).filter(Implementation.id.in_(impl_ids)).all() impl_map = {str(impl.id): impl for impl in implementations} # Convert to response model with items @@ -164,6 +169,7 @@ def update_thread( # Verify user has access to the project from app.permissions.context import get_org_context_from_project + get_org_context_from_project(db, str(thread.project_id), current_user) updated_thread = ThreadService.update_thread(db, thread_id, thread_update.title) @@ -184,6 +190,7 @@ def create_comment( raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") from app.permissions.context import get_org_context_from_project + get_org_context_from_project(db, str(thread.project_id), current_user) # Note: Proactive conversation triggers are now created from the decision summarizer @@ -208,6 +215,7 @@ def update_comment( """Update a comment's body.""" # Get comment to verify it exists and check thread access from app.models import Comment + comment = db.query(Comment).filter(Comment.id == comment_id).first() if not comment: raise HTTPException(status_code=404, detail=f"Comment {comment_id} not found") @@ -215,9 +223,10 @@ def update_comment( # Verify user has access to the thread's project thread = ThreadService.get_thread_by_id(db, comment.thread_id) if not thread: - raise HTTPException(status_code=404, detail=f"Thread not found") + raise HTTPException(status_code=404, detail="Thread not found") from app.permissions.context import get_org_context_from_project + get_org_context_from_project(db, str(thread.project_id), current_user) updated_comment = ThreadService.update_comment(db, comment_id, comment_update.body_markdown) @@ -233,6 +242,7 @@ def delete_comment( """Delete a comment.""" # Get comment to verify it exists and check thread access from app.models import Comment + comment = db.query(Comment).filter(Comment.id == comment_id).first() if not comment: raise HTTPException(status_code=404, detail=f"Comment {comment_id} not found") @@ -240,9 +250,10 @@ def delete_comment( # Verify user has access to the thread's project thread = ThreadService.get_thread_by_id(db, comment.thread_id) if not thread: - raise HTTPException(status_code=404, detail=f"Thread not found") + raise HTTPException(status_code=404, detail="Thread not found") from app.permissions.context import get_org_context_from_project + get_org_context_from_project(db, str(thread.project_id), current_user) ThreadService.delete_comment(db, comment_id) @@ -261,6 +272,7 @@ def approve_thread( # Verify user has access to the project from app.permissions.context import get_org_context_from_project + get_org_context_from_project(db, str(thread.project_id), current_user) # Toggle pending_approval flag @@ -301,11 +313,12 @@ async def invoke_ai_mention( - Risks / Watchouts """ import logging + + from app.models.job import JobType + from app.models.project import Project from app.permissions.context import get_org_context_from_project from app.services.job_service import JobService from app.services.kafka_producer import get_kafka_producer - from app.models.job import JobType - from app.models.project import Project logger = logging.getLogger(__name__) @@ -319,6 +332,7 @@ async def invoke_ai_mention( # Validate feature exists from app.models.feature import Feature + feature = db.query(Feature).filter(Feature.id == request.feature_id).first() if not feature: raise HTTPException(status_code=404, detail=f"Feature {request.feature_id} not found") @@ -373,10 +387,7 @@ async def invoke_ai_mention( thread.is_generating_ai_response = False db.commit() logger.error(f"Failed to publish AI mention job {job.id} to Kafka") - raise HTTPException( - status_code=500, - detail="Failed to queue AI mention job" - ) + raise HTTPException(status_code=500, detail="Failed to queue AI mention job") logger.info(f"AI mention job {job.id} queued for thread {thread_id}") @@ -500,9 +511,7 @@ def cancel_thread_ai_response( db.refresh(thread) ThreadService._broadcast_thread_update(db, thread, "thread_updated") - logger.info( - f"Cancelled thread {thread_id} jobs: {cancelled}" - ) + logger.info(f"Cancelled thread {thread_id} jobs: {cancelled}") # ============= Typing Indicator Endpoint ============= @@ -527,7 +536,9 @@ def send_typing_indicator( from app.permissions.context import get_org_context_from_project from app.services.typing_indicator_service import TypingIndicatorService - logger.info(f"[TypingIndicator] Endpoint called: thread={thread_id}, typing={request.typing}, user={current_user.id}") + logger.info( + f"[TypingIndicator] Endpoint called: thread={thread_id}, typing={request.typing}, user={current_user.id}" + ) # Get thread to check permissions thread = db.query(Thread).filter(Thread.id == thread_id).first() @@ -539,9 +550,10 @@ def send_typing_indicator( # Get org_id from the project from app.models.project import Project + project = db.query(Project).filter(Project.id == thread.project_id).first() if not project: - raise HTTPException(status_code=404, detail=f"Project not found") + raise HTTPException(status_code=404, detail="Project not found") # Get user's display name user_name = current_user.display_name or current_user.email @@ -740,11 +752,12 @@ async def invoke_spec_draft_ai_mention( - Broadcast to clients via WebSocket (thread_item_created event) """ import logging + + from app.models.job import JobType + from app.models.project import Project from app.permissions.context import get_org_context_from_project from app.services.job_service import JobService from app.services.kafka_producer import get_kafka_producer - from app.models.job import JobType - from app.models.project import Project logger = logging.getLogger(__name__) @@ -757,8 +770,8 @@ async def invoke_spec_draft_ai_mention( get_org_context_from_project(db, str(thread.project_id), current_user) # Validate version exists and is accessible - from app.services.draft_version_service import DraftVersionService from app.services.brainstorming_phase_service import BrainstormingPhaseService + from app.services.draft_version_service import DraftVersionService draft = DraftVersionService.get_draft(db=db, version_id=UUID(request.version_id)) if not draft: @@ -817,10 +830,7 @@ async def invoke_spec_draft_ai_mention( thread.is_generating_ai_response = False db.commit() logger.error(f"Failed to publish spec draft AI mention job {job.id} to Kafka") - raise HTTPException( - status_code=500, - detail="Failed to queue AI mention job" - ) + raise HTTPException(status_code=500, detail="Failed to queue AI mention job") logger.info(f"Spec draft AI mention job {job.id} queued for thread {thread_id}") diff --git a/backend/app/routers/user_groups.py b/backend/app/routers/user_groups.py index 2ab598e..a8a84cb 100644 --- a/backend/app/routers/user_groups.py +++ b/backend/app/routers/user_groups.py @@ -67,9 +67,7 @@ def create_group( created_by_user_id=org_context.user.id, description=request.description, ) - logger.info( - f"Group '{request.name}' created in org {org_id} by {org_context.user.email}" - ) + logger.info(f"Group '{request.name}' created in org {org_id} by {org_context.user.email}") return UserGroupResponse.from_group(group) except IntegrityError: db.rollback() @@ -196,9 +194,7 @@ def update_group( name=request.name, description=request.description, ) - logger.info( - f"Group {group_id} updated in org {org_id} by {org_context.user.email}" - ) + logger.info(f"Group {group_id} updated in org {org_id} by {org_context.user.email}") return UserGroupResponse.from_group(updated_group) except IntegrityError: db.rollback() @@ -244,9 +240,7 @@ def delete_group( ) UserGroupService.delete_group(db, group_id) - logger.info( - f"Group {group_id} deleted from org {org_id} by {org_context.user.email}" - ) + logger.info(f"Group {group_id} deleted from org {org_id} by {org_context.user.email}") @router.get( @@ -355,9 +349,7 @@ def add_group_member( group_id=group_id, user_id=request.user_id, ) - logger.info( - f"User {request.user_id} added to group {group_id} by {org_context.user.email}" - ) + logger.info(f"User {request.user_id} added to group {group_id} by {org_context.user.email}") return GroupMemberResponse.from_membership(membership) except IntegrityError: db.rollback() @@ -411,6 +403,4 @@ def remove_group_member( detail="User is not a member of this group", ) - logger.info( - f"User {user_id} removed from group {group_id} by {org_context.user.email}" - ) + logger.info(f"User {user_id} removed from group {group_id} by {org_context.user.email}") diff --git a/backend/app/routers/user_question_sessions.py b/backend/app/routers/user_question_sessions.py index c8e9b46..691306a 100644 --- a/backend/app/routers/user_question_sessions.py +++ b/backend/app/routers/user_question_sessions.py @@ -1,4 +1,5 @@ """API endpoints for user question sessions.""" + import logging from typing import List from uuid import UUID @@ -6,25 +7,24 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from app.database import get_db from app.auth.dependencies import get_current_user -from app.models.user import User +from app.database import get_db from app.models.brainstorming_phase import BrainstormingPhase -from app.models.job import Job, JobType, JobStatus -from app.services.user_question_session_service import UserQuestionSessionService -from app.services.job_service import JobService -from app.services.kafka_producer import get_kafka_producer -from app.services.project_share_service import ProjectShareService +from app.models.job import Job, JobStatus, JobType +from app.models.user import User from app.schemas.user_question_session import ( - UserQuestionSessionResponse, - UserQuestionSessionWithMessagesResponse, - UserQuestionMessageResponse, - GenerateQuestionsRequest, - GenerateQuestionsResponse, AddQuestionsRequest, AddQuestionsResponse, + GenerateQuestionsRequest, + GenerateQuestionsResponse, + UserQuestionMessageResponse, + UserQuestionSessionResponse, + UserQuestionSessionWithMessagesResponse, ) - +from app.services.job_service import JobService +from app.services.kafka_producer import get_kafka_producer +from app.services.project_share_service import ProjectShareService +from app.services.user_question_session_service import UserQuestionSessionService logger = logging.getLogger(__name__) @@ -174,15 +174,10 @@ async def generate_questions( # Check if session can add more questions if not session.can_add_more_questions: - raise HTTPException( - status_code=400, - detail="Session has reached the maximum number of questions (5)" - ) + raise HTTPException(status_code=400, detail="Session has reached the maximum number of questions (5)") # Get phase for org_id - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == session.brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == session.brainstorming_phase_id).first() if not phase: raise HTTPException(status_code=404, detail="Phase not found") @@ -238,10 +233,7 @@ async def generate_questions( detail="Failed to queue question generation job", ) - logger.info( - f"Created user-initiated question generation job {job.id} " - f"for session {session_id}" - ) + logger.info(f"Created user-initiated question generation job {job.id} for session {session_id}") return GenerateQuestionsResponse( job_id=str(job.id), diff --git a/backend/app/routers/websocket.py b/backend/app/routers/websocket.py index d1ca5d9..9310f4b 100644 --- a/backend/app/routers/websocket.py +++ b/backend/app/routers/websocket.py @@ -4,7 +4,7 @@ from typing import Annotated from uuid import UUID -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, status +from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect, status from jose import JWTError from sqlalchemy.orm import Session diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index 9ae2f8d..b07915f 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -1,61 +1,70 @@ """ Pydantic schemas for request/response validation. """ -from app.schemas.auth import UserCreate, UserLogin, UserResponse, TokenResponse -from app.schemas.org import OrgResponse, OrgMembershipResponse -from app.schemas.project import ( - ProjectCreate, - ProjectUpdate, - ProjectResponse, - FeatureCreate, - BugfixCreate, - ProjectMemberCreate, - ProjectMembershipResponse, - ProjectMemberResponse, -) -from app.schemas.job import JobResponse, JobDetailResponse -from app.schemas.spec import SpecVersionResponse, SpecGenerationStatusResponse, PromptPlanGenerationStatusResponse -from app.schemas.thread import ( - ThreadCreate, - ThreadUpdate, - ThreadResponse, - ThreadListResponse, - CommentCreate, - CommentUpdate, - CommentResponse, + +from app.schemas.activity import ( + ActivityListResponse, + ActivityLogResponse, ) +from app.schemas.auth import TokenResponse, UserCreate, UserLogin, UserResponse from app.schemas.brainstorming_phase import ( BrainstormingPhaseCreate, - BrainstormingPhaseUpdate, - BrainstormingPhaseResponse, BrainstormingPhaseListResponse, + BrainstormingPhaseResponse, + BrainstormingPhaseUpdate, ) -from app.schemas.module import ( - ModuleCreate, - ModuleUpdate, - ModuleResponse, - ModuleListResponse, +from app.schemas.draft_version import ( + DraftListResponse, + DraftVersionCreate, + DraftVersionResponse, ) from app.schemas.feature import ( FeatureCreate as Phase7FeatureCreate, - FeatureUpdate as Phase7FeatureUpdate, - FeatureResponse as Phase7FeatureResponse, +) +from app.schemas.feature import ( FeatureListResponse as Phase7FeatureListResponse, +) +from app.schemas.feature import ( + FeatureResponse as Phase7FeatureResponse, +) +from app.schemas.feature import ( FeatureRestoreRequest, ) -from app.schemas.draft_version import ( - DraftVersionCreate, - DraftVersionResponse, - DraftListResponse, +from app.schemas.feature import ( + FeatureUpdate as Phase7FeatureUpdate, ) from app.schemas.final_version import ( FinalizeRequest, - FinalSpecResponse, FinalPromptPlanResponse, + FinalSpecResponse, ) -from app.schemas.activity import ( - ActivityLogResponse, - ActivityListResponse, +from app.schemas.job import JobDetailResponse, JobResponse +from app.schemas.module import ( + ModuleCreate, + ModuleListResponse, + ModuleResponse, + ModuleUpdate, +) +from app.schemas.org import OrgMembershipResponse, OrgResponse +from app.schemas.project import ( + BugfixCreate, + FeatureCreate, + ProjectCreate, + ProjectMemberCreate, + ProjectMemberResponse, + ProjectMembershipResponse, + ProjectResponse, + ProjectUpdate, +) +from app.schemas.spec import PromptPlanGenerationStatusResponse, SpecGenerationStatusResponse, SpecVersionResponse +from app.schemas.thread import ( + CommentCreate, + CommentResponse, + CommentUpdate, + ThreadCreate, + ThreadListResponse, + ThreadResponse, + ThreadUpdate, ) __all__ = [ diff --git a/backend/app/schemas/activity.py b/backend/app/schemas/activity.py index a18116a..575e76a 100644 --- a/backend/app/schemas/activity.py +++ b/backend/app/schemas/activity.py @@ -1,6 +1,7 @@ """Schemas for activity logs.""" + from datetime import datetime -from typing import Optional, Any, List +from typing import Any, List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, field_validator @@ -19,7 +20,7 @@ class ActivityLogResponse(BaseModel): event_metadata: Optional[dict] = None created_at: datetime - @field_validator('id', 'entity_id', 'actor_id', mode='before') + @field_validator("id", "entity_id", "actor_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" diff --git a/backend/app/schemas/analytics.py b/backend/app/schemas/analytics.py index d14e430..89db2c9 100644 --- a/backend/app/schemas/analytics.py +++ b/backend/app/schemas/analytics.py @@ -35,9 +35,7 @@ class TopUserEntry(BaseModel): total_credits: float = Field(description="Total credits (tokens / 100k)") total_cost_usd: Optional[float] = Field(default=None, description="Total cost in USD") call_count: int = Field(description="Number of LLM calls made") - percentage_of_total: float = Field( - description="Percentage of total platform usage in the time range" - ) + percentage_of_total: float = Field(description="Percentage of total platform usage in the time range") class TopProjectEntry(BaseModel): @@ -53,9 +51,7 @@ class TopProjectEntry(BaseModel): total_credits: float = Field(description="Total credits (tokens / 100k)") total_cost_usd: Optional[float] = Field(default=None, description="Total cost in USD") call_count: int = Field(description="Number of LLM calls made") - percentage_of_total: float = Field( - description="Percentage of total platform usage in the time range" - ) + percentage_of_total: float = Field(description="Percentage of total platform usage in the time range") class EfficiencyMetrics(BaseModel): @@ -64,9 +60,7 @@ class EfficiencyMetrics(BaseModel): model_config = ConfigDict(from_attributes=True) avg_tokens_per_call: float = Field(description="Average tokens per LLM call") - avg_cost_per_call: Optional[float] = Field( - default=None, description="Average cost per LLM call in USD" - ) + avg_cost_per_call: Optional[float] = Field(default=None, description="Average cost per LLM call in USD") total_calls: int = Field(description="Total number of LLM calls") total_tokens: int = Field(description="Total tokens consumed") total_credits: float = Field(description="Total credits (tokens / 100k)") @@ -81,16 +75,10 @@ class TopUsersResponse(BaseModel): time_range: TimeRange = Field(description="The time range used for the query") start_date: date = Field(description="Start date of the query range") end_date: date = Field(description="End date of the query range") - org_id: Optional[UUID] = Field( - default=None, description="Organization filter (null for platform-wide)" - ) + org_id: Optional[UUID] = Field(default=None, description="Organization filter (null for platform-wide)") users: list[TopUserEntry] = Field(description="List of top users by usage") - total_platform_tokens: int = Field( - description="Total tokens across all users in the time range" - ) - total_platform_credits: float = Field( - description="Total credits across all users in the time range" - ) + total_platform_tokens: int = Field(description="Total tokens across all users in the time range") + total_platform_credits: float = Field(description="Total credits across all users in the time range") class TopProjectsResponse(BaseModel): @@ -101,16 +89,10 @@ class TopProjectsResponse(BaseModel): time_range: TimeRange = Field(description="The time range used for the query") start_date: date = Field(description="Start date of the query range") end_date: date = Field(description="End date of the query range") - org_id: Optional[UUID] = Field( - default=None, description="Organization filter (null for platform-wide)" - ) + org_id: Optional[UUID] = Field(default=None, description="Organization filter (null for platform-wide)") projects: list[TopProjectEntry] = Field(description="List of top projects by usage") - total_platform_tokens: int = Field( - description="Total tokens across all projects in the time range" - ) - total_platform_credits: float = Field( - description="Total credits across all projects in the time range" - ) + total_platform_tokens: int = Field(description="Total tokens across all projects in the time range") + total_platform_credits: float = Field(description="Total credits across all projects in the time range") class EfficiencyOverviewResponse(BaseModel): @@ -121,9 +103,7 @@ class EfficiencyOverviewResponse(BaseModel): time_range: TimeRange = Field(description="The time range used for the query") start_date: date = Field(description="Start date of the query range") end_date: date = Field(description="End date of the query range") - org_id: Optional[UUID] = Field( - default=None, description="Organization filter (null for platform-wide)" - ) + org_id: Optional[UUID] = Field(default=None, description="Organization filter (null for platform-wide)") metrics: EfficiencyMetrics = Field(description="Efficiency metrics") @@ -151,25 +131,13 @@ class PlanEfficiencyResponse(BaseModel): description="Usage efficiency as percentage of plan limit (None for unlimited plans)", ) tokens_used: int = Field(description="Total tokens consumed in the period") - tokens_limit: Optional[int] = Field( - default=None, description="Plan token limit (None for unlimited)" - ) - tokens_remaining: Optional[int] = Field( - default=None, description="Tokens remaining in plan (None for unlimited)" - ) - is_over_limit: bool = Field( - default=False, description="Whether usage exceeds the plan limit" - ) - period_start: Optional[date] = Field( - default=None, description="Start of current billing/tracking period" - ) - period_end: Optional[date] = Field( - default=None, description="End of current billing/tracking period" - ) + tokens_limit: Optional[int] = Field(default=None, description="Plan token limit (None for unlimited)") + tokens_remaining: Optional[int] = Field(default=None, description="Tokens remaining in plan (None for unlimited)") + is_over_limit: bool = Field(default=False, description="Whether usage exceeds the plan limit") + period_start: Optional[date] = Field(default=None, description="Start of current billing/tracking period") + period_end: Optional[date] = Field(default=None, description="End of current billing/tracking period") credits_used: float = Field(description="Credits consumed (tokens / 100k)") - credits_limit: Optional[float] = Field( - default=None, description="Plan limit in credits (None for unlimited)" - ) + credits_limit: Optional[float] = Field(default=None, description="Plan limit in credits (None for unlimited)") # MFBT-053: Org Efficiency Overview Schemas @@ -193,16 +161,10 @@ class EfficiencyScore(BaseModel): default=None, description="Usage efficiency as percentage of plan limit" ) plan_type: str = Field(description="Type of plan (monthly, lifetime, unlimited)") - threshold_color: Optional[ThresholdColor] = Field( - default=None, description="Color-coded threshold indicator" - ) - recommendation_text: Optional[str] = Field( - default=None, description="Recommendation based on efficiency level" - ) + threshold_color: Optional[ThresholdColor] = Field(default=None, description="Color-coded threshold indicator") + recommendation_text: Optional[str] = Field(default=None, description="Recommendation based on efficiency level") tokens_used: int = Field(description="Total tokens consumed in the period") - tokens_limit: Optional[int] = Field( - default=None, description="Plan token limit (None for unlimited)" - ) + tokens_limit: Optional[int] = Field(default=None, description="Plan token limit (None for unlimited)") class UserEfficiencyEntry(BaseModel): @@ -214,9 +176,7 @@ class UserEfficiencyEntry(BaseModel): email: str = Field(description="User's email address") display_name: Optional[str] = Field(default=None, description="User's display name") tokens_used: int = Field(description="Total tokens consumed by this user") - percentage_of_org: float = Field( - description="Percentage of organization's total usage" - ) + percentage_of_org: float = Field(description="Percentage of organization's total usage") class ProjectEfficiencyEntry(BaseModel): @@ -227,9 +187,7 @@ class ProjectEfficiencyEntry(BaseModel): project_id: UUID = Field(description="Project's unique identifier") project_name: str = Field(description="Project's name") tokens_used: int = Field(description="Total tokens consumed by this project") - percentage_of_org: float = Field( - description="Percentage of organization's total usage" - ) + percentage_of_org: float = Field(description="Percentage of organization's total usage") class OrganizationEfficiency(BaseModel): @@ -240,12 +198,8 @@ class OrganizationEfficiency(BaseModel): org_id: UUID = Field(description="Organization's unique identifier") org_name: str = Field(description="Organization's name") efficiency: EfficiencyScore = Field(description="Efficiency score and metrics") - users: list[UserEfficiencyEntry] = Field( - default_factory=list, description="Top users by usage" - ) - projects: list[ProjectEfficiencyEntry] = Field( - default_factory=list, description="Top projects by usage" - ) + users: list[UserEfficiencyEntry] = Field(default_factory=list, description="Top users by usage") + projects: list[ProjectEfficiencyEntry] = Field(default_factory=list, description="Top projects by usage") class OrgEfficiencyResponse(BaseModel): @@ -256,6 +210,4 @@ class OrgEfficiencyResponse(BaseModel): time_range: TimeRange = Field(description="The time range used for the query") start_date: date = Field(description="Start date of the query range") end_date: date = Field(description="End date of the query range") - organizations: list[OrganizationEfficiency] = Field( - description="List of organizations with efficiency metrics" - ) + organizations: list[OrganizationEfficiency] = Field(description="List of organizations with efficiency metrics") diff --git a/backend/app/schemas/api_key.py b/backend/app/schemas/api_key.py index f16d756..ab03aea 100644 --- a/backend/app/schemas/api_key.py +++ b/backend/app/schemas/api_key.py @@ -21,8 +21,12 @@ class ApiKeyResponse(BaseModel): last_used_at: datetime | None revoked: bool key_preview: str = Field(..., description="Partial key for identification (e.g., 'mfbtsk-...cb29f34f')") - api_key: str | None = Field(None, description="Full API key (available for keys created after encryption was enabled)") - is_legacy: bool = Field(False, description="True if key was created before encryption was enabled (cannot be viewed)") + api_key: str | None = Field( + None, description="Full API key (available for keys created after encryption was enabled)" + ) + is_legacy: bool = Field( + False, description="True if key was created before encryption was enabled (cannot be viewed)" + ) model_config = {"from_attributes": True} diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index a67599d..2812d0b 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -4,6 +4,7 @@ These schemas are used for request validation and response serialization in the authentication endpoints. """ + from datetime import datetime from typing import Optional from uuid import UUID @@ -123,6 +124,7 @@ class SwitchOrgRequest(BaseModel): # Email Verification Schemas + class RegistrationResponse(BaseModel): """Schema for user registration response (with email verification status).""" diff --git a/backend/app/schemas/brainstorming_phase.py b/backend/app/schemas/brainstorming_phase.py index 59426c2..674d0e8 100644 --- a/backend/app/schemas/brainstorming_phase.py +++ b/backend/app/schemas/brainstorming_phase.py @@ -1,6 +1,7 @@ """Schemas for brainstorming phases.""" + from datetime import datetime -from typing import Optional, Any +from typing import Any, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -16,7 +17,7 @@ class BrainstormingPhaseCreate(BaseModel): description: str = Field( ..., min_length=10, - description="Detailed description of what to brainstorm. Required for AI-generated aspects and questions." + description="Detailed description of what to brainstorm. Required for AI-generated aspects and questions.", ) @@ -56,7 +57,7 @@ class BrainstormingPhaseResponse(BaseModel): phase_subtype: Optional[str] = None container_sequence: Optional[int] = None - @field_validator('id', 'project_id', 'created_by', 'container_id', mode='before') + @field_validator("id", "project_id", "created_by", "container_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -97,7 +98,7 @@ class BrainstormingPhaseListResponse(BaseModel): phase_subtype: Optional[str] = None container_sequence: Optional[int] = None - @field_validator('id', 'project_id', 'container_id', mode='before') + @field_validator("id", "project_id", "container_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -164,7 +165,7 @@ class ModuleProgressResponse(BaseModel): progress_percent: float = 0.0 next_feature: Optional[str] = None - @field_validator('module_id', mode='before') + @field_validator("module_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: if isinstance(value, UUID): @@ -184,7 +185,7 @@ class PhaseImplementationProgressResponse(BaseModel): next_feature: Optional[str] = None modules: list[ModuleProgressResponse] = [] - @field_validator('phase_id', mode='before') + @field_validator("phase_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: if isinstance(value, UUID): diff --git a/backend/app/schemas/bug_sync_history.py b/backend/app/schemas/bug_sync_history.py index 03f2782..2168b3b 100644 --- a/backend/app/schemas/bug_sync_history.py +++ b/backend/app/schemas/bug_sync_history.py @@ -1,4 +1,5 @@ """Bug sync history schemas.""" + from datetime import datetime from typing import Any from uuid import UUID @@ -15,8 +16,6 @@ class BugSyncHistoryResponse(BaseModel): project_id: UUID synced_at: datetime status: str = Field(..., description="Sync status (success/error)") - imported_data_json: dict[str, Any] | None = Field( - None, description="Imported ticket data" - ) + imported_data_json: dict[str, Any] | None = Field(None, description="Imported ticket data") error_message: str | None = Field(None, description="Error message if failed") triggered_by: str = Field(..., description="Who triggered sync (system/user/agent)") diff --git a/backend/app/schemas/dashboard.py b/backend/app/schemas/dashboard.py index 0162be1..089276f 100644 --- a/backend/app/schemas/dashboard.py +++ b/backend/app/schemas/dashboard.py @@ -1,8 +1,9 @@ """ Pydantic schemas for Dashboard API responses. """ + from datetime import datetime -from typing import Optional, List +from typing import List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict @@ -35,6 +36,7 @@ class PlanInfo(BaseModel): class RecentLLMCall(BaseModel): """Summary of a recent LLM call for dashboard display.""" + model_config = ConfigDict(from_attributes=True) id: UUID diff --git a/backend/app/schemas/draft_version.py b/backend/app/schemas/draft_version.py index 802eb07..eff8834 100644 --- a/backend/app/schemas/draft_version.py +++ b/backend/app/schemas/draft_version.py @@ -1,6 +1,7 @@ """Schemas for draft versions.""" + from datetime import datetime -from typing import Optional, Any, List +from typing import Any, List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -30,7 +31,7 @@ class DraftVersionResponse(BaseModel): created_by: str created_at: datetime - @field_validator('id', 'brainstorming_phase_id', 'created_by', mode='before') + @field_validator("id", "brainstorming_phase_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -51,7 +52,7 @@ class DraftListResponse(BaseModel): created_by: str created_at: datetime - @field_validator('id', 'brainstorming_phase_id', 'created_by', mode='before') + @field_validator("id", "brainstorming_phase_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" diff --git a/backend/app/schemas/email_template.py b/backend/app/schemas/email_template.py index cf1c384..68c0899 100644 --- a/backend/app/schemas/email_template.py +++ b/backend/app/schemas/email_template.py @@ -1,4 +1,5 @@ """Email template schemas.""" + from datetime import datetime from typing import Any from uuid import UUID diff --git a/backend/app/schemas/feature.py b/backend/app/schemas/feature.py index 3546b7d..454e466 100644 --- a/backend/app/schemas/feature.py +++ b/backend/app/schemas/feature.py @@ -1,12 +1,13 @@ """Schemas for features.""" + from datetime import datetime from enum import Enum -from typing import Optional, Any, List +from typing import Any, List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from app.models.feature import FeatureProvenance, FeatureStatus, FeaturePriority, FeatureType, FeatureCompletionStatus +from app.models.feature import FeatureCompletionStatus, FeaturePriority, FeatureProvenance, FeatureStatus, FeatureType class FeatureSortField(str, Enum): @@ -117,7 +118,7 @@ class FeatureResponse(BaseModel): short_id: str url_identifier: str - @field_validator('id', 'module_id', 'created_by', 'completed_by_id', mode='before') + @field_validator("id", "module_id", "created_by", "completed_by_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -177,7 +178,7 @@ class FeatureListResponse(BaseModel): short_id: str url_identifier: str - @field_validator('id', 'module_id', mode='before') + @field_validator("id", "module_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -230,6 +231,7 @@ def from_feature( # Import-related schemas # ============================================ + class IssueSearchRequest(BaseModel): """Schema for searching external issues.""" @@ -279,7 +281,7 @@ class FeatureImportCommentResponse(BaseModel): source_created_at: datetime order_index: int - @field_validator('id', 'feature_id', mode='before') + @field_validator("id", "feature_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -362,8 +364,8 @@ class FeatureSidebarData(BaseModel): has_prompt_plan: bool = False has_notes: bool = False - @model_validator(mode='after') - def fallback_decision_summary_short(self) -> 'FeatureSidebarData': + @model_validator(mode="after") + def fallback_decision_summary_short(self) -> "FeatureSidebarData": """Fallback decision_summary_short to decision_summary if not set.""" if self.decision_summary_short is None and self.decision_summary is not None: self.decision_summary_short = self.decision_summary diff --git a/backend/app/schemas/feature_content_version.py b/backend/app/schemas/feature_content_version.py index 344aafb..25b46e9 100644 --- a/backend/app/schemas/feature_content_version.py +++ b/backend/app/schemas/feature_content_version.py @@ -1,6 +1,7 @@ """Schemas for feature content versions.""" + from datetime import datetime -from typing import Optional, Any +from typing import Any, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator diff --git a/backend/app/schemas/final_version.py b/backend/app/schemas/final_version.py index 029d95d..d688c7d 100644 --- a/backend/app/schemas/final_version.py +++ b/backend/app/schemas/final_version.py @@ -1,6 +1,7 @@ """Schemas for final versions (specs and prompt plans).""" + from datetime import datetime -from typing import Optional, Any +from typing import Any, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, field_validator @@ -26,7 +27,7 @@ class FinalSpecResponse(BaseModel): generated_at: datetime created_by: str - @field_validator('id', 'brainstorming_phase_id', 'created_by', mode='before') + @field_validator("id", "brainstorming_phase_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -34,7 +35,7 @@ def convert_uuid_to_str(cls, value: Any) -> str: return str(value) return value - @field_validator('generated_from_version_id', mode='before') + @field_validator("generated_from_version_id", mode="before") @classmethod def convert_optional_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string, allowing None.""" @@ -58,7 +59,7 @@ class FinalPromptPlanResponse(BaseModel): generated_at: datetime created_by: str - @field_validator('id', 'brainstorming_phase_id', 'created_by', mode='before') + @field_validator("id", "brainstorming_phase_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -66,7 +67,7 @@ def convert_uuid_to_str(cls, value: Any) -> str: return str(value) return value - @field_validator('generated_from_version_id', mode='before') + @field_validator("generated_from_version_id", mode="before") @classmethod def convert_optional_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string, allowing None.""" diff --git a/backend/app/schemas/form_draft.py b/backend/app/schemas/form_draft.py index 37df1f5..d8e471e 100644 --- a/backend/app/schemas/form_draft.py +++ b/backend/app/schemas/form_draft.py @@ -12,9 +12,7 @@ class FormDraftUpsertRequest(BaseModel): """Schema for creating or updating a form draft.""" - id: Optional[str] = Field( - None, description="If provided, updates existing draft; otherwise creates new" - ) + id: Optional[str] = Field(None, description="If provided, updates existing draft; otherwise creates new") draft_type: FormDraftType name: str = Field(..., min_length=1, max_length=255) content: Dict[str, Any] diff --git a/backend/app/schemas/grounding_note.py b/backend/app/schemas/grounding_note.py index 5e7d869..2747353 100644 --- a/backend/app/schemas/grounding_note.py +++ b/backend/app/schemas/grounding_note.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict class GroundingNoteCreate(BaseModel): @@ -33,11 +33,7 @@ def from_version(cls, version) -> "GroundingNoteVersionListResponse": id=str(version.id), version=version.version, edit_source=version.edit_source, - creator_display_name=( - version.creator.display_name or version.creator.email - if version.creator - else None - ), + creator_display_name=(version.creator.display_name or version.creator.email if version.creator else None), created_at=version.created_at, ) @@ -66,10 +62,6 @@ def from_version(cls, version) -> "GroundingNoteVersionResponse": content_markdown=version.content_markdown, is_active=version.is_active, edit_source=version.edit_source, - creator_display_name=( - version.creator.display_name or version.creator.email - if version.creator - else None - ), + creator_display_name=(version.creator.display_name or version.creator.email if version.creator else None), created_at=version.created_at, ) diff --git a/backend/app/schemas/identity.py b/backend/app/schemas/identity.py index e68a0f0..649626f 100644 --- a/backend/app/schemas/identity.py +++ b/backend/app/schemas/identity.py @@ -3,6 +3,7 @@ Schemas for identity providers and user identities used in authentication. """ + from datetime import datetime from typing import Optional from uuid import UUID diff --git a/backend/app/schemas/implementation.py b/backend/app/schemas/implementation.py index 2a6084a..df74de6 100644 --- a/backend/app/schemas/implementation.py +++ b/backend/app/schemas/implementation.py @@ -1,7 +1,7 @@ """Schemas for implementations.""" + from datetime import datetime from typing import Optional -from uuid import UUID from pydantic import BaseModel, ConfigDict, Field diff --git a/backend/app/schemas/inbox_badge.py b/backend/app/schemas/inbox_badge.py index eaf784a..f6d239f 100644 --- a/backend/app/schemas/inbox_badge.py +++ b/backend/app/schemas/inbox_badge.py @@ -1,4 +1,5 @@ """Schemas for inbox badge count aggregation.""" + from pydantic import BaseModel diff --git a/backend/app/schemas/inbox_conversation.py b/backend/app/schemas/inbox_conversation.py index 348e4a1..a5ecf5d 100644 --- a/backend/app/schemas/inbox_conversation.py +++ b/backend/app/schemas/inbox_conversation.py @@ -1,4 +1,5 @@ """Schemas for inbox conversation aggregation API.""" + import enum from datetime import datetime from typing import Optional diff --git a/backend/app/schemas/inbox_deep_link.py b/backend/app/schemas/inbox_deep_link.py index 90616ce..bcd3efa 100644 --- a/backend/app/schemas/inbox_deep_link.py +++ b/backend/app/schemas/inbox_deep_link.py @@ -1,4 +1,5 @@ """Schemas for inbox deep link resolution API.""" + from typing import Optional from pydantic import BaseModel diff --git a/backend/app/schemas/inbox_event.py b/backend/app/schemas/inbox_event.py index 46ff687..7db43de 100644 --- a/backend/app/schemas/inbox_event.py +++ b/backend/app/schemas/inbox_event.py @@ -1,7 +1,9 @@ """Schema definitions for inbox WebSocket events.""" + from datetime import datetime from enum import Enum from typing import List, Optional + from pydantic import BaseModel diff --git a/backend/app/schemas/inbox_follow.py b/backend/app/schemas/inbox_follow.py index c0a2e19..5f04dc9 100644 --- a/backend/app/schemas/inbox_follow.py +++ b/backend/app/schemas/inbox_follow.py @@ -8,7 +8,6 @@ from app.models.inbox_follow import InboxFollowType, InboxThreadType - # Request Schemas @@ -16,9 +15,7 @@ class FollowThreadRequest(BaseModel): """Request body for following a thread.""" thread_id: str = Field(..., description="Thread ID to follow") - thread_type: InboxThreadType = Field( - ..., description="Type: feature, phase, project_chat" - ) + thread_type: InboxThreadType = Field(..., description="Type: feature, phase, project_chat") # Response Schemas diff --git a/backend/app/schemas/integration_config.py b/backend/app/schemas/integration_config.py index 2592e5d..426bbf3 100644 --- a/backend/app/schemas/integration_config.py +++ b/backend/app/schemas/integration_config.py @@ -1,4 +1,5 @@ """Integration config schemas.""" + from datetime import datetime from typing import Any, Optional from uuid import UUID @@ -7,7 +8,7 @@ from app.models.integration_config import IntegrationVisibility from app.models.integration_config_share import IntegrationShareSubjectType -from app.schemas.project_share import UserSummary, GroupSummary +from app.schemas.project_share import GroupSummary, UserSummary class IntegrationConfigBase(BaseModel): @@ -17,12 +18,8 @@ class IntegrationConfigBase(BaseModel): ..., description="Provider name (github, jira, anthropic, openai, azure-openai, google-vertex, deepseek, aws-bedrock)", ) - display_name: str = Field( - ..., description="Display name for this integration (e.g., 'Production OpenAI')" - ) - config_json: dict[str, Any] | None = Field( - None, description="Optional provider-specific configuration" - ) + display_name: str = Field(..., description="Display name for this integration (e.g., 'Production OpenAI')") + config_json: dict[str, Any] | None = Field(None, description="Optional provider-specific configuration") class IntegrationConfigCreate(IntegrationConfigBase): @@ -40,12 +37,8 @@ class IntegrationConfigUpdate(BaseModel): display_name: str | None = Field(None, description="New display name") token: str | None = Field(None, description="New authentication token") - config_json: dict[str, Any] | None = Field( - None, description="Updated configuration" - ) - visibility: IntegrationVisibility | None = Field( - None, description="New visibility level" - ) + config_json: dict[str, Any] | None = Field(None, description="Updated configuration") + visibility: IntegrationVisibility | None = Field(None, description="New visibility level") class IntegrationConfigResponse(IntegrationConfigBase): diff --git a/backend/app/schemas/invitation.py b/backend/app/schemas/invitation.py index 7e9c49c..416dfeb 100644 --- a/backend/app/schemas/invitation.py +++ b/backend/app/schemas/invitation.py @@ -1,12 +1,10 @@ """Invitation schemas for request/response validation.""" from datetime import datetime -from typing import Optional from uuid import UUID from pydantic import BaseModel, Field, field_validator -from app.models.org_invitation import InvitationStatus from app.models.org_membership import OrgRole diff --git a/backend/app/schemas/llm_call_log.py b/backend/app/schemas/llm_call_log.py index 370a168..9f3df95 100644 --- a/backend/app/schemas/llm_call_log.py +++ b/backend/app/schemas/llm_call_log.py @@ -1,8 +1,9 @@ """ Pydantic schemas for LLM Call Log API responses. """ + from datetime import datetime -from typing import Optional, List +from typing import List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict @@ -10,6 +11,7 @@ class LLMCallLogSummary(BaseModel): """Summary view for list display (excludes full request/response).""" + model_config = ConfigDict(from_attributes=True) id: UUID @@ -26,6 +28,7 @@ class LLMCallLogSummary(BaseModel): class LLMCallLogDetail(BaseModel): """Full detail view including request/response.""" + model_config = ConfigDict(from_attributes=True) id: UUID @@ -64,6 +67,7 @@ class LLMCallLogDetail(BaseModel): class JobWithCallLogs(BaseModel): """Job summary with its call logs for Agent Log page.""" + model_config = ConfigDict(from_attributes=True) # Job info @@ -95,6 +99,7 @@ class JobWithCallLogs(BaseModel): class AgentLogListResponse(BaseModel): """Response for listing agent logs.""" + items: List[JobWithCallLogs] total: int limit: int diff --git a/backend/app/schemas/llm_preference.py b/backend/app/schemas/llm_preference.py index dee0b36..f84e167 100644 --- a/backend/app/schemas/llm_preference.py +++ b/backend/app/schemas/llm_preference.py @@ -1,4 +1,5 @@ """LLM Preference schemas.""" + from datetime import datetime from typing import Literal from uuid import UUID diff --git a/backend/app/schemas/mcp_call_log.py b/backend/app/schemas/mcp_call_log.py index 1411587..67289c6 100644 --- a/backend/app/schemas/mcp_call_log.py +++ b/backend/app/schemas/mcp_call_log.py @@ -1,8 +1,9 @@ """ Pydantic schemas for MCP Call Log API responses. """ + from datetime import datetime -from typing import Optional, List, Any +from typing import Any, List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict @@ -10,6 +11,7 @@ class MCPCallLogSummary(BaseModel): """Summary view for list display (excludes full request/response).""" + model_config = ConfigDict(from_attributes=True) id: UUID @@ -27,6 +29,7 @@ class MCPCallLogSummary(BaseModel): class MCPCallLogDetail(BaseModel): """Full detail view including request/response.""" + model_config = ConfigDict(from_attributes=True) id: UUID @@ -56,6 +59,7 @@ class MCPCallLogDetail(BaseModel): class MCPLogListResponse(BaseModel): """Response for listing MCP logs.""" + items: List[MCPCallLogSummary] total: int limit: int diff --git a/backend/app/schemas/module.py b/backend/app/schemas/module.py index 12361fa..70f16e4 100644 --- a/backend/app/schemas/module.py +++ b/backend/app/schemas/module.py @@ -1,6 +1,7 @@ """Schemas for modules.""" + from datetime import datetime -from typing import Optional, Any +from typing import Any, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -47,7 +48,7 @@ class ModuleResponse(BaseModel): short_id: str url_identifier: str - @field_validator('id', 'project_id', 'brainstorming_phase_id', 'created_by', mode='before') + @field_validator("id", "project_id", "brainstorming_phase_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -80,7 +81,7 @@ class ModuleListResponse(BaseModel): short_id: str url_identifier: str - @field_validator('id', 'project_id', 'brainstorming_phase_id', mode='before') + @field_validator("id", "project_id", "brainstorming_phase_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -102,7 +103,7 @@ class ModuleArchiveResponse(BaseModel): archived_at: datetime archived_features_count: int - @field_validator('id', 'project_id', mode='before') + @field_validator("id", "project_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -110,4 +111,4 @@ def convert_uuid_to_str(cls, value: Any) -> Optional[str]: return None if isinstance(value, UUID): return str(value) - return value \ No newline at end of file + return value diff --git a/backend/app/schemas/notification.py b/backend/app/schemas/notification.py index c097f3e..4b767ea 100644 --- a/backend/app/schemas/notification.py +++ b/backend/app/schemas/notification.py @@ -1,16 +1,19 @@ """ Notification schemas for request/response validation. """ + +from datetime import datetime +from typing import Optional from uuid import UUID + from pydantic import BaseModel, Field -from typing import Optional -from datetime import datetime from app.models import NotificationChannel class NotificationPreferenceBase(BaseModel): """Base notification preference schema.""" + channel: NotificationChannel enabled: bool channel_config: Optional[str] = None @@ -18,17 +21,20 @@ class NotificationPreferenceBase(BaseModel): class NotificationPreferenceCreate(NotificationPreferenceBase): """Schema for creating a notification preference.""" + pass class NotificationPreferenceUpdate(BaseModel): """Schema for updating a notification preference.""" + enabled: Optional[bool] = None channel_config: Optional[str] = None class NotificationPreferenceResponse(NotificationPreferenceBase): """Schema for notification preference responses.""" + id: UUID user_id: UUID created_at: datetime @@ -40,11 +46,13 @@ class Config: class NotificationProjectMuteCreate(BaseModel): """Schema for creating a project mute.""" + project_id: UUID class NotificationProjectMuteResponse(BaseModel): """Schema for project mute responses.""" + id: UUID user_id: UUID project_id: UUID @@ -56,11 +64,13 @@ class Config: class NotificationThreadWatchCreate(BaseModel): """Schema for creating a thread watch.""" + thread_id: str class NotificationThreadWatchResponse(BaseModel): """Schema for thread watch responses.""" + id: UUID user_id: UUID thread_id: str @@ -72,6 +82,7 @@ class Config: class NotificationEventCreate(BaseModel): """Schema for creating a notification event.""" + event_type: str = Field(..., description="Event type (e.g., 'thread_comment_added')") project_id: UUID related_entity_id: Optional[str] = None diff --git a/backend/app/schemas/oauth.py b/backend/app/schemas/oauth.py index 26edd1e..9452690 100644 --- a/backend/app/schemas/oauth.py +++ b/backend/app/schemas/oauth.py @@ -3,6 +3,7 @@ Schemas for normalized user information from OAuth providers. """ + from typing import Any from pydantic import BaseModel, Field diff --git a/backend/app/schemas/org.py b/backend/app/schemas/org.py index ab415b5..18dec1a 100644 --- a/backend/app/schemas/org.py +++ b/backend/app/schemas/org.py @@ -4,6 +4,7 @@ These schemas are used for request validation and response serialization in the organization endpoints. """ + from datetime import datetime from typing import Optional from uuid import UUID @@ -57,9 +58,7 @@ class OrgMemberResponse(BaseModel): user_id: UUID = Field(..., description="User ID") email: str = Field(..., description="User email address") display_name: Optional[str] = Field(None, description="User display name") - role: str = Field( - ..., description="Role in organization (owner, admin, member, viewer)" - ) + role: str = Field(..., description="Role in organization (owner, admin, member, viewer)") joined_at: datetime = Field(..., description="When user joined organization") diff --git a/backend/app/schemas/phase_container.py b/backend/app/schemas/phase_container.py index a347a9e..248cc29 100644 --- a/backend/app/schemas/phase_container.py +++ b/backend/app/schemas/phase_container.py @@ -1,6 +1,7 @@ """Schemas for phase containers.""" + from datetime import datetime -from typing import Optional, Any, List +from typing import Any, List, Optional from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -36,7 +37,7 @@ class PhaseContainerResponse(BaseModel): created_at: datetime updated_at: datetime - @field_validator('id', 'project_id', 'archived_by_id', mode='before') + @field_validator("id", "project_id", "archived_by_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -59,7 +60,7 @@ class PhaseContainerListResponse(BaseModel): archived_at: Optional[datetime] = None created_at: datetime - @field_validator('id', 'project_id', mode='before') + @field_validator("id", "project_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" diff --git a/backend/app/schemas/plan_recommendation.py b/backend/app/schemas/plan_recommendation.py index fa1112f..e1df2f6 100644 --- a/backend/app/schemas/plan_recommendation.py +++ b/backend/app/schemas/plan_recommendation.py @@ -20,22 +20,12 @@ class EfficiencyStreakInfo(BaseModel): model_config = ConfigDict(from_attributes=True) has_streak: bool = Field(description="Whether a qualifying streak was found") - action: Optional[RecommendationAction] = Field( - default=None, description="Recommended action if streak qualifies" - ) + action: Optional[RecommendationAction] = Field(default=None, description="Recommended action if streak qualifies") consecutive_days: int = Field(default=0, description="Number of consecutive days in streak") - avg_efficiency_percent: float = Field( - default=0.0, description="Average efficiency during streak" - ) - streak_start_date: Optional[date] = Field( - default=None, description="Start date of the streak" - ) - streak_end_date: Optional[date] = Field( - default=None, description="End date of the streak" - ) - avg_daily_tokens_used: int = Field( - default=0, description="Average daily tokens used during streak" - ) + avg_efficiency_percent: float = Field(default=0.0, description="Average efficiency during streak") + streak_start_date: Optional[date] = Field(default=None, description="Start date of the streak") + streak_end_date: Optional[date] = Field(default=None, description="End date of the streak") + avg_daily_tokens_used: int = Field(default=0, description="Average daily tokens used during streak") class PlanRecommendationBase(BaseModel): @@ -47,13 +37,9 @@ class PlanRecommendationBase(BaseModel): avg_efficiency_percent: float = Field(description="Average efficiency percentage") streak_start_date: date = Field(description="Start of the streak") streak_end_date: date = Field(description="End of the streak") - tokens_allocated_at_recommendation: int = Field( - description="Token allocation when recommendation was made" - ) + tokens_allocated_at_recommendation: int = Field(description="Token allocation when recommendation was made") avg_daily_tokens_used: int = Field(description="Average daily token usage") - recommended_allocation: Optional[int] = Field( - default=None, description="Suggested new allocation" - ) + recommended_allocation: Optional[int] = Field(default=None, description="Suggested new allocation") savings_or_increase_percent: Optional[float] = Field( default=None, description="Estimated savings/increase percentage" ) @@ -67,12 +53,8 @@ class PlanRecommendationResponse(PlanRecommendationBase): id: UUID = Field(description="Recommendation ID") org_id: UUID = Field(description="Organization ID") status: RecommendationStatus = Field(description="Current status") - dismissed_at: Optional[datetime] = Field( - default=None, description="When dismissed" - ) - dismissed_by_user_id: Optional[UUID] = Field( - default=None, description="User who dismissed" - ) + dismissed_at: Optional[datetime] = Field(default=None, description="When dismissed") + dismissed_by_user_id: Optional[UUID] = Field(default=None, description="User who dismissed") created_at: datetime = Field(description="When created") updated_at: datetime = Field(description="When last updated") @@ -80,9 +62,7 @@ class PlanRecommendationResponse(PlanRecommendationBase): class DismissRecommendationRequest(BaseModel): """Request to dismiss a recommendation.""" - reason: Optional[str] = Field( - default=None, description="Optional reason for dismissal" - ) + reason: Optional[str] = Field(default=None, description="Optional reason for dismissal") class EvaluateRecommendationResponse(BaseModel): @@ -92,15 +72,11 @@ class EvaluateRecommendationResponse(BaseModel): org_id: UUID = Field(description="Organization ID") org_name: str = Field(description="Organization name") - recommendation_created: bool = Field( - description="Whether a new recommendation was created" - ) + recommendation_created: bool = Field(description="Whether a new recommendation was created") recommendation: Optional[PlanRecommendationResponse] = Field( default=None, description="The recommendation if created or already exists" ) - streak_info: EfficiencyStreakInfo = Field( - description="Information about the analyzed efficiency streak" - ) + streak_info: EfficiencyStreakInfo = Field(description="Information about the analyzed efficiency streak") class ProcessAllOrgsResponse(BaseModel): @@ -109,9 +85,5 @@ class ProcessAllOrgsResponse(BaseModel): model_config = ConfigDict(from_attributes=True) orgs_processed: int = Field(description="Number of organizations processed") - recommendations_created: int = Field( - description="Number of new recommendations created" - ) - recommendations_expired: int = Field( - description="Number of recommendations that were expired" - ) + recommendations_created: int = Field(description="Number of new recommendations created") + recommendations_expired: int = Field(description="Number of recommendations that were expired") diff --git a/backend/app/schemas/platform_settings.py b/backend/app/schemas/platform_settings.py index 62e76f7..c68a507 100644 --- a/backend/app/schemas/platform_settings.py +++ b/backend/app/schemas/platform_settings.py @@ -1,11 +1,11 @@ """Platform settings schemas.""" + from datetime import datetime from typing import Any, Literal from uuid import UUID from pydantic import BaseModel, ConfigDict, Field - # ==================== Platform Connector Schemas ==================== PlatformConnectorTypeStr = Literal["llm", "email", "object_storage", "code_explorer", "web_search"] @@ -14,16 +14,12 @@ class PlatformConnectorBase(BaseModel): """Base platform connector schema.""" - connector_type: PlatformConnectorTypeStr = Field( - ..., description="Type of connector (llm, email, object_storage)" - ) + connector_type: PlatformConnectorTypeStr = Field(..., description="Type of connector (llm, email, object_storage)") provider: str = Field( ..., description="Provider name (anthropic, openai, sendgrid, aws-s3, etc.)", ) - display_name: str = Field( - ..., description="Human-readable name for this connector" - ) + display_name: str = Field(..., description="Human-readable name for this connector") config_json: dict[str, Any] | None = Field( None, description="Provider-specific configuration (model, region, from_email, etc.)" ) @@ -32,9 +28,7 @@ class PlatformConnectorBase(BaseModel): class PlatformConnectorCreate(PlatformConnectorBase): """Schema for creating a platform connector.""" - credentials: str = Field( - ..., description="API key or credentials (will be encrypted)" - ) + credentials: str = Field(..., description="API key or credentials (will be encrypted)") class PlatformConnectorUpdate(BaseModel): @@ -42,9 +36,7 @@ class PlatformConnectorUpdate(BaseModel): display_name: str | None = Field(None, description="New display name") credentials: str | None = Field(None, description="New credentials (will be encrypted)") - config_json: dict[str, Any] | None = Field( - None, description="Updated configuration" - ) + config_json: dict[str, Any] | None = Field(None, description="Updated configuration") is_active: bool | None = Field(None, description="Active status") @@ -66,14 +58,10 @@ class PlatformConnectorResponse(BaseModel): class PlatformConnectorTest(BaseModel): """Schema for testing a connector before creating.""" - connector_type: PlatformConnectorTypeStr = Field( - ..., description="Type of connector" - ) + connector_type: PlatformConnectorTypeStr = Field(..., description="Type of connector") provider: str = Field(..., description="Provider name") credentials: str = Field(..., description="API key or credentials to test") - config_json: dict[str, Any] | None = Field( - None, description="Provider-specific configuration" - ) + config_json: dict[str, Any] | None = Field(None, description="Provider-specific configuration") class PlatformConnectorTestResult(BaseModel): @@ -85,6 +73,7 @@ class PlatformConnectorTestResult(BaseModel): # ==================== Platform Settings Schemas ==================== + class PlatformSettingsResponse(BaseModel): """Schema for platform settings response.""" @@ -110,65 +99,32 @@ class PlatformSettingsResponse(BaseModel): class PlatformSettingsUpdate(BaseModel): """Schema for updating platform settings.""" - main_llm_connector_id: UUID | None = Field( - None, description="ID of main (heavy) LLM connector" - ) - lightweight_llm_connector_id: UUID | None = Field( - None, description="ID of lightweight LLM connector" - ) - email_connector_id: UUID | None = Field( - None, description="ID of email connector (Sendgrid)" - ) - object_storage_connector_id: UUID | None = Field( - None, description="ID of object storage connector (S3)" - ) + main_llm_connector_id: UUID | None = Field(None, description="ID of main (heavy) LLM connector") + lightweight_llm_connector_id: UUID | None = Field(None, description="ID of lightweight LLM connector") + email_connector_id: UUID | None = Field(None, description="ID of email connector (Sendgrid)") + object_storage_connector_id: UUID | None = Field(None, description="ID of object storage connector (S3)") code_explorer_connector_id: UUID | None = Field( None, description="ID of code explorer connector (Anthropic API for Claude Code)" ) - code_explorer_enabled: bool | None = Field( - None, description="Enable code explorer feature" - ) - base_url: str | None = Field( - None, description="Base URL for the platform (e.g., https://mfbt.example.com)" - ) - mock_discovery_enabled: bool | None = Field( - None, description="Enable mock discovery mode" - ) - mock_discovery_question_limit: int | None = Field( - None, description="Question limit for mock discovery" - ) - mock_discovery_delay_seconds: int | None = Field( - None, description="Delay in seconds for mock discovery" - ) + code_explorer_enabled: bool | None = Field(None, description="Enable code explorer feature") + base_url: str | None = Field(None, description="Base URL for the platform (e.g., https://mfbt.example.com)") + mock_discovery_enabled: bool | None = Field(None, description="Enable mock discovery mode") + mock_discovery_question_limit: int | None = Field(None, description="Question limit for mock discovery") + mock_discovery_delay_seconds: int | None = Field(None, description="Delay in seconds for mock discovery") # Flags to explicitly clear connector IDs - clear_main_llm: bool = Field( - False, description="Set to true to clear main LLM connector" - ) - clear_lightweight_llm: bool = Field( - False, description="Set to true to clear lightweight LLM connector" - ) - clear_email: bool = Field( - False, description="Set to true to clear email connector" - ) - clear_object_storage: bool = Field( - False, description="Set to true to clear object storage connector" - ) - clear_code_explorer: bool = Field( - False, description="Set to true to clear code explorer connector" - ) - web_search_connector_id: UUID | None = Field( - None, description="ID of web search connector (Tavily)" - ) - web_search_enabled: bool | None = Field( - None, description="Enable web search feature" - ) - clear_web_search: bool = Field( - False, description="Set to true to clear web search connector" - ) + clear_main_llm: bool = Field(False, description="Set to true to clear main LLM connector") + clear_lightweight_llm: bool = Field(False, description="Set to true to clear lightweight LLM connector") + clear_email: bool = Field(False, description="Set to true to clear email connector") + clear_object_storage: bool = Field(False, description="Set to true to clear object storage connector") + clear_code_explorer: bool = Field(False, description="Set to true to clear code explorer connector") + web_search_connector_id: UUID | None = Field(None, description="ID of web search connector (Tavily)") + web_search_enabled: bool | None = Field(None, description="Enable web search feature") + clear_web_search: bool = Field(False, description="Set to true to clear web search connector") # ==================== Platform Admin Check Schema ==================== + class PlatformAdminCheckResponse(BaseModel): """Schema for platform admin check response.""" @@ -177,6 +133,7 @@ class PlatformAdminCheckResponse(BaseModel): # ==================== Test Email Schema ==================== + class SendTestEmailRequest(BaseModel): """Schema for sending a test email.""" @@ -192,18 +149,13 @@ class SendTestEmailResponse(BaseModel): # ==================== Email Environment Config Schema ==================== + class EmailEnvConfigResponse(BaseModel): """Schema for email environment config check response.""" - configured: bool = Field( - ..., description="Whether email env vars are configured with non-empty values" - ) - from_email: str | None = Field( - None, description="From email address (if configured)" - ) - from_name: str | None = Field( - None, description="From name (if configured)" - ) + configured: bool = Field(..., description="Whether email env vars are configured with non-empty values") + from_email: str | None = Field(None, description="From email address (if configured)") + from_name: str | None = Field(None, description="From name (if configured)") # ==================== User Plan Management Schemas ==================== @@ -274,24 +226,12 @@ class UserPlanUpdateRequest(BaseModel): """Request to update user/org plan values.""" # Optional fields - only update what's provided - plan_end_date: datetime | None = Field( - None, description="Extend trial by setting new end date" - ) - plan_llm_tokens_total: int | None = Field( - None, description="Set total token limit" - ) - plan_llm_tokens_per_month: int | None = Field( - None, description="Set monthly token limit" - ) - plan_max_projects: int | None = Field( - None, description="Set max projects limit" - ) - plan_max_users: int | None = Field( - None, description="Set max users limit" - ) - reset_token_usage: bool = Field( - False, description="Reset plan_llm_tokens_used to 0" - ) + plan_end_date: datetime | None = Field(None, description="Extend trial by setting new end date") + plan_llm_tokens_total: int | None = Field(None, description="Set total token limit") + plan_llm_tokens_per_month: int | None = Field(None, description="Set monthly token limit") + plan_max_projects: int | None = Field(None, description="Set max projects limit") + plan_max_users: int | None = Field(None, description="Set max users limit") + reset_token_usage: bool = Field(False, description="Reset plan_llm_tokens_used to 0") class UserPlanUpdateResponse(BaseModel): @@ -308,29 +248,19 @@ class UserPlanUpdateResponse(BaseModel): class FreemiumSettingsResponse(BaseModel): """Response containing freemium plan configuration.""" - freemium_initial_tokens: int = Field( - description="Initial tokens granted to new users on signup" - ) - freemium_weekly_topup_tokens: int = Field( - description="Tokens added each Monday (additive, up to max)" - ) - freemium_max_tokens: int = Field( - description="Maximum token balance for freemium users" - ) + freemium_initial_tokens: int = Field(description="Initial tokens granted to new users on signup") + freemium_weekly_topup_tokens: int = Field(description="Tokens added each Monday (additive, up to max)") + freemium_max_tokens: int = Field(description="Maximum token balance for freemium users") class FreemiumSettingsUpdate(BaseModel): """Request to update freemium plan configuration.""" - freemium_initial_tokens: int | None = Field( - None, ge=0, description="Initial tokens granted to new users on signup" - ) + freemium_initial_tokens: int | None = Field(None, ge=0, description="Initial tokens granted to new users on signup") freemium_weekly_topup_tokens: int | None = Field( None, ge=0, description="Tokens added each Monday (additive, up to max)" ) - freemium_max_tokens: int | None = Field( - None, ge=0, description="Maximum token balance for freemium users" - ) + freemium_max_tokens: int | None = Field(None, ge=0, description="Maximum token balance for freemium users") # ==================== Web Search Environment Config Schema ==================== @@ -339,9 +269,7 @@ class FreemiumSettingsUpdate(BaseModel): class WebSearchEnvConfigResponse(BaseModel): """Schema for web search environment config check response.""" - configured: bool = Field( - ..., description="Whether web search env vars are configured with non-empty values" - ) + configured: bool = Field(..., description="Whether web search env vars are configured with non-empty values") # ==================== GitHub OAuth Settings Schemas ==================== @@ -350,37 +278,23 @@ class WebSearchEnvConfigResponse(BaseModel): class GitHubOAuthEnvConfigResponse(BaseModel): """Schema for GitHub OAuth environment config check response.""" - configured: bool = Field( - ..., description="Whether GitHub OAuth env vars are configured with non-empty values" - ) + configured: bool = Field(..., description="Whether GitHub OAuth env vars are configured with non-empty values") class GitHubOAuthSettingsResponse(BaseModel): """Schema for GitHub OAuth settings status response.""" - has_client_id: bool = Field( - ..., description="Whether a client ID is configured (UI or ENV)" - ) - has_client_secret: bool = Field( - ..., description="Whether a client secret is configured (UI or ENV)" - ) - is_fully_configured: bool = Field( - ..., description="Whether both client ID and secret are configured" - ) - source: str | None = Field( - None, description="Configuration source: 'ui', 'env', or None if not configured" - ) + has_client_id: bool = Field(..., description="Whether a client ID is configured (UI or ENV)") + has_client_secret: bool = Field(..., description="Whether a client secret is configured (UI or ENV)") + is_fully_configured: bool = Field(..., description="Whether both client ID and secret are configured") + source: str | None = Field(None, description="Configuration source: 'ui', 'env', or None if not configured") class GitHubOAuthSettingsUpdate(BaseModel): """Schema for updating GitHub OAuth settings.""" - client_id: str | None = Field( - None, description="GitHub OAuth App Client ID" - ) - client_secret: str | None = Field( - None, description="GitHub OAuth App Client Secret" - ) + client_id: str | None = Field(None, description="GitHub OAuth App Client ID") + client_secret: str | None = Field(None, description="GitHub OAuth App Client Secret") clear_credentials: bool = Field( False, description="Set to true to clear all GitHub OAuth credentials from UI config" ) diff --git a/backend/app/schemas/project.py b/backend/app/schemas/project.py index 4938422..0230edb 100644 --- a/backend/app/schemas/project.py +++ b/backend/app/schemas/project.py @@ -1,4 +1,5 @@ """Project schemas.""" + from datetime import datetime from enum import Enum from typing import Optional @@ -6,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field -from app.models.project import ProjectType, ProjectStatus +from app.models.project import ProjectStatus, ProjectType from app.models.project_membership import ProjectRole @@ -26,7 +27,9 @@ class ProjectCreate(BaseModel): short_description: Optional[str] = None idea_text: Optional[str] = None key: Optional[str] = Field(None, max_length=100) - project_tech_stack: Optional[str] = Field(None, max_length=20, description="High-level tech stack type (web_fullstack, mobile, desktop, etc.)") + project_tech_stack: Optional[str] = Field( + None, max_length=20, description="High-level tech stack type (web_fullstack, mobile, desktop, etc.)" + ) class ProjectUpdate(BaseModel): diff --git a/backend/app/schemas/project_chat.py b/backend/app/schemas/project_chat.py index 1d129e3..34671ec 100644 --- a/backend/app/schemas/project_chat.py +++ b/backend/app/schemas/project_chat.py @@ -1,4 +1,5 @@ """Schemas for project chat API endpoints.""" + from datetime import datetime from typing import List, Optional from uuid import UUID @@ -10,12 +11,14 @@ class MCQOption(BaseModel): """Schema for an MCQ option in bot responses.""" + id: str text: str class MCQAnswerContext(BaseModel): """MCQ answer context stored with user messages when answering an MCQ.""" + question_text: str question_message_id: str choices: List[MCQOption] @@ -25,6 +28,7 @@ class MCQAnswerContext(BaseModel): class ProjectChatMessageAuthor(BaseModel): """Author information for project chat messages.""" + id: str email: str display_name: str @@ -38,6 +42,7 @@ def from_user(cls, user) -> "ProjectChatMessageAuthor": class ProjectChatMessageResponse(BaseModel): """Response schema for a project chat message.""" + id: UUID project_chat_id: UUID message_type: str @@ -61,9 +66,7 @@ def from_message(cls, message) -> "ProjectChatMessageResponse": # Parse reactions from response_data reactions = [] if message.response_data and "reactions" in message.response_data: - reactions = [ - Reaction(**r) for r in message.response_data["reactions"] - ] + reactions = [Reaction(**r) for r in message.response_data["reactions"]] return cls( id=message.id, @@ -88,6 +91,7 @@ def mcq_options(self) -> Optional[List[MCQOption]]: class ProjectChatResponse(BaseModel): """Response schema for a project chat.""" + id: UUID org_id: UUID project_id: Optional[UUID] = None # Nullable for org-scoped project chats @@ -141,18 +145,21 @@ class ProjectChatResponse(BaseModel): class ProjectChatWithMessages(BaseModel): """Response schema for a project chat with all messages.""" + project_chat: ProjectChatResponse messages: List[ProjectChatMessageResponse] class CreateProjectChatRequest(BaseModel): """Request schema for creating a project chat.""" + initial_message_content: Optional[str] = None # Auto-post this message after creation target_container_id: Optional[str] = None # Target container for extension creation class SendMessageRequest(BaseModel): """Request schema for sending a user message.""" + content: str = Field(..., min_length=1, max_length=32000) images: Optional[List[ImageAttachment]] = Field(None, max_length=10) mcq_answer: Optional[MCQAnswerContext] = None @@ -160,17 +167,20 @@ class SendMessageRequest(BaseModel): class SendMessageResponse(BaseModel): """Response schema for send message endpoint.""" + job_id: Optional[UUID] = None message_id: UUID class CreatePhaseFromProjectChatRequest(BaseModel): """Request schema for creating a phase from project chat.""" + pass # Uses proposed_title and proposed_description from project chat class CreatePhaseFromProjectChatResponse(BaseModel): """Response schema for create phase from project chat endpoint.""" + phase_id: UUID phase_title: str phase_short_id: Optional[str] = None @@ -180,6 +190,7 @@ class CreatePhaseFromProjectChatResponse(BaseModel): class CreateFeatureFromProjectChatRequest(BaseModel): """Request schema for creating a feature from project chat.""" + # Optional overrides - if not provided, use proposed values from project chat module_id: Optional[str] = None # Use existing module (overrides proposal) new_module_title: Optional[str] = None # Create new module (overrides proposal) @@ -188,6 +199,7 @@ class CreateFeatureFromProjectChatRequest(BaseModel): class CreateFeatureFromProjectChatResponse(BaseModel): """Response schema for create feature from project chat endpoint.""" + feature_id: UUID feature_key: str feature_title: str @@ -199,6 +211,7 @@ class CreateFeatureFromProjectChatResponse(BaseModel): class CreatedFeatureInfo(BaseModel): """Info about a feature created from this project chat.""" + id: UUID feature_key: str title: str @@ -208,6 +221,7 @@ class CreatedFeatureInfo(BaseModel): class ProjectChatListItem(BaseModel): """Summary schema for project chat list (sidebar).""" + id: UUID chat_title: Optional[str] = None proposed_title: Optional[str] = None @@ -227,6 +241,7 @@ class ProjectChatListItem(BaseModel): class ProjectChatListResponse(BaseModel): """Paginated list response for project chats.""" + project_chats: List[ProjectChatListItem] total: int has_more: bool @@ -234,8 +249,10 @@ class ProjectChatListResponse(BaseModel): # Org-scoped project chat schemas + class OrgProjectChatListItem(BaseModel): """Summary schema for org-scoped project chat list (sidebar).""" + id: UUID chat_title: Optional[str] = None proposed_title: Optional[str] = None @@ -253,6 +270,7 @@ class OrgProjectChatListItem(BaseModel): class OrgProjectChatListResponse(BaseModel): """Paginated list response for org-scoped project chats.""" + project_chats: List[OrgProjectChatListItem] total: int has_more: bool @@ -260,6 +278,7 @@ class OrgProjectChatListResponse(BaseModel): class CreateProjectFromProjectChatRequest(BaseModel): """Request schema for creating a project from project chat.""" + # Optional overrides project_name: Optional[str] = None project_key: Optional[str] = None @@ -267,6 +286,7 @@ class CreateProjectFromProjectChatRequest(BaseModel): class CreateProjectFromProjectChatResponse(BaseModel): """Response schema for create project from project chat endpoint.""" + project_id: UUID project_name: str phase_id: UUID @@ -275,14 +295,16 @@ class CreateProjectFromProjectChatResponse(BaseModel): class ProjectChatStartOverResponse(BaseModel): """Response schema for start-over operation.""" + deleted_count: int deleted_message_ids: List[str] class UpdateVisibilityRequest(BaseModel): """Request schema for updating project chat visibility.""" + visibility: str = Field( ..., pattern="^(private|team)$", - description="Visibility: 'private' (only creator) or 'team' (all project members)" + description="Visibility: 'private' (only creator) or 'team' (all project members)", ) diff --git a/backend/app/schemas/project_repository.py b/backend/app/schemas/project_repository.py index 74d85fb..897478d 100644 --- a/backend/app/schemas/project_repository.py +++ b/backend/app/schemas/project_repository.py @@ -1,4 +1,5 @@ """ProjectRepository schemas for API requests/responses.""" + from datetime import datetime from typing import Optional from uuid import UUID diff --git a/backend/app/schemas/project_share.py b/backend/app/schemas/project_share.py index 66527d0..0c0f95b 100644 --- a/backend/app/schemas/project_share.py +++ b/backend/app/schemas/project_share.py @@ -10,7 +10,6 @@ from app.models.project_membership import ProjectRole from app.models.project_share import ShareSubjectType - # ============================================================================ # Request Schemas # ============================================================================ diff --git a/backend/app/schemas/spec.py b/backend/app/schemas/spec.py index 4f966b3..890ef45 100644 --- a/backend/app/schemas/spec.py +++ b/backend/app/schemas/spec.py @@ -2,6 +2,7 @@ from datetime import datetime from uuid import UUID + from pydantic import BaseModel, ConfigDict diff --git a/backend/app/schemas/thread.py b/backend/app/schemas/thread.py index 8bf53da..ed85e00 100644 --- a/backend/app/schemas/thread.py +++ b/backend/app/schemas/thread.py @@ -1,14 +1,18 @@ """Schemas for threads and comments.""" + from datetime import datetime -from typing import Optional, List, Any, Union +from typing import Any, List, Optional from uuid import UUID -from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator, model_validator + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + from app.models import ContextType # Thread schemas class ThreadCreate(BaseModel): """Schema for creating a thread.""" + context_type: ContextType context_id: Optional[str] = None title: Optional[str] = Field(None, min_length=1, max_length=200) @@ -16,18 +20,20 @@ class ThreadCreate(BaseModel): class ThreadUpdate(BaseModel): """Schema for updating a thread.""" + title: str = Field(..., min_length=1, max_length=200) class AuthorInfo(BaseModel): """Schema for author information.""" + id: str email: str display_name: Optional[str] = None model_config = ConfigDict(from_attributes=True) - @field_validator('id', mode='before') + @field_validator("id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -35,21 +41,22 @@ def convert_uuid_to_str(cls, value: Any) -> str: return str(value) return value - @field_validator('display_name', mode='before') + @field_validator("display_name", mode="before") @classmethod def fallback_display_name(cls, value: Any, info) -> Optional[str]: """Fallback to email prefix if display_name is None.""" if value is not None: return value # Access email from the data being validated - email = info.data.get('email') + email = info.data.get("email") if email: - return email.split('@')[0] + return email.split("@")[0] return None class CommentResponse(BaseModel): """Schema for comment response.""" + id: str thread_id: str author_id: str @@ -60,7 +67,7 @@ class CommentResponse(BaseModel): model_config = ConfigDict(from_attributes=True) - @field_validator('id', 'thread_id', 'author_id', mode='before') + @field_validator("id", "thread_id", "author_id", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> str: """Convert UUID to string before validation.""" @@ -71,6 +78,7 @@ def convert_uuid_to_str(cls, value: Any) -> str: class ThreadResponse(BaseModel): """Schema for thread response with thread items.""" + id: str project_id: str context_type: ContextType @@ -104,7 +112,7 @@ class ThreadResponse(BaseModel): model_config = ConfigDict(from_attributes=True) - @field_validator('id', 'project_id', 'context_id', 'created_by', mode='before') + @field_validator("id", "project_id", "context_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -114,8 +122,8 @@ def convert_uuid_to_str(cls, value: Any) -> Optional[str]: return str(value) return value - @model_validator(mode='after') - def fallback_decision_summary_short(self) -> 'ThreadResponse': + @model_validator(mode="after") + def fallback_decision_summary_short(self) -> "ThreadResponse": """Fallback decision_summary_short to decision_summary if not set.""" if self.decision_summary_short is None and self.decision_summary is not None: self.decision_summary_short = self.decision_summary @@ -124,6 +132,7 @@ def fallback_decision_summary_short(self) -> 'ThreadResponse': class ThreadListResponse(BaseModel): """Schema for thread list response (without comments).""" + id: str project_id: str context_type: ContextType @@ -147,14 +156,14 @@ class ThreadListResponse(BaseModel): model_config = ConfigDict(from_attributes=True) - @model_validator(mode='after') - def set_thread_id(self) -> 'ThreadListResponse': + @model_validator(mode="after") + def set_thread_id(self) -> "ThreadListResponse": """Set thread_id as alias for id for frontend compatibility.""" if self.thread_id is None: self.thread_id = self.id return self - @field_validator('id', 'project_id', 'context_id', 'created_by', mode='before') + @field_validator("id", "project_id", "context_id", "created_by", mode="before") @classmethod def convert_uuid_to_str(cls, value: Any) -> Optional[str]: """Convert UUID to string before validation.""" @@ -168,9 +177,11 @@ def convert_uuid_to_str(cls, value: Any) -> Optional[str]: # Comment schemas class CommentCreate(BaseModel): """Schema for creating a comment.""" + body_markdown: str = Field(..., min_length=1, max_length=32000) class CommentUpdate(BaseModel): """Schema for updating a comment.""" + body_markdown: str = Field(..., min_length=1, max_length=32000) diff --git a/backend/app/schemas/thread_item.py b/backend/app/schemas/thread_item.py index 259cf17..ae3f8df 100644 --- a/backend/app/schemas/thread_item.py +++ b/backend/app/schemas/thread_item.py @@ -1,7 +1,10 @@ """Schemas for thread items (mixed content).""" + from datetime import datetime -from typing import List, Optional, Union, Literal +from typing import List, Literal, Optional, Union + from pydantic import BaseModel, ConfigDict, Field + from app.models.thread_item import ThreadItemType @@ -49,6 +52,7 @@ class ImageAttachment(BaseModel): # Author info (embedded in items) class AuthorInfo(BaseModel): """Author information embedded in thread items.""" + model_config = ConfigDict(from_attributes=True) id: str @@ -69,6 +73,7 @@ def from_user(cls, user) -> "AuthorInfo": # Base thread item response class ThreadItemBase(BaseModel): """Base thread item fields.""" + model_config = ConfigDict(from_attributes=True) id: str @@ -82,6 +87,7 @@ class ThreadItemBase(BaseModel): # Comment item class CommentItemResponse(ThreadItemBase): """Response model for comment items.""" + item_type: Literal[ThreadItemType.COMMENT] = ThreadItemType.COMMENT body_markdown: str images: List[ImageAttachment] = Field(default_factory=list) @@ -120,6 +126,7 @@ def from_thread_item(cls, item): # MCQ choice class MCQChoice(BaseModel): """MCQ choice option.""" + id: str label: str @@ -127,6 +134,7 @@ class MCQChoice(BaseModel): # MCQ follow-up item class MCQFollowupItemResponse(ThreadItemBase): """Response model for MCQ follow-up items.""" + item_type: Literal[ThreadItemType.MCQ_FOLLOWUP] = ThreadItemType.MCQ_FOLLOWUP question_text: str choices: List[MCQChoice] @@ -173,6 +181,7 @@ def from_thread_item(cls, item): # No follow-up message item class NoFollowupMessageItemResponse(ThreadItemBase): """Response model for no follow-up message items.""" + item_type: Literal[ThreadItemType.NO_FOLLOWUP_MESSAGE] = ThreadItemType.NO_FOLLOWUP_MESSAGE message: str @@ -193,6 +202,7 @@ def from_thread_item(cls, item): # MCQ Answer item - shows the user's answer to an MCQ in the timeline class MCQAnswerItemResponse(ThreadItemBase): """Response model for MCQ answer items - shows the user's answer to an MCQ.""" + item_type: Literal[ThreadItemType.MCQ_ANSWER] = ThreadItemType.MCQ_ANSWER original_mcq_item_id: str question_text: str @@ -223,6 +233,7 @@ def from_thread_item(cls, item): # Implementation created marker item class ImplementationCreatedItemResponse(ThreadItemBase): """Response model for implementation created marker items.""" + item_type: Literal[ThreadItemType.IMPLEMENTATION_CREATED] = ThreadItemType.IMPLEMENTATION_CREATED implementation_id: str implementation_name: str @@ -239,11 +250,7 @@ def from_thread_item(cls, item, implementation=None): implementation: Optional Implementation model with generation flags """ # Use current implementation name if available, otherwise fall back to snapshot - impl_name = ( - implementation.name - if implementation - else item.content_data.get("implementation_name", "") - ) + impl_name = implementation.name if implementation else item.content_data.get("implementation_name", "") return cls( id=str(item.id), thread_id=str(item.thread_id), @@ -261,6 +268,7 @@ def from_thread_item(cls, item, implementation=None): # Code exploration item - shows results from code explorer class CodeExplorationItemResponse(ThreadItemBase): """Response model for code exploration result items.""" + item_type: Literal[ThreadItemType.CODE_EXPLORATION] = ThreadItemType.CODE_EXPLORATION exploration_id: str prompt: str @@ -302,6 +310,7 @@ def from_thread_item(cls, item): # System message item - shows system messages like cancellation notices class SystemItemResponse(ThreadItemBase): """Response model for system message items.""" + item_type: Literal[ThreadItemType.SYSTEM] = ThreadItemType.SYSTEM message: str @@ -322,6 +331,7 @@ def from_thread_item(cls, item): # Web search result item - displays web search results in conversation class WebSearchResultItem(BaseModel): """Individual web search result.""" + title: str url: str content: Optional[str] = None @@ -330,6 +340,7 @@ class WebSearchResultItem(BaseModel): class WebSearchItemResponse(ThreadItemBase): """Response model for web search result items.""" + item_type: Literal[ThreadItemType.WEB_SEARCH] = ThreadItemType.WEB_SEARCH query: str answer: Optional[str] = None @@ -362,18 +373,29 @@ def from_thread_item(cls, item): # Union type for polymorphic responses -ThreadItemResponse = Union[CommentItemResponse, MCQFollowupItemResponse, NoFollowupMessageItemResponse, MCQAnswerItemResponse, ImplementationCreatedItemResponse, CodeExplorationItemResponse, SystemItemResponse, WebSearchItemResponse] +ThreadItemResponse = Union[ + CommentItemResponse, + MCQFollowupItemResponse, + NoFollowupMessageItemResponse, + MCQAnswerItemResponse, + ImplementationCreatedItemResponse, + CodeExplorationItemResponse, + SystemItemResponse, + WebSearchItemResponse, +] # Request schemas class CreateCommentItem(BaseModel): """Request to create a comment item.""" + body_markdown: str = Field(..., min_length=1, max_length=32000) images: Optional[List[ImageAttachment]] = Field(None, max_length=10) class CreateMCQAnswer(BaseModel): """Request to answer an MCQ follow-up.""" + selected_option_id: str free_text: Optional[str] = Field(None, max_length=32000) force_change: bool = False # If True, delete downstream items when changing answer @@ -381,11 +403,13 @@ class CreateMCQAnswer(BaseModel): class UpdateThreadItem(BaseModel): """Request to update a thread item (e.g., edit comment).""" + body_markdown: Optional[str] = Field(None, min_length=1, max_length=32000) class StartOverResponse(BaseModel): """Response for start-over operation (bulk deletion).""" + deleted_count: int deleted_item_ids: List[str] deleted_implementation_ids: List[str] = [] @@ -393,6 +417,7 @@ class StartOverResponse(BaseModel): class DownstreamItemsResponse(BaseModel): """Response for checking downstream items from an MCQ.""" + has_downstream: bool downstream_count: int diff --git a/backend/app/schemas/user_question_session.py b/backend/app/schemas/user_question_session.py index 9a41bbe..defe851 100644 --- a/backend/app/schemas/user_question_session.py +++ b/backend/app/schemas/user_question_session.py @@ -1,19 +1,21 @@ """Schemas for user question session API endpoints.""" + from datetime import datetime -from typing import List, Optional, Any -from uuid import UUID +from typing import List, Optional from pydantic import BaseModel, Field class MCQChoiceSchema(BaseModel): """Schema for an MCQ choice.""" + id: str label: str class MCQSchema(BaseModel): """Schema for an MCQ.""" + question_text: str choices: List[MCQChoiceSchema] explanation: Optional[str] = None @@ -21,6 +23,7 @@ class MCQSchema(BaseModel): class GeneratedQuestionPreview(BaseModel): """Schema for a generated question preview (before adding to phase).""" + temp_id: str aspect_title: str title: str @@ -31,6 +34,7 @@ class GeneratedQuestionPreview(BaseModel): class UserQuestionMessageResponse(BaseModel): """Response schema for a user question message.""" + id: str role: str content: str @@ -44,6 +48,7 @@ class UserQuestionMessageResponse(BaseModel): class UserQuestionSessionResponse(BaseModel): """Response schema for a user question session.""" + id: str brainstorming_phase_id: str title: Optional[str] = None @@ -61,35 +66,41 @@ class UserQuestionSessionResponse(BaseModel): class UserQuestionSessionWithMessagesResponse(BaseModel): """Response schema for a session with messages.""" + session: UserQuestionSessionResponse messages: List[UserQuestionMessageResponse] class CreateSessionRequest(BaseModel): """Request schema for creating a session.""" + pass # No body needed, phase_id is in URL class GenerateQuestionsRequest(BaseModel): """Request schema for generating questions.""" + user_prompt: str = Field(..., min_length=1, max_length=2000) num_questions: int = Field(default=3, ge=1, le=5) class GenerateQuestionsResponse(BaseModel): """Response schema for generate questions endpoint.""" + job_id: str message_id: str class AddQuestionsRequest(BaseModel): """Request schema for adding questions to the phase.""" + message_id: str temp_question_ids: List[str] class AddQuestionsResponse(BaseModel): """Response schema for add questions endpoint.""" + added_count: int feature_ids: List[str] session_limit_reached: bool diff --git a/backend/app/services/activity_log_service.py b/backend/app/services/activity_log_service.py index a52cf8d..e86df4e 100644 --- a/backend/app/services/activity_log_service.py +++ b/backend/app/services/activity_log_service.py @@ -1,13 +1,15 @@ """Service for managing activity logs.""" -from typing import Optional, List, Any + +from typing import List, Optional from uuid import UUID -from sqlalchemy.orm import Session + from sqlalchemy import or_ +from sqlalchemy.orm import Session from app.models.activity_log import ActivityLog from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module from app.models.feature import Feature +from app.models.module import Module class ActivityEventTypes: @@ -159,34 +161,22 @@ def get_activity_for_project( """ # Get all brainstorming phase IDs for the project phase_ids = [ - str(p.id) for p in - db.query(BrainstormingPhase.id) - .filter(BrainstormingPhase.project_id == project_id) - .all() + str(p.id) for p in db.query(BrainstormingPhase.id).filter(BrainstormingPhase.project_id == project_id).all() ] # Get all module IDs for the project - module_ids = [ - str(m.id) for m in - db.query(Module.id) - .filter(Module.project_id == project_id) - .all() - ] + module_ids = [str(m.id) for m in db.query(Module.id).filter(Module.project_id == project_id).all()] # Get all feature IDs for the project's modules feature_ids = [] if module_ids: feature_ids = [ - str(f.id) for f in - db.query(Feature.id) - .filter(Feature.module_id.in_([UUID(mid) for mid in module_ids])) - .all() + str(f.id) + for f in db.query(Feature.id).filter(Feature.module_id.in_([UUID(mid) for mid in module_ids])).all() ] # Build the query with OR conditions for all entity types - conditions = [ - (ActivityLog.entity_type == "project") & (ActivityLog.entity_id == str(project_id)) - ] + conditions = [(ActivityLog.entity_type == "project") & (ActivityLog.entity_id == str(project_id))] if phase_ids: conditions.append( @@ -194,20 +184,12 @@ def get_activity_for_project( ) if module_ids: - conditions.append( - (ActivityLog.entity_type == "module") & (ActivityLog.entity_id.in_(module_ids)) - ) + conditions.append((ActivityLog.entity_type == "module") & (ActivityLog.entity_id.in_(module_ids))) if feature_ids: - conditions.append( - (ActivityLog.entity_type == "feature") & (ActivityLog.entity_id.in_(feature_ids)) - ) + conditions.append((ActivityLog.entity_type == "feature") & (ActivityLog.entity_id.in_(feature_ids))) - query = ( - db.query(ActivityLog) - .filter(or_(*conditions)) - .order_by(ActivityLog.created_at.desc()) - ) + query = db.query(ActivityLog).filter(or_(*conditions)).order_by(ActivityLog.created_at.desc()) if offset: query = query.offset(offset) diff --git a/backend/app/services/agent_utils.py b/backend/app/services/agent_utils.py index da3ddcd..ef834d5 100644 --- a/backend/app/services/agent_utils.py +++ b/backend/app/services/agent_utils.py @@ -3,13 +3,14 @@ The Agent user is a special system user that represents actions taken by automated agents (like LLM-based coding agents) rather than human users. """ -from typing import Optional + from uuid import UUID + from sqlalchemy.orm import Session + from app.models.user import User from app.services.user_service import UserService - # Special email and display name for the Agent user AGENT_EMAIL = "agent@mfbt.system" AGENT_DISPLAY_NAME = "MFBT Agent" @@ -41,6 +42,7 @@ def get_or_create_agent_user(db: Session) -> User: # Create new agent user with a secure random password # The password is not meant to be used - agents authenticate via other means import secrets + random_password = secrets.token_urlsafe(32) agent_user = UserService.create_user( diff --git a/backend/app/services/analytics_cache.py b/backend/app/services/analytics_cache.py index fcc9744..47679fb 100644 --- a/backend/app/services/analytics_cache.py +++ b/backend/app/services/analytics_cache.py @@ -85,9 +85,7 @@ def _generate_cache_key(prefix: str, *args: Any, **kwargs: Any) -> str: } # Create a hash of the arguments for a compact key - key_hash = hashlib.md5( - json.dumps(key_data, sort_keys=True).encode() - ).hexdigest()[:12] + key_hash = hashlib.md5(json.dumps(key_data, sort_keys=True).encode()).hexdigest()[:12] return f"{prefix}:{key_hash}" diff --git a/backend/app/services/analytics_service.py b/backend/app/services/analytics_service.py index b14bbff..fb17023 100644 --- a/backend/app/services/analytics_service.py +++ b/backend/app/services/analytics_service.py @@ -10,7 +10,7 @@ from typing import Optional from uuid import UUID -from sqlalchemy import func, or_ +from sqlalchemy import func from sqlalchemy.orm import Session from app.models.daily_usage_summary import DailyUsageSummary @@ -143,13 +143,9 @@ def _get_realtime_user_usage( # Date filter based on dialect if dialect == "postgresql": - query = query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date - ) + query = query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date) else: - query = query.filter( - func.date(LLMUsageLog.created_at) == target_date - ) + query = query.filter(func.date(LLMUsageLog.created_at) == target_date) if org_id: query = query.filter(LLMUsageLog.org_id == org_id) @@ -228,13 +224,9 @@ def _get_realtime_project_usage( # Date filter based on dialect if dialect == "postgresql": - query = query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date - ) + query = query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date) else: - query = query.filter( - func.date(LLMUsageLog.created_at) == target_date - ) + query = query.filter(func.date(LLMUsageLog.created_at) == target_date) if org_id: query = query.filter(LLMUsageLog.org_id == org_id) @@ -261,9 +253,9 @@ def _merge_usage_dicts(historical: dict, realtime: dict) -> dict: # Add values together merged[key] = { "total_tokens": merged[key]["total_tokens"] + data["total_tokens"], - "total_cost_usd": ( - (merged[key]["total_cost_usd"] or 0) + (data["total_cost_usd"] or 0) - ) if (merged[key]["total_cost_usd"] is not None or data["total_cost_usd"] is not None) else None, + "total_cost_usd": ((merged[key]["total_cost_usd"] or 0) + (data["total_cost_usd"] or 0)) + if (merged[key]["total_cost_usd"] is not None or data["total_cost_usd"] is not None) + else None, "call_count": merged[key]["call_count"] + data["call_count"], } else: @@ -301,9 +293,7 @@ def get_top_users( historical_end = end_date - timedelta(days=1) if end_date == today else end_date historical = {} if start_date <= historical_end: - historical = AnalyticsService._get_historical_user_usage( - db, start_date, historical_end, org_id - ) + historical = AnalyticsService._get_historical_user_usage(db, start_date, historical_end, org_id) # Get real-time data for today if today is in range realtime = {} @@ -329,10 +319,7 @@ def get_top_users( # Get user info for all user_ids user_ids = list(merged.keys()) - users_info = { - u.id: u - for u in db.query(User).filter(User.id.in_(user_ids)).all() - } + users_info = {u.id: u for u in db.query(User).filter(User.id.in_(user_ids)).all()} # Build sorted entries entries = [] @@ -343,16 +330,18 @@ def get_top_users( percentage = (data["total_tokens"] / total_tokens * 100) if total_tokens > 0 else 0 - entries.append(TopUserEntry( - user_id=user_id, - email=user.email, - display_name=user.display_name, - total_tokens=data["total_tokens"], - total_credits=tokens_to_credits(data["total_tokens"]), - total_cost_usd=data["total_cost_usd"], - call_count=data["call_count"], - percentage_of_total=round(percentage, 2), - )) + entries.append( + TopUserEntry( + user_id=user_id, + email=user.email, + display_name=user.display_name, + total_tokens=data["total_tokens"], + total_credits=tokens_to_credits(data["total_tokens"]), + total_cost_usd=data["total_cost_usd"], + call_count=data["call_count"], + percentage_of_total=round(percentage, 2), + ) + ) # Sort by total_tokens descending and apply limit entries.sort(key=lambda x: x.total_tokens, reverse=True) @@ -398,9 +387,7 @@ def get_top_projects( historical_end = end_date - timedelta(days=1) if end_date == today else end_date historical = {} if start_date <= historical_end: - historical = AnalyticsService._get_historical_project_usage( - db, start_date, historical_end, org_id - ) + historical = AnalyticsService._get_historical_project_usage(db, start_date, historical_end, org_id) # Get real-time data for today if today is in range realtime = {} @@ -444,17 +431,19 @@ def get_top_projects( project, org = project_org percentage = (data["total_tokens"] / total_tokens * 100) if total_tokens > 0 else 0 - entries.append(TopProjectEntry( - project_id=project_id, - project_name=project.name, - org_id=org.id, - org_name=org.name, - total_tokens=data["total_tokens"], - total_credits=tokens_to_credits(data["total_tokens"]), - total_cost_usd=data["total_cost_usd"], - call_count=data["call_count"], - percentage_of_total=round(percentage, 2), - )) + entries.append( + TopProjectEntry( + project_id=project_id, + project_name=project.name, + org_id=org.id, + org_name=org.name, + total_tokens=data["total_tokens"], + total_credits=tokens_to_credits(data["total_tokens"]), + total_cost_usd=data["total_cost_usd"], + call_count=data["call_count"], + percentage_of_total=round(percentage, 2), + ) + ) # Sort by total_tokens descending and apply limit entries.sort(key=lambda x: x.total_tokens, reverse=True) @@ -515,13 +504,9 @@ def _get_realtime_totals( # Date filter based on dialect if dialect == "postgresql": - query = query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date - ) + query = query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date) else: - query = query.filter( - func.date(LLMUsageLog.created_at) == target_date - ) + query = query.filter(func.date(LLMUsageLog.created_at) == target_date) if org_id: query = query.filter(LLMUsageLog.org_id == org_id) @@ -562,9 +547,7 @@ def get_efficiency_metrics( historical_end = end_date - timedelta(days=1) if end_date == today else end_date historical = {"total_tokens": 0, "total_cost_usd": None, "call_count": 0} if start_date <= historical_end: - historical = AnalyticsService._get_historical_totals( - db, start_date, historical_end, org_id - ) + historical = AnalyticsService._get_historical_totals(db, start_date, historical_end, org_id) # Get real-time data for today if today is in range realtime = {"total_tokens": 0, "total_cost_usd": None, "call_count": 0} @@ -640,34 +623,30 @@ def _get_current_month_usage( historical_tokens = 0 if first_day <= historical_end: - result = db.query( - func.coalesce(func.sum(DailyUsageSummary.total_tokens), 0) - ).filter( - DailyUsageSummary.org_id == org_id, - DailyUsageSummary.date >= first_day, - DailyUsageSummary.date <= historical_end, - ).scalar() + result = ( + db.query(func.coalesce(func.sum(DailyUsageSummary.total_tokens), 0)) + .filter( + DailyUsageSummary.org_id == org_id, + DailyUsageSummary.date >= first_day, + DailyUsageSummary.date <= historical_end, + ) + .scalar() + ) historical_tokens = int(result) # Get real-time data for today dialect = AnalyticsService._get_dialect(db) realtime_query = db.query( - func.coalesce( - func.sum(LLMUsageLog.prompt_tokens + LLMUsageLog.completion_tokens), 0 - ) + func.coalesce(func.sum(LLMUsageLog.prompt_tokens + LLMUsageLog.completion_tokens), 0) ).filter( LLMUsageLog.org_id == org_id, ) if dialect == "postgresql": - realtime_query = realtime_query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == today - ) + realtime_query = realtime_query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == today) else: - realtime_query = realtime_query.filter( - func.date(LLMUsageLog.created_at) == today - ) + realtime_query = realtime_query.filter(func.date(LLMUsageLog.created_at) == today) realtime_tokens = int(realtime_query.scalar()) @@ -858,15 +837,20 @@ def _get_org_user_usage( historical = {} if start_date <= historical_end: - results = db.query( - DailyUsageSummary.user_id, - func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), - ).filter( - DailyUsageSummary.org_id == org_id, - DailyUsageSummary.user_id.isnot(None), - DailyUsageSummary.date >= start_date, - DailyUsageSummary.date <= historical_end, - ).group_by(DailyUsageSummary.user_id).all() + results = ( + db.query( + DailyUsageSummary.user_id, + func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), + ) + .filter( + DailyUsageSummary.org_id == org_id, + DailyUsageSummary.user_id.isnot(None), + DailyUsageSummary.date >= start_date, + DailyUsageSummary.date <= historical_end, + ) + .group_by(DailyUsageSummary.user_id) + .all() + ) for row in results: historical[row.user_id] = int(row.total_tokens or 0) @@ -885,9 +869,7 @@ def _get_org_user_usage( ) if dialect == "postgresql": - query = query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == today - ) + query = query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == today) else: query = query.filter(func.date(LLMUsageLog.created_at) == today) @@ -925,15 +907,20 @@ def _get_org_project_usage( historical = {} if start_date <= historical_end: - results = db.query( - DailyUsageSummary.project_id, - func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), - ).filter( - DailyUsageSummary.org_id == org_id, - DailyUsageSummary.project_id.isnot(None), - DailyUsageSummary.date >= start_date, - DailyUsageSummary.date <= historical_end, - ).group_by(DailyUsageSummary.project_id).all() + results = ( + db.query( + DailyUsageSummary.project_id, + func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), + ) + .filter( + DailyUsageSummary.org_id == org_id, + DailyUsageSummary.project_id.isnot(None), + DailyUsageSummary.date >= start_date, + DailyUsageSummary.date <= historical_end, + ) + .group_by(DailyUsageSummary.project_id) + .all() + ) for row in results: historical[row.project_id] = int(row.total_tokens or 0) @@ -952,9 +939,7 @@ def _get_org_project_usage( ) if dialect == "postgresql": - query = query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == today - ) + query = query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == today) else: query = query.filter(func.date(LLMUsageLog.created_at) == today) @@ -1023,61 +1008,59 @@ def get_org_efficiency_overview( total_org_tokens = plan_efficiency.tokens_used # Get top users - user_usage = AnalyticsService._get_org_user_usage( - db, org.id, start_date, end_date, limit_users - ) + user_usage = AnalyticsService._get_org_user_usage(db, org.id, start_date, end_date, limit_users) # Fetch user info user_ids = [u[0] for u in user_usage] - users_info = { - u.id: u - for u in db.query(User).filter(User.id.in_(user_ids)).all() - } if user_ids else {} + users_info = {u.id: u for u in db.query(User).filter(User.id.in_(user_ids)).all()} if user_ids else {} user_entries = [] for user_id, tokens in user_usage: user = users_info.get(user_id) if user: percentage = (tokens / total_org_tokens * 100) if total_org_tokens > 0 else 0 - user_entries.append(UserEfficiencyEntry( - user_id=user_id, - email=user.email, - display_name=user.display_name, - tokens_used=tokens, - percentage_of_org=round(percentage, 2), - )) + user_entries.append( + UserEfficiencyEntry( + user_id=user_id, + email=user.email, + display_name=user.display_name, + tokens_used=tokens, + percentage_of_org=round(percentage, 2), + ) + ) # Get top projects - project_usage = AnalyticsService._get_org_project_usage( - db, org.id, start_date, end_date, limit_projects - ) + project_usage = AnalyticsService._get_org_project_usage(db, org.id, start_date, end_date, limit_projects) # Fetch project info project_ids = [p[0] for p in project_usage] - projects_info = { - p.id: p - for p in db.query(Project).filter(Project.id.in_(project_ids)).all() - } if project_ids else {} + projects_info = ( + {p.id: p for p in db.query(Project).filter(Project.id.in_(project_ids)).all()} if project_ids else {} + ) project_entries = [] for project_id, tokens in project_usage: project = projects_info.get(project_id) if project: percentage = (tokens / total_org_tokens * 100) if total_org_tokens > 0 else 0 - project_entries.append(ProjectEfficiencyEntry( - project_id=project_id, - project_name=project.name, - tokens_used=tokens, - percentage_of_org=round(percentage, 2), - )) - - org_efficiencies.append(OrganizationEfficiency( - org_id=org.id, - org_name=org.name, - efficiency=efficiency_score, - users=user_entries, - projects=project_entries, - )) + project_entries.append( + ProjectEfficiencyEntry( + project_id=project_id, + project_name=project.name, + tokens_used=tokens, + percentage_of_org=round(percentage, 2), + ) + ) + + org_efficiencies.append( + OrganizationEfficiency( + org_id=org.id, + org_name=org.name, + efficiency=efficiency_score, + users=user_entries, + projects=project_entries, + ) + ) return OrgEfficiencyResponse( time_range=time_range, diff --git a/backend/app/services/api_key_service.py b/backend/app/services/api_key_service.py index 284828d..5ee0b67 100644 --- a/backend/app/services/api_key_service.py +++ b/backend/app/services/api_key_service.py @@ -3,17 +3,17 @@ from datetime import datetime, timezone from uuid import UUID -from sqlalchemy.orm import Session from fastapi import HTTPException, status +from sqlalchemy.orm import Session +from app.auth.api_key_utils import generate_api_key, get_key_preview, hash_api_key, hash_api_key_sha256 +from app.auth.encryption_utils import decrypt_api_key, encrypt_api_key from app.models.api_key import ApiKey from app.models.project import Project -from app.auth.api_key_utils import generate_api_key, hash_api_key, hash_api_key_sha256, get_key_preview -from app.auth.encryption_utils import encrypt_api_key, decrypt_api_key from app.schemas.api_key import ( ApiKeyCreate, - ApiKeyResponse, ApiKeyCreateResponse, + ApiKeyResponse, MCPConnectionConfig, ) @@ -76,12 +76,7 @@ def list_api_keys(db: Session, user_id: UUID) -> list[ApiKey]: Returns: List of API keys (non-revoked and revoked) """ - return ( - db.query(ApiKey) - .filter(ApiKey.user_id == user_id) - .order_by(ApiKey.created_at.desc()) - .all() - ) + return db.query(ApiKey).filter(ApiKey.user_id == user_id).order_by(ApiKey.created_at.desc()).all() @staticmethod def get_api_key(db: Session, key_id: UUID, user_id: UUID) -> ApiKey | None: @@ -96,11 +91,7 @@ def get_api_key(db: Session, key_id: UUID, user_id: UUID) -> ApiKey | None: Returns: ApiKey if found and belongs to user, None otherwise """ - return ( - db.query(ApiKey) - .filter(ApiKey.id == key_id, ApiKey.user_id == user_id) - .first() - ) + return db.query(ApiKey).filter(ApiKey.id == key_id, ApiKey.user_id == user_id).first() @staticmethod def revoke_api_key(db: Session, key_id: UUID, user_id: UUID) -> ApiKey: @@ -121,9 +112,7 @@ def revoke_api_key(db: Session, key_id: UUID, user_id: UUID) -> ApiKey: api_key = ApiKeyService.get_api_key(db, key_id, user_id) if not api_key: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="API key not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="API key not found") api_key.revoked = True db.commit() @@ -148,9 +137,7 @@ def delete_api_key(db: Session, key_id: UUID, user_id: UUID) -> None: api_key = ApiKeyService.get_api_key(db, key_id, user_id) if not api_key: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="API key not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="API key not found") if not api_key.revoked: raise HTTPException( @@ -162,9 +149,7 @@ def delete_api_key(db: Session, key_id: UUID, user_id: UUID) -> None: db.commit() @staticmethod - def get_mcp_connection_config( - db: Session, project_id: UUID, base_url: str - ) -> MCPConnectionConfig: + def get_mcp_connection_config(db: Session, project_id: UUID, base_url: str) -> MCPConnectionConfig: """ Get MCP connection configuration for a project. @@ -182,16 +167,12 @@ def get_mcp_connection_config( project = db.query(Project).filter(Project.id == project_id).first() if not project: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Project not found" - ) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Project not found") # Construct MCP URL using the short URL identifier mcp_url = f"{base_url}/api/v1/projects/{project.url_identifier}/mcp" - return MCPConnectionConfig( - mcp_url=mcp_url, project_id=project_id, project_key=project.key - ) + return MCPConnectionConfig(mcp_url=mcp_url, project_id=project_id, project_key=project.key) @staticmethod def to_response(api_key: ApiKey, raw_key: str | None = None) -> ApiKeyResponse | ApiKeyCreateResponse: diff --git a/backend/app/services/brainstorming_phase_service.py b/backend/app/services/brainstorming_phase_service.py index c0ea069..aabf566 100644 --- a/backend/app/services/brainstorming_phase_service.py +++ b/backend/app/services/brainstorming_phase_service.py @@ -1,36 +1,29 @@ """Service for managing brainstorming phases.""" -import os -import base64 + import logging from dataclasses import asdict -from typing import Optional, List, Callable, Dict, Any +from typing import Any, Callable, Dict, List, Optional from uuid import UUID + from sqlalchemy.orm import Session, joinedload -from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.project import Project -from app.models.module import Module, ModuleProvenance, ModuleType -from app.models.feature import Feature, FeatureProvenance, FeaturePriority, FeatureType, FeatureVisibilityStatus -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.spec_version import SpecVersion, SpecType -from app.models.final_spec import FinalSpec +from app.models.feature import Feature, FeaturePriority, FeatureProvenance, FeatureType, FeatureVisibilityStatus from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec +from app.models.module import Module, ModuleProvenance, ModuleType +from app.models.project import Project from app.models.prompt_plan_coverage import PromptPlanCoverageReport -from app.services.platform_settings_service import require_llm_config_sync +from app.models.spec_version import SpecType, SpecVersion +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType from app.services.grounding_service import GroundingService - +from app.services.platform_settings_service import require_llm_config_sync logger = logging.getLogger(__name__) -def _build_existing_conversation_context( - db: Session, - brainstorming_phase_id: UUID -) -> "ExistingConversationContext": +def _build_existing_conversation_context(db: Session, brainstorming_phase_id: UUID) -> "ExistingConversationContext": """ Build rich context about existing aspects and questions for a brainstorming phase. @@ -48,19 +41,21 @@ def _build_existing_conversation_context( ExistingConversationContext with all aspects and their decision summaries """ from app.agents.brainstorm_conversation.types import ( - ExistingConversationContext, + AspectCategory, ExistingAspect, + ExistingConversationContext, ExistingQuestionWithAnswer, - AspectCategory, QuestionPriority, ) from app.services.thread_service import ThreadService # Fetch all modules (aspects) for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None) - ).order_by(Module.order_index).all() + modules = ( + db.query(Module) + .filter(Module.brainstorming_phase_id == brainstorming_phase_id, Module.archived_at.is_(None)) + .order_by(Module.order_index) + .all() + ) existing_aspects = [] total_questions = 0 @@ -68,21 +63,26 @@ def _build_existing_conversation_context( for module in modules: # Fetch features (questions) for this module - only ACTIVE visibility (exclude pending) - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) questions = [] module_answered = 0 for feature in features: # Get the thread for this feature - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id) - ).first() + thread = ( + db.query(Thread) + .filter(Thread.context_type == ContextType.BRAINSTORM_FEATURE, Thread.context_id == str(feature.id)) + .first() + ) # Get decision context using ThreadService helper decision_context = {"type": "empty"} @@ -106,10 +106,7 @@ def _build_existing_conversation_context( decision_summary = decision_context.get("summary") # Extract question text from unresolved points raw_points = decision_context.get("unresolved_points", []) - unresolved_points = [ - p.get("question", str(p)) if isinstance(p, dict) else str(p) - for p in raw_points - ] + unresolved_points = [p.get("question", str(p)) if isinstance(p, dict) else str(p) for p in raw_points] status = "answered" module_answered += 1 @@ -127,21 +124,22 @@ def _build_existing_conversation_context( "optional": QuestionPriority.OPTIONAL, } priority = priority_map.get( - feature.priority.value if feature.priority else "important", - QuestionPriority.IMPORTANT + feature.priority.value if feature.priority else "important", QuestionPriority.IMPORTANT ) - questions.append(ExistingQuestionWithAnswer( - question_id=str(feature.id), - question_title=feature.title, - question_description=feature.spec_text or "", - aspect_title=module.title, - priority=priority, - status=status, - decision_type=decision_type, - decision_summary=decision_summary, - unresolved_points=unresolved_points, - )) + questions.append( + ExistingQuestionWithAnswer( + question_id=str(feature.id), + question_title=feature.title, + question_description=feature.spec_text or "", + aspect_title=module.title, + priority=priority, + status=status, + decision_type=decision_type, + decision_summary=decision_summary, + unresolved_points=unresolved_points, + ) + ) total_questions += len(questions) total_answered += module_answered @@ -150,15 +148,17 @@ def _build_existing_conversation_context( # The LLM-generated aspects have categories, but stored modules don't retain them category = AspectCategory.BUSINESS_LOGIC - existing_aspects.append(ExistingAspect( - aspect_id=str(module.id), - title=module.title, - description=module.description or "", - category=category, - questions=questions, - total_questions=len(questions), - answered_questions=module_answered, - )) + existing_aspects.append( + ExistingAspect( + aspect_id=str(module.id), + title=module.title, + description=module.description or "", + category=category, + questions=questions, + total_questions=len(questions), + answered_questions=module_answered, + ) + ) return ExistingConversationContext( aspects=existing_aspects, @@ -196,9 +196,9 @@ def _build_cross_project_context( CrossProjectContext with all cross-phase and project decisions """ from app.agents.brainstorm_conversation.types import ( - CrossProjectContext, CrossPhaseContext, CrossPhaseDecision, + CrossProjectContext, ProjectFeatureDecision, ) @@ -206,48 +206,59 @@ def _build_cross_project_context( project_features_context = [] # 1. Query other brainstorming phases (not current, not archived) - other_phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project_id, - BrainstormingPhase.id != current_phase_id, - BrainstormingPhase.archived_at.is_(None) - ).order_by(BrainstormingPhase.created_at).limit(max_phases).all() + other_phases = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.project_id == project_id, + BrainstormingPhase.id != current_phase_id, + BrainstormingPhase.archived_at.is_(None), + ) + .order_by(BrainstormingPhase.created_at) + .limit(max_phases) + .all() + ) for phase in other_phases: decisions = [] # Get modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.archived_at.is_(None) - ).all() + modules = db.query(Module).filter(Module.brainstorming_phase_id == phase.id, Module.archived_at.is_(None)).all() for module in modules: # Get ACTIVE features (questions) with threads that have decisions - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id) - ).first() + thread = ( + db.query(Thread) + .filter(Thread.context_type == ContextType.BRAINSTORM_FEATURE, Thread.context_id == str(feature.id)) + .first() + ) # Only include if thread has decision_summary_short or decision_summary if thread and (thread.decision_summary_short or thread.decision_summary): summary = thread.decision_summary_short or ( - thread.decision_summary[:100] + "..." if len(thread.decision_summary or "") > 100 + thread.decision_summary[:100] + "..." + if len(thread.decision_summary or "") > 100 else thread.decision_summary ) if summary: - decisions.append(CrossPhaseDecision( - question_title=feature.title, - decision_summary_short=summary, - aspect_title=module.title, - )) + decisions.append( + CrossPhaseDecision( + question_title=feature.title, + decision_summary_short=summary, + aspect_title=module.title, + ) + ) # Cap decisions per phase if len(decisions) >= max_decisions_per_phase: @@ -263,47 +274,54 @@ def _build_cross_project_context( if len(description) > 200: description = description[:200] + "..." - other_phases_context.append(CrossPhaseContext( - phase_id=str(phase.id), - phase_title=phase.title, - phase_description=description, - decisions=decisions, - )) + other_phases_context.append( + CrossPhaseContext( + phase_id=str(phase.id), + phase_title=phase.title, + phase_description=description, + decisions=decisions, + ) + ) # 2. Query project-level features (module.brainstorming_phase_id IS NULL) - project_modules = db.query(Module).filter( - Module.project_id == project_id, - Module.brainstorming_phase_id.is_(None), - Module.archived_at.is_(None) - ).all() + project_modules = ( + db.query(Module) + .filter(Module.project_id == project_id, Module.brainstorming_phase_id.is_(None), Module.archived_at.is_(None)) + .all() + ) for module in project_modules: # Get IMPLEMENTATION features (not CONVERSATION) - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.IMPLEMENTATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.IMPLEMENTATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature (could be SPEC or GENERAL context type) - thread = db.query(Thread).filter( - Thread.context_id == str(feature.id) - ).first() + thread = db.query(Thread).filter(Thread.context_id == str(feature.id)).first() # Only include if thread has decision summary if thread and (thread.decision_summary_short or thread.decision_summary): summary = thread.decision_summary_short or ( - thread.decision_summary[:100] + "..." if len(thread.decision_summary or "") > 100 + thread.decision_summary[:100] + "..." + if len(thread.decision_summary or "") > 100 else thread.decision_summary ) if summary: - project_features_context.append(ProjectFeatureDecision( - feature_title=feature.title, - module_title=module.title, - decision_summary_short=summary, - )) + project_features_context.append( + ProjectFeatureDecision( + feature_title=feature.title, + module_title=module.title, + decision_summary_short=summary, + ) + ) # Cap project features if len(project_features_context) >= max_project_features: @@ -343,9 +361,9 @@ def load_sibling_phases( SiblingPhasesContext with all sibling phases, or None if container not found """ from app.agents.brainstorm_conversation.types import ( + CrossPhaseDecision, SiblingPhaseContext, SiblingPhasesContext, - CrossPhaseDecision, ) from app.services.phase_container_service import PhaseContainerService @@ -365,9 +383,7 @@ def load_sibling_phases( query = query.filter(BrainstormingPhase.id != exclude_phase_id) # Order by container_sequence - siblings = query.order_by( - BrainstormingPhase.container_sequence.nullsfirst() - ).limit(max_phases).all() + siblings = query.order_by(BrainstormingPhase.container_sequence.nullsfirst()).limit(max_phases).all() sibling_contexts = [] @@ -375,25 +391,27 @@ def load_sibling_phases( decisions = [] # Get modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.archived_at.is_(None) - ).all() + modules = db.query(Module).filter(Module.brainstorming_phase_id == phase.id, Module.archived_at.is_(None)).all() for module in modules: # Get ACTIVE features with threads that have decisions - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Get thread for this feature - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id) - ).first() + thread = ( + db.query(Thread) + .filter(Thread.context_type == ContextType.BRAINSTORM_FEATURE, Thread.context_id == str(feature.id)) + .first() + ) # Only include if thread has decision summary if thread and (thread.decision_summary_short or thread.decision_summary): @@ -403,11 +421,13 @@ def load_sibling_phases( else thread.decision_summary ) if summary: - decisions.append(CrossPhaseDecision( - question_title=feature.title, - decision_summary_short=summary, - aspect_title=module.title, - )) + decisions.append( + CrossPhaseDecision( + question_title=feature.title, + decision_summary_short=summary, + aspect_title=module.title, + ) + ) # Cap decisions per phase if len(decisions) >= max_decisions_per_phase: @@ -422,19 +442,19 @@ def load_sibling_phases( description = description[:200] + "..." # Get phase subtype (default to INITIAL_SPEC if not set) - phase_subtype = ( - phase.phase_subtype.value if phase.phase_subtype else "INITIAL_SPEC" - ) + phase_subtype = phase.phase_subtype.value if phase.phase_subtype else "INITIAL_SPEC" - sibling_contexts.append(SiblingPhaseContext( - phase_id=str(phase.id), - phase_title=phase.title, - phase_subtype=phase_subtype, - container_sequence=phase.container_sequence or 0, - description=description, - decisions=decisions, - implementation_analysis=phase.code_exploration_output, - )) + sibling_contexts.append( + SiblingPhaseContext( + phase_id=str(phase.id), + phase_title=phase.title, + phase_subtype=phase_subtype, + container_sequence=phase.container_sequence or 0, + description=description, + decisions=decisions, + implementation_analysis=phase.code_exploration_output, + ) + ) return SiblingPhasesContext( container_id=str(container.id), @@ -458,18 +478,20 @@ def _phase_has_implementations_for_analysis(db: Session, phase_id: UUID) -> bool True if phase has implementations worth analyzing """ # Check for IMPLEMENTATION modules - impl_count = db.query(Module).filter( - Module.brainstorming_phase_id == phase_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.archived_at.is_(None), - ).count() + impl_count = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == phase_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.archived_at.is_(None), + ) + .count() + ) if impl_count > 0: return True # Check for 3+ answered questions (phases with meaningful decisions) - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() return phase is not None and (phase.active_answered or 0) >= 3 @@ -483,9 +505,7 @@ def _cache_phase_analysis(db: Session, phase_id: UUID, analysis: str) -> None: """ from datetime import datetime, timezone - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() if phase: phase.code_exploration_output = analysis phase.code_exploration_cached_at = datetime.now(timezone.utc) @@ -513,13 +533,13 @@ async def enrich_sibling_phases_with_analysis( Returns: SiblingPhasesContext with implementation_analysis populated where available """ - from app.agents.brainstorm_conversation.types import SiblingPhasesContext - from app.services.code_explorer_client import code_explorer_client from app.models import Project - from workers.handlers.code_explorer import get_code_explorer_api_key, get_github_token_for_org # Check if code explorer is enabled from app.models.platform_settings import PlatformSettings + from app.services.code_explorer_client import code_explorer_client + from workers.handlers.code_explorer import get_code_explorer_api_key, get_github_token_for_org + settings = db.query(PlatformSettings).first() if not settings or not settings.code_explorer_enabled: logger.debug("Code explorer disabled, skipping sibling phase analysis") @@ -543,18 +563,18 @@ async def enrich_sibling_phases_with_analysis( github_token = None if repo.github_integration_config_id: try: - github_token = await get_github_token_for_org( - db, org_id, repo.github_integration_config_id - ) + github_token = await get_github_token_for_org(db, org_id, repo.github_integration_config_id) except Exception as e: logger.warning(f"Failed to get GitHub token for repo {repo.slug}: {e}") - repos.append({ - "slug": repo.slug, - "repo_url": repo.repo_url, - "branch": repo.default_branch or "main", - "github_token": github_token, - "user_remarks": repo.user_remarks, - }) + repos.append( + { + "slug": repo.slug, + "repo_url": repo.repo_url, + "branch": repo.default_branch or "main", + "github_token": github_token, + "user_remarks": repo.user_remarks, + } + ) if not repos: return sibling_context @@ -568,10 +588,12 @@ async def enrich_sibling_phases_with_analysis( if not _phase_has_implementations_for_analysis(db, UUID(phase.phase_id)): # Not worth analyzing continue - phases_to_analyze.append({ - "phase_id": phase.phase_id, - "phase_title": phase.phase_title, - }) + phases_to_analyze.append( + { + "phase_id": phase.phase_id, + "phase_title": phase.phase_title, + } + ) if not phases_to_analyze: logger.debug("No sibling phases need implementation analysis") @@ -601,9 +623,7 @@ async def enrich_sibling_phases_with_analysis( # Cache the analysis _cache_phase_analysis(db, UUID(phase.phase_id), analysis["implementation_summary"]) - logger.info( - f"Enriched {len(result['analyses'])} sibling phases with implementation analysis" - ) + logger.info(f"Enriched {len(result['analyses'])} sibling phases with implementation analysis") except Exception as e: logger.warning(f"Failed to enrich sibling phases with analysis: {e}") @@ -634,10 +654,14 @@ def _build_tech_stack_context( from app.models.project_chat import ProjectChat # Find pre-phase discussion that created this project - discussion = db.query(ProjectChat).filter( - ProjectChat.created_project_id == project_id, - ProjectChat.proposed_project_tech_stack.isnot(None), - ).first() + discussion = ( + db.query(ProjectChat) + .filter( + ProjectChat.created_project_id == project_id, + ProjectChat.proposed_project_tech_stack.isnot(None), + ) + .first() + ) if not discussion or not discussion.proposed_project_tech_stack: return None @@ -687,7 +711,7 @@ def _build_spec_summary_from_json(spec_version) -> Optional[str]: if body: truncated = body[:300] if len(body) > 300: - truncated = truncated.rsplit(' ', 1)[0] + "..." + truncated = truncated.rsplit(" ", 1)[0] + "..." lines.append(truncated) lines.append("") @@ -751,17 +775,13 @@ def get_brainstorming_phase( Returns: The BrainstormingPhase if found and not archived (or include_archived=True), None otherwise """ - query = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ) + query = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id) if not include_archived: query = query.filter(BrainstormingPhase.archived_at.is_(None)) return query.first() @staticmethod - def get_by_identifier( - db: Session, identifier: str, include_archived: bool = False - ) -> Optional[BrainstormingPhase]: + def get_by_identifier(db: Session, identifier: str, include_archived: bool = False) -> Optional[BrainstormingPhase]: """Get a brainstorming phase by UUID, short_id, or URL identifier. This method supports backward compatibility with existing UUID-based URLs @@ -784,9 +804,7 @@ def get_by_identifier( if is_uuid(identifier): try: uuid_val = UUID(identifier) - return BrainstormingPhaseService.get_brainstorming_phase( - db, uuid_val, include_archived - ) + return BrainstormingPhaseService.get_brainstorming_phase(db, uuid_val, include_archived) except ValueError: pass @@ -813,10 +831,10 @@ def list_project_phases( Returns: List of BrainstormingPhase objects for the project """ - query = db.query(BrainstormingPhase).options( - joinedload(BrainstormingPhase.creator) - ).filter( - BrainstormingPhase.project_id == project_id + query = ( + db.query(BrainstormingPhase) + .options(joinedload(BrainstormingPhase.creator)) + .filter(BrainstormingPhase.project_id == project_id) ) if not include_archived: query = query.filter(BrainstormingPhase.archived_at.is_(None)) @@ -852,9 +870,7 @@ def refresh_phase_question_stats( from sqlalchemy import and_, func # Get the phase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() if not phase: logger.warning(f"[refresh_phase_question_stats] Phase {phase_id} not found") @@ -862,10 +878,13 @@ def refresh_phase_question_stats( # Get module IDs for this phase module_ids = [ - m.id for m in db.query(Module.id).filter( + m.id + for m in db.query(Module.id) + .filter( Module.brainstorming_phase_id == phase_id, Module.archived_at.is_(None), - ).all() + ) + .all() ] if not module_ids: @@ -877,48 +896,63 @@ def refresh_phase_question_stats( return # Count active questions - active_total = db.query(func.count(Feature.id)).filter( - and_( - Feature.module_id.in_(module_ids), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None), + active_total = ( + db.query(func.count(Feature.id)) + .filter( + and_( + Feature.module_id.in_(module_ids), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) ) - ).scalar() or 0 + .scalar() + or 0 + ) # Count pending questions - pending_total = db.query(func.count(Feature.id)).filter( - and_( - Feature.module_id.in_(module_ids), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.PENDING, - Feature.archived_at.is_(None), + pending_total = ( + db.query(func.count(Feature.id)) + .filter( + and_( + Feature.module_id.in_(module_ids), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.PENDING, + Feature.archived_at.is_(None), + ) ) - ).scalar() or 0 + .scalar() + or 0 + ) # Count answered active questions # A question is answered if its thread has an MCQ item with selected_option_id set active_feature_ids = [ - str(f.id) for f in db.query(Feature.id).filter( + str(f.id) + for f in db.query(Feature.id) + .filter( and_( Feature.module_id.in_(module_ids), Feature.feature_type == FeatureType.CONVERSATION, Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, Feature.archived_at.is_(None), ) - ).all() + ) + .all() ] active_answered = 0 if active_feature_ids: # Query for distinct context_ids that have answered MCQs - answered_query = db.query(func.count(func.distinct(Thread.context_id))).join( - ThreadItem, ThreadItem.thread_id == Thread.id - ).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id.in_(active_feature_ids), - ThreadItem.item_type == ThreadItemType.MCQ_FOLLOWUP, - ThreadItem.content_data.op('->>')('selected_option_id').isnot(None), + answered_query = ( + db.query(func.count(func.distinct(Thread.context_id))) + .join(ThreadItem, ThreadItem.thread_id == Thread.id) + .filter( + Thread.context_type == ContextType.BRAINSTORM_FEATURE, + Thread.context_id.in_(active_feature_ids), + ThreadItem.item_type == ThreadItemType.MCQ_FOLLOWUP, + ThreadItem.content_data.op("->>")("selected_option_id").isnot(None), + ) ) active_answered = answered_query.scalar() or 0 @@ -952,9 +986,7 @@ def delete_brainstorming_phase( Returns: True if the phase was deleted, False if not found """ - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() if phase is None: return False @@ -983,10 +1015,14 @@ def archive_brainstorming_phase( """ from datetime import datetime, timezone - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id, - BrainstormingPhase.archived_at.is_(None), - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == phase_id, + BrainstormingPhase.archived_at.is_(None), + ) + .first() + ) if phase is None: return None @@ -1010,10 +1046,14 @@ def restore_brainstorming_phase( Returns: The restored BrainstormingPhase if found and was archived, None otherwise """ - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id, - BrainstormingPhase.archived_at.isnot(None), - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == phase_id, + BrainstormingPhase.archived_at.isnot(None), + ) + .first() + ) if phase is None: return None @@ -1051,42 +1091,63 @@ def cleanup_downstream_content_for_spec_regeneration( """ # 1. Delete IMPLEMENTATION modules (features cascade delete via FK) # CRITICAL FILTERS: phase_id, module_type=IMPLEMENTATION, provenance=SYSTEM - modules_deleted = db.query(Module).filter( - Module.brainstorming_phase_id == phase_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.provenance == ModuleProvenance.SYSTEM, - ).delete(synchronize_session='fetch') + modules_deleted = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == phase_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.provenance == ModuleProvenance.SYSTEM, + ) + .delete(synchronize_session="fetch") + ) # 2. Get prompt plan version IDs for this phase (needed to delete coverage reports) prompt_plan_version_ids = [ - v.id for v in db.query(SpecVersion.id).filter( + v.id + for v in db.query(SpecVersion.id) + .filter( SpecVersion.brainstorming_phase_id == phase_id, SpecVersion.spec_type == SpecType.PROMPT_PLAN, - ).all() + ) + .all() ] # 3. Delete prompt plan coverage reports (FK doesn't have CASCADE) coverage_reports_deleted = 0 if prompt_plan_version_ids: - coverage_reports_deleted = db.query(PromptPlanCoverageReport).filter( - PromptPlanCoverageReport.spec_version_id.in_(prompt_plan_version_ids) - ).delete(synchronize_session='fetch') + coverage_reports_deleted = ( + db.query(PromptPlanCoverageReport) + .filter(PromptPlanCoverageReport.spec_version_id.in_(prompt_plan_version_ids)) + .delete(synchronize_session="fetch") + ) # 4. Delete prompt plan versions from spec_versions - prompt_plan_versions_deleted = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == phase_id, - SpecVersion.spec_type == SpecType.PROMPT_PLAN, - ).delete(synchronize_session='fetch') + prompt_plan_versions_deleted = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == phase_id, + SpecVersion.spec_type == SpecType.PROMPT_PLAN, + ) + .delete(synchronize_session="fetch") + ) # 5. Delete final prompt plan - final_prompt_plan_deleted = db.query(FinalPromptPlan).filter( - FinalPromptPlan.brainstorming_phase_id == phase_id, - ).delete(synchronize_session='fetch') + final_prompt_plan_deleted = ( + db.query(FinalPromptPlan) + .filter( + FinalPromptPlan.brainstorming_phase_id == phase_id, + ) + .delete(synchronize_session="fetch") + ) # 6. Delete final spec - final_spec_deleted = db.query(FinalSpec).filter( - FinalSpec.brainstorming_phase_id == phase_id, - ).delete(synchronize_session='fetch') + final_spec_deleted = ( + db.query(FinalSpec) + .filter( + FinalSpec.brainstorming_phase_id == phase_id, + ) + .delete(synchronize_session="fetch") + ) db.commit() @@ -1120,11 +1181,15 @@ def cleanup_downstream_content_for_prompt_plan_regeneration( """ # Delete IMPLEMENTATION modules (features cascade delete via FK) # CRITICAL FILTERS: phase_id, module_type=IMPLEMENTATION, provenance=SYSTEM - modules_deleted = db.query(Module).filter( - Module.brainstorming_phase_id == phase_id, - Module.module_type == ModuleType.IMPLEMENTATION, - Module.provenance == ModuleProvenance.SYSTEM, - ).delete(synchronize_session='fetch') + modules_deleted = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == phase_id, + Module.module_type == ModuleType.IMPLEMENTATION, + Module.provenance == ModuleProvenance.SYSTEM, + ) + .delete(synchronize_session="fetch") + ) db.commit() @@ -1150,9 +1215,7 @@ def update_brainstorming_phase( Returns: The updated BrainstormingPhase if found, None otherwise """ - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() if phase is None: return None @@ -1203,19 +1266,17 @@ async def generate_aspects_and_questions( Raises: ValueError: If phase not found, no description, or LLM config missing """ - from app.services.module_service import ModuleService - from app.services.feature_service import FeatureService - from app.services.agent_utils import get_or_create_agent_user from app.agents.brainstorm_conversation import ( - create_orchestrator, BrainstormConversationContext, PhaseType, + create_orchestrator, ) + from app.services.agent_utils import get_or_create_agent_user + from app.services.feature_service import FeatureService + from app.services.module_service import ModuleService # 1. Load brainstorming phase and validate description exists - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == brainstorming_phase_id).first() if not phase: raise ValueError(f"Brainstorming phase {brainstorming_phase_id} not found") @@ -1267,9 +1328,7 @@ async def generate_aspects_and_questions( # Build cross-phase and project-level context cross_project_context = _build_cross_project_context( - db, - project_id=phase.project_id, - current_phase_id=brainstorming_phase_id + db, project_id=phase.project_id, current_phase_id=brainstorming_phase_id ) # Build tech stack context from pre-phase discussion @@ -1283,9 +1342,7 @@ async def generate_aspects_and_questions( sibling_phases_context = None if phase.container_id: sibling_phases_context = load_sibling_phases( - db=db, - container_id=phase.container_id, - exclude_phase_id=phase.id + db=db, container_id=phase.container_id, exclude_phase_id=phase.id ) # Enrich with implementation analysis from code explorer if sibling_phases_context and sibling_phases_context.sibling_phases: @@ -1363,9 +1420,7 @@ async def generate_aspects_and_questions( "important": FeaturePriority.IMPORTANT, "optional": FeaturePriority.OPTIONAL, } - feature_priority = priority_mapping.get( - gen_question.priority.value, FeaturePriority.IMPORTANT - ) + feature_priority = priority_mapping.get(gen_question.priority.value, FeaturePriority.IMPORTANT) # Create Feature (representing a Clarification Question) - CONVERSATION type for brainstorming feature = FeatureService.create_feature( @@ -1422,7 +1477,7 @@ async def generate_aspects_and_questions( # 9. Extract LLM usage stats before closing orchestrator llm_usage = None - if hasattr(orchestrator, 'model_client') and hasattr(orchestrator.model_client, 'get_usage_stats'): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -1451,10 +1506,7 @@ async def generate_aspects_and_questions( return result_dict @staticmethod - def _load_features_with_mcq_and_thread_discussions( - db: Session, - brainstorming_phase_id: UUID - ) -> tuple: + def _load_features_with_mcq_and_thread_discussions(db: Session, brainstorming_phase_id: UUID) -> tuple: """ Load all features with their answered MCQs and thread discussions with comments. @@ -1467,10 +1519,11 @@ def _load_features_with_mcq_and_thread_discussions( """ from sqlalchemy.orm import joinedload - modules = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None) - ).all() + modules = ( + db.query(Module) + .filter(Module.brainstorming_phase_id == brainstorming_phase_id, Module.archived_at.is_(None)) + .all() + ) aspects = [] clarification_questions = [] @@ -1483,21 +1536,28 @@ def _load_features_with_mcq_and_thread_discussions( # Load only conversation-type features (MCQ clarification questions) # for this module - exclude implementation features and pending visibility - features = db.query(Feature).filter( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None) - ).all() + features = ( + db.query(Feature) + .filter( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .all() + ) for feature in features: # Find the thread for this feature and load its items with authors - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id), - ).options( - joinedload(Thread.items).joinedload(ThreadItem.author) - ).first() + thread = ( + db.query(Thread) + .filter( + Thread.context_type == ContextType.BRAINSTORM_FEATURE, + Thread.context_id == str(feature.id), + ) + .options(joinedload(Thread.items).joinedload(ThreadItem.author)) + .first() + ) # Load ALL answered MCQs for this feature answered_mcqs = [] @@ -1518,13 +1578,15 @@ def _load_features_with_mcq_and_thread_discussions( selected_label = choice.get("label", selected_id) break - answered_mcqs.append({ - "question_text": content.get("question_text", ""), - "selected_option_id": selected_id, - "selected_label": selected_label, - "free_text": content.get("free_text"), - "choices": content.get("choices", []), - }) + answered_mcqs.append( + { + "question_text": content.get("question_text", ""), + "selected_option_id": selected_id, + "selected_label": selected_label, + "free_text": content.get("free_text"), + "choices": content.get("choices", []), + } + ) elif item.item_type == ThreadItemType.COMMENT.value: author_name = "Unknown" @@ -1567,25 +1629,29 @@ def _load_features_with_mcq_and_thread_discussions( # Unanswered questions would cause the LLM to infer/invent requirements if answered_mcqs: module_has_answered_questions = True - clarification_questions.append({ - "id": str(feature.id), - "title": feature.title, - "description": feature.spec_text or "", - "spec_text": feature.spec_text or "", - "category": feature.category or "General", - "priority": feature.priority.value if feature.priority else "important", - "answered_mcqs": answered_mcqs, - }) + clarification_questions.append( + { + "id": str(feature.id), + "title": feature.title, + "description": feature.spec_text or "", + "spec_text": feature.spec_text or "", + "category": feature.category or "General", + "priority": feature.priority.value if feature.priority else "important", + "answered_mcqs": answered_mcqs, + } + ) # Only include modules (aspects) that have at least one answered question # LLM-generated modules without user input should not be passed to spec generation if module_has_answered_questions: - aspects.append({ - "id": str(module.id), - "title": module.title, - "description": module.description or "", - "category": "General", - }) + aspects.append( + { + "id": str(module.id), + "title": module.title, + "description": module.description or "", + "category": "General", + } + ) return aspects, clarification_questions, thread_discussions @@ -1617,16 +1683,14 @@ async def generate_brainstorm_spec( Raises: ValueError: If phase not found, no modules/features, or LLM config missing """ - from app.models.spec_version import SpecVersion, SpecType from app.agents.brainstorm_spec import ( - create_orchestrator, BrainstormSpecContext, + create_orchestrator, ) + from app.models.spec_version import SpecType, SpecVersion # 1. Load brainstorming phase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == brainstorming_phase_id).first() if not phase: raise ValueError(f"Brainstorming phase {brainstorming_phase_id} not found") @@ -1643,25 +1707,19 @@ async def generate_brainstorm_spec( # 4. Load modules, features with MCQ answers, and thread discussions with comments aspects, clarification_questions, thread_discussions = ( - BrainstormingPhaseService._load_features_with_mcq_and_thread_discussions( - db, brainstorming_phase_id - ) + BrainstormingPhaseService._load_features_with_mcq_and_thread_discussions(db, brainstorming_phase_id) ) # 4.5 Build cross-project context for system awareness cross_project_context = _build_cross_project_context( - db, - project_id=phase.project_id, - current_phase_id=brainstorming_phase_id + db, project_id=phase.project_id, current_phase_id=brainstorming_phase_id ) # 4.6 Load sibling phases context for container-based phases sibling_phases_context = None if phase.container_id: sibling_phases_context = load_sibling_phases( - db=db, - container_id=phase.container_id, - exclude_phase_id=phase.id + db=db, container_id=phase.container_id, exclude_phase_id=phase.id ) # Enrich with implementation analysis from code explorer if sibling_phases_context and sibling_phases_context.sibling_phases: @@ -1707,10 +1765,14 @@ async def generate_brainstorm_spec( # 8. Save spec as SpecVersion # Find next version number - existing_versions = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == brainstorming_phase_id, - SpecVersion.spec_type == SpecType.SPECIFICATION - ).count() + existing_versions = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == brainstorming_phase_id, + SpecVersion.spec_type == SpecType.SPECIFICATION, + ) + .count() + ) spec_version = SpecVersion( project_id=phase.project_id, @@ -1727,7 +1789,7 @@ async def generate_brainstorm_spec( # 9. Extract LLM usage stats before closing orchestrator llm_usage = None - if hasattr(orchestrator, 'model_client') and hasattr(orchestrator.model_client, 'get_usage_stats'): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -1783,16 +1845,14 @@ async def generate_brainstorm_prompt_plan( Raises: ValueError: If phase not found, no modules/features, or LLM config missing """ - from app.models.spec_version import SpecVersion, SpecType from app.agents.brainstorm_prompt_plan import ( - create_orchestrator, BrainstormPromptPlanContext, + create_orchestrator, ) + from app.models.spec_version import SpecType, SpecVersion # 1. Load brainstorming phase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == brainstorming_phase_id).first() if not phase: raise ValueError(f"Brainstorming phase {brainstorming_phase_id} not found") @@ -1809,17 +1869,20 @@ async def generate_brainstorm_prompt_plan( # 4. Load modules, features with MCQ answers, and thread discussions with comments aspects, clarification_questions, thread_discussions = ( - BrainstormingPhaseService._load_features_with_mcq_and_thread_discussions( - db, brainstorming_phase_id - ) + BrainstormingPhaseService._load_features_with_mcq_and_thread_discussions(db, brainstorming_phase_id) ) # 5. Load existing brainstorm spec if available (use summaries if available) brainstorm_spec = None - existing_spec = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == brainstorming_phase_id, - SpecVersion.spec_type == SpecType.SPECIFICATION - ).order_by(SpecVersion.version.desc()).first() + existing_spec = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == brainstorming_phase_id, + SpecVersion.spec_type == SpecType.SPECIFICATION, + ) + .order_by(SpecVersion.version.desc()) + .first() + ) if existing_spec: # Try summaries from content_json first brainstorm_spec = _build_spec_summary_from_json(existing_spec) @@ -1831,9 +1894,7 @@ async def generate_brainstorm_prompt_plan( # 6. Build cross-project context for system awareness cross_project_context = _build_cross_project_context( - db, - project_id=phase.project_id, - current_phase_id=brainstorming_phase_id + db, project_id=phase.project_id, current_phase_id=brainstorming_phase_id ) # 6b. Load grounding summary (supplementary signal for project maturity) @@ -1856,9 +1917,7 @@ async def generate_brainstorm_prompt_plan( sibling_phases_context = None if phase.container_id: sibling_phases_context = load_sibling_phases( - db=db, - container_id=phase.container_id, - exclude_phase_id=phase.id + db=db, container_id=phase.container_id, exclude_phase_id=phase.id ) # Enrich with implementation analysis from code explorer if sibling_phases_context and sibling_phases_context.sibling_phases: @@ -1906,10 +1965,14 @@ async def generate_brainstorm_prompt_plan( # 9. Save prompt plan as SpecVersion # Find next version number - existing_versions = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == brainstorming_phase_id, - SpecVersion.spec_type == SpecType.PROMPT_PLAN - ).count() + existing_versions = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == brainstorming_phase_id, + SpecVersion.spec_type == SpecType.PROMPT_PLAN, + ) + .count() + ) spec_version = SpecVersion( project_id=phase.project_id, @@ -1926,7 +1989,7 @@ async def generate_brainstorm_prompt_plan( # 10. Extract LLM usage stats before closing orchestrator llm_usage = None - if hasattr(orchestrator, 'model_client') and hasattr(orchestrator.model_client, 'get_usage_stats'): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -1980,10 +2043,14 @@ def get_pending_questions_count( from sqlalchemy import and_ # Get all modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == brainstorming_phase_id, + Module.archived_at.is_(None), + ) + .all() + ) module_ids = [m.id for m in modules] @@ -1997,45 +2064,61 @@ def get_pending_questions_count( } # Count pending questions - pending_questions = db.query(Feature).filter( - and_( - Feature.module_id.in_(module_ids), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.PENDING, - Feature.archived_at.is_(None), + pending_questions = ( + db.query(Feature) + .filter( + and_( + Feature.module_id.in_(module_ids), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.PENDING, + Feature.archived_at.is_(None), + ) ) - ).count() + .count() + ) # Count active questions - active_questions = db.query(Feature).filter( - and_( - Feature.module_id.in_(module_ids), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None), + active_questions = ( + db.query(Feature) + .filter( + and_( + Feature.module_id.in_(module_ids), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) ) - ).count() + .count() + ) # Count pending aspects (modules that only have pending questions) pending_aspects = 0 for module in modules: - module_active = db.query(Feature).filter( - and_( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None), + module_active = ( + db.query(Feature) + .filter( + and_( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) ) - ).count() + .count() + ) - module_pending = db.query(Feature).filter( - and_( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.PENDING, - Feature.archived_at.is_(None), + module_pending = ( + db.query(Feature) + .filter( + and_( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.PENDING, + Feature.archived_at.is_(None), + ) ) - ).count() + .count() + ) # Only count as pending aspect if it has no active questions but has pending ones if module_active == 0 and module_pending > 0: @@ -2066,7 +2149,7 @@ def _count_answered_questions( Returns: Count of answered ACTIVE questions """ - from sqlalchemy import and_, cast, String, text, func + from sqlalchemy import String, and_, cast, func, text # Determine the database dialect for JSON query syntax dialect_name = db.bind.dialect.name if db.bind else "postgresql" @@ -2089,7 +2172,7 @@ def _count_answered_questions( and_( Thread.context_type == ContextType.BRAINSTORM_FEATURE, uuid_compare, - ) + ), ) .join(ThreadItem, ThreadItem.thread_id == Thread.id) .filter( @@ -2105,14 +2188,10 @@ def _count_answered_questions( # Add JSON filter based on dialect if dialect_name == "sqlite": # SQLite uses json_extract - query = query.filter( - text("json_extract(thread_items.content_data, '$.selected_option_id') IS NOT NULL") - ) + query = query.filter(text("json_extract(thread_items.content_data, '$.selected_option_id') IS NOT NULL")) else: # PostgreSQL uses ->> operator - query = query.filter( - text("thread_items.content_data->>'selected_option_id' IS NOT NULL") - ) + query = query.filter(text("thread_items.content_data->>'selected_option_id' IS NOT NULL")) return query.distinct().count() @@ -2169,13 +2248,17 @@ def get_generation_preflight_status( - skip_code: str - Machine-readable code for logging - details: dict - Diagnostic counts """ - from sqlalchemy import and_, func + from sqlalchemy import func # Load phase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id, - BrainstormingPhase.archived_at.is_(None), - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == brainstorming_phase_id, + BrainstormingPhase.archived_at.is_(None), + ) + .first() + ) if not phase: return { @@ -2206,9 +2289,7 @@ def get_generation_preflight_status( } # Get counts for remaining checks - counts = BrainstormingPhaseService.get_pending_questions_count( - db, brainstorming_phase_id - ) + counts = BrainstormingPhaseService.get_pending_questions_count(db, brainstorming_phase_id) pending_questions = counts["pending_questions"] active_questions = counts["active_questions"] @@ -2222,15 +2303,19 @@ def get_generation_preflight_status( } # Count active aspects (modules with at least one ACTIVE question) - active_aspects = db.query(func.count(func.distinct(Module.id))).join( - Feature, Feature.module_id == Module.id - ).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, - Feature.archived_at.is_(None), - ).scalar() or 0 + active_aspects = ( + db.query(func.count(func.distinct(Module.id))) + .join(Feature, Feature.module_id == Module.id) + .filter( + Module.brainstorming_phase_id == brainstorming_phase_id, + Module.archived_at.is_(None), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.ACTIVE, + Feature.archived_at.is_(None), + ) + .scalar() + or 0 + ) # Check 4: Phase not fully explored (only if there are existing questions) if active_questions > 0: @@ -2253,9 +2338,7 @@ def get_generation_preflight_status( # Check 5: User engagement >= 15% (only if there are existing questions) if active_questions > 0: - answered_count = BrainstormingPhaseService._count_answered_questions( - db, brainstorming_phase_id - ) + answered_count = BrainstormingPhaseService._count_answered_questions(db, brainstorming_phase_id) answered_ratio = answered_count / active_questions if answered_ratio < 0.15: @@ -2297,39 +2380,56 @@ def get_pending_questions( from sqlalchemy import and_ # Get all modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None), - ).order_by(Module.order_index).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == brainstorming_phase_id, + Module.archived_at.is_(None), + ) + .order_by(Module.order_index) + .all() + ) aspects = [] for module in modules: # Get pending questions for this module - pending_features = db.query(Feature).filter( - and_( - Feature.module_id == module.id, - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.PENDING, - Feature.archived_at.is_(None), + pending_features = ( + db.query(Feature) + .filter( + and_( + Feature.module_id == module.id, + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.PENDING, + Feature.archived_at.is_(None), + ) ) - ).all() + .all() + ) if pending_features: questions = [] for feature in pending_features: # Get MCQ data from thread if it exists mcq_data = None - thread = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id), - ).first() + thread = ( + db.query(Thread) + .filter( + Thread.context_type == ContextType.BRAINSTORM_FEATURE, + Thread.context_id == str(feature.id), + ) + .first() + ) if thread: - mcq_item = db.query(ThreadItem).filter( - ThreadItem.thread_id == thread.id, - ThreadItem.item_type == ThreadItemType.MCQ_FOLLOWUP, - ).first() + mcq_item = ( + db.query(ThreadItem) + .filter( + ThreadItem.thread_id == thread.id, + ThreadItem.item_type == ThreadItemType.MCQ_FOLLOWUP, + ) + .first() + ) if mcq_item and mcq_item.content_data: mcq_data = { @@ -2338,26 +2438,28 @@ def get_pending_questions( "explanation": mcq_item.content_data.get("explanation"), } - questions.append({ - "id": str(feature.id), - "feature_key": feature.feature_key, - "title": feature.title, - "spec_text": feature.spec_text, - "priority": feature.priority.value if feature.priority else "important", - "category": feature.category, - "mcq": mcq_data, - }) - - aspects.append({ - "id": str(module.id), - "name": module.title, - "description": module.description, - "questions": questions, - }) - - counts = BrainstormingPhaseService.get_pending_questions_count( - db, brainstorming_phase_id - ) + questions.append( + { + "id": str(feature.id), + "feature_key": feature.feature_key, + "title": feature.title, + "spec_text": feature.spec_text, + "priority": feature.priority.value if feature.priority else "important", + "category": feature.category, + "mcq": mcq_data, + } + ) + + aspects.append( + { + "id": str(module.id), + "name": module.title, + "description": module.description, + "questions": questions, + } + ) + + counts = BrainstormingPhaseService.get_pending_questions_count(db, brainstorming_phase_id) return { "aspects": aspects, @@ -2393,12 +2495,9 @@ def activate_pending_questions( ValueError: If activating would exceed max questions limit """ from sqlalchemy import and_ - from app.services.agent_utils import get_or_create_agent_user # Get current counts - counts = BrainstormingPhaseService.get_pending_questions_count( - db, brainstorming_phase_id - ) + counts = BrainstormingPhaseService.get_pending_questions_count(db, brainstorming_phase_id) # Check if we can activate all requested questions available_slots = counts["max_questions"] - counts["active_questions"] @@ -2409,22 +2508,30 @@ def activate_pending_questions( ) # Get all modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == brainstorming_phase_id, + Module.archived_at.is_(None), + ) + .all() + ) module_ids = [m.id for m in modules] # Get the features to activate - features = db.query(Feature).filter( - and_( - Feature.id.in_(question_ids), - Feature.module_id.in_(module_ids), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.PENDING, - Feature.archived_at.is_(None), + features = ( + db.query(Feature) + .filter( + and_( + Feature.id.in_(question_ids), + Feature.module_id.in_(module_ids), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.PENDING, + Feature.archived_at.is_(None), + ) ) - ).all() + .all() + ) activated_count = 0 skipped_count = len(question_ids) - len(features) @@ -2441,9 +2548,7 @@ def activate_pending_questions( BrainstormingPhaseService.refresh_phase_question_stats(db, brainstorming_phase_id) # Get updated counts - updated_counts = BrainstormingPhaseService.get_pending_questions_count( - db, brainstorming_phase_id - ) + updated_counts = BrainstormingPhaseService.get_pending_questions_count(db, brainstorming_phase_id) return { "activated_count": activated_count, @@ -2477,21 +2582,29 @@ def dismiss_pending_questions( from sqlalchemy import and_ # Get all modules for this phase - modules = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id, - Module.archived_at.is_(None), - ).all() + modules = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == brainstorming_phase_id, + Module.archived_at.is_(None), + ) + .all() + ) module_ids = [m.id for m in modules] # Get the features to dismiss (only PENDING visibility, CONVERSATION type) - features = db.query(Feature).filter( - and_( - Feature.id.in_(question_ids), - Feature.module_id.in_(module_ids), - Feature.feature_type == FeatureType.CONVERSATION, - Feature.visibility_status == FeatureVisibilityStatus.PENDING, + features = ( + db.query(Feature) + .filter( + and_( + Feature.id.in_(question_ids), + Feature.module_id.in_(module_ids), + Feature.feature_type == FeatureType.CONVERSATION, + Feature.visibility_status == FeatureVisibilityStatus.PENDING, + ) ) - ).all() + .all() + ) dismissed_count = 0 skipped_count = len(question_ids) - len(features) @@ -2500,10 +2613,14 @@ def dismiss_pending_questions( if feature_ids_to_delete: # Delete associated threads and thread items # Thread items have cascade delete, but we need to find threads by context_id - threads = db.query(Thread).filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id.in_([str(fid) for fid in feature_ids_to_delete]), - ).all() + threads = ( + db.query(Thread) + .filter( + Thread.context_type == ContextType.BRAINSTORM_FEATURE, + Thread.context_id.in_([str(fid) for fid in feature_ids_to_delete]), + ) + .all() + ) for thread in threads: # Delete thread items first @@ -2518,9 +2635,13 @@ def dismiss_pending_questions( # Check if any modules are now empty (no features left) and delete them for module_id in set(f.module_id for f in features): - remaining_features = db.query(Feature).filter( - Feature.module_id == module_id, - ).count() + remaining_features = ( + db.query(Feature) + .filter( + Feature.module_id == module_id, + ) + .count() + ) if remaining_features == 0: # Delete the empty module db.query(Module).filter(Module.id == module_id).delete() diff --git a/backend/app/services/bug_sync_service.py b/backend/app/services/bug_sync_service.py index 2234cf2..58f0e15 100644 --- a/backend/app/services/bug_sync_service.py +++ b/backend/app/services/bug_sync_service.py @@ -1,4 +1,5 @@ """Bug sync service.""" + import json import logging from dataclasses import asdict @@ -26,9 +27,7 @@ def __init__(self, db: AsyncSession): self.db = db self.integration_service = IntegrationService(db) - async def sync_bug_ticket( - self, project_id: UUID, triggered_by: str = "system" - ) -> BugSyncHistory: + async def sync_bug_ticket(self, project_id: UUID, triggered_by: str = "system") -> BugSyncHistory: """Sync bug ticket data for a bugfix project. Args: @@ -50,15 +49,10 @@ async def sync_bug_ticket( raise ValueError(f"Project not found: {project_id}") if project.type != ProjectType.BUGFIX.value: - raise ValueError( - f"Project {project_id} is not a bugfix project (type: {project.type})" - ) + raise ValueError(f"Project {project_id} is not a bugfix project (type: {project.type})") if not project.external_ticket_id or not project.external_system: - raise ValueError( - f"Project {project_id} does not have external_ticket_id " - f"or external_system configured" - ) + raise ValueError(f"Project {project_id} does not have external_ticket_id or external_system configured") history = BugSyncHistory( project_id=project_id, @@ -68,9 +62,7 @@ async def sync_bug_ticket( try: # Get adapter for the external system - adapter = await self.integration_service.get_adapter( - project.org_id, project.external_system - ) + adapter = await self.integration_service.get_adapter(project.org_id, project.external_system) # Fetch ticket data ticket_data = await adapter.fetch_ticket(project.external_ticket_id) @@ -83,10 +75,7 @@ async def sync_bug_ticket( history.status = "success" history.imported_data_json = asdict(ticket_data) - logger.info( - f"Successfully synced ticket {project.external_ticket_id} " - f"for project {project_id}" - ) + logger.info(f"Successfully synced ticket {project.external_ticket_id} for project {project_id}") except Exception as e: logger.error( @@ -129,9 +118,7 @@ def _format_imported_data(self, ticket_data) -> str: author = comment.get("author", "unknown") body = comment.get("body", "") created = comment.get("created_at", "") - parts.append( - f"### Comment {i} by {author} ({created})\n\n{body}\n" - ) + parts.append(f"### Comment {i} by {author} ({created})\n\n{body}\n") if ticket_data.metadata: parts.append(f"## Metadata\n\n```json\n{json.dumps(ticket_data.metadata, indent=2)}\n```\n") diff --git a/backend/app/services/code_explorer_client.py b/backend/app/services/code_explorer_client.py index 9be0e0f..0df604f 100644 --- a/backend/app/services/code_explorer_client.py +++ b/backend/app/services/code_explorer_client.py @@ -110,12 +110,14 @@ async def explore( logger.info(f"Calling code-explorer for {len(repos)} repos mode={mode}") elif repo_url: # Convert legacy format to repos list - repos = [{ - "slug": repo_url.rstrip("/").split("/")[-1].replace(".git", ""), - "repo_url": repo_url, - "branch": branch, - "github_token": github_token, - }] + repos = [ + { + "slug": repo_url.rstrip("/").split("/")[-1].replace(".git", ""), + "repo_url": repo_url, + "branch": branch, + "github_token": github_token, + } + ] logger.info(f"Calling code-explorer for repo {repo_url} branch {branch} mode={mode}") else: raise ValueError("Either 'repos' or 'repo_url' must be provided") @@ -295,10 +297,7 @@ def _build_implementation_analysis_prompt( sibling_phases: list[dict], ) -> str: """Build the prompt for implementation analysis.""" - phase_list = "\n".join( - f"- {p['phase_title']} (ID: {p['phase_id']})" - for p in sibling_phases - ) + phase_list = "\n".join(f"- {p['phase_title']} (ID: {p['phase_id']})" for p in sibling_phases) return f"""Analyze the implementation files in this codebase that relate to these phases: @@ -344,13 +343,15 @@ def _parse_implementation_analysis( # Try to find content related to this phase phase_section = output - analyses.append(ImplementationAnalysis( - phase_id=phase_id, - phase_title=phase_title, - file_paths=[], # Would be populated by parsing file mentions - implementation_summary=phase_section[:500] if phase_section else output[:500], - key_patterns=[], # Would be populated by parsing pattern mentions - )) + analyses.append( + ImplementationAnalysis( + phase_id=phase_id, + phase_title=phase_title, + file_paths=[], # Would be populated by parsing file mentions + implementation_summary=phase_section[:500] if phase_section else output[:500], + key_patterns=[], # Would be populated by parsing pattern mentions + ) + ) return analyses diff --git a/backend/app/services/daily_usage_summary_service.py b/backend/app/services/daily_usage_summary_service.py index 068103d..335759b 100644 --- a/backend/app/services/daily_usage_summary_service.py +++ b/backend/app/services/daily_usage_summary_service.py @@ -4,17 +4,16 @@ This service provides operations for the daily_usage_summary table which stores pre-aggregated LLM usage data for efficient analytics. """ + import logging -from datetime import datetime, timezone, date, timedelta -from typing import Optional, List, Dict, Any +from datetime import date, datetime, timezone +from typing import Any, Dict, List, Optional from uuid import UUID -from decimal import Decimal +from sqlalchemy import func, inspect, text from sqlalchemy.orm import Session -from sqlalchemy import func, text, inspect -from sqlalchemy.dialects.postgresql import insert as pg_insert -from app.models.daily_usage_summary import DailyUsageSummary, SENTINEL_UUID +from app.models.daily_usage_summary import SENTINEL_UUID, DailyUsageSummary from app.models.llm_usage_log import LLMUsageLog logger = logging.getLogger(__name__) @@ -62,8 +61,7 @@ def validate_source_schema(db: Session) -> Dict[str, Any]: # Get types for required columns column_types = { - col: str(actual_columns.get(col, "MISSING")) - for col in DailyUsageSummaryService.REQUIRED_SOURCE_COLUMNS + col: str(actual_columns.get(col, "MISSING")) for col in DailyUsageSummaryService.REQUIRED_SOURCE_COLUMNS } return { @@ -115,14 +113,10 @@ def aggregate_for_date( # Use database-specific date extraction if dialect == "postgresql": # PostgreSQL: use timezone conversion - query = query.filter( - func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date - ) + query = query.filter(func.date(func.timezone("UTC", LLMUsageLog.created_at)) == target_date) else: # SQLite: use simple date function (created_at is already stored as UTC) - query = query.filter( - func.date(LLMUsageLog.created_at) == target_date - ) + query = query.filter(func.date(LLMUsageLog.created_at) == target_date) query = query.group_by( LLMUsageLog.org_id, @@ -167,32 +161,43 @@ def aggregate_for_date( updated_at = :updated_at """) - db.execute(upsert_sql, { - "org_id": row.org_id, - "user_id": row.user_id, - "project_id": row.project_id, - "date": target_date, - "total_tokens": int(row.total_tokens or 0), - "prompt_tokens": int(row.prompt_tokens or 0), - "completion_tokens": int(row.completion_tokens or 0), - "total_cost_usd": row.total_cost_usd, - "call_count": row.call_count, - "created_at": now, - "updated_at": now, - "sentinel_uuid": SENTINEL_UUID, - }) + db.execute( + upsert_sql, + { + "org_id": row.org_id, + "user_id": row.user_id, + "project_id": row.project_id, + "date": target_date, + "total_tokens": int(row.total_tokens or 0), + "prompt_tokens": int(row.prompt_tokens or 0), + "completion_tokens": int(row.completion_tokens or 0), + "total_cost_usd": row.total_cost_usd, + "call_count": row.call_count, + "created_at": now, + "updated_at": now, + "sentinel_uuid": SENTINEL_UUID, + }, + ) else: # SQLite: Use manual check-and-insert/update pattern from uuid import uuid4 # Check for existing record - existing = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == row.org_id, - DailyUsageSummary.date == target_date, - # Handle NULL comparisons properly - (DailyUsageSummary.user_id == row.user_id) if row.user_id else DailyUsageSummary.user_id.is_(None), - (DailyUsageSummary.project_id == row.project_id) if row.project_id else DailyUsageSummary.project_id.is_(None), - ).first() + existing = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == row.org_id, + DailyUsageSummary.date == target_date, + # Handle NULL comparisons properly + (DailyUsageSummary.user_id == row.user_id) + if row.user_id + else DailyUsageSummary.user_id.is_(None), + (DailyUsageSummary.project_id == row.project_id) + if row.project_id + else DailyUsageSummary.project_id.is_(None), + ) + .first() + ) if existing: # Update existing record @@ -285,17 +290,21 @@ def get_org_totals_for_period( Dict with total_tokens, prompt_tokens, completion_tokens, total_cost_usd, call_count """ - result = db.query( - func.coalesce(func.sum(DailyUsageSummary.total_tokens), 0).label("total_tokens"), - func.coalesce(func.sum(DailyUsageSummary.prompt_tokens), 0).label("prompt_tokens"), - func.coalesce(func.sum(DailyUsageSummary.completion_tokens), 0).label("completion_tokens"), - func.sum(DailyUsageSummary.total_cost_usd).label("total_cost_usd"), - func.coalesce(func.sum(DailyUsageSummary.call_count), 0).label("call_count"), - ).filter( - DailyUsageSummary.org_id == org_id, - DailyUsageSummary.date >= start_date, - DailyUsageSummary.date <= end_date, - ).first() + result = ( + db.query( + func.coalesce(func.sum(DailyUsageSummary.total_tokens), 0).label("total_tokens"), + func.coalesce(func.sum(DailyUsageSummary.prompt_tokens), 0).label("prompt_tokens"), + func.coalesce(func.sum(DailyUsageSummary.completion_tokens), 0).label("completion_tokens"), + func.sum(DailyUsageSummary.total_cost_usd).label("total_cost_usd"), + func.coalesce(func.sum(DailyUsageSummary.call_count), 0).label("call_count"), + ) + .filter( + DailyUsageSummary.org_id == org_id, + DailyUsageSummary.date >= start_date, + DailyUsageSummary.date <= end_date, + ) + .first() + ) return { "total_tokens": int(result.total_tokens), @@ -324,20 +333,22 @@ def get_usage_by_user( Returns: List of dicts with user_id, total_tokens, total_cost_usd, call_count """ - results = db.query( - DailyUsageSummary.user_id, - func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), - func.sum(DailyUsageSummary.total_cost_usd).label("total_cost_usd"), - func.sum(DailyUsageSummary.call_count).label("call_count"), - ).filter( - DailyUsageSummary.org_id == org_id, - DailyUsageSummary.date >= start_date, - DailyUsageSummary.date <= end_date, - ).group_by( - DailyUsageSummary.user_id - ).order_by( - func.sum(DailyUsageSummary.total_tokens).desc() - ).all() + results = ( + db.query( + DailyUsageSummary.user_id, + func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), + func.sum(DailyUsageSummary.total_cost_usd).label("total_cost_usd"), + func.sum(DailyUsageSummary.call_count).label("call_count"), + ) + .filter( + DailyUsageSummary.org_id == org_id, + DailyUsageSummary.date >= start_date, + DailyUsageSummary.date <= end_date, + ) + .group_by(DailyUsageSummary.user_id) + .order_by(func.sum(DailyUsageSummary.total_tokens).desc()) + .all() + ) return [ { @@ -368,20 +379,22 @@ def get_usage_by_project( Returns: List of dicts with project_id, total_tokens, total_cost_usd, call_count """ - results = db.query( - DailyUsageSummary.project_id, - func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), - func.sum(DailyUsageSummary.total_cost_usd).label("total_cost_usd"), - func.sum(DailyUsageSummary.call_count).label("call_count"), - ).filter( - DailyUsageSummary.org_id == org_id, - DailyUsageSummary.date >= start_date, - DailyUsageSummary.date <= end_date, - ).group_by( - DailyUsageSummary.project_id - ).order_by( - func.sum(DailyUsageSummary.total_tokens).desc() - ).all() + results = ( + db.query( + DailyUsageSummary.project_id, + func.sum(DailyUsageSummary.total_tokens).label("total_tokens"), + func.sum(DailyUsageSummary.total_cost_usd).label("total_cost_usd"), + func.sum(DailyUsageSummary.call_count).label("call_count"), + ) + .filter( + DailyUsageSummary.org_id == org_id, + DailyUsageSummary.date >= start_date, + DailyUsageSummary.date <= end_date, + ) + .group_by(DailyUsageSummary.project_id) + .order_by(func.sum(DailyUsageSummary.total_tokens).desc()) + .all() + ) return [ { @@ -410,9 +423,7 @@ def delete_summaries_before_date( Returns: Number of records deleted """ - query = db.query(DailyUsageSummary).filter( - DailyUsageSummary.date < cutoff_date - ) + query = db.query(DailyUsageSummary).filter(DailyUsageSummary.date < cutoff_date) if org_id: query = query.filter(DailyUsageSummary.org_id == org_id) diff --git a/backend/app/services/dashboard_service.py b/backend/app/services/dashboard_service.py index 808763b..f227b24 100644 --- a/backend/app/services/dashboard_service.py +++ b/backend/app/services/dashboard_service.py @@ -4,13 +4,13 @@ This service provides organization-scoped dashboard metrics including user counts, project counts, and LLM usage summaries. """ + from datetime import datetime, timezone -from typing import Dict, Any, List +from typing import Any, Dict from uuid import UUID from sqlalchemy.orm import Session -from app.models.organization import Organization from app.models.org_membership import OrgMembership from app.models.project import Project from app.models.user import User @@ -42,11 +42,15 @@ def get_org_stats(db: Session, org_id: UUID) -> Dict[str, int]: ) # Count active projects (non-archived, non-deleted) - project_count = db.query(Project).filter( - Project.org_id == org_id, - Project.deleted_at.is_(None), - Project.status != "archived", - ).count() + project_count = ( + db.query(Project) + .filter( + Project.org_id == org_id, + Project.deleted_at.is_(None), + Project.status != "archived", + ) + .count() + ) return { "user_count": user_count, @@ -117,14 +121,10 @@ def get_dashboard_data( # Get current month usage now = datetime.now(timezone.utc) - monthly_usage = LLMUsageLogService.get_monthly_usage( - db, org_id, now.year, now.month - ) + monthly_usage = LLMUsageLogService.get_monthly_usage(db, org_id, now.year, now.month) # Get recent calls and transform to include user name - recent_logs = LLMUsageLogService.list_recent_usage( - db, org_id, limit=recent_calls_limit - ) + recent_logs = LLMUsageLogService.list_recent_usage(db, org_id, limit=recent_calls_limit) recent_calls = [] for log in recent_logs: call_dict = { @@ -144,9 +144,7 @@ def get_dashboard_data( recent_calls.append(call_dict) # Get plan info - plan_info = DashboardService.get_plan_info( - db, org_id, stats["user_count"], stats["project_count"] - ) + plan_info = DashboardService.get_plan_info(db, org_id, stats["user_count"], stats["project_count"]) return { "user_count": stats["user_count"], diff --git a/backend/app/services/draft_version_service.py b/backend/app/services/draft_version_service.py index 6e908a8..4814084 100644 --- a/backend/app/services/draft_version_service.py +++ b/backend/app/services/draft_version_service.py @@ -1,10 +1,12 @@ """Service for managing draft versions of specs and prompt plans.""" -from typing import Optional, List, Any + +from typing import List, Optional from uuid import UUID -from sqlalchemy.orm import Session + from sqlalchemy import func +from sqlalchemy.orm import Session -from app.models.spec_version import SpecVersion, SpecType +from app.models.spec_version import SpecType, SpecVersion class DraftVersionService: @@ -105,9 +107,7 @@ def _create_draft( The created SpecVersion """ # Get next version number - next_version = DraftVersionService._get_next_version_number( - db, brainstorming_phase_id, spec_type - ) + next_version = DraftVersionService._get_next_version_number(db, brainstorming_phase_id, spec_type) # Deactivate previous active drafts db.query(SpecVersion).filter( diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py index 7699766..f970c56 100644 --- a/backend/app/services/email_service.py +++ b/backend/app/services/email_service.py @@ -37,6 +37,7 @@ async def _get_template_service(self): """Get or create the template service (lazy-loaded).""" if self._template_service is None: from app.services.email_template_service import EmailTemplateService + self._template_service = EmailTemplateService(self.db) return self._template_service @@ -122,13 +123,9 @@ async def get_email_config(self) -> tuple[str | None, dict | None, str | None]: # 1. Try UI-configured connector first if platform_settings.email_connector_id: - connector = await self._settings_service.get_connector( - platform_settings.email_connector_id - ) + connector = await self._settings_service.get_connector(platform_settings.email_connector_id) if connector and connector.is_active: - api_key = self._settings_service._decrypt_credentials( - connector.encrypted_credentials - ) + api_key = self._settings_service._decrypt_credentials(connector.encrypted_credentials) config = connector.config_json or {} return api_key, config, None @@ -649,9 +646,7 @@ async def _send_via_sendgrid( ) else: error_body = response.text - logger.warning( - f"Sendgrid returned status {response.status_code}: {error_body}" - ) + logger.warning(f"Sendgrid returned status {response.status_code}: {error_body}") return EmailSendResult( success=False, message=f"Sendgrid error: {response.status_code}", diff --git a/backend/app/services/email_template_service.py b/backend/app/services/email_template_service.py index 6cfebfa..23b6fe2 100644 --- a/backend/app/services/email_template_service.py +++ b/backend/app/services/email_template_service.py @@ -1,4 +1,5 @@ """Email template service for managing and rendering templates.""" + import html import logging import re @@ -61,9 +62,7 @@ async def get_template_by_id(self, template_id: UUID) -> EmailTemplate | None: result = await self.db.execute(stmt) return result.scalar_one_or_none() - async def get_template_by_key( - self, key: str | EmailTemplateKey - ) -> EmailTemplate | None: + async def get_template_by_key(self, key: str | EmailTemplateKey) -> EmailTemplate | None: """Get a template by its unique key. Args: @@ -128,21 +127,13 @@ async def update_template( # If updating body, validate mandatory variables new_body = body_markdown if body_markdown is not None else template.body_markdown - new_subject = ( - subject_template if subject_template is not None else template.subject_template - ) - new_mandatory = ( - mandatory_variables - if mandatory_variables is not None - else template.mandatory_variables - ) + new_subject = subject_template if subject_template is not None else template.subject_template + new_mandatory = mandatory_variables if mandatory_variables is not None else template.mandatory_variables # Validate that all mandatory variables are present validation = self.validate_template(new_subject, new_body, new_mandatory) if not validation["valid"]: - raise ValueError( - f"Template is missing mandatory variables: {', '.join(validation['missing_variables'])}" - ) + raise ValueError(f"Template is missing mandatory variables: {', '.join(validation['missing_variables'])}") if display_name is not None: template.display_name = display_name @@ -192,9 +183,7 @@ def validate_template( return { "valid": len(missing) == 0, "missing_variables": missing, - "message": ( - "Template is valid" if len(missing) == 0 else f"Missing: {', '.join(missing)}" - ), + "message": ("Template is valid" if len(missing) == 0 else f"Missing: {', '.join(missing)}"), } def extract_variables(self, content: str) -> list[str]: @@ -243,9 +232,7 @@ def render_template( return subject, html_body - def _substitute_variables( - self, template: str, variables: dict[str, str], escape_html: bool = True - ) -> str: + def _substitute_variables(self, template: str, variables: dict[str, str], escape_html: bool = True) -> str: """Substitute {{variable}} placeholders with values. Args: @@ -403,9 +390,7 @@ def _parse_inline_formatting(self, text: str) -> list[dict]: if part.startswith("**") and part.endswith("**"): # Bold text - content.append( - {"type": "text", "text": part[2:-2], "marks": [{"type": "bold"}]} - ) + content.append({"type": "text", "text": part[2:-2], "marks": [{"type": "bold"}]}) elif part.startswith("[") and "](" in part: # Link match = re.match(r"\[([^\]]+)\]\(([^)]+)\)", part) @@ -419,9 +404,7 @@ def _parse_inline_formatting(self, text: str) -> list[dict]: ) elif part.startswith("{{") and part.endswith("}}"): # Variable - render with special styling - content.append( - {"type": "text", "text": part, "marks": [{"type": "code"}]} - ) + content.append({"type": "text", "text": part, "marks": [{"type": "code"}]}) else: content.append({"type": "text", "text": part}) diff --git a/backend/app/services/feature_content_version_service.py b/backend/app/services/feature_content_version_service.py index 330df46..b7b47e5 100644 --- a/backend/app/services/feature_content_version_service.py +++ b/backend/app/services/feature_content_version_service.py @@ -1,13 +1,14 @@ """Service for managing feature content versions.""" + import logging -from typing import Optional, List +from typing import List, Optional from uuid import UUID -from datetime import datetime, timezone -from sqlalchemy.orm import Session, joinedload + from sqlalchemy import func +from sqlalchemy.orm import Session, joinedload -from app.models.feature_content_version import FeatureContentVersion, FeatureContentType from app.models.feature import Feature +from app.models.feature_content_version import FeatureContentType, FeatureContentVersion from app.models.module import Module from app.models.project import Project from app.services.kafka_producer import get_sync_kafka_producer @@ -48,9 +49,7 @@ def create_version( The created FeatureContentVersion """ # Get next version number - next_version = FeatureContentVersionService._get_next_version_number( - db, feature_id, content_type - ) + next_version = FeatureContentVersionService._get_next_version_number(db, feature_id, content_type) # Deactivate previous active versions db.query(FeatureContentVersion).filter( @@ -259,15 +258,17 @@ def broadcast_batch_version_updates( # Build version data array versions_data = [] for version in versions: - versions_data.append({ - "id": str(version.id), - "feature_id": str(version.feature_id), - "content_type": version.content_type.value, - "version": version.version, - "edit_source": version.edit_source, - "creator_display_name": version.creator.display_name if version.creator else None, - "created_at": version.created_at.isoformat(), - }) + versions_data.append( + { + "id": str(version.id), + "feature_id": str(version.feature_id), + "content_type": version.content_type.value, + "version": version.version, + "edit_source": version.edit_source, + "creator_display_name": version.creator.display_name if version.creator else None, + "created_at": version.created_at.isoformat(), + } + ) message = { "type": "feature_content_versions_batch_created", @@ -283,9 +284,7 @@ def broadcast_batch_version_updates( ) if success: - logger.info( - f"Broadcasted batch version update via Kafka: {len(versions)} versions" - ) + logger.info(f"Broadcasted batch version update via Kafka: {len(versions)} versions") except Exception as e: logger.error(f"Failed to broadcast batch version update: {e}", exc_info=True) diff --git a/backend/app/services/feature_import_service.py b/backend/app/services/feature_import_service.py index 260efdf..c585a9b 100644 --- a/backend/app/services/feature_import_service.py +++ b/backend/app/services/feature_import_service.py @@ -1,10 +1,10 @@ """Service for importing external issues as features.""" + import logging from datetime import datetime, timezone from typing import Optional from uuid import UUID -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session @@ -12,17 +12,16 @@ from app.integrations.factory import get_adapter from app.models.feature import ( Feature, + FeaturePriority, FeatureProvenance, FeatureStatus, - FeaturePriority, FeatureType, ) -from app.models.feature_import_comment import FeatureImportComment from app.models.feature_content_version import FeatureContentType -from app.models.integration_config import IntegrationConfig +from app.models.feature_import_comment import FeatureImportComment from app.models.module import Module -from app.services.feature_service import FeatureService from app.services.feature_content_version_service import FeatureContentVersionService +from app.services.feature_service import FeatureService from app.services.integration_service import IntegrationService logger = logging.getLogger(__name__) @@ -66,9 +65,7 @@ async def search_issues( raise ValueError(f"Connector {connector_id} not found") if config.provider != provider: - raise ValueError( - f"Connector provider mismatch: expected {provider}, got {config.provider}" - ) + raise ValueError(f"Connector provider mismatch: expected {provider}, got {config.provider}") # Get decrypted token and create adapter token = integration_service._decrypt_token(config.encrypted_token) @@ -128,10 +125,7 @@ async def import_issue( ) if existing: - raise ValueError( - f"This issue is already imported as feature '{existing.title}' " - f"({existing.feature_key})" - ) + raise ValueError(f"This issue is already imported as feature '{existing.title}' ({existing.feature_key})") # Get decrypted token and create adapter token = integration_service._decrypt_token(config.encrypted_token) @@ -141,9 +135,7 @@ async def import_issue( ticket_data = await adapter.fetch_ticket(external_id) # Normalize state category - state_category = FeatureImportService._normalize_state_category( - provider, ticket_data.state - ) + state_category = FeatureImportService._normalize_state_category(provider, ticket_data.state) # Extract author from metadata author = ticket_data.metadata.get("author", "unknown") @@ -194,9 +186,7 @@ async def import_issue( feature_id=feature.id, author=comment.get("author", "unknown"), body_markdown=comment.get("body", ""), - source_created_at=FeatureImportService._parse_timestamp( - comment.get("created_at") - ), + source_created_at=FeatureImportService._parse_timestamp(comment.get("created_at")), order_index=idx, ) sync_db.add(import_comment) @@ -221,9 +211,7 @@ async def import_issue( # Log but don't fail the import if version creation fails logger.warning(f"Failed to create spec version for {feature.feature_key}: {e}") - logger.info( - f"Imported issue {external_id} from {provider} as feature {feature.feature_key}" - ) + logger.info(f"Imported issue {external_id} from {provider} as feature {feature.feature_key}") return feature diff --git a/backend/app/services/feature_service.py b/backend/app/services/feature_service.py index 74d14d3..0febb2b 100644 --- a/backend/app/services/feature_service.py +++ b/backend/app/services/feature_service.py @@ -1,16 +1,24 @@ """Service for managing features.""" -from typing import Optional, List, Tuple -from uuid import UUID + from datetime import datetime, timezone +from typing import List, Optional, Tuple +from uuid import UUID + +from sqlalchemy import asc, case, desc, func from sqlalchemy.orm import Session -from sqlalchemy import func, case, asc, desc -from sqlalchemy.exc import IntegrityError -from app.config import settings -from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeaturePriority, FeatureType, FeatureCompletionStatus, FeatureVisibilityStatus +from app.models.feature import ( + Feature, + FeatureCompletionStatus, + FeaturePriority, + FeatureProvenance, + FeatureStatus, + FeatureType, + FeatureVisibilityStatus, +) from app.models.module import Module from app.models.project import Project -from app.services.activity_log_service import ActivityLogService, ActivityEventTypes +from app.services.activity_log_service import ActivityEventTypes, ActivityLogService class FeatureService: @@ -127,11 +135,7 @@ def get_by_identifier(db: Session, identifier: str) -> Optional[Feature]: # Extract short_id and query short_id = extract_short_id(identifier) - return ( - db.query(Feature) - .filter(Feature.short_id == short_id) - .first() - ) + return db.query(Feature).filter(Feature.short_id == short_id).first() @staticmethod def list_features( @@ -251,7 +255,7 @@ def list_features( (Feature.priority == FeaturePriority.MUST_HAVE, 1), (Feature.priority == FeaturePriority.IMPORTANT, 2), (Feature.priority == FeaturePriority.OPTIONAL, 3), - else_=4 + else_=4, ) query = query.order_by(order_func(priority_order)) elif sort_by == "completion_status": @@ -260,7 +264,7 @@ def list_features( (Feature.completion_status == FeatureCompletionStatus.PENDING, 1), (Feature.completion_status == FeatureCompletionStatus.IN_PROGRESS, 2), (Feature.completion_status == FeatureCompletionStatus.COMPLETED, 3), - else_=4 + else_=4, ) query = query.order_by(order_func(status_order)) else: @@ -363,10 +367,7 @@ def generate_feature_key( # Use MAX instead of COUNT for robustness (handles deleted features) # We need to join through modules to get all features for a project max_number = ( - db.query(func.max(Feature.feature_key_number)) - .join(Module) - .filter(Module.project_id == project_id) - .scalar() + db.query(func.max(Feature.feature_key_number)).join(Module).filter(Module.project_id == project_id).scalar() ) or 0 next_number = max_number + 1 @@ -420,20 +421,13 @@ def archive_system_features( ids_to_archive = [fid[0] for fid in feature_ids] # Get feature details before archiving (for activity logging) - features_to_archive = ( - db.query(Feature) - .filter(Feature.id.in_(ids_to_archive)) - .all() - ) + features_to_archive = db.query(Feature).filter(Feature.id.in_(ids_to_archive)).all() # Update features by ID result = ( db.query(Feature) .filter(Feature.id.in_(ids_to_archive)) - .update( - {"status": FeatureStatus.ARCHIVED, "archived_at": now}, - synchronize_session="fetch" - ) + .update({"status": FeatureStatus.ARCHIVED, "archived_at": now}, synchronize_session="fetch") ) db.commit() @@ -624,8 +618,9 @@ def _broadcast_feature_update(db: Session, feature: Feature, update_type: str): update_type: Type of update (completion_status, completion_summary, notes) """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.schemas.feature import FeatureListResponse + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -650,22 +645,15 @@ def _broadcast_feature_update(db: Session, feature: Feature, update_type: str): feature_data = FeatureListResponse.from_feature(feature).model_dump(mode="json") - impl_stats = db.query( - func.max(case( - (Implementation.spec_text.isnot(None), 1), - else_=0 - )).label('has_spec'), - func.max(case( - (Implementation.prompt_plan_text.isnot(None), 1), - else_=0 - )).label('has_prompt_plan'), - func.max(case( - (Implementation.implementation_notes.isnot(None), 1), - else_=0 - )).label('has_notes'), - ).filter( - Implementation.feature_id == feature.id - ).first() + impl_stats = ( + db.query( + func.max(case((Implementation.spec_text.isnot(None), 1), else_=0)).label("has_spec"), + func.max(case((Implementation.prompt_plan_text.isnot(None), 1), else_=0)).label("has_prompt_plan"), + func.max(case((Implementation.implementation_notes.isnot(None), 1), else_=0)).label("has_notes"), + ) + .filter(Implementation.feature_id == feature.id) + .first() + ) if impl_stats: feature_data["has_spec"] = feature_data["has_spec"] or bool(impl_stats.has_spec) @@ -690,10 +678,7 @@ def _broadcast_feature_update(db: Session, feature: Feature, update_type: str): ) if success: - logger.info( - f"Broadcasted feature update via Kafka: feature_id={feature.id}, " - f"update_type={update_type}" - ) + logger.info(f"Broadcasted feature update via Kafka: feature_id={feature.id}, update_type={update_type}") except Exception as e: logger.error(f"Failed to broadcast feature update: {e}", exc_info=True) diff --git a/backend/app/services/finalization_service.py b/backend/app/services/finalization_service.py index 6b1603b..05233f6 100644 --- a/backend/app/services/finalization_service.py +++ b/backend/app/services/finalization_service.py @@ -1,11 +1,13 @@ """Service for finalizing specs and prompt plans.""" + from typing import Optional from uuid import UUID + from sqlalchemy.orm import Session -from app.models.spec_version import SpecVersion, SpecType -from app.models.final_spec import FinalSpec from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec +from app.models.spec_version import SpecType, SpecVersion class FinalizationService: @@ -43,9 +45,7 @@ def generate_final_spec( raise ValueError(f"Draft version {draft_version_id} is not a specification") # Delete any existing final spec for this phase - db.query(FinalSpec).filter( - FinalSpec.brainstorming_phase_id == draft.brainstorming_phase_id - ).delete() + db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == draft.brainstorming_phase_id).delete() # Create new final spec final_spec = FinalSpec( @@ -123,9 +123,7 @@ def get_final_spec( Returns: The FinalSpec if found, None otherwise """ - return db.query(FinalSpec).filter( - FinalSpec.brainstorming_phase_id == brainstorming_phase_id - ).first() + return db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == brainstorming_phase_id).first() @staticmethod def get_final_prompt_plan( @@ -141,9 +139,9 @@ def get_final_prompt_plan( Returns: The FinalPromptPlan if found, None otherwise """ - return db.query(FinalPromptPlan).filter( - FinalPromptPlan.brainstorming_phase_id == brainstorming_phase_id - ).first() + return ( + db.query(FinalPromptPlan).filter(FinalPromptPlan.brainstorming_phase_id == brainstorming_phase_id).first() + ) @staticmethod def _synthesize_from_threads( diff --git a/backend/app/services/github_integration_oauth_service.py b/backend/app/services/github_integration_oauth_service.py index 0d46cbb..f59e279 100644 --- a/backend/app/services/github_integration_oauth_service.py +++ b/backend/app/services/github_integration_oauth_service.py @@ -58,10 +58,7 @@ def is_configured() -> bool: Returns: True if both client ID and secret are set in ENV """ - return bool( - settings.github_integration_oauth_client_id - and settings.github_integration_oauth_client_secret - ) + return bool(settings.github_integration_oauth_client_id and settings.github_integration_oauth_client_secret) @staticmethod def is_configured_with_db(db: Session) -> bool: @@ -102,8 +99,7 @@ def _get_credentials_sync(self, sync_db: Session) -> tuple[str, str]: if not client_id or not client_secret: raise ValueError( - "GitHub OAuth is not configured. " - "Please configure it in Platform Settings or via environment variables." + "GitHub OAuth is not configured. Please configure it in Platform Settings or via environment variables." ) self._cached_credentials = (client_id, client_secret) @@ -143,9 +139,7 @@ async def create_state( return state_token - async def validate_and_consume_state( - self, state_token: str - ) -> GitHubOAuthState | None: + async def validate_and_consume_state(self, state_token: str) -> GitHubOAuthState | None: """Validate state token and consume it (single-use). Args: @@ -263,9 +257,7 @@ async def cleanup_expired_states(self) -> int: Returns: Number of expired states deleted """ - stmt = delete(GitHubOAuthState).where( - GitHubOAuthState.expires_at <= datetime.now(timezone.utc) - ) + stmt = delete(GitHubOAuthState).where(GitHubOAuthState.expires_at <= datetime.now(timezone.utc)) result = await self.db.execute(stmt) await self.db.commit() return result.rowcount diff --git a/backend/app/services/grounding_note_service.py b/backend/app/services/grounding_note_service.py index 1cc4c8e..9a8005a 100644 --- a/backend/app/services/grounding_note_service.py +++ b/backend/app/services/grounding_note_service.py @@ -66,9 +66,7 @@ def create_version( db.commit() db.refresh(new_version) - logger.info( - f"Created grounding note version {next_version} for project {project_id}" - ) + logger.info(f"Created grounding note version {next_version} for project {project_id}") return new_version @staticmethod @@ -131,8 +129,4 @@ def get_version( Returns: The GroundingNoteVersion or None if not found """ - return ( - db.query(GroundingNoteVersion) - .filter(GroundingNoteVersion.id == version_id) - .first() - ) + return db.query(GroundingNoteVersion).filter(GroundingNoteVersion.id == version_id).first() diff --git a/backend/app/services/grounding_service.py b/backend/app/services/grounding_service.py index 17752d4..df8cb67 100644 --- a/backend/app/services/grounding_service.py +++ b/backend/app/services/grounding_service.py @@ -1,16 +1,15 @@ """Service for managing grounding files for coding agent warm starts.""" import os -from typing import Optional, List -from uuid import UUID from datetime import datetime, timezone +from typing import List, Optional +from uuid import UUID + from sqlalchemy.orm import Session -from app.config import settings from app.models.grounding_file import GroundingFile from app.models.grounding_file_branch import GroundingFileBranch - # Template for agents.md - the main grounding file AGENTS_MD_TEMPLATE = """## Architecture @@ -198,9 +197,7 @@ def create_file( ValueError: If filename has invalid extension or already exists """ if not GroundingService.is_valid_filename(filename): - raise ValueError( - f"Invalid file extension. Allowed: {', '.join(ALLOWED_EXTENSIONS)}" - ) + raise ValueError(f"Invalid file extension. Allowed: {', '.join(ALLOWED_EXTENSIONS)}") # Check if file already exists existing = GroundingService.get_file(db, project_id, filename) @@ -364,8 +361,9 @@ def _broadcast_grounding_update( update_type: Type of update (created, written, appended, deleted) """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.models.project import Project + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -373,9 +371,7 @@ def _broadcast_grounding_update( # Get project to find org_id project = db.query(Project).filter(Project.id == project_id).first() if not project: - logger.warning( - f"Cannot broadcast grounding update: project {project_id} not found" - ) + logger.warning(f"Cannot broadcast grounding update: project {project_id} not found") return # Serialize grounding file to JSON-compatible format @@ -475,9 +471,7 @@ def get_or_create_branch_file( Returns: The GroundingFileBranch (existing or newly created) """ - existing = GroundingService.get_branch_file( - db, project_id, user_id, branch_name, filename - ) + existing = GroundingService.get_branch_file(db, project_id, user_id, branch_name, filename) if existing: return existing @@ -601,9 +595,7 @@ def update_branch_summary( Returns: The updated GroundingFileBranch, or None if not found """ - branch_file = GroundingService.get_branch_file( - db, project_id, user_id, branch_name, filename - ) + branch_file = GroundingService.get_branch_file(db, project_id, user_id, branch_name, filename) if not branch_file: return None @@ -636,9 +628,7 @@ def mark_branch_merged( Returns: The updated GroundingFileBranch, or None if not found """ - branch_file = GroundingService.get_branch_file( - db, project_id, user_id, branch_name, filename - ) + branch_file = GroundingService.get_branch_file(db, project_id, user_id, branch_name, filename) if not branch_file: return None @@ -688,8 +678,9 @@ def _broadcast_branch_grounding_update( update_type: Type of update (created, written, appended, merged) """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.models.project import Project + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -697,9 +688,7 @@ def _broadcast_branch_grounding_update( # Get project to find org_id project = db.query(Project).filter(Project.id == project_id).first() if not project: - logger.warning( - f"Cannot broadcast branch grounding update: project {project_id} not found" - ) + logger.warning(f"Cannot broadcast branch grounding update: project {project_id} not found") return # Serialize branch file to JSON-compatible format @@ -712,11 +701,7 @@ def _broadcast_branch_grounding_update( "repo_path": branch_file.repo_path, "is_merged": branch_file.is_merged, "is_merging": branch_file.is_merging, - "merged_at": ( - branch_file.merged_at.isoformat() - if branch_file.merged_at - else None - ), + "merged_at": (branch_file.merged_at.isoformat() if branch_file.merged_at else None), "last_synced_with_global_at": ( branch_file.last_synced_with_global_at.isoformat() if branch_file.last_synced_with_global_at @@ -752,6 +737,4 @@ def _broadcast_branch_grounding_update( ) except Exception as e: - logger.error( - f"Failed to broadcast branch grounding update: {e}", exc_info=True - ) + logger.error(f"Failed to broadcast branch grounding update: {e}", exc_info=True) diff --git a/backend/app/services/image_service.py b/backend/app/services/image_service.py index 048908b..12d3a1f 100644 --- a/backend/app/services/image_service.py +++ b/backend/app/services/image_service.py @@ -7,10 +7,9 @@ import logging import time import uuid -from dataclasses import dataclass, asdict +from dataclasses import asdict, dataclass from datetime import datetime, timezone from io import BytesIO -from typing import BinaryIO from uuid import UUID import boto3 @@ -89,18 +88,14 @@ async def _get_s3_config(self) -> tuple[dict | None, dict | None, str | None]: if not settings.object_storage_connector_id: return None, None, "No object storage connector configured" - connector = await self._settings_service.get_connector( - settings.object_storage_connector_id - ) + connector = await self._settings_service.get_connector(settings.object_storage_connector_id) if not connector: return None, None, "Object storage connector not found" if not connector.is_active: return None, None, "Object storage connector is not active" - credentials_str = self._settings_service._decrypt_credentials( - connector.encrypted_credentials - ) + credentials_str = self._settings_service._decrypt_credentials(connector.encrypted_credentials) try: credentials = json.loads(credentials_str) except json.JSONDecodeError: @@ -200,9 +195,7 @@ def _create_thumbnail(self, image: Image.Image, content_type: str) -> bytes: return buffer.getvalue() - def _generate_s3_keys( - self, org_id: str, project_id: str, content_type: str - ) -> tuple[str, str]: + def _generate_s3_keys(self, org_id: str, project_id: str, content_type: str) -> tuple[str, str]: """Generate S3 keys for image and thumbnail. Args: @@ -248,7 +241,9 @@ def _generate_s3_keys_for_project_chat( thumb_ext = "gif" if content_type == "image/gif" else "jpg" if project_id: - base_path = f"orgs/{org_id}/projects/{project_id}/discussions/{project_chat_id}/images/{now.year}/{now.month:02d}" + base_path = ( + f"orgs/{org_id}/projects/{project_id}/discussions/{project_chat_id}/images/{now.year}/{now.month:02d}" + ) else: base_path = f"orgs/{org_id}/discussions/{project_chat_id}/images/{now.year}/{now.month:02d}" @@ -291,9 +286,7 @@ async def upload_image( thumbnail_content = self._create_thumbnail(image, content_type) # Generate S3 keys - image_key, thumbnail_key = self._generate_s3_keys( - org_id, project_id, content_type - ) + image_key, thumbnail_key = self._generate_s3_keys(org_id, project_id, content_type) # Get S3 client s3_client, bucket = await self._get_s3_client() @@ -307,9 +300,7 @@ async def upload_image( ) # Upload thumbnail - thumb_content_type = ( - "image/gif" if content_type == "image/gif" else "image/jpeg" - ) + thumb_content_type = "image/gif" if content_type == "image/gif" else "image/jpeg" s3_client.put_object( Bucket=bucket, Key=thumbnail_key, @@ -343,9 +334,7 @@ async def upload_image( except Exception as e: logger.exception(f"Failed to upload image: {e}") - return ImageUploadResult( - success=False, message=f"Failed to upload image: {str(e)}" - ) + return ImageUploadResult(success=False, message=f"Failed to upload image: {str(e)}") async def upload_image_for_project_chat( self, @@ -401,9 +390,7 @@ async def upload_image_for_project_chat( ) # Upload thumbnail - thumb_content_type = ( - "image/gif" if content_type == "image/gif" else "image/jpeg" - ) + thumb_content_type = "image/gif" if content_type == "image/gif" else "image/jpeg" s3_client.put_object( Bucket=bucket, Key=thumbnail_key, @@ -437,9 +424,7 @@ async def upload_image_for_project_chat( except Exception as e: logger.exception(f"Failed to upload discussion image: {e}") - return ImageUploadResult( - success=False, message=f"Failed to upload image: {str(e)}" - ) + return ImageUploadResult(success=False, message=f"Failed to upload image: {str(e)}") async def get_signed_url(self, s3_key: str, expiry: int | None = None) -> str: """Generate a signed URL for an S3 object. @@ -505,10 +490,9 @@ def get_image_bytes(s3_key: str) -> bytes: Raises: ValueError: If S3 is not configured or fetch fails """ - from sqlalchemy.orm import Session from app.database import SessionLocal - from app.models.platform_settings import PlatformSettings from app.models.platform_connector import PlatformConnector + from app.models.platform_settings import PlatformSettings # Get S3 config from platform settings (sync) db = SessionLocal() @@ -518,19 +502,16 @@ def get_image_bytes(s3_key: str) -> bytes: raise ValueError("No object storage connector configured") connector = ( - db.query(PlatformConnector) - .filter(PlatformConnector.id == settings.object_storage_connector_id) - .first() + db.query(PlatformConnector).filter(PlatformConnector.id == settings.object_storage_connector_id).first() ) if not connector or not connector.is_active: raise ValueError("Object storage connector not found or inactive") # Decrypt credentials from app.services.platform_settings_service import PlatformSettingsService + platform_service = PlatformSettingsService(db) - credentials_str = platform_service._decrypt_credentials( - connector.encrypted_credentials - ) + credentials_str = platform_service._decrypt_credentials(connector.encrypted_credentials) credentials = json.loads(credentials_str) config = connector.config_json or {} @@ -618,10 +599,7 @@ def generate_signed_image_url( sig = hmac.new(key, message.encode(), hashlib.sha256).hexdigest() base_url = settings.base_url.rstrip("/") - return ( - f"{base_url}/api/v1/images/{image_id}" - f"?max_width={max_width}&max_height={max_height}&exp={exp}&sig={sig}" - ) + return f"{base_url}/api/v1/images/{image_id}?max_width={max_width}&max_height={max_height}&exp={exp}&sig={sig}" @staticmethod def verify_image_signature( diff --git a/backend/app/services/implementation_service.py b/backend/app/services/implementation_service.py index c722b8c..b0a2169 100644 --- a/backend/app/services/implementation_service.py +++ b/backend/app/services/implementation_service.py @@ -1,13 +1,15 @@ """Service for managing implementations.""" + import logging -from typing import Optional, List -from uuid import UUID from datetime import datetime, timezone -from sqlalchemy.orm import Session +from typing import List, Optional +from uuid import UUID + from sqlalchemy import asc +from sqlalchemy.orm import Session -from app.models.implementation import Implementation from app.models.feature import Feature, FeatureCompletionStatus +from app.models.implementation import Implementation from app.models.module import Module from app.models.project import Project from app.services.kafka_producer import get_sync_kafka_producer @@ -38,11 +40,7 @@ def sync_feature_completion_from_implementations( if not feature: return - implementations = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature_id) - .all() - ) + implementations = db.query(Implementation).filter(Implementation.feature_id == feature_id).all() # If no implementations, status is pending if not implementations: @@ -94,11 +92,7 @@ def create_implementation( """ # Calculate order_index if not provided if order_index is None: - max_order = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature_id) - .count() - ) + max_order = db.query(Implementation).filter(Implementation.feature_id == feature_id).count() order_index = max_order # If this is primary, unset any existing primary @@ -137,11 +131,7 @@ def get_implementation( Returns: The Implementation if found, None otherwise """ - return ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + return db.query(Implementation).filter(Implementation.id == implementation_id).first() @staticmethod def list_implementations( @@ -213,11 +203,7 @@ def update_implementation( Returns: The updated Implementation if found, None otherwise """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return None @@ -256,11 +242,7 @@ def mark_complete( Returns: The updated Implementation if found, None otherwise """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return None @@ -293,11 +275,7 @@ def mark_incomplete( Returns: The updated Implementation if found, None otherwise """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return None @@ -306,9 +284,7 @@ def mark_incomplete( implementation.completed_by_id = None # Sync feature completion status (will become pending since not all are complete) - ImplementationService.sync_feature_completion_from_implementations( - db, implementation.feature_id - ) + ImplementationService.sync_feature_completion_from_implementations(db, implementation.feature_id) db.commit() db.refresh(implementation) @@ -339,11 +315,7 @@ def clear_status_and_notes( Returns: The updated Implementation if found, None otherwise """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return None @@ -357,9 +329,7 @@ def clear_status_and_notes( implementation.notes_updated_at = None # Sync feature completion status (derived from all implementations) - ImplementationService.sync_feature_completion_from_implementations( - db, implementation.feature_id - ) + ImplementationService.sync_feature_completion_from_implementations(db, implementation.feature_id) db.commit() db.refresh(implementation) @@ -390,11 +360,7 @@ def clear_status_and_notes_bulk( Returns: List of updated Implementation objects """ - implementations = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature_id) - .all() - ) + implementations = db.query(Implementation).filter(Implementation.feature_id == feature_id).all() # Clear implementation-level status and notes for impl in implementations: @@ -407,9 +373,7 @@ def clear_status_and_notes_bulk( impl.notes_updated_at = None # Sync feature completion status (derived from all implementations) - ImplementationService.sync_feature_completion_from_implementations( - db, feature_id - ) + ImplementationService.sync_feature_completion_from_implementations(db, feature_id) db.commit() for impl in implementations: @@ -455,9 +419,7 @@ def clear_status_and_notes_by_module( # Sync feature completion for each affected feature for feature_id in feature_ids: - ImplementationService.sync_feature_completion_from_implementations( - db, feature_id - ) + ImplementationService.sync_feature_completion_from_implementations(db, feature_id) db.commit() for impl in implementations: @@ -504,9 +466,7 @@ def clear_status_and_notes_by_project( # Sync feature completion for each affected feature for feature_id in feature_ids: - ImplementationService.sync_feature_completion_from_implementations( - db, feature_id - ) + ImplementationService.sync_feature_completion_from_implementations(db, feature_id) db.commit() for impl in implementations: @@ -527,11 +487,7 @@ def set_primary( Returns: The updated Implementation if found, None otherwise """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return None @@ -566,21 +522,13 @@ def delete_implementation( Returns: True if deleted, False if not found or is the only implementation (and not forced) """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return False # Check if it's the only implementation (unless force is True) if not force: - count = ( - db.query(Implementation) - .filter(Implementation.feature_id == implementation.feature_id) - .count() - ) + count = db.query(Implementation).filter(Implementation.feature_id == implementation.feature_id).count() if count <= 1: return False @@ -619,18 +567,12 @@ def append_note( Returns: The updated Implementation if found, None otherwise """ - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: return None if implementation.implementation_notes: - implementation.implementation_notes = ( - implementation.implementation_notes + "\n\n" + note - ) + implementation.implementation_notes = implementation.implementation_notes + "\n\n" + note else: implementation.implementation_notes = note @@ -657,34 +599,23 @@ def broadcast_implementation_created( """ try: # Get feature to find module - feature = ( - db.query(Feature) - .filter(Feature.id == implementation.feature_id) - .first() - ) + feature = db.query(Feature).filter(Feature.id == implementation.feature_id).first() if not feature: logger.warning( - f"Cannot broadcast implementation_created: " - f"feature {implementation.feature_id} not found" + f"Cannot broadcast implementation_created: feature {implementation.feature_id} not found" ) return # Get module to find project module = db.query(Module).filter(Module.id == feature.module_id).first() if not module: - logger.warning( - f"Cannot broadcast implementation_created: " - f"module {feature.module_id} not found" - ) + logger.warning(f"Cannot broadcast implementation_created: module {feature.module_id} not found") return # Get project to find org_id project = db.query(Project).filter(Project.id == module.project_id).first() if not project: - logger.warning( - f"Cannot broadcast implementation_created: " - f"project {module.project_id} not found" - ) + logger.warning(f"Cannot broadcast implementation_created: project {module.project_id} not found") return # Create the message payload @@ -745,28 +676,19 @@ def broadcast_implementations_deleted( # Get feature to find module feature = db.query(Feature).filter(Feature.id == feature_id).first() if not feature: - logger.warning( - f"Cannot broadcast implementations_deleted: " - f"feature {feature_id} not found" - ) + logger.warning(f"Cannot broadcast implementations_deleted: feature {feature_id} not found") return # Get module to find project module = db.query(Module).filter(Module.id == feature.module_id).first() if not module: - logger.warning( - f"Cannot broadcast implementations_deleted: " - f"module {feature.module_id} not found" - ) + logger.warning(f"Cannot broadcast implementations_deleted: module {feature.module_id} not found") return # Get project to find org_id project = db.query(Project).filter(Project.id == module.project_id).first() if not project: - logger.warning( - f"Cannot broadcast implementations_deleted: " - f"project {module.project_id} not found" - ) + logger.warning(f"Cannot broadcast implementations_deleted: project {module.project_id} not found") return # Create the message payload @@ -809,34 +731,23 @@ def broadcast_implementation_updated( """ try: # Get feature to find module - feature = ( - db.query(Feature) - .filter(Feature.id == implementation.feature_id) - .first() - ) + feature = db.query(Feature).filter(Feature.id == implementation.feature_id).first() if not feature: logger.warning( - f"Cannot broadcast implementation_updated: " - f"feature {implementation.feature_id} not found" + f"Cannot broadcast implementation_updated: feature {implementation.feature_id} not found" ) return # Get module to find project module = db.query(Module).filter(Module.id == feature.module_id).first() if not module: - logger.warning( - f"Cannot broadcast implementation_updated: " - f"module {feature.module_id} not found" - ) + logger.warning(f"Cannot broadcast implementation_updated: module {feature.module_id} not found") return # Get project to find org_id project = db.query(Project).filter(Project.id == module.project_id).first() if not project: - logger.warning( - f"Cannot broadcast implementation_updated: " - f"project {module.project_id} not found" - ) + logger.warning(f"Cannot broadcast implementation_updated: project {module.project_id} not found") return # Create the message payload @@ -859,9 +770,7 @@ def broadcast_implementation_updated( "order_index": implementation.order_index, "created_at": implementation.created_at.isoformat(), "notes_updated_at": ( - implementation.notes_updated_at.isoformat() - if implementation.notes_updated_at - else None + implementation.notes_updated_at.isoformat() if implementation.notes_updated_at else None ), "notes_updated_by_agent": implementation.notes_updated_by_agent, }, @@ -884,9 +793,8 @@ def broadcast_implementation_updated( # The features list only subscribes to feature_updated events, and # _broadcast_feature_update now aggregates has_* across implementations. from app.services.feature_service import FeatureService - FeatureService._broadcast_feature_update( - db, feature, f"implementation_{update_type}" - ) + + FeatureService._broadcast_feature_update(db, feature, f"implementation_{update_type}") except Exception as e: logger.error(f"Failed to broadcast implementation_updated: {e}", exc_info=True) diff --git a/backend/app/services/inbox_badge_service.py b/backend/app/services/inbox_badge_service.py index 986159d..1485e17 100644 --- a/backend/app/services/inbox_badge_service.py +++ b/backend/app/services/inbox_badge_service.py @@ -1,4 +1,5 @@ """Service layer for inbox badge count aggregation.""" + import logging from uuid import UUID @@ -38,9 +39,7 @@ def get_project_badge_counts( - total_unread: Sum of all unread counts """ # Get unread mentions count for this project - unread_mentions = InboxMentionService.get_unread_mentions_count( - db, user_id, project_id - ) + unread_mentions = InboxMentionService.get_unread_mentions_count(db, user_id, project_id) # Get all followed conversations for this user followed_statuses = InboxStatusService.get_followed_conversations(db, user_id) @@ -59,11 +58,13 @@ def get_project_badge_counts( ) if count is not None and count > 0: - conversations.append({ - "conversation_id": status.conversation_id, - "conversation_type": status.conversation_type.value, - "unread_count": count, - }) + conversations.append( + { + "conversation_id": status.conversation_id, + "conversation_type": status.conversation_type.value, + "unread_count": count, + } + ) total_unread += count return { @@ -97,17 +98,19 @@ def _get_conversation_unread_count( # Verify project chat belongs to this project try: project_chat_id = UUID(conversation_id) - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id, - ProjectChat.project_id == project_id, - ).first() + project_chat = ( + db.query(ProjectChat) + .filter( + ProjectChat.id == project_chat_id, + ProjectChat.project_id == project_id, + ) + .first() + ) if not project_chat: return None - return InboxStatusService.get_unread_count_for_project_chat( - db, user_id, project_chat_id - ) + return InboxStatusService.get_unread_count_for_project_chat(db, user_id, project_chat_id) except ValueError: logger.warning(f"Invalid project_chat_id: {conversation_id}") return None @@ -115,17 +118,19 @@ def _get_conversation_unread_count( else: # FEATURE or PHASE - both use Thread # Verify thread belongs to this project - thread = db.query(Thread).filter( - Thread.id == conversation_id, - Thread.project_id == str(project_id), - ).first() + thread = ( + db.query(Thread) + .filter( + Thread.id == conversation_id, + Thread.project_id == str(project_id), + ) + .first() + ) if not thread: return None - return InboxStatusService.get_unread_count_for_thread( - db, user_id, conversation_id, conversation_type - ) + return InboxStatusService.get_unread_count_for_thread(db, user_id, conversation_id, conversation_type) @staticmethod def get_org_badge_counts( diff --git a/backend/app/services/inbox_broadcast_service.py b/backend/app/services/inbox_broadcast_service.py index 89eb476..b7ee8a9 100644 --- a/backend/app/services/inbox_broadcast_service.py +++ b/backend/app/services/inbox_broadcast_service.py @@ -1,4 +1,5 @@ """Service layer for broadcasting inbox events via WebSocket.""" + import logging from datetime import datetime, timezone from typing import List, Optional @@ -49,18 +50,26 @@ def get_recipients_for_thread( List of user IDs who should receive the broadcast """ # Get thread-level followers - thread_follows = db.query(InboxFollow.user_id).filter( - InboxFollow.thread_id == thread_id, - InboxFollow.thread_type == thread_type, - InboxFollow.follow_type == InboxFollowType.THREAD, - ).all() + thread_follows = ( + db.query(InboxFollow.user_id) + .filter( + InboxFollow.thread_id == thread_id, + InboxFollow.thread_type == thread_type, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) + .all() + ) thread_follower_ids = {f.user_id for f in thread_follows} # Get project-level followers - project_follows = db.query(InboxFollow.user_id).filter( - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, - ).all() + project_follows = ( + db.query(InboxFollow.user_id) + .filter( + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) + .all() + ) project_follower_ids = {f.user_id for f in project_follows} # Combine (union) both sets @@ -95,19 +104,27 @@ def get_recipients_for_project_chat( List of user IDs who should receive the broadcast """ # Get thread-level followers for this chat - thread_follows = db.query(InboxFollow.user_id).filter( - InboxFollow.thread_id == str(project_chat_id), - InboxFollow.thread_type == InboxThreadType.PROJECT_CHAT, - InboxFollow.follow_type == InboxFollowType.THREAD, - ).all() + thread_follows = ( + db.query(InboxFollow.user_id) + .filter( + InboxFollow.thread_id == str(project_chat_id), + InboxFollow.thread_type == InboxThreadType.PROJECT_CHAT, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) + .all() + ) all_recipients = {f.user_id for f in thread_follows} # For project-scoped chats, also include project followers if project_id: - project_follows = db.query(InboxFollow.user_id).filter( - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, - ).all() + project_follows = ( + db.query(InboxFollow.user_id) + .filter( + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) + .all() + ) all_recipients.update(f.user_id for f in project_follows) # Exclude the specified user @@ -176,9 +193,7 @@ def broadcast_new_message( ) else: if not project_id: - logger.warning( - f"Cannot broadcast for {conversation_type} without project_id" - ) + logger.warning(f"Cannot broadcast for {conversation_type} without project_id") return False recipients = InboxBroadcastService.get_recipients_for_thread( db=db, @@ -189,9 +204,7 @@ def broadcast_new_message( ) if not recipients: - logger.debug( - f"No recipients for inbox broadcast: {conversation_type}:{conversation_id}" - ) + logger.debug(f"No recipients for inbox broadcast: {conversation_type}:{conversation_id}") return True # Not an error, just no one to notify # Build deep link URL for direct navigation @@ -214,9 +227,7 @@ def broadcast_new_message( "conversation_type": conversation_type, "project_id": str(project_id) if project_id else None, "payload": { - "message_preview": InboxBroadcastService._truncate_preview( - message_preview - ), + "message_preview": InboxBroadcastService._truncate_preview(message_preview), "sequence_number": message_sequence, "author_id": str(author_id), "author_name": author_name, @@ -233,13 +244,10 @@ def broadcast_new_message( if success: logger.info( - f"Broadcasted inbox_new_message to {len(recipients)} recipients: " - f"{conversation_type}:{conversation_id}" + f"Broadcasted inbox_new_message to {len(recipients)} recipients: {conversation_type}:{conversation_id}" ) else: - logger.error( - f"Failed to broadcast inbox_new_message: {conversation_type}:{conversation_id}" - ) + logger.error(f"Failed to broadcast inbox_new_message: {conversation_type}:{conversation_id}") return success @@ -303,9 +311,7 @@ def broadcast_mention_added( "payload": { "mentioner_id": str(mentioner_id), "mentioner_name": mentioner_name, - "message_preview": InboxBroadcastService._truncate_preview( - message_preview - ), + "message_preview": InboxBroadcastService._truncate_preview(message_preview), "deep_link_url": deep_link_url, }, "timestamp": datetime.now(timezone.utc).isoformat(), @@ -319,14 +325,10 @@ def broadcast_mention_added( if success: logger.info( - f"Broadcasted inbox_mention_added to user {mentioned_user_id}: " - f"{conversation_type}:{conversation_id}" + f"Broadcasted inbox_mention_added to user {mentioned_user_id}: {conversation_type}:{conversation_id}" ) else: - logger.error( - f"Failed to broadcast inbox_mention_added: " - f"{conversation_type}:{conversation_id}" - ) + logger.error(f"Failed to broadcast inbox_mention_added: {conversation_type}:{conversation_id}") return success @@ -378,8 +380,7 @@ def broadcast_read_status_changed( if success: logger.debug( - f"Broadcasted inbox_read_status_changed for user {user_id}: " - f"{conversation_type}:{conversation_id}" + f"Broadcasted inbox_read_status_changed for user {user_id}: {conversation_type}:{conversation_id}" ) return success @@ -431,8 +432,7 @@ def broadcast_badge_update( if success: logger.debug( - f"Broadcasted inbox_badge_updated for user {user_id}: " - f"unread={unread_count}, mentions={unread_mentions}" + f"Broadcasted inbox_badge_updated for user {user_id}: unread={unread_count}, mentions={unread_mentions}" ) return success diff --git a/backend/app/services/inbox_conversation_service.py b/backend/app/services/inbox_conversation_service.py index 747ff42..cda0d3f 100644 --- a/backend/app/services/inbox_conversation_service.py +++ b/backend/app/services/inbox_conversation_service.py @@ -1,4 +1,5 @@ """Service layer for inbox conversation aggregation.""" + import logging import time from typing import Optional @@ -7,21 +8,20 @@ from sqlalchemy import and_, desc from sqlalchemy.orm import Session, joinedload +from app.models.brainstorming_phase import BrainstormingPhase +from app.models.feature import Feature from app.models.inbox_mention import InboxConversationType, InboxMention from app.models.project import Project from app.models.project_chat import ProjectChat, ProjectChatMessage -from app.models.thread import Thread, ContextType +from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem -from app.models.feature import Feature -from app.models.module import Module -from app.models.brainstorming_phase import BrainstormingPhase from app.models.user_conversation_status import UserConversationStatus from app.schemas.inbox_conversation import ( - UnifiedConversation, + ConversationSortField, InboxConversationsRequest, InboxConversationsResponse, - ConversationSortField, SortOrder, + UnifiedConversation, ) from app.services.inbox_status_service import InboxStatusService from app.services.project_share_service import ProjectShareService @@ -98,18 +98,14 @@ def get_user_inbox_conversations( conversations.extend(phases) # Enrich with unread counts - conversations = InboxConversationService._enrich_with_unread_counts( - db, user_id, conversations - ) + conversations = InboxConversationService._enrich_with_unread_counts(db, user_id, conversations) # Apply unread_only filter if request.unread_only: conversations = [c for c in conversations if c.unread_count > 0] # Sort conversations - conversations = InboxConversationService._sort_conversations( - conversations, request.sort_by, request.sort_order - ) + conversations = InboxConversationService._sort_conversations(conversations, request.sort_by, request.sort_order) # Calculate pagination total = len(conversations) @@ -187,9 +183,7 @@ def get_org_inbox_conversations( ) # Build maps of project ID to project name and short_id for display - projects = db.query(Project).filter( - Project.id.in_(accessible_project_ids) - ).all() + projects = db.query(Project).filter(Project.id.in_(accessible_project_ids)).all() project_info_map = {p.id: {"name": p.name, "short_id": p.short_id} for p in projects} # Collect conversations from all accessible projects @@ -219,9 +213,7 @@ def get_org_inbox_conversations( all_conversations.extend(phases) # Enrich with unread counts - all_conversations = InboxConversationService._enrich_with_unread_counts( - db, user_id, all_conversations - ) + all_conversations = InboxConversationService._enrich_with_unread_counts(db, user_id, all_conversations) # Apply unread_only filter if request.unread_only: @@ -264,13 +256,17 @@ def _get_followed_conversation_ids( conversation_type: InboxConversationType, ) -> list[str]: """Get IDs of conversations the user is following.""" - statuses = db.query(UserConversationStatus).filter( - and_( - UserConversationStatus.user_id == user_id, - UserConversationStatus.conversation_type == conversation_type, - UserConversationStatus.is_followed == True, # noqa: E712 + statuses = ( + db.query(UserConversationStatus) + .filter( + and_( + UserConversationStatus.user_id == user_id, + UserConversationStatus.conversation_type == conversation_type, + UserConversationStatus.is_followed == True, # noqa: E712 + ) ) - ).all() + .all() + ) return [s.conversation_id for s in statuses] @staticmethod @@ -281,14 +277,19 @@ def _get_mentioned_conversation_ids( project_id: UUID, ) -> list[str]: """Get IDs of conversations where the user has unread mentions.""" - mentions = db.query(InboxMention.conversation_id).filter( - and_( - InboxMention.user_id == user_id, - InboxMention.conversation_type == conversation_type, - InboxMention.project_id == project_id, - InboxMention.is_read == False, # noqa: E712 + mentions = ( + db.query(InboxMention.conversation_id) + .filter( + and_( + InboxMention.user_id == user_id, + InboxMention.conversation_type == conversation_type, + InboxMention.project_id == project_id, + InboxMention.is_read == False, # noqa: E712 + ) ) - ).distinct().all() + .distinct() + .all() + ) return [m[0] for m in mentions] @staticmethod @@ -329,21 +330,29 @@ def _get_project_chat_conversations( return [] # Query ProjectChats that are followed or have unread mentions in this project - chats = db.query(ProjectChat).filter( - and_( - ProjectChat.project_id == project_id, - ProjectChat.id.in_(all_uuids), + chats = ( + db.query(ProjectChat) + .filter( + and_( + ProjectChat.project_id == project_id, + ProjectChat.id.in_(all_uuids), + ) ) - ).options( - joinedload(ProjectChat.creator), - ).all() + .options( + joinedload(ProjectChat.creator), + ) + .all() + ) conversations = [] for chat in chats: # Get last message for preview - last_message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.project_chat_id == chat.id - ).order_by(desc(ProjectChatMessage.created_at)).first() + last_message = ( + db.query(ProjectChatMessage) + .filter(ProjectChatMessage.project_chat_id == chat.id) + .order_by(desc(ProjectChatMessage.created_at)) + .first() + ) last_preview = None last_message_at = None @@ -361,20 +370,22 @@ def _get_project_chat_conversations( chat_slug = slugify(title) url_id = f"{chat_slug}-{chat.short_id}" if chat_slug else chat.short_id - conversations.append(UnifiedConversation( - id=str(chat.id), - conversation_type=InboxConversationType.PROJECT_CHAT, - title=title, - last_activity_at=chat.updated_at, - unread_count=0, # Will be enriched later - project_id=project_id, - project_name=project_name, - project_short_id=project_short_id, - last_message_preview=last_preview, - last_message_at=last_message_at, - last_message_author=last_author, - url_identifier=url_id, - )) + conversations.append( + UnifiedConversation( + id=str(chat.id), + conversation_type=InboxConversationType.PROJECT_CHAT, + title=title, + last_activity_at=chat.updated_at, + unread_count=0, # Will be enriched later + project_id=project_id, + project_name=project_name, + project_short_id=project_short_id, + last_message_preview=last_preview, + last_message_at=last_message_at, + last_message_author=last_author, + url_identifier=url_id, + ) + ) return conversations @@ -417,13 +428,17 @@ def _get_feature_conversations( return [] # Query Threads that are followed or have unread mentions, in this project, and are feature threads - threads = db.query(Thread).filter( - and_( - Thread.project_id == str(project_id), - Thread.id.in_([str(u) for u in all_uuids]), - Thread.context_type == ContextType.BRAINSTORM_FEATURE, + threads = ( + db.query(Thread) + .filter( + and_( + Thread.project_id == str(project_id), + Thread.id.in_([str(u) for u in all_uuids]), + Thread.context_type == ContextType.BRAINSTORM_FEATURE, + ) ) - ).all() + .all() + ) conversations = [] for thread in threads: @@ -437,9 +452,9 @@ def _get_feature_conversations( if thread.context_id: try: context_uuid = UUID(thread.context_id) - feature = db.query(Feature).filter( - Feature.id == context_uuid - ).options(joinedload(Feature.module)).first() + feature = ( + db.query(Feature).filter(Feature.id == context_uuid).options(joinedload(Feature.module)).first() + ) if feature: feature_key = feature.feature_key feature_short_id = feature.short_id @@ -449,9 +464,12 @@ def _get_feature_conversations( pass # Invalid UUID, skip feature lookup # Get last message for preview - last_item = db.query(ThreadItem).filter( - ThreadItem.thread_id == thread.id - ).order_by(desc(ThreadItem.created_at)).first() + last_item = ( + db.query(ThreadItem) + .filter(ThreadItem.thread_id == thread.id) + .order_by(desc(ThreadItem.created_at)) + .first() + ) last_preview = None last_message_at = None @@ -474,22 +492,24 @@ def _get_feature_conversations( else: url_id = feature_short_id - conversations.append(UnifiedConversation( - id=str(thread.id), - conversation_type=InboxConversationType.FEATURE, - title=title, - last_activity_at=thread.updated_at, - unread_count=0, # Will be enriched later - project_id=project_id, - project_name=project_name, - project_short_id=project_short_id, - last_message_preview=last_preview, - last_message_at=last_message_at, - last_message_author=last_author, - feature_key=feature_key, - module_title=module_title, - url_identifier=url_id, - )) + conversations.append( + UnifiedConversation( + id=str(thread.id), + conversation_type=InboxConversationType.FEATURE, + title=title, + last_activity_at=thread.updated_at, + unread_count=0, # Will be enriched later + project_id=project_id, + project_name=project_name, + project_short_id=project_short_id, + last_message_preview=last_preview, + last_message_at=last_message_at, + last_message_author=last_author, + feature_key=feature_key, + module_title=module_title, + url_identifier=url_id, + ) + ) return conversations @@ -503,9 +523,7 @@ def _get_phase_conversations( ) -> list[UnifiedConversation]: """Get Phase thread conversations the user follows or has unread mentions in.""" # Get followed conversation IDs (thread IDs) - followed_ids = InboxConversationService._get_followed_conversation_ids( - db, user_id, InboxConversationType.PHASE - ) + followed_ids = InboxConversationService._get_followed_conversation_ids(db, user_id, InboxConversationType.PHASE) # Get conversation IDs with unread mentions mentioned_ids = InboxConversationService._get_mentioned_conversation_ids( @@ -531,23 +549,26 @@ def _get_phase_conversations( if not all_uuids: return [] - from app.models.feature import Feature - from app.models.module import Module - # Query Threads that are followed or have unread mentions, in this project, and are phase threads # Note: BRAINSTORM_FEATURE is handled by _get_feature_conversations, not here - threads = db.query(Thread).filter( - and_( - Thread.project_id == str(project_id), - Thread.id.in_([str(u) for u in all_uuids]), - Thread.context_type.in_([ - ContextType.SPEC, - ContextType.GENERAL, - ContextType.SPEC_DRAFT, - ContextType.PROMPT_PLAN_DRAFT, - ]), + threads = ( + db.query(Thread) + .filter( + and_( + Thread.project_id == str(project_id), + Thread.id.in_([str(u) for u in all_uuids]), + Thread.context_type.in_( + [ + ContextType.SPEC, + ContextType.GENERAL, + ContextType.SPEC_DRAFT, + ContextType.PROMPT_PLAN_DRAFT, + ] + ), + ) ) - ).all() + .all() + ) conversations = [] for thread in threads: @@ -558,9 +579,7 @@ def _get_phase_conversations( if thread.context_id: try: context_uuid = UUID(thread.context_id) - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == context_uuid - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == context_uuid).first() if phase: phase_title = phase.title phase_short_id = phase.short_id @@ -568,9 +587,12 @@ def _get_phase_conversations( pass # Invalid UUID, skip phase lookup # Get last message for preview - last_item = db.query(ThreadItem).filter( - ThreadItem.thread_id == thread.id - ).order_by(desc(ThreadItem.created_at)).first() + last_item = ( + db.query(ThreadItem) + .filter(ThreadItem.thread_id == thread.id) + .order_by(desc(ThreadItem.created_at)) + .first() + ) last_preview = None last_message_at = None @@ -593,21 +615,23 @@ def _get_phase_conversations( else: url_id = phase_short_id - conversations.append(UnifiedConversation( - id=str(thread.id), - conversation_type=InboxConversationType.PHASE, - title=title, - last_activity_at=thread.updated_at, - unread_count=0, # Will be enriched later - project_id=project_id, - project_name=project_name, - project_short_id=project_short_id, - last_message_preview=last_preview, - last_message_at=last_message_at, - last_message_author=last_author, - phase_title=phase_title, - url_identifier=url_id, - )) + conversations.append( + UnifiedConversation( + id=str(thread.id), + conversation_type=InboxConversationType.PHASE, + title=title, + last_activity_at=thread.updated_at, + unread_count=0, # Will be enriched later + project_id=project_id, + project_name=project_name, + project_short_id=project_short_id, + last_message_preview=last_preview, + last_message_at=last_message_at, + last_message_author=last_author, + phase_title=phase_title, + url_identifier=url_id, + ) + ) return conversations @@ -620,20 +644,23 @@ def _enrich_with_unread_counts( """Enrich conversations with unread counts, first unread sequence, and mention status.""" for conv in conversations: # Check if user is actually following this conversation - is_following = InboxStatusService.is_following( - db, user_id, conv.conversation_type, conv.id - ) + is_following = InboxStatusService.is_following(db, user_id, conv.conversation_type, conv.id) # Check for unread mentions in this conversation # Get all unread mentions to count them - unread_mentions = db.query(InboxMention).filter( - and_( - InboxMention.user_id == user_id, - InboxMention.conversation_type == conv.conversation_type, - InboxMention.conversation_id == conv.id, - InboxMention.is_read == False, # noqa: E712 + unread_mentions = ( + db.query(InboxMention) + .filter( + and_( + InboxMention.user_id == user_id, + InboxMention.conversation_type == conv.conversation_type, + InboxMention.conversation_id == conv.id, + InboxMention.is_read == False, # noqa: E712 + ) ) - ).order_by(InboxMention.message_sequence.asc()).all() + .order_by(InboxMention.message_sequence.asc()) + .all() + ) conv.has_unread_mention = len(unread_mentions) > 0 first_unread_mention = unread_mentions[0] if unread_mentions else None @@ -641,9 +668,7 @@ def _enrich_with_unread_counts( if is_following: # User is following: show all unread messages since their watermark if conv.conversation_type == InboxConversationType.PROJECT_CHAT: - conv.unread_count = InboxStatusService.get_unread_count_for_project_chat( - db, user_id, UUID(conv.id) - ) + conv.unread_count = InboxStatusService.get_unread_count_for_project_chat(db, user_id, UUID(conv.id)) conv.first_unread_sequence = InboxStatusService.get_first_unread_sequence_for_project_chat( db, user_id, UUID(conv.id) ) @@ -657,7 +682,10 @@ def _enrich_with_unread_counts( # If there's an unread mention that's earlier than first_unread_sequence, use it if first_unread_mention: - if conv.first_unread_sequence is None or first_unread_mention.message_sequence < conv.first_unread_sequence: + if ( + conv.first_unread_sequence is None + or first_unread_mention.message_sequence < conv.first_unread_sequence + ): conv.first_unread_sequence = first_unread_mention.message_sequence else: # User is NOT following but has mentions: only show unread mention count @@ -700,7 +728,7 @@ def _truncate_preview(text: str) -> str: text = " ".join(text.split()) if len(text) <= PREVIEW_MAX_LENGTH: return text - return text[:PREVIEW_MAX_LENGTH - 3] + "..." + return text[: PREVIEW_MAX_LENGTH - 3] + "..." @staticmethod def resolve_conversation_for_deep_link( @@ -724,17 +752,11 @@ def resolve_conversation_for_deep_link( Dictionary with project and conversation info, or None if not found/unauthorized """ if conversation_type == InboxConversationType.PROJECT_CHAT: - return InboxConversationService._resolve_project_chat( - db, user_id, conversation_id - ) + return InboxConversationService._resolve_project_chat(db, user_id, conversation_id) elif conversation_type == InboxConversationType.FEATURE: - return InboxConversationService._resolve_feature_thread( - db, user_id, conversation_id - ) + return InboxConversationService._resolve_feature_thread(db, user_id, conversation_id) elif conversation_type == InboxConversationType.PHASE: - return InboxConversationService._resolve_phase_thread( - db, user_id, conversation_id - ) + return InboxConversationService._resolve_phase_thread(db, user_id, conversation_id) return None @staticmethod @@ -756,6 +778,7 @@ def _resolve_project_chat( # Verify user has access to the project project = chat.project from app.services.project_service import ProjectService + membership = ProjectService.get_project_membership(db, project.id, user_id) if not membership: return None @@ -793,6 +816,7 @@ def _resolve_feature_thread( # Verify user has access from app.services.project_service import ProjectService + membership = ProjectService.get_project_membership(db, project.id, user_id) if not membership: return None @@ -856,6 +880,7 @@ def _resolve_phase_thread( # Verify user has access from app.services.project_service import ProjectService + membership = ProjectService.get_project_membership(db, project.id, user_id) if not membership: return None @@ -868,9 +893,7 @@ def _resolve_phase_thread( if thread.context_id: try: context_uuid = UUID(thread.context_id) - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == context_uuid - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == context_uuid).first() if phase: phase_title = phase.title phase_short_id = phase.short_id diff --git a/backend/app/services/inbox_follow_service.py b/backend/app/services/inbox_follow_service.py index ef4631e..83698a6 100644 --- a/backend/app/services/inbox_follow_service.py +++ b/backend/app/services/inbox_follow_service.py @@ -1,4 +1,5 @@ """Service layer for inbox follow operations.""" + import logging from typing import Optional from uuid import UUID @@ -32,13 +33,17 @@ def follow_project( Created or existing InboxFollow """ # Check if already following - existing = db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, + existing = ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) ) - ).first() + .first() + ) if existing: return existing @@ -72,13 +77,17 @@ def unfollow_project( Returns: True if unfollowed, False if not following """ - result = db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, + result = ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) ) - ).delete() + .delete() + ) db.commit() @@ -109,15 +118,19 @@ def follow_thread( Created or existing InboxFollow """ # Check if already following - existing = db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.project_id == project_id, - InboxFollow.thread_id == thread_id, - InboxFollow.thread_type == thread_type, - InboxFollow.follow_type == InboxFollowType.THREAD, + existing = ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.project_id == project_id, + InboxFollow.thread_id == thread_id, + InboxFollow.thread_type == thread_type, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) ) - ).first() + .first() + ) if existing: return existing @@ -133,9 +146,7 @@ def follow_thread( db.commit() db.refresh(follow) - logger.info( - f"User {user_id} followed thread {thread_id} (type={thread_type.value})" - ) + logger.info(f"User {user_id} followed thread {thread_id} (type={thread_type.value})") return follow @staticmethod @@ -157,14 +168,18 @@ def unfollow_thread( Returns: True if unfollowed, False if not following """ - result = db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.thread_id == thread_id, - InboxFollow.thread_type == thread_type, - InboxFollow.follow_type == InboxFollowType.THREAD, + result = ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.thread_id == thread_id, + InboxFollow.thread_type == thread_type, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) ) - ).delete() + .delete() + ) db.commit() @@ -190,13 +205,18 @@ def is_following_project( Returns: True if following, False otherwise """ - return db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, + return ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) ) - ).first() is not None + .first() + is not None + ) @staticmethod def is_following_thread( @@ -217,14 +237,19 @@ def is_following_thread( Returns: True if following, False otherwise """ - return db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.thread_id == thread_id, - InboxFollow.thread_type == thread_type, - InboxFollow.follow_type == InboxFollowType.THREAD, + return ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.thread_id == thread_id, + InboxFollow.thread_type == thread_type, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) ) - ).first() is not None + .first() + is not None + ) @staticmethod def get_project_follows( @@ -241,12 +266,16 @@ def get_project_follows( Returns: List of InboxFollow records for project follows """ - return db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, + return ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) ) - ).all() + .all() + ) @staticmethod def get_thread_follows_in_project( @@ -265,13 +294,17 @@ def get_thread_follows_in_project( Returns: List of InboxFollow records for thread follows """ - return db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.THREAD, + return ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) ) - ).all() + .all() + ) @staticmethod def get_users_following_project( @@ -288,12 +321,16 @@ def get_users_following_project( Returns: List of user IDs """ - follows = db.query(InboxFollow.user_id).filter( - and_( - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, + follows = ( + db.query(InboxFollow.user_id) + .filter( + and_( + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) ) - ).all() + .all() + ) return [f.user_id for f in follows] @staticmethod @@ -313,13 +350,17 @@ def get_users_following_thread( Returns: List of user IDs """ - follows = db.query(InboxFollow.user_id).filter( - and_( - InboxFollow.thread_id == thread_id, - InboxFollow.thread_type == thread_type, - InboxFollow.follow_type == InboxFollowType.THREAD, + follows = ( + db.query(InboxFollow.user_id) + .filter( + and_( + InboxFollow.thread_id == thread_id, + InboxFollow.thread_type == thread_type, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) ) - ).all() + .all() + ) return [f.user_id for f in follows] @staticmethod @@ -370,13 +411,17 @@ def get_project_follow( Returns: InboxFollow if exists, None otherwise """ - return db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.project_id == project_id, - InboxFollow.follow_type == InboxFollowType.PROJECT, + return ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.project_id == project_id, + InboxFollow.follow_type == InboxFollowType.PROJECT, + ) ) - ).first() + .first() + ) @staticmethod def get_thread_follow( @@ -397,11 +442,15 @@ def get_thread_follow( Returns: InboxFollow if exists, None otherwise """ - return db.query(InboxFollow).filter( - and_( - InboxFollow.user_id == user_id, - InboxFollow.thread_id == thread_id, - InboxFollow.thread_type == thread_type, - InboxFollow.follow_type == InboxFollowType.THREAD, + return ( + db.query(InboxFollow) + .filter( + and_( + InboxFollow.user_id == user_id, + InboxFollow.thread_id == thread_id, + InboxFollow.thread_type == thread_type, + InboxFollow.follow_type == InboxFollowType.THREAD, + ) ) - ).first() + .first() + ) diff --git a/backend/app/services/inbox_mention_service.py b/backend/app/services/inbox_mention_service.py index e177100..6d6374b 100644 --- a/backend/app/services/inbox_mention_service.py +++ b/backend/app/services/inbox_mention_service.py @@ -1,4 +1,5 @@ """Service layer for inbox mention operations.""" + import logging from datetime import datetime, timezone from typing import Optional @@ -7,7 +8,7 @@ from sqlalchemy import and_, desc from sqlalchemy.orm import Session -from app.models.inbox_mention import InboxMention, InboxConversationType +from app.models.inbox_mention import InboxConversationType, InboxMention from app.services.mention_utils import extract_user_mentions logger = logging.getLogger(__name__) @@ -44,12 +45,16 @@ def create_mention( Created InboxMention """ # Check if mention already exists - existing = db.query(InboxMention).filter( - and_( - InboxMention.user_id == user_id, - InboxMention.message_id == message_id, + existing = ( + db.query(InboxMention) + .filter( + and_( + InboxMention.user_id == user_id, + InboxMention.message_id == message_id, + ) ) - ).first() + .first() + ) if existing: return existing @@ -105,9 +110,7 @@ def create_mentions_from_message( mentioned_user_ids = extract_user_mentions(body_markdown) # Exclude self-mentions - mentioned_user_ids = [ - uid for uid in mentioned_user_ids if str(uid) != str(author_id) - ] + mentioned_user_ids = [uid for uid in mentioned_user_ids if str(uid) != str(author_id)] mentions = [] for user_id in mentioned_user_ids: @@ -143,9 +146,7 @@ def create_mentions_from_message( message_sequence=message_sequence, ) except Exception as e: - logger.warning( - f"Failed to broadcast mention notification for user {user_id}: {e}" - ) + logger.warning(f"Failed to broadcast mention notification for user {user_id}: {e}") except Exception as e: logger.warning(f"Failed to create mention for user {user_id}: {e}") @@ -166,9 +167,7 @@ def mark_mention_read( Returns: Updated InboxMention or None if not found """ - mention = db.query(InboxMention).filter( - InboxMention.id == mention_id - ).first() + mention = db.query(InboxMention).filter(InboxMention.id == mention_id).first() if not mention: return None @@ -201,17 +200,23 @@ def mark_mentions_read_by_user_and_conversation( Number of mentions marked as read """ now = datetime.now(timezone.utc) - result = db.query(InboxMention).filter( - and_( - InboxMention.user_id == user_id, - InboxMention.conversation_type == conversation_type, - InboxMention.conversation_id == conversation_id, - InboxMention.is_read == False, # noqa: E712 + result = ( + db.query(InboxMention) + .filter( + and_( + InboxMention.user_id == user_id, + InboxMention.conversation_type == conversation_type, + InboxMention.conversation_id == conversation_id, + InboxMention.is_read == False, # noqa: E712 + ) + ) + .update( + { + InboxMention.is_read: True, + InboxMention.read_at: now, + } ) - ).update({ - InboxMention.is_read: True, - InboxMention.read_at: now, - }) + ) db.commit() return result @@ -276,9 +281,7 @@ def get_unread_mentions( if project_id: query = query.filter(InboxMention.project_id == project_id) - return query.order_by( - desc(InboxMention.mentioned_at) - ).offset(offset).limit(limit).all() + return query.order_by(desc(InboxMention.mentioned_at)).offset(offset).limit(limit).all() @staticmethod def get_all_mentions( @@ -301,16 +304,12 @@ def get_all_mentions( Returns: List of InboxMention records """ - query = db.query(InboxMention).filter( - InboxMention.user_id == user_id - ) + query = db.query(InboxMention).filter(InboxMention.user_id == user_id) if project_id: query = query.filter(InboxMention.project_id == project_id) - return query.order_by( - desc(InboxMention.mentioned_at) - ).offset(offset).limit(limit).all() + return query.order_by(desc(InboxMention.mentioned_at)).offset(offset).limit(limit).all() @staticmethod def get_mentions_in_conversation( @@ -329,12 +328,17 @@ def get_mentions_in_conversation( Returns: List of InboxMention records """ - return db.query(InboxMention).filter( - and_( - InboxMention.conversation_type == conversation_type, - InboxMention.conversation_id == conversation_id, + return ( + db.query(InboxMention) + .filter( + and_( + InboxMention.conversation_type == conversation_type, + InboxMention.conversation_id == conversation_id, + ) ) - ).order_by(InboxMention.mentioned_at).all() + .order_by(InboxMention.mentioned_at) + .all() + ) @staticmethod def delete_mentions_for_message( @@ -351,8 +355,6 @@ def delete_mentions_for_message( Returns: Number of mentions deleted """ - result = db.query(InboxMention).filter( - InboxMention.message_id == message_id - ).delete() + result = db.query(InboxMention).filter(InboxMention.message_id == message_id).delete() db.commit() return result diff --git a/backend/app/services/inbox_status_service.py b/backend/app/services/inbox_status_service.py index a862f48..13dd1f6 100644 --- a/backend/app/services/inbox_status_service.py +++ b/backend/app/services/inbox_status_service.py @@ -1,17 +1,18 @@ """Service layer for inbox status operations (read watermarks).""" + import logging from datetime import datetime, timezone from typing import Optional from uuid import UUID from sqlalchemy import and_, case, func -from sqlalchemy.orm import Session from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session -from app.models.user_conversation_status import UserConversationStatus from app.models.inbox_mention import InboxConversationType from app.models.project_chat import ProjectChatMessage from app.models.thread_item import ThreadItem +from app.models.user_conversation_status import UserConversationStatus logger = logging.getLogger(__name__) @@ -46,36 +47,46 @@ def update_read_position( Updated or created UserConversationStatus """ # Use INSERT ... ON CONFLICT DO UPDATE for upsert - stmt = insert(UserConversationStatus).values( - user_id=user_id, - conversation_type=conversation_type, - conversation_id=conversation_id, - last_read_sequence=sequence_number, - is_followed=True, - updated_at=datetime.now(timezone.utc), - ).on_conflict_do_update( - constraint="pk_user_conversation_status", - set_={ - # Use CASE instead of GREATEST() for SQLite compatibility in tests - "last_read_sequence": case( - (UserConversationStatus.last_read_sequence > sequence_number, - UserConversationStatus.last_read_sequence), - else_=sequence_number - ), - "updated_at": datetime.now(timezone.utc), - }, + stmt = ( + insert(UserConversationStatus) + .values( + user_id=user_id, + conversation_type=conversation_type, + conversation_id=conversation_id, + last_read_sequence=sequence_number, + is_followed=True, + updated_at=datetime.now(timezone.utc), + ) + .on_conflict_do_update( + constraint="pk_user_conversation_status", + set_={ + # Use CASE instead of GREATEST() for SQLite compatibility in tests + "last_read_sequence": case( + ( + UserConversationStatus.last_read_sequence > sequence_number, + UserConversationStatus.last_read_sequence, + ), + else_=sequence_number, + ), + "updated_at": datetime.now(timezone.utc), + }, + ) ) db.execute(stmt) db.commit() # Fetch the updated status - status = db.query(UserConversationStatus).filter( - and_( - UserConversationStatus.user_id == user_id, - UserConversationStatus.conversation_type == conversation_type, - UserConversationStatus.conversation_id == conversation_id, + status = ( + db.query(UserConversationStatus) + .filter( + and_( + UserConversationStatus.user_id == user_id, + UserConversationStatus.conversation_type == conversation_type, + UserConversationStatus.conversation_id == conversation_id, + ) ) - ).first() + .first() + ) # Broadcast read status change via WebSocket if org_id and status: @@ -91,9 +102,7 @@ def update_read_position( new_watermark=status.last_read_sequence, ) except Exception as e: - logger.warning( - f"Failed to broadcast read status change for user {user_id}: {e}" - ) + logger.warning(f"Failed to broadcast read status change for user {user_id}: {e}") return status @@ -116,13 +125,17 @@ def get_status( Returns: UserConversationStatus or None if not found """ - return db.query(UserConversationStatus).filter( - and_( - UserConversationStatus.user_id == user_id, - UserConversationStatus.conversation_type == conversation_type, - UserConversationStatus.conversation_id == conversation_id, + return ( + db.query(UserConversationStatus) + .filter( + and_( + UserConversationStatus.user_id == user_id, + UserConversationStatus.conversation_type == conversation_type, + UserConversationStatus.conversation_id == conversation_id, + ) ) - ).first() + .first() + ) @staticmethod def get_last_read_sequence( @@ -143,9 +156,7 @@ def get_last_read_sequence( Returns: Last read sequence number (0 if never read) """ - status = InboxStatusService.get_status( - db, user_id, conversation_type, conversation_id - ) + status = InboxStatusService.get_status(db, user_id, conversation_type, conversation_id) return status.last_read_sequence if status else 0 @staticmethod @@ -172,12 +183,16 @@ def get_unread_count_for_project_chat( str(project_chat_id), ) - return db.query(ProjectChatMessage).filter( - and_( - ProjectChatMessage.project_chat_id == project_chat_id, - ProjectChatMessage.sequence_number > last_read, + return ( + db.query(ProjectChatMessage) + .filter( + and_( + ProjectChatMessage.project_chat_id == project_chat_id, + ProjectChatMessage.sequence_number > last_read, + ) ) - ).count() + .count() + ) @staticmethod def get_first_unread_sequence_for_project_chat( @@ -206,12 +221,17 @@ def get_first_unread_sequence_for_project_chat( ) # Get the last (highest) unread sequence - most recent unread message - last_unread = db.query(ProjectChatMessage.sequence_number).filter( - and_( - ProjectChatMessage.project_chat_id == project_chat_id, - ProjectChatMessage.sequence_number > last_read, + last_unread = ( + db.query(ProjectChatMessage.sequence_number) + .filter( + and_( + ProjectChatMessage.project_chat_id == project_chat_id, + ProjectChatMessage.sequence_number > last_read, + ) ) - ).order_by(ProjectChatMessage.sequence_number.desc()).first() + .order_by(ProjectChatMessage.sequence_number.desc()) + .first() + ) return last_unread[0] if last_unread else None @@ -241,12 +261,16 @@ def get_unread_count_for_thread( thread_id, ) - return db.query(ThreadItem).filter( - and_( - ThreadItem.thread_id == thread_id, - ThreadItem.sequence_number > last_read, + return ( + db.query(ThreadItem) + .filter( + and_( + ThreadItem.thread_id == thread_id, + ThreadItem.sequence_number > last_read, + ) ) - ).count() + .count() + ) @staticmethod def get_first_unread_sequence_for_thread( @@ -277,12 +301,17 @@ def get_first_unread_sequence_for_thread( ) # Get the last (highest) unread sequence - most recent unread message - last_unread = db.query(ThreadItem.sequence_number).filter( - and_( - ThreadItem.thread_id == thread_id, - ThreadItem.sequence_number > last_read, + last_unread = ( + db.query(ThreadItem.sequence_number) + .filter( + and_( + ThreadItem.thread_id == thread_id, + ThreadItem.sequence_number > last_read, + ) ) - ).order_by(ThreadItem.sequence_number.desc()).first() + .order_by(ThreadItem.sequence_number.desc()) + .first() + ) return last_unread[0] if last_unread else None @@ -308,30 +337,38 @@ def toggle_follow( Updated UserConversationStatus """ # Upsert with follow status - stmt = insert(UserConversationStatus).values( - user_id=user_id, - conversation_type=conversation_type, - conversation_id=conversation_id, - last_read_sequence=0, - is_followed=follow, - updated_at=datetime.now(timezone.utc), - ).on_conflict_do_update( - constraint="pk_user_conversation_status", - set_={ - "is_followed": follow, - "updated_at": datetime.now(timezone.utc), - }, + stmt = ( + insert(UserConversationStatus) + .values( + user_id=user_id, + conversation_type=conversation_type, + conversation_id=conversation_id, + last_read_sequence=0, + is_followed=follow, + updated_at=datetime.now(timezone.utc), + ) + .on_conflict_do_update( + constraint="pk_user_conversation_status", + set_={ + "is_followed": follow, + "updated_at": datetime.now(timezone.utc), + }, + ) ) db.execute(stmt) db.commit() - return db.query(UserConversationStatus).filter( - and_( - UserConversationStatus.user_id == user_id, - UserConversationStatus.conversation_type == conversation_type, - UserConversationStatus.conversation_id == conversation_id, + return ( + db.query(UserConversationStatus) + .filter( + and_( + UserConversationStatus.user_id == user_id, + UserConversationStatus.conversation_type == conversation_type, + UserConversationStatus.conversation_id == conversation_id, + ) ) - ).first() + .first() + ) @staticmethod def is_following( @@ -352,9 +389,7 @@ def is_following( Returns: True if following, False otherwise """ - status = InboxStatusService.get_status( - db, user_id, conversation_type, conversation_id - ) + status = InboxStatusService.get_status(db, user_id, conversation_type, conversation_id) return status.is_followed if status else False @staticmethod @@ -382,9 +417,7 @@ def get_followed_conversations( ) if conversation_type: - query = query.filter( - UserConversationStatus.conversation_type == conversation_type - ) + query = query.filter(UserConversationStatus.conversation_type == conversation_type) return query.all() @@ -411,13 +444,15 @@ def mark_all_as_read( """ # Find max sequence based on conversation type if conversation_type == InboxConversationType.PROJECT_CHAT: - max_seq = db.query(func.max(ProjectChatMessage.sequence_number)).filter( - ProjectChatMessage.project_chat_id == conversation_id - ).scalar() + max_seq = ( + db.query(func.max(ProjectChatMessage.sequence_number)) + .filter(ProjectChatMessage.project_chat_id == conversation_id) + .scalar() + ) else: - max_seq = db.query(func.max(ThreadItem.sequence_number)).filter( - ThreadItem.thread_id == conversation_id - ).scalar() + max_seq = ( + db.query(func.max(ThreadItem.sequence_number)).filter(ThreadItem.thread_id == conversation_id).scalar() + ) max_seq = max_seq or 0 diff --git a/backend/app/services/integration_config_share_service.py b/backend/app/services/integration_config_share_service.py index df9f4b6..d95d35c 100644 --- a/backend/app/services/integration_config_share_service.py +++ b/backend/app/services/integration_config_share_service.py @@ -1,4 +1,5 @@ """Integration config share service.""" + from uuid import UUID from sqlalchemy import select @@ -48,9 +49,7 @@ async def create_share( # Check if share already exists existing = await self.get_share(integration_config_id, subject_type, subject_id) if existing: - raise ValueError( - f"Share already exists for {subject_type.value} {subject_id}" - ) + raise ValueError(f"Share already exists for {subject_type.value} {subject_id}") # Validate subject exists if subject_type == IntegrationShareSubjectType.USER: @@ -118,9 +117,7 @@ async def get_share_by_id(self, share_id: UUID) -> IntegrationConfigShare | None result = await self.db.execute(stmt) return result.scalar_one_or_none() - async def list_shares( - self, integration_config_id: UUID - ) -> list[IntegrationConfigShare]: + async def list_shares(self, integration_config_id: UUID) -> list[IntegrationConfigShare]: """List all shares for an integration config. Args: @@ -132,9 +129,7 @@ async def list_shares( stmt = ( select(IntegrationConfigShare) .options(selectinload(IntegrationConfigShare.created_by)) - .where( - IntegrationConfigShare.integration_config_id == integration_config_id - ) + .where(IntegrationConfigShare.integration_config_id == integration_config_id) .order_by(IntegrationConfigShare.created_at) ) result = await self.db.execute(stmt) @@ -176,9 +171,7 @@ async def delete_shares_for_config(self, integration_config_id: UUID) -> int: await self.db.commit() return count - async def enrich_share_with_subject( - self, share: IntegrationConfigShare - ) -> dict: + async def enrich_share_with_subject(self, share: IntegrationConfigShare) -> dict: """Enrich a share with user or group details. Args: @@ -215,11 +208,10 @@ async def enrich_share_with_subject( if group: # Get member count from sqlalchemy import func + from app.models.user_group_membership import UserGroupMembership - count_stmt = select(func.count()).where( - UserGroupMembership.group_id == group.id - ) + count_stmt = select(func.count()).where(UserGroupMembership.group_id == group.id) count_result = await self.db.execute(count_stmt) member_count = count_result.scalar() or 0 diff --git a/backend/app/services/integration_service.py b/backend/app/services/integration_service.py index 081fed6..9bdb2d8 100644 --- a/backend/app/services/integration_service.py +++ b/backend/app/services/integration_service.py @@ -1,4 +1,5 @@ """Integration config service.""" + import base64 import os from typing import Any @@ -28,9 +29,7 @@ def __init__(self, db: AsyncSession, encryption_key: str | None = None): encryption_key: Encryption key (defaults to env var ENCRYPTION_KEY) """ self.db = db - self._encryption_key = encryption_key or os.getenv( - "ENCRYPTION_KEY", "default-insecure-key-change-me" - ) + self._encryption_key = encryption_key or os.getenv("ENCRYPTION_KEY", "default-insecure-key-change-me") self._fernet = self._get_fernet() def _get_fernet(self) -> Fernet: @@ -125,9 +124,7 @@ async def create_config( return config - async def get_config( - self, organization_id: UUID, provider: str - ) -> IntegrationConfig | None: + async def get_config(self, organization_id: UUID, provider: str) -> IntegrationConfig | None: """Get integration config for an organization and provider (legacy method). Note: This method returns the first config for the provider. @@ -147,9 +144,7 @@ async def get_config( result = await self.db.execute(stmt) return result.scalar_one_or_none() - async def get_config_by_id( - self, organization_id: UUID, config_id: UUID - ) -> IntegrationConfig | None: + async def get_config_by_id(self, organization_id: UUID, config_id: UUID) -> IntegrationConfig | None: """Get integration config by ID. Args: @@ -195,26 +190,21 @@ async def list_configs( # For regular users, filter by visibility # Get user's group IDs for group share checking - group_stmt = select(UserGroupMembership.group_id).where( - UserGroupMembership.user_id == user_id - ) + group_stmt = select(UserGroupMembership.group_id).where(UserGroupMembership.user_id == user_id) group_result = await self.db.execute(group_stmt) user_group_ids = [row[0] for row in group_result.fetchall()] # Build subquery for configs user has explicit shares to - share_subquery = ( - select(IntegrationConfigShare.integration_config_id) - .where( - or_( - # Direct user share - (IntegrationConfigShare.subject_type == IntegrationShareSubjectType.USER) - & (IntegrationConfigShare.subject_id == user_id), - # Group share (if user is in any groups) - (IntegrationConfigShare.subject_type == IntegrationShareSubjectType.GROUP) - & (IntegrationConfigShare.subject_id.in_(user_group_ids)) - if user_group_ids - else False, - ) + share_subquery = select(IntegrationConfigShare.integration_config_id).where( + or_( + # Direct user share + (IntegrationConfigShare.subject_type == IntegrationShareSubjectType.USER) + & (IntegrationConfigShare.subject_id == user_id), + # Group share (if user is in any groups) + (IntegrationConfigShare.subject_type == IntegrationShareSubjectType.GROUP) + & (IntegrationConfigShare.subject_id.in_(user_group_ids)) + if user_group_ids + else False, ) ) @@ -281,9 +271,7 @@ async def update_config( return config - async def delete_config( - self, organization_id: UUID, config_id: UUID - ) -> bool: + async def delete_config(self, organization_id: UUID, config_id: UUID) -> bool: """Delete an integration config. Args: @@ -317,10 +305,7 @@ async def get_adapter(self, organization_id: UUID, provider: str): """ config = await self.get_config(organization_id, provider) if not config: - raise ValueError( - f"No integration config found for organization {organization_id} " - f"and provider {provider}" - ) + raise ValueError(f"No integration config found for organization {organization_id} and provider {provider}") # Decrypt token token = self._decrypt_token(config.encrypted_token) @@ -367,9 +352,7 @@ async def can_user_access( return True # Check for group share - group_stmt = select(UserGroupMembership.group_id).where( - UserGroupMembership.user_id == user_id - ) + group_stmt = select(UserGroupMembership.group_id).where(UserGroupMembership.user_id == user_id) group_result = await self.db.execute(group_stmt) user_group_ids = [row[0] for row in group_result.fetchall()] @@ -422,8 +405,6 @@ async def get_share_count(self, config_id: UUID) -> int: """ from sqlalchemy import func - stmt = select(func.count()).where( - IntegrationConfigShare.integration_config_id == config_id - ) + stmt = select(func.count()).where(IntegrationConfigShare.integration_config_id == config_id) result = await self.db.execute(stmt) return result.scalar() or 0 diff --git a/backend/app/services/invitation_service.py b/backend/app/services/invitation_service.py index 7480cb0..c548f2c 100644 --- a/backend/app/services/invitation_service.py +++ b/backend/app/services/invitation_service.py @@ -99,11 +99,7 @@ def get_invitation_by_id(db: Session, invitation_id: UUID) -> Optional[OrgInvita Returns: OrgInvitation or None if not found """ - return ( - db.query(OrgInvitation) - .filter(OrgInvitation.id == invitation_id) - .first() - ) + return db.query(OrgInvitation).filter(OrgInvitation.id == invitation_id).first() @staticmethod def get_invitation_by_token(db: Session, token: str) -> Optional[OrgInvitation]: @@ -116,16 +112,10 @@ def get_invitation_by_token(db: Session, token: str) -> Optional[OrgInvitation]: Returns: OrgInvitation or None if not found """ - return ( - db.query(OrgInvitation) - .filter(OrgInvitation.token == token) - .first() - ) + return db.query(OrgInvitation).filter(OrgInvitation.token == token).first() @staticmethod - def validate_invitation( - db: Session, token: str - ) -> tuple[Optional[OrgInvitation], Optional[str]]: + def validate_invitation(db: Session, token: str) -> tuple[Optional[OrgInvitation], Optional[str]]: """Validate an invitation token. Args: @@ -165,11 +155,7 @@ def cancel_invitation(db: Session, invitation_id: UUID) -> Optional[OrgInvitatio Returns: Updated OrgInvitation or None if not found or not pending """ - invitation = ( - db.query(OrgInvitation) - .filter(OrgInvitation.id == invitation_id) - .first() - ) + invitation = db.query(OrgInvitation).filter(OrgInvitation.id == invitation_id).first() if not invitation or invitation.status != InvitationStatus.PENDING: return None @@ -197,11 +183,7 @@ def mark_as_accepted( Returns: Updated OrgInvitation or None if not found """ - invitation = ( - db.query(OrgInvitation) - .filter(OrgInvitation.id == invitation_id) - .first() - ) + invitation = db.query(OrgInvitation).filter(OrgInvitation.id == invitation_id).first() if not invitation: return None @@ -289,9 +271,7 @@ def accept_invitation( group = group_assignment.group if group: # Check if already a member (skip if exists) - if not UserGroupService.is_user_in_group( - db, group.id, accepting_user.id - ): + if not UserGroupService.is_user_in_group(db, group.id, accepting_user.id): try: UserGroupService.add_member( db=db, @@ -303,9 +283,7 @@ def accept_invitation( except IntegrityError: # Race condition: user was added between check and insert db.rollback() - logger.debug( - f"User {accepting_user.id} already in group {group.id}, skipping" - ) + logger.debug(f"User {accepting_user.id} already in group {group.id}, skipping") # Step 5: Mark invitation as accepted InvitationService.mark_as_accepted( @@ -346,9 +324,7 @@ def list_org_invitations( return query.order_by(OrgInvitation.created_at.desc()).all() @staticmethod - def get_pending_invitation_for_email( - db: Session, org_id: UUID, email: str - ) -> Optional[OrgInvitation]: + def get_pending_invitation_for_email(db: Session, org_id: UUID, email: str) -> Optional[OrgInvitation]: """Get a pending invitation for a specific email in an org. Args: diff --git a/backend/app/services/job_service.py b/backend/app/services/job_service.py index f3dbb61..e836237 100644 --- a/backend/app/services/job_service.py +++ b/backend/app/services/job_service.py @@ -3,18 +3,16 @@ Handles creation, updating, and querying of jobs. """ -import asyncio -from datetime import datetime, UTC + +from datetime import UTC, datetime from typing import List, Optional from uuid import UUID from sqlalchemy.orm import Session -from app.config import settings -from app.models.job import Job, JobType, JobStatus +from app.models.job import Job, JobStatus, JobType from app.schemas.job import JobListResponse - # Job timeout configuration (minutes) # Used by the stuck job monitor to determine when a RUNNING job is considered stuck JOB_TIMEOUT_MINUTES: dict[JobType, int] = { @@ -23,13 +21,11 @@ JobType.MENTION_NOTIFICATION: 5, # Email sending only JobType.GROUNDING_UPDATE: 10, JobType.GROUNDING_SUMMARIZE: 10, - # Medium jobs - 20 min (expected: 2-10 min) JobType.COLLAB_THREAD_DECISION_SUMMARIZE: 20, JobType.COLLAB_THREAD_AI_MENTION: 20, JobType.BUG_SYNC: 20, JobType.PHASE_ANALYSIS: 20, - # Long LLM jobs - 45 min (expected: 10-30 min) JobType.BRAINSTORM_CONVERSATION_GENERATE: 45, JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE: 45, @@ -38,12 +34,10 @@ JobType.MODULE_FEATURE_GENERATE: 45, JobType.FEATURE_CONTENT_GENERATE: 45, JobType.USER_INITIATED_QUESTION_GENERATE: 30, - # Slack jobs - 20 min (expected: 2-10 min) JobType.SLACK_AI_MENTION: 20, JobType.SLACK_CODE_EXPLORE: 20, JobType.SLACK_WEB_SEARCH: 20, - # Archive - 60 min (can be slow for large projects) JobType.ARCHIVE_EXPORT: 60, } @@ -190,6 +184,7 @@ def update_job_status( # Broadcast job update via WebSocket if job has an org_id import logging + logger = logging.getLogger(__name__) logger.info(f"update_job_status: job_id={job.id}, org_id={job.org_id}, status={status}") @@ -220,6 +215,7 @@ def _broadcast_job_update(job: Job, db: Session = None): db: Optional database session for enriching with project name """ import logging + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -420,22 +416,24 @@ def mark_stuck_jobs_failed(db: Session) -> int: Number of jobs marked as failed """ import logging - from datetime import timedelta, timezone + from datetime import timedelta logger = logging.getLogger(__name__) now = datetime.now(UTC) stuck_count = 0 # ======= Check RUNNING jobs ======= - running_jobs = db.query(Job).filter( - Job.status == JobStatus.RUNNING, - Job.started_at.isnot(None), - ).all() + running_jobs = ( + db.query(Job) + .filter( + Job.status == JobStatus.RUNNING, + Job.started_at.isnot(None), + ) + .all() + ) for job in running_jobs: - timeout_minutes = JOB_TIMEOUT_MINUTES.get( - job.job_type, DEFAULT_TIMEOUT_MINUTES - ) + timeout_minutes = JOB_TIMEOUT_MINUTES.get(job.job_type, DEFAULT_TIMEOUT_MINUTES) # Handle both naive (SQLite) and aware (Postgres) datetimes started_at = job.started_at @@ -448,8 +446,7 @@ def mark_stuck_jobs_failed(db: Session) -> int: job.status = JobStatus.FAILED job.finished_at = now job.error_message = ( - f"Job exceeded {timeout_minutes} minute timeout " - f"(started at {job.started_at.isoformat()})" + f"Job exceeded {timeout_minutes} minute timeout (started at {job.started_at.isoformat()})" ) stuck_count += 1 @@ -467,9 +464,13 @@ def mark_stuck_jobs_failed(db: Session) -> int: JobService._clear_phase_generation_flag_for_job(db, job, logger) # ======= Check QUEUED jobs ======= - queued_jobs = db.query(Job).filter( - Job.status == JobStatus.QUEUED, - ).all() + queued_jobs = ( + db.query(Job) + .filter( + Job.status == JobStatus.QUEUED, + ) + .all() + ) for job in queued_jobs: # Handle both naive (SQLite) and aware (Postgres) datetimes @@ -528,15 +529,11 @@ def _clear_phase_generation_flag_for_job(db: Session, job: Job, logger) -> None: try: from app.models.brainstorming_phase import BrainstormingPhase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == UUID(phase_id_str) - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == UUID(phase_id_str)).first() if phase: setattr(phase, flag_name, False) - logger.info( - f"Cleared {flag_name} for stuck job {job.id} (phase {phase_id_str})" - ) + logger.info(f"Cleared {flag_name} for stuck job {job.id} (phase {phase_id_str})") except Exception as e: logger.error(f"Failed to clear phase generation flag for stuck job: {e}") @@ -580,9 +577,7 @@ def cancel_jobs_for_implementation(db: Session, implementation_id: UUID) -> int: job.error_message = "Cancelled due to start-over operation" cancelled_count += 1 - logger.info( - f"Cancelled job {job.id} for implementation {implementation_id}" - ) + logger.info(f"Cancelled job {job.id} for implementation {implementation_id}") # Broadcast job update to WebSocket clients if job.org_id: @@ -634,9 +629,7 @@ def cancel_jobs_for_project_chat(db: Session, project_chat_id: UUID) -> int: job.error_message = "Cancelled by user" cancelled_count += 1 - logger.info( - f"Cancelled job {job.id} (type={job.job_type}) for discussion {project_chat_id}" - ) + logger.info(f"Cancelled job {job.id} (type={job.job_type}) for discussion {project_chat_id}") # Broadcast job update to WebSocket clients if job.org_id: @@ -688,9 +681,7 @@ def cancel_jobs_for_thread(db: Session, thread_id: UUID) -> int: job.error_message = "Cancelled by user" cancelled_count += 1 - logger.info( - f"Cancelled job {job.id} (type={job.job_type}) for thread {thread_id}" - ) + logger.info(f"Cancelled job {job.id} (type={job.job_type}) for thread {thread_id}") # Broadcast job update to WebSocket clients if job.org_id: @@ -699,9 +690,7 @@ def cancel_jobs_for_thread(db: Session, thread_id: UUID) -> int: return cancelled_count @staticmethod - def cancel_jobs_for_phase_generation( - db: Session, phase_id: UUID, generation_type: str - ) -> int: + def cancel_jobs_for_phase_generation(db: Session, phase_id: UUID, generation_type: str) -> int: """ Cancel all QUEUED/RUNNING jobs for a brainstorming phase generation. @@ -750,12 +739,10 @@ def cancel_jobs_for_phase_generation( job.error_message = "Cancelled by user" cancelled_count += 1 - logger.info( - f"Cancelled job {job.id} (type={job.job_type}) for phase {phase_id}" - ) + logger.info(f"Cancelled job {job.id} (type={job.job_type}) for phase {phase_id}") # Broadcast job update to WebSocket clients if job.org_id: JobService._broadcast_job_update(job, db=db) - return cancelled_count \ No newline at end of file + return cancelled_count diff --git a/backend/app/services/kafka_producer.py b/backend/app/services/kafka_producer.py index 502a7b0..7764675 100644 --- a/backend/app/services/kafka_producer.py +++ b/backend/app/services/kafka_producer.py @@ -3,6 +3,7 @@ Handles publishing messages to Kafka topics. """ + import asyncio import json import logging @@ -14,7 +15,8 @@ from aiokafka import AIOKafkaProducer from aiokafka.errors import KafkaError from kafka import KafkaProducer -from kafka.errors import KafkaError as SyncKafkaError, NoBrokersAvailable +from kafka.errors import KafkaError as SyncKafkaError +from kafka.errors import NoBrokersAvailable from app.config import settings @@ -112,12 +114,10 @@ async def start(self): backoff = min(backoff * self.backoff_multiplier, self.max_backoff) logger.error( - f"Failed to start Kafka producer after {self.max_retries} attempts. " - f"Kafka brokers: {self.bootstrap_servers}" + f"Failed to start Kafka producer after {self.max_retries} attempts. Kafka brokers: {self.bootstrap_servers}" ) raise KafkaConnectionError( - f"Unable to connect to Kafka at {self.bootstrap_servers} " - f"after {self.max_retries} retries" + f"Unable to connect to Kafka at {self.bootstrap_servers} after {self.max_retries} retries" ) async def stop(self): @@ -310,8 +310,7 @@ def start(self) -> None: f"Kafka brokers: {self.bootstrap_servers}" ) raise KafkaConnectionError( - f"Unable to connect to Kafka at {self.bootstrap_servers} " - f"after {self.max_retries} retries" + f"Unable to connect to Kafka at {self.bootstrap_servers} after {self.max_retries} retries" ) def stop(self) -> None: diff --git a/backend/app/services/llm_adapters.py b/backend/app/services/llm_adapters.py index 23eef8e..1ed525b 100644 --- a/backend/app/services/llm_adapters.py +++ b/backend/app/services/llm_adapters.py @@ -1,4 +1,5 @@ """LLM provider adapters for testing connections.""" + from abc import ABC, abstractmethod from typing import Any @@ -232,9 +233,10 @@ class AWSBedrockAdapter(LLMAdapter): async def test_connection(self, prompt: str = "2+2=?") -> dict[str, Any]: """Test AWS Bedrock API connection.""" try: - import boto3 import json + import boto3 + # Bedrock requires region in config region = self.config.get("region", "us-east-1") model_id = self.config.get("model", "anthropic.claude-3-sonnet-20240229-v1:0") @@ -260,16 +262,20 @@ async def test_connection(self, prompt: str = "2+2=?") -> dict[str, Any]: # Prepare request body based on model if "anthropic" in model_id: - body = json.dumps({ - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": self.config.get("max_tokens", 100), - "messages": [{"role": "user", "content": prompt}], - }) + body = json.dumps( + { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": self.config.get("max_tokens", 100), + "messages": [{"role": "user", "content": prompt}], + } + ) elif "meta" in model_id: # Llama models - body = json.dumps({ - "prompt": prompt, - "max_gen_len": self.config.get("max_tokens", 100), - }) + body = json.dumps( + { + "prompt": prompt, + "max_gen_len": self.config.get("max_tokens", 100), + } + ) else: return { "success": False, @@ -330,9 +336,6 @@ def get_llm_adapter(provider: str, api_key: str, config: dict[str, Any] | None = adapter_class = adapters.get(provider.lower()) if not adapter_class: - raise ValueError( - f"Unsupported LLM provider: {provider}. " - f"Supported providers: {', '.join(adapters.keys())}" - ) + raise ValueError(f"Unsupported LLM provider: {provider}. Supported providers: {', '.join(adapters.keys())}") return adapter_class(api_key=api_key, config=config) diff --git a/backend/app/services/llm_call_log_service.py b/backend/app/services/llm_call_log_service.py index e92a4e3..c399ca2 100644 --- a/backend/app/services/llm_call_log_service.py +++ b/backend/app/services/llm_call_log_service.py @@ -1,16 +1,16 @@ """ Service for managing LLM call logs. """ + from datetime import datetime, timezone -from typing import Optional, List -from uuid import UUID from decimal import Decimal +from typing import List, Optional +from uuid import UUID from sqlalchemy.orm import Session -from sqlalchemy import select -from app.models.llm_call_log import LLMCallLog from app.models.job import Job, JobType +from app.models.llm_call_log import LLMCallLog class LLMCallLogService: @@ -72,9 +72,7 @@ def get_call_log(db: Session, call_log_id: UUID) -> Optional[LLMCallLog]: @staticmethod def list_call_logs_for_job(db: Session, job_id: UUID) -> List[LLMCallLog]: """List all call logs for a specific job, ordered by start time.""" - return db.query(LLMCallLog).filter( - LLMCallLog.job_id == job_id - ).order_by(LLMCallLog.started_at).all() + return db.query(LLMCallLog).filter(LLMCallLog.job_id == job_id).order_by(LLMCallLog.started_at).all() @staticmethod def list_jobs_with_call_logs( diff --git a/backend/app/services/llm_mock.py b/backend/app/services/llm_mock.py index 1ebeb92..247993d 100644 --- a/backend/app/services/llm_mock.py +++ b/backend/app/services/llm_mock.py @@ -13,7 +13,6 @@ from app.models.discovery import DiscoveryQuestion - # Question templates by category QUESTION_TEMPLATES = { "architecture": [ @@ -252,9 +251,7 @@ def generate_discovery_questions_mock( rng = random.Random(seed) # Normalize existing question texts for comparison - existing_normalized = { - _normalize_text(q.question_text) for q in existing_questions - } + existing_normalized = {_normalize_text(q.question_text) for q in existing_questions} # Generate one question per category categories = ["architecture", "data", "ux", "ui", "security"] diff --git a/backend/app/services/llm_mock_prompt_plan.py b/backend/app/services/llm_mock_prompt_plan.py index 168c301..62128a4 100644 --- a/backend/app/services/llm_mock_prompt_plan.py +++ b/backend/app/services/llm_mock_prompt_plan.py @@ -1,7 +1,6 @@ """Mock LLM service for generating prompt plans.""" import hashlib -import json import random import re from uuid import UUID @@ -72,7 +71,9 @@ def _parse_spec_sections(spec_markdown: str) -> dict: sections["data_model"] = "data model" in content_lower or "database" in content_lower or "entities" in content_lower sections["architecture"] = "architecture" in content_lower or "system design" in content_lower sections["api"] = "api" in content_lower or "endpoints" in content_lower or "rest" in content_lower - sections["security"] = "security" in content_lower or "authentication" in content_lower or "authorization" in content_lower + sections["security"] = ( + "security" in content_lower or "authentication" in content_lower or "authorization" in content_lower + ) sections["performance"] = "performance" in content_lower or "scalability" in content_lower sections["testing"] = "testing" in content_lower or "test strategy" in content_lower @@ -90,8 +91,8 @@ def _extract_project_name(spec_markdown: str, sections: dict, rng: random.Random if line.startswith("# "): title = line[2:].strip() # Remove common suffixes - title = re.sub(r'\s*[-–—:]\s*(Technical\s+)?Specification.*$', '', title, flags=re.IGNORECASE) - title = re.sub(r'\s*Spec(ification)?$', '', title, flags=re.IGNORECASE) + title = re.sub(r"\s*[-–—:]\s*(Technical\s+)?Specification.*$", "", title, flags=re.IGNORECASE) + title = re.sub(r"\s*Spec(ification)?$", "", title, flags=re.IGNORECASE) return title.strip() if title.strip() else "Application" return "Application" @@ -103,98 +104,118 @@ def _generate_phases(sections: dict, rng: random.Random) -> list[dict]: phase_index = 1 # Phase 1: Always start with setup - phases.append({ - "phase_index": phase_index, - "title": "Project Setup & Infrastructure", - "description": "Initialize repository, setup development environment, configure build tools, establish CI/CD pipeline, and create basic project structure.", - "test_plan": "- Repository initialized with README and basic documentation\n- Build scripts execute successfully\n- CI/CD pipeline passes health checks\n- Development environment setup documented", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Project Setup & Infrastructure", + "description": "Initialize repository, setup development environment, configure build tools, establish CI/CD pipeline, and create basic project structure.", + "test_plan": "- Repository initialized with README and basic documentation\n- Build scripts execute successfully\n- CI/CD pipeline passes health checks\n- Development environment setup documented", + } + ) phase_index += 1 # Phase 2: Database/Data Model if present if sections["data_model"]: - phases.append({ - "phase_index": phase_index, - "title": "Database Schema & Models", - "description": "Design and implement database schema, create ORM models for core entities, setup migrations, and establish relationships between entities.", - "test_plan": "- Database migrations apply cleanly\n- All models have unit tests\n- Relationships function correctly\n- CRUD operations work for all entities", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Database Schema & Models", + "description": "Design and implement database schema, create ORM models for core entities, setup migrations, and establish relationships between entities.", + "test_plan": "- Database migrations apply cleanly\n- All models have unit tests\n- Relationships function correctly\n- CRUD operations work for all entities", + } + ) phase_index += 1 # Phase 3-5: Core API implementation if sections["api"]: - phases.append({ - "phase_index": phase_index, - "title": "Authentication & User Management", - "description": "Implement user registration, login/logout, password management, and session handling. Setup JWT or session-based authentication.", - "test_plan": "- Users can register with valid credentials\n- Login returns proper tokens/sessions\n- Protected endpoints require authentication\n- Password hashing works correctly", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Authentication & User Management", + "description": "Implement user registration, login/logout, password management, and session handling. Setup JWT or session-based authentication.", + "test_plan": "- Users can register with valid credentials\n- Login returns proper tokens/sessions\n- Protected endpoints require authentication\n- Password hashing works correctly", + } + ) phase_index += 1 - phases.append({ - "phase_index": phase_index, - "title": "Core Business Logic & Services", - "description": "Implement primary business logic, service layer classes, domain-specific operations, and data validation rules.", - "test_plan": "- Service methods have comprehensive unit tests\n- Business rules are enforced correctly\n- Error handling covers edge cases\n- Data validation prevents invalid states", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Core Business Logic & Services", + "description": "Implement primary business logic, service layer classes, domain-specific operations, and data validation rules.", + "test_plan": "- Service methods have comprehensive unit tests\n- Business rules are enforced correctly\n- Error handling covers edge cases\n- Data validation prevents invalid states", + } + ) phase_index += 1 - phases.append({ - "phase_index": phase_index, - "title": "API Endpoints Implementation", - "description": "Develop RESTful API endpoints for core resources, implement request/response handling, add input validation, and setup proper HTTP status codes.", - "test_plan": "- All endpoints return correct status codes\n- Request validation catches invalid input\n- Response format is consistent\n- Integration tests cover happy and error paths", - }) + phases.append( + { + "phase_index": phase_index, + "title": "API Endpoints Implementation", + "description": "Develop RESTful API endpoints for core resources, implement request/response handling, add input validation, and setup proper HTTP status codes.", + "test_plan": "- All endpoints return correct status codes\n- Request validation catches invalid input\n- Response format is consistent\n- Integration tests cover happy and error paths", + } + ) phase_index += 1 # Phase: Authorization/Permissions if security mentioned if sections["security"]: - phases.append({ - "phase_index": phase_index, - "title": "Authorization & Security", - "description": "Implement role-based access control (RBAC), permission checks, secure sensitive endpoints, add rate limiting, and ensure proper security headers.", - "test_plan": "- Role-based permissions enforced correctly\n- Unauthorized access returns 403\n- Security headers present in responses\n- Rate limiting prevents abuse", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Authorization & Security", + "description": "Implement role-based access control (RBAC), permission checks, secure sensitive endpoints, add rate limiting, and ensure proper security headers.", + "test_plan": "- Role-based permissions enforced correctly\n- Unauthorized access returns 403\n- Security headers present in responses\n- Rate limiting prevents abuse", + } + ) phase_index += 1 # Phase: Performance optimization if mentioned if sections["performance"]: - phases.append({ - "phase_index": phase_index, - "title": "Performance Optimization & Caching", - "description": "Add caching layer, optimize database queries, implement connection pooling, add indexes, and tune application performance.", - "test_plan": "- Response times meet performance targets\n- Database queries use appropriate indexes\n- Cache hit rates are acceptable\n- Load testing shows scalability", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Performance Optimization & Caching", + "description": "Add caching layer, optimize database queries, implement connection pooling, add indexes, and tune application performance.", + "test_plan": "- Response times meet performance targets\n- Database queries use appropriate indexes\n- Cache hit rates are acceptable\n- Load testing shows scalability", + } + ) phase_index += 1 # Phase: Testing infrastructure if mentioned if sections["testing"]: - phases.append({ - "phase_index": phase_index, - "title": "Testing Infrastructure & Coverage", - "description": "Enhance test coverage, add integration tests, setup E2E testing, implement test fixtures, and ensure critical paths are fully tested.", - "test_plan": "- Code coverage meets minimum threshold (80%+)\n- Critical paths have 100% coverage\n- Integration tests cover main workflows\n- E2E tests validate user journeys", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Testing Infrastructure & Coverage", + "description": "Enhance test coverage, add integration tests, setup E2E testing, implement test fixtures, and ensure critical paths are fully tested.", + "test_plan": "- Code coverage meets minimum threshold (80%+)\n- Critical paths have 100% coverage\n- Integration tests cover main workflows\n- E2E tests validate user journeys", + } + ) phase_index += 1 # Phase: Documentation (always include) - phases.append({ - "phase_index": phase_index, - "title": "Documentation & Deployment", - "description": "Complete API documentation, write deployment guides, create user documentation, setup monitoring and logging, and prepare for production deployment.", - "test_plan": "- API documentation is complete and accurate\n- Deployment guide enables successful setup\n- Monitoring dashboards are functional\n- Health check endpoints respond correctly", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Documentation & Deployment", + "description": "Complete API documentation, write deployment guides, create user documentation, setup monitoring and logging, and prepare for production deployment.", + "test_plan": "- API documentation is complete and accurate\n- Deployment guide enables successful setup\n- Monitoring dashboards are functional\n- Health check endpoints respond correctly", + } + ) phase_index += 1 # Add a final polish phase if we have many sections (complex project) complexity_score = sum(1 for v in sections.values() if v) if complexity_score >= 5: - phases.append({ - "phase_index": phase_index, - "title": "Final Polish & Production Readiness", - "description": "Address technical debt, optimize user experience, perform security audit, finalize error handling, and ensure production readiness.", - "test_plan": "- Security audit passes\n- Performance benchmarks met\n- Error messages are user-friendly\n- Production checklist completed", - }) + phases.append( + { + "phase_index": phase_index, + "title": "Final Polish & Production Readiness", + "description": "Address technical debt, optimize user experience, perform security audit, finalize error handling, and ensure production readiness.", + "test_plan": "- Security audit passes\n- Performance benchmarks met\n- Error messages are user-friendly\n- Production checklist completed", + } + ) return phases diff --git a/backend/app/services/llm_mock_spec.py b/backend/app/services/llm_mock_spec.py index 047116d..65324ba 100644 --- a/backend/app/services/llm_mock_spec.py +++ b/backend/app/services/llm_mock_spec.py @@ -37,7 +37,9 @@ def generate_specification_mock( - Testing Strategy """ # Create deterministic seed from inputs - seed_data = f"{str(project_id)}:{idea_text}:{json.dumps(sorted(discovery_answers, key=lambda x: x.get('question', '')))}" + seed_data = ( + f"{str(project_id)}:{idea_text}:{json.dumps(sorted(discovery_answers, key=lambda x: x.get('question', '')))}" + ) seed = int(hashlib.sha256(seed_data.encode()).hexdigest(), 16) % (10**8) rng = random.Random(seed) @@ -130,7 +132,7 @@ def _generate_overview(idea_text: str, rng: random.Random) -> str: if not idea_text: return "This application provides core functionality for users.\n" - return f"""This project aims to {idea_text.lower().strip('.')}. The system will be designed with scalability, maintainability, and user experience as primary concerns. + return f"""This project aims to {idea_text.lower().strip(".")}. The system will be designed with scalability, maintainability, and user experience as primary concerns. **Key Objectives:** - Deliver a robust and reliable solution @@ -192,7 +194,9 @@ def _generate_data_model(database: str | None, discovery_answers: list[dict], rn """ -def _generate_architecture(tech_stack: list[str], architecture_notes: list[str], discovery_answers: list[dict], rng: random.Random) -> str: +def _generate_architecture( + tech_stack: list[str], architecture_notes: list[str], discovery_answers: list[dict], rng: random.Random +) -> str: """Generate Architecture section.""" parts = [] diff --git a/backend/app/services/llm_preference_service.py b/backend/app/services/llm_preference_service.py index ca8797e..73cbee2 100644 --- a/backend/app/services/llm_preference_service.py +++ b/backend/app/services/llm_preference_service.py @@ -1,4 +1,5 @@ """LLM Preference service.""" + from uuid import UUID from sqlalchemy import select @@ -27,9 +28,7 @@ async def get_preference(self, organization_id: UUID) -> LLMPreference | None: Returns: LLM preference or None if not found """ - stmt = select(LLMPreference).where( - LLMPreference.organization_id == organization_id - ) + stmt = select(LLMPreference).where(LLMPreference.organization_id == organization_id) result = await self.db.execute(stmt) return result.scalar_one_or_none() diff --git a/backend/app/services/llm_usage_log_service.py b/backend/app/services/llm_usage_log_service.py index d92b707..064d9b0 100644 --- a/backend/app/services/llm_usage_log_service.py +++ b/backend/app/services/llm_usage_log_service.py @@ -4,15 +4,15 @@ This service provides operations for the lightweight llm_usage_logs table which tracks all LLM calls (including standalone calls outside of job context). """ -import json + import logging from datetime import datetime, timezone -from typing import Optional, List, Dict, Any -from uuid import UUID from decimal import Decimal +from typing import Any, Dict, List, Optional +from uuid import UUID +from sqlalchemy import extract, func from sqlalchemy.orm import Session -from sqlalchemy import func, extract from app.models.llm_usage_log import LLMUsageLog from app.services.kafka_producer import get_sync_kafka_producer @@ -119,7 +119,9 @@ def _broadcast_usage_created(log_entry: LLMUsageLog): "created_at": log_entry.created_at.isoformat() if log_entry.created_at else None, "llm_call_log_id": str(log_entry.llm_call_log_id) if log_entry.llm_call_log_id else None, "triggered_by_user_id": str(log_entry.triggered_by_user_id) if log_entry.triggered_by_user_id else None, - "triggered_by_user_name": log_entry.triggered_by_user.display_name if log_entry.triggered_by_user else None, + "triggered_by_user_name": log_entry.triggered_by_user.display_name + if log_entry.triggered_by_user + else None, "duration_ms": log_entry.duration_ms, } @@ -138,9 +140,7 @@ def _broadcast_usage_created(log_entry: LLMUsageLog): ) if success: - logger.debug( - f"Broadcasted LLM usage log via Kafka: id={log_entry.id}" - ) + logger.debug(f"Broadcasted LLM usage log via Kafka: id={log_entry.id}") except Exception as e: # Log error but don't fail the usage log creation @@ -165,15 +165,19 @@ def get_monthly_usage( Returns: Dict with total_prompt_tokens, total_completion_tokens, total_tokens, total_cost_usd """ - result = db.query( - func.coalesce(func.sum(LLMUsageLog.prompt_tokens), 0).label("prompt_tokens"), - func.coalesce(func.sum(LLMUsageLog.completion_tokens), 0).label("completion_tokens"), - func.sum(LLMUsageLog.cost_usd).label("cost_usd"), - ).filter( - LLMUsageLog.org_id == org_id, - extract("year", LLMUsageLog.created_at) == year, - extract("month", LLMUsageLog.created_at) == month, - ).first() + result = ( + db.query( + func.coalesce(func.sum(LLMUsageLog.prompt_tokens), 0).label("prompt_tokens"), + func.coalesce(func.sum(LLMUsageLog.completion_tokens), 0).label("completion_tokens"), + func.sum(LLMUsageLog.cost_usd).label("cost_usd"), + ) + .filter( + LLMUsageLog.org_id == org_id, + extract("year", LLMUsageLog.created_at) == year, + extract("month", LLMUsageLog.created_at) == month, + ) + .first() + ) prompt_tokens = int(result.prompt_tokens) if result.prompt_tokens else 0 completion_tokens = int(result.completion_tokens) if result.completion_tokens else 0 @@ -205,21 +209,23 @@ def get_usage_by_agent( Returns: List of dicts with agent_name, agent_display_name, total_tokens, total_cost_usd, call_count """ - results = db.query( - LLMUsageLog.agent_name, - func.max(LLMUsageLog.agent_display_name).label("agent_display_name"), - func.sum(LLMUsageLog.prompt_tokens + LLMUsageLog.completion_tokens).label("total_tokens"), - func.sum(LLMUsageLog.cost_usd).label("total_cost"), - func.count(LLMUsageLog.id).label("call_count"), - ).filter( - LLMUsageLog.org_id == org_id, - extract("year", LLMUsageLog.created_at) == year, - extract("month", LLMUsageLog.created_at) == month, - ).group_by( - LLMUsageLog.agent_name - ).order_by( - func.sum(LLMUsageLog.prompt_tokens + LLMUsageLog.completion_tokens).desc() - ).all() + results = ( + db.query( + LLMUsageLog.agent_name, + func.max(LLMUsageLog.agent_display_name).label("agent_display_name"), + func.sum(LLMUsageLog.prompt_tokens + LLMUsageLog.completion_tokens).label("total_tokens"), + func.sum(LLMUsageLog.cost_usd).label("total_cost"), + func.count(LLMUsageLog.id).label("call_count"), + ) + .filter( + LLMUsageLog.org_id == org_id, + extract("year", LLMUsageLog.created_at) == year, + extract("month", LLMUsageLog.created_at) == month, + ) + .group_by(LLMUsageLog.agent_name) + .order_by(func.sum(LLMUsageLog.prompt_tokens + LLMUsageLog.completion_tokens).desc()) + .all() + ) return [ { @@ -249,11 +255,13 @@ def list_recent_usage( Returns: List of LLMUsageLog entries ordered by created_at descending """ - return db.query(LLMUsageLog).filter( - LLMUsageLog.org_id == org_id - ).order_by( - LLMUsageLog.created_at.desc() - ).limit(limit).all() + return ( + db.query(LLMUsageLog) + .filter(LLMUsageLog.org_id == org_id) + .order_by(LLMUsageLog.created_at.desc()) + .limit(limit) + .all() + ) @staticmethod def count_usage_logs( diff --git a/backend/app/services/mcp_call_log_service.py b/backend/app/services/mcp_call_log_service.py index 72a18c4..910085b 100644 --- a/backend/app/services/mcp_call_log_service.py +++ b/backend/app/services/mcp_call_log_service.py @@ -1,12 +1,13 @@ """ Service for managing MCP call logs. """ + from datetime import datetime, timezone -from typing import Optional, List +from typing import List, Optional from uuid import UUID -from sqlalchemy.orm import Session, joinedload from sqlalchemy import desc +from sqlalchemy.orm import Session, joinedload from app.models.mcp_call_log import MCPCallLog @@ -63,12 +64,7 @@ def create_call_log( @staticmethod def get_call_log(db: Session, call_log_id: UUID) -> Optional[MCPCallLog]: """Get a specific call log by ID.""" - return ( - db.query(MCPCallLog) - .options(joinedload(MCPCallLog.project)) - .filter(MCPCallLog.id == call_log_id) - .first() - ) + return db.query(MCPCallLog).options(joinedload(MCPCallLog.project)).filter(MCPCallLog.id == call_log_id).first() @staticmethod def list_call_logs( @@ -86,11 +82,7 @@ def list_call_logs( Returns logs ordered by creation date (most recent first). """ - query = ( - db.query(MCPCallLog) - .options(joinedload(MCPCallLog.project)) - .filter(MCPCallLog.org_id == org_id) - ) + query = db.query(MCPCallLog).options(joinedload(MCPCallLog.project)).filter(MCPCallLog.org_id == org_id) if project_id: query = query.filter(MCPCallLog.project_id == project_id) diff --git a/backend/app/services/mcp_image_service.py b/backend/app/services/mcp_image_service.py index 539a119..8acb2b8 100644 --- a/backend/app/services/mcp_image_service.py +++ b/backend/app/services/mcp_image_service.py @@ -30,9 +30,7 @@ def cleanup_expired_submissions(db: Session) -> int: now = datetime.now(timezone.utc) # Find and delete expired submissions - expired = db.query(MCPImageSubmission).filter( - MCPImageSubmission.expires_at < now - ).all() + expired = db.query(MCPImageSubmission).filter(MCPImageSubmission.expires_at < now).all() if not expired: return 0 diff --git a/backend/app/services/mcp_oauth_service.py b/backend/app/services/mcp_oauth_service.py index 4ea6221..cce3597 100644 --- a/backend/app/services/mcp_oauth_service.py +++ b/backend/app/services/mcp_oauth_service.py @@ -2,24 +2,23 @@ import base64 import hashlib +import re import secrets from datetime import datetime, timedelta, timezone from typing import Any from uuid import UUID -import re import bcrypt import jwt from sqlalchemy.orm import Session +from app.auth.encryption_utils import encrypt_api_key from app.config import settings from app.models.mcp_oauth_client import MCPOAuthClient from app.models.mcp_oauth_code import MCPOAuthAuthorizationCode from app.models.mcp_oauth_token import MCPOAuthToken -from app.models.user import User from app.models.project import Project -from app.auth.encryption_utils import encrypt_api_key, decrypt_api_key - +from app.models.user import User # Token expiry constants ACCESS_TOKEN_EXPIRY_HOURS = 1 @@ -343,6 +342,7 @@ def _generate_access_token( # Get user email for sub claim from app.database import SessionLocal + with SessionLocal() as db: user = db.query(User).filter(User.id == user_id).first() user_email = user.email if user else str(user_id) @@ -484,9 +484,7 @@ def revoke_token( for tr in token_records: # Check refresh token - if tr.refresh_token_hash and bcrypt.checkpw( - token.encode("utf-8"), tr.refresh_token_hash.encode("utf-8") - ): + if tr.refresh_token_hash and bcrypt.checkpw(token.encode("utf-8"), tr.refresh_token_hash.encode("utf-8")): tr.revoked = True tr.revoked_at = datetime.now(timezone.utc) db.commit() @@ -516,6 +514,7 @@ def validate_access_token( Tuple of (User, project_id) if valid, None otherwise """ import logging + logger = logging.getLogger(__name__) try: diff --git a/backend/app/services/mention_utils.py b/backend/app/services/mention_utils.py index 2a4e002..68f8c62 100644 --- a/backend/app/services/mention_utils.py +++ b/backend/app/services/mention_utils.py @@ -216,9 +216,9 @@ def resolve_phase_feature_mention( ("brainstorming_phase", UUID("...")) """ # Import here to avoid circular imports - from app.services.phase_container_service import PhaseContainerService from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.feature_service import FeatureService + from app.services.phase_container_service import PhaseContainerService # Try phase container first container = PhaseContainerService.get_by_identifier(db, identifier) diff --git a/backend/app/services/module_service.py b/backend/app/services/module_service.py index 5560eb3..e35d445 100644 --- a/backend/app/services/module_service.py +++ b/backend/app/services/module_service.py @@ -1,23 +1,20 @@ """Service for managing modules.""" -import os -import base64 -from typing import Optional, List, Callable, Tuple -from uuid import UUID + from datetime import datetime, timezone -from sqlalchemy.orm import Session +from typing import Callable, List, Optional, Tuple +from uuid import UUID + from sqlalchemy import func -from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from sqlalchemy.orm import Session -from app.models.module import Module, ModuleProvenance, ModuleType from app.models.brainstorming_phase import BrainstormingPhase -from app.models.final_spec import FinalSpec +from app.models.feature import Feature, FeaturePriority, FeatureProvenance, FeatureStatus, FeatureType +from app.models.feature_content_version import FeatureContentType from app.models.final_prompt_plan import FinalPromptPlan -from app.models.spec_version import SpecVersion, SpecType +from app.models.final_spec import FinalSpec +from app.models.module import Module, ModuleProvenance, ModuleType from app.models.project import Project -from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType, FeaturePriority -from app.models.feature_content_version import FeatureContentType +from app.models.spec_version import SpecType, SpecVersion from app.services.feature_content_version_service import FeatureContentVersionService from app.services.implementation_service import ImplementationService from app.services.platform_settings_service import require_llm_config_sync @@ -55,9 +52,7 @@ def generate_module_key( # Use MAX instead of COUNT for robustness (handles deleted modules) max_number = ( - db.query(func.max(Module.module_key_number)) - .filter(Module.project_id == project_id) - .scalar() + db.query(func.max(Module.module_key_number)).filter(Module.project_id == project_id).scalar() ) or 0 next_number = max_number + 1 @@ -156,11 +151,7 @@ def get_by_identifier(db: Session, identifier: str) -> Optional[Module]: # Extract short_id and query short_id = extract_short_id(identifier) - return ( - db.query(Module) - .filter(Module.short_id == short_id, Module.archived_at.is_(None)) - .first() - ) + return db.query(Module).filter(Module.short_id == short_id, Module.archived_at.is_(None)).first() @staticmethod def list_modules( @@ -235,7 +226,7 @@ def archive_module_with_features( Returns: Tuple of (archived Module, count of archived features) """ - from app.services.activity_log_service import ActivityLogService, ActivityEventTypes + from app.services.activity_log_service import ActivityEventTypes, ActivityLogService module = db.query(Module).filter(Module.id == module_id).first() if module is None: @@ -244,10 +235,14 @@ def archive_module_with_features( now = datetime.now(timezone.utc) # Get active features in this module - active_features = db.query(Feature).filter( - Feature.module_id == module_id, - Feature.status == FeatureStatus.ACTIVE, - ).all() + active_features = ( + db.query(Feature) + .filter( + Feature.module_id == module_id, + Feature.status == FeatureStatus.ACTIVE, + ) + .all() + ) # Archive each feature and log activity archived_count = 0 @@ -387,13 +382,11 @@ async def generate_modules_from_final_spec( Raises: ValueError: If phase not found, no final spec, or LLM config missing """ + from app.agents.module_feature import ModuleFeatureContext, create_orchestrator from app.services.feature_service import FeatureService - from app.agents.module_feature import create_orchestrator, ModuleFeatureContext # 1. Load brainstorming phase and verify final spec exists - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == brainstorming_phase_id).first() if not phase: raise ValueError(f"Brainstorming phase {brainstorming_phase_id} not found") @@ -401,9 +394,7 @@ async def generate_modules_from_final_spec( raise ValueError(f"Phase {brainstorming_phase_id} does not belong to project {project_id}") # 2. Load final spec - final_spec = db.query(FinalSpec).filter( - FinalSpec.brainstorming_phase_id == brainstorming_phase_id - ).first() + final_spec = db.query(FinalSpec).filter(FinalSpec.brainstorming_phase_id == brainstorming_phase_id).first() if not final_spec: raise ValueError(f"No Final Spec found for phase {brainstorming_phase_id}. Generate a Final Spec first.") @@ -412,19 +403,24 @@ async def generate_modules_from_final_spec( prompt_plan_json = None # First try to get FinalPromptPlan - final_prompt_plan = db.query(FinalPromptPlan).filter( - FinalPromptPlan.brainstorming_phase_id == brainstorming_phase_id - ).first() + final_prompt_plan = ( + db.query(FinalPromptPlan).filter(FinalPromptPlan.brainstorming_phase_id == brainstorming_phase_id).first() + ) if final_prompt_plan: prompt_plan_markdown = final_prompt_plan.content_markdown prompt_plan_json = final_prompt_plan.content_json else: # Fall back to latest draft prompt plan - latest_prompt_plan = db.query(SpecVersion).filter( - SpecVersion.brainstorming_phase_id == brainstorming_phase_id, - SpecVersion.spec_type == SpecType.PROMPT_PLAN, - ).order_by(SpecVersion.version.desc()).first() + latest_prompt_plan = ( + db.query(SpecVersion) + .filter( + SpecVersion.brainstorming_phase_id == brainstorming_phase_id, + SpecVersion.spec_type == SpecType.PROMPT_PLAN, + ) + .order_by(SpecVersion.version.desc()) + .first() + ) if latest_prompt_plan: prompt_plan_markdown = latest_prompt_plan.content_markdown @@ -449,10 +445,9 @@ async def generate_modules_from_final_spec( # 6.5. Build cross-phase context for coherence from app.services.brainstorming_phase_service import _build_cross_project_context + cross_project_context = _build_cross_project_context( - db, - project_id=project_id, - current_phase_id=brainstorming_phase_id + db, project_id=project_id, current_phase_id=brainstorming_phase_id ) # 6.6. Extract pending topics from prompt plan JSON if available @@ -463,6 +458,7 @@ async def generate_modules_from_final_spec( # 6.7. Prepare phase description images for context from app.agents.module_feature.types import ImageAttachmentInfo + phase_description_images = phase.description_image_attachments or [] phase_image_info = [ ImageAttachmentInfo( @@ -602,7 +598,7 @@ async def generate_modules_from_final_spec( # 10. Extract LLM usage stats before closing orchestrator llm_usage = None - if hasattr(orchestrator, 'model_client') and hasattr(orchestrator.model_client, 'get_usage_stats'): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), diff --git a/backend/app/services/notification_adapters/__init__.py b/backend/app/services/notification_adapters/__init__.py index eed1a8d..fbd88e4 100644 --- a/backend/app/services/notification_adapters/__init__.py +++ b/backend/app/services/notification_adapters/__init__.py @@ -3,6 +3,7 @@ This module exports notification adapters for various channels (Email, Slack, Teams). """ + from app.services.notification_adapters.base import NotificationAdapter from app.services.notification_adapters.email import EmailAdapter from app.services.notification_adapters.slack import SlackAdapter diff --git a/backend/app/services/notification_adapters/base.py b/backend/app/services/notification_adapters/base.py index 5922cad..b30383c 100644 --- a/backend/app/services/notification_adapters/base.py +++ b/backend/app/services/notification_adapters/base.py @@ -3,8 +3,9 @@ Defines the contract that all notification channel adapters must follow. """ + from abc import ABC, abstractmethod -from typing import Dict, Any +from typing import Any, Dict class NotificationAdapter(ABC): @@ -22,7 +23,7 @@ async def send_notification( subject: str, body: str, metadata: Dict[str, Any], - channel_config: Dict[str, Any] | None = None + channel_config: Dict[str, Any] | None = None, ) -> bool: """ Send a notification through this channel. diff --git a/backend/app/services/notification_adapters/email.py b/backend/app/services/notification_adapters/email.py index 3b241dc..f0015e9 100644 --- a/backend/app/services/notification_adapters/email.py +++ b/backend/app/services/notification_adapters/email.py @@ -3,8 +3,9 @@ Sends notifications via SMTP email. """ + import logging -from typing import Dict, Any +from typing import Any, Dict from app.services.notification_adapters.base import NotificationAdapter @@ -34,7 +35,7 @@ async def send_notification( subject: str, body: str, metadata: Dict[str, Any], - channel_config: Dict[str, Any] | None = None + channel_config: Dict[str, Any] | None = None, ) -> bool: """ Send an email notification. @@ -50,10 +51,7 @@ async def send_notification( True if email was sent successfully """ if self.mock: - logger.info( - f"[MOCK EMAIL] To: {recipient}, Subject: {subject}, " - f"Body: {body[:50]}..., Metadata: {metadata}" - ) + logger.info(f"[MOCK EMAIL] To: {recipient}, Subject: {subject}, Body: {body[:50]}..., Metadata: {metadata}") return True # In production, implement actual SMTP sending: diff --git a/backend/app/services/notification_adapters/slack.py b/backend/app/services/notification_adapters/slack.py index 28e1609..ed8a867 100644 --- a/backend/app/services/notification_adapters/slack.py +++ b/backend/app/services/notification_adapters/slack.py @@ -3,8 +3,9 @@ Sends notifications to Slack via webhook or API. """ + import logging -from typing import Dict, Any +from typing import Any, Dict from app.services.notification_adapters.base import NotificationAdapter @@ -33,7 +34,7 @@ async def send_notification( subject: str, body: str, metadata: Dict[str, Any], - channel_config: Dict[str, Any] | None = None + channel_config: Dict[str, Any] | None = None, ) -> bool: """ Send a Slack notification. @@ -49,10 +50,7 @@ async def send_notification( True if message was sent successfully """ if self.mock: - logger.info( - f"[MOCK SLACK] To: {recipient}, Subject: {subject}, " - f"Body: {body[:50]}..., Metadata: {metadata}" - ) + logger.info(f"[MOCK SLACK] To: {recipient}, Subject: {subject}, Body: {body[:50]}..., Metadata: {metadata}") return True # In production, implement actual Slack sending: diff --git a/backend/app/services/notification_adapters/teams.py b/backend/app/services/notification_adapters/teams.py index 4b7c1af..553f534 100644 --- a/backend/app/services/notification_adapters/teams.py +++ b/backend/app/services/notification_adapters/teams.py @@ -3,8 +3,9 @@ Sends notifications to Teams via webhook. """ + import logging -from typing import Dict, Any +from typing import Any, Dict from app.services.notification_adapters.base import NotificationAdapter @@ -33,7 +34,7 @@ async def send_notification( subject: str, body: str, metadata: Dict[str, Any], - channel_config: Dict[str, Any] | None = None + channel_config: Dict[str, Any] | None = None, ) -> bool: """ Send a Teams notification. @@ -49,10 +50,7 @@ async def send_notification( True if message was sent successfully """ if self.mock: - logger.info( - f"[MOCK TEAMS] To: {recipient}, Subject: {subject}, " - f"Body: {body[:50]}..., Metadata: {metadata}" - ) + logger.info(f"[MOCK TEAMS] To: {recipient}, Subject: {subject}, Body: {body[:50]}..., Metadata: {metadata}") return True # In production, implement actual Teams sending: diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 825b23f..f99adb3 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -3,18 +3,19 @@ Handles notification preferences, project mutes, thread watches, and event enqueueing. """ -from uuid import UUID + from typing import List, Optional -from sqlalchemy.orm import Session +from uuid import UUID + from sqlalchemy import and_ +from sqlalchemy.orm import Session from app.models import ( - NotificationPreference, + JobType, NotificationChannel, + NotificationPreference, NotificationProjectMute, NotificationThreadWatch, - User, - JobType, ) from app.services.job_service import JobService @@ -25,28 +26,20 @@ class NotificationService: # Preference Management @staticmethod - def get_user_preferences( - db: Session, - user_id: UUID - ) -> List[NotificationPreference]: + def get_user_preferences(db: Session, user_id: UUID) -> List[NotificationPreference]: """Get all notification preferences for a user.""" - return db.query(NotificationPreference).filter( - NotificationPreference.user_id == user_id - ).all() + return db.query(NotificationPreference).filter(NotificationPreference.user_id == user_id).all() @staticmethod def get_user_preference( - db: Session, - user_id: UUID, - channel: NotificationChannel + db: Session, user_id: UUID, channel: NotificationChannel ) -> Optional[NotificationPreference]: """Get a specific notification preference for a user.""" - return db.query(NotificationPreference).filter( - and_( - NotificationPreference.user_id == user_id, - NotificationPreference.channel == channel - ) - ).first() + return ( + db.query(NotificationPreference) + .filter(and_(NotificationPreference.user_id == user_id, NotificationPreference.channel == channel)) + .first() + ) @staticmethod def create_preference( @@ -54,15 +47,10 @@ def create_preference( user_id: UUID, channel: NotificationChannel, enabled: bool = True, - channel_config: Optional[str] = None + channel_config: Optional[str] = None, ) -> NotificationPreference: """Create a notification preference for a user.""" - pref = NotificationPreference( - user_id=user_id, - channel=channel, - enabled=enabled, - channel_config=channel_config - ) + pref = NotificationPreference(user_id=user_id, channel=channel, enabled=enabled, channel_config=channel_config) db.add(pref) db.commit() db.refresh(pref) @@ -70,15 +58,10 @@ def create_preference( @staticmethod def update_preference( - db: Session, - preference_id: UUID, - enabled: Optional[bool] = None, - channel_config: Optional[str] = None + db: Session, preference_id: UUID, enabled: Optional[bool] = None, channel_config: Optional[str] = None ) -> Optional[NotificationPreference]: """Update a notification preference.""" - pref = db.query(NotificationPreference).filter( - NotificationPreference.id == preference_id - ).first() + pref = db.query(NotificationPreference).filter(NotificationPreference.id == preference_id).first() if not pref: return None @@ -95,9 +78,7 @@ def update_preference( @staticmethod def delete_preference(db: Session, preference_id: UUID) -> bool: """Delete a notification preference.""" - pref = db.query(NotificationPreference).filter( - NotificationPreference.id == preference_id - ).first() + pref = db.query(NotificationPreference).filter(NotificationPreference.id == preference_id).first() if not pref: return False @@ -109,70 +90,47 @@ def delete_preference(db: Session, preference_id: UUID) -> bool: # Project Mute Management @staticmethod - def get_user_project_mutes( - db: Session, - user_id: UUID - ) -> List[NotificationProjectMute]: + def get_user_project_mutes(db: Session, user_id: UUID) -> List[NotificationProjectMute]: """Get all project mutes for a user.""" - return db.query(NotificationProjectMute).filter( - NotificationProjectMute.user_id == user_id - ).all() + return db.query(NotificationProjectMute).filter(NotificationProjectMute.user_id == user_id).all() @staticmethod - def is_project_muted( - db: Session, - user_id: UUID, - project_id: UUID - ) -> bool: + def is_project_muted(db: Session, user_id: UUID, project_id: UUID) -> bool: """Check if a user has muted a project.""" - mute = db.query(NotificationProjectMute).filter( - and_( - NotificationProjectMute.user_id == user_id, - NotificationProjectMute.project_id == project_id - ) - ).first() + mute = ( + db.query(NotificationProjectMute) + .filter(and_(NotificationProjectMute.user_id == user_id, NotificationProjectMute.project_id == project_id)) + .first() + ) return mute is not None @staticmethod - def mute_project( - db: Session, - user_id: UUID, - project_id: UUID - ) -> NotificationProjectMute: + def mute_project(db: Session, user_id: UUID, project_id: UUID) -> NotificationProjectMute: """Mute a project for a user.""" # Check if already muted - existing = db.query(NotificationProjectMute).filter( - and_( - NotificationProjectMute.user_id == user_id, - NotificationProjectMute.project_id == project_id - ) - ).first() + existing = ( + db.query(NotificationProjectMute) + .filter(and_(NotificationProjectMute.user_id == user_id, NotificationProjectMute.project_id == project_id)) + .first() + ) if existing: return existing - mute = NotificationProjectMute( - user_id=user_id, - project_id=project_id - ) + mute = NotificationProjectMute(user_id=user_id, project_id=project_id) db.add(mute) db.commit() db.refresh(mute) return mute @staticmethod - def unmute_project( - db: Session, - user_id: UUID, - project_id: UUID - ) -> bool: + def unmute_project(db: Session, user_id: UUID, project_id: UUID) -> bool: """Unmute a project for a user.""" - mute = db.query(NotificationProjectMute).filter( - and_( - NotificationProjectMute.user_id == user_id, - NotificationProjectMute.project_id == project_id - ) - ).first() + mute = ( + db.query(NotificationProjectMute) + .filter(and_(NotificationProjectMute.user_id == user_id, NotificationProjectMute.project_id == project_id)) + .first() + ) if not mute: return False @@ -184,81 +142,53 @@ def unmute_project( # Thread Watch Management @staticmethod - def get_user_thread_watches( - db: Session, - user_id: UUID - ) -> List[NotificationThreadWatch]: + def get_user_thread_watches(db: Session, user_id: UUID) -> List[NotificationThreadWatch]: """Get all thread watches for a user.""" - return db.query(NotificationThreadWatch).filter( - NotificationThreadWatch.user_id == user_id - ).all() + return db.query(NotificationThreadWatch).filter(NotificationThreadWatch.user_id == user_id).all() @staticmethod - def get_thread_watchers( - db: Session, - thread_id: str - ) -> List[UUID]: + def get_thread_watchers(db: Session, thread_id: str) -> List[UUID]: """Get all users watching a thread.""" - watches = db.query(NotificationThreadWatch).filter( - NotificationThreadWatch.thread_id == thread_id - ).all() + watches = db.query(NotificationThreadWatch).filter(NotificationThreadWatch.thread_id == thread_id).all() return [watch.user_id for watch in watches] @staticmethod - def is_watching_thread( - db: Session, - user_id: UUID, - thread_id: str - ) -> bool: + def is_watching_thread(db: Session, user_id: UUID, thread_id: str) -> bool: """Check if a user is watching a thread.""" - watch = db.query(NotificationThreadWatch).filter( - and_( - NotificationThreadWatch.user_id == user_id, - NotificationThreadWatch.thread_id == thread_id - ) - ).first() + watch = ( + db.query(NotificationThreadWatch) + .filter(and_(NotificationThreadWatch.user_id == user_id, NotificationThreadWatch.thread_id == thread_id)) + .first() + ) return watch is not None @staticmethod - def watch_thread( - db: Session, - user_id: UUID, - thread_id: str - ) -> NotificationThreadWatch: + def watch_thread(db: Session, user_id: UUID, thread_id: str) -> NotificationThreadWatch: """Watch a thread for notifications.""" # Check if already watching - existing = db.query(NotificationThreadWatch).filter( - and_( - NotificationThreadWatch.user_id == user_id, - NotificationThreadWatch.thread_id == thread_id - ) - ).first() + existing = ( + db.query(NotificationThreadWatch) + .filter(and_(NotificationThreadWatch.user_id == user_id, NotificationThreadWatch.thread_id == thread_id)) + .first() + ) if existing: return existing - watch = NotificationThreadWatch( - user_id=user_id, - thread_id=thread_id - ) + watch = NotificationThreadWatch(user_id=user_id, thread_id=thread_id) db.add(watch) db.commit() db.refresh(watch) return watch @staticmethod - def unwatch_thread( - db: Session, - user_id: UUID, - thread_id: str - ) -> bool: + def unwatch_thread(db: Session, user_id: UUID, thread_id: str) -> bool: """Unwatch a thread.""" - watch = db.query(NotificationThreadWatch).filter( - and_( - NotificationThreadWatch.user_id == user_id, - NotificationThreadWatch.thread_id == thread_id - ) - ).first() + watch = ( + db.query(NotificationThreadWatch) + .filter(and_(NotificationThreadWatch.user_id == user_id, NotificationThreadWatch.thread_id == thread_id)) + .first() + ) if not watch: return False @@ -278,7 +208,7 @@ def enqueue_notification_event( body: str, recipients: List[UUID], related_entity_id: Optional[str] = None, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, ) -> str: """ Enqueue a notification event for processing by the fanout worker. @@ -303,14 +233,9 @@ def enqueue_notification_event( "body": body, "recipients": [str(uid) for uid in recipients], "related_entity_id": related_entity_id, - "metadata": metadata or {} + "metadata": metadata or {}, } - job = JobService.create_job( - db=db, - job_type=JobType.NOTIFICATION_FANOUT, - payload=payload, - project_id=project_id - ) + job = JobService.create_job(db=db, job_type=JobType.NOTIFICATION_FANOUT, payload=payload, project_id=project_id) return str(job.id) diff --git a/backend/app/services/org_service.py b/backend/app/services/org_service.py index de011ed..0b1ff40 100644 --- a/backend/app/services/org_service.py +++ b/backend/app/services/org_service.py @@ -3,14 +3,15 @@ Follows the service layer pattern established in user_service.py. """ + from typing import Optional from uuid import UUID from sqlalchemy import and_ from sqlalchemy.orm import Session -from app.models.organization import Organization from app.models.org_membership import OrgMembership, OrgRole +from app.models.organization import Organization from app.models.provisioning import ProvisioningSource from app.services.team_role_service import TeamRoleService @@ -19,9 +20,7 @@ class OrgService: """Service class for organization-related operations.""" @staticmethod - def create_org_with_owner( - db: Session, name: str, owner_user_id: UUID - ) -> tuple[Organization, OrgMembership]: + def create_org_with_owner(db: Session, name: str, owner_user_id: UUID) -> tuple[Organization, OrgMembership]: """ Create a new organization and assign the specified user as owner. @@ -48,9 +47,7 @@ def create_org_with_owner( db.flush() # Get the org.id before creating membership # Create owner membership - membership = OrgMembership( - org_id=org.id, user_id=owner_user_id, role=OrgRole.OWNER - ) + membership = OrgMembership(org_id=org.id, user_id=owner_user_id, role=OrgRole.OWNER) db.add(membership) # Seed default team roles for the new org @@ -99,17 +96,11 @@ def find_or_create_org_by_external_id( from app.plugin_registry import get_plugin_registry # Try to find existing org by external ID - org = ( - db.query(Organization) - .filter(Organization.organization_id == external_org_id) - .first() - ) + org = db.query(Organization).filter(Organization.organization_id == external_org_id).first() if org: # Organization exists - add user as member if not already a member - existing_membership = OrgService.get_org_membership( - db, org.id, initial_user_id - ) + existing_membership = OrgService.get_org_membership(db, org.id, initial_user_id) if not existing_membership: OrgService.add_org_member( db, @@ -161,17 +152,13 @@ def get_user_orgs(db: Session, user_id: UUID) -> list[tuple[Organization, OrgRol List of (Organization, OrgRole) tuples """ # Query memberships - relationships will auto-load org due to lazy="joined" - memberships = ( - db.query(OrgMembership).filter(OrgMembership.user_id == user_id).all() - ) + memberships = db.query(OrgMembership).filter(OrgMembership.user_id == user_id).all() # Extract org and role from each membership return [(membership.org, membership.role) for membership in memberships] @staticmethod - def get_org_membership( - db: Session, org_id: UUID, user_id: UUID - ) -> Optional[OrgMembership]: + def get_org_membership(db: Session, org_id: UUID, user_id: UUID) -> Optional[OrgMembership]: """ Get a user's membership in a specific organization. @@ -185,11 +172,7 @@ def get_org_membership( """ return ( db.query(OrgMembership) - .filter( - and_( - OrgMembership.org_id == org_id, OrgMembership.user_id == user_id - ) - ) + .filter(and_(OrgMembership.org_id == org_id, OrgMembership.user_id == user_id)) .first() ) diff --git a/backend/app/services/phase_container_service.py b/backend/app/services/phase_container_service.py index 0f6e2bd..3e96730 100644 --- a/backend/app/services/phase_container_service.py +++ b/backend/app/services/phase_container_service.py @@ -1,17 +1,18 @@ """Service for managing phase containers.""" + import logging from datetime import datetime, timezone -from typing import Optional, List +from typing import List, Optional from uuid import UUID + from sqlalchemy.orm import Session -from app.models.phase_container import PhaseContainer from app.models.brainstorming_phase import ( BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype, ) - +from app.models.phase_container import PhaseContainer logger = logging.getLogger(__name__) @@ -68,9 +69,7 @@ def get_container( return query.first() @staticmethod - def get_by_identifier( - db: Session, identifier: str, include_archived: bool = False - ) -> Optional[PhaseContainer]: + def get_by_identifier(db: Session, identifier: str, include_archived: bool = False) -> Optional[PhaseContainer]: """Get a phase container by UUID, short_id, or URL identifier. This method supports backward compatibility with existing UUID-based URLs @@ -93,9 +92,7 @@ def get_by_identifier( if is_uuid(identifier): try: uuid_val = UUID(identifier) - return PhaseContainerService.get_container( - db, uuid_val, include_archived - ) + return PhaseContainerService.get_container(db, uuid_val, include_archived) except ValueError: pass @@ -204,9 +201,7 @@ def restore_container( Returns: PhaseContainer: Restored container, or None if not found """ - container = PhaseContainerService.get_container( - db, container_id, include_archived=True - ) + container = PhaseContainerService.get_container(db, container_id, include_archived=True) if not container: return None @@ -232,9 +227,7 @@ def get_container_phases( Returns: List[BrainstormingPhase]: Phases ordered by container_sequence """ - query = db.query(BrainstormingPhase).filter( - BrainstormingPhase.container_id == container_id - ) + query = db.query(BrainstormingPhase).filter(BrainstormingPhase.container_id == container_id) if not include_archived: query = query.filter(BrainstormingPhase.archived_at.is_(None)) return query.order_by(BrainstormingPhase.container_sequence).all() @@ -277,10 +270,14 @@ def assign_phase_to_container( Returns: BrainstormingPhase: Updated phase, or None if phase not found """ - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id, - BrainstormingPhase.archived_at.is_(None), - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == phase_id, + BrainstormingPhase.archived_at.is_(None), + ) + .first() + ) if not phase: return None @@ -307,9 +304,13 @@ def remove_phase_from_container( Returns: BrainstormingPhase: Updated phase, or None if phase not found """ - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id, - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == phase_id, + ) + .first() + ) if not phase: return None @@ -380,13 +381,15 @@ def get_extension_preview(db: Session, container_id: UUID) -> dict: for phase in phases: if phase.phase_subtype == PhaseSubtype.EXTENSION: extension_number = (phase.container_sequence or 1) - 1 - sibling_extensions.append({ - "phase_id": str(phase.id), - "title": phase.title, - "extension_number": extension_number, - "created_at": phase.created_at, - "description": phase.description, - }) + sibling_extensions.append( + { + "phase_id": str(phase.id), + "title": phase.title, + "extension_number": extension_number, + "created_at": phase.created_at, + "description": phase.description, + } + ) # Get next sequence next_container_sequence = PhaseContainerService.get_next_sequence(db, container_id) diff --git a/backend/app/services/phase_progress_service.py b/backend/app/services/phase_progress_service.py index ce9d63d..65881b8 100644 --- a/backend/app/services/phase_progress_service.py +++ b/backend/app/services/phase_progress_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session -from app.models.feature import Feature, FeatureType, FeatureStatus +from app.models.feature import Feature, FeatureStatus, FeatureType from app.models.module import Module, ModuleType @@ -108,9 +108,7 @@ def get_phase_progress(db: Session, phase_id: UUID) -> PhaseProgress: .all() ) - total, completed, pending, in_prog, pct, mod_next = ( - PhaseProgressService.compute_feature_stats(features) - ) + total, completed, pending, in_prog, pct, mod_next = PhaseProgressService.compute_feature_stats(features) mod_progress = ModuleProgress( module_id=module.id, @@ -135,9 +133,7 @@ def get_phase_progress(db: Session, phase_id: UUID) -> PhaseProgress: phase_next_feature = mod_next phase.progress_percent = ( - round(phase.completed_features / phase.total_features * 100, 1) - if phase.total_features > 0 - else 0 + round(phase.completed_features / phase.total_features * 100, 1) if phase.total_features > 0 else 0 ) phase.next_feature = phase_next_feature @@ -165,9 +161,7 @@ def get_module_progress(db: Session, module_id: UUID) -> ModuleProgress: .all() ) - total, completed, pending, in_prog, pct, next_feat = ( - PhaseProgressService.compute_feature_stats(features) - ) + total, completed, pending, in_prog, pct, next_feat = PhaseProgressService.compute_feature_stats(features) return ModuleProgress( module_id=module_id, diff --git a/backend/app/services/plan_recommendation_service.py b/backend/app/services/plan_recommendation_service.py index b528fc3..875c6a4 100644 --- a/backend/app/services/plan_recommendation_service.py +++ b/backend/app/services/plan_recommendation_service.py @@ -164,9 +164,7 @@ def analyze_efficiency_streak( today = datetime.now(timezone.utc).date() start_date = today - timedelta(days=lookback_days) - daily_usage = PlanRecommendationService.get_org_daily_usage( - db, org.id, start_date, today - ) + daily_usage = PlanRecommendationService.get_org_daily_usage(db, org.id, start_date, today) if not daily_usage: return EfficiencyStreakInfo(has_streak=False) @@ -403,11 +401,7 @@ def dismiss_recommendation( Returns: Updated recommendation or None if not found """ - recommendation = ( - db.query(PlanRecommendation) - .filter(PlanRecommendation.id == recommendation_id) - .first() - ) + recommendation = db.query(PlanRecommendation).filter(PlanRecommendation.id == recommendation_id).first() if not recommendation: return None @@ -432,25 +426,17 @@ def expire_stale_recommendations(db: Session) -> int: Number of recommendations expired """ active_recs = ( - db.query(PlanRecommendation) - .filter(PlanRecommendation.status == RecommendationStatus.ACTIVE) - .all() + db.query(PlanRecommendation).filter(PlanRecommendation.status == RecommendationStatus.ACTIVE).all() ) expired_count = 0 for rec in active_recs: - org = ( - db.query(Organization) - .filter(Organization.id == rec.org_id) - .first() - ) + org = db.query(Organization).filter(Organization.id == rec.org_id).first() if not org: continue # Check recent efficiency - streak_info = PlanRecommendationService.analyze_efficiency_streak( - db, org, lookback_days=7 - ) + streak_info = PlanRecommendationService.analyze_efficiency_streak(db, org, lookback_days=7) # If no streak found in recent days, pattern has changed if not streak_info.has_streak: diff --git a/backend/app/services/plan_service.py b/backend/app/services/plan_service.py index d24ef8c..568698e 100644 --- a/backend/app/services/plan_service.py +++ b/backend/app/services/plan_service.py @@ -3,8 +3,9 @@ Handles token-based limits for trial and subscription plans. """ + from datetime import datetime, timedelta, timezone -from typing import Optional, Any +from typing import Any, Optional from uuid import UUID from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from app.models.organization import Organization from app.models.user import User - # Trial token allocation (legacy, kept for backwards compatibility) TRIAL_TOKEN_ALLOCATION = 5_000_000 # 5M tokens @@ -71,6 +71,7 @@ def setup_trial_plan(db: Session, org: Organization, user: User) -> Organization Updated Organization instance """ from datetime import timedelta + from app.auth.trial import TRIAL_DURATION_DAYS org.plan_name = "trial" @@ -337,9 +338,7 @@ def search_users_by_email(db: Session, email_query: str, limit: int = 20) -> lis ) @staticmethod - async def search_users_by_email_async( - db: "AsyncSession", email_query: str, limit: int = 20 - ) -> list[User]: + async def search_users_by_email_async(db: "AsyncSession", email_query: str, limit: int = 20) -> list[User]: """ Search users by email substring (case-insensitive). Async version. Platform admin only - searches across entire platform. @@ -400,9 +399,7 @@ def get_user_primary_org(db: Session, user_id: UUID) -> Optional[Organization]: return None @staticmethod - async def get_user_primary_org_async( - db: "AsyncSession", user_id: UUID - ) -> Optional[Organization]: + async def get_user_primary_org_async(db: "AsyncSession", user_id: UUID) -> Optional[Organization]: """ Get user's primary org (current_org_id or first org if not set). Async version. @@ -448,9 +445,7 @@ async def get_user_primary_org_async( return None @staticmethod - def get_org_plan_details( - db: Session, org_id: UUID, year: int, month: int - ) -> dict[str, Any]: + def get_org_plan_details(db: Session, org_id: UUID, year: int, month: int) -> dict[str, Any]: """ Get comprehensive plan details including current usage counts. @@ -534,13 +529,10 @@ async def get_org_plan_details_async( project_count = project_count_result.scalar() or 0 # Get monthly cost - cost_stmt = ( - select(func.sum(LLMUsageLog.cost_usd)) - .filter( - LLMUsageLog.org_id == org_id, - func.extract("year", LLMUsageLog.created_at) == year, - func.extract("month", LLMUsageLog.created_at) == month, - ) + cost_stmt = select(func.sum(LLMUsageLog.cost_usd)).filter( + LLMUsageLog.org_id == org_id, + func.extract("year", LLMUsageLog.created_at) == year, + func.extract("month", LLMUsageLog.created_at) == month, ) cost_result = await db.execute(cost_stmt) cost_this_month = cost_result.scalar() diff --git a/backend/app/services/platform_settings_service.py b/backend/app/services/platform_settings_service.py index d881b06..9a5bfd6 100644 --- a/backend/app/services/platform_settings_service.py +++ b/backend/app/services/platform_settings_service.py @@ -1,4 +1,5 @@ """Platform settings service.""" + import base64 import os from typing import Any @@ -40,9 +41,7 @@ def __init__(self, db: AsyncSession, encryption_key: str | None = None): encryption_key: Encryption key (defaults to env var ENCRYPTION_KEY) """ self.db = db - self._encryption_key = encryption_key or os.getenv( - "ENCRYPTION_KEY", "default-insecure-key-change-me" - ) + self._encryption_key = encryption_key or os.getenv("ENCRYPTION_KEY", "default-insecure-key-change-me") self._fernet = self._get_fernet() def _get_fernet(self) -> Fernet: @@ -241,9 +240,7 @@ async def list_connectors( Returns: List of platform connectors """ - stmt = select(PlatformConnector).order_by( - PlatformConnector.connector_type, PlatformConnector.provider - ) + stmt = select(PlatformConnector).order_by(PlatformConnector.connector_type, PlatformConnector.provider) if connector_type: type_value = connector_type.value if isinstance(connector_type, PlatformConnectorType) else connector_type @@ -430,7 +427,9 @@ async def get_llm_config_for_agents(self) -> dict[str, Any]: if lightweight_connector and lightweight_connector.is_active: result["lightweight"] = { "provider": lightweight_connector.provider, - "model": lightweight_connector.config_json.get("model") if lightweight_connector.config_json else None, + "model": lightweight_connector.config_json.get("model") + if lightweight_connector.config_json + else None, "api_key": self._decrypt_credentials(lightweight_connector.encrypted_credentials), } @@ -560,19 +559,12 @@ async def get_github_oauth_credentials(self) -> tuple[str | None, str | None]: platform_settings.github_oauth_client_id_encrypted and platform_settings.github_oauth_client_secret_encrypted ): - client_id = self._decrypt_credentials( - platform_settings.github_oauth_client_id_encrypted - ) - client_secret = self._decrypt_credentials( - platform_settings.github_oauth_client_secret_encrypted - ) + client_id = self._decrypt_credentials(platform_settings.github_oauth_client_id_encrypted) + client_secret = self._decrypt_credentials(platform_settings.github_oauth_client_secret_encrypted) return (client_id, client_secret) # Fall back to environment variables - if ( - app_settings.github_integration_oauth_client_id - and app_settings.github_integration_oauth_client_secret - ): + if app_settings.github_integration_oauth_client_id and app_settings.github_integration_oauth_client_secret: return ( app_settings.github_integration_oauth_client_id, app_settings.github_integration_oauth_client_secret, @@ -605,7 +597,6 @@ def get_llm_config_sync(db) -> dict[str, Any]: Raises: ValueError: If no LLM configuration found """ - from sqlalchemy.orm import Session # Get or create platform settings settings = db.query(PlatformSettings).first() @@ -642,9 +633,7 @@ def decrypt_credentials(encrypted: str) -> str: # Get main LLM connector if settings.main_llm_connector_id: main_connector = ( - db.query(PlatformConnector) - .filter(PlatformConnector.id == settings.main_llm_connector_id) - .first() + db.query(PlatformConnector).filter(PlatformConnector.id == settings.main_llm_connector_id).first() ) if main_connector and main_connector.is_active: result["main"] = { @@ -656,9 +645,7 @@ def decrypt_credentials(encrypted: str) -> str: # Get lightweight LLM connector if settings.lightweight_llm_connector_id: lightweight_connector = ( - db.query(PlatformConnector) - .filter(PlatformConnector.id == settings.lightweight_llm_connector_id) - .first() + db.query(PlatformConnector).filter(PlatformConnector.id == settings.lightweight_llm_connector_id).first() ) if lightweight_connector and lightweight_connector.is_active: result["lightweight"] = { @@ -685,8 +672,7 @@ def require_llm_config_sync(db) -> dict[str, Any]: config = get_llm_config_sync(db) if not config["main"]: raise ValueError( - "No LLM configured at platform level. " - "Please ask a platform admin to configure an LLM in Platform Settings." + "No LLM configured at platform level. Please ask a platform admin to configure an LLM in Platform Settings." ) return config @@ -728,11 +714,7 @@ def get_web_search_config_sync(db) -> dict[str, Any]: # Priority 1: UI-configured connector if settings.web_search_connector_id: - connector = ( - db.query(PlatformConnector) - .filter(PlatformConnector.id == settings.web_search_connector_id) - .first() - ) + connector = db.query(PlatformConnector).filter(PlatformConnector.id == settings.web_search_connector_id).first() if connector and connector.is_active: # Decrypt credentials encryption_key = os.getenv("ENCRYPTION_KEY", "default-insecure-key-change-me") @@ -797,10 +779,7 @@ def get_github_oauth_credentials_sync(db) -> tuple[str | None, str | None]: db.refresh(settings) # Check UI configuration first (takes precedence) - if ( - settings.github_oauth_client_id_encrypted - and settings.github_oauth_client_secret_encrypted - ): + if settings.github_oauth_client_id_encrypted and settings.github_oauth_client_secret_encrypted: # Decrypt credentials encryption_key = os.getenv("ENCRYPTION_KEY", "default-insecure-key-change-me") kdf = PBKDF2HMAC( @@ -812,12 +791,8 @@ def get_github_oauth_credentials_sync(db) -> tuple[str | None, str | None]: key = base64.urlsafe_b64encode(kdf.derive(encryption_key.encode())) fernet = Fernet(key) - client_id_encrypted = base64.b64decode( - settings.github_oauth_client_id_encrypted.encode() - ) - client_secret_encrypted = base64.b64decode( - settings.github_oauth_client_secret_encrypted.encode() - ) + client_id_encrypted = base64.b64decode(settings.github_oauth_client_id_encrypted.encode()) + client_secret_encrypted = base64.b64decode(settings.github_oauth_client_secret_encrypted.encode()) client_id = fernet.decrypt(client_id_encrypted).decode() client_secret = fernet.decrypt(client_secret_encrypted).decode() @@ -825,10 +800,7 @@ def get_github_oauth_credentials_sync(db) -> tuple[str | None, str | None]: return (client_id, client_secret) # Fall back to environment variables - if ( - app_settings.github_integration_oauth_client_id - and app_settings.github_integration_oauth_client_secret - ): + if app_settings.github_integration_oauth_client_id and app_settings.github_integration_oauth_client_secret: return ( app_settings.github_integration_oauth_client_id, app_settings.github_integration_oauth_client_secret, diff --git a/backend/app/services/prefix_service.py b/backend/app/services/prefix_service.py index eb57deb..1c8bfa7 100644 --- a/backend/app/services/prefix_service.py +++ b/backend/app/services/prefix_service.py @@ -1,4 +1,5 @@ """Service for generating and validating project ticket prefixes.""" + import logging import re from typing import List, Optional, Tuple @@ -125,9 +126,7 @@ async def generate_prefix( # Try LLM generation try: - prefix = await PrefixService._generate_with_llm( - project_name, existing_prefixes, llm_settings, org_id - ) + prefix = await PrefixService._generate_with_llm(project_name, existing_prefixes, llm_settings, org_id) if prefix and PrefixService.is_prefix_available(db, org_id, prefix): return prefix else: @@ -193,10 +192,12 @@ async def _generate_with_llm( ) # Call LLM - result = await model_client.create([ - SystemMessage(content=system_message), - UserMessage(content=user_message, source="user"), - ]) + result = await model_client.create( + [ + SystemMessage(content=system_message), + UserMessage(content=user_message, source="user"), + ] + ) # Clean and validate the response prefix = result.content.strip().upper() diff --git a/backend/app/services/project_chat_service.py b/backend/app/services/project_chat_service.py index 34cf4d5..e59090e 100644 --- a/backend/app/services/project_chat_service.py +++ b/backend/app/services/project_chat_service.py @@ -1,35 +1,33 @@ """Service for managing project chats.""" + import logging -from datetime import datetime, timezone -from typing import Optional, List, Dict, Any, Tuple +from typing import Any, Dict, List, Optional, Tuple from uuid import UUID from sqlalchemy import func, or_ from sqlalchemy.orm import Session, joinedload +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype +from app.models.feature import Feature, FeaturePriority, FeatureProvenance, FeatureType, FeatureVisibilityStatus +from app.models.feature_content_version import FeatureContentType +from app.models.grounding_file import GroundingFile +from app.models.module import Module, ModuleProvenance, ModuleType +from app.models.project import Project, ProjectType from app.models.project_chat import ( ProjectChat, ProjectChatMessage, ProjectChatMessageType, ProjectChatVisibility, ) -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype -from app.models.project import Project, ProjectType -from app.models.organization import Organization -from app.models.grounding_file import GroundingFile -from app.models.module import Module, ModuleProvenance, ModuleType -from app.models.feature import Feature, FeatureProvenance, FeatureType, FeaturePriority, FeatureVisibilityStatus -from app.models.thread import Thread, ContextType +from app.models.thread import ContextType, Thread from app.models.user import User -from app.models.feature_content_version import FeatureContentType from app.services.brainstorming_phase_service import BrainstormingPhaseService -from app.services.project_service import ProjectService -from app.services.module_service import ModuleService -from app.services.feature_service import FeatureService from app.services.feature_content_version_service import FeatureContentVersionService +from app.services.feature_service import FeatureService +from app.services.module_service import ModuleService +from app.services.project_service import ProjectService from app.utils.short_id import build_url_identifier - logger = logging.getLogger(__name__) @@ -69,17 +67,11 @@ def create_project_chat( db.commit() db.refresh(project_chat) - logger.info( - f"Created project chat {project_chat.id} for project {project_id}" - ) + logger.info(f"Created project chat {project_chat.id} for project {project_id}") return project_chat @staticmethod - def create_org_project_chat( - db: Session, - org_id: UUID, - created_by: UUID - ) -> ProjectChat: + def create_org_project_chat(db: Session, org_id: UUID, created_by: UUID) -> ProjectChat: """ Create a new org-scoped project chat (for creating a new project). @@ -103,16 +95,11 @@ def create_org_project_chat( db.commit() db.refresh(project_chat) - logger.info( - f"Created org-scoped project chat {project_chat.id} for org {org_id}" - ) + logger.info(f"Created org-scoped project chat {project_chat.id} for org {org_id}") return project_chat @staticmethod - def get_project_chat( - db: Session, - project_chat_id: UUID - ) -> Optional[ProjectChat]: + def get_project_chat(db: Session, project_chat_id: UUID) -> Optional[ProjectChat]: """ Get a project chat by ID with messages loaded. @@ -123,16 +110,15 @@ def get_project_chat( Returns: ProjectChat with messages, or None if not found """ - return db.query(ProjectChat).options( - joinedload(ProjectChat.messages) - ).filter( - ProjectChat.id == project_chat_id - ).first() + return ( + db.query(ProjectChat) + .options(joinedload(ProjectChat.messages)) + .filter(ProjectChat.id == project_chat_id) + .first() + ) @staticmethod - def get_by_identifier( - db: Session, identifier: str - ) -> Optional[ProjectChat]: + def get_by_identifier(db: Session, identifier: str) -> Optional[ProjectChat]: """Get a project chat by UUID, short_id, or URL identifier. This method supports backward compatibility with existing UUID-based URLs @@ -160,17 +146,15 @@ def get_by_identifier( # Extract short_id and query short_id = extract_short_id(identifier) - return db.query(ProjectChat).options( - joinedload(ProjectChat.messages) - ).filter( - ProjectChat.short_id == short_id - ).first() + return ( + db.query(ProjectChat) + .options(joinedload(ProjectChat.messages)) + .filter(ProjectChat.short_id == short_id) + .first() + ) @staticmethod - def get_project_chat_for_project( - db: Session, - project_id: UUID - ) -> Optional[ProjectChat]: + def get_project_chat_for_project(db: Session, project_id: UUID) -> Optional[ProjectChat]: """ Get the most recent project chat for a project. @@ -181,19 +165,16 @@ def get_project_chat_for_project( Returns: ProjectChat with messages, or None if not found """ - return db.query(ProjectChat).options( - joinedload(ProjectChat.messages) - ).filter( - ProjectChat.project_id == project_id - ).order_by( - ProjectChat.created_at.desc() - ).first() + return ( + db.query(ProjectChat) + .options(joinedload(ProjectChat.messages)) + .filter(ProjectChat.project_id == project_id) + .order_by(ProjectChat.created_at.desc()) + .first() + ) @staticmethod - def list_project_chats( - db: Session, - project_id: UUID - ) -> List[ProjectChat]: + def list_project_chats(db: Session, project_id: UUID) -> List[ProjectChat]: """ List all project chats for a project. @@ -204,19 +185,16 @@ def list_project_chats( Returns: List of ProjectChat objects """ - return db.query(ProjectChat).filter( - ProjectChat.project_id == project_id - ).order_by( - ProjectChat.created_at.desc() - ).all() + return ( + db.query(ProjectChat) + .filter(ProjectChat.project_id == project_id) + .order_by(ProjectChat.created_at.desc()) + .all() + ) @staticmethod def list_project_chats_paginated( - db: Session, - project_id: UUID, - user_id: UUID, - limit: int = 20, - offset: int = 0 + db: Session, project_id: UUID, user_id: UUID, limit: int = 20, offset: int = 0 ) -> Tuple[List[Dict[str, Any]], int, bool]: """ List project chats for a project with pagination for sidebar. @@ -241,42 +219,49 @@ def list_project_chats_paginated( ) # Get total count with visibility filter - total = db.query(func.count(ProjectChat.id)).filter( - ProjectChat.project_id == project_id, - visibility_filter, - ).scalar() or 0 + total = ( + db.query(func.count(ProjectChat.id)) + .filter( + ProjectChat.project_id == project_id, + visibility_filter, + ) + .scalar() + or 0 + ) # Subquery for message counts - message_count_subq = db.query( - ProjectChatMessage.project_chat_id, - func.count(ProjectChatMessage.id).label("message_count") - ).group_by( - ProjectChatMessage.project_chat_id - ).subquery() + message_count_subq = ( + db.query(ProjectChatMessage.project_chat_id, func.count(ProjectChatMessage.id).label("message_count")) + .group_by(ProjectChatMessage.project_chat_id) + .subquery() + ) # Get chats with message count via subquery join # Select specific columns to avoid eager loading relationships - chats = db.query( - ProjectChat.id, - ProjectChat.chat_title, - ProjectChat.proposed_title, - ProjectChat.created_at, - ProjectChat.updated_at, - ProjectChat.created_phase_id, - ProjectChat.created_feature_ids, - ProjectChat.short_id, - ProjectChat.visibility, - ProjectChat.created_by, - func.coalesce(message_count_subq.c.message_count, 0).label("message_count") - ).outerjoin( - message_count_subq, - message_count_subq.c.project_chat_id == ProjectChat.id - ).filter( - ProjectChat.project_id == project_id, - visibility_filter, - ).order_by( - ProjectChat.created_at.desc() - ).offset(offset).limit(limit).all() + chats = ( + db.query( + ProjectChat.id, + ProjectChat.chat_title, + ProjectChat.proposed_title, + ProjectChat.created_at, + ProjectChat.updated_at, + ProjectChat.created_phase_id, + ProjectChat.created_feature_ids, + ProjectChat.short_id, + ProjectChat.visibility, + ProjectChat.created_by, + func.coalesce(message_count_subq.c.message_count, 0).label("message_count"), + ) + .outerjoin(message_count_subq, message_count_subq.c.project_chat_id == ProjectChat.id) + .filter( + ProjectChat.project_id == project_id, + visibility_filter, + ) + .order_by(ProjectChat.created_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) has_more = offset + limit < total @@ -285,29 +270,29 @@ def list_project_chats_paginated( # Build url_identifier from chat_title/proposed_title and short_id title = row.chat_title or row.proposed_title or "chat" from app.utils.short_id import slugify + url_identifier = f"{slugify(title)}-{row.short_id}" - result.append({ - "id": row.id, - "chat_title": row.chat_title, - "proposed_title": row.proposed_title, - "created_at": row.created_at, - "updated_at": row.updated_at, - "created_by": row.created_by, - "created_phase_id": row.created_phase_id, - "created_feature_count": len(row.created_feature_ids) if row.created_feature_ids else 0, - "message_count": row.message_count, - "short_id": row.short_id, - "url_identifier": url_identifier, - "visibility": row.visibility.value if row.visibility else "private", - }) + result.append( + { + "id": row.id, + "chat_title": row.chat_title, + "proposed_title": row.proposed_title, + "created_at": row.created_at, + "updated_at": row.updated_at, + "created_by": row.created_by, + "created_phase_id": row.created_phase_id, + "created_feature_count": len(row.created_feature_ids) if row.created_feature_ids else 0, + "message_count": row.message_count, + "short_id": row.short_id, + "url_identifier": url_identifier, + "visibility": row.visibility.value if row.visibility else "private", + } + ) return result, total, has_more @staticmethod - def get_org_project_chat( - db: Session, - org_id: UUID - ) -> Optional[ProjectChat]: + def get_org_project_chat(db: Session, org_id: UUID) -> Optional[ProjectChat]: """ Get the most recent org-scoped project chat (no project yet). @@ -318,22 +303,20 @@ def get_org_project_chat( Returns: ProjectChat with messages, or None if not found """ - return db.query(ProjectChat).options( - joinedload(ProjectChat.messages) - ).filter( - ProjectChat.org_id == org_id, - ProjectChat.project_id.is_(None), # Org-scoped only - ).order_by( - ProjectChat.created_at.desc() - ).first() + return ( + db.query(ProjectChat) + .options(joinedload(ProjectChat.messages)) + .filter( + ProjectChat.org_id == org_id, + ProjectChat.project_id.is_(None), # Org-scoped only + ) + .order_by(ProjectChat.created_at.desc()) + .first() + ) @staticmethod def list_org_project_chats_paginated( - db: Session, - org_id: UUID, - user_id: UUID, - limit: int = 20, - offset: int = 0 + db: Session, org_id: UUID, user_id: UUID, limit: int = 20, offset: int = 0 ) -> Tuple[List[Dict[str, Any]], int, bool]: """ List org-scoped project chats for a specific user with pagination. @@ -349,44 +332,49 @@ def list_org_project_chats_paginated( Tuple of (chats with message counts, total count, has_more) """ # Subquery for message counts - message_count_subq = db.query( - ProjectChatMessage.project_chat_id, - func.count(ProjectChatMessage.id).label("message_count") - ).group_by( - ProjectChatMessage.project_chat_id - ).subquery() + message_count_subq = ( + db.query(ProjectChatMessage.project_chat_id, func.count(ProjectChatMessage.id).label("message_count")) + .group_by(ProjectChatMessage.project_chat_id) + .subquery() + ) # Get total count of org-scoped chats WITH at least one message - total = db.query(func.count(ProjectChat.id)).join( - message_count_subq, - message_count_subq.c.project_chat_id == ProjectChat.id - ).filter( - ProjectChat.org_id == org_id, - ProjectChat.created_by == user_id, # User-specific - ProjectChat.project_id.is_(None), # Org-scoped only - ).scalar() or 0 + total = ( + db.query(func.count(ProjectChat.id)) + .join(message_count_subq, message_count_subq.c.project_chat_id == ProjectChat.id) + .filter( + ProjectChat.org_id == org_id, + ProjectChat.created_by == user_id, # User-specific + ProjectChat.project_id.is_(None), # Org-scoped only + ) + .scalar() + or 0 + ) # Get chats with message count via inner join (excludes empty chats) - chats = db.query( - ProjectChat.id, - ProjectChat.chat_title, - ProjectChat.proposed_title, - ProjectChat.proposed_project_name, - ProjectChat.created_at, - ProjectChat.updated_at, - ProjectChat.created_project_id, - ProjectChat.short_id, - message_count_subq.c.message_count.label("message_count") - ).join( - message_count_subq, - message_count_subq.c.project_chat_id == ProjectChat.id - ).filter( - ProjectChat.org_id == org_id, - ProjectChat.created_by == user_id, # User-specific - ProjectChat.project_id.is_(None), # Org-scoped only - ).order_by( - ProjectChat.created_at.desc() - ).offset(offset).limit(limit).all() + chats = ( + db.query( + ProjectChat.id, + ProjectChat.chat_title, + ProjectChat.proposed_title, + ProjectChat.proposed_project_name, + ProjectChat.created_at, + ProjectChat.updated_at, + ProjectChat.created_project_id, + ProjectChat.short_id, + message_count_subq.c.message_count.label("message_count"), + ) + .join(message_count_subq, message_count_subq.c.project_chat_id == ProjectChat.id) + .filter( + ProjectChat.org_id == org_id, + ProjectChat.created_by == user_id, # User-specific + ProjectChat.project_id.is_(None), # Org-scoped only + ) + .order_by(ProjectChat.created_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) has_more = offset + limit < total @@ -395,27 +383,25 @@ def list_org_project_chats_paginated( # Compute url_identifier from chat_title and short_id (mirrors model property) title = row.chat_title or "chat" url_identifier = build_url_identifier(title, row.short_id) - result.append({ - "id": row.id, - "chat_title": row.chat_title, - "proposed_title": row.proposed_title, - "proposed_project_name": row.proposed_project_name, - "created_at": row.created_at, - "updated_at": row.updated_at, - "created_project_id": row.created_project_id, - "short_id": row.short_id, - "url_identifier": url_identifier, - "message_count": row.message_count, - }) + result.append( + { + "id": row.id, + "chat_title": row.chat_title, + "proposed_title": row.proposed_title, + "proposed_project_name": row.proposed_project_name, + "created_at": row.created_at, + "updated_at": row.updated_at, + "created_project_id": row.created_project_id, + "short_id": row.short_id, + "url_identifier": url_identifier, + "message_count": row.message_count, + } + ) return result, total, has_more @staticmethod - def set_created_phase( - db: Session, - project_chat_id: UUID, - phase_id: UUID - ) -> None: + def set_created_phase(db: Session, project_chat_id: UUID, phase_id: UUID) -> None: """ Link a project chat to the phase it created. @@ -424,9 +410,7 @@ def set_created_phase( project_chat_id: ID of the project chat phase_id: ID of the created phase """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if project_chat: project_chat.created_phase_id = phase_id @@ -495,7 +479,7 @@ def add_bot_message( project_chat_id: UUID, content: str, response_data: Optional[Dict[str, Any]] = None, - job_id: Optional[UUID] = None + job_id: Optional[UUID] = None, ) -> ProjectChatMessage: """ Add a bot message and update project chat state. @@ -514,9 +498,7 @@ def add_bot_message( Created ProjectChatMessage """ # Get current project chat state for snapshot - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() # Capture snapshot BEFORE updating summary_snapshot = project_chat.running_summary if project_chat else None @@ -549,6 +531,7 @@ def add_bot_message( if "proposed_feature_module_id" in response_data: module_id = response_data["proposed_feature_module_id"] from uuid import UUID as PyUUID + project_chat.proposed_feature_module_id = PyUUID(module_id) if module_id else None if "proposed_feature_module_title" in response_data: project_chat.proposed_feature_module_title = response_data["proposed_feature_module_title"] @@ -623,10 +606,10 @@ def add_web_search_message( content_lines.append("\n**Sources:**") for i, result in enumerate(results[:5], 1): content_lines.append(f"{i}. [{result['title']}]({result['url']})") - if result.get('content'): + if result.get("content"): # Truncate content to first 600 chars - content = result['content'][:600] - if len(result['content']) > 600: + content = result["content"][:600] + if len(result["content"]) > 600: content += "..." content_lines.append(f" > {content}") @@ -655,9 +638,7 @@ def add_web_search_message( db.refresh(message) # Broadcast the message - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if project_chat: ProjectChatService.broadcast_message_created(db, project_chat, message) @@ -666,11 +647,7 @@ def add_web_search_message( return message @staticmethod - def update_project_chat_summary( - db: Session, - project_chat_id: UUID, - running_summary: str - ) -> None: + def update_project_chat_summary(db: Session, project_chat_id: UUID, running_summary: str) -> None: """ Update the running summary of a project chat. @@ -679,20 +656,14 @@ def update_project_chat_summary( project_chat_id: ID of the project chat running_summary: New running summary """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if project_chat: project_chat.running_summary = running_summary db.commit() @staticmethod - def set_generating( - db: Session, - project_chat_id: UUID, - is_generating: bool - ) -> None: + def set_generating(db: Session, project_chat_id: UUID, is_generating: bool) -> None: """ Set the generating flag for a project chat. @@ -701,20 +672,14 @@ def set_generating( project_chat_id: ID of the project chat is_generating: Whether the bot is generating a response """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if project_chat: project_chat.is_generating = is_generating db.commit() @staticmethod - def update_visibility( - db: Session, - project_chat_id: UUID, - visibility: ProjectChatVisibility - ) -> ProjectChat: + def update_visibility(db: Session, project_chat_id: UUID, visibility: ProjectChatVisibility) -> ProjectChat: """ Update the visibility of a project chat. @@ -729,9 +694,7 @@ def update_visibility( Raises: ValueError: If project chat not found """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: raise ValueError("Project chat not found") @@ -743,18 +706,12 @@ def update_visibility( # Broadcast the update ProjectChatService.broadcast_project_chat_updated(db, project_chat) - logger.info( - f"Updated visibility for project chat {project_chat_id} to {visibility.value}" - ) + logger.info(f"Updated visibility for project chat {project_chat_id} to {visibility.value}") return project_chat @staticmethod def set_ai_error_state( - db: Session, - project_chat_id: UUID, - error_message: str, - job_id: Optional[UUID], - user_message: Optional[str] + db: Session, project_chat_id: UUID, error_message: str, job_id: Optional[UUID], user_message: Optional[str] ) -> None: """ Set AI error state on project chat for persistence across page refreshes. @@ -766,9 +723,7 @@ def set_ai_error_state( job_id: The failed job ID user_message: The user's original message (for retry) """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: logger.warning(f"Cannot set AI error state: project chat {project_chat_id} not found") @@ -793,9 +748,7 @@ def clear_ai_error_state(db: Session, project_chat_id: UUID) -> None: db: Database session project_chat_id: ID of the project chat """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: logger.warning(f"Cannot clear AI error state: project chat {project_chat_id} not found") @@ -812,10 +765,7 @@ def clear_ai_error_state(db: Session, project_chat_id: UUID) -> None: logger.info(f"Cleared AI error state for project chat {project_chat_id}") @staticmethod - def delete_project_chat( - db: Session, - project_chat_id: UUID - ) -> bool: + def delete_project_chat(db: Session, project_chat_id: UUID) -> bool: """ Delete a project chat and all its messages. @@ -826,9 +776,7 @@ def delete_project_chat( Returns: True if deleted, False if not found """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: return False @@ -840,10 +788,7 @@ def delete_project_chat( return True @staticmethod - def _collect_project_chat_images( - db: Session, - project_chat_id: UUID - ) -> list[dict] | None: + def _collect_project_chat_images(db: Session, project_chat_id: UUID) -> list[dict] | None: """Collect all unique images from user messages in the project chat. Args: @@ -853,11 +798,15 @@ def _collect_project_chat_images( Returns: List of image attachment dicts, or None if no images """ - messages = db.query(ProjectChatMessage).filter( - ProjectChatMessage.project_chat_id == project_chat_id, - ProjectChatMessage.message_type == ProjectChatMessageType.USER, - ProjectChatMessage.images.isnot(None), - ).all() + messages = ( + db.query(ProjectChatMessage) + .filter( + ProjectChatMessage.project_chat_id == project_chat_id, + ProjectChatMessage.message_type == ProjectChatMessageType.USER, + ProjectChatMessage.images.isnot(None), + ) + .all() + ) all_images = [] seen_ids = set() @@ -872,11 +821,7 @@ def _collect_project_chat_images( return all_images if all_images else None @staticmethod - def create_phase_from_project_chat( - db: Session, - project_chat_id: UUID, - user_id: UUID - ) -> BrainstormingPhase: + def create_phase_from_project_chat(db: Session, project_chat_id: UUID, user_id: UUID) -> BrainstormingPhase: """ Create a brainstorming phase from a project chat's proposed values. @@ -896,12 +841,10 @@ def create_phase_from_project_chat( Raises: ValueError: If project chat not found, not ready, or missing proposed values """ - from app.services.phase_container_service import PhaseContainerService from app.models.phase_container import PhaseContainer + from app.services.phase_container_service import PhaseContainerService - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: raise ValueError("Project chat not found") @@ -919,20 +862,14 @@ def create_phase_from_project_chat( # Auto-switch to TEAM visibility when creating a phase if project_chat.visibility == ProjectChatVisibility.PRIVATE: project_chat.visibility = ProjectChatVisibility.TEAM - logger.info( - f"Auto-switched project chat {project_chat_id} to TEAM visibility on phase creation" - ) + logger.info(f"Auto-switched project chat {project_chat_id} to TEAM visibility on phase creation") # Collect images from user messages in the project chat - description_images = ProjectChatService._collect_project_chat_images( - db, project_chat_id - ) + description_images = ProjectChatService._collect_project_chat_images(db, project_chat_id) if project_chat.target_container_id: # Extension path: create phase inside an existing container - container = PhaseContainerService.get_container( - db, project_chat.target_container_id - ) + container = PhaseContainerService.get_container(db, project_chat.target_container_id) if not container: raise ValueError("Target container not found") if container.archived_at is not None: @@ -952,8 +889,7 @@ def create_phase_from_project_chat( db.flush() logger.info( - f"Created extension phase {phase.id} in container {container.id} " - f"from project chat {project_chat_id}" + f"Created extension phase {phase.id} in container {container.id} from project chat {project_chat_id}" ) else: # Standalone path: create phase and auto-wrap in a new container @@ -991,8 +927,7 @@ def create_phase_from_project_chat( db.flush() logger.info( - f"Created phase {phase.id} with auto-container {container.id} " - f"from project chat {project_chat_id}" + f"Created phase {phase.id} with auto-container {container.id} from project chat {project_chat_id}" ) # Link the project chat to the created phase @@ -1028,9 +963,7 @@ def create_feature_from_project_chat( Raises: ValueError: If project chat not found or missing required fields """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: raise ValueError("Project chat not found") @@ -1041,9 +974,7 @@ def create_feature_from_project_chat( # Auto-switch to TEAM visibility when creating a feature if project_chat.visibility == ProjectChatVisibility.PRIVATE: project_chat.visibility = ProjectChatVisibility.TEAM - logger.info( - f"Auto-switched project chat {project_chat_id} to TEAM visibility on feature creation" - ) + logger.info(f"Auto-switched project chat {project_chat_id} to TEAM visibility on feature creation") # Determine module (override > proposed existing > proposed new) module = None @@ -1065,9 +996,7 @@ def create_feature_from_project_chat( ) is_new_module = True elif project_chat.proposed_feature_module_id: - module = db.query(Module).filter( - Module.id == project_chat.proposed_feature_module_id - ).first() + module = db.query(Module).filter(Module.id == project_chat.proposed_feature_module_id).first() if not module: raise ValueError("Proposed module not found") elif project_chat.proposed_feature_module_title: @@ -1085,9 +1014,7 @@ def create_feature_from_project_chat( raise ValueError("No module specified for feature") # Collect images from user messages in the project chat - description_images = ProjectChatService._collect_project_chat_images( - db, project_chat_id - ) + description_images = ProjectChatService._collect_project_chat_images(db, project_chat_id) # Create feature feature = FeatureService.create_feature( @@ -1166,9 +1093,7 @@ def get_created_features( Returns: List of feature info dicts """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat or not project_chat.created_feature_ids: return [] @@ -1176,30 +1101,24 @@ def get_created_features( # Convert string UUIDs to UUID objects for the query feature_uuids = [UUID(fid) for fid in project_chat.created_feature_ids] - features = db.query(Feature).options( - joinedload(Feature.module) - ).filter( - Feature.id.in_(feature_uuids) - ).all() + features = db.query(Feature).options(joinedload(Feature.module)).filter(Feature.id.in_(feature_uuids)).all() result = [] for feature in features: - result.append({ - "id": str(feature.id), - "feature_key": feature.feature_key, - "title": feature.title, - "module_id": str(feature.module_id), - "module_title": feature.module.title if feature.module else "Unknown", - }) + result.append( + { + "id": str(feature.id), + "feature_key": feature.feature_key, + "title": feature.title, + "module_id": str(feature.module_id), + "module_title": feature.module.title if feature.module else "Unknown", + } + ) return result @staticmethod - def get_conversation_history( - db: Session, - project_chat_id: UUID, - limit: int = 10 - ) -> List[Dict[str, Any]]: + def get_conversation_history(db: Session, project_chat_id: UUID, limit: int = 10) -> List[Dict[str, Any]]: """ Get recent conversation history for agent context. @@ -1211,11 +1130,13 @@ def get_conversation_history( Returns: List of message dicts with type and content """ - messages = db.query(ProjectChatMessage).filter( - ProjectChatMessage.project_chat_id == project_chat_id - ).order_by( - ProjectChatMessage.created_at.desc() - ).limit(limit).all() + messages = ( + db.query(ProjectChatMessage) + .filter(ProjectChatMessage.project_chat_id == project_chat_id) + .order_by(ProjectChatMessage.created_at.desc()) + .limit(limit) + .all() + ) # Reverse to get chronological order messages = list(reversed(messages)) @@ -1230,10 +1151,7 @@ def get_conversation_history( ] @staticmethod - def build_context_for_agent( - db: Session, - project_chat_id: UUID - ) -> Dict[str, Any]: + def build_context_for_agent(db: Session, project_chat_id: UUID) -> Dict[str, Any]: """ Build the full context needed for the agent. @@ -1244,12 +1162,12 @@ def build_context_for_agent( Returns: Dict with project info, grounding context, phases, and conversation """ - project_chat = db.query(ProjectChat).options( - joinedload(ProjectChat.project), - joinedload(ProjectChat.messages) - ).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = ( + db.query(ProjectChat) + .options(joinedload(ProjectChat.project), joinedload(ProjectChat.messages)) + .filter(ProjectChat.id == project_chat_id) + .first() + ) if not project_chat: raise ValueError("Project chat not found") @@ -1257,10 +1175,11 @@ def build_context_for_agent( project = project_chat.project # Check for grounding (brownfield indicator) - grounding_file = db.query(GroundingFile).filter( - GroundingFile.project_id == project.id, - GroundingFile.filename == "agents.md" - ).first() + grounding_file = ( + db.query(GroundingFile) + .filter(GroundingFile.project_id == project.id, GroundingFile.filename == "agents.md") + .first() + ) has_grounding = False grounding_summary = None @@ -1269,10 +1188,11 @@ def build_context_for_agent( grounding_summary = grounding_file.summary # Get existing phases - existing_phases = db.query(BrainstormingPhase).filter( - BrainstormingPhase.project_id == project.id, - BrainstormingPhase.archived_at.is_(None) - ).all() + existing_phases = ( + db.query(BrainstormingPhase) + .filter(BrainstormingPhase.project_id == project.id, BrainstormingPhase.archived_at.is_(None)) + .all() + ) phase_summaries = [ { @@ -1305,11 +1225,7 @@ def build_context_for_agent( } @staticmethod - def broadcast_message_created( - db: Session, - project_chat: ProjectChat, - message: ProjectChatMessage - ) -> None: + def broadcast_message_created(db: Session, project_chat: ProjectChat, message: ProjectChatMessage) -> None: """ Broadcast a message creation via Kafka for WebSocket delivery. @@ -1359,10 +1275,7 @@ def broadcast_message_created( logger.info(f"Broadcasted project chat message created: project_chat_id={project_chat.id}") @staticmethod - def broadcast_project_chat_updated( - db: Session, - project_chat: ProjectChat - ) -> None: + def broadcast_project_chat_updated(db: Session, project_chat: ProjectChat) -> None: """ Broadcast a project chat state update via Kafka for WebSocket delivery. @@ -1405,7 +1318,9 @@ def broadcast_project_chat_updated( "ready_to_create_feature": project_chat.ready_to_create_feature, "proposed_feature_title": project_chat.proposed_feature_title, "proposed_feature_description": project_chat.proposed_feature_description, - "proposed_feature_module_id": str(project_chat.proposed_feature_module_id) if project_chat.proposed_feature_module_id else None, + "proposed_feature_module_id": str(project_chat.proposed_feature_module_id) + if project_chat.proposed_feature_module_id + else None, "proposed_feature_module_title": project_chat.proposed_feature_module_title, "proposed_feature_module_description": project_chat.proposed_feature_module_description, # Project proposal fields (org-scoped only) @@ -1420,7 +1335,9 @@ def broadcast_project_chat_updated( "is_searching_web": project_chat.is_searching_web, "exploring_code_prompt": project_chat.exploring_code_prompt, "searching_web_query": project_chat.searching_web_query, - "last_exploration_id": str(project_chat.last_exploration_id) if project_chat.last_exploration_id else None, + "last_exploration_id": str(project_chat.last_exploration_id) + if project_chat.last_exploration_id + else None, "ai_error_message": project_chat.ai_error_message, "ai_error_user_message": project_chat.ai_error_user_message, "retry_status": project_chat.retry_status, @@ -1465,29 +1382,23 @@ def _trigger_mention_notifications( author_id: ID of the user who made the mention mentioned_user_ids: List of user UUIDs to notify """ - from app.services.kafka_producer import get_sync_kafka_producer - from app.services.job_service import JobService from app.models.job import JobType + from app.services.job_service import JobService + from app.services.kafka_producer import get_sync_kafka_producer # Skip for org-scoped project chats (no project to link to) if not project_chat.project_id: - logger.debug( - f"Skipping mention notifications for org-scoped project chat {project_chat.id}" - ) + logger.debug(f"Skipping mention notifications for org-scoped project chat {project_chat.id}") return # Skip user notifications for private chats (only MFBTAI can be mentioned) if project_chat.visibility == ProjectChatVisibility.PRIVATE: - logger.debug( - f"Skipping mention notifications for private project chat {project_chat.id}" - ) + logger.debug(f"Skipping mention notifications for private project chat {project_chat.id}") return project = db.query(Project).filter(Project.id == project_chat.project_id).first() if not project: - logger.warning( - f"Cannot trigger mention notifications: project {project_chat.project_id} not found" - ) + logger.warning(f"Cannot trigger mention notifications: project {project_chat.project_id} not found") return # Create the job @@ -1525,9 +1436,7 @@ def _trigger_mention_notifications( key=str(project_chat.id), ) - logger.info( - f"Triggered project chat mention notification job {job.id} for {len(mentioned_user_ids)} users" - ) + logger.info(f"Triggered project chat mention notification job {job.id} for {len(mentioned_user_ids)} users") except Exception as e: logger.error(f"Failed to publish project chat mention notification job to Kafka: {e}") @@ -1556,9 +1465,7 @@ def create_project_from_project_chat( Raises: ValueError: If project chat not found, not org-scoped, not ready, or missing values """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: raise ValueError("Project chat not found") @@ -1600,9 +1507,7 @@ def create_project_from_project_chat( ) # Collect images from user messages in the project chat - description_images = ProjectChatService._collect_project_chat_images( - db, project_chat_id - ) + description_images = ProjectChatService._collect_project_chat_images(db, project_chat_id) # Create the first brainstorming phase phase = BrainstormingPhaseService.create_brainstorming_phase( @@ -1621,9 +1526,7 @@ def create_project_from_project_chat( project_chat.created_phase_id = phase.id db.commit() - logger.info( - f"Created project {project.id} and phase {phase.id} from project chat {project_chat_id}" - ) + logger.info(f"Created project {project.id} and phase {phase.id} from project chat {project_chat_id}") return project, phase @@ -1652,9 +1555,7 @@ def delete_message( Raises: ValueError: If message not found, not a user message, or not the author """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: raise ValueError("Project chat not found") @@ -1662,10 +1563,14 @@ def delete_message( if project_chat.is_readonly: raise ValueError("Cannot delete messages from a readonly project chat") - message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.id == message_id, - ProjectChatMessage.project_chat_id == project_chat_id, - ).first() + message = ( + db.query(ProjectChatMessage) + .filter( + ProjectChatMessage.id == message_id, + ProjectChatMessage.project_chat_id == project_chat_id, + ) + .first() + ) if not message: raise ValueError("Message not found") @@ -1677,11 +1582,16 @@ def delete_message( raise ValueError("Only the author can delete their messages") # Find the previous message with a snapshot to restore running_summary - previous_message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.project_chat_id == project_chat_id, - ProjectChatMessage.created_at < message.created_at, - ProjectChatMessage.summary_snapshot.isnot(None), - ).order_by(ProjectChatMessage.created_at.desc()).first() + previous_message = ( + db.query(ProjectChatMessage) + .filter( + ProjectChatMessage.project_chat_id == project_chat_id, + ProjectChatMessage.created_at < message.created_at, + ProjectChatMessage.summary_snapshot.isnot(None), + ) + .order_by(ProjectChatMessage.created_at.desc()) + .first() + ) # Broadcast deletion BEFORE deleting ProjectChatService.broadcast_message_deleted(db, project_chat, message_id) @@ -1724,9 +1634,7 @@ def start_over_from_message( Raises: ValueError: If message not found, not a user message, or not the author """ - project_chat = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + project_chat = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not project_chat: raise ValueError("Project chat not found") @@ -1734,10 +1642,14 @@ def start_over_from_message( if project_chat.is_readonly: raise ValueError("Cannot start over in a readonly project chat") - message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.id == message_id, - ProjectChatMessage.project_chat_id == project_chat_id, - ).first() + message = ( + db.query(ProjectChatMessage) + .filter( + ProjectChatMessage.id == message_id, + ProjectChatMessage.project_chat_id == project_chat_id, + ) + .first() + ) if not message: raise ValueError("Message not found") @@ -1749,25 +1661,32 @@ def start_over_from_message( raise ValueError("Only the author can start over from their messages") # Find all messages to delete (including and after the target message) - messages_to_delete = db.query(ProjectChatMessage).filter( - ProjectChatMessage.project_chat_id == project_chat_id, - ProjectChatMessage.created_at >= message.created_at, - ).all() + messages_to_delete = ( + db.query(ProjectChatMessage) + .filter( + ProjectChatMessage.project_chat_id == project_chat_id, + ProjectChatMessage.created_at >= message.created_at, + ) + .all() + ) deleted_message_ids = [str(m.id) for m in messages_to_delete] deleted_count = len(messages_to_delete) # Find the previous message with a snapshot to restore running_summary - previous_message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.project_chat_id == project_chat_id, - ProjectChatMessage.created_at < message.created_at, - ProjectChatMessage.summary_snapshot.isnot(None), - ).order_by(ProjectChatMessage.created_at.desc()).first() + previous_message = ( + db.query(ProjectChatMessage) + .filter( + ProjectChatMessage.project_chat_id == project_chat_id, + ProjectChatMessage.created_at < message.created_at, + ProjectChatMessage.summary_snapshot.isnot(None), + ) + .order_by(ProjectChatMessage.created_at.desc()) + .first() + ) # Broadcast bulk deletion BEFORE deleting - ProjectChatService.broadcast_messages_bulk_deleted( - db, project_chat, deleted_message_ids - ) + ProjectChatService.broadcast_messages_bulk_deleted(db, project_chat, deleted_message_ids) # Delete all the messages for m in messages_to_delete: @@ -1814,8 +1733,7 @@ def start_over_from_message( ProjectChatService.broadcast_project_chat_updated(db, project_chat) logger.info( - f"Start over from message {message_id} in project chat {project_chat_id}, " - f"deleted {deleted_count} messages" + f"Start over from message {message_id} in project chat {project_chat_id}, deleted {deleted_count} messages" ) return { @@ -1824,11 +1742,7 @@ def start_over_from_message( } @staticmethod - def broadcast_message_deleted( - db: Session, - project_chat: ProjectChat, - message_id: UUID - ) -> None: + def broadcast_message_deleted(db: Session, project_chat: ProjectChat, message_id: UUID) -> None: """ Broadcast a message deletion via Kafka for WebSocket delivery. @@ -1858,11 +1772,7 @@ def broadcast_message_deleted( logger.info(f"Broadcasted project chat message deleted: message_id={message_id}") @staticmethod - def broadcast_messages_bulk_deleted( - db: Session, - project_chat: ProjectChat, - deleted_message_ids: List[str] - ) -> None: + def broadcast_messages_bulk_deleted(db: Session, project_chat: ProjectChat, deleted_message_ids: List[str]) -> None: """ Broadcast bulk message deletion via Kafka for WebSocket delivery. @@ -1921,9 +1831,7 @@ def toggle_reaction( from sqlalchemy.orm.attributes import flag_modified # Get the message - message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.id == message_id - ).first() + message = db.query(ProjectChatMessage).filter(ProjectChatMessage.id == message_id).first() if not message: raise ValueError("Message not found") @@ -1944,8 +1852,7 @@ def toggle_reaction( # Get user display name for the tooltip reacting_user = db.query(User).filter(User.id == user_id).first() # user_id is already UUID user_display_name = ( - reacting_user.display_name or reacting_user.email.split("@")[0] - if reacting_user else "Unknown" + reacting_user.display_name or reacting_user.email.split("@")[0] if reacting_user else "Unknown" ) # Find existing reaction for this emoji @@ -1977,13 +1884,15 @@ def toggle_reaction( existing_reaction["count"] += 1 else: # Create new reaction - reactions.append({ - "emoji": emoji, - "emoji_native": emoji_native, - "user_ids": [user_id_str], - "user_names": {user_id_str: user_display_name}, - "count": 1, - }) + reactions.append( + { + "emoji": emoji, + "emoji_native": emoji_native, + "user_ids": [user_id_str], + "user_names": {user_id_str: user_display_name}, + "count": 1, + } + ) response_data["reactions"] = reactions message.response_data = response_data @@ -1998,17 +1907,11 @@ def toggle_reaction( project_chat = message.project_chat ProjectChatService.broadcast_message_updated(db, project_chat, message) - logger.info( - f"Toggled reaction {emoji} on message {message_id} by user {user_id}: {action}" - ) + logger.info(f"Toggled reaction {emoji} on message {message_id} by user {user_id}: {action}") return message, action @staticmethod - def broadcast_message_updated( - db: Session, - project_chat: ProjectChat, - message: ProjectChatMessage - ) -> None: + def broadcast_message_updated(db: Session, project_chat: ProjectChat, message: ProjectChatMessage) -> None: """ Broadcast a message update via Kafka for WebSocket delivery. @@ -2018,7 +1921,6 @@ def broadcast_message_updated( message: The updated message """ from app.services.kafka_producer import get_sync_kafka_producer - from app.schemas.thread_item import Reaction # Use org_id directly from project_chat org_id = project_chat.org_id diff --git a/backend/app/services/project_repository_service.py b/backend/app/services/project_repository_service.py index 9f391b2..1f583a5 100644 --- a/backend/app/services/project_repository_service.py +++ b/backend/app/services/project_repository_service.py @@ -1,4 +1,5 @@ """ProjectRepository service for managing project repositories.""" + import logging from typing import Optional from uuid import UUID @@ -41,9 +42,7 @@ def _clear_phase_exploration_caches(db: Session, project_id: UUID) -> int: ) ) if updated > 0: - logger.info( - f"Cleared code exploration cache for {updated} phases in project {project_id}" - ) + logger.info(f"Cleared code exploration cache for {updated} phases in project {project_id}") return updated @@ -120,9 +119,7 @@ def create_repository( return repo @staticmethod - def get_repository_by_id( - db: Session, project_id: UUID, repo_id: UUID - ) -> Optional[ProjectRepository]: + def get_repository_by_id(db: Session, project_id: UUID, repo_id: UUID) -> Optional[ProjectRepository]: """Get a repository by ID. Args: @@ -143,9 +140,7 @@ def get_repository_by_id( ) @staticmethod - def get_repository_by_slug( - db: Session, project_id: UUID, slug: str - ) -> Optional[ProjectRepository]: + def get_repository_by_slug(db: Session, project_id: UUID, slug: str) -> Optional[ProjectRepository]: """Get a repository by slug. Args: @@ -184,9 +179,7 @@ def list_repositories(db: Session, project_id: UUID) -> list[ProjectRepository]: ) @staticmethod - def get_primary_repository( - db: Session, project_id: UUID - ) -> Optional[ProjectRepository]: + def get_primary_repository(db: Session, project_id: UUID) -> Optional[ProjectRepository]: """Get the primary repository (lowest sort_order). Args: @@ -280,9 +273,7 @@ def delete_repository(db: Session, project_id: UUID, slug: str) -> bool: return True @staticmethod - def reorder_repositories( - db: Session, project_id: UUID, slug_order: list[str] - ) -> list[ProjectRepository]: + def reorder_repositories(db: Session, project_id: UUID, slug_order: list[str]) -> list[ProjectRepository]: """Reorder repositories. Args: @@ -296,9 +287,7 @@ def reorder_repositories( Raises: ValueError: If any slug is not found or if slugs are missing """ - existing_repos = { - r.slug: r for r in ProjectRepositoryService.list_repositories(db, project_id) - } + existing_repos = {r.slug: r for r in ProjectRepositoryService.list_repositories(db, project_id)} # Verify all provided slugs exist for slug in slug_order: @@ -345,9 +334,7 @@ def clear_connector_for_repos(db: Session, connector_id: UUID) -> int: return updated @staticmethod - def broadcast_repository_update( - db: Session, repo: ProjectRepository, action: str - ) -> None: + def broadcast_repository_update(db: Session, repo: ProjectRepository, action: str) -> None: """Broadcast a repository update via WebSocket. Args: diff --git a/backend/app/services/project_service.py b/backend/app/services/project_service.py index 7f9de1e..54973df 100644 --- a/backend/app/services/project_service.py +++ b/backend/app/services/project_service.py @@ -1,22 +1,21 @@ """Project service for managing projects.""" + +from datetime import datetime, timezone from typing import Dict, Optional from uuid import UUID, uuid4 -from datetime import datetime, timezone -import uuid -from sqlalchemy.orm import Session -from sqlalchemy import and_ from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session -from app.models import Project, ProjectType, ProjectStatus, Organization, ProjectRole, User, SpecVersion -from app.models.project_share import ProjectShare, ShareSubjectType +from app.models import Organization, Project, ProjectRole, ProjectStatus, ProjectType, User from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module from app.models.feature import Feature from app.models.feature_content_version import FeatureContentVersion -from app.models.final_spec import FinalSpec from app.models.final_prompt_plan import FinalPromptPlan -from app.models.thread import Thread, ContextType +from app.models.final_spec import FinalSpec +from app.models.module import Module +from app.models.project_share import ProjectShare, ShareSubjectType +from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem from app.services.project_share_service import ProjectShareService @@ -196,11 +195,7 @@ def get_by_identifier(db: Session, identifier: str) -> Optional[Project]: # Extract short_id and query short_id = extract_short_id(identifier) - return ( - db.query(Project) - .filter(Project.short_id == short_id, Project.deleted_at.is_(None)) - .first() - ) + return db.query(Project).filter(Project.short_id == short_id, Project.deleted_at.is_(None)).first() @staticmethod def list_org_projects( @@ -281,9 +276,7 @@ def update_project( new_key = new_key.upper() # Validate uniqueness within org - if not PrefixService.is_prefix_available( - db, project.org_id, new_key, exclude_project_id=project_id - ): + if not PrefixService.is_prefix_available(db, project.org_id, new_key, exclude_project_id=project_id): raise ValueError(f"Prefix '{new_key}' is already in use") project.key = new_key @@ -397,9 +390,7 @@ def clone_project( db.rollback() # If it's a key uniqueness issue and we can retry, generate new key if "uq_projects_org_key" in str(e) and attempt < max_retries - 1: - cloned_key = ProjectService._generate_unique_key( - db, source.org_id, cloned_name - ) + cloned_key = ProjectService._generate_unique_key(db, source.org_id, cloned_name) else: raise @@ -419,11 +410,7 @@ def clone_project( # Clone phases if requested if include_phases: - source_phases = ( - db.query(BrainstormingPhase) - .filter(BrainstormingPhase.project_id == project_id) - .all() - ) + source_phases = db.query(BrainstormingPhase).filter(BrainstormingPhase.project_id == project_id).all() for phase in source_phases: new_phase = BrainstormingPhase( @@ -465,9 +452,7 @@ def clone_project( db.add(new_final_plan) # Clone modules (both phase-linked and orphaned) - source_modules = ( - db.query(Module).filter(Module.project_id == project_id).all() - ) + source_modules = db.query(Module).filter(Module.project_id == project_id).all() for module in source_modules: new_phase_id = None @@ -493,12 +478,7 @@ def clone_project( module_id_map[module.id] = new_module.id # Clone features - source_features = ( - db.query(Feature) - .join(Module) - .filter(Module.project_id == project_id) - .all() - ) + source_features = db.query(Feature).join(Module).filter(Module.project_id == project_id).all() for feature in source_features: new_module_id = module_id_map.get(feature.module_id) @@ -558,19 +538,12 @@ def clone_project( # Clone threads based on option if include_threads != "none": - source_threads = ( - db.query(Thread) - .filter(Thread.project_id == str(project_id)) - .all() - ) + source_threads = db.query(Thread).filter(Thread.project_id == str(project_id)).all() for thread in source_threads: # Map context_id if it's a feature new_context_id = thread.context_id - if ( - thread.context_type == ContextType.BRAINSTORM_FEATURE - and thread.context_id - ): + if thread.context_type == ContextType.BRAINSTORM_FEATURE and thread.context_id: try: old_feature_id = UUID(thread.context_id) if old_feature_id in feature_id_map: @@ -764,9 +737,7 @@ def list_project_children( Returns: list[Project]: List of child projects """ - query = db.query(Project).filter( - Project.parent_project_id == parent_project_id - ) + query = db.query(Project).filter(Project.parent_project_id == parent_project_id) if type_filter: query = query.filter(Project.type == type_filter.value) diff --git a/backend/app/services/project_share_service.py b/backend/app/services/project_share_service.py index 0384fa3..97a08df 100644 --- a/backend/app/services/project_share_service.py +++ b/backend/app/services/project_share_service.py @@ -3,12 +3,11 @@ from typing import Optional from uuid import UUID -from sqlalchemy import or_, select +from sqlalchemy import select from sqlalchemy.orm import Session from app.models.project_membership import ProjectRole from app.models.project_share import ProjectShare, ShareSubjectType -from app.models.user_group import UserGroup from app.models.user_group_membership import UserGroupMembership @@ -139,9 +138,7 @@ def create_org_share( Returns: The created or updated ProjectShare """ - return ProjectShareService.create_share( - db, project_id, ShareSubjectType.ORG, org_id, role, created_by_user_id - ) + return ProjectShareService.create_share(db, project_id, ShareSubjectType.ORG, org_id, role, created_by_user_id) @staticmethod def get_org_share(db: Session, project_id: UUID) -> Optional[ProjectShare]: @@ -249,9 +246,7 @@ def list_project_shares(db: Session, project_id: UUID) -> list[ProjectShare]: ) @staticmethod - def get_user_direct_share( - db: Session, project_id: UUID, user_id: UUID - ) -> Optional[ProjectShare]: + def get_user_direct_share(db: Session, project_id: UUID, user_id: UUID) -> Optional[ProjectShare]: """Get a user's direct share for a project. Args: @@ -273,9 +268,7 @@ def get_user_direct_share( ) @staticmethod - def get_user_group_shares( - db: Session, project_id: UUID, user_id: UUID - ) -> list[ProjectShare]: + def get_user_group_shares(db: Session, project_id: UUID, user_id: UUID) -> list[ProjectShare]: """Get group shares that apply to a user for a project. Args: @@ -288,9 +281,7 @@ def get_user_group_shares( """ # Get all group IDs the user belongs to user_group_ids = ( - select(UserGroupMembership.group_id) - .where(UserGroupMembership.user_id == user_id) - .scalar_subquery() + select(UserGroupMembership.group_id).where(UserGroupMembership.user_id == user_id).scalar_subquery() ) return ( @@ -304,9 +295,7 @@ def get_user_group_shares( ) @staticmethod - def _get_user_org_share( - db: Session, project_id: UUID, user_id: UUID - ) -> Optional[ProjectShare]: + def _get_user_org_share(db: Session, project_id: UUID, user_id: UUID) -> Optional[ProjectShare]: """Get org share if user is a member of the project's org. Args: @@ -317,8 +306,8 @@ def _get_user_org_share( Returns: ProjectShare or None if no org share or user not in org """ - from app.models.project import Project from app.models.org_membership import OrgMembership + from app.models.project import Project project = db.query(Project).filter(Project.id == project_id).first() if not project: @@ -348,9 +337,7 @@ def _get_user_org_share( ) @staticmethod - def get_user_effective_share( - db: Session, project_id: UUID, user_id: UUID - ) -> Optional[ProjectShare]: + def get_user_effective_share(db: Session, project_id: UUID, user_id: UUID) -> Optional[ProjectShare]: """Get a user's effective share for a project. Returns the share with the highest role. Precedence for equal roles: @@ -374,9 +361,9 @@ def get_user_effective_share( # Type precedence (higher is preferred when roles are equal) type_precedence = { - ShareSubjectType.USER: 2, # Direct share - highest + ShareSubjectType.USER: 2, # Direct share - highest ShareSubjectType.GROUP: 1, # Group share - middle - ShareSubjectType.ORG: 0, # Org share - lowest + ShareSubjectType.ORG: 0, # Org share - lowest } # Get direct share @@ -389,11 +376,7 @@ def get_user_effective_share( org_share = ProjectShareService._get_user_org_share(db, project_id, user_id) # Combine all shares - all_shares = ( - ([direct_share] if direct_share else []) - + group_shares - + ([org_share] if org_share else []) - ) + all_shares = ([direct_share] if direct_share else []) + group_shares + ([org_share] if org_share else []) if not all_shares: return None @@ -409,9 +392,7 @@ def get_user_effective_share( return all_shares[0] @staticmethod - def get_user_effective_role( - db: Session, project_id: UUID, user_id: UUID - ) -> Optional[ProjectRole]: + def get_user_effective_role(db: Session, project_id: UUID, user_id: UUID) -> Optional[ProjectRole]: """Get a user's effective role for a project. Args: @@ -452,9 +433,7 @@ def user_has_project_access(db: Session, project_id: UUID, user_id: UUID) -> boo # Check group access user_group_ids = ( - select(UserGroupMembership.group_id) - .where(UserGroupMembership.user_id == user_id) - .scalar_subquery() + select(UserGroupMembership.group_id).where(UserGroupMembership.user_id == user_id).scalar_subquery() ) group_access = ( @@ -494,9 +473,7 @@ def list_user_project_shares(db: Session, user_id: UUID) -> list[ProjectShare]: ) @staticmethod - def list_user_accessible_project_ids( - db: Session, user_id: UUID, org_id: Optional[UUID] = None - ) -> list[UUID]: + def list_user_accessible_project_ids(db: Session, user_id: UUID, org_id: Optional[UUID] = None) -> list[UUID]: """Get all project IDs a user has access to (direct + group + org shares). Args: @@ -507,52 +484,38 @@ def list_user_accessible_project_ids( Returns: List of project IDs the user can access """ - from app.models.project import Project from app.models.org_membership import OrgMembership + from app.models.project import Project # Get project IDs from direct user shares - direct_query = ( - select(ProjectShare.project_id) - .where( - ProjectShare.subject_type == ShareSubjectType.USER, - ProjectShare.subject_id == user_id, - ) + direct_query = select(ProjectShare.project_id).where( + ProjectShare.subject_type == ShareSubjectType.USER, + ProjectShare.subject_id == user_id, ) # Get group IDs the user belongs to user_group_ids = ( - select(UserGroupMembership.group_id) - .where(UserGroupMembership.user_id == user_id) - .scalar_subquery() + select(UserGroupMembership.group_id).where(UserGroupMembership.user_id == user_id).scalar_subquery() ) # Get project IDs from group shares - group_query = ( - select(ProjectShare.project_id) - .where( - ProjectShare.subject_type == ShareSubjectType.GROUP, - ProjectShare.subject_id.in_(user_group_ids), - ) + group_query = select(ProjectShare.project_id).where( + ProjectShare.subject_type == ShareSubjectType.GROUP, + ProjectShare.subject_id.in_(user_group_ids), ) # Get org IDs user belongs to - user_org_ids = ( - select(OrgMembership.org_id) - .where(OrgMembership.user_id == user_id) - .scalar_subquery() - ) + user_org_ids = select(OrgMembership.org_id).where(OrgMembership.user_id == user_id).scalar_subquery() # Get project IDs from org shares (where user is org member) - org_query = ( - select(ProjectShare.project_id) - .where( - ProjectShare.subject_type == ShareSubjectType.ORG, - ProjectShare.subject_id.in_(user_org_ids), - ) + org_query = select(ProjectShare.project_id).where( + ProjectShare.subject_type == ShareSubjectType.ORG, + ProjectShare.subject_id.in_(user_org_ids), ) # Combine with UNION and create subquery from sqlalchemy import union + combined = union(direct_query, group_query, org_query).subquery() # If org_id provided, filter by org @@ -584,12 +547,10 @@ def list_user_accessible_projects( Returns: List of Project objects the user can access (excludes archived) """ - from app.models.project import Project, ProjectType, ProjectStatus + from app.models.project import Project, ProjectStatus # Get all accessible project IDs - project_ids = ProjectShareService.list_user_accessible_project_ids( - db=db, user_id=user_id - ) + project_ids = ProjectShareService.list_user_accessible_project_ids(db=db, user_id=user_id) if not project_ids: return [] @@ -625,8 +586,8 @@ def list_project_accessible_users(db: Session, project_id: UUID) -> list["User"] Returns: List of User objects who can access the project """ - from app.models.project import Project from app.models.org_membership import OrgMembership + from app.models.project import Project from app.models.user import User project = db.query(Project).filter(Project.id == project_id).first() @@ -658,9 +619,7 @@ def list_project_accessible_users(db: Session, project_id: UUID) -> list["User"] ) for share in group_shares: members = ( - db.query(UserGroupMembership.user_id) - .filter(UserGroupMembership.group_id == share.subject_id) - .all() + db.query(UserGroupMembership.user_id).filter(UserGroupMembership.group_id == share.subject_id).all() ) for (uid,) in members: user_ids.add(uid) @@ -675,11 +634,7 @@ def list_project_accessible_users(db: Session, project_id: UUID) -> list["User"] .first() ) if org_share: - org_members = ( - db.query(OrgMembership.user_id) - .filter(OrgMembership.org_id == project.org_id) - .all() - ) + org_members = db.query(OrgMembership.user_id).filter(OrgMembership.org_id == project.org_id).all() for (uid,) in org_members: user_ids.add(uid) @@ -687,10 +642,6 @@ def list_project_accessible_users(db: Session, project_id: UUID) -> list["User"] if not user_ids: return [] - users = ( - db.query(User) - .filter(User.id.in_(user_ids), User.is_active == True) - .all() - ) + users = db.query(User).filter(User.id.in_(user_ids), User.is_active == True).all() return users diff --git a/backend/app/services/sample_project_service.py b/backend/app/services/sample_project_service.py index f57851f..e8e08d0 100644 --- a/backend/app/services/sample_project_service.py +++ b/backend/app/services/sample_project_service.py @@ -15,28 +15,27 @@ from sqlalchemy.orm import Session from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype -from app.models.phase_container import PhaseContainer -from app.models.spec_version import SpecVersion, SpecType from app.models.feature import ( Feature, - FeatureType, FeaturePriority, FeatureProvenance, FeatureStatus, + FeatureType, FeatureVisibilityStatus, ) -from app.models.feature_content_version import FeatureContentVersion, FeatureContentType -from app.models.module import Module, ModuleType, ModuleProvenance +from app.models.feature_content_version import FeatureContentType, FeatureContentVersion +from app.models.implementation import Implementation +from app.models.module import Module, ModuleProvenance, ModuleType from app.models.org_membership import OrgMembership, OrgRole -from app.models.project import Project, ProjectType, ProjectStatus +from app.models.phase_container import PhaseContainer +from app.models.project import Project, ProjectStatus, ProjectType from app.models.provisioning import ProvisioningSource -from app.models.thread import Thread, ContextType +from app.models.spec_version import SpecType, SpecVersion +from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem, ThreadItemType from app.models.user import User -from app.models.implementation import Implementation from app.services.agent_utils import get_or_create_agent_user - logger = logging.getLogger(__name__) @@ -58,9 +57,7 @@ def _get_sample_project_count(db: Session, org_id: UUID) -> int: ) @staticmethod - def _generate_unique_name_and_key( - db: Session, org_id: UUID, base_name: str, base_key: str - ) -> Tuple[str, str]: + def _generate_unique_name_and_key(db: Session, org_id: UUID, base_name: str, base_key: str) -> Tuple[str, str]: """Generate unique name and key for sample project based on existing count.""" count = SampleProjectService._get_sample_project_count(db, org_id) if count == 0: @@ -122,9 +119,7 @@ def create_sample_project( # Generate unique name and key for this instance base_name = seed_data["project"]["name"] base_key = seed_data["project"]["key"] - unique_name, unique_key = SampleProjectService._generate_unique_name_and_key( - db, org_id, base_name, base_key - ) + unique_name, unique_key = SampleProjectService._generate_unique_name_and_key(db, org_id, base_name, base_key) # Build ref->id mappings as we create entities ref_map: Dict[str, UUID] = { @@ -133,9 +128,7 @@ def create_sample_project( } # 1. Create sample users (Jane, Max) with is_active=False - sample_user_map = SampleProjectService._create_sample_users( - db, org_id, seed_data.get("sample_users", []) - ) + sample_user_map = SampleProjectService._create_sample_users(db, org_id, seed_data.get("sample_users", [])) ref_map.update(sample_user_map) # 2. Create project with is_sample=True @@ -146,9 +139,7 @@ def create_sample_project( # 3. Create containers for container_data in seed_data.get("containers", []): - container = SampleProjectService._create_container( - db, project.id, container_data - ) + container = SampleProjectService._create_container(db, project.id, container_data) ref_map[container_data["ref"]] = container.id # 4. Create phases (with container linkage) @@ -156,9 +147,7 @@ def create_sample_project( for phase_data in seed_data.get("phases", []): container_ref = phase_data.get("container_ref") container_id = ref_map.get(container_ref) if container_ref else None - phase = SampleProjectService._create_phase( - db, project.id, user_id, phase_data, container_id=container_id - ) + phase = SampleProjectService._create_phase(db, project.id, user_id, phase_data, container_id=container_id) ref_map[phase_data["ref"]] = phase.id created_phases.append(phase) @@ -167,43 +156,33 @@ def create_sample_project( phase_ref = spec_data.get("phase_ref") phase_id = ref_map.get(phase_ref) if phase_id: - SampleProjectService._create_spec_draft( - db, phase_id, user_id, spec_data, SpecType.SPECIFICATION - ) + SampleProjectService._create_spec_draft(db, phase_id, user_id, spec_data, SpecType.SPECIFICATION) # 5. Create prompt plan drafts for phases for plan_data in seed_data.get("prompt_plans", []): phase_ref = plan_data.get("phase_ref") phase_id = ref_map.get(phase_ref) if phase_id: - SampleProjectService._create_spec_draft( - db, phase_id, user_id, plan_data, SpecType.PROMPT_PLAN - ) + SampleProjectService._create_spec_draft(db, phase_id, user_id, plan_data, SpecType.PROMPT_PLAN) # 6. Create modules for module_data in seed_data.get("modules", []): phase_ref = module_data.get("phase_ref") phase_id = ref_map.get(phase_ref) if phase_ref else None - module = SampleProjectService._create_module( - db, project.id, phase_id, user_id, module_data - ) + module = SampleProjectService._create_module(db, project.id, phase_id, user_id, module_data) ref_map[module_data["ref"]] = module.id # 7. Create features for feature_data in seed_data.get("features", []): module_id = ref_map[feature_data["module_ref"]] - feature = SampleProjectService._create_feature( - db, project.id, module_id, user_id, feature_data - ) + feature = SampleProjectService._create_feature(db, project.id, module_id, user_id, feature_data) ref_map[feature_data["ref"]] = feature.id # 8. Create threads for thread_data in seed_data.get("threads", []): feature_ref = thread_data.get("feature_ref") feature_id = ref_map.get(feature_ref) if feature_ref else None - thread = SampleProjectService._create_thread( - db, project.id, feature_id, user_id, thread_data - ) + thread = SampleProjectService._create_thread(db, project.id, feature_id, user_id, thread_data) ref_map[thread_data["ref"]] = thread.id # 9. Create thread items @@ -211,21 +190,15 @@ def create_sample_project( signed_up_user = db.query(User).filter(User.id == user_id).first() for item_data in seed_data.get("thread_items", []): thread_id = ref_map[item_data["thread_ref"]] - author_id = SampleProjectService._resolve_author( - item_data.get("author_ref", "mfbtai"), ref_map - ) - SampleProjectService._create_thread_item( - db, str(thread_id), str(author_id), item_data, signed_up_user - ) + author_id = SampleProjectService._resolve_author(item_data.get("author_ref", "mfbtai"), ref_map) + SampleProjectService._create_thread_item(db, str(thread_id), str(author_id), item_data, signed_up_user) db.commit() logger.info(f"Created sample project {project.id} for org {org_id}") return project, created_phases @staticmethod - def _create_sample_users( - db: Session, org_id: UUID, sample_users: list - ) -> Dict[str, UUID]: + def _create_sample_users(db: Session, org_id: UUID, sample_users: list) -> Dict[str, UUID]: """ Create sample users with is_active=False. @@ -283,8 +256,8 @@ def _create_project( unique_key: str, ) -> Project: """Create the sample project.""" - from app.services.project_share_service import ProjectShareService from app.models.project_share import ProjectRole + from app.services.project_share_service import ProjectShareService project = Project( id=uuid4(), @@ -315,9 +288,7 @@ def _create_project( return project @staticmethod - def _create_container( - db: Session, project_id: UUID, data: Dict[str, Any] - ) -> PhaseContainer: + def _create_container(db: Session, project_id: UUID, data: Dict[str, Any]) -> PhaseContainer: """Create a phase container.""" container = PhaseContainer( project_id=project_id, @@ -649,9 +620,7 @@ def _create_thread( decision_summary_short=data.get("decision_summary_short"), unresolved_points=data.get("unresolved_points"), suggested_implementation_name=data.get("suggested_implementation_name"), - show_create_implementation_button=data.get( - "show_create_implementation_button", False - ), + show_create_implementation_button=data.get("show_create_implementation_button", False), created_by=str(user_id), created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), @@ -685,11 +654,7 @@ def _create_thread_item( content_data = data.get("content_data", {}).copy() # If this is an MCQ with a selected_option_id, populate answered_by/answered_at - if ( - item_type == ThreadItemType.MCQ_FOLLOWUP.value - and content_data.get("selected_option_id") - and signed_up_user - ): + if item_type == ThreadItemType.MCQ_FOLLOWUP.value and content_data.get("selected_option_id") and signed_up_user: answered_at = datetime.now(timezone.utc) content_data["answered_at"] = answered_at.isoformat() content_data["answered_by"] = { diff --git a/backend/app/services/spec_service.py b/backend/app/services/spec_service.py index e7a39e0..a9f3e36 100644 --- a/backend/app/services/spec_service.py +++ b/backend/app/services/spec_service.py @@ -2,9 +2,10 @@ from typing import Optional from uuid import UUID + from sqlalchemy.orm import Session -from app.models.spec_version import SpecVersion, SpecType +from app.models.spec_version import SpecType, SpecVersion class SpecService: diff --git a/backend/app/services/team_role_service.py b/backend/app/services/team_role_service.py index 6bec672..08dcd60 100644 --- a/backend/app/services/team_role_service.py +++ b/backend/app/services/team_role_service.py @@ -8,12 +8,12 @@ from sqlalchemy.orm import Session from app.models import ( - TeamRoleDefinition, - ProjectTeamAssignment, DEFAULT_TEAM_ROLES, + OrgMembership, Project, + ProjectTeamAssignment, + TeamRoleDefinition, User, - OrgMembership, ) @@ -85,9 +85,7 @@ def get_org_role_definitions(db: Session, org_id: UUID) -> list[TeamRoleDefiniti ) @staticmethod - def get_role_definition( - db: Session, role_id: UUID - ) -> Optional[TeamRoleDefinition]: + def get_role_definition(db: Session, role_id: UUID) -> Optional[TeamRoleDefinition]: """Get a specific role definition by ID. Args: @@ -97,16 +95,10 @@ def get_role_definition( Returns: TeamRoleDefinition or None if not found """ - return ( - db.query(TeamRoleDefinition) - .filter(TeamRoleDefinition.id == role_id) - .first() - ) + return db.query(TeamRoleDefinition).filter(TeamRoleDefinition.id == role_id).first() @staticmethod - def get_role_definition_by_key( - db: Session, org_id: UUID, role_key: str - ) -> Optional[TeamRoleDefinition]: + def get_role_definition_by_key(db: Session, org_id: UUID, role_key: str) -> Optional[TeamRoleDefinition]: """Get a specific role definition by org and key. Args: @@ -138,9 +130,7 @@ def get_next_order_index(db: Session, org_id: UUID) -> int: int: Next order index """ max_index = ( - db.query(func.max(TeamRoleDefinition.order_index)) - .filter(TeamRoleDefinition.org_id == org_id) - .scalar() + db.query(func.max(TeamRoleDefinition.order_index)).filter(TeamRoleDefinition.org_id == org_id).scalar() ) return (max_index or 0) + 1 @@ -277,9 +267,7 @@ def delete_role_definition(db: Session, role_id: UUID) -> None: db.commit() @staticmethod - def reset_role_to_default( - db: Session, role_id: UUID - ) -> TeamRoleDefinition: + def reset_role_to_default(db: Session, role_id: UUID) -> TeamRoleDefinition: """Reset a role to its default title/description. Args: @@ -379,7 +367,7 @@ def assign_user_to_role( .first() ) if existing: - raise ValueError(f"User already has this role on this project") + raise ValueError("User already has this role on this project") assignment = ProjectTeamAssignment( project_id=project_id, @@ -403,11 +391,7 @@ def remove_assignment_by_id(db: Session, assignment_id: UUID) -> None: Raises: ValueError: If assignment not found """ - assignment = ( - db.query(ProjectTeamAssignment) - .filter(ProjectTeamAssignment.id == assignment_id) - .first() - ) + assignment = db.query(ProjectTeamAssignment).filter(ProjectTeamAssignment.id == assignment_id).first() if not assignment: raise ValueError("Assignment not found") @@ -415,9 +399,7 @@ def remove_assignment_by_id(db: Session, assignment_id: UUID) -> None: db.commit() @staticmethod - def get_project_team( - db: Session, project_id: UUID - ) -> dict[UUID, list[ProjectTeamAssignment]]: + def get_project_team(db: Session, project_id: UUID) -> dict[UUID, list[ProjectTeamAssignment]]: """Get all team assignments for a project, grouped by role definition ID. Args: @@ -427,11 +409,7 @@ def get_project_team( Returns: dict mapping role definition IDs to lists of assignments """ - assignments = ( - db.query(ProjectTeamAssignment) - .filter(ProjectTeamAssignment.project_id == project_id) - .all() - ) + assignments = db.query(ProjectTeamAssignment).filter(ProjectTeamAssignment.project_id == project_id).all() # Group by role definition result: dict[UUID, list[ProjectTeamAssignment]] = {} @@ -444,9 +422,7 @@ def get_project_team( return result @staticmethod - def get_project_team_assignments( - db: Session, project_id: UUID - ) -> list[ProjectTeamAssignment]: + def get_project_team_assignments(db: Session, project_id: UUID) -> list[ProjectTeamAssignment]: """Get all team assignments for a project. Args: @@ -456,16 +432,10 @@ def get_project_team_assignments( Returns: list of all assignments """ - return ( - db.query(ProjectTeamAssignment) - .filter(ProjectTeamAssignment.project_id == project_id) - .all() - ) + return db.query(ProjectTeamAssignment).filter(ProjectTeamAssignment.project_id == project_id).all() @staticmethod - def get_user_roles_on_project( - db: Session, project_id: UUID, user_id: UUID - ) -> list[TeamRoleDefinition]: + def get_user_roles_on_project(db: Session, project_id: UUID, user_id: UUID) -> list[TeamRoleDefinition]: """Get all roles a user has on a project. Args: diff --git a/backend/app/services/thread_service.py b/backend/app/services/thread_service.py index 86580b0..dade900 100644 --- a/backend/app/services/thread_service.py +++ b/backend/app/services/thread_service.py @@ -1,15 +1,16 @@ """Service layer for thread and comment operations.""" + import logging from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from uuid import UUID -from sqlalchemy import func + from sqlalchemy.orm import Session, joinedload -from app.config import settings -from app.models import Thread, Comment, ContextType, Project, User -from app.models.thread_item import ThreadItem, ThreadItemType + +from app.models import Comment, ContextType, Project, Thread, User from app.models.feature import Feature -from app.models.job import Job, JobType, JobStatus +from app.models.job import Job, JobStatus, JobType +from app.models.thread_item import ThreadItem, ThreadItemType from app.services.mention_utils import extract_user_mentions, has_mfbtai_mention logger = logging.getLogger(__name__) @@ -33,6 +34,7 @@ def _touch_linked_feature(db: Session, thread: Thread) -> None: try: # Convert string context_id to UUID for Feature query from uuid import UUID as PyUUID + feature_id = PyUUID(thread.context_id) if isinstance(thread.context_id, str) else thread.context_id feature = db.query(Feature).filter(Feature.id == feature_id).first() if feature: @@ -52,12 +54,7 @@ def _should_show_create_implementation_button(db: Session, thread_id: str) -> bo there is conversation content. """ # Get all items ordered by created_at - items = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == thread_id) - .order_by(ThreadItem.created_at) - .all() - ) + items = db.query(ThreadItem).filter(ThreadItem.thread_id == thread_id).order_by(ThreadItem.created_at).all() if not items: return False # No conversation at all @@ -126,9 +123,7 @@ def get_thread_by_id(db: Session, thread_id: str) -> Optional[Thread]: return ( db.query(Thread) .filter(Thread.id == thread_id) - .options( - joinedload(Thread.comments).joinedload(Comment.author) - ) + .options(joinedload(Thread.comments).joinedload(Comment.author)) .first() ) @@ -186,12 +181,7 @@ def create_comment( db.flush() # Reload comment with author information - comment = ( - db.query(Comment) - .filter(Comment.id == comment.id) - .options(joinedload(Comment.author)) - .first() - ) + comment = db.query(Comment).filter(Comment.id == comment.id).options(joinedload(Comment.author)).first() # Broadcast comment creation via WebSocket # TODO: Fix UUID serialization issue in broadcast @@ -258,8 +248,9 @@ def _broadcast_thread_update(db: Session, thread: Thread, event_type: str): event_type: Type of event (e.g., "thread_created") """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.schemas.thread import ThreadListResponse + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -290,10 +281,7 @@ def _broadcast_thread_update(db: Session, thread: Thread, event_type: str): ) if success: - logger.info( - f"Broadcasted thread update via Kafka: thread_id={thread.id}, " - f"event={event_type}" - ) + logger.info(f"Broadcasted thread update via Kafka: thread_id={thread.id}, event={event_type}") except Exception as e: logger.error(f"Failed to broadcast thread update: {e}", exc_info=True) @@ -309,8 +297,9 @@ def _broadcast_comment_update(db: Session, comment: Comment, event_type: str): event_type: Type of event (e.g., "comment_created", "comment_updated", "comment_deleted") """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.schemas.thread import CommentResponse + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -346,10 +335,7 @@ def _broadcast_comment_update(db: Session, comment: Comment, event_type: str): ) if success: - logger.info( - f"Broadcasted comment update via Kafka: comment_id={comment.id}, " - f"event={event_type}" - ) + logger.info(f"Broadcasted comment update via Kafka: comment_id={comment.id}, event={event_type}") except Exception as e: logger.error(f"Failed to broadcast comment update: {e}", exc_info=True) @@ -379,9 +365,10 @@ def create_comment_item( body_markdown: Comment text in markdown format images: Optional list of image attachment metadata dicts """ - from app.services.agent_utils import AGENT_EMAIL, get_or_create_agent_user from uuid import UUID as PyUUID + from app.services.agent_utils import AGENT_EMAIL + # Verify thread exists thread = db.query(Thread).filter(Thread.id == thread_id).first() if not thread: @@ -423,12 +410,7 @@ def create_comment_item( db.refresh(item) # Reload item with author information - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item.id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item.id).options(joinedload(ThreadItem.author)).first() # Broadcast item creation via WebSocket ThreadService._broadcast_thread_item_update(db, item, "thread_item_created") @@ -437,22 +419,13 @@ def create_comment_item( # Skip if message only tags @MFBTAI (no other users mentioned) # because the conversation is just starting an AI interaction try: - only_mfbtai_mentioned = ( - has_mfbtai_mention(body_markdown) - and not extract_user_mentions(body_markdown) - ) + only_mfbtai_mentioned = has_mfbtai_mention(body_markdown) and not extract_user_mentions(body_markdown) if only_mfbtai_mentioned: - logger.debug( - f"Skipping decision summary for thread {thread_id}: only @MFBTAI mentioned" - ) + logger.debug(f"Skipping decision summary for thread {thread_id}: only @MFBTAI mentioned") else: - ThreadService.trigger_decision_summary( - db, thread_id, triggered_by_user_id=UUID(author_id) - ) + ThreadService.trigger_decision_summary(db, thread_id, triggered_by_user_id=UUID(author_id)) except Exception as e: - logger.warning( - f"Failed to trigger decision summary for thread {thread_id}: {e}" - ) + logger.warning(f"Failed to trigger decision summary for thread {thread_id}: {e}") # Trigger mention notifications (async, non-blocking) # Only for user comments, not agent comments @@ -460,17 +433,11 @@ def create_comment_item( try: mentioned_user_ids = extract_user_mentions(body_markdown) # Exclude self-mentions - mentioned_user_ids = [ - uid for uid in mentioned_user_ids if str(uid) != str(author_id) - ] + mentioned_user_ids = [uid for uid in mentioned_user_ids if str(uid) != str(author_id)] if mentioned_user_ids: - ThreadService._trigger_mention_notifications( - db, thread, item, author_id, mentioned_user_ids - ) + ThreadService._trigger_mention_notifications(db, thread, item, author_id, mentioned_user_ids) except Exception as e: - logger.warning( - f"Failed to trigger mention notifications for thread {thread_id}: {e}" - ) + logger.warning(f"Failed to trigger mention notifications for thread {thread_id}: {e}") return item @@ -534,10 +501,7 @@ def create_mcq_item( # Append standard choices if not already present (LLM may include them) existing_ids = {c.get("id") for c in choices} - all_choices = choices + [ - c for c in ThreadService.STANDARD_MCQ_CHOICES - if c["id"] not in existing_ids - ] + all_choices = choices + [c for c in ThreadService.STANDARD_MCQ_CHOICES if c["id"] not in existing_ids] # Build content_data content_data = { @@ -566,12 +530,7 @@ def create_mcq_item( db.refresh(item) # Reload item with author information - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item.id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item.id).options(joinedload(ThreadItem.author)).first() # Broadcast item creation via WebSocket ThreadService._broadcast_thread_item_update(db, item, "thread_item_created") @@ -650,6 +609,7 @@ def answer_mcq_item( HTTPException: If changing answer with downstream items and force_change=False """ import logging + from fastapi import HTTPException from sqlalchemy.orm.attributes import flag_modified @@ -704,9 +664,7 @@ def answer_mcq_item( deleted_item_ids = [str(i.id) for i in downstream_items] # Broadcast bulk deletion before deleting - ThreadService._broadcast_thread_items_bulk_deleted( - db, str(item.thread_id), deleted_item_ids - ) + ThreadService._broadcast_thread_items_bulk_deleted(db, str(item.thread_id), deleted_item_ids) # Delete downstream items for downstream_item in downstream_items: @@ -718,9 +676,7 @@ def answer_mcq_item( thread.unresolved_points = None thread.last_summarized_item_id = None - logger.info( - f"Deleted {len(deleted_item_ids)} downstream items for MCQ answer change: {item_id}" - ) + logger.info(f"Deleted {len(deleted_item_ids)} downstream items for MCQ answer change: {item_id}") # Update content_data with answer item.content_data["selected_option_id"] = selected_option_id @@ -805,19 +761,13 @@ def answer_mcq_item( from app.services.brainstorming_phase_service import BrainstormingPhaseService # Get feature → module → phase chain - feature = db.query(Feature).filter( - Feature.id == UUID(thread.context_id) - ).first() + feature = db.query(Feature).filter(Feature.id == UUID(thread.context_id)).first() if feature: module = db.query(Module).filter(Module.id == feature.module_id).first() if module and module.brainstorming_phase_id: - BrainstormingPhaseService.refresh_phase_question_stats( - db, module.brainstorming_phase_id - ) + BrainstormingPhaseService.refresh_phase_question_stats(db, module.brainstorming_phase_id) except Exception as e: - logger.warning( - f"Failed to refresh phase question stats after MCQ answer: {e}" - ) + logger.warning(f"Failed to refresh phase question stats after MCQ answer: {e}") # Broadcast item update via WebSocket ThreadService._broadcast_thread_item_update(db, item, "thread_item_updated") @@ -830,9 +780,7 @@ def answer_mcq_item( triggered_by_user_id=UUID(answerer_id) if answerer_id else None, ) except Exception as e: - logger.warning( - f"Failed to trigger decision summary for thread {item.thread_id}: {e}" - ) + logger.warning(f"Failed to trigger decision summary for thread {item.thread_id}: {e}") # Trigger MFBTAI if this was an MFBTAI-generated MCQ source = item.content_data.get("source") @@ -840,9 +788,7 @@ def answer_mcq_item( try: ThreadService._trigger_mfbtai_on_mcq_answer(db, item, answerer_id) except Exception as e: - logger.warning( - f"Failed to trigger MFBTAI for MCQ answer {item_id}: {e}" - ) + logger.warning(f"Failed to trigger MFBTAI for MCQ answer {item_id}: {e}") return item @@ -884,17 +830,12 @@ def _trigger_mfbtai_on_mcq_answer( pending_mfbtai_mcqs = [ item for item in all_mcqs_in_thread - if ( - item.content_data.get("source") == "mfbtai" - and not item.content_data.get("selected_option_id") - ) + if (item.content_data.get("source") == "mfbtai" and not item.content_data.get("selected_option_id")) ] if pending_mfbtai_mcqs: # Other MFBTAI MCQs still pending - don't trigger yet - logger.info( - f"Thread {thread.id} has {len(pending_mfbtai_mcqs)} pending MFBTAI MCQs - deferring trigger" - ) + logger.info(f"Thread {thread.id} has {len(pending_mfbtai_mcqs)} pending MFBTAI MCQs - deferring trigger") return # Get feature for context @@ -935,6 +876,7 @@ def _trigger_mfbtai_on_mcq_answer( # Update payload with job_id and mark as modified for SQLAlchemy from sqlalchemy.orm.attributes import flag_modified + job.payload["job_id"] = str(job.id) flag_modified(job, "payload") db.commit() @@ -945,9 +887,7 @@ def _trigger_mfbtai_on_mcq_answer( topic="mfbt.collab_thread.ai_mention", message=job.payload, ) - logger.info( - f"Triggered MFBTAI for MCQ answer: thread={thread.id}, job={job.id}" - ) + logger.info(f"Triggered MFBTAI for MCQ answer: thread={thread.id}, job={job.id}") except Exception as e: # Job is created, just Kafka failed - log and continue logger.error(f"Failed to publish MFBTAI job to Kafka: {e}") @@ -1031,6 +971,7 @@ def delete_thread_item(db: Session, item_id: str) -> None: ThreadService.trigger_decision_summary(db, thread_id) except Exception as e: import logging + logging.getLogger(__name__).warning( f"Failed to trigger decision summary after delete for thread {thread_id}: {e}" ) @@ -1063,12 +1004,7 @@ def toggle_reaction( """ from app.models.thread_item import ThreadItemType - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item_id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item_id).options(joinedload(ThreadItem.author)).first() if not item: raise ValueError(f"Thread item {item_id} not found") @@ -1089,12 +1025,13 @@ def toggle_reaction( # Get user display name for the tooltip from uuid import UUID as PyUUID + from app.models import User + user_uuid = PyUUID(user_id) if isinstance(user_id, str) else user_id reacting_user = db.query(User).filter(User.id == user_uuid).first() user_display_name = ( - reacting_user.display_name or reacting_user.email.split("@")[0] - if reacting_user else "Unknown" + reacting_user.display_name or reacting_user.email.split("@")[0] if reacting_user else "Unknown" ) if reaction_index is not None: @@ -1126,13 +1063,15 @@ def toggle_reaction( reaction["count"] = len(user_ids) else: # New emoji reaction - reactions.append({ - "emoji": emoji, - "emoji_native": emoji_native, - "user_ids": [user_id], - "user_names": {user_id: user_display_name}, - "count": 1, - }) + reactions.append( + { + "emoji": emoji, + "emoji_native": emoji_native, + "user_ids": [user_id], + "user_names": {user_id: user_display_name}, + "count": 1, + } + ) # Update content_data (create new dict to trigger SQLAlchemy change detection) content_data["reactions"] = reactions @@ -1141,6 +1080,7 @@ def toggle_reaction( # Mark column as modified to ensure SQLAlchemy saves the change from sqlalchemy.orm.attributes import flag_modified + flag_modified(item, "content_data") db.commit() @@ -1178,15 +1118,11 @@ def start_over_from_item( PermissionError: If user is not the author of the starting item """ import logging + logger = logging.getLogger(__name__) # Get the target item with author loaded - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item_id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item_id).options(joinedload(ThreadItem.author)).first() if not item: raise ValueError(f"Thread item {item_id} not found") @@ -1232,9 +1168,7 @@ def start_over_from_item( if implementations_to_delete and thread.context_type == ContextType.BRAINSTORM_FEATURE: feature_id_for_broadcast = UUID(thread.context_id) total_implementations = ( - db.query(Implementation) - .filter(Implementation.feature_id == feature_id_for_broadcast) - .count() + db.query(Implementation).filter(Implementation.feature_id == feature_id_for_broadcast).count() ) if len(implementations_to_delete) >= total_implementations: @@ -1264,15 +1198,11 @@ def start_over_from_item( deleted_count = len(items_to_delete) # Broadcast bulk deletion before deleting - ThreadService._broadcast_thread_items_bulk_deleted( - db, thread_id, deleted_item_ids - ) + ThreadService._broadcast_thread_items_bulk_deleted(db, thread_id, deleted_item_ids) # Broadcast implementation deletions if deleted_impl_ids and feature_id_for_broadcast: - ImplementationService.broadcast_implementations_deleted( - db, deleted_impl_ids, feature_id_for_broadcast - ) + ImplementationService.broadcast_implementations_deleted(db, deleted_impl_ids, feature_id_for_broadcast) # Find the item immediately before the first deleted one (for snapshot recovery) prev_item = ( @@ -1310,7 +1240,9 @@ def start_over_from_item( # Re-evaluate button visibility and clear suggested name after deletion # The button should only show if there are substantive items after the last impl marker - thread.show_create_implementation_button = ThreadService._should_show_create_implementation_button(db, thread_id) + thread.show_create_implementation_button = ThreadService._should_show_create_implementation_button( + db, thread_id + ) thread.suggested_implementation_name = None db.commit() @@ -1331,37 +1263,21 @@ def start_over_from_item( # 2. We did a full reset (no snapshot was available) # When snapshot is available, all items are already summarized if prev_item and prev_item.summary_snapshot is None: - remaining_count = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == thread_id) - .count() - ) + remaining_count = db.query(ThreadItem).filter(ThreadItem.thread_id == thread_id).count() if remaining_count > 0: try: - ThreadService.trigger_decision_summary( - db, thread_id, triggered_by_user_id=UUID(user_id) - ) + ThreadService.trigger_decision_summary(db, thread_id, triggered_by_user_id=UUID(user_id)) except Exception as e: - logger.warning( - f"Failed to trigger decision summary after start-over for thread {thread_id}: {e}" - ) + logger.warning(f"Failed to trigger decision summary after start-over for thread {thread_id}: {e}") elif prev_item is None: # No previous item means all items were deleted or we deleted from first item # Check if any items remain (shouldn't happen but be safe) - remaining_count = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == thread_id) - .count() - ) + remaining_count = db.query(ThreadItem).filter(ThreadItem.thread_id == thread_id).count() if remaining_count > 0: try: - ThreadService.trigger_decision_summary( - db, thread_id, triggered_by_user_id=UUID(user_id) - ) + ThreadService.trigger_decision_summary(db, thread_id, triggered_by_user_id=UUID(user_id)) except Exception as e: - logger.warning( - f"Failed to trigger decision summary after start-over for thread {thread_id}: {e}" - ) + logger.warning(f"Failed to trigger decision summary after start-over for thread {thread_id}: {e}") logger.info( f"Start over from item {item_id}: deleted {deleted_count} items, " @@ -1388,6 +1304,7 @@ def _broadcast_thread_items_bulk_deleted( deleted_item_ids: List of deleted item IDs """ import logging + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1423,8 +1340,7 @@ def _broadcast_thread_items_bulk_deleted( if success: logger.info( - f"Broadcasted bulk deletion via Kafka: thread_id={thread_id}, " - f"count={len(deleted_item_ids)}" + f"Broadcasted bulk deletion via Kafka: thread_id={thread_id}, count={len(deleted_item_ids)}" ) except Exception as e: @@ -1438,8 +1354,9 @@ def _broadcast_thread_item_update(db: Session, item: ThreadItem, event_type: str Similar to _broadcast_comment_update but for thread items. """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.schemas.thread_item import thread_item_to_response + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1477,10 +1394,7 @@ def _broadcast_thread_item_update(db: Session, item: ThreadItem, event_type: str ) if success: - logger.info( - f"Broadcasted thread item update via Kafka: item_id={item.id}, " - f"event={event_type}" - ) + logger.info(f"Broadcasted thread item update via Kafka: item_id={item.id}, event={event_type}") except Exception as e: logger.error(f"Failed to broadcast thread item update: {e}", exc_info=True) @@ -1511,6 +1425,7 @@ def broadcast_decision_summary_update( show_create_implementation_button: Whether to show Create Implementation button """ import logging + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1572,6 +1487,7 @@ def broadcast_thread_updated( thread: The Thread object that was updated """ import logging + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1605,8 +1521,7 @@ def broadcast_thread_updated( if success: logger.debug( - f"Broadcasted thread update via Kafka: thread_id={thread.id}, " - f"retry_status={thread.retry_status}" + f"Broadcasted thread update via Kafka: thread_id={thread.id}, retry_status={thread.retry_status}" ) except Exception as e: @@ -1634,6 +1549,7 @@ def set_ai_error_state( user_message: The user's original message (for retry) """ import logging + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1690,6 +1606,7 @@ def clear_ai_error_state(db: Session, thread_id: str) -> None: thread_id: ID of the thread to clear error state """ import logging + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1792,12 +1709,7 @@ def list_threads_for_version( Returns: List of Thread objects for the version """ - return ( - db.query(Thread) - .filter(Thread.version_id == version_id) - .order_by(Thread.created_at) - .all() - ) + return db.query(Thread).filter(Thread.version_id == version_id).order_by(Thread.created_at).all() @staticmethod def get_block_thread( @@ -1875,8 +1787,9 @@ def trigger_decision_summary( The created Job if one was created, None if skipped. """ import logging - from app.services.kafka_producer import get_sync_kafka_producer + from app.services.job_service import JobService + from app.services.kafka_producer import get_sync_kafka_producer logger = logging.getLogger(__name__) @@ -1900,18 +1813,14 @@ def trigger_decision_summary( db.query(ThreadItem) .filter( ThreadItem.thread_id == thread_id, - ThreadItem.created_at > ( - db.query(ThreadItem.created_at) - .filter(ThreadItem.id == last_summarized) - .scalar_subquery() - ) + ThreadItem.created_at + > (db.query(ThreadItem.created_at).filter(ThreadItem.id == last_summarized).scalar_subquery()), ) .count() ) if unprocessed_count == 0: logger.debug( - f"Skipping decision summary for thread {thread_id}: " - f"no unprocessed items since {last_summarized}" + f"Skipping decision summary for thread {thread_id}: no unprocessed items since {last_summarized}" ) return None @@ -1923,7 +1832,7 @@ def trigger_decision_summary( Job.job_type == JobType.COLLAB_THREAD_DECISION_SUMMARIZE, Job.status.in_([JobStatus.QUEUED, JobStatus.RUNNING]), ) - .filter(Job.payload.op('->>')('thread_id') == str(thread_id)) + .filter(Job.payload.op("->>")("thread_id") == str(thread_id)) .first() ) @@ -2002,15 +1911,13 @@ def _trigger_mention_notifications( author_id: ID of the user who made the mention mentioned_user_ids: List of user UUIDs to notify """ - from app.services.kafka_producer import get_sync_kafka_producer from app.services.job_service import JobService + from app.services.kafka_producer import get_sync_kafka_producer # Get project for org_id project = db.query(Project).filter(Project.id == thread.project_id).first() if not project: - logger.warning( - f"Cannot trigger mention notifications: project {thread.project_id} not found" - ) + logger.warning(f"Cannot trigger mention notifications: project {thread.project_id} not found") return # Create the job @@ -2048,9 +1955,7 @@ def _trigger_mention_notifications( key=str(thread.id), ) - logger.info( - f"Triggered mention notification job {job.id} for {len(mentioned_user_ids)} users" - ) + logger.info(f"Triggered mention notification job {job.id} for {len(mentioned_user_ids)} users") except Exception as e: logger.error(f"Failed to publish mention notification job to Kafka: {e}") @@ -2276,7 +2181,6 @@ def create_code_exploration_item( ValueError: If thread not found """ from app.services.agent_utils import get_or_create_agent_user - from app.models.code_exploration_result import CodeExplorationResult thread = db.query(Thread).filter(Thread.id == thread_id).first() if not thread: @@ -2292,9 +2196,7 @@ def create_code_exploration_item( "branch": exploration_result.branch, "repo_url": exploration_result.repo_url, "execution_time_seconds": ( - float(exploration_result.execution_time_seconds) - if exploration_result.execution_time_seconds - else None + float(exploration_result.execution_time_seconds) if exploration_result.execution_time_seconds else None ), "prompt_tokens": exploration_result.prompt_tokens, "completion_tokens": exploration_result.completion_tokens, @@ -2317,12 +2219,7 @@ def create_code_exploration_item( db.refresh(item) # Reload with author - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item.id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item.id).options(joinedload(ThreadItem.author)).first() # Broadcast the new thread item ThreadService._broadcast_thread_item_update(db, item, "thread_item_created") @@ -2435,12 +2332,7 @@ def create_web_search_item( db.refresh(item) # Reload with author - item = ( - db.query(ThreadItem) - .filter(ThreadItem.id == item.id) - .options(joinedload(ThreadItem.author)) - .first() - ) + item = db.query(ThreadItem).filter(ThreadItem.id == item.id).options(joinedload(ThreadItem.author)).first() # Broadcast the new thread item ThreadService._broadcast_thread_item_update(db, item, "thread_item_created") @@ -2515,7 +2407,6 @@ def create_project_chat_thread( The created Thread instance """ from app.utils.short_id import generate_short_id - from app.models.thread import ProjectChatVisibility thread = Thread( project_id=project_id, @@ -2549,7 +2440,7 @@ def get_project_chat_by_identifier( Returns: The Thread instance or None """ - from app.utils.short_id import is_uuid, extract_short_id + from app.utils.short_id import extract_short_id, is_uuid # Try UUID first if is_uuid(identifier): @@ -2611,13 +2502,7 @@ def list_project_chat_threads( if created_by is not None: query = query.filter(Thread.created_by == created_by) - return ( - query - .order_by(Thread.updated_at.desc()) - .offset(offset) - .limit(limit) - .all() - ) + return query.order_by(Thread.updated_at.desc()).offset(offset).limit(limit).all() @staticmethod def update_project_chat_proposal_state( diff --git a/backend/app/services/typing_indicator_service.py b/backend/app/services/typing_indicator_service.py index ab2f158..8e274e4 100644 --- a/backend/app/services/typing_indicator_service.py +++ b/backend/app/services/typing_indicator_service.py @@ -8,7 +8,7 @@ import json import logging from datetime import datetime, timezone -from typing import List, Dict, Optional +from typing import Dict, List, Optional from app.services.analytics_cache import get_redis_client from app.services.kafka_producer import get_sync_kafka_producer @@ -67,11 +67,13 @@ def set_typing( try: key = TypingIndicatorService._get_key(thread_id, user_id) - value = json.dumps({ - "user_id": user_id, - "user_name": user_name, - "started_at": datetime.now(timezone.utc).isoformat(), - }) + value = json.dumps( + { + "user_id": user_id, + "user_name": user_name, + "started_at": datetime.now(timezone.utc).isoformat(), + } + ) # Set with TTL - will auto-expire client.setex(key, TypingIndicatorService.TYPING_TTL_SECONDS, value) @@ -164,10 +166,12 @@ def get_typers(thread_id: str) -> List[Dict[str, str]]: value = client.get(key) if value: data = json.loads(value) - typers.append({ - "user_id": data["user_id"], - "user_name": data["user_name"], - }) + typers.append( + { + "user_id": data["user_id"], + "user_name": data["user_name"], + } + ) return typers diff --git a/backend/app/services/user_group_service.py b/backend/app/services/user_group_service.py index 3d59d79..6b6a9b9 100644 --- a/backend/app/services/user_group_service.py +++ b/backend/app/services/user_group_service.py @@ -66,9 +66,7 @@ def get_group_by_id(db: Session, group_id: UUID) -> Optional[UserGroup]: return db.query(UserGroup).filter(UserGroup.id == group_id).first() @staticmethod - def get_group_by_name( - db: Session, org_id: UUID, name: str - ) -> Optional[UserGroup]: + def get_group_by_name(db: Session, org_id: UUID, name: str) -> Optional[UserGroup]: """Get a group by name within an organization. Args: @@ -79,11 +77,7 @@ def get_group_by_name( Returns: UserGroup or None if not found """ - return ( - db.query(UserGroup) - .filter(UserGroup.org_id == org_id, UserGroup.name == name) - .first() - ) + return db.query(UserGroup).filter(UserGroup.org_id == org_id, UserGroup.name == name).first() @staticmethod def list_org_groups(db: Session, org_id: UUID) -> list[UserGroup]: @@ -96,12 +90,7 @@ def list_org_groups(db: Session, org_id: UUID) -> list[UserGroup]: Returns: List of UserGroup objects """ - return ( - db.query(UserGroup) - .filter(UserGroup.org_id == org_id) - .order_by(UserGroup.name) - .all() - ) + return db.query(UserGroup).filter(UserGroup.org_id == org_id).order_by(UserGroup.name).all() @staticmethod def update_group( diff --git a/backend/app/services/user_question_session_service.py b/backend/app/services/user_question_session_service.py index e585961..c336b6b 100644 --- a/backend/app/services/user_question_session_service.py +++ b/backend/app/services/user_question_session_service.py @@ -1,34 +1,31 @@ """Service for managing user question sessions.""" + import logging -import json -import uuid -from datetime import datetime, timezone -from typing import Optional, List, Dict, Any +from typing import Any, Dict, List, Optional from uuid import UUID from sqlalchemy.orm import Session, joinedload -from app.models.user_question_session import ( - UserQuestionSession, - UserQuestionSessionStatus, - UserQuestionMessage, - MessageRole, -) from app.models.brainstorming_phase import BrainstormingPhase -from app.models.module import Module, ModuleProvenance, ModuleType from app.models.feature import ( Feature, - FeatureProvenance, FeaturePriority, + FeatureProvenance, FeatureType, FeatureVisibilityStatus, ) -from app.models.thread import Thread, ContextType +from app.models.module import Module, ModuleProvenance, ModuleType +from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem, ThreadItemType +from app.models.user_question_session import ( + MessageRole, + UserQuestionMessage, + UserQuestionSession, + UserQuestionSessionStatus, +) from app.services.feature_service import FeatureService from app.services.module_service import ModuleService - logger = logging.getLogger(__name__) @@ -38,11 +35,7 @@ class UserQuestionSessionService: MAX_QUESTIONS_PER_SESSION = 5 @staticmethod - def create_session( - db: Session, - brainstorming_phase_id: UUID, - user_id: UUID - ) -> UserQuestionSession: + def create_session(db: Session, brainstorming_phase_id: UUID, user_id: UUID) -> UserQuestionSession: """ Create a new user question session. @@ -64,17 +57,11 @@ def create_session( db.commit() db.refresh(session) - logger.info( - f"Created user question session {session.id} for phase {brainstorming_phase_id}" - ) + logger.info(f"Created user question session {session.id} for phase {brainstorming_phase_id}") return session @staticmethod - def list_sessions( - db: Session, - brainstorming_phase_id: UUID, - user_id: UUID - ) -> List[UserQuestionSession]: + def list_sessions(db: Session, brainstorming_phase_id: UUID, user_id: UUID) -> List[UserQuestionSession]: """ List all sessions for a phase and user. @@ -86,17 +73,18 @@ def list_sessions( Returns: List of UserQuestionSession objects """ - return db.query(UserQuestionSession).filter( - UserQuestionSession.brainstorming_phase_id == brainstorming_phase_id, - UserQuestionSession.user_id == user_id - ).order_by(UserQuestionSession.created_at.desc()).all() + return ( + db.query(UserQuestionSession) + .filter( + UserQuestionSession.brainstorming_phase_id == brainstorming_phase_id, + UserQuestionSession.user_id == user_id, + ) + .order_by(UserQuestionSession.created_at.desc()) + .all() + ) @staticmethod - def get_session( - db: Session, - session_id: UUID, - user_id: UUID - ) -> Optional[UserQuestionSession]: + def get_session(db: Session, session_id: UUID, user_id: UUID) -> Optional[UserQuestionSession]: """ Get a session by ID. @@ -108,17 +96,14 @@ def get_session( Returns: UserQuestionSession or None if not found/unauthorized """ - return db.query(UserQuestionSession).filter( - UserQuestionSession.id == session_id, - UserQuestionSession.user_id == user_id - ).first() + return ( + db.query(UserQuestionSession) + .filter(UserQuestionSession.id == session_id, UserQuestionSession.user_id == user_id) + .first() + ) @staticmethod - def get_session_with_messages( - db: Session, - session_id: UUID, - user_id: UUID - ) -> Optional[UserQuestionSession]: + def get_session_with_messages(db: Session, session_id: UUID, user_id: UUID) -> Optional[UserQuestionSession]: """ Get a session with all messages. @@ -130,19 +115,16 @@ def get_session_with_messages( Returns: UserQuestionSession with messages loaded, or None """ - return db.query(UserQuestionSession).options( - joinedload(UserQuestionSession.messages) - ).filter( - UserQuestionSession.id == session_id, - UserQuestionSession.user_id == user_id - ).first() + return ( + db.query(UserQuestionSession) + .options(joinedload(UserQuestionSession.messages)) + .filter(UserQuestionSession.id == session_id, UserQuestionSession.user_id == user_id) + .first() + ) @staticmethod def add_user_message( - db: Session, - session_id: UUID, - content: str, - job_id: Optional[UUID] = None + db: Session, session_id: UUID, content: str, job_id: Optional[UUID] = None ) -> UserQuestionMessage: """ Add a user message to the session. @@ -165,9 +147,7 @@ def add_user_message( db.add(message) # Update session title from first message if not set - session = db.query(UserQuestionSession).filter( - UserQuestionSession.id == session_id - ).first() + session = db.query(UserQuestionSession).filter(UserQuestionSession.id == session_id).first() if session and not session.title: # Truncate to 100 chars for title session.title = content[:100] if len(content) > 100 else content @@ -182,7 +162,7 @@ def add_assistant_message( session_id: UUID, content: str, generated_questions: Optional[List[Dict[str, Any]]] = None, - job_id: Optional[UUID] = None + job_id: Optional[UUID] = None, ) -> UserQuestionMessage: """ Add an assistant message with generated questions. @@ -210,10 +190,7 @@ def add_assistant_message( return message @staticmethod - def get_session_history( - db: Session, - session_id: UUID - ) -> List[Dict[str, Any]]: + def get_session_history(db: Session, session_id: UUID) -> List[Dict[str, Any]]: """ Get the conversation history for a session. @@ -224,21 +201,17 @@ def get_session_history( Returns: List of message dicts with role and content """ - messages = db.query(UserQuestionMessage).filter( - UserQuestionMessage.session_id == session_id - ).order_by(UserQuestionMessage.created_at).all() + messages = ( + db.query(UserQuestionMessage) + .filter(UserQuestionMessage.session_id == session_id) + .order_by(UserQuestionMessage.created_at) + .all() + ) - return [ - {"role": msg.role.value, "content": msg.content} - for msg in messages - ] + return [{"role": msg.role.value, "content": msg.content} for msg in messages] @staticmethod - def can_add_questions( - db: Session, - session_id: UUID, - count: int = 1 - ) -> bool: + def can_add_questions(db: Session, session_id: UUID, count: int = 1) -> bool: """ Check if questions can be added to the session. @@ -250,9 +223,7 @@ def can_add_questions( Returns: True if questions can be added, False otherwise """ - session = db.query(UserQuestionSession).filter( - UserQuestionSession.id == session_id - ).first() + session = db.query(UserQuestionSession).filter(UserQuestionSession.id == session_id).first() if not session: return False @@ -261,11 +232,7 @@ def can_add_questions( @staticmethod def add_questions_to_phase( - db: Session, - session_id: UUID, - message_id: UUID, - temp_question_ids: List[str], - user_id: UUID + db: Session, session_id: UUID, message_id: UUID, temp_question_ids: List[str], user_id: UUID ) -> Dict[str, Any]: """ Add selected questions from a message to the phase as ACTIVE features. @@ -281,19 +248,21 @@ def add_questions_to_phase( Dict with added_count, feature_ids, session_limit_reached """ # Get the session - session = db.query(UserQuestionSession).filter( - UserQuestionSession.id == session_id, - UserQuestionSession.user_id == user_id - ).first() + session = ( + db.query(UserQuestionSession) + .filter(UserQuestionSession.id == session_id, UserQuestionSession.user_id == user_id) + .first() + ) if not session: raise ValueError("Session not found") # Get the message - message = db.query(UserQuestionMessage).filter( - UserQuestionMessage.id == message_id, - UserQuestionMessage.session_id == session_id - ).first() + message = ( + db.query(UserQuestionMessage) + .filter(UserQuestionMessage.id == message_id, UserQuestionMessage.session_id == session_id) + .first() + ) if not message: raise ValueError("Message not found") @@ -304,25 +273,16 @@ def add_questions_to_phase( # Check limit remaining = UserQuestionSessionService.MAX_QUESTIONS_PER_SESSION - session.questions_added if remaining <= 0: - return { - "added_count": 0, - "feature_ids": [], - "session_limit_reached": True - } + return {"added_count": 0, "feature_ids": [], "session_limit_reached": True} # Filter questions to add - questions_to_add = [ - q for q in message.generated_questions - if q.get("temp_id") in temp_question_ids - ] + questions_to_add = [q for q in message.generated_questions if q.get("temp_id") in temp_question_ids] # Limit to remaining budget questions_to_add = questions_to_add[:remaining] # Get phase for module lookup/creation - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == session.brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == session.brainstorming_phase_id).first() if not phase: raise ValueError("Phase not found") @@ -333,17 +293,19 @@ def add_questions_to_phase( for q in questions_to_add: # Get or create module for the aspect aspect_title = q.get("aspect_title", "User Questions") - module = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id, - Module.title == aspect_title, - Module.archived_at.is_(None) - ).first() + module = ( + db.query(Module) + .filter( + Module.brainstorming_phase_id == phase.id, + Module.title == aspect_title, + Module.archived_at.is_(None), + ) + .first() + ) if not module: # Create new module - max_order = db.query(Module).filter( - Module.brainstorming_phase_id == phase.id - ).count() + max_order = db.query(Module).filter(Module.brainstorming_phase_id == phase.id).count() # Generate module key module_key, module_key_number = ModuleService.generate_module_key(db, phase.project_id) @@ -432,14 +394,11 @@ def add_questions_to_phase( return { "added_count": len(added_feature_ids), "feature_ids": added_feature_ids, - "session_limit_reached": session.questions_added >= UserQuestionSessionService.MAX_QUESTIONS_PER_SESSION + "session_limit_reached": session.questions_added >= UserQuestionSessionService.MAX_QUESTIONS_PER_SESSION, } @staticmethod - def build_session_context_for_agent( - db: Session, - session_id: UUID - ) -> Dict[str, Any]: + def build_session_context_for_agent(db: Session, session_id: UUID) -> Dict[str, Any]: """ Build the context needed for the agent pipeline from a session. @@ -450,18 +409,17 @@ def build_session_context_for_agent( Returns: Dict with phase info and session history """ - session = db.query(UserQuestionSession).options( - joinedload(UserQuestionSession.messages) - ).filter( - UserQuestionSession.id == session_id - ).first() + session = ( + db.query(UserQuestionSession) + .options(joinedload(UserQuestionSession.messages)) + .filter(UserQuestionSession.id == session_id) + .first() + ) if not session: raise ValueError("Session not found") - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == session.brainstorming_phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == session.brainstorming_phase_id).first() if not phase: raise ValueError("Phase not found") @@ -500,9 +458,7 @@ def set_ai_error_state( job_id: The failed job ID user_prompt: The user's original prompt (for retry) """ - session = db.query(UserQuestionSession).filter( - UserQuestionSession.id == session_id - ).first() + session = db.query(UserQuestionSession).filter(UserQuestionSession.id == session_id).first() if not session: logger.warning(f"Cannot set AI error state: session {session_id} not found") @@ -526,9 +482,7 @@ def clear_ai_error_state(db: Session, session_id: UUID) -> None: db: Database session session_id: ID of the session to clear error state """ - session = db.query(UserQuestionSession).filter( - UserQuestionSession.id == session_id - ).first() + session = db.query(UserQuestionSession).filter(UserQuestionSession.id == session_id).first() if not session: logger.warning(f"Cannot clear AI error state: session {session_id} not found") diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py index e22e9e6..a8eeec5 100644 --- a/backend/app/services/user_service.py +++ b/backend/app/services/user_service.py @@ -3,15 +3,16 @@ Follows the service layer pattern established in job_service.py. """ + import secrets -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional from uuid import UUID from sqlalchemy.orm import Session -from app.models.user import User from app.auth.utils import hash_password, verify_password +from app.models.user import User class UserService: @@ -159,10 +160,14 @@ def verify_email_token(db: Session, token: str) -> Optional[User]: """ # Find users with non-expired tokens now = datetime.now(timezone.utc) - users_with_tokens = db.query(User).filter( - User.email_verification_token.isnot(None), - User.email_verification_token_expires_at > now, - ).all() + users_with_tokens = ( + db.query(User) + .filter( + User.email_verification_token.isnot(None), + User.email_verification_token_expires_at > now, + ) + .all() + ) # Try to match the token against each user's hash for user in users_with_tokens: diff --git a/backend/app/utils/deep_link.py b/backend/app/utils/deep_link.py index 6c87212..e0909fa 100644 --- a/backend/app/utils/deep_link.py +++ b/backend/app/utils/deep_link.py @@ -1,4 +1,5 @@ """Utilities for building inbox deep link URLs.""" + import re from typing import Optional diff --git a/backend/app/websocket/broadcast_consumer.py b/backend/app/websocket/broadcast_consumer.py index 66ecc36..88d62d5 100644 --- a/backend/app/websocket/broadcast_consumer.py +++ b/backend/app/websocket/broadcast_consumer.py @@ -141,10 +141,7 @@ async def _handle_message(self, message: dict): # Log the broadcast attempt connection_count = manager.get_connection_count(org_id) - logger.info( - f"Broadcasting {message_type} to org {org_id} " - f"({connection_count} connections)" - ) + logger.info(f"Broadcasting {message_type} to org {org_id} ({connection_count} connections)") # Forward to WebSocket clients # Remove org_id from message before sending (clients don't need it) diff --git a/backend/app/websocket/manager.py b/backend/app/websocket/manager.py index 299d48c..1a5acc0 100644 --- a/backend/app/websocket/manager.py +++ b/backend/app/websocket/manager.py @@ -37,7 +37,9 @@ async def connect(self, websocket: WebSocket, org_id: UUID): self.active_connections[org_id_str] = [] self.active_connections[org_id_str].append(websocket) - logger.info(f"WebSocket connected for org {org_id}. Total connections: {len(self.active_connections[org_id_str])}") + logger.info( + f"WebSocket connected for org {org_id}. Total connections: {len(self.active_connections[org_id_str])}" + ) async def disconnect(self, websocket: WebSocket, org_id: UUID): """ @@ -52,7 +54,9 @@ async def disconnect(self, websocket: WebSocket, org_id: UUID): if org_id_str in self.active_connections: if websocket in self.active_connections[org_id_str]: self.active_connections[org_id_str].remove(websocket) - logger.info(f"WebSocket disconnected for org {org_id}. Remaining connections: {len(self.active_connections[org_id_str])}") + logger.info( + f"WebSocket disconnected for org {org_id}. Remaining connections: {len(self.active_connections[org_id_str])}" + ) # Clean up empty connection lists if not self.active_connections[org_id_str]: diff --git a/backend/export_mock_discovery_data.py b/backend/export_mock_discovery_data.py index d7937a6..bd55f7f 100644 --- a/backend/export_mock_discovery_data.py +++ b/backend/export_mock_discovery_data.py @@ -2,24 +2,30 @@ """ One-time script to export AutoPrompt v2 discovery questions to mock_discovery_data.json """ + import json import sys + from sqlalchemy import create_engine, select + from app.models import DiscoveryQuestion # AutoPrompt v2 project ID AUTOPROMPT_V2_PROJECT_ID = "d9df3d05-f736-42b9-a60b-0424f906515a" + def export_mock_data(): """Export discovery questions from AutoPrompt v2 project to JSON file.""" # Create database connection - engine = create_engine('postgresql://mfbt:iammfbt@localhost/mfbt_dev') + engine = create_engine("postgresql://mfbt:iammfbt@localhost/mfbt_dev") with engine.connect() as conn: # Query all discovery questions from AutoPrompt v2 project - stmt = select(DiscoveryQuestion).where( - DiscoveryQuestion.project_id == AUTOPROMPT_V2_PROJECT_ID - ).order_by(DiscoveryQuestion.created_at) + stmt = ( + select(DiscoveryQuestion) + .where(DiscoveryQuestion.project_id == AUTOPROMPT_V2_PROJECT_ID) + .order_by(DiscoveryQuestion.created_at) + ) result = conn.execute(stmt) questions = result.fetchall() @@ -50,7 +56,7 @@ def export_mock_data(): # Write to JSON file output_path = "app/agents/discovery/mock_discovery_data.json" - with open(output_path, 'w') as f: + with open(output_path, "w") as f: json.dump(mock_data, f, indent=2, ensure_ascii=False) print(f"Successfully exported {len(mock_data)} questions to {output_path}") @@ -60,14 +66,15 @@ def export_mock_data(): priorities = {} depths = {} for q in mock_data: - categories[q['category']] = categories.get(q['category'], 0) + 1 - priorities[q['priority']] = priorities.get(q['priority'], 0) + 1 - depths[q['depth']] = depths.get(q['depth'], 0) + 1 + categories[q["category"]] = categories.get(q["category"], 0) + 1 + priorities[q["priority"]] = priorities.get(q["priority"], 0) + 1 + depths[q["depth"]] = depths.get(q["depth"], 0) + 1 print("\nStatistics:") print(f" Categories: {dict(sorted(categories.items()))}") print(f" Priorities: {dict(sorted(priorities.items()))}") print(f" Depths: {dict(sorted(depths.items()))}") + if __name__ == "__main__": export_mock_data() diff --git a/backend/mcp_server.py b/backend/mcp_server.py index 6883960..c0ee37a 100755 --- a/backend/mcp_server.py +++ b/backend/mcp_server.py @@ -33,7 +33,7 @@ # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ async def main(): # Convert database URL to string and hide credentials db_url_str = str(settings.database_url) - db_display = db_url_str.split('@')[-1] if '@' in db_url_str else "local" + db_display = db_url_str.split("@")[-1] if "@" in db_url_str else "local" logger.info(f"Database: {db_display}") try: diff --git a/backend/scripts/fix_missing_content_versions.py b/backend/scripts/fix_missing_content_versions.py index 0e2fac9..c91ddf5 100644 --- a/backend/scripts/fix_missing_content_versions.py +++ b/backend/scripts/fix_missing_content_versions.py @@ -1,10 +1,12 @@ """Script to create content versions for features that have spec_text/prompt_plan_text but no content versions.""" + import sys -sys.path.insert(0, '/home/shuveb/Projects/mfbt/backend') + +sys.path.insert(0, "/home/shuveb/Projects/mfbt/backend") from app.database import SessionLocal from app.models.feature import Feature -from app.models.feature_content_version import FeatureContentVersion, FeatureContentType +from app.models.feature_content_version import FeatureContentType, FeatureContentVersion from app.services.feature_content_version_service import FeatureContentVersionService @@ -12,15 +14,19 @@ def fix_missing_content_versions(): db = SessionLocal() try: # Find features with spec_text but no spec content version - features_needing_spec = db.query(Feature).filter( - Feature.spec_text.isnot(None), - Feature.spec_text != '', - ~Feature.id.in_( - db.query(FeatureContentVersion.feature_id).filter( - FeatureContentVersion.content_type == FeatureContentType.SPEC - ) + features_needing_spec = ( + db.query(Feature) + .filter( + Feature.spec_text.isnot(None), + Feature.spec_text != "", + ~Feature.id.in_( + db.query(FeatureContentVersion.feature_id).filter( + FeatureContentVersion.content_type == FeatureContentType.SPEC + ) + ), ) - ).all() + .all() + ) print(f"Found {len(features_needing_spec)} features needing spec content versions") for feature in features_needing_spec: @@ -35,15 +41,19 @@ def fix_missing_content_versions(): ) # Find features with prompt_plan_text but no prompt_plan content version - features_needing_plan = db.query(Feature).filter( - Feature.prompt_plan_text.isnot(None), - Feature.prompt_plan_text != '', - ~Feature.id.in_( - db.query(FeatureContentVersion.feature_id).filter( - FeatureContentVersion.content_type == FeatureContentType.PROMPT_PLAN - ) + features_needing_plan = ( + db.query(Feature) + .filter( + Feature.prompt_plan_text.isnot(None), + Feature.prompt_plan_text != "", + ~Feature.id.in_( + db.query(FeatureContentVersion.feature_id).filter( + FeatureContentVersion.content_type == FeatureContentType.PROMPT_PLAN + ) + ), ) - ).all() + .all() + ) print(f"Found {len(features_needing_plan)} features needing prompt_plan content versions") for feature in features_needing_plan: diff --git a/backend/scripts/mcp_stdio_proxy.py b/backend/scripts/mcp_stdio_proxy.py index 248c451..540f9a5 100755 --- a/backend/scripts/mcp_stdio_proxy.py +++ b/backend/scripts/mcp_stdio_proxy.py @@ -21,8 +21,8 @@ import json import os import sys -import urllib.request import urllib.error +import urllib.request def send_request(url: str, api_key: str, payload: dict) -> dict: @@ -62,35 +62,43 @@ def send_request(url: str, api_key: str, payload: dict) -> dict: def main(): parser = argparse.ArgumentParser(description="MCP Stdio Proxy for MFBT") - parser.add_argument("--url", help="MCP HTTP endpoint URL", - default=os.environ.get("MFBT_MCP_URL")) - parser.add_argument("--api-key", help="API key for authentication", - default=os.environ.get("MFBT_API_KEY")) + parser.add_argument("--url", help="MCP HTTP endpoint URL", default=os.environ.get("MFBT_MCP_URL")) + parser.add_argument("--api-key", help="API key for authentication", default=os.environ.get("MFBT_API_KEY")) args = parser.parse_args() url = args.url api_key = args.api_key if not url: - print(json.dumps({ - "jsonrpc": "2.0", - "id": None, - "error": { - "code": -32600, - "message": "Missing MCP URL. Set --url or MFBT_MCP_URL environment variable.", - }, - }), flush=True) + print( + json.dumps( + { + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32600, + "message": "Missing MCP URL. Set --url or MFBT_MCP_URL environment variable.", + }, + } + ), + flush=True, + ) sys.exit(1) if not api_key: - print(json.dumps({ - "jsonrpc": "2.0", - "id": None, - "error": { - "code": -32600, - "message": "Missing API key. Set --api-key or MFBT_API_KEY environment variable.", - }, - }), flush=True) + print( + json.dumps( + { + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32600, + "message": "Missing API key. Set --api-key or MFBT_API_KEY environment variable.", + }, + } + ), + flush=True, + ) sys.exit(1) # Read JSON-RPC requests from stdin, send to HTTP endpoint, write responses to stdout @@ -102,14 +110,19 @@ def main(): try: request = json.loads(line) except json.JSONDecodeError as e: - print(json.dumps({ - "jsonrpc": "2.0", - "id": None, - "error": { - "code": -32700, - "message": f"Parse error: {e}", - }, - }), flush=True) + print( + json.dumps( + { + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32700, + "message": f"Parse error: {e}", + }, + } + ), + flush=True, + ) continue # Forward request to HTTP endpoint diff --git a/backend/tests/agents/brainstorm_conversation/test_input_validator.py b/backend/tests/agents/brainstorm_conversation/test_input_validator.py index f480ea8..2fcf4c8 100644 --- a/backend/tests/agents/brainstorm_conversation/test_input_validator.py +++ b/backend/tests/agents/brainstorm_conversation/test_input_validator.py @@ -1,15 +1,16 @@ """Tests for the InputValidatorAgent.""" import json -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest + from app.agents.brainstorm_conversation.input_validator import ( + DEFAULT_CLARIFICATION, + SYSTEM_PROMPT, InputValidatorAgent, ValidationResult, create_input_validator, - SYSTEM_PROMPT, - DEFAULT_CLARIFICATION, ) @@ -58,14 +59,14 @@ async def test_validate_valid_input(self, mock_model_client): """Test that valid input returns is_valid=True.""" # Setup mock response mock_response = MagicMock() - mock_response.chat_message.content = json.dumps({ - "is_valid": True, - "reason": "Clear topic about authentication", - }) - - with patch( - "app.agents.brainstorm_conversation.input_validator.AssistantAgent" - ) as MockAgent: + mock_response.chat_message.content = json.dumps( + { + "is_valid": True, + "reason": "Clear topic about authentication", + } + ) + + with patch("app.agents.brainstorm_conversation.input_validator.AssistantAgent") as MockAgent: mock_agent_instance = AsyncMock() mock_agent_instance.on_messages.return_value = mock_response MockAgent.return_value = mock_agent_instance @@ -81,15 +82,15 @@ async def test_validate_invalid_input(self, mock_model_client): """Test that invalid input returns is_valid=False with clarification.""" # Setup mock response mock_response = MagicMock() - mock_response.chat_message.content = json.dumps({ - "is_valid": False, - "reason": "Nonsensical input", - "clarification": "Could you describe what you'd like to explore?", - }) - - with patch( - "app.agents.brainstorm_conversation.input_validator.AssistantAgent" - ) as MockAgent: + mock_response.chat_message.content = json.dumps( + { + "is_valid": False, + "reason": "Nonsensical input", + "clarification": "Could you describe what you'd like to explore?", + } + ) + + with patch("app.agents.brainstorm_conversation.input_validator.AssistantAgent") as MockAgent: mock_agent_instance = AsyncMock() mock_agent_instance.on_messages.return_value = mock_response MockAgent.return_value = mock_agent_instance @@ -104,9 +105,7 @@ async def test_validate_invalid_input(self, mock_model_client): @pytest.mark.asyncio async def test_validate_defaults_to_valid_on_error(self, mock_model_client): """Test that validation defaults to valid on any error (fail-open).""" - with patch( - "app.agents.brainstorm_conversation.input_validator.AssistantAgent" - ) as MockAgent: + with patch("app.agents.brainstorm_conversation.input_validator.AssistantAgent") as MockAgent: mock_agent_instance = AsyncMock() mock_agent_instance.on_messages.side_effect = Exception("API error") MockAgent.return_value = mock_agent_instance @@ -121,10 +120,12 @@ async def test_validate_defaults_to_valid_on_error(self, mock_model_client): async def test_validate_with_phase_context(self, mock_model_client): """Test validation includes phase context in prompt.""" mock_response = MagicMock() - mock_response.chat_message.content = json.dumps({ - "is_valid": True, - "reason": "Valid within project context", - }) + mock_response.chat_message.content = json.dumps( + { + "is_valid": True, + "reason": "Valid within project context", + } + ) captured_prompt = None @@ -133,9 +134,7 @@ async def capture_prompt(messages, *args, **kwargs): captured_prompt = messages[0].content return mock_response - with patch( - "app.agents.brainstorm_conversation.input_validator.AssistantAgent" - ) as MockAgent: + with patch("app.agents.brainstorm_conversation.input_validator.AssistantAgent") as MockAgent: mock_agent_instance = AsyncMock() mock_agent_instance.on_messages = capture_prompt MockAgent.return_value = mock_agent_instance @@ -156,15 +155,15 @@ async def capture_prompt(messages, *args, **kwargs): async def test_validate_uses_default_clarification_when_missing(self, mock_model_client): """Test that default clarification is used when LLM doesn't provide one.""" mock_response = MagicMock() - mock_response.chat_message.content = json.dumps({ - "is_valid": False, - "reason": "Too vague", - # No clarification provided - }) - - with patch( - "app.agents.brainstorm_conversation.input_validator.AssistantAgent" - ) as MockAgent: + mock_response.chat_message.content = json.dumps( + { + "is_valid": False, + "reason": "Too vague", + # No clarification provided + } + ) + + with patch("app.agents.brainstorm_conversation.input_validator.AssistantAgent") as MockAgent: mock_agent_instance = AsyncMock() mock_agent_instance.on_messages.return_value = mock_response MockAgent.return_value = mock_agent_instance diff --git a/backend/tests/agents/collab_thread_assistant/test_mcq_parser.py b/backend/tests/agents/collab_thread_assistant/test_mcq_parser.py index d8bf200..c3e7a30 100644 --- a/backend/tests/agents/collab_thread_assistant/test_mcq_parser.py +++ b/backend/tests/agents/collab_thread_assistant/test_mcq_parser.py @@ -1,12 +1,12 @@ """ Tests for the MCQ parser module. """ -import pytest + from app.agents.collab_thread_assistant.mcq_parser import ( - parse_mfbtai_response, + MAX_MCQS_PER_RESPONSE, ParsedMCQ, ParsedResponse, - MAX_MCQS_PER_RESPONSE, + parse_mfbtai_response, ) @@ -144,10 +144,10 @@ def test_parse_limits_mcq_count(self): """More than MAX_MCQS_PER_RESPONSE MCQs are truncated.""" # Create response with 5 MCQs questions = [ - {"question_text": f"Q{i}?", "choices": [{"id": f"opt{i}", "label": f"Option {i}"}]} - for i in range(5) + {"question_text": f"Q{i}?", "choices": [{"id": f"opt{i}", "label": f"Option {i}"}]} for i in range(5) ] import json + response = f"[MFBT_MCQ]{json.dumps({'questions': questions})}[/MFBT_MCQ]" result = parse_mfbtai_response(response) diff --git a/backend/tests/agents/test_brainstorm_code_explorer_stage.py b/backend/tests/agents/test_brainstorm_code_explorer_stage.py index ccf2f11..ce1e641 100644 --- a/backend/tests/agents/test_brainstorm_code_explorer_stage.py +++ b/backend/tests/agents/test_brainstorm_code_explorer_stage.py @@ -1,21 +1,21 @@ """Tests for the code explorer stage in the brainstorm conversation pipeline.""" -import pytest from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 +import pytest + from app.agents.brainstorm_conversation.code_explorer_stage import ( - run_code_exploration, _build_exploration_prompt, + run_code_exploration, ) from app.agents.brainstorm_conversation.types import ( - CodeExplorationContext, - SummarizedPhaseContext, ClassificationResult, + CodeExplorationContext, PhaseComplexity, + SummarizedPhaseContext, ) - # Patch path for code_explorer_client at the location where it's imported CODE_EXPLORER_CLIENT_PATH = "app.agents.brainstorm_conversation.code_explorer_stage.code_explorer_client" @@ -155,9 +155,7 @@ def classification(self): ) @pytest.mark.asyncio - async def test_returns_none_when_explorer_disabled( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_returns_none_when_explorer_disabled(self, mock_db, mock_project, summarized_context, classification): """Test returns None when code explorer is disabled.""" # Mock platform settings with code_explorer_enabled=False mock_settings = MagicMock() @@ -176,9 +174,7 @@ async def test_returns_none_when_explorer_disabled( assert result is None @pytest.mark.asyncio - async def test_returns_none_when_no_settings( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_returns_none_when_no_settings(self, mock_db, mock_project, summarized_context, classification): """Test returns None when no platform settings exist.""" mock_db.query.return_value.first.return_value = None @@ -194,9 +190,7 @@ async def test_returns_none_when_no_settings( assert result is None @pytest.mark.asyncio - async def test_returns_none_when_no_repositories( - self, mock_db, summarized_context, classification - ): + async def test_returns_none_when_no_repositories(self, mock_db, summarized_context, classification): """Test returns None when project has no repositories.""" # Mock platform settings with code_explorer_enabled=True mock_settings = MagicMock() @@ -220,9 +214,7 @@ async def test_returns_none_when_no_repositories( assert result is None @pytest.mark.asyncio - async def test_returns_none_when_api_key_missing( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_returns_none_when_api_key_missing(self, mock_db, mock_project, summarized_context, classification): """Test returns None when API key is not configured.""" # Mock platform settings with code_explorer_enabled=True mock_settings = MagicMock() @@ -263,14 +255,18 @@ async def test_returns_context_on_successful_exploration( "completion_tokens": 500, } - with patch( - "workers.handlers.code_explorer.get_code_explorer_api_key", - return_value="test-api-key", - ), patch( - "workers.handlers.code_explorer.get_github_token_for_org", - new_callable=AsyncMock, - return_value="test-github-token", - ), patch(CODE_EXPLORER_CLIENT_PATH) as mock_client: + with ( + patch( + "workers.handlers.code_explorer.get_code_explorer_api_key", + return_value="test-api-key", + ), + patch( + "workers.handlers.code_explorer.get_github_token_for_org", + new_callable=AsyncMock, + return_value="test-github-token", + ), + patch(CODE_EXPLORER_CLIENT_PATH) as mock_client, + ): mock_client.explore = AsyncMock(return_value=mock_result) result = await run_code_exploration( @@ -292,9 +288,7 @@ async def test_returns_context_on_successful_exploration( assert result.exploration_prompt is not None @pytest.mark.asyncio - async def test_returns_none_on_exploration_failure( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_returns_none_on_exploration_failure(self, mock_db, mock_project, summarized_context, classification): """Test returns None when exploration fails.""" # Mock platform settings mock_settings = MagicMock() @@ -308,14 +302,18 @@ async def test_returns_none_on_exploration_failure( "error_code": "TIMEOUT", } - with patch( - "workers.handlers.code_explorer.get_code_explorer_api_key", - return_value="test-api-key", - ), patch( - "workers.handlers.code_explorer.get_github_token_for_org", - new_callable=AsyncMock, - return_value="test-github-token", - ), patch(CODE_EXPLORER_CLIENT_PATH) as mock_client: + with ( + patch( + "workers.handlers.code_explorer.get_code_explorer_api_key", + return_value="test-api-key", + ), + patch( + "workers.handlers.code_explorer.get_github_token_for_org", + new_callable=AsyncMock, + return_value="test-github-token", + ), + patch(CODE_EXPLORER_CLIENT_PATH) as mock_client, + ): mock_client.explore = AsyncMock(return_value=mock_result) result = await run_code_exploration( @@ -330,9 +328,7 @@ async def test_returns_none_on_exploration_failure( assert result is None @pytest.mark.asyncio - async def test_returns_none_on_empty_output( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_returns_none_on_empty_output(self, mock_db, mock_project, summarized_context, classification): """Test returns None when exploration returns empty output.""" # Mock platform settings mock_settings = MagicMock() @@ -345,14 +341,18 @@ async def test_returns_none_on_empty_output( "output": " ", # Whitespace only } - with patch( - "workers.handlers.code_explorer.get_code_explorer_api_key", - return_value="test-api-key", - ), patch( - "workers.handlers.code_explorer.get_github_token_for_org", - new_callable=AsyncMock, - return_value="test-github-token", - ), patch(CODE_EXPLORER_CLIENT_PATH) as mock_client: + with ( + patch( + "workers.handlers.code_explorer.get_code_explorer_api_key", + return_value="test-api-key", + ), + patch( + "workers.handlers.code_explorer.get_github_token_for_org", + new_callable=AsyncMock, + return_value="test-github-token", + ), + patch(CODE_EXPLORER_CLIENT_PATH) as mock_client, + ): mock_client.explore = AsyncMock(return_value=mock_result) result = await run_code_exploration( @@ -367,23 +367,25 @@ async def test_returns_none_on_empty_output( assert result is None @pytest.mark.asyncio - async def test_handles_exception_gracefully( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_handles_exception_gracefully(self, mock_db, mock_project, summarized_context, classification): """Test handles exceptions gracefully and returns None.""" # Mock platform settings mock_settings = MagicMock() mock_settings.code_explorer_enabled = True mock_db.query.return_value.first.return_value = mock_settings - with patch( - "workers.handlers.code_explorer.get_code_explorer_api_key", - return_value="test-api-key", - ), patch( - "workers.handlers.code_explorer.get_github_token_for_org", - new_callable=AsyncMock, - return_value="test-github-token", - ), patch(CODE_EXPLORER_CLIENT_PATH) as mock_client: + with ( + patch( + "workers.handlers.code_explorer.get_code_explorer_api_key", + return_value="test-api-key", + ), + patch( + "workers.handlers.code_explorer.get_github_token_for_org", + new_callable=AsyncMock, + return_value="test-github-token", + ), + patch(CODE_EXPLORER_CLIENT_PATH) as mock_client, + ): mock_client.explore = AsyncMock(side_effect=Exception("Connection error")) result = await run_code_exploration( @@ -398,9 +400,7 @@ async def test_handles_exception_gracefully( assert result is None @pytest.mark.asyncio - async def test_continues_without_github_token( - self, mock_db, mock_project, summarized_context, classification - ): + async def test_continues_without_github_token(self, mock_db, mock_project, summarized_context, classification): """Test continues exploration even if GitHub token retrieval fails.""" # Mock platform settings mock_settings = MagicMock() @@ -414,14 +414,18 @@ async def test_continues_without_github_token( "execution_time_seconds": 3.0, } - with patch( - "workers.handlers.code_explorer.get_code_explorer_api_key", - return_value="test-api-key", - ), patch( - "workers.handlers.code_explorer.get_github_token_for_org", - new_callable=AsyncMock, - side_effect=Exception("Token retrieval failed"), - ), patch(CODE_EXPLORER_CLIENT_PATH) as mock_client: + with ( + patch( + "workers.handlers.code_explorer.get_code_explorer_api_key", + return_value="test-api-key", + ), + patch( + "workers.handlers.code_explorer.get_github_token_for_org", + new_callable=AsyncMock, + side_effect=Exception("Token retrieval failed"), + ), + patch(CODE_EXPLORER_CLIENT_PATH) as mock_client, + ): mock_client.explore = AsyncMock(return_value=mock_result) result = await run_code_exploration( diff --git a/backend/tests/agents/test_brainstorm_spec_types.py b/backend/tests/agents/test_brainstorm_spec_types.py index 9236f66..1f79988 100644 --- a/backend/tests/agents/test_brainstorm_spec_types.py +++ b/backend/tests/agents/test_brainstorm_spec_types.py @@ -1,9 +1,8 @@ """Tests for brainstorm spec types and summary generation.""" -import pytest from app.agents.brainstorm_spec.types import ( - SpecSectionContent, BrainstormSpecification, + SpecSectionContent, ) @@ -41,14 +40,16 @@ class TestBrainstormSpecification: def test_to_json_includes_summary(self): """Test that to_json includes summary field.""" - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="exec", - title="Executive Summary", - body_markdown="Full content", - summary="Brief summary", - ) - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="exec", + title="Executive Summary", + body_markdown="Full content", + summary="Brief summary", + ) + ] + ) result = spec.to_json() assert "sections" in result @@ -61,27 +62,31 @@ def test_to_json_includes_summary(self): def test_to_json_includes_empty_summary(self): """Test that to_json includes empty summary when not provided.""" - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="exec", - title="Executive Summary", - body_markdown="Full content", - ) - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="exec", + title="Executive Summary", + body_markdown="Full content", + ) + ] + ) result = spec.to_json() assert result["sections"][0]["summary"] == "" def test_to_markdown_unchanged(self): """Test that to_markdown still works as before (full content).""" - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="exec", - title="Executive Summary", - body_markdown="Full detailed content here.", - summary="Brief summary.", - ) - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="exec", + title="Executive Summary", + body_markdown="Full detailed content here.", + summary="Brief summary.", + ) + ] + ) result = spec.to_markdown() assert "## Executive Summary" in result @@ -91,20 +96,22 @@ def test_to_markdown_unchanged(self): def test_to_summary_markdown_uses_summaries(self): """Test that to_summary_markdown outputs section summaries.""" - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="exec", - title="Executive Summary", - body_markdown="Very long detailed content that should not appear...", - summary="Brief overview for downstream processing.", - ), - SpecSectionContent( - id="goals", - title="Goals", - body_markdown="Detailed goals...", - summary="Key objectives are A, B, and C.", - ), - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="exec", + title="Executive Summary", + body_markdown="Very long detailed content that should not appear...", + summary="Brief overview for downstream processing.", + ), + SpecSectionContent( + id="goals", + title="Goals", + body_markdown="Detailed goals...", + summary="Key objectives are A, B, and C.", + ), + ] + ) result = spec.to_summary_markdown() assert "## Executive Summary" in result @@ -117,14 +124,16 @@ def test_to_summary_markdown_uses_summaries(self): def test_to_summary_markdown_fallback_without_summary(self): """Test fallback when section has no summary.""" long_content = "A" * 400 # 400 chars - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="exec", - title="Executive Summary", - body_markdown=long_content, - summary="", # No summary - ) - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="exec", + title="Executive Summary", + body_markdown=long_content, + summary="", # No summary + ) + ] + ) result = spec.to_summary_markdown() # Should have truncated content as fallback (300 chars max) @@ -136,14 +145,16 @@ def test_to_summary_markdown_fallback_without_summary(self): def test_to_summary_markdown_short_content_no_truncation(self): """Test that short content without summary is not truncated.""" short_content = "Short content here." - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="exec", - title="Executive Summary", - body_markdown=short_content, - summary="", # No summary - ) - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="exec", + title="Executive Summary", + body_markdown=short_content, + summary="", # No summary + ) + ] + ) result = spec.to_summary_markdown() assert short_content in result @@ -152,26 +163,28 @@ def test_to_summary_markdown_short_content_no_truncation(self): def test_to_summary_markdown_multiple_sections(self): """Test summary markdown with multiple sections.""" - spec = BrainstormSpecification(sections=[ - SpecSectionContent( - id="sec1", - title="1. First Section", - body_markdown="Content 1", - summary="Summary 1", - ), - SpecSectionContent( - id="sec2", - title="2. Second Section", - body_markdown="Content 2", - summary="Summary 2", - ), - SpecSectionContent( - id="sec3", - title="3. Third Section", - body_markdown="Content 3", - summary="Summary 3", - ), - ]) + spec = BrainstormSpecification( + sections=[ + SpecSectionContent( + id="sec1", + title="1. First Section", + body_markdown="Content 1", + summary="Summary 1", + ), + SpecSectionContent( + id="sec2", + title="2. Second Section", + body_markdown="Content 2", + summary="Summary 2", + ), + SpecSectionContent( + id="sec3", + title="3. Third Section", + body_markdown="Content 3", + summary="Summary 3", + ), + ] + ) result = spec.to_summary_markdown() assert "## 1. First Section" in result diff --git a/backend/tests/agents/test_brainstorm_spec_writer.py b/backend/tests/agents/test_brainstorm_spec_writer.py index c38ece1..48a1435 100644 --- a/backend/tests/agents/test_brainstorm_spec_writer.py +++ b/backend/tests/agents/test_brainstorm_spec_writer.py @@ -1,7 +1,9 @@ """Tests for brainstorm spec writer agent parsing logic.""" -import pytest + from unittest.mock import MagicMock +import pytest + from app.agents.brainstorm_spec.writer import WriterAgent @@ -113,10 +115,10 @@ def test_generate_fallback_summary_breaks_at_word_boundary(self, writer): def test_parse_multiline_json_content(self, writer): """Test parsing JSON with multiline content.""" - response = '''{ + response = """{ "content": "# Header\\n\\n## Subheader\\n\\n- Item 1\\n- Item 2", "summary": "A multi-paragraph summary." - }''' + }""" content, summary = writer._parse_writer_response(response, "test_section") diff --git a/backend/tests/agents/test_collab_thread_assistant/test_assistant.py b/backend/tests/agents/test_collab_thread_assistant/test_assistant.py index d726464..d2e4a99 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_assistant.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_assistant.py @@ -1,12 +1,13 @@ """Tests for the Collab Thread Assistant agent.""" -import pytest +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from unittest.mock import MagicMock, AsyncMock, patch + +import pytest from app.agents.collab_thread_assistant.assistant import ( - CollabThreadAssistant, SYSTEM_PROMPT, + CollabThreadAssistant, ) from app.agents.collab_thread_assistant.types import ( CollabThreadContext, @@ -279,9 +280,7 @@ async def test_respond_calls_agent(self, assistant, sample_context): - Consider scalability needs.""" - with patch.object( - assistant, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -298,9 +297,7 @@ async def test_respond_handles_list_content(self, assistant, sample_context): mock_response = MagicMock() mock_response.chat_message.content = ["Part 1", "Part 2", "Part 3"] - with patch.object( - assistant, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -311,9 +308,7 @@ async def test_respond_handles_list_content(self, assistant, sample_context): @pytest.mark.asyncio async def test_respond_fallback_on_error(self, assistant, sample_context): """Test fallback response on LLM error.""" - with patch.object( - assistant, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(side_effect=Exception("API Error")) @@ -333,9 +328,7 @@ async def test_respond_with_logger(self, mock_model_client, sample_context): mock_response = MagicMock() mock_response.chat_message.content = "Response text" - with patch.object( - assistant, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -348,8 +341,7 @@ async def test_respond_with_logger(self, mock_model_client, sample_context): def test_generate_fallback_response(self, assistant): """Test fallback response generation.""" response = assistant._generate_fallback_response( - user_message="@MFBTAI What should I do?", - error="Connection timeout" + user_message="@MFBTAI What should I do?", error="Connection timeout" ) assert "## Summary" in response diff --git a/backend/tests/agents/test_collab_thread_assistant/test_context_loader.py b/backend/tests/agents/test_collab_thread_assistant/test_context_loader.py index bc82c8d..b3382ae 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_context_loader.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_context_loader.py @@ -1,24 +1,25 @@ """Tests for context_loader module of the Collab Thread Assistant.""" -import pytest -from uuid import uuid4 from datetime import datetime, timezone from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest from app.agents.collab_thread_assistant.context_loader import ( - token_count, - load_thread, + _format_messages_as_text, + load_brainstorming_phase_context, load_feature_files, - load_grounding_files, load_files, - load_brainstorming_phase_context, - _format_messages_as_text, + load_grounding_files, + load_thread, + token_count, ) from app.agents.collab_thread_assistant.types import ThreadMessage -from app.models.thread_item import ThreadItemType -from app.models.thread import ContextType -from app.models.module import ModuleType from app.models.feature import FeatureVisibilityStatus +from app.models.module import ModuleType +from app.models.thread import ContextType +from app.models.thread_item import ThreadItemType class TestTokenCount: @@ -529,9 +530,7 @@ def test_returns_none_when_feature_not_found(self): mock_query.first.return_value = None mock_db.query.return_value = mock_query - result = load_brainstorming_phase_context( - mock_db, str(uuid4()), str(uuid4()) - ) + result = load_brainstorming_phase_context(mock_db, str(uuid4()), str(uuid4())) assert result is None @@ -550,9 +549,7 @@ def test_returns_none_when_no_brainstorming_phase(self): mock_query.first.return_value = mock_feature mock_db.query.return_value = mock_query - result = load_brainstorming_phase_context( - mock_db, str(uuid4()), str(uuid4()) - ) + result = load_brainstorming_phase_context(mock_db, str(uuid4()), str(uuid4())) assert result is None @@ -635,9 +632,7 @@ def query_side_effect(model): mock_db.query.side_effect = query_side_effect - result = load_brainstorming_phase_context( - mock_db, str(uuid4()), feature_id - ) + result = load_brainstorming_phase_context(mock_db, str(uuid4()), feature_id) assert result is not None assert result.phase_id == mock_phase.id @@ -716,9 +711,7 @@ def query_side_effect(model): mock_db.query.side_effect = query_side_effect - result = load_brainstorming_phase_context( - mock_db, str(uuid4()), feature_id - ) + result = load_brainstorming_phase_context(mock_db, str(uuid4()), feature_id) # Should return None since no answered questions and no comments assert result is None @@ -772,9 +765,7 @@ def query_side_effect(model): mock_db.query.side_effect = query_side_effect - result = load_brainstorming_phase_context( - mock_db, str(uuid4()), feature_id, exclude_current_feature=True - ) + result = load_brainstorming_phase_context(mock_db, str(uuid4()), feature_id, exclude_current_feature=True) # Should return None since only current feature's thread exists assert result is None @@ -841,9 +832,7 @@ def query_side_effect(model): mock_db.query.side_effect = query_side_effect - result = load_brainstorming_phase_context( - mock_db, str(uuid4()), feature_id - ) + result = load_brainstorming_phase_context(mock_db, str(uuid4()), feature_id) assert result is not None # Check feature_contexts (new grouped structure) @@ -859,9 +848,7 @@ def query_side_effect(model): class TestLoadCrossProjectContext: """Tests for load_cross_project_context function.""" - def _create_mock_phase_with_decisions( - self, phase_id, phase_title, phase_description, decisions - ): + def _create_mock_phase_with_decisions(self, phase_id, phase_title, phase_description, decisions): """Helper to create a mock phase with decision context.""" mock_phase = MagicMock() mock_phase.id = phase_id @@ -879,9 +866,7 @@ def _create_mock_module(self, module_id, title, brainstorming_phase_id=None): mock_module.archived_at = None return mock_module - def _create_mock_feature_with_thread( - self, feature_id, title, decision_summary_short=None, decision_summary=None - ): + def _create_mock_feature_with_thread(self, feature_id, title, decision_summary_short=None, decision_summary=None): """Helper to create a mock feature with a thread that has decision summary.""" mock_feature = MagicMock() mock_feature.id = feature_id @@ -936,16 +921,13 @@ def query_side_effect(model): def test_loads_decisions_from_other_phases(self): """Test loading decisions from other brainstorming phases.""" from app.agents.collab_thread_assistant.context_loader import load_cross_project_context - from app.models.feature import FeatureType mock_db = MagicMock() project_id = str(uuid4()) # Create phase with decisions phase_id = uuid4() - mock_phase = self._create_mock_phase_with_decisions( - phase_id, "Phase 1", "First planning phase", [] - ) + mock_phase = self._create_mock_phase_with_decisions(phase_id, "Phase 1", "First planning phase", []) # Create module for the phase module_id = uuid4() @@ -982,8 +964,8 @@ def test_loads_decisions_from_other_phases(self): def query_side_effect(model): from app.models.brainstorming_phase import BrainstormingPhase - from app.models.module import Module from app.models.feature import Feature + from app.models.module import Module from app.models.thread import Thread query_calls.append(model) @@ -1039,9 +1021,7 @@ def query_side_effect(model): mock_db.query.side_effect = query_side_effect - result = load_cross_project_context( - mock_db, project_id, exclude_phase_id=exclude_phase_id - ) + result = load_cross_project_context(mock_db, project_id, exclude_phase_id=exclude_phase_id) # Should return None since no other phases exist assert result is None @@ -1094,8 +1074,8 @@ def test_loads_project_level_features(self): def query_side_effect(model): from app.models.brainstorming_phase import BrainstormingPhase - from app.models.module import Module from app.models.feature import Feature + from app.models.module import Module from app.models.thread import Thread if model == BrainstormingPhase: @@ -1129,9 +1109,7 @@ def test_truncates_long_descriptions(self): # Create phase with very long description phase_id = uuid4() long_description = "A" * 300 # Longer than 200 char limit - mock_phase = self._create_mock_phase_with_decisions( - phase_id, "Long Phase", long_description, [] - ) + mock_phase = self._create_mock_phase_with_decisions(phase_id, "Long Phase", long_description, []) # Create module and feature with decision module_id = uuid4() @@ -1164,8 +1142,8 @@ def test_truncates_long_descriptions(self): def query_side_effect(model): from app.models.brainstorming_phase import BrainstormingPhase - from app.models.module import Module from app.models.feature import Feature + from app.models.module import Module from app.models.thread import Thread if model == BrainstormingPhase: @@ -1196,9 +1174,7 @@ def test_falls_back_to_full_summary_when_short_missing(self): # Create phase with decisions phase_id = uuid4() - mock_phase = self._create_mock_phase_with_decisions( - phase_id, "Phase 1", "Description", [] - ) + mock_phase = self._create_mock_phase_with_decisions(phase_id, "Phase 1", "Description", []) module_id = uuid4() mock_module = self._create_mock_module(module_id, "Module", phase_id) @@ -1232,8 +1208,8 @@ def test_falls_back_to_full_summary_when_short_missing(self): def query_side_effect(model): from app.models.brainstorming_phase import BrainstormingPhase - from app.models.module import Module from app.models.feature import Feature + from app.models.module import Module from app.models.thread import Thread if model == BrainstormingPhase: diff --git a/backend/tests/agents/test_collab_thread_assistant/test_exploration_parser.py b/backend/tests/agents/test_collab_thread_assistant/test_exploration_parser.py index 4e4a74e..fe86206 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_exploration_parser.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_exploration_parser.py @@ -1,12 +1,10 @@ """Tests for exploration_parser module.""" -import pytest - from app.agents.collab_thread_assistant.exploration_parser import ( CodeExplorationRequest, + has_exploration_block, parse_exploration_request, strip_exploration_block, - has_exploration_block, ) @@ -39,7 +37,9 @@ def test_parse_exploration_request_without_branch(self): def test_parse_exploration_request_false(self): """Test parsing when wants_code_exploration is false.""" - response = """[MFBT_EXPLORE]{"wants_code_exploration": false, "code_exploration_prompt": "test"}[/MFBT_EXPLORE]""" + response = ( + """[MFBT_EXPLORE]{"wants_code_exploration": false, "code_exploration_prompt": "test"}[/MFBT_EXPLORE]""" + ) result = parse_exploration_request(response) @@ -100,7 +100,9 @@ def test_strip_exploration_block_with_preamble(self): def test_strip_exploration_block_no_preamble(self): """Test stripping exploration block without preamble.""" - response = """[MFBT_EXPLORE]{"wants_code_exploration": true, "code_exploration_prompt": "test"}[/MFBT_EXPLORE]""" + response = ( + """[MFBT_EXPLORE]{"wants_code_exploration": true, "code_exploration_prompt": "test"}[/MFBT_EXPLORE]""" + ) result = strip_exploration_block(response) assert result == "" diff --git a/backend/tests/agents/test_collab_thread_assistant/test_instrumentation.py b/backend/tests/agents/test_collab_thread_assistant/test_instrumentation.py index 864fcaa..82184fd 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_instrumentation.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_instrumentation.py @@ -5,16 +5,13 @@ """ import json -import logging from unittest.mock import patch -import pytest - from app.agents.collab_thread_assistant.instrumentation import ( CollabThreadAssistantLogger, DebugInfo, - SummarizationEvent, RetryEvent, + SummarizationEvent, get_assistant_logger, ) @@ -37,9 +34,7 @@ def test_debug_info_to_dict(self): debug_info = DebugInfo() debug_info.context_load_duration_ms = 150.5 debug_info.assistant_call_duration_ms = 200.0 - debug_info.token_breakdown = { - "thread": {"original": 1000, "after_summarization": 500, "summarized": True} - } + debug_info.token_breakdown = {"thread": {"original": 1000, "after_summarization": 500, "summarized": True}} debug_info.summarization_events.append( SummarizationEvent( content_type="thread", @@ -442,9 +437,7 @@ def test_get_debug_info(self): logger.debug_info.context_load_duration_ms = 150.0 logger.debug_info.assistant_call_duration_ms = 300.0 logger.set_token_breakdown("thread", 1000, 1000, False) - logger.debug_info.summarization_events.append( - SummarizationEvent("spec", 15000, 1800, 10000) - ) + logger.debug_info.summarization_events.append(SummarizationEvent("spec", 15000, 1800, 10000)) result = logger.get_debug_info() diff --git a/backend/tests/agents/test_collab_thread_assistant/test_orchestrator.py b/backend/tests/agents/test_collab_thread_assistant/test_orchestrator.py index 8ef9bbc..ba3dcb5 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_orchestrator.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_orchestrator.py @@ -1,9 +1,17 @@ """Tests for orchestrator module of the Collab Thread Assistant.""" +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import UUID, uuid4 + import pytest -from uuid import uuid4, UUID -from unittest.mock import MagicMock, patch, AsyncMock +from app.agents.collab_thread_assistant.config import ( + MAX_RETRIES, + MIN_MESSAGES_FOR_SUMMARY, + SUMMARY_MAX_TOKENS, + TOKEN_THRESHOLD, +) +from app.agents.collab_thread_assistant.context_loader import token_count from app.agents.collab_thread_assistant.orchestrator import ( build_context, call_assistant, @@ -14,13 +22,6 @@ CollabThreadContext, ThreadMessage, ) -from app.agents.collab_thread_assistant.config import ( - TOKEN_THRESHOLD, - SUMMARY_MAX_TOKENS, - MAX_RETRIES, - MIN_MESSAGES_FOR_SUMMARY, -) -from app.agents.collab_thread_assistant.context_loader import token_count class TestBuildContext: @@ -194,7 +195,10 @@ async def test_build_context_no_summarization_below_threshold( mock_model_client = MagicMock() context = await build_context( - mock_db, thread_id, feature_id, project_id, + mock_db, + thread_id, + feature_id, + project_id, model_client=mock_model_client, ) @@ -220,8 +224,7 @@ async def test_build_context_summarizes_large_thread( # Create enough messages to trigger summarization messages = [ - ThreadMessage(author="Alice", body=f"Message {i}", created_at="2024-01-15T10:00:00Z") - for i in range(100) + ThreadMessage(author="Alice", body=f"Message {i}", created_at="2024-01-15T10:00:00Z") for i in range(100) ] mock_load_thread.return_value = { @@ -252,7 +255,10 @@ async def test_build_context_summarizes_large_thread( mock_model_client = MagicMock() context = await build_context( - mock_db, thread_id, feature_id, project_id, + mock_db, + thread_id, + feature_id, + project_id, model_client=mock_model_client, ) @@ -307,7 +313,10 @@ async def test_build_context_summarizes_large_spec( mock_model_client = MagicMock() context = await build_context( - mock_db, thread_id, feature_id, project_id, + mock_db, + thread_id, + feature_id, + project_id, model_client=mock_model_client, ) @@ -527,7 +536,7 @@ async def test_handle_ai_mention_with_model_client(self, mock_call_assistant): "assistant_retry_attempts": 0, "context_partial": False, "errors_encountered": None, - } + }, ) mock_call_assistant.return_value = mock_response @@ -675,9 +684,7 @@ class TestCallAssistantRetry: @patch("asyncio.sleep", new_callable=AsyncMock) @patch("app.agents.collab_thread_assistant.orchestrator.CollabThreadAssistant") @patch("app.agents.collab_thread_assistant.orchestrator.build_context") - async def test_call_assistant_retries_context_loading( - self, mock_build_context, MockAssistant, mock_sleep - ): + async def test_call_assistant_retries_context_loading(self, mock_build_context, MockAssistant, mock_sleep): """Test that call_assistant retries context loading on failure.""" thread_id = str(uuid4()) feature_id = str(uuid4()) @@ -694,12 +701,18 @@ async def test_call_assistant_retries_context_loading( notes="Notes", grounding_files={}, summarization_applied={ - "thread": False, "spec": False, "prompt_plan": False, - "notes": False, "grounding": False, + "thread": False, + "spec": False, + "prompt_plan": False, + "notes": False, + "grounding": False, }, token_counts={ - "thread": 100, "spec": 50, "prompt_plan": 30, - "notes": 20, "grounding": 0, + "thread": 100, + "spec": 50, + "prompt_plan": 30, + "notes": 20, + "grounding": 0, }, ) mock_build_context.side_effect = [ @@ -728,9 +741,7 @@ async def test_call_assistant_retries_context_loading( @patch("asyncio.sleep", new_callable=AsyncMock) @patch("app.agents.collab_thread_assistant.orchestrator.CollabThreadAssistant") @patch("app.agents.collab_thread_assistant.orchestrator.build_context") - async def test_call_assistant_retries_assistant_response( - self, mock_build_context, MockAssistant, mock_sleep - ): + async def test_call_assistant_retries_assistant_response(self, mock_build_context, MockAssistant, mock_sleep): """Test that call_assistant retries assistant response on failure.""" thread_id = str(uuid4()) feature_id = str(uuid4()) @@ -746,21 +757,25 @@ async def test_call_assistant_retries_assistant_response( notes="Notes", grounding_files={}, summarization_applied={ - "thread": False, "spec": False, "prompt_plan": False, - "notes": False, "grounding": False, + "thread": False, + "spec": False, + "prompt_plan": False, + "notes": False, + "grounding": False, }, token_counts={ - "thread": 100, "spec": 50, "prompt_plan": 30, - "notes": 20, "grounding": 0, + "thread": 100, + "spec": 50, + "prompt_plan": 30, + "notes": 20, + "grounding": 0, }, ) mock_build_context.return_value = mock_context # First assistant call fails, second succeeds mock_assistant = MagicMock() - mock_assistant.respond = AsyncMock( - side_effect=[RuntimeError("LLM timeout"), "Success response"] - ) + mock_assistant.respond = AsyncMock(side_effect=[RuntimeError("LLM timeout"), "Success response"]) MockAssistant.return_value = mock_assistant response = await call_assistant( @@ -779,9 +794,7 @@ async def test_call_assistant_retries_assistant_response( @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) @patch("app.agents.collab_thread_assistant.orchestrator.build_context") - async def test_call_assistant_graceful_fallback_on_context_failure( - self, mock_build_context, mock_sleep - ): + async def test_call_assistant_graceful_fallback_on_context_failure(self, mock_build_context, mock_sleep): """Test graceful fallback when all context loading attempts fail.""" thread_id = str(uuid4()) feature_id = str(uuid4()) @@ -827,12 +840,18 @@ async def test_call_assistant_graceful_fallback_on_assistant_failure( notes="Notes", grounding_files={}, summarization_applied={ - "thread": False, "spec": False, "prompt_plan": False, - "notes": False, "grounding": False, + "thread": False, + "spec": False, + "prompt_plan": False, + "notes": False, + "grounding": False, }, token_counts={ - "thread": 100, "spec": 50, "prompt_plan": 30, - "notes": 20, "grounding": 0, + "thread": 100, + "spec": 50, + "prompt_plan": 30, + "notes": 20, + "grounding": 0, }, ) mock_build_context.return_value = mock_context @@ -863,9 +882,7 @@ async def test_call_assistant_graceful_fallback_on_assistant_failure( @patch("asyncio.sleep", new_callable=AsyncMock) @patch("app.agents.collab_thread_assistant.orchestrator.CollabThreadAssistant") @patch("app.agents.collab_thread_assistant.orchestrator.build_context") - async def test_call_assistant_metadata_includes_retry_info( - self, mock_build_context, MockAssistant, mock_sleep - ): + async def test_call_assistant_metadata_includes_retry_info(self, mock_build_context, MockAssistant, mock_sleep): """Test that metadata includes retry information.""" thread_id = str(uuid4()) feature_id = str(uuid4()) @@ -881,12 +898,18 @@ async def test_call_assistant_metadata_includes_retry_info( notes="Notes", grounding_files={}, summarization_applied={ - "thread": False, "spec": False, "prompt_plan": False, - "notes": False, "grounding": False, + "thread": False, + "spec": False, + "prompt_plan": False, + "notes": False, + "grounding": False, }, token_counts={ - "thread": 100, "spec": 50, "prompt_plan": 30, - "notes": 20, "grounding": 0, + "thread": 100, + "spec": 50, + "prompt_plan": 30, + "notes": 20, + "grounding": 0, }, ) mock_build_context.return_value = mock_context diff --git a/backend/tests/agents/test_collab_thread_assistant/test_quality.py b/backend/tests/agents/test_collab_thread_assistant/test_quality.py index 715bfaf..3576b13 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_quality.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_quality.py @@ -5,23 +5,21 @@ and trade-off formatting. """ -import pytest +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from unittest.mock import MagicMock, AsyncMock, patch -from app.agents.collab_thread_assistant.validators import ( - ResponseValidator, - ValidationResult, -) +import pytest + from app.agents.collab_thread_assistant.assistant import ( CollabThreadAssistant, - SYSTEM_PROMPT, ) from app.agents.collab_thread_assistant.types import ( CollabThreadContext, - ThreadMessage, ) - +from app.agents.collab_thread_assistant.validators import ( + ResponseValidator, + ValidationResult, +) # ============================================================================ # Test Fixtures - Mock Responses @@ -224,8 +222,7 @@ def test_fallback_response_structure(self, validator): """Test that fallback responses also follow structure.""" assistant = CollabThreadAssistant(MagicMock()) fallback = assistant._generate_fallback_response( - user_message="@MFBTAI What should I do?", - error="Connection timeout" + user_message="@MFBTAI What should I do?", error="Connection timeout" ) result = validator.validate(fallback) @@ -513,16 +510,12 @@ def sample_context(self): ) @pytest.mark.asyncio - async def test_assistant_response_follows_structure( - self, assistant, sample_context, validator - ): + async def test_assistant_response_follows_structure(self, assistant, sample_context, validator): """Test that assistant response follows Universal Output Structure.""" mock_response = MagicMock() mock_response.chat_message.content = VALID_RESPONSE - with patch.object( - assistant, "_create_agent", return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -533,16 +526,12 @@ async def test_assistant_response_follows_structure( assert validation.missing_sections == [] @pytest.mark.asyncio - async def test_trade_off_question_gets_table_response( - self, assistant, sample_context, validator - ): + async def test_trade_off_question_gets_table_response(self, assistant, sample_context, validator): """Test that trade-off questions produce table format.""" mock_response = MagicMock() mock_response.chat_message.content = TRADE_OFF_RESPONSE - with patch.object( - assistant, "_create_agent", return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -556,13 +545,9 @@ async def test_trade_off_question_gets_table_response( assert validation.trade_off_table_valid is True @pytest.mark.asyncio - async def test_fallback_response_follows_structure( - self, assistant, sample_context, validator - ): + async def test_fallback_response_follows_structure(self, assistant, sample_context, validator): """Test that fallback response on error still follows structure.""" - with patch.object( - assistant, "_create_agent", return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(side_effect=Exception("API Error")) diff --git a/backend/tests/agents/test_collab_thread_assistant/test_retry.py b/backend/tests/agents/test_collab_thread_assistant/test_retry.py index b5f2e5c..96bae62 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_retry.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_retry.py @@ -4,18 +4,17 @@ Phase 5: Comprehensive tests for with_retry() and with_retry_sync(). """ -import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest +from app.agents.collab_thread_assistant.config import MAX_RETRIES from app.agents.collab_thread_assistant.retry import ( RetryError, + calculate_backoff_delay, with_retry, with_retry_sync, - calculate_backoff_delay, ) -from app.agents.collab_thread_assistant.config import MAX_RETRIES, RETRY_BACKOFF_MS class TestRetryError: @@ -88,9 +87,7 @@ async def test_success_after_one_retry(self): @pytest.mark.asyncio async def test_success_after_two_retries(self): """Test success after two failed attempts.""" - mock_func = AsyncMock( - side_effect=[ValueError("fail1"), ValueError("fail2"), "success"] - ) + mock_func = AsyncMock(side_effect=[ValueError("fail1"), ValueError("fail2"), "success"]) with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: result = await with_retry(mock_func, max_attempts=3) @@ -122,9 +119,7 @@ async def test_failure_after_all_attempts(self): @pytest.mark.asyncio async def test_backoff_timing(self): """Test that backoff delays are applied correctly.""" - mock_func = AsyncMock( - side_effect=[ValueError("fail1"), ValueError("fail2"), "success"] - ) + mock_func = AsyncMock(side_effect=[ValueError("fail1"), ValueError("fail2"), "success"]) with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: await with_retry(mock_func, max_attempts=3) @@ -150,9 +145,7 @@ async def test_custom_backoff_delays(self): @pytest.mark.asyncio async def test_on_retry_callback_called(self): """Test that on_retry callback is called on failures.""" - mock_func = AsyncMock( - side_effect=[ValueError("fail1"), ValueError("fail2"), "success"] - ) + mock_func = AsyncMock(side_effect=[ValueError("fail1"), ValueError("fail2"), "success"]) callback_calls = [] def on_retry(attempt, error): @@ -182,9 +175,7 @@ def bad_callback(attempt, error): @pytest.mark.asyncio async def test_custom_max_attempts(self): """Test with custom max_attempts value.""" - mock_func = AsyncMock( - side_effect=[ValueError("fail1"), ValueError("fail2"), "success"] - ) + mock_func = AsyncMock(side_effect=[ValueError("fail1"), ValueError("fail2"), "success"]) with patch("asyncio.sleep", new_callable=AsyncMock): result = await with_retry(mock_func, max_attempts=5) @@ -206,9 +197,7 @@ async def test_single_attempt(self): @pytest.mark.asyncio async def test_no_sleep_after_last_attempt(self): """Test that no sleep occurs after the final failed attempt.""" - mock_func = AsyncMock( - side_effect=[ValueError("fail1"), ValueError("fail2"), ValueError("fail3")] - ) + mock_func = AsyncMock(side_effect=[ValueError("fail1"), ValueError("fail2"), ValueError("fail3")]) with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: with pytest.raises(RetryError): @@ -258,9 +247,7 @@ async def test_sync_function_success_after_retries(self): @pytest.mark.asyncio async def test_sync_function_failure_after_all_attempts(self): """Test sync function raises RetryError after all attempts.""" - mock_func = MagicMock( - side_effect=[ValueError("fail1"), ValueError("fail2"), ValueError("fail3")] - ) + mock_func = MagicMock(side_effect=[ValueError("fail1"), ValueError("fail2"), ValueError("fail3")]) with patch("asyncio.sleep", new_callable=AsyncMock): with pytest.raises(RetryError) as exc_info: diff --git a/backend/tests/agents/test_collab_thread_assistant/test_stress.py b/backend/tests/agents/test_collab_thread_assistant/test_stress.py index 1630d55..29bb90e 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_stress.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_stress.py @@ -7,16 +7,17 @@ import hashlib import random -import pytest +from datetime import datetime +from typing import Any, Dict, List +from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from datetime import datetime, timezone -from typing import List, Dict, Any -from unittest.mock import MagicMock, AsyncMock, patch + +import pytest from app.agents.collab_thread_assistant.config import ( - TOKEN_THRESHOLD, - SUMMARY_MAX_TOKENS, RECENT_MESSAGES_COUNT, + SUMMARY_MAX_TOKENS, + TOKEN_THRESHOLD, ) from app.agents.collab_thread_assistant.context_loader import token_count from app.agents.collab_thread_assistant.types import ( @@ -25,7 +26,6 @@ ) from app.agents.collab_thread_assistant.validators import ResponseValidator - # ============================================================================ # Synthetic Data Generators # ============================================================================ @@ -79,9 +79,7 @@ def generate_synthetic_messages( word_count = rng.randint(30, 150) body = _generate_message_body(rng, topic, word_count) - timestamp = datetime( - 2024, 1, 15, 10, message_idx % 60, message_idx % 60 - ).isoformat() + "Z" + timestamp = datetime(2024, 1, 15, 10, message_idx % 60, message_idx % 60).isoformat() + "Z" messages.append( ThreadMessage( @@ -112,13 +110,34 @@ def _generate_message_body(rng: random.Random, topic: str, word_count: int) -> s ] filler_words = [ - "important", "consideration", "implementation", "architecture", - "scalability", "maintainability", "performance", "reliability", - "consistency", "availability", "latency", "throughput", - "complexity", "simplicity", "flexibility", "extensibility", - "modularity", "testability", "observability", "security", - "authentication", "authorization", "validation", "sanitization", - "encryption", "hashing", "caching", "indexing", + "important", + "consideration", + "implementation", + "architecture", + "scalability", + "maintainability", + "performance", + "reliability", + "consistency", + "availability", + "latency", + "throughput", + "complexity", + "simplicity", + "flexibility", + "extensibility", + "modularity", + "testability", + "observability", + "security", + "authentication", + "authorization", + "validation", + "sanitization", + "encryption", + "hashing", + "caching", + "indexing", ] starter = rng.choice(sentence_starters) @@ -169,9 +188,7 @@ def generate_synthetic_spec(target_tokens: int, seed: str = "stress_test_spec") return spec -def generate_synthetic_prompt_plan( - target_tokens: int, seed: str = "stress_test_plan" -) -> str: +def generate_synthetic_prompt_plan(target_tokens: int, seed: str = "stress_test_plan") -> str: """ Generate a synthetic prompt plan document. @@ -205,7 +222,7 @@ def generate_synthetic_prompt_plan( step_count = rng.randint(3, 8) for i in range(step_count): step_length = rng.randint(20, 50) - plan += f"{i+1}. " + _generate_paragraph(rng, step_length) + "\n" + plan += f"{i + 1}. " + _generate_paragraph(rng, step_length) + "\n" plan += "\n" @@ -215,13 +232,40 @@ def generate_synthetic_prompt_plan( def _generate_paragraph(rng: random.Random, word_count: int) -> str: """Generate a paragraph of technical content.""" technical_words = [ - "implement", "configure", "deploy", "test", "validate", - "integrate", "optimize", "refactor", "design", "architect", - "database", "API", "service", "component", "module", - "function", "class", "interface", "protocol", "endpoint", - "request", "response", "payload", "schema", "model", - "authentication", "authorization", "validation", "encryption", - "caching", "indexing", "querying", "filtering", "sorting", + "implement", + "configure", + "deploy", + "test", + "validate", + "integrate", + "optimize", + "refactor", + "design", + "architect", + "database", + "API", + "service", + "component", + "module", + "function", + "class", + "interface", + "protocol", + "endpoint", + "request", + "response", + "payload", + "schema", + "model", + "authentication", + "authorization", + "validation", + "encryption", + "caching", + "indexing", + "querying", + "filtering", + "sorting", ] words = [] @@ -268,17 +312,13 @@ def generate_synthetic_context( spec = generate_synthetic_spec(spec_tokens, seed + "_spec") if prompt_plan_tokens > 0: - prompt_plan = generate_synthetic_prompt_plan( - prompt_plan_tokens, seed + "_plan" - ) + prompt_plan = generate_synthetic_prompt_plan(prompt_plan_tokens, seed + "_plan") if notes_tokens > 0: notes = generate_synthetic_spec(notes_tokens, seed + "_notes") if grounding_tokens > 0: - grounding_files = { - "agents.md": generate_synthetic_spec(grounding_tokens, seed + "_grounding") - } + grounding_files = {"agents.md": generate_synthetic_spec(grounding_tokens, seed + "_grounding")} return { "messages": messages, @@ -520,15 +560,11 @@ async def test_large_context_assistant_response(self, large_context, validator): assistant = CollabThreadAssistant(MagicMock()) - with patch.object( - assistant, "_create_agent", return_value=MagicMock() - ) as mock_create: + with patch.object(assistant, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) - result = await assistant.respond( - large_context, "@MFBTAI What should we do next?" - ) + result = await assistant.respond(large_context, "@MFBTAI What should we do next?") validation = validator.validate(result) assert validation.valid is True @@ -561,9 +597,7 @@ class TestStressEdgeCases: def test_empty_thread_generation(self): """Test generating context with no thread messages.""" - context_data = generate_synthetic_context( - thread_tokens=0, spec_tokens=1000 - ) + context_data = generate_synthetic_context(thread_tokens=0, spec_tokens=1000) assert len(context_data["messages"]) == 0 assert context_data["thread_text"] == "" @@ -595,6 +629,4 @@ def test_all_context_parts_exceed_threshold(self): assert token_count(context_data["thread_text"]) > TOKEN_THRESHOLD assert token_count(context_data["spec"]) > TOKEN_THRESHOLD assert token_count(context_data["prompt_plan"]) > TOKEN_THRESHOLD - assert token_count( - context_data["grounding_files"]["agents.md"] - ) > TOKEN_THRESHOLD + assert token_count(context_data["grounding_files"]["agents.md"]) > TOKEN_THRESHOLD diff --git a/backend/tests/agents/test_collab_thread_assistant/test_summarizer.py b/backend/tests/agents/test_collab_thread_assistant/test_summarizer.py index eeede3f..883faa6 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_summarizer.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_summarizer.py @@ -1,11 +1,11 @@ """Tests for the Collab Thread Assistant summarizer.""" +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import MagicMock, AsyncMock, patch from app.agents.collab_thread_assistant.summarizer import SummarizerAgent from app.agents.collab_thread_assistant.types import ThreadMessage -from app.agents.collab_thread_assistant.config import SUMMARY_MAX_TOKENS class TestSummarizerAgent: @@ -101,9 +101,7 @@ async def test_summarize_calls_agent(self, summarizer): mock_response = MagicMock() mock_response.chat_message.content = "This is a summary." - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -119,9 +117,7 @@ async def test_summarize_with_context_type(self, summarizer): mock_response = MagicMock() mock_response.chat_message.content = "Summarized content." - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -138,9 +134,7 @@ async def test_summarize_thread_calls_agent(self, summarizer, sample_messages): mock_response = MagicMock() mock_response.chat_message.content = "Thread summary with decisions." - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -157,9 +151,7 @@ async def test_summarize_handles_list_response(self, summarizer): mock_response = MagicMock() mock_response.chat_message.content = ["Part 1", "Part 2"] - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -173,9 +165,7 @@ async def test_summarize_thread_handles_list_response(self, summarizer, sample_m mock_response = MagicMock() mock_response.chat_message.content = ["Summary", "part", "two"] - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -186,9 +176,7 @@ async def test_summarize_thread_handles_list_response(self, summarizer, sample_m @pytest.mark.asyncio async def test_summarize_fallback_on_error(self, summarizer): """Test fallback behavior when summarization fails.""" - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(side_effect=Exception("API Error")) @@ -202,9 +190,7 @@ async def test_summarize_fallback_on_error(self, summarizer): @pytest.mark.asyncio async def test_summarize_thread_fallback_on_error(self, summarizer, sample_messages): """Test thread summarize fallback behavior when LLM fails.""" - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(side_effect=Exception("LLM Error")) @@ -219,9 +205,7 @@ async def test_maintain_running_summary_initial(self, summarizer, sample_message mock_response = MagicMock() mock_response.chat_message.content = "Initial running summary." - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -239,9 +223,7 @@ async def test_maintain_running_summary_update(self, summarizer, sample_messages mock_response = MagicMock() mock_response.chat_message.content = "Updated running summary." - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -259,9 +241,7 @@ async def test_maintain_running_summary_update(self, summarizer, sample_messages @pytest.mark.asyncio async def test_maintain_running_summary_fallback_with_existing(self, summarizer, sample_messages): """Test running summary fallback when update fails and existing summary exists.""" - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(side_effect=Exception("LLM Error")) @@ -284,9 +264,7 @@ async def test_logger_set_agent_called_for_summarize(self, mock_model_client): mock_response = MagicMock() mock_response.chat_message.content = "Summary" - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -306,13 +284,9 @@ async def test_logger_set_agent_called_for_thread_summarize(self, mock_model_cli mock_response = MagicMock() mock_response.chat_message.content = "Thread summary" - messages = [ - ThreadMessage(author="Alice", body="Hello", created_at="2024-01-01T10:00:00Z") - ] + messages = [ThreadMessage(author="Alice", body="Hello", created_at="2024-01-01T10:00:00Z")] - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -331,13 +305,9 @@ async def test_logger_set_agent_called_for_running_summary(self, mock_model_clie mock_response = MagicMock() mock_response.chat_message.content = "Running summary" - messages = [ - ThreadMessage(author="Alice", body="Hello", created_at="2024-01-01T10:00:00Z") - ] + messages = [ThreadMessage(author="Alice", body="Hello", created_at="2024-01-01T10:00:00Z")] - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) @@ -355,9 +325,7 @@ async def test_summarize_without_logger(self, mock_model_client): mock_response = MagicMock() mock_response.chat_message.content = "Summary" - with patch.object( - summarizer, '_create_agent', return_value=MagicMock() - ) as mock_create: + with patch.object(summarizer, "_create_agent", return_value=MagicMock()) as mock_create: mock_agent = mock_create.return_value mock_agent.on_messages = AsyncMock(return_value=mock_response) diff --git a/backend/tests/agents/test_collab_thread_assistant/test_types.py b/backend/tests/agents/test_collab_thread_assistant/test_types.py index 68c9a17..7aae73c 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_types.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_types.py @@ -1,13 +1,12 @@ """Tests for types module of the Collab Thread Assistant.""" -import pytest from uuid import uuid4 from app.agents.collab_thread_assistant.types import ( - ThreadMessage, - CollabThreadContext, AssistantResponse, + CollabThreadContext, ContextLoadResult, + ThreadMessage, ) @@ -108,9 +107,7 @@ class TestAssistantResponse: def test_create_response_minimal(self): """Test creating a response with just reply_text.""" - response = AssistantResponse( - reply_text="Here is my response..." - ) + response = AssistantResponse(reply_text="Here is my response...") assert response.reply_text == "Here is my response..." assert response.metadata == {} @@ -125,7 +122,7 @@ def test_create_response_with_metadata(self): "summarized_parts": ["thread"], "latency_ms": 1234.5, "retries": 0, - } + }, ) assert response.reply_text == "Analysis complete." diff --git a/backend/tests/agents/test_collab_thread_assistant/test_web_search_parser.py b/backend/tests/agents/test_collab_thread_assistant/test_web_search_parser.py index 225004a..32d863f 100644 --- a/backend/tests/agents/test_collab_thread_assistant/test_web_search_parser.py +++ b/backend/tests/agents/test_collab_thread_assistant/test_web_search_parser.py @@ -1,12 +1,10 @@ """Tests for web_search_parser module.""" -import pytest - from app.agents.collab_thread_assistant.web_search_parser import ( WebSearchRequest, + has_web_search_block, parse_web_search_request, strip_web_search_block, - has_web_search_block, ) diff --git a/backend/tests/agents/test_grounding_orchestrator.py b/backend/tests/agents/test_grounding_orchestrator.py index 521aef9..015bd73 100644 --- a/backend/tests/agents/test_grounding_orchestrator.py +++ b/backend/tests/agents/test_grounding_orchestrator.py @@ -3,13 +3,15 @@ Tests the agent orchestrator for grounding file updates, particularly the content_summary parsing and summarize_content method. """ + import json +from unittest.mock import AsyncMock, MagicMock + import pytest -from unittest.mock import MagicMock, AsyncMock from app.agents.grounding.orchestrator import ( - GroundingUpdateOrchestrator, SUMMARIZE_ONLY_PROMPT, + GroundingUpdateOrchestrator, ) from app.agents.grounding.types import GroundingUpdateResult @@ -23,21 +25,26 @@ def test_parse_response_with_content_summary(self): mock_client = MagicMock() orchestrator = GroundingUpdateOrchestrator(model_client=mock_client) - response_text = json.dumps({ - "updated_content": "# Architecture\nMicroservices with API gateway", - "changes": { - "added": ["Microservices pattern"], - "updated": [], - "removed": [], - }, - "summary": "Added microservices pattern documentation", - "content_summary": "This file documents a microservices architecture with API gateway, event-driven communication, and caching patterns.", - }) + response_text = json.dumps( + { + "updated_content": "# Architecture\nMicroservices with API gateway", + "changes": { + "added": ["Microservices pattern"], + "updated": [], + "removed": [], + }, + "summary": "Added microservices pattern documentation", + "content_summary": "This file documents a microservices architecture with API gateway, event-driven communication, and caching patterns.", + } + ) result = orchestrator._parse_response(response_text) assert isinstance(result, GroundingUpdateResult) - assert result.content_summary == "This file documents a microservices architecture with API gateway, event-driven communication, and caching patterns." + assert ( + result.content_summary + == "This file documents a microservices architecture with API gateway, event-driven communication, and caching patterns." + ) assert result.summary == "Added microservices pattern documentation" assert "Microservices with API gateway" in result.updated_content @@ -46,11 +53,13 @@ def test_parse_response_without_content_summary(self): mock_client = MagicMock() orchestrator = GroundingUpdateOrchestrator(model_client=mock_client) - response_text = json.dumps({ - "updated_content": "# Architecture\nSimple patterns", - "changes": {"added": [], "updated": [], "removed": []}, - "summary": "No changes made", - }) + response_text = json.dumps( + { + "updated_content": "# Architecture\nSimple patterns", + "changes": {"added": [], "updated": [], "removed": []}, + "summary": "No changes made", + } + ) result = orchestrator._parse_response(response_text) @@ -100,7 +109,7 @@ async def test_summarize_content_calls_agent(self): # We need to create a fresh agent for summarization, so we'll mock AssistantAgent from unittest.mock import patch - with patch('app.agents.grounding.orchestrator.AssistantAgent') as MockAgent: + with patch("app.agents.grounding.orchestrator.AssistantAgent") as MockAgent: mock_agent_instance = MagicMock() mock_agent_instance.on_messages = AsyncMock(return_value=mock_response) MockAgent.return_value = mock_agent_instance @@ -127,7 +136,7 @@ async def test_summarize_content_strips_whitespace(self): from unittest.mock import patch - with patch('app.agents.grounding.orchestrator.AssistantAgent') as MockAgent: + with patch("app.agents.grounding.orchestrator.AssistantAgent") as MockAgent: mock_agent_instance = MagicMock() mock_agent_instance.on_messages = AsyncMock(return_value=mock_response) MockAgent.return_value = mock_agent_instance @@ -145,12 +154,13 @@ async def test_summarize_content_calls_progress_callback(self): mock_response.chat_message.content = "Summary" progress_calls = [] + def progress_callback(step, percent): progress_calls.append((step, percent)) from unittest.mock import patch - with patch('app.agents.grounding.orchestrator.AssistantAgent') as MockAgent: + with patch("app.agents.grounding.orchestrator.AssistantAgent") as MockAgent: mock_agent_instance = MagicMock() mock_agent_instance.on_messages = AsyncMock(return_value=mock_response) MockAgent.return_value = mock_agent_instance diff --git a/backend/tests/agents/test_llm_client.py b/backend/tests/agents/test_llm_client.py index 03f5923..b595428 100644 --- a/backend/tests/agents/test_llm_client.py +++ b/backend/tests/agents/test_llm_client.py @@ -1,22 +1,22 @@ """Tests for the LiteLLM ChatCompletionClient wrapper.""" -import pytest from unittest.mock import AsyncMock, MagicMock, patch +import pytest +from autogen_core.models import ( + AssistantMessage, + ModelFamily, + RequestUsage, + SystemMessage, + UserMessage, +) + from app.agents.llm_client import ( + DEFAULT_LLM_REQUEST_TIMEOUT_SECONDS, LiteLLMChatCompletionClient, create_litellm_client, get_litellm_model_name, get_model_family, - PROVIDER_MODEL_PREFIXES, - DEFAULT_LLM_REQUEST_TIMEOUT_SECONDS, -) -from autogen_core.models import ( - SystemMessage, - UserMessage, - AssistantMessage, - RequestUsage, - ModelFamily, ) @@ -272,9 +272,7 @@ async def test_create_mock_response(self): with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: mock_acompletion.return_value = mock_response - result = await client.create( - messages=[UserMessage(content="Hello", source="user")] - ) + result = await client.create(messages=[UserMessage(content="Hello", source="user")]) assert result.content == "Hello! I'm Claude." assert result.finish_reason == "stop" diff --git a/backend/tests/agents/test_project_chat_gating.py b/backend/tests/agents/test_project_chat_gating.py index e5848be..fd8ec23 100644 --- a/backend/tests/agents/test_project_chat_gating.py +++ b/backend/tests/agents/test_project_chat_gating.py @@ -1,13 +1,14 @@ """Tests for the project-chat gating agent.""" +from unittest.mock import AsyncMock, MagicMock + import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from app.agents.project_chat_gating import should_ai_respond, GatingResponse +from app.agents.project_chat_gating import GatingResponse, should_ai_respond from app.agents.project_chat_gating.agent import ( - _parse_gating_response, - _format_messages_for_context, GATING_SYSTEM_PROMPT, + _format_messages_for_context, + _parse_gating_response, ) diff --git a/backend/tests/agents/test_response_parser.py b/backend/tests/agents/test_response_parser.py index 754c733..0609d69 100644 --- a/backend/tests/agents/test_response_parser.py +++ b/backend/tests/agents/test_response_parser.py @@ -6,15 +6,16 @@ """ import json + import pytest from app.agents.response_parser import ( - strip_markdown_json, - strip_markdown_content, - normalize_response_content, extract_json_from_text, + normalize_response_content, parse_json_response, safe_parse_json, + strip_markdown_content, + strip_markdown_json, ) @@ -66,20 +67,20 @@ def test_nested_python_code_block(self): necessary. Regex-based approaches with non-greedy matching would incorrectly stop at the nested ```python fence. """ - text = '''```json + text = """```json {"description": "Use this code:\\n```python\\nprint('hello')\\n```\\nfor setup"} -```''' +```""" result = strip_markdown_json(text) # The nested code block should be preserved in the JSON string - assert '```python' in result + assert "```python" in result assert "print('hello')" in result # Should be valid JSON parsed = json.loads(result) - assert 'description' in parsed + assert "description" in parsed def test_nested_multiple_code_blocks(self): """Multiple nested code blocks should all be preserved.""" - text = '''```json + text = """```json [{ "title": "Config Example", "examples": [ @@ -87,30 +88,30 @@ def test_nested_multiple_code_blocks(self): {"lang": "bash", "code": "```bash\\necho hello\\n```"} ] }] -```''' +```""" result = strip_markdown_json(text) - assert '```python' in result - assert '```bash' in result + assert "```python" in result + assert "```bash" in result # Should be valid JSON parsed = json.loads(result) assert len(parsed) == 1 - assert parsed[0]['title'] == 'Config Example' + assert parsed[0]["title"] == "Config Example" def test_nested_json_code_block(self): """Even nested ```json blocks should be preserved.""" - text = '''```json + text = """```json {"example": "Here's some JSON:\\n```json\\n{\\"nested\\": true}\\n```\\nPretty cool!"} -```''' +```""" result = strip_markdown_json(text) # The nested ```json should be preserved - assert '```json' in result + assert "```json" in result parsed = json.loads(result) - assert 'nested' in parsed['example'] + assert "nested" in parsed["example"] def test_empty_input(self): """Empty input should return empty string.""" - assert strip_markdown_json('') == '' - assert strip_markdown_json(' ') == '' + assert strip_markdown_json("") == "" + assert strip_markdown_json(" ") == "" def test_single_line(self): """Single line input should pass through.""" @@ -128,13 +129,13 @@ class TestStripMarkdownContent: def test_markdown_fence(self): """```markdown fences should be stripped.""" - text = '```markdown\n# Hello World\n```' - assert strip_markdown_content(text) == '# Hello World' + text = "```markdown\n# Hello World\n```" + assert strip_markdown_content(text) == "# Hello World" def test_plain_fence(self): """Plain ``` fences should be stripped.""" - text = '```\n# Hello World\n```' - assert strip_markdown_content(text) == '# Hello World' + text = "```\n# Hello World\n```" + assert strip_markdown_content(text) == "# Hello World" class TestNormalizeResponseContent: @@ -142,22 +143,22 @@ class TestNormalizeResponseContent: def test_string_passthrough(self): """Strings should pass through unchanged.""" - assert normalize_response_content('hello') == 'hello' + assert normalize_response_content("hello") == "hello" def test_dict_to_json(self): """Dicts should be converted to JSON strings.""" - result = normalize_response_content({'key': 'value'}) + result = normalize_response_content({"key": "value"}) assert result == '{"key": "value"}' def test_list_to_string(self): """Lists should be joined with spaces.""" - result = normalize_response_content(['hello', 'world']) - assert result == 'hello world' + result = normalize_response_content(["hello", "world"]) + assert result == "hello world" def test_mixed_list(self): """Lists with mixed types should be stringified and joined.""" - result = normalize_response_content(['count:', 42]) - assert result == 'count: 42' + result = normalize_response_content(["count:", 42]) + assert result == "count: 42" class TestExtractJsonFromText: @@ -171,9 +172,9 @@ def test_json_object_extraction(self): def test_json_array_extraction(self): """JSON arrays should be extracted from surrounding text.""" - text = 'The list is: [1, 2, 3]' + text = "The list is: [1, 2, 3]" result = extract_json_from_text(text) - assert result == '[1, 2, 3]' + assert result == "[1, 2, 3]" def test_nested_brackets(self): """Nested brackets should be handled correctly.""" @@ -182,18 +183,18 @@ def test_nested_brackets(self): assert result == '{"outer": {"inner": [1, 2]}}' # Verify it's valid JSON parsed = json.loads(result) - assert parsed['outer']['inner'] == [1, 2] + assert parsed["outer"]["inner"] == [1, 2] def test_brackets_in_strings(self): """Brackets inside strings should not affect matching.""" text = 'Data: {"message": "Use {braces} and [brackets]"}' result = extract_json_from_text(text) parsed = json.loads(result) - assert parsed['message'] == 'Use {braces} and [brackets]' + assert parsed["message"] == "Use {braces} and [brackets]" def test_no_json_found(self): """Should return None when no JSON is found.""" - text = 'Just some plain text without JSON' + text = "Just some plain text without JSON" assert extract_json_from_text(text) is None def test_prefers_object_over_array(self): @@ -206,7 +207,7 @@ def test_prefers_array_when_first(self): """When array comes first, it should be extracted.""" text = '[1, 2] and then {"obj": 1}' result = extract_json_from_text(text) - assert result == '[1, 2]' + assert result == "[1, 2]" class TestParseJsonResponse: @@ -215,38 +216,38 @@ class TestParseJsonResponse: def test_direct_json(self): """Direct JSON should parse immediately.""" result = parse_json_response('{"key": "value"}') - assert result == {'key': 'value'} + assert result == {"key": "value"} def test_markdown_wrapped(self): """Markdown-wrapped JSON should be extracted and parsed.""" text = '```json\n{"key": "value"}\n```' result = parse_json_response(text) - assert result == {'key': 'value'} + assert result == {"key": "value"} def test_with_surrounding_text(self): """JSON with surrounding text should be extracted.""" text = 'Here is the data: {"key": "value"}' result = parse_json_response(text) - assert result == {'key': 'value'} + assert result == {"key": "value"} def test_fallback_to_raw(self): """With fallback enabled, invalid JSON returns raw text.""" - text = 'Just some text' + text = "Just some text" result = parse_json_response(text, fallback_to_raw=True) - assert result == 'Just some text' + assert result == "Just some text" def test_raises_without_fallback(self): """Without fallback, invalid JSON raises JSONDecodeError.""" with pytest.raises(json.JSONDecodeError): - parse_json_response('Just some text', fallback_to_raw=False) + parse_json_response("Just some text", fallback_to_raw=False) def test_nested_code_blocks_preserved(self): """Nested code blocks should be preserved when parsing.""" - text = '''```json + text = """```json {"code": "```python\\nprint('hi')\\n```"} -```''' +```""" result = parse_json_response(text) - assert '```python' in result['code'] + assert "```python" in result["code"] class TestSafeParseJson: @@ -255,21 +256,21 @@ class TestSafeParseJson: def test_valid_json(self): """Valid JSON should be parsed.""" result = safe_parse_json('{"key": "value"}') - assert result == {'key': 'value'} + assert result == {"key": "value"} def test_invalid_json_returns_default(self): """Invalid JSON should return default.""" - result = safe_parse_json('not json') + result = safe_parse_json("not json") assert result == {} def test_custom_default(self): """Custom default should be returned on failure.""" - result = safe_parse_json('not json', default={'error': True}) - assert result == {'error': True} + result = safe_parse_json("not json", default={"error": True}) + assert result == {"error": True} def test_array_json(self): """JSON arrays should be parsed.""" - result = safe_parse_json('[1, 2, 3]') + result = safe_parse_json("[1, 2, 3]") assert result == [1, 2, 3] @@ -281,7 +282,7 @@ def test_brainstorm_question_with_code_example(self): Simulate a brainstorm question that includes a code example in one of its MCQ choices. """ - text = '''```json + text = """```json [ { "title": "Database Schema Design", @@ -297,43 +298,43 @@ def test_brainstorm_question_with_code_example(self): } } ] -```''' +```""" result = strip_markdown_json(text) # Should be valid JSON parsed = json.loads(result) assert len(parsed) == 1 - assert parsed[0]['title'] == 'Database Schema Design' + assert parsed[0]["title"] == "Database Schema Design" # The nested code block should be in the choice - assert '```json' in parsed[0]['mcq']['choices'][0]['label'] + assert "```json" in parsed[0]["mcq"]["choices"][0]["label"] def test_project_chat_response_with_exploration_code(self): """ Simulate a project chat response that includes code snippets discovered during exploration. """ - text = '''```json + text = """```json { "reply_text": "I found the authentication code. Here's what I discovered:\\n\\n```python\\ndef authenticate(user, password):\\n return verify(user, password)\\n```\\n\\nThis uses bcrypt for hashing.", "ready_to_create_phase": false, "ready_to_create_feature": false } -```''' +```""" result = strip_markdown_json(text) parsed = json.loads(result) - assert 'authenticate' in parsed['reply_text'] - assert '```python' in parsed['reply_text'] + assert "authenticate" in parsed["reply_text"] + assert "```python" in parsed["reply_text"] def test_spec_with_implementation_examples(self): """ Simulate a spec document that includes implementation code examples. """ - text = '''```json + text = """```json { "content": "## Authentication\\n\\nImplement using JWT:\\n\\n```typescript\\nconst token = jwt.sign(payload, secret);\\n```\\n\\n## Storage\\n\\nUse Redis for sessions.", "summary": "JWT authentication with Redis sessions" } -```''' +```""" result = strip_markdown_json(text) parsed = json.loads(result) - assert '```typescript' in parsed['content'] - assert 'jwt.sign' in parsed['content'] + assert "```typescript" in parsed["content"] + assert "jwt.sign" in parsed["content"] diff --git a/backend/tests/agents/test_response_parser_realdata.py b/backend/tests/agents/test_response_parser_realdata.py index 61b8970..49f06e1 100644 --- a/backend/tests/agents/test_response_parser_realdata.py +++ b/backend/tests/agents/test_response_parser_realdata.py @@ -6,17 +6,17 @@ """ import json -import pytest -from app.agents.response_parser import strip_markdown_json, parse_json_response +import pytest +from app.agents.response_parser import parse_json_response, strip_markdown_json # Real data from the database dump - these have nested code blocks REAL_RESPONSES = [ # Row 041ccb8b - Contains nested bash code blocks in the feature description { "id": "041ccb8b-e83a-453c-8e70-95f64abc5e8b", - "content": '''```json + "content": """```json { "reply_text": "Thanks for the exploration! Now I can see the complete picture:\\n\\n**Main Tool Container:** Uses `dumb-init /bin/sh -c` with a dynamically constructed shell script in `runner.py`. The SIGTERM trap needs to be added to that shell script construction.\\n\\n**Sidecar Container:** Uses the Dockerfile's CMD which calls `entrypoint.sh`. The SIGTERM trap needs to be added directly to that bash script.\\n\\nBoth approaches will work at the bash level before Python execution starts. Ready to create the feature?", "mcq_options": [ @@ -34,15 +34,14 @@ "chat_title": "Fix Tool Graceful Shutdown", "wants_code_exploration": false } -```''', +```""", "expected_title": "Add SIGTERM Trap to Tool and Sidecar Entrypoints", "expected_ready_feature": True, }, - # Row 15dc6a42 - Contains nested bash code blocks { "id": "15dc6a42-4e83-48d5-b084-0b58851a2df3", - "content": '''```json + "content": """```json { "reply_text": "Perfect! I have enough context to propose this feature. You want to add SIGTERM trapping to both the sidecar's bash entrypoint (`tool-sidecar/entrypoint.sh`) and the runner's shell script that executes tools. This will acknowledge SIGTERM signals but allow both containers to continue execution until natural completion. Click 'Create Feature' when ready, or let me know if you'd like any adjustments to the description.", "mcq_options": null, @@ -57,15 +56,14 @@ "chat_title": "Fix Tool Graceful Shutdown", "wants_code_exploration": false } -```''', +```""", "expected_title": "Add SIGTERM Trap to Tool and Sidecar Entrypoints", "expected_ready_feature": True, }, - # Row 7af4981e - Contains multiple nested code blocks (bash with python inside) { "id": "7af4981e-ebd0-44d2-9cba-e53ff2b9623e", - "content": '''```json + "content": """```json { "reply_text": "Here's the plan to fix the graceful shutdown issue:\\n\\n**The Problem:**\\nCurrently, when SIGTERM is sent to stop tool execution, the Python signal handlers in `entrypoint.py` just log a warning but don't actually allow the process to continue. This causes the entire execution to get stuck.\\n\\n**The Solution:**\\nAdd bash-level SIGTERM traps BEFORE Python starts, so the shell acknowledges the signal but doesn't stop execution.\\n\\n**Two Changes Needed:**\\n\\n**Change 1 - Sidecar (`unstract` repo):**\\nEdit `tool-sidecar/entrypoint.sh` to add the trap before starting Python:\\n```bash\\n#!/bin/bash\\n\\n# NEW: Add SIGTERM trap\\nsigterm_handler() {\\n echo \\"SIGTERM received in shell but ignoring - continuing\\" >> /shared/logs/logs.txt\\n}\\ntrap sigterm_handler TERM\\n\\n# Existing code continues\\n. /app/.venv/bin/activate\\nexec opentelemetry-instrument python -m unstract.tool_sidecar.log_processor\\n```\\n\\n**Change 2 - Tool Runner (`unstract-cloud` repo):**\\nEdit `runner/src/unstract/runner/runner.py` around lines 292-321 where the shell script is constructed. Add the trap to the `shell_script` variable:\\n```python\\nshell_script = f\\"\\"\\"\\nsigterm_handler() {{\\n echo \\"SIGTERM received in shell but ignoring - continuing\\" >> {log_file}\\n}}\\ntrap sigterm_handler TERM\\n\\n{mkdir_cmd} && {run_tool_fn}; {execute_cmd}\\n\\"\\"\\"\\n```\\n\\n**Why This Works:**\\nThe trap runs at the bash level (before Python), so when Kubernetes sends SIGTERM, bash handles it, logs it, but returns control to continue executing the Python tool. Both the main tool and sidecar can finish their work naturally.\\n\\nDoes this explanation make sense?", "mcq_options": [ @@ -83,7 +81,7 @@ "chat_title": "Fix Tool Graceful Shutdown", "wants_code_exploration": false } -```''', +```""", "expected_title": "Add SIGTERM Trap to Tool and Sidecar Entrypoints", "expected_ready_feature": True, }, @@ -109,15 +107,16 @@ def test_strip_markdown_json_preserves_nested_blocks(self, test_case): parsed = json.loads(stripped) except json.JSONDecodeError as e: pytest.fail( - f"Failed to parse JSON for {test_case['id']}: {e}\n" - f"Stripped content preview: {stripped[:500]}..." + f"Failed to parse JSON for {test_case['id']}: {e}\nStripped content preview: {stripped[:500]}..." ) # Verify expected fields - assert parsed.get("proposed_feature_title") == test_case["expected_title"], \ + assert parsed.get("proposed_feature_title") == test_case["expected_title"], ( f"Title mismatch for {test_case['id']}" - assert parsed.get("ready_to_create_feature") == test_case["expected_ready_feature"], \ + ) + assert parsed.get("ready_to_create_feature") == test_case["expected_ready_feature"], ( f"ready_to_create_feature mismatch for {test_case['id']}" + ) @pytest.mark.parametrize("test_case", REAL_RESPONSES, ids=lambda x: x["id"][:8]) def test_parse_json_response_handles_real_data(self, test_case): @@ -172,13 +171,11 @@ def test_nested_blocks_not_corrupted(self): # Count the number of ``` in the original (after outer fences removed) # The description contains ```bash which should be preserved - assert stripped.count("```bash") >= 1, \ - "Nested ```bash should be preserved (not matched as outer fence)" + assert stripped.count("```bash") >= 1, "Nested ```bash should be preserved (not matched as outer fence)" # The stripped content should still be valid JSON parsed = json.loads(stripped) # And the description should still have the full code example description = parsed["proposed_feature_description"] - assert "trap sigterm_handler TERM" in description, \ - "Code example should not be truncated" + assert "trap sigterm_handler TERM" in description, "Code example should not be truncated" diff --git a/backend/tests/agents/test_retry.py b/backend/tests/agents/test_retry.py index 064ae0e..92c5982 100644 --- a/backend/tests/agents/test_retry.py +++ b/backend/tests/agents/test_retry.py @@ -15,8 +15,8 @@ import pytest from app.agents.retry import ( - RETRYABLE_EXCEPTIONS, NON_RETRYABLE_EXCEPTIONS, + RETRYABLE_EXCEPTIONS, RetryError, create_llm_retry, is_permanent_failure, diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 57b5c78..6552f49 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -3,6 +3,7 @@ This module provides common fixtures used across all test modules. """ + import os import tempfile from unittest.mock import AsyncMock, MagicMock @@ -12,27 +13,22 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker +import app.services.kafka_producer as kafka_module from app.auth.utils import create_access_token -from app.database import Base, get_db, get_async_db +from app.database import Base, get_async_db, get_db from app.main import app -import app.services.kafka_producer as kafka_module + # Import all models to ensure they're registered with SQLAlchemy metadata from app.models import ( - Job, - OrgMembership, Organization, - User, Project, - ProjectMembership, ProjectType, - SpecVersion, SpecType, - Thread, - Comment, + User, ) -from app.services.user_service import UserService from app.services.org_service import OrgService from app.services.project_service import ProjectService +from app.services.user_service import UserService @pytest.fixture(autouse=True) @@ -57,6 +53,7 @@ def mock_kafka_producer(): # When start() is called, set producer to a truthy value async def mock_start(): mock_service.producer = MagicMock() + mock_service.start = AsyncMock(side_effect=mock_start) # Patch get_kafka_producer to return our mock @@ -89,6 +86,7 @@ def mock_sync_kafka_producer(): # When start() is called, set _producer to a truthy value def mock_start(): mock_producer._producer = MagicMock() + mock_producer.start = MagicMock(side_effect=mock_start) # Patch get_sync_kafka_producer to return our mock diff --git a/backend/tests/services/test_web_search_service.py b/backend/tests/services/test_web_search_service.py index e74a548..6d0acf1 100644 --- a/backend/tests/services/test_web_search_service.py +++ b/backend/tests/services/test_web_search_service.py @@ -5,8 +5,8 @@ import pytest from app.services.web_search_service import ( - WebSearchResult, WebSearchResponse, + WebSearchResult, WebSearchService, test_tavily_connection, ) @@ -133,9 +133,7 @@ async def test_search_success(self, service): async def test_search_failure(self, service): """Test search failure handling.""" with patch.object(service, "_get_async_client") as mock_client: - mock_client.return_value.search = AsyncMock( - side_effect=Exception("API rate limit exceeded") - ) + mock_client.return_value.search = AsyncMock(side_effect=Exception("API rate limit exceeded")) response = await service.search("test query") @@ -187,9 +185,7 @@ def test_search_sync_success(self, service): def test_search_sync_failure(self, service): """Test sync search failure handling.""" with patch.object(service, "_get_sync_client") as mock_client: - mock_client.return_value.search = MagicMock( - side_effect=Exception("Connection error") - ) + mock_client.return_value.search = MagicMock(side_effect=Exception("Connection error")) response = service.search_sync("test query") @@ -302,9 +298,7 @@ class TestTestTavilyConnection: @pytest.mark.asyncio async def test_connection_success(self): """Test successful Tavily connection.""" - with patch( - "app.services.web_search_service.WebSearchService" - ) as mock_service_class: + with patch("app.services.web_search_service.WebSearchService") as mock_service_class: mock_service = MagicMock() mock_service.search = AsyncMock( return_value=WebSearchResponse( @@ -325,9 +319,7 @@ async def test_connection_success(self): @pytest.mark.asyncio async def test_connection_search_failure(self): """Test Tavily connection when search fails.""" - with patch( - "app.services.web_search_service.WebSearchService" - ) as mock_service_class: + with patch("app.services.web_search_service.WebSearchService") as mock_service_class: mock_service = MagicMock() mock_service.search = AsyncMock( return_value=WebSearchResponse( @@ -348,13 +340,9 @@ async def test_connection_search_failure(self): @pytest.mark.asyncio async def test_connection_exception(self): """Test Tavily connection when exception occurs.""" - with patch( - "app.services.web_search_service.WebSearchService" - ) as mock_service_class: + with patch("app.services.web_search_service.WebSearchService") as mock_service_class: mock_service = MagicMock() - mock_service.search = AsyncMock( - side_effect=Exception("Network error") - ) + mock_service.search = AsyncMock(side_effect=Exception("Network error")) mock_service_class.return_value = mock_service success, message = await test_tavily_connection("test-api-key") diff --git a/backend/tests/test_activity_log_model.py b/backend/tests/test_activity_log_model.py index 1c98240..0760b11 100644 --- a/backend/tests/test_activity_log_model.py +++ b/backend/tests/test_activity_log_model.py @@ -2,16 +2,16 @@ Tests the activity logging functionality for tracking entity events. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from datetime import datetime, timezone +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.activity_log import ActivityLog from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.activity_log import ActivityLog +from app.models.user import User from app.services.user_service import UserService @@ -223,11 +223,7 @@ def test_query_activity_by_event_type( test_db_session.commit() # Query for all SPEC_DRAFT_CREATED events - draft_logs = ( - test_db_session.query(ActivityLog) - .filter(ActivityLog.event_type == "SPEC_DRAFT_CREATED") - .all() - ) + draft_logs = test_db_session.query(ActivityLog).filter(ActivityLog.event_type == "SPEC_DRAFT_CREATED").all() assert len(draft_logs) == 2 @@ -300,11 +296,7 @@ def test_brainstorming_phase_events( test_db_session.commit() # Query phase activity - phase_logs = ( - test_db_session.query(ActivityLog) - .filter(ActivityLog.entity_id == phase_id) - .all() - ) + phase_logs = test_db_session.query(ActivityLog).filter(ActivityLog.entity_id == phase_id).all() assert len(phase_logs) == 2 @@ -335,16 +327,8 @@ def test_module_and_feature_events( test_db_session.commit() # Query by entity type - module_logs = ( - test_db_session.query(ActivityLog) - .filter(ActivityLog.entity_type == "module") - .all() - ) - feature_logs = ( - test_db_session.query(ActivityLog) - .filter(ActivityLog.entity_type == "feature") - .all() - ) + module_logs = test_db_session.query(ActivityLog).filter(ActivityLog.entity_type == "module").all() + feature_logs = test_db_session.query(ActivityLog).filter(ActivityLog.entity_type == "feature").all() assert len(module_logs) == 1 assert len(feature_logs) == 1 diff --git a/backend/tests/test_activity_log_service.py b/backend/tests/test_activity_log_service.py index 60966cd..4104936 100644 --- a/backend/tests/test_activity_log_service.py +++ b/backend/tests/test_activity_log_service.py @@ -2,21 +2,20 @@ Tests the service layer for activity log operations. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User -from app.models.organization import Organization -from app.models.project import Project, ProjectStatus from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleProvenance from app.models.feature import Feature, FeatureProvenance, FeatureStatus -from app.models.activity_log import ActivityLog -from app.services.user_service import UserService +from app.models.module import Module, ModuleProvenance +from app.models.organization import Organization +from app.models.project import Project, ProjectStatus +from app.models.user import User from app.services.activity_log_service import ActivityLogService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_activity_log_wiring.py b/backend/tests/test_activity_log_wiring.py index 943f65e..7fce481 100644 --- a/backend/tests/test_activity_log_wiring.py +++ b/backend/tests/test_activity_log_wiring.py @@ -3,18 +3,16 @@ These tests verify that activity logs are created correctly when entities are created, updated, archived, deleted, or restored. """ -import pytest -from uuid import uuid4 from app.models.activity_log import ActivityLog from app.models.brainstorming_phase import BrainstormingPhaseType -from app.models.module import ModuleProvenance, ModuleType from app.models.feature import FeatureProvenance, FeatureType +from app.models.module import ModuleProvenance, ModuleType from app.services.activity_log_service import ActivityEventTypes from app.services.brainstorming_phase_service import BrainstormingPhaseService -from app.services.module_service import ModuleService -from app.services.feature_service import FeatureService from app.services.draft_version_service import DraftVersionService +from app.services.feature_service import FeatureService +from app.services.module_service import ModuleService class TestBrainstormingPhaseActivityLogging: @@ -655,11 +653,7 @@ def test_activity_actor_id_is_set_correctly( assert response.status_code == 201 module_data = response.json() - activity_log = ( - db.query(ActivityLog) - .filter(ActivityLog.entity_id == module_data["id"]) - .first() - ) + activity_log = db.query(ActivityLog).filter(ActivityLog.entity_id == module_data["id"]).first() assert activity_log is not None assert activity_log.actor_id == str(test_user.id) diff --git a/backend/tests/test_agent_api.py b/backend/tests/test_agent_api.py index 4c7f718..ff17a93 100644 --- a/backend/tests/test_agent_api.py +++ b/backend/tests/test_agent_api.py @@ -1,29 +1,26 @@ """Tests for Agent API endpoints exposing final data to coding agents.""" -import pytest from datetime import datetime, timezone from uuid import uuid4 -from sqlalchemy.orm import Session + +import pytest from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app.models import ( - User, Organization, - Project, ProjectType, - OrgMembership, - OrgRole, + User, ) from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.final_spec import FinalSpec +from app.models.feature import Feature, FeatureProvenance, FeatureStatus from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec from app.models.module import Module, ModuleProvenance -from app.models.feature import Feature, FeatureProvenance, FeatureStatus -from app.services.user_service import UserService -from app.services.org_service import OrgService -from app.services.project_service import ProjectService -from app.services.api_key_service import ApiKeyService from app.schemas.api_key import ApiKeyCreate +from app.services.api_key_service import ApiKeyService +from app.services.project_service import ProjectService +from app.services.user_service import UserService class TestAgentAPI: @@ -182,9 +179,7 @@ def setup_project_with_data( } @pytest.fixture - def api_key( - self, db: Session, test_user: User, setup_project_with_data: dict - ) -> str: + def api_key(self, db: Session, test_user: User, setup_project_with_data: dict) -> str: """Create an API key for the test user.""" _, raw_key = ApiKeyService.create_api_key( db=db, @@ -449,4 +444,3 @@ def test_invalid_api_key_rejected( headers={"Authorization": "Bearer invalid-key-12345"}, ) assert response.status_code == 401 - diff --git a/backend/tests/test_analytics_cache.py b/backend/tests/test_analytics_cache.py index f813b9a..17a0b3d 100644 --- a/backend/tests/test_analytics_cache.py +++ b/backend/tests/test_analytics_cache.py @@ -279,9 +279,7 @@ def test_deletes_org_specific_keys(self): result = invalidate_analytics_cache(org_id="abc123") assert result == 1 - mock_redis.scan_iter.assert_called_once_with( - match="analytics:*:abc123*", count=100 - ) + mock_redis.scan_iter.assert_called_once_with(match="analytics:*:abc123*", count=100) def test_handles_empty_key_list(self): """Test handling when no keys match pattern.""" diff --git a/backend/tests/test_analytics_service.py b/backend/tests/test_analytics_service.py index 459ce1f..c92aeec 100644 --- a/backend/tests/test_analytics_service.py +++ b/backend/tests/test_analytics_service.py @@ -8,7 +8,6 @@ from decimal import Decimal from uuid import uuid4 -import pytest from sqlalchemy.orm import Session from app.models.daily_usage_summary import DailyUsageSummary @@ -414,9 +413,7 @@ def test_with_org_filter(self, db: Session): db.commit() # Query for org1 only - result = AnalyticsService.get_efficiency_metrics( - db, TimeRange.WEEKLY, org_id=org1.id - ) + result = AnalyticsService.get_efficiency_metrics(db, TimeRange.WEEKLY, org_id=org1.id) assert result.org_id == org1.id assert result.metrics.total_tokens == 500_000 @@ -977,9 +974,7 @@ class TestGetOrgEfficiencyOverview: def test_empty_organizations(self, db: Session): """Test handling when no organizations exist.""" # Query with a non-existent org_id - result = AnalyticsService.get_org_efficiency_overview( - db, TimeRange.MONTHLY, org_id=uuid4() - ) + result = AnalyticsService.get_org_efficiency_overview(db, TimeRange.MONTHLY, org_id=uuid4()) assert result.organizations == [] @@ -994,9 +989,7 @@ def test_single_org_overview(self, db: Session): db.add(org) db.commit() - result = AnalyticsService.get_org_efficiency_overview( - db, TimeRange.MONTHLY, org_id=org.id - ) + result = AnalyticsService.get_org_efficiency_overview(db, TimeRange.MONTHLY, org_id=org.id) assert len(result.organizations) == 1 org_efficiency = result.organizations[0] @@ -1039,9 +1032,7 @@ def test_includes_top_users(self, db: Session): db.add(summary) db.commit() - result = AnalyticsService.get_org_efficiency_overview( - db, TimeRange.MONTHLY, org_id=org.id - ) + result = AnalyticsService.get_org_efficiency_overview(db, TimeRange.MONTHLY, org_id=org.id) org_efficiency = result.organizations[0] # Users list may or may not include the user depending on whether @@ -1092,9 +1083,7 @@ def test_includes_top_projects(self, db: Session): db.add(summary) db.commit() - result = AnalyticsService.get_org_efficiency_overview( - db, TimeRange.MONTHLY, org_id=org.id - ) + result = AnalyticsService.get_org_efficiency_overview(db, TimeRange.MONTHLY, org_id=org.id) org_efficiency = result.organizations[0] assert isinstance(org_efficiency.projects, list) @@ -1145,9 +1134,7 @@ def test_respects_limit_parameters(self, db: Session): db.commit() - result = AnalyticsService.get_org_efficiency_overview( - db, TimeRange.MONTHLY, org_id=org.id, limit_projects=3 - ) + result = AnalyticsService.get_org_efficiency_overview(db, TimeRange.MONTHLY, org_id=org.id, limit_projects=3) org_efficiency = result.organizations[0] # Should be capped at 3 projects @@ -1167,9 +1154,7 @@ def test_multiple_orgs_platform_wide(self, db: Session): orgs.append(org) db.commit() - result = AnalyticsService.get_org_efficiency_overview( - db, TimeRange.MONTHLY - ) + result = AnalyticsService.get_org_efficiency_overview(db, TimeRange.MONTHLY) # Should include all 3 orgs (plus any pre-existing orgs in test DB) assert len(result.organizations) >= 3 diff --git a/backend/tests/test_api_key_encryption.py b/backend/tests/test_api_key_encryption.py index 4eb9d17..ce7c276 100644 --- a/backend/tests/test_api_key_encryption.py +++ b/backend/tests/test_api_key_encryption.py @@ -3,9 +3,10 @@ Tests encryption, decryption, and verification of API keys. """ + import pytest -from app.auth.api_key_utils import generate_api_key, hash_api_key, hash_api_key_sha256, verify_api_key, get_key_preview +from app.auth.api_key_utils import generate_api_key, get_key_preview, hash_api_key, hash_api_key_sha256, verify_api_key @pytest.fixture(autouse=True) @@ -69,7 +70,7 @@ def test_encrypt_api_key_returns_different_string(self): def test_decrypt_api_key_returns_original(self): """Test that decryption returns the original key.""" - from app.auth.encryption_utils import encrypt_api_key, decrypt_api_key + from app.auth.encryption_utils import decrypt_api_key, encrypt_api_key key = generate_api_key() encrypted = encrypt_api_key(key) @@ -130,25 +131,21 @@ class TestApiKeyCreationWithLookupHash: def test_create_api_key_sets_lookup_hash(self, db, test_user): """Test that creating an API key populates key_lookup_hash.""" - from app.services.api_key_service import ApiKeyService from app.schemas.api_key import ApiKeyCreate + from app.services.api_key_service import ApiKeyService - api_key, raw_key = ApiKeyService.create_api_key( - db, user_id=test_user.id, data=ApiKeyCreate(name="test-key") - ) + api_key, raw_key = ApiKeyService.create_api_key(db, user_id=test_user.id, data=ApiKeyCreate(name="test-key")) assert api_key.key_lookup_hash is not None assert api_key.key_lookup_hash == hash_api_key_sha256(raw_key) def test_lookup_hash_matches_sha256_of_raw_key(self, db, test_user): """Test that the stored lookup hash matches SHA-256 of the raw key.""" - from app.services.api_key_service import ApiKeyService - from app.schemas.api_key import ApiKeyCreate from app.models.api_key import ApiKey + from app.schemas.api_key import ApiKeyCreate + from app.services.api_key_service import ApiKeyService - _, raw_key = ApiKeyService.create_api_key( - db, user_id=test_user.id, data=ApiKeyCreate(name="test-key") - ) + _, raw_key = ApiKeyService.create_api_key(db, user_id=test_user.id, data=ApiKeyCreate(name="test-key")) # Query directly by lookup hash expected_hash = hash_api_key_sha256(raw_key) @@ -177,13 +174,11 @@ class TestApiKeyDeletion: def test_delete_revoked_key(self, db, test_user): """Test that a revoked key can be permanently deleted.""" - from app.services.api_key_service import ApiKeyService from app.schemas.api_key import ApiKeyCreate + from app.services.api_key_service import ApiKeyService # Create and revoke a key - api_key, _ = ApiKeyService.create_api_key( - db, user_id=test_user.id, data=ApiKeyCreate(name="test-key") - ) + api_key, _ = ApiKeyService.create_api_key(db, user_id=test_user.id, data=ApiKeyCreate(name="test-key")) ApiKeyService.revoke_api_key(db, key_id=api_key.id, user_id=test_user.id) # Delete should succeed @@ -195,14 +190,13 @@ def test_delete_revoked_key(self, db, test_user): def test_delete_active_key_fails(self, db, test_user): """Test that an active (non-revoked) key cannot be deleted.""" - from app.services.api_key_service import ApiKeyService - from app.schemas.api_key import ApiKeyCreate from fastapi import HTTPException + from app.schemas.api_key import ApiKeyCreate + from app.services.api_key_service import ApiKeyService + # Create a key but don't revoke it - api_key, _ = ApiKeyService.create_api_key( - db, user_id=test_user.id, data=ApiKeyCreate(name="test-key") - ) + api_key, _ = ApiKeyService.create_api_key(db, user_id=test_user.id, data=ApiKeyCreate(name="test-key")) # Delete should fail with 400 with pytest.raises(HTTPException) as exc_info: @@ -212,10 +206,12 @@ def test_delete_active_key_fails(self, db, test_user): def test_delete_nonexistent_key_fails(self, db, test_user): """Test that deleting a nonexistent key returns 404.""" - from app.services.api_key_service import ApiKeyService - from fastapi import HTTPException from uuid import uuid4 + from fastapi import HTTPException + + from app.services.api_key_service import ApiKeyService + with pytest.raises(HTTPException) as exc_info: ApiKeyService.delete_api_key(db, key_id=uuid4(), user_id=test_user.id) assert exc_info.value.status_code == 404 diff --git a/backend/tests/test_auth_service.py b/backend/tests/test_auth_service.py index 609a763..cf30641 100644 --- a/backend/tests/test_auth_service.py +++ b/backend/tests/test_auth_service.py @@ -4,14 +4,16 @@ Verifies that the AuthService correctly maps OAuth identities to internal users, handles user creation, and links identities to existing users. """ -import pytest + from unittest.mock import patch + +import pytest from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker +from app.auth.service import PROVIDER_CONFIG, AuthService from app.database import Base -from app.auth.service import AuthService, PROVIDER_CONFIG -from app.models import User, IdentityProvider, IdentityProviderType, UserIdentity +from app.models import IdentityProvider, IdentityProviderType, User from app.schemas.oauth import NormalizedUserInfo @@ -111,9 +113,7 @@ class TestGetOrCreateIdentityProvider: def test_create_google_provider(self, test_db_session): """Test creating a new Google identity provider.""" - provider = AuthService.get_or_create_identity_provider( - test_db_session, "google" - ) + provider = AuthService.get_or_create_identity_provider(test_db_session, "google") assert provider.id is not None assert provider.slug == "google" @@ -123,9 +123,7 @@ def test_create_google_provider(self, test_db_session): def test_create_github_provider(self, test_db_session): """Test creating a new GitHub identity provider.""" - provider = AuthService.get_or_create_identity_provider( - test_db_session, "github" - ) + provider = AuthService.get_or_create_identity_provider(test_db_session, "github") assert provider.id is not None assert provider.slug == "github" @@ -136,23 +134,17 @@ def test_create_github_provider(self, test_db_session): def test_get_existing_provider(self, test_db_session): """Test getting an existing provider returns the same instance.""" # Create provider - provider1 = AuthService.get_or_create_identity_provider( - test_db_session, "google" - ) + provider1 = AuthService.get_or_create_identity_provider(test_db_session, "google") # Get again - should return same provider - provider2 = AuthService.get_or_create_identity_provider( - test_db_session, "google" - ) + provider2 = AuthService.get_or_create_identity_provider(test_db_session, "google") assert provider1.id == provider2.id def test_unknown_provider_raises_error(self, test_db_session): """Test that unknown provider slug raises ValueError.""" with pytest.raises(ValueError) as exc_info: - AuthService.get_or_create_identity_provider( - test_db_session, "unknown-provider" - ) + AuthService.get_or_create_identity_provider(test_db_session, "unknown-provider") assert "Unknown provider" in str(exc_info.value) assert "unknown-provider" in str(exc_info.value) @@ -168,9 +160,7 @@ def test_provider_config_completeness(self): class TestUpsertUserFromIdentity: """Tests for AuthService.upsert_user_from_identity.""" - def test_new_identity_new_email_creates_user( - self, test_db_session, google_normalized_info - ): + def test_new_identity_new_email_creates_user(self, test_db_session, google_normalized_info): """Case 1: New provider identity with new email creates new user.""" user, is_new_user = AuthService.upsert_user_from_identity( test_db_session, @@ -191,9 +181,7 @@ def test_new_identity_new_email_creates_user( assert identities[0].subject == "google-12345" assert identities[0].provider.slug == "google" - def test_existing_email_links_identity( - self, test_db_session, sample_user, google_normalized_info - ): + def test_existing_email_links_identity(self, test_db_session, sample_user, google_normalized_info): """Case 2: Existing user with same email, new identity links to existing user.""" # Update normalized info to match existing user's email google_normalized_info.email = sample_user.email @@ -212,9 +200,7 @@ def test_existing_email_links_identity( assert len(identities) == 1 assert identities[0].subject == "google-12345" - def test_existing_identity_returns_same_user( - self, test_db_session, google_normalized_info - ): + def test_existing_identity_returns_same_user(self, test_db_session, google_normalized_info): """Case 3: Existing identity (same provider, same subject) returns same user.""" # First upsert - creates user user1, is_new_user1 = AuthService.upsert_user_from_identity( @@ -270,9 +256,7 @@ def test_new_identity_no_email_creates_user(self, test_db_session): assert user.display_name == "No Email User" assert user.password_hash is None - def test_user_can_have_multiple_identities( - self, test_db_session, google_normalized_info, github_normalized_info - ): + def test_user_can_have_multiple_identities(self, test_db_session, google_normalized_info, github_normalized_info): """Test that a user can link multiple provider identities.""" # Use same email for both shared_email = "multiauth@example.com" @@ -302,9 +286,7 @@ def test_user_can_have_multiple_identities( providers = {i.provider.slug for i in identities} assert providers == {"google", "github"} - def test_unknown_provider_raises_error( - self, test_db_session, google_normalized_info - ): + def test_unknown_provider_raises_error(self, test_db_session, google_normalized_info): """Test that unknown provider slug raises ValueError.""" with pytest.raises(ValueError) as exc_info: AuthService.upsert_user_from_identity( @@ -315,9 +297,7 @@ def test_unknown_provider_raises_error( assert "Unknown provider" in str(exc_info.value) - def test_identity_updates_on_subsequent_login( - self, test_db_session, google_normalized_info - ): + def test_identity_updates_on_subsequent_login(self, test_db_session, google_normalized_info): """Test that identity profile is updated on subsequent logins.""" # First login user1, _ = AuthService.upsert_user_from_identity( @@ -367,20 +347,14 @@ def test_provider_created_if_not_exists(self, test_db_session, google_normalized class TestGetUserIdentities: """Tests for AuthService.get_user_identities.""" - def test_get_identities_for_user( - self, test_db_session, google_normalized_info, github_normalized_info - ): + def test_get_identities_for_user(self, test_db_session, google_normalized_info, github_normalized_info): """Test getting all identities for a user.""" # Create user with two identities google_normalized_info.email = "multi@example.com" github_normalized_info.email = "multi@example.com" - AuthService.upsert_user_from_identity( - test_db_session, "google", google_normalized_info - ) - user, _ = AuthService.upsert_user_from_identity( - test_db_session, "github", github_normalized_info - ) + AuthService.upsert_user_from_identity(test_db_session, "google", google_normalized_info) + user, _ = AuthService.upsert_user_from_identity(test_db_session, "github", github_normalized_info) identities = AuthService.get_user_identities(test_db_session, user.id) diff --git a/backend/tests/test_auth_utils.py b/backend/tests/test_auth_utils.py index e4347ec..6d91cd3 100644 --- a/backend/tests/test_auth_utils.py +++ b/backend/tests/test_auth_utils.py @@ -3,16 +3,17 @@ Tests password hashing, verification, and JWT token operations. """ -from datetime import timedelta, datetime, UTC + +from datetime import UTC, datetime, timedelta import pytest from jose import JWTError from app.auth.utils import ( - hash_password, - verify_password, create_access_token, decode_access_token, + hash_password, + verify_password, ) diff --git a/backend/tests/test_brainstorm_agent.py b/backend/tests/test_brainstorm_agent.py index 348e6cc..d19999a 100644 --- a/backend/tests/test_brainstorm_agent.py +++ b/backend/tests/test_brainstorm_agent.py @@ -1,22 +1,22 @@ """Tests for the brainstorm agent.""" -import pytest + from uuid import uuid4 +from app.agents.brainstorm.generator import strip_markdown_json from app.agents.brainstorm.types import ( + MAX_ASPECTS, + MAX_MCQ_CHOICES, + MAX_QUESTIONS_PER_ASPECT, + MIN_ASPECTS, + MIN_MCQ_CHOICES, + MIN_QUESTIONS_PER_ASPECT, BrainstormContext, BrainstormResult, GeneratedAspect, GeneratedClarificationQuestion, GeneratedMCQ, validate_brainstorm_result, - MIN_ASPECTS, - MAX_ASPECTS, - MIN_QUESTIONS_PER_ASPECT, - MAX_QUESTIONS_PER_ASPECT, - MIN_MCQ_CHOICES, - MAX_MCQ_CHOICES, ) -from app.agents.brainstorm.generator import strip_markdown_json class TestBrainstormTypes: @@ -136,10 +136,7 @@ def _make_valid_result( for j in range(questions_per_aspect): mcq = GeneratedMCQ( question_text=f"Question {j}?", - choices=[ - {"id": chr(ord('a') + k), "label": f"Choice {k}"} - for k in range(mcq_choices) - ], + choices=[{"id": chr(ord("a") + k), "label": f"Choice {k}"} for k in range(mcq_choices)], ) question = GeneratedClarificationQuestion( title=f"Question {i}-{j}", @@ -250,15 +247,15 @@ def test_strips_whitespace(self): def test_complex_json(self): """Test with complex nested JSON.""" - text = '''```json + text = """```json { "aspects": [ {"title": "Test", "questions": []} ], "notes": [] } -```''' +```""" result = strip_markdown_json(text) assert '"aspects"' in result - assert result.startswith('{') - assert result.endswith('}') + assert result.startswith("{") + assert result.endswith("}") diff --git a/backend/tests/test_brainstorm_generation.py b/backend/tests/test_brainstorm_generation.py index 99a03cb..e76e8fb 100644 --- a/backend/tests/test_brainstorm_generation.py +++ b/backend/tests/test_brainstorm_generation.py @@ -1,14 +1,12 @@ """Tests for brainstorm generation (integration tests).""" -import pytest -from unittest.mock import AsyncMock, MagicMock, patch + from uuid import uuid4 +import pytest + from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleProvenance -from app.models.feature import Feature, FeatureProvenance -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType from app.models.job import JobType +from app.models.thread import ContextType from app.services.brainstorming_phase_service import BrainstormingPhaseService diff --git a/backend/tests/test_brainstorming_phase_endpoints.py b/backend/tests/test_brainstorming_phase_endpoints.py index 3f0a384..9d40c55 100644 --- a/backend/tests/test_brainstorming_phase_endpoints.py +++ b/backend/tests/test_brainstorming_phase_endpoints.py @@ -1,5 +1,5 @@ """Tests for brainstorming phase REST API endpoints.""" -import pytest + from uuid import uuid4 from app.models.brainstorming_phase import BrainstormingPhaseType diff --git a/backend/tests/test_brainstorming_phase_models.py b/backend/tests/test_brainstorming_phase_models.py index cbb19b1..adeadd8 100644 --- a/backend/tests/test_brainstorming_phase_models.py +++ b/backend/tests/test_brainstorming_phase_models.py @@ -2,18 +2,20 @@ Tests the new Phase 7 models for brainstorming workflows. """ + +from datetime import datetime, timezone + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from datetime import datetime, timezone +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User -from app.models.organization import Organization -from app.models.project import Project, ProjectStatus from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleProvenance from app.models.feature import Feature, FeatureProvenance, FeatureStatus +from app.models.module import Module, ModuleProvenance +from app.models.organization import Organization +from app.models.project import Project, ProjectStatus +from app.models.user import User from app.services.user_service import UserService @@ -126,9 +128,7 @@ def test_create_feature_specific_brainstorming_phase( assert phase.phase_type == BrainstormingPhaseType.FEATURE_SPECIFIC - def test_project_relationship( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_project_relationship(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test the relationship between BrainstormingPhase and Project.""" phase = BrainstormingPhase( project_id=sample_project.id, @@ -143,9 +143,7 @@ def test_project_relationship( assert len(sample_project.brainstorming_phases) == 1 assert sample_project.brainstorming_phases[0].id == phase.id - def test_cascade_delete_on_project( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_cascade_delete_on_project(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test cascade delete when project is deleted.""" phase = BrainstormingPhase( project_id=sample_project.id, @@ -165,9 +163,7 @@ def test_cascade_delete_on_project( deleted_phase = test_db_session.query(BrainstormingPhase).filter_by(id=phase_id).first() assert deleted_phase is None - def test_multiple_phases_per_project( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_multiple_phases_per_project(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test creating multiple brainstorming phases for one project.""" phase1 = BrainstormingPhase( project_id=sample_project.id, @@ -239,9 +235,7 @@ def test_create_system_module( assert module.order_index == 0 assert module.archived_at is None - def test_create_user_module( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_create_user_module(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test creating a user-created module without brainstorming phase.""" module = Module( project_id=sample_project.id, @@ -261,9 +255,7 @@ def test_create_user_module( assert module.provenance == ModuleProvenance.USER assert module.brainstorming_phase_id is None - def test_archive_module( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_archive_module(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test archiving a module.""" module = Module( project_id=sample_project.id, @@ -284,9 +276,7 @@ def test_archive_module( assert module.archived_at is not None - def test_project_relationship( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_project_relationship(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test the relationship between Module and Project.""" module = Module( project_id=sample_project.id, @@ -360,9 +350,7 @@ def test_cascade_delete_on_brainstorming_phase( assert remaining_module is not None assert remaining_module.brainstorming_phase_id is None - def test_order_index( - self, test_db_session: Session, sample_project: Project, sample_user: User - ): + def test_order_index(self, test_db_session: Session, sample_project: Project, sample_user: User): """Test module ordering with order_index.""" module1 = Module( project_id=sample_project.id, @@ -403,9 +391,7 @@ class TestFeature: """Tests for Feature model.""" @pytest.fixture - def sample_module( - self, test_db_session: Session, sample_project: Project, sample_user: User - ) -> Module: + def sample_module(self, test_db_session: Session, sample_project: Project, sample_user: User) -> Module: """Create a sample module for testing.""" module = Module( project_id=sample_project.id, @@ -421,9 +407,7 @@ def sample_module( test_db_session.refresh(module) return module - def test_create_system_feature( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_create_system_feature(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test creating a system-generated feature.""" feature = Feature( module_id=sample_module.id, @@ -447,9 +431,7 @@ def test_create_system_feature( assert feature.status == FeatureStatus.ACTIVE assert feature.archived_at is None - def test_create_user_feature( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_create_user_feature(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test creating a user-created feature.""" feature = Feature( module_id=sample_module.id, @@ -465,9 +447,7 @@ def test_create_user_feature( assert feature.provenance == FeatureProvenance.USER - def test_feature_with_all_fields( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_feature_with_all_fields(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test creating a feature with all fields populated.""" feature = Feature( module_id=sample_module.id, @@ -488,9 +468,7 @@ def test_feature_with_all_fields( assert feature.prompt_plan_text == "Detailed implementation plan." assert feature.implementation_notes == "Notes from the coding agent." - def test_archive_feature( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_archive_feature(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test archiving a feature.""" feature = Feature( module_id=sample_module.id, @@ -512,9 +490,7 @@ def test_archive_feature( assert feature.status == FeatureStatus.ARCHIVED assert feature.archived_at is not None - def test_restore_feature( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_restore_feature(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test restoring an archived feature.""" feature = Feature( module_id=sample_module.id, @@ -540,9 +516,7 @@ def test_restore_feature( assert feature.provenance == FeatureProvenance.RESTORED assert feature.archived_at is None - def test_module_relationship( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_module_relationship(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test the relationship between Feature and Module.""" feature = Feature( module_id=sample_module.id, @@ -559,9 +533,7 @@ def test_module_relationship( assert len(sample_module.features) == 1 assert sample_module.features[0].id == feature.id - def test_cascade_delete_on_module( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_cascade_delete_on_module(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test cascade delete when module is deleted.""" feature = Feature( module_id=sample_module.id, @@ -583,9 +555,7 @@ def test_cascade_delete_on_module( deleted_feature = test_db_session.query(Feature).filter_by(id=feature_id).first() assert deleted_feature is None - def test_multiple_features_per_module( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_multiple_features_per_module(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test creating multiple features in one module.""" feature1 = Feature( module_id=sample_module.id, @@ -610,9 +580,7 @@ def test_multiple_features_per_module( assert len(sample_module.features) == 2 - def test_feature_key_uniqueness_query( - self, test_db_session: Session, sample_module: Module, sample_user: User - ): + def test_feature_key_uniqueness_query(self, test_db_session: Session, sample_module: Module, sample_user: User): """Test querying for features by feature_key.""" feature = Feature( module_id=sample_module.id, @@ -626,11 +594,7 @@ def test_feature_key_uniqueness_query( test_db_session.commit() # Query by feature_key - found_feature = ( - test_db_session.query(Feature) - .filter(Feature.feature_key == "TESTPROJ-010") - .first() - ) + found_feature = test_db_session.query(Feature).filter(Feature.feature_key == "TESTPROJ-010").first() assert found_feature is not None assert found_feature.id == feature.id diff --git a/backend/tests/test_brainstorming_phase_service.py b/backend/tests/test_brainstorming_phase_service.py index 7e8bc9f..322c670 100644 --- a/backend/tests/test_brainstorming_phase_service.py +++ b/backend/tests/test_brainstorming_phase_service.py @@ -2,25 +2,27 @@ Tests the service layer for brainstorming phase operations. """ + +from uuid import uuid4 + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhaseType +from app.models.feature import Feature, FeatureProvenance, FeatureType +from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec +from app.models.module import Module, ModuleProvenance, ModuleType from app.models.organization import Organization +from app.models.phase_container import PhaseContainer from app.models.project import Project, ProjectStatus -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleProvenance, ModuleType -from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType -from app.models.spec_version import SpecVersion, SpecType -from app.models.final_spec import FinalSpec -from app.models.final_prompt_plan import FinalPromptPlan from app.models.prompt_plan_coverage import PromptPlanCoverageReport -from app.services.user_service import UserService +from app.models.spec_version import SpecType, SpecVersion +from app.models.user import User from app.services.brainstorming_phase_service import BrainstormingPhaseService, load_sibling_phases -from app.models.phase_container import PhaseContainer +from app.services.user_service import UserService @pytest.fixture @@ -313,10 +315,13 @@ def test_delete_existing_phase( ) assert result is True - assert BrainstormingPhaseService.get_brainstorming_phase( - db=test_db_session, - phase_id=phase_id, - ) is None + assert ( + BrainstormingPhaseService.get_brainstorming_phase( + db=test_db_session, + phase_id=phase_id, + ) + is None + ) def test_delete_nonexistent_phase( self, @@ -1286,6 +1291,7 @@ def test_ignores_archived_implementation_modules( ): """Test that archived IMPLEMENTATION modules are ignored.""" from datetime import datetime, timezone + from app.services.brainstorming_phase_service import ( _phase_has_implementations_for_analysis, ) @@ -1422,9 +1428,9 @@ def test_aspect_generator_includes_phase_decisions(self): """Aspect generator _build_sibling_phases_section includes decisions in output.""" from app.agents.brainstorm_conversation.aspect_generator import AspectGeneratorAgent from app.agents.brainstorm_conversation.types import ( - SiblingPhasesContext, - SiblingPhaseContext, CrossPhaseDecision, + SiblingPhaseContext, + SiblingPhasesContext, ) generator = AspectGeneratorAgent.__new__(AspectGeneratorAgent) @@ -1467,9 +1473,9 @@ def test_question_generator_filters_by_aspect_relevance(self): """Question generator filters decisions by aspect keyword overlap.""" from app.agents.brainstorm_conversation.question_generator import QuestionGeneratorAgent from app.agents.brainstorm_conversation.types import ( - SiblingPhasesContext, - SiblingPhaseContext, CrossPhaseDecision, + SiblingPhaseContext, + SiblingPhasesContext, ) generator = QuestionGeneratorAgent.__new__(QuestionGeneratorAgent) diff --git a/backend/tests/test_brainstorming_phase_service_summary.py b/backend/tests/test_brainstorming_phase_service_summary.py index b74621b..e518063 100644 --- a/backend/tests/test_brainstorming_phase_service_summary.py +++ b/backend/tests/test_brainstorming_phase_service_summary.py @@ -1,5 +1,5 @@ """Tests for brainstorming phase service summary extraction.""" -import pytest + from unittest.mock import MagicMock from app.services.brainstorming_phase_service import _build_spec_summary_from_json diff --git a/backend/tests/test_brainstorming_preflight.py b/backend/tests/test_brainstorming_preflight.py index 34a8a86..096c081 100644 --- a/backend/tests/test_brainstorming_preflight.py +++ b/backend/tests/test_brainstorming_preflight.py @@ -1,22 +1,23 @@ """Tests for brainstorming phase generation preflight checks.""" -import pytest -from uuid import uuid4 + from datetime import datetime, timezone from unittest.mock import patch +from uuid import uuid4 +import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleType, ModuleProvenance -from app.models.feature import Feature, FeatureType, FeatureVisibilityStatus, FeaturePriority, FeatureProvenance -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.project import Project +from app.models.feature import Feature, FeaturePriority, FeatureProvenance, FeatureType, FeatureVisibilityStatus +from app.models.module import Module, ModuleProvenance, ModuleType from app.models.organization import Organization +from app.models.project import Project +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType from app.models.user import User +from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.user_service import UserService @@ -127,9 +128,9 @@ def create_phase_with_questions( module = Module( project_id=project.id, brainstorming_phase_id=phase.id, - title=f"Aspect {a+1}", - description=f"Description for aspect {a+1}", - module_key=f"TEST-{a+1:03d}", + title=f"Aspect {a + 1}", + description=f"Description for aspect {a + 1}", + module_key=f"TEST-{a + 1:03d}", module_key_number=a + 1, order_index=a, module_type=ModuleType.CONVERSATION, @@ -142,8 +143,8 @@ def create_phase_with_questions( for q in range(questions_per_aspect): feature = Feature( module_id=module.id, - title=f"Question {a+1}.{q+1}", - feature_key=f"Q-{a+1:03d}-{q+1:03d}", + title=f"Question {a + 1}.{q + 1}", + feature_key=f"Q-{a + 1:03d}-{q + 1:03d}", feature_key_number=(a * questions_per_aspect) + q + 1, feature_type=FeatureType.CONVERSATION, visibility_status=visibility_status, @@ -257,8 +258,8 @@ def test_at_pending_limit_skips(self, test_db, sample_user, sample_org, sample_p for i in range(100): feature = Feature( module_id=module.id, - title=f"Question {i+1}", - feature_key=f"Q-{i+1:03d}", + title=f"Question {i + 1}", + feature_key=f"Q-{i + 1:03d}", feature_key_number=i + 1, feature_type=FeatureType.CONVERSATION, visibility_status=FeatureVisibilityStatus.PENDING, # PENDING, not ACTIVE @@ -290,13 +291,11 @@ def test_fully_explored_skips(self, test_db, sample_user, sample_org, sample_pro ) # Delete one to get exactly 8 questions - feature_to_delete = test_db.query(Feature).filter( - Feature.visibility_status == FeatureVisibilityStatus.ACTIVE - ).first() + feature_to_delete = ( + test_db.query(Feature).filter(Feature.visibility_status == FeatureVisibilityStatus.ACTIVE).first() + ) if feature_to_delete: - thread = test_db.query(Thread).filter( - Thread.context_id == str(feature_to_delete.id) - ).first() + thread = test_db.query(Thread).filter(Thread.context_id == str(feature_to_delete.id)).first() if thread: test_db.query(ThreadItem).filter(ThreadItem.thread_id == thread.id).delete() test_db.delete(thread) @@ -345,7 +344,9 @@ def test_first_generation_does_not_skip(self, test_db, sample_user, sample_org, assert result["should_skip"] is False assert result["skip_code"] is None - def test_sufficient_engagement_does_not_skip(self, test_db, sample_user, sample_org, sample_project, mock_llm_config): + def test_sufficient_engagement_does_not_skip( + self, test_db, sample_user, sample_org, sample_project, mock_llm_config + ): """Phase with >=15% engagement should NOT skip.""" # Create phase with 10 questions, 2 answered (20% engagement) phase = create_phase_with_questions( @@ -448,8 +449,8 @@ def test_counts_answered_mcqs(self, test_db, sample_user, sample_org, sample_pro for i in range(3): feature = Feature( module_id=module.id, - title=f"Question {i+1}", - feature_key=f"Q-{i+1:03d}", + title=f"Question {i + 1}", + feature_key=f"Q-{i + 1:03d}", feature_key_number=i + 1, feature_type=FeatureType.CONVERSATION, visibility_status=FeatureVisibilityStatus.ACTIVE, diff --git a/backend/tests/test_code_explorer_client.py b/backend/tests/test_code_explorer_client.py index 76bfcf5..8c9e558 100644 --- a/backend/tests/test_code_explorer_client.py +++ b/backend/tests/test_code_explorer_client.py @@ -1,12 +1,11 @@ """Tests for CodeExplorerClient service.""" -import pytest from unittest.mock import AsyncMock, patch +import pytest + from app.services.code_explorer_client import ( CodeExplorerClient, - ImplementationAnalysis, - ImplementationAnalysisResult, ) @@ -165,7 +164,10 @@ def test_parse_with_phase_mentions(self): assert len(analyses) == 1 assert analyses[0]["phase_id"] == "phase-1" assert analyses[0]["phase_title"] == "Login Feature" - assert "Login Feature" in analyses[0]["implementation_summary"] or "OAuth2" in analyses[0]["implementation_summary"] + assert ( + "Login Feature" in analyses[0]["implementation_summary"] + or "OAuth2" in analyses[0]["implementation_summary"] + ) def test_parse_truncates_long_output(self): """Test that long output is truncated to 500 chars.""" diff --git a/backend/tests/test_credit_formatter.py b/backend/tests/test_credit_formatter.py index eb3f3d9..a3acf10 100644 --- a/backend/tests/test_credit_formatter.py +++ b/backend/tests/test_credit_formatter.py @@ -2,8 +2,6 @@ Tests for credit_formatter utility functions. """ -import pytest - from app.utils.credit_formatter import ( TOKENS_PER_CREDIT, format_credits, diff --git a/backend/tests/test_daily_aggregation_job.py b/backend/tests/test_daily_aggregation_job.py index c0710b0..7432e77 100644 --- a/backend/tests/test_daily_aggregation_job.py +++ b/backend/tests/test_daily_aggregation_job.py @@ -1,11 +1,11 @@ """ Tests for the daily usage aggregation scheduler job. """ -import pytest -from datetime import datetime, timezone, date, timedelta -from uuid import uuid4 + +from datetime import date, datetime, timedelta, timezone from decimal import Decimal -from unittest.mock import patch, MagicMock +from unittest.mock import patch +from uuid import uuid4 from app.models.daily_usage_summary import DailyUsageSummary from app.models.llm_usage_log import LLMUsageLog @@ -31,7 +31,7 @@ def test_aggregate_for_date_creates_summaries(self, db, test_org, test_user, tes model="test/model", prompt_tokens=100 * (i + 1), completion_tokens=50 * (i + 1), - cost_usd=Decimal(f"0.00{i+1}"), + cost_usd=Decimal(f"0.00{i + 1}"), created_at=log_time, ) db.add(log) @@ -46,11 +46,15 @@ def test_aggregate_for_date_creates_summaries(self, db, test_org, test_user, tes assert records == 1 # One summary for this user/project/org combo # Verify the summary - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - DailyUsageSummary.project_id == test_project.id, - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + DailyUsageSummary.project_id == test_project.id, + ) + .first() + ) assert summary is not None # Total: (100+200+300) + (50+100+150) = 600 + 300 = 900 @@ -103,11 +107,15 @@ def test_aggregate_for_date_updates_existing(self, db, test_org, test_user, test assert records == 1 # Verify the summary was updated with new values (replaced, not accumulated) - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - DailyUsageSummary.project_id == test_project.id, - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + DailyUsageSummary.project_id == test_project.id, + ) + .first() + ) assert summary is not None assert summary.total_tokens == 300 # 200 + 100 from the single log @@ -250,16 +258,8 @@ def test_run_aggregation_with_retry_success(self, db): target_date = date.today() - timedelta(days=1) - with patch.object( - DailyUsageSummaryService, - 'aggregate_for_date', - return_value=5 - ): - result = _run_aggregation_with_retry( - db, - target_date, - retry_delays=[1, 2, 3] - ) + with patch.object(DailyUsageSummaryService, "aggregate_for_date", return_value=5): + result = _run_aggregation_with_retry(db, target_date, retry_delays=[1, 2, 3]) assert result is True @@ -271,16 +271,10 @@ def test_run_aggregation_with_retry_retries_on_failure(self, db): # First two calls fail, third succeeds with patch.object( - DailyUsageSummaryService, - 'aggregate_for_date', - side_effect=[Exception("Fail 1"), Exception("Fail 2"), 5] + DailyUsageSummaryService, "aggregate_for_date", side_effect=[Exception("Fail 1"), Exception("Fail 2"), 5] ): - with patch('workers.scheduler.time.sleep') as mock_sleep: - result = _run_aggregation_with_retry( - db, - target_date, - retry_delays=[0.1, 0.2, 0.3] - ) + with patch("workers.scheduler.time.sleep") as mock_sleep: + result = _run_aggregation_with_retry(db, target_date, retry_delays=[0.1, 0.2, 0.3]) assert result is True assert mock_sleep.call_count == 2 @@ -292,16 +286,8 @@ def test_run_aggregation_with_retry_exhausted(self, db): target_date = date.today() - timedelta(days=1) # All calls fail - with patch.object( - DailyUsageSummaryService, - 'aggregate_for_date', - side_effect=Exception("Always fail") - ): - with patch('workers.scheduler.time.sleep'): - result = _run_aggregation_with_retry( - db, - target_date, - retry_delays=[0.1, 0.2] - ) + with patch.object(DailyUsageSummaryService, "aggregate_for_date", side_effect=Exception("Always fail")): + with patch("workers.scheduler.time.sleep"): + result = _run_aggregation_with_retry(db, target_date, retry_delays=[0.1, 0.2]) assert result is False diff --git a/backend/tests/test_daily_usage_summary_service.py b/backend/tests/test_daily_usage_summary_service.py index 0f95fed..e1b08db 100644 --- a/backend/tests/test_daily_usage_summary_service.py +++ b/backend/tests/test_daily_usage_summary_service.py @@ -1,13 +1,12 @@ """ Tests for DailyUsageSummaryService. """ -import pytest -from datetime import datetime, timezone, date, timedelta -from uuid import uuid4 + +from datetime import date, datetime, timedelta, timezone from decimal import Decimal +from uuid import uuid4 -from app.models.daily_usage_summary import DailyUsageSummary, SENTINEL_UUID -from app.models.llm_usage_log import LLMUsageLog +from app.models.daily_usage_summary import SENTINEL_UUID, DailyUsageSummary from app.services.daily_usage_summary_service import DailyUsageSummaryService @@ -276,9 +275,7 @@ def test_delete_summaries_before_date(self, db, test_org): assert deleted == 1 # Verify remaining summaries - remaining = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id - ).count() + remaining = db.query(DailyUsageSummary).filter(DailyUsageSummary.org_id == test_org.id).count() assert remaining == 2 def test_empty_date_range_returns_empty_list(self, db, test_org): diff --git a/backend/tests/test_daily_usage_trigger.py b/backend/tests/test_daily_usage_trigger.py index 4b792b6..5fdd487 100644 --- a/backend/tests/test_daily_usage_trigger.py +++ b/backend/tests/test_daily_usage_trigger.py @@ -4,12 +4,11 @@ These tests require PostgreSQL since they test database triggers. When running on SQLite, the tests will be skipped automatically. """ -import pytest -from datetime import datetime, timezone, date, timedelta -from uuid import uuid4 + +from datetime import datetime, timedelta, timezone from decimal import Decimal -from sqlalchemy import text +import pytest from app.models.daily_usage_summary import DailyUsageSummary from app.models.llm_usage_log import LLMUsageLog @@ -47,11 +46,15 @@ def test_trigger_creates_summary_on_first_insert(self, db, test_org, test_user, db.commit() # Check that a daily summary was created - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - DailyUsageSummary.project_id == test_project.id, - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + DailyUsageSummary.project_id == test_project.id, + ) + .first() + ) assert summary is not None assert summary.total_tokens == 150 @@ -97,11 +100,15 @@ def test_trigger_updates_summary_on_subsequent_inserts(self, db, test_org, test_ db.commit() # Check that the summary was updated - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - DailyUsageSummary.project_id == test_project.id, - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + DailyUsageSummary.project_id == test_project.id, + ) + .first() + ) assert summary is not None assert summary.total_tokens == 450 # 150 + 300 @@ -131,11 +138,15 @@ def test_trigger_handles_null_user(self, db, test_org, test_project): db.commit() # Check that a daily summary was created with NULL user - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id.is_(None), - DailyUsageSummary.project_id == test_project.id, - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id.is_(None), + DailyUsageSummary.project_id == test_project.id, + ) + .first() + ) assert summary is not None assert summary.user_id is None @@ -163,11 +174,15 @@ def test_trigger_handles_null_project(self, db, test_org, test_user): db.commit() # Check that a daily summary was created with NULL project - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - DailyUsageSummary.project_id.is_(None), - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + DailyUsageSummary.project_id.is_(None), + ) + .first() + ) assert summary is not None assert summary.project_id is None @@ -206,10 +221,14 @@ def test_trigger_separates_by_date(self, db, test_org, test_user, test_project): db.commit() # Check that two separate summaries were created - summaries = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - ).all() + summaries = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + ) + .all() + ) assert len(summaries) == 2 @@ -237,10 +256,14 @@ def test_trigger_accumulates_cost(self, db, test_org, test_user, test_project): db.commit() # Check accumulated cost - summary = db.query(DailyUsageSummary).filter( - DailyUsageSummary.org_id == test_org.id, - DailyUsageSummary.user_id == test_user.id, - ).first() + summary = ( + db.query(DailyUsageSummary) + .filter( + DailyUsageSummary.org_id == test_org.id, + DailyUsageSummary.user_id == test_user.id, + ) + .first() + ) assert summary is not None assert summary.total_cost_usd == Decimal("0.006") diff --git a/backend/tests/test_database.py b/backend/tests/test_database.py index 5db7f1e..c6f47da 100644 --- a/backend/tests/test_database.py +++ b/backend/tests/test_database.py @@ -8,13 +8,14 @@ Note: These tests use an in-memory SQLite database for speed and isolation. """ + import pytest -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker -from app.database import Base, get_db -from app.models.user import User from app.config import settings +from app.database import Base +from app.models.user import User @pytest.fixture diff --git a/backend/tests/test_domain_validation.py b/backend/tests/test_domain_validation.py index c438617..6e61a3c 100644 --- a/backend/tests/test_domain_validation.py +++ b/backend/tests/test_domain_validation.py @@ -1,5 +1,5 @@ """Tests for email domain validation during signup.""" -import pytest + from unittest.mock import patch from app.auth.domain_validation import validate_signup_domain diff --git a/backend/tests/test_draft_endpoints.py b/backend/tests/test_draft_endpoints.py index ec0779f..5993f8c 100644 --- a/backend/tests/test_draft_endpoints.py +++ b/backend/tests/test_draft_endpoints.py @@ -1,9 +1,9 @@ """Tests for draft version and finalization REST API endpoints.""" -import pytest + from uuid import uuid4 +from app.models import OrgMembership, OrgRole, ProjectRole from app.models.brainstorming_phase import BrainstormingPhaseType -from app.models import ProjectRole, OrgMembership, OrgRole from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.draft_version_service import DraftVersionService from app.services.finalization_service import FinalizationService diff --git a/backend/tests/test_draft_version_service.py b/backend/tests/test_draft_version_service.py index ddf5770..836e34c 100644 --- a/backend/tests/test_draft_version_service.py +++ b/backend/tests/test_draft_version_service.py @@ -2,19 +2,21 @@ Tests the service layer for draft version operations. """ + +from uuid import uuid4 + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.spec_version import SpecVersion, SpecType -from app.services.user_service import UserService +from app.models.spec_version import SpecType +from app.models.user import User from app.services.draft_version_service import DraftVersionService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_e2e_workflow.py b/backend/tests/test_e2e_workflow.py index 1ddae96..0bcd73f 100644 --- a/backend/tests/test_e2e_workflow.py +++ b/backend/tests/test_e2e_workflow.py @@ -3,21 +3,11 @@ These tests verify complete workflows through the system, ensuring all components work together correctly. """ -import pytest - -from app.models.brainstorming_phase import BrainstormingPhaseType -from app.models.module import ModuleProvenance -from app.models.feature import FeatureProvenance -from app.models import ProjectRole, OrgMembership, OrgRole -from app.services.brainstorming_phase_service import BrainstormingPhaseService -from app.services.draft_version_service import DraftVersionService -from app.services.finalization_service import FinalizationService -from app.services.module_service import ModuleService -from app.services.feature_service import FeatureService -from app.services.thread_service import ThreadService + +from app.models import OrgMembership, OrgRole, ProjectRole +from app.services.activity_log_service import ActivityLogService from app.services.project_service import ProjectService from app.services.user_service import UserService -from app.services.activity_log_service import ActivityLogService class TestCompleteBrainstormingWorkflow: @@ -299,7 +289,11 @@ def test_viewer_can_read_but_not_write( response = client.post( f"/api/v1/projects/{test_project.id}/brainstorming-phases", - json={"phase_type": "initial", "title": "Test Phase", "description": "A test brainstorming phase for permission testing"}, + json={ + "phase_type": "initial", + "title": "Test Phase", + "description": "A test brainstorming phase for permission testing", + }, headers=owner_headers, ) phase_id = response.json()["id"] @@ -375,7 +369,11 @@ def test_member_can_create_but_not_finalize( response = client.post( f"/api/v1/projects/{test_project.id}/brainstorming-phases", - json={"phase_type": "initial", "title": "Test Phase", "description": "A test brainstorming phase for permission testing"}, + json={ + "phase_type": "initial", + "title": "Test Phase", + "description": "A test brainstorming phase for permission testing", + }, headers=owner_headers, ) phase_id = response.json()["id"] @@ -439,7 +437,11 @@ def test_admin_can_finalize_and_archive( response = client.post( f"/api/v1/projects/{test_project.id}/brainstorming-phases", - json={"phase_type": "initial", "title": "Test Phase", "description": "A test brainstorming phase for permission testing"}, + json={ + "phase_type": "initial", + "title": "Test Phase", + "description": "A test brainstorming phase for permission testing", + }, headers=owner_headers, ) phase_id = response.json()["id"] diff --git a/backend/tests/test_email_service.py b/backend/tests/test_email_service.py index 1e31e9b..11e5a54 100644 --- a/backend/tests/test_email_service.py +++ b/backend/tests/test_email_service.py @@ -1,9 +1,10 @@ """Tests for email service.""" -import pytest from unittest.mock import AsyncMock, MagicMock, patch -from app.services.email_service import EmailService, EmailSendResult +import pytest + +from app.services.email_service import EmailService class TestEmailService: @@ -16,8 +17,10 @@ async def test_get_email_config_no_connector(self): mock_db = AsyncMock() # Mock PlatformSettingsService and app settings - with patch("app.services.email_service.PlatformSettingsService") as mock_service_class, \ - patch("app.config.settings") as mock_app_settings: + with ( + patch("app.services.email_service.PlatformSettingsService") as mock_service_class, + patch("app.config.settings") as mock_app_settings, + ): mock_service = AsyncMock() mock_service_class.return_value = mock_service @@ -41,8 +44,10 @@ async def test_get_email_config_connector_not_found(self): """Test get_email_config when connector doesn't exist and no env vars.""" mock_db = AsyncMock() - with patch("app.services.email_service.PlatformSettingsService") as mock_service_class, \ - patch("app.config.settings") as mock_app_settings: + with ( + patch("app.services.email_service.PlatformSettingsService") as mock_service_class, + patch("app.config.settings") as mock_app_settings, + ): mock_service = AsyncMock() mock_service_class.return_value = mock_service @@ -67,8 +72,10 @@ async def test_get_email_config_connector_inactive(self): """Test get_email_config when connector is inactive and no env vars.""" mock_db = AsyncMock() - with patch("app.services.email_service.PlatformSettingsService") as mock_service_class, \ - patch("app.config.settings") as mock_app_settings: + with ( + patch("app.services.email_service.PlatformSettingsService") as mock_service_class, + patch("app.config.settings") as mock_app_settings, + ): mock_service = AsyncMock() mock_service_class.return_value = mock_service diff --git a/backend/tests/test_email_verification.py b/backend/tests/test_email_verification.py index f59be7f..a5c453c 100644 --- a/backend/tests/test_email_verification.py +++ b/backend/tests/test_email_verification.py @@ -3,17 +3,19 @@ Tests UserService token methods, auth endpoints, and dependencies. """ + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from datetime import datetime, timezone, timedelta -from sqlalchemy import create_engine, event -from sqlalchemy.orm import sessionmaker, Session from fastapi.testclient import TestClient -from unittest.mock import patch, AsyncMock, MagicMock +from sqlalchemy import create_engine, event +from sqlalchemy.orm import Session, sessionmaker +from app.auth.utils import hash_password from app.database import Base from app.models.user import User from app.services.user_service import UserService -from app.auth.utils import hash_password, verify_password @pytest.fixture @@ -81,9 +83,7 @@ def verified_user(test_db_session): class TestUserServiceVerificationToken: """Tests for UserService email verification token methods.""" - def test_generate_verification_token_creates_token( - self, test_db_session: Session, sample_user - ): + def test_generate_verification_token_creates_token(self, test_db_session: Session, sample_user): """Test that generating a verification token returns a token and stores hash.""" raw_token = UserService.generate_verification_token(test_db_session, sample_user) @@ -108,9 +108,7 @@ def test_generate_verification_token_creates_token( assert expiry > now assert expiry < now + timedelta(hours=25) - def test_verify_email_token_succeeds_with_valid_token( - self, test_db_session: Session, sample_user - ): + def test_verify_email_token_succeeds_with_valid_token(self, test_db_session: Session, sample_user): """Test that verifying a valid token marks the user as verified.""" # Generate token raw_token = UserService.generate_verification_token(test_db_session, sample_user) @@ -124,9 +122,7 @@ def test_verify_email_token_succeeds_with_valid_token( assert verified_user.email_verification_token is None assert verified_user.email_verification_token_expires_at is None - def test_verify_email_token_fails_with_invalid_token( - self, test_db_session: Session, sample_user - ): + def test_verify_email_token_fails_with_invalid_token(self, test_db_session: Session, sample_user): """Test that verifying an invalid token returns None.""" # Generate token UserService.generate_verification_token(test_db_session, sample_user) @@ -140,9 +136,7 @@ def test_verify_email_token_fails_with_invalid_token( test_db_session.refresh(sample_user) assert sample_user.email_verified is False - def test_verify_email_token_fails_with_expired_token( - self, test_db_session: Session, sample_user - ): + def test_verify_email_token_fails_with_expired_token(self, test_db_session: Session, sample_user): """Test that verifying an expired token returns None.""" # Generate token raw_token = UserService.generate_verification_token(test_db_session, sample_user) @@ -156,9 +150,7 @@ def test_verify_email_token_fails_with_expired_token( assert result is None - def test_clear_verification_token( - self, test_db_session: Session, sample_user - ): + def test_clear_verification_token(self, test_db_session: Session, sample_user): """Test that clearing the verification token works.""" # Generate token first UserService.generate_verification_token(test_db_session, sample_user) @@ -172,9 +164,7 @@ def test_clear_verification_token( assert sample_user.email_verification_token is None assert sample_user.email_verification_token_expires_at is None - def test_generate_new_token_overwrites_old( - self, test_db_session: Session, sample_user - ): + def test_generate_new_token_overwrites_old(self, test_db_session: Session, sample_user): """Test that generating a new token overwrites the old one.""" # Generate first token token1 = UserService.generate_verification_token(test_db_session, sample_user) @@ -257,13 +247,15 @@ def test_register_skips_verification_when_enabled(self, test_db_session: Session def test_register_endpoint_skips_verification(self): """Integration test: register endpoint returns email_verification_required=False.""" - from app.main import app from app.config import settings + from app.main import app client = TestClient(app) - with patch.object(settings, "dangerously_skip_email_verification", True), \ - patch("app.routers.auth.validate_signup_domain", return_value=(True, None)): + with ( + patch.object(settings, "dangerously_skip_email_verification", True), + patch("app.routers.auth.validate_signup_domain", return_value=(True, None)), + ): response = client.post( "/api/v1/auth/register", json={ @@ -280,18 +272,18 @@ def test_register_endpoint_skips_verification(self): def test_register_endpoint_sends_verification_when_disabled(self): """Integration test: register endpoint sends verification email when skip is disabled.""" - from app.main import app from app.config import settings + from app.main import app client = TestClient(app) - with patch.object(settings, "dangerously_skip_email_verification", False), \ - patch("app.routers.auth.validate_signup_domain", return_value=(True, None)), \ - patch("app.routers.auth.EmailService") as mock_email_cls: + with ( + patch.object(settings, "dangerously_skip_email_verification", False), + patch("app.routers.auth.validate_signup_domain", return_value=(True, None)), + patch("app.routers.auth.EmailService") as mock_email_cls, + ): mock_email_instance = AsyncMock() - mock_email_instance.send_verification_email = AsyncMock( - return_value=MagicMock(success=True) - ) + mock_email_instance.send_verification_email = AsyncMock(return_value=MagicMock(success=True)) mock_email_cls.return_value = mock_email_instance response = client.post( diff --git a/backend/tests/test_feature_endpoints.py b/backend/tests/test_feature_endpoints.py index 3f8232c..e4eb9c7 100644 --- a/backend/tests/test_feature_endpoints.py +++ b/backend/tests/test_feature_endpoints.py @@ -1,13 +1,13 @@ """Tests for feature REST API endpoints.""" -import pytest + from uuid import uuid4 -from app.models.module import ModuleProvenance +from app.models import OrgMembership, OrgRole, ProjectRole from app.models.feature import FeatureProvenance, FeatureType -from app.models import ProjectRole, OrgMembership, OrgRole -from app.services.module_service import ModuleService +from app.models.module import ModuleProvenance from app.services.feature_service import FeatureService from app.services.implementation_service import ImplementationService +from app.services.module_service import ModuleService from app.services.project_service import ProjectService from app.services.user_service import UserService @@ -179,7 +179,7 @@ def test_create_feature_viewer_returns_403( project_id=test_project.id, user_id=viewer.id, role=ProjectRole.VIEWER, - ) + ) from app.auth.utils import create_access_token @@ -573,7 +573,7 @@ def test_update_feature_viewer_returns_403( project_id=test_project.id, user_id=viewer.id, role=ProjectRole.VIEWER, - ) + ) from app.auth.utils import create_access_token @@ -669,7 +669,7 @@ def test_archive_feature_member_returns_403( project_id=test_project.id, user_id=member.id, role=ProjectRole.MEMBER, - ) + ) from app.auth.utils import create_access_token @@ -813,7 +813,7 @@ def test_restore_feature_member_returns_403( project_id=test_project.id, user_id=member.id, role=ProjectRole.MEMBER, - ) + ) from app.auth.utils import create_access_token @@ -925,7 +925,7 @@ def test_list_features_includes_unresolved_count( db, ): """Test that feature list endpoint returns unresolved_count from thread.""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType # Create a module and feature module = ModuleService.create_module( @@ -1031,7 +1031,7 @@ def test_list_features_zero_unresolved_when_thread_empty( db, ): """Test that feature with thread but empty unresolved_points has count=0.""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType # Create a module and feature module = ModuleService.create_module( @@ -1120,7 +1120,8 @@ def test_create_feature_with_thread_existing_module( assert data["thread_id"] is not None # Verify thread was created - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread + thread = db.query(Thread).filter(Thread.id == data["thread_id"]).first() assert thread is not None assert thread.context_type == ContextType.BRAINSTORM_FEATURE @@ -1155,7 +1156,9 @@ def test_create_feature_with_thread_new_module( # Verify the new module was created from uuid import UUID as PyUUID + from app.models.module import Module + module = db.query(Module).filter(Module.id == PyUUID(data["module_id"])).first() assert module is not None assert module.title == "New Module" @@ -1273,6 +1276,7 @@ def test_create_feature_with_thread_viewer_returns_403( ) from app.auth.utils import create_access_token + token = create_access_token(data={"sub": viewer.email}) headers = {"Authorization": f"Bearer {token}"} @@ -1383,4 +1387,4 @@ def test_clear_project_status_notes( # Verify the implementation is cleared db.refresh(impl) assert impl.is_complete is False - assert impl.implementation_notes is None \ No newline at end of file + assert impl.implementation_notes is None diff --git a/backend/tests/test_feature_import_service.py b/backend/tests/test_feature_import_service.py index c8daa3f..6789a8a 100644 --- a/backend/tests/test_feature_import_service.py +++ b/backend/tests/test_feature_import_service.py @@ -1,14 +1,16 @@ """Tests for feature import service.""" -import pytest + from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 -from app.integrations.base import TicketData, IssueSearchResult -from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType, FeaturePriority +import pytest + +from app.integrations.base import IssueSearchResult, TicketData +from app.models.feature import Feature from app.models.feature_import_comment import FeatureImportComment -from app.models.module import Module from app.models.integration_config import IntegrationConfig +from app.models.module import Module from app.services.feature_import_service import FeatureImportService diff --git a/backend/tests/test_feature_service.py b/backend/tests/test_feature_service.py index f56f9bf..9add631 100644 --- a/backend/tests/test_feature_service.py +++ b/backend/tests/test_feature_service.py @@ -2,21 +2,23 @@ Tests the service layer for feature operations. """ -import pytest + from datetime import datetime, timezone -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session from uuid import uuid4 +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + from app.database import Base -from app.models.user import User +from app.models.feature import FeatureProvenance, FeatureStatus, FeatureVisibilityStatus +from app.models.module import Module, ModuleProvenance from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.module import Module, ModuleProvenance -from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureVisibilityStatus -from app.services.user_service import UserService +from app.models.user import User from app.services.feature_service import FeatureService from app.services.module_service import ModuleService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_final_version_models.py b/backend/tests/test_final_version_models.py index 1c00caf..34b0421 100644 --- a/backend/tests/test_final_version_models.py +++ b/backend/tests/test_final_version_models.py @@ -2,19 +2,19 @@ Tests the final version models for the Phase 7 brainstorming workflow. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from datetime import datetime, timezone +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User -from app.models.organization import Organization -from app.models.project import Project, ProjectStatus from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.spec_version import SpecVersion, SpecType -from app.models.final_spec import FinalSpec from app.models.final_prompt_plan import FinalPromptPlan +from app.models.final_spec import FinalSpec +from app.models.organization import Organization +from app.models.project import Project, ProjectStatus +from app.models.spec_version import SpecType, SpecVersion +from app.models.user import User from app.services.user_service import UserService diff --git a/backend/tests/test_finalization_service.py b/backend/tests/test_finalization_service.py index f9e2699..5fdad9e 100644 --- a/backend/tests/test_finalization_service.py +++ b/backend/tests/test_finalization_service.py @@ -2,22 +2,23 @@ Tests the service layer for finalizing specs and prompt plans. """ + +from uuid import uuid4 + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User -from app.models.organization import Organization -from app.models.project import Project, ProjectStatus from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.spec_version import SpecVersion, SpecType from app.models.final_spec import FinalSpec -from app.models.final_prompt_plan import FinalPromptPlan -from app.services.user_service import UserService +from app.models.organization import Organization +from app.models.project import Project, ProjectStatus +from app.models.spec_version import SpecVersion +from app.models.user import User from app.services.draft_version_service import DraftVersionService from app.services.finalization_service import FinalizationService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_form_draft_router.py b/backend/tests/test_form_draft_router.py index baf9cd6..f23bdc5 100644 --- a/backend/tests/test_form_draft_router.py +++ b/backend/tests/test_form_draft_router.py @@ -2,7 +2,6 @@ import os import tempfile -from unittest.mock import patch, MagicMock import pytest from fastapi.testclient import TestClient @@ -13,13 +12,13 @@ from app.database import Base, get_db from app.main import app from app.models import ( - User, Organization, - Project, - ProjectType, - ProjectStatus, OrgMembership, OrgRole, + Project, + ProjectStatus, + ProjectType, + User, ) from app.models.form_draft import FormDraft, FormDraftType from app.services.user_service import UserService @@ -168,9 +167,7 @@ def test_list_drafts_empty(self, client, sample_project, user_headers): assert response.status_code == 200 assert response.json() == [] - def test_list_drafts_returns_user_drafts( - self, client, db, sample_project, sample_user, user_headers - ): + def test_list_drafts_returns_user_drafts(self, client, db, sample_project, sample_user, user_headers): """Test that list returns the user's drafts.""" # Create a draft draft = FormDraft( @@ -193,9 +190,7 @@ def test_list_drafts_returns_user_drafts( assert len(data) == 1 assert data[0]["name"] == "Test Draft" - def test_list_drafts_filter_by_type( - self, client, db, sample_project, sample_user, user_headers - ): + def test_list_drafts_filter_by_type(self, client, db, sample_project, sample_user, user_headers): """Test filtering drafts by type.""" # Create drafts of different types phase_draft = FormDraft( @@ -272,9 +267,7 @@ def test_get_draft_not_found(self, client, user_headers): assert response.status_code == 404 - def test_get_other_users_draft( - self, client, db, sample_project, sample_user, other_user, other_user_headers - ): + def test_get_other_users_draft(self, client, db, sample_project, sample_user, other_user, other_user_headers): """Test that users cannot get other users' drafts.""" # Create draft for sample_user draft = FormDraft( @@ -318,9 +311,7 @@ def test_create_draft(self, client, sample_project, user_headers): assert data["draft_type"] == "brainstorming_phase" assert data["content"]["title"] == "New Phase" - def test_update_draft( - self, client, db, sample_project, sample_user, user_headers - ): + def test_update_draft(self, client, db, sample_project, sample_user, user_headers): """Test updating an existing draft.""" # Create initial draft draft = FormDraft( @@ -378,9 +369,7 @@ def test_create_draft_with_context(self, client, sample_project, user_headers): class TestDeleteFormDraft: """Tests for DELETE /api/v1/form-drafts/{draft_id}.""" - def test_delete_draft( - self, client, db, sample_project, sample_user, user_headers - ): + def test_delete_draft(self, client, db, sample_project, sample_user, user_headers): """Test deleting a draft.""" draft = FormDraft( user_id=sample_user.id, @@ -418,9 +407,7 @@ def test_delete_draft_not_found(self, client, user_headers): assert response.status_code == 404 - def test_delete_other_users_draft( - self, client, db, sample_project, sample_user, other_user, other_user_headers - ): + def test_delete_other_users_draft(self, client, db, sample_project, sample_user, other_user, other_user_headers): """Test that users cannot delete other users' drafts.""" draft = FormDraft( user_id=sample_user.id, diff --git a/backend/tests/test_form_draft_service.py b/backend/tests/test_form_draft_service.py index c0d6505..16c734a 100644 --- a/backend/tests/test_form_draft_service.py +++ b/backend/tests/test_form_draft_service.py @@ -4,20 +4,21 @@ Verifies form draft CRUD operations including list, get, upsert, and delete. """ +from uuid import uuid4 + import pytest from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker -from uuid import uuid4 from app.database import Base from app.models import ( - User, Organization, Project, - ProjectType, ProjectStatus, + ProjectType, + User, ) -from app.models.form_draft import FormDraft, FormDraftType +from app.models.form_draft import FormDraftType from app.services.form_draft_service import FormDraftService @@ -301,9 +302,7 @@ def test_delete_draft_not_found(self, test_db, sample_user): class TestFormDraftUserIsolation: """Tests for user isolation - ensuring users can't access other users' drafts.""" - def test_cannot_get_other_users_draft( - self, test_db, sample_user, other_user, sample_project - ): + def test_cannot_get_other_users_draft(self, test_db, sample_user, other_user, sample_project): """Test that users cannot get other users' drafts.""" # Create draft as sample_user draft = FormDraftService.upsert_draft( @@ -324,9 +323,7 @@ def test_cannot_get_other_users_draft( assert result is None - def test_cannot_update_other_users_draft( - self, test_db, sample_user, other_user, sample_project - ): + def test_cannot_update_other_users_draft(self, test_db, sample_user, other_user, sample_project): """Test that users cannot update other users' drafts.""" # Create draft as sample_user draft = FormDraftService.upsert_draft( @@ -361,9 +358,7 @@ def test_cannot_update_other_users_draft( ) assert original.name == "Original Name" - def test_cannot_delete_other_users_draft( - self, test_db, sample_user, other_user, sample_project - ): + def test_cannot_delete_other_users_draft(self, test_db, sample_user, other_user, sample_project): """Test that users cannot delete other users' drafts.""" # Create draft as sample_user draft = FormDraftService.upsert_draft( @@ -392,9 +387,7 @@ def test_cannot_delete_other_users_draft( ) assert original is not None - def test_list_only_shows_own_drafts( - self, test_db, sample_user, other_user, sample_project - ): + def test_list_only_shows_own_drafts(self, test_db, sample_user, other_user, sample_project): """Test that list only returns the user's own drafts.""" # Create draft for each user FormDraftService.upsert_draft( diff --git a/backend/tests/test_github_adapter.py b/backend/tests/test_github_adapter.py index edb2044..9264e93 100644 --- a/backend/tests/test_github_adapter.py +++ b/backend/tests/test_github_adapter.py @@ -1,4 +1,5 @@ """Tests for GitHub adapter with PAT and GitHub App authentication.""" + import time from unittest.mock import AsyncMock, MagicMock, patch @@ -183,9 +184,7 @@ async def test_connection_pat_auth_failure(self): mock_response = MagicMock() mock_response.status_code = 401 - error = httpx.HTTPStatusError( - "401 Unauthorized", request=MagicMock(), response=mock_response - ) + error = httpx.HTTPStatusError("401 Unauthorized", request=MagicMock(), response=mock_response) # Create proper async context manager mock mock_client = AsyncMock() @@ -287,9 +286,7 @@ async def test_list_repos_pat_user(self): @pytest.mark.asyncio async def test_list_repos_pat_with_org(self): """Test listing repos for PAT with org filter.""" - adapter = GitHubAdapter( - token="ghp_test", config={"auth_type": "pat", "org_name": "my-org"} - ) + adapter = GitHubAdapter(token="ghp_test", config={"auth_type": "pat", "org_name": "my-org"}) mock_response = MagicMock() mock_response.json.return_value = [ diff --git a/backend/tests/test_grounding_note_endpoints.py b/backend/tests/test_grounding_note_endpoints.py index 9fe6c32..b361d56 100644 --- a/backend/tests/test_grounding_note_endpoints.py +++ b/backend/tests/test_grounding_note_endpoints.py @@ -1,5 +1,5 @@ """Tests for grounding note REST API endpoints.""" -import pytest + from uuid import uuid4 from app.services.grounding_note_service import GroundingNoteService @@ -225,9 +225,9 @@ def test_create_grounding_note_requires_member_role( db, ): """Test that creating a note requires MEMBER role on the project.""" - from app.services.user_service import UserService - from app.services.project_share_service import ProjectShareService from app.models.project_membership import ProjectRole + from app.services.project_share_service import ProjectShareService + from app.services.user_service import UserService # Create another user and give them VIEWER access (not MEMBER) other_user = UserService.create_user( diff --git a/backend/tests/test_grounding_note_service.py b/backend/tests/test_grounding_note_service.py index e44ea21..7a5eca9 100644 --- a/backend/tests/test_grounding_note_service.py +++ b/backend/tests/test_grounding_note_service.py @@ -2,17 +2,17 @@ Tests the service layer for grounding note version operations. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.grounding_note_version import GroundingNoteVersion -from app.services.user_service import UserService +from app.models.user import User from app.services.grounding_note_service import GroundingNoteService +from app.services.user_service import UserService @pytest.fixture @@ -157,9 +157,7 @@ def test_get_active_version( created_by=sample_user.id, ) - active = GroundingNoteService.get_active_version( - test_db_session, sample_project.id - ) + active = GroundingNoteService.get_active_version(test_db_session, sample_project.id) assert active is not None assert active.version == 2 @@ -172,9 +170,7 @@ def test_get_active_version_returns_none_when_no_notes( sample_project: Project, ): """Test that get_active_version returns None when no notes exist.""" - active = GroundingNoteService.get_active_version( - test_db_session, sample_project.id - ) + active = GroundingNoteService.get_active_version(test_db_session, sample_project.id) assert active is None @@ -205,9 +201,7 @@ def test_list_versions( created_by=sample_user.id, ) - versions = GroundingNoteService.list_versions( - test_db_session, sample_project.id - ) + versions = GroundingNoteService.list_versions(test_db_session, sample_project.id) # Should be in descending order (newest first) assert len(versions) == 3 @@ -221,9 +215,7 @@ def test_list_versions_empty_when_no_notes( sample_project: Project, ): """Test that list_versions returns empty list when no notes exist.""" - versions = GroundingNoteService.list_versions( - test_db_session, sample_project.id - ) + versions = GroundingNoteService.list_versions(test_db_session, sample_project.id) assert versions == [] @@ -253,6 +245,7 @@ def test_get_version_not_found( ): """Test that get_version returns None for non-existent ID.""" import uuid + random_id = uuid.uuid4() retrieved = GroundingNoteService.get_version(test_db_session, random_id) diff --git a/backend/tests/test_grounding_service.py b/backend/tests/test_grounding_service.py index 90dd13c..68e2f23 100644 --- a/backend/tests/test_grounding_service.py +++ b/backend/tests/test_grounding_service.py @@ -3,17 +3,17 @@ Tests the service layer for grounding file operations, particularly the summary field functionality. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.grounding_file import GroundingFile -from app.services.user_service import UserService +from app.models.user import User from app.services.grounding_service import GroundingService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_health.py b/backend/tests/test_health.py index 3b043cd..e60cc6c 100644 --- a/backend/tests/test_health.py +++ b/backend/tests/test_health.py @@ -3,9 +3,10 @@ Following TDD: this test is written first and will fail until the endpoint is implemented. """ + from fastapi.testclient import TestClient -from app.main import app +from app.main import app client = TestClient(app) diff --git a/backend/tests/test_identity_models.py b/backend/tests/test_identity_models.py index ccce8d6..cb5c2ff 100644 --- a/backend/tests/test_identity_models.py +++ b/backend/tests/test_identity_models.py @@ -4,16 +4,16 @@ Verifies that OAuth-related models work correctly with proper constraints and relationships. """ + import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from uuid import uuid4 from app.database import Base from app.models import ( - User, IdentityProvider, IdentityProviderType, + User, UserIdentity, ) @@ -55,11 +55,7 @@ def test_db_session(test_engine): @pytest.fixture def sample_user(test_db_session): """Create a sample user for testing.""" - user = User( - email="testuser@example.com", - password_hash="hashed_password", - display_name="Test User" - ) + user = User(email="testuser@example.com", password_hash="hashed_password", display_name="Test User") test_db_session.add(user) test_db_session.commit() test_db_session.refresh(user) @@ -69,11 +65,7 @@ def sample_user(test_db_session): @pytest.fixture def sample_user_oauth_only(test_db_session): """Create a sample user without password (OAuth-only).""" - user = User( - email="oauth@example.com", - password_hash=None, - display_name="OAuth User" - ) + user = User(email="oauth@example.com", password_hash=None, display_name="OAuth User") test_db_session.add(user) test_db_session.commit() test_db_session.refresh(user) @@ -300,7 +292,9 @@ def test_identity_unique_constraint(self, test_db_session, sample_user, google_p with pytest.raises(Exception): test_db_session.commit() - def test_identity_same_subject_different_providers(self, test_db_session, sample_user, google_provider, github_provider): + def test_identity_same_subject_different_providers( + self, test_db_session, sample_user, google_provider, github_provider + ): """Test that same subject can exist across different providers.""" google_identity = UserIdentity( user_id=sample_user.id, @@ -374,7 +368,7 @@ def test_identity_cascade_delete_provider(self, test_db_session, sample_user): def test_identity_with_tokens(self, test_db_session, sample_user, google_provider): """Test storing OAuth tokens in identity.""" - from datetime import datetime, timezone, timedelta + from datetime import datetime, timedelta, timezone expires_at = datetime.now(timezone.utc) + timedelta(hours=1) identity = UserIdentity( diff --git a/backend/tests/test_image_service.py b/backend/tests/test_image_service.py index da609e5..47de060 100644 --- a/backend/tests/test_image_service.py +++ b/backend/tests/test_image_service.py @@ -1,25 +1,18 @@ """Tests for ImageService.""" + import base64 -import json import os -import tempfile from io import BytesIO -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock from uuid import uuid4 import pytest -from PIL import Image -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from PIL import Image -from app.database import Base -from app.models.platform_connector import PlatformConnector -from app.models.platform_settings import PlatformSettings -from app.services.image_service import ImageService, ImageMetadata +from app.services.image_service import ImageMetadata, ImageService def get_test_fernet(): @@ -192,9 +185,7 @@ def test_generate_s3_keys_structure(self): org_id = str(uuid4()) project_id = str(uuid4()) - image_key, thumb_key = service._generate_s3_keys( - org_id, project_id, "image/jpeg" - ) + image_key, thumb_key = service._generate_s3_keys(org_id, project_id, "image/jpeg") # Verify structure assert image_key.startswith(f"orgs/{org_id}/projects/{project_id}/images/") @@ -333,9 +324,7 @@ async def test_get_signed_url(self): service = ImageService(mock_db) mock_s3 = MagicMock() - mock_s3.generate_presigned_url = MagicMock( - return_value="https://bucket.s3.amazonaws.com/key?signature=xxx" - ) + mock_s3.generate_presigned_url = MagicMock(return_value="https://bucket.s3.amazonaws.com/key?signature=xxx") async def mock_get_s3(): return mock_s3, "test-bucket" diff --git a/backend/tests/test_image_signing.py b/backend/tests/test_image_signing.py index ed65171..1f9b565 100644 --- a/backend/tests/test_image_signing.py +++ b/backend/tests/test_image_signing.py @@ -163,9 +163,7 @@ def test_roundtrip_generate_and_verify(self, mock_signing_key): class TestImageResizing: """Tests for image resizing functionality.""" - def _create_test_image( - self, width: int, height: int, color: str = "red", format: str = "JPEG" - ) -> bytes: + def _create_test_image(self, width: int, height: int, color: str = "red", format: str = "JPEG") -> bytes: """Create a test image with given dimensions.""" mode = "RGB" if format != "PNG" else "RGBA" img = Image.new(mode, (width, height), color=color) diff --git a/backend/tests/test_images_router.py b/backend/tests/test_images_router.py index fb38043..3c3d226 100644 --- a/backend/tests/test_images_router.py +++ b/backend/tests/test_images_router.py @@ -3,7 +3,6 @@ from io import BytesIO from unittest.mock import MagicMock, patch -import pytest from PIL import Image from app.services.image_service import ImageService @@ -23,15 +22,9 @@ def test_serve_image_with_valid_signature(self, client): """Test serving image with valid signature.""" test_image_bytes = self._create_test_image() - with patch.object( - ImageService, "verify_image_signature", return_value=(True, None) - ): - with patch.object( - ImageService, "get_image_bytes", return_value=test_image_bytes - ): - with patch( - "app.routers.images.SessionLocal" - ) as mock_session_class: + with patch.object(ImageService, "verify_image_signature", return_value=(True, None)): + with patch.object(ImageService, "get_image_bytes", return_value=test_image_bytes): + with patch("app.routers.images.SessionLocal") as mock_session_class: # Set up mock DB session mock_db = MagicMock() mock_session_class.return_value = mock_db @@ -47,13 +40,10 @@ def test_serve_image_with_valid_signature(self, client): } ] } - mock_db.query.return_value.filter.return_value.all.return_value = [ - mock_item - ] + mock_db.query.return_value.filter.return_value.all.return_value = [mock_item] response = client.get( - "/api/v1/images/test-image-id" - "?max_width=800&max_height=600&exp=9999999999&sig=valid" + "/api/v1/images/test-image-id?max_width=800&max_height=600&exp=9999999999&sig=valid" ) assert response.status_code == 200 @@ -67,10 +57,7 @@ def test_serve_image_with_invalid_signature(self, client): "verify_image_signature", return_value=(False, "Invalid signature"), ): - response = client.get( - "/api/v1/images/test-id" - "?max_width=800&max_height=600&exp=1&sig=invalid" - ) + response = client.get("/api/v1/images/test-id?max_width=800&max_height=600&exp=1&sig=invalid") assert response.status_code == 403 assert "invalid" in response.json()["detail"].lower() @@ -82,30 +69,22 @@ def test_serve_image_with_expired_url(self, client): "verify_image_signature", return_value=(False, "URL has expired"), ): - response = client.get( - "/api/v1/images/test-id" - "?max_width=800&max_height=600&exp=1&sig=expired" - ) + response = client.get("/api/v1/images/test-id?max_width=800&max_height=600&exp=1&sig=expired") assert response.status_code == 403 assert "expired" in response.json()["detail"].lower() def test_serve_image_not_found(self, client): """Test that missing images return 404.""" - with patch.object( - ImageService, "verify_image_signature", return_value=(True, None) - ): - with patch( - "app.routers.images.SessionLocal" - ) as mock_session_class: + with patch.object(ImageService, "verify_image_signature", return_value=(True, None)): + with patch("app.routers.images.SessionLocal") as mock_session_class: mock_db = MagicMock() mock_session_class.return_value = mock_db # Return empty list - no images found mock_db.query.return_value.filter.return_value.all.return_value = [] response = client.get( - "/api/v1/images/nonexistent" - "?max_width=800&max_height=600&exp=9999999999&sig=valid" + "/api/v1/images/nonexistent?max_width=800&max_height=600&exp=9999999999&sig=valid" ) assert response.status_code == 404 @@ -113,17 +92,13 @@ def test_serve_image_not_found(self, client): def test_serve_image_s3_fetch_fails(self, client): """Test that S3 fetch failures return 500.""" - with patch.object( - ImageService, "verify_image_signature", return_value=(True, None) - ): + with patch.object(ImageService, "verify_image_signature", return_value=(True, None)): with patch.object( ImageService, "get_image_bytes", side_effect=ValueError("S3 not configured"), ): - with patch( - "app.routers.images.SessionLocal" - ) as mock_session_class: + with patch("app.routers.images.SessionLocal") as mock_session_class: mock_db = MagicMock() mock_session_class.return_value = mock_db mock_item = MagicMock() @@ -136,13 +111,10 @@ def test_serve_image_s3_fetch_fails(self, client): } ] } - mock_db.query.return_value.filter.return_value.all.return_value = [ - mock_item - ] + mock_db.query.return_value.filter.return_value.all.return_value = [mock_item] response = client.get( - "/api/v1/images/test-id" - "?max_width=800&max_height=600&exp=9999999999&sig=valid" + "/api/v1/images/test-id?max_width=800&max_height=600&exp=9999999999&sig=valid" ) assert response.status_code == 500 @@ -150,35 +122,25 @@ def test_serve_image_s3_fetch_fails(self, client): def test_serve_image_missing_exp_param(self, client): """Test that missing exp parameter returns 422.""" - response = client.get( - "/api/v1/images/test-id" "?max_width=800&max_height=600&sig=valid" - ) + response = client.get("/api/v1/images/test-id?max_width=800&max_height=600&sig=valid") assert response.status_code == 422 # Validation error def test_serve_image_missing_sig_param(self, client): """Test that missing sig parameter returns 422.""" - response = client.get( - "/api/v1/images/test-id" "?max_width=800&max_height=600&exp=9999999999" - ) + response = client.get("/api/v1/images/test-id?max_width=800&max_height=600&exp=9999999999") assert response.status_code == 422 # Validation error def test_serve_image_max_width_exceeds_limit(self, client): """Test that max_width > 4096 returns 422.""" - response = client.get( - "/api/v1/images/test-id" - "?max_width=5000&max_height=600&exp=9999999999&sig=valid" - ) + response = client.get("/api/v1/images/test-id?max_width=5000&max_height=600&exp=9999999999&sig=valid") assert response.status_code == 422 # Validation error def test_serve_image_max_width_zero_or_negative(self, client): """Test that max_width <= 0 returns 422.""" - response = client.get( - "/api/v1/images/test-id" - "?max_width=0&max_height=600&exp=9999999999&sig=valid" - ) + response = client.get("/api/v1/images/test-id?max_width=0&max_height=600&exp=9999999999&sig=valid") assert response.status_code == 422 # Validation error @@ -186,15 +148,9 @@ def test_serve_image_uses_default_dimensions(self, client): """Test that default dimensions (800x600) are used when not specified.""" test_image_bytes = self._create_test_image() - with patch.object( - ImageService, "verify_image_signature", return_value=(True, None) - ) as mock_verify: - with patch.object( - ImageService, "get_image_bytes", return_value=test_image_bytes - ): - with patch( - "app.routers.images.SessionLocal" - ) as mock_session_class: + with patch.object(ImageService, "verify_image_signature", return_value=(True, None)) as mock_verify: + with patch.object(ImageService, "get_image_bytes", return_value=test_image_bytes): + with patch("app.routers.images.SessionLocal") as mock_session_class: mock_db = MagicMock() mock_session_class.return_value = mock_db mock_item = MagicMock() @@ -207,14 +163,10 @@ def test_serve_image_uses_default_dimensions(self, client): } ] } - mock_db.query.return_value.filter.return_value.all.return_value = [ - mock_item - ] + mock_db.query.return_value.filter.return_value.all.return_value = [mock_item] # Don't specify max_width and max_height - response = client.get( - "/api/v1/images/test-id" "?exp=9999999999&sig=valid" - ) + response = client.get("/api/v1/images/test-id?exp=9999999999&sig=valid") # Verify that verify_image_signature was called with defaults mock_verify.assert_called_once() diff --git a/backend/tests/test_inbox_badge_service.py b/backend/tests/test_inbox_badge_service.py index c64dca5..c5c7084 100644 --- a/backend/tests/test_inbox_badge_service.py +++ b/backend/tests/test_inbox_badge_service.py @@ -1,17 +1,18 @@ """Tests for inbox badge service.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_mention import InboxConversationType from app.models.organization import Organization from app.models.project import Project -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType -from app.models.inbox_mention import InboxConversationType +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType +from app.models.user import User from app.services.inbox_badge_service import InboxBadgeService from app.services.inbox_mention_service import InboxMentionService from app.services.inbox_status_service import InboxStatusService diff --git a/backend/tests/test_inbox_broadcast_service.py b/backend/tests/test_inbox_broadcast_service.py index 1c78078..a20eece 100644 --- a/backend/tests/test_inbox_broadcast_service.py +++ b/backend/tests/test_inbox_broadcast_service.py @@ -1,16 +1,17 @@ """Tests for inbox broadcast service.""" -import pytest + from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_follow import InboxThreadType from app.models.organization import Organization from app.models.project import Project -from app.models.inbox_follow import InboxFollow, InboxFollowType, InboxThreadType -from app.services.inbox_broadcast_service import InboxBroadcastService, MAX_PREVIEW_LENGTH +from app.models.user import User +from app.services.inbox_broadcast_service import MAX_PREVIEW_LENGTH, InboxBroadcastService from app.services.inbox_follow_service import InboxFollowService @@ -109,15 +110,11 @@ def test_get_recipients_thread_and_project_combined(self, db: Session, setup_dat InboxFollowService.follow_project(db, user1.id, project.id) # User2 follows the specific thread - InboxFollowService.follow_thread( - db, user2.id, project.id, thread_id, InboxThreadType.FEATURE - ) + InboxFollowService.follow_thread(db, user2.id, project.id, thread_id, InboxThreadType.FEATURE) # User3 follows both (should only appear once) InboxFollowService.follow_project(db, user3.id, project.id) - InboxFollowService.follow_thread( - db, user3.id, project.id, thread_id, InboxThreadType.FEATURE - ) + InboxFollowService.follow_thread(db, user3.id, project.id, thread_id, InboxThreadType.FEATURE) recipients = InboxBroadcastService.get_recipients_for_thread( db=db, @@ -189,9 +186,7 @@ def test_truncate_preview(self): assert truncated.endswith("...") @patch("app.services.kafka_producer.get_sync_kafka_producer") - def test_broadcast_new_message_publishes( - self, mock_kafka, db: Session, setup_data - ): + def test_broadcast_new_message_publishes(self, mock_kafka, db: Session, setup_data): """Test that broadcast_new_message publishes to Kafka.""" data = setup_data user1 = data["user1"] @@ -236,9 +231,7 @@ def test_broadcast_new_message_publishes( assert message["payload"]["author_name"] == "User Two" @patch("app.services.kafka_producer.get_sync_kafka_producer") - def test_broadcast_new_message_no_followers( - self, mock_kafka, db: Session, setup_data - ): + def test_broadcast_new_message_no_followers(self, mock_kafka, db: Session, setup_data): """Test that broadcast returns True but doesn't publish when no followers.""" data = setup_data user2 = data["user2"] @@ -266,9 +259,7 @@ def test_broadcast_new_message_no_followers( mock_producer.publish.assert_not_called() @patch("app.services.kafka_producer.get_sync_kafka_producer") - def test_broadcast_mention_targets_only_mentioned_user( - self, mock_kafka, db: Session, setup_data - ): + def test_broadcast_mention_targets_only_mentioned_user(self, mock_kafka, db: Session, setup_data): """Test that mention broadcast only targets the mentioned user.""" data = setup_data user1 = data["user1"] # mentioned @@ -303,9 +294,7 @@ def test_broadcast_mention_targets_only_mentioned_user( assert str(user1.id) in message["target_user_ids"] @patch("app.services.kafka_producer.get_sync_kafka_producer") - def test_broadcast_read_status_changed_targets_user( - self, mock_kafka, db: Session, setup_data - ): + def test_broadcast_read_status_changed_targets_user(self, mock_kafka, db: Session, setup_data): """Test that read status broadcast targets the specific user.""" data = setup_data user1 = data["user1"] @@ -374,9 +363,7 @@ def test_get_recipients_for_project_chat(self, db: Session, setup_data): InboxFollowService.follow_project(db, user1.id, project.id) # User2 follows the specific chat thread - InboxFollowService.follow_thread( - db, user2.id, project.id, str(chat_id), InboxThreadType.PROJECT_CHAT - ) + InboxFollowService.follow_thread(db, user2.id, project.id, str(chat_id), InboxThreadType.PROJECT_CHAT) recipients = InboxBroadcastService.get_recipients_for_project_chat( db=db, @@ -389,9 +376,7 @@ def test_get_recipients_for_project_chat(self, db: Session, setup_data): assert user1.id in recipients assert user2.id in recipients - def test_get_recipients_for_project_chat_excludes_author( - self, db: Session, setup_data - ): + def test_get_recipients_for_project_chat_excludes_author(self, db: Session, setup_data): """Test that author is excluded from project chat recipients.""" data = setup_data user1 = data["user1"] @@ -426,9 +411,7 @@ def test_get_recipients_for_org_scoped_chat(self, db: Session, setup_data): InboxFollowService.follow_project(db, user1.id, project.id) # User2 follows the specific chat thread - InboxFollowService.follow_thread( - db, user2.id, project.id, str(chat_id), InboxThreadType.PROJECT_CHAT - ) + InboxFollowService.follow_thread(db, user2.id, project.id, str(chat_id), InboxThreadType.PROJECT_CHAT) # For org-scoped chat (project_id=None), only thread followers receive recipients = InboxBroadcastService.get_recipients_for_project_chat( diff --git a/backend/tests/test_inbox_conversation_service.py b/backend/tests/test_inbox_conversation_service.py index 82d16f8..785f63c 100644 --- a/backend/tests/test_inbox_conversation_service.py +++ b/backend/tests/test_inbox_conversation_service.py @@ -1,27 +1,28 @@ """Tests for inbox conversation aggregation service.""" -import pytest -from datetime import datetime, timezone, timedelta + +from datetime import datetime, timedelta, timezone from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User -from app.models.organization import Organization -from app.models.project import Project -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleProvenance, ModuleType from app.models.feature import Feature, FeatureProvenance, FeatureType from app.models.inbox_mention import InboxConversationType -from app.services.inbox_conversation_service import InboxConversationService -from app.services.inbox_status_service import InboxStatusService +from app.models.module import Module, ModuleProvenance, ModuleType +from app.models.organization import Organization +from app.models.project import Project +from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType +from app.models.user import User from app.schemas.inbox_conversation import ( - InboxConversationsRequest, ConversationSortField, + InboxConversationsRequest, SortOrder, ) +from app.services.inbox_conversation_service import InboxConversationService +from app.services.inbox_status_service import InboxStatusService class TestInboxConversationService: @@ -107,9 +108,7 @@ def test_empty_inbox_nonexistent_project(self, db: Session, setup_data): data = setup_data request = InboxConversationsRequest() - response = InboxConversationService.get_user_inbox_conversations( - db, data["user"].id, uuid4(), request - ) + response = InboxConversationService.get_user_inbox_conversations(db, data["user"].id, uuid4(), request) assert response.total == 0 assert response.conversations == [] @@ -513,9 +512,7 @@ def test_filter_by_conversation_type(self, db: Session, setup_data): ) # Filter by project chat only - request = InboxConversationsRequest( - conversation_type=InboxConversationType.PROJECT_CHAT - ) + request = InboxConversationsRequest(conversation_type=InboxConversationType.PROJECT_CHAT) response = InboxConversationService.get_user_inbox_conversations( db, data["user"].id, data["project"].id, request ) diff --git a/backend/tests/test_inbox_deep_link.py b/backend/tests/test_inbox_deep_link.py index 5cb6ac9..6416eb1 100644 --- a/backend/tests/test_inbox_deep_link.py +++ b/backend/tests/test_inbox_deep_link.py @@ -1,21 +1,21 @@ """Tests for inbox deep link resolution.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 -from fastapi.testclient import TestClient +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType +from app.models.feature import Feature, FeatureProvenance, FeatureType +from app.models.inbox_mention import InboxConversationType +from app.models.module import Module, ModuleProvenance, ModuleType from app.models.organization import Organization from app.models.project import Project -from app.models.thread import Thread, ContextType -from app.models.feature import Feature, FeatureProvenance, FeatureType -from app.models.module import Module, ModuleType, ModuleProvenance -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType from app.models.project_chat import ProjectChat -from app.models.project_share import ProjectShare, ProjectRole -from app.models.inbox_mention import InboxConversationType +from app.models.project_share import ProjectRole, ProjectShare +from app.models.thread import ContextType, Thread +from app.models.user import User from app.services.inbox_conversation_service import InboxConversationService from app.utils.deep_link import build_inbox_deep_link_url, build_redirect_url, slugify diff --git a/backend/tests/test_inbox_follow_router.py b/backend/tests/test_inbox_follow_router.py index 6ee5c29..51b808e 100644 --- a/backend/tests/test_inbox_follow_router.py +++ b/backend/tests/test_inbox_follow_router.py @@ -1,17 +1,18 @@ """Tests for inbox follow router endpoints.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 +import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_follow import InboxThreadType from app.models.organization import Organization from app.models.project import Project -from app.models.project_share import ProjectShare, ShareSubjectType from app.models.project_membership import ProjectRole -from app.models.inbox_follow import InboxFollowType, InboxThreadType +from app.models.project_share import ProjectShare, ShareSubjectType +from app.models.user import User from app.services.inbox_follow_service import InboxFollowService diff --git a/backend/tests/test_inbox_follow_service.py b/backend/tests/test_inbox_follow_service.py index 2bf2fd0..2eebaa6 100644 --- a/backend/tests/test_inbox_follow_service.py +++ b/backend/tests/test_inbox_follow_service.py @@ -1,14 +1,15 @@ """Tests for inbox follow service.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_follow import InboxFollowType, InboxThreadType from app.models.organization import Organization from app.models.project import Project -from app.models.inbox_follow import InboxFollow, InboxFollowType, InboxThreadType +from app.models.user import User from app.services.inbox_follow_service import InboxFollowService @@ -201,9 +202,7 @@ def test_unfollow_thread(self, db: Session, setup_data): ) assert result is True - assert not InboxFollowService.is_following_thread( - db, user.id, thread_id, InboxThreadType.FEATURE - ) + assert not InboxFollowService.is_following_thread(db, user.id, thread_id, InboxThreadType.FEATURE) def test_is_following_project(self, db: Session, setup_data): """Test checking if following a project.""" @@ -228,9 +227,7 @@ def test_is_following_thread(self, db: Session, setup_data): thread_id = str(uuid4()) # Not following initially - assert not InboxFollowService.is_following_thread( - db, user.id, thread_id, InboxThreadType.FEATURE - ) + assert not InboxFollowService.is_following_thread(db, user.id, thread_id, InboxThreadType.FEATURE) # Follow InboxFollowService.follow_thread( @@ -242,9 +239,7 @@ def test_is_following_thread(self, db: Session, setup_data): ) # Now following - assert InboxFollowService.is_following_thread( - db, user.id, thread_id, InboxThreadType.FEATURE - ) + assert InboxFollowService.is_following_thread(db, user.id, thread_id, InboxThreadType.FEATURE) def test_get_project_follows(self, db: Session, setup_data): """Test getting all project follows for a user.""" @@ -298,9 +293,7 @@ def test_get_thread_follows_in_project(self, db: Session, setup_data): thread_type=InboxThreadType.PHASE, ) - follows = InboxFollowService.get_thread_follows_in_project( - db, user.id, project.id - ) + follows = InboxFollowService.get_thread_follows_in_project(db, user.id, project.id) thread_ids = [f.thread_id for f in follows] assert len(follows) == 2 @@ -348,9 +341,7 @@ def test_get_users_following_thread(self, db: Session, setup_data): thread_type=InboxThreadType.FEATURE, ) - user_ids = InboxFollowService.get_users_following_thread( - db, thread_id, InboxThreadType.FEATURE - ) + user_ids = InboxFollowService.get_users_following_thread(db, thread_id, InboxThreadType.FEATURE) assert len(user_ids) == 2 assert user.id in user_ids @@ -385,12 +376,8 @@ def test_different_thread_types_are_separate(self, db: Session, setup_data): assert follow1.id != follow2.id # Check both exist - assert InboxFollowService.is_following_thread( - db, user.id, thread_id, InboxThreadType.FEATURE - ) - assert InboxFollowService.is_following_thread( - db, user.id, thread_id, InboxThreadType.PHASE - ) + assert InboxFollowService.is_following_thread(db, user.id, thread_id, InboxThreadType.FEATURE) + assert InboxFollowService.is_following_thread(db, user.id, thread_id, InboxThreadType.PHASE) def test_project_chat_thread_type(self, db: Session, setup_data): """Test following a project chat thread.""" @@ -408,6 +395,4 @@ def test_project_chat_thread_type(self, db: Session, setup_data): ) assert follow.thread_type == InboxThreadType.PROJECT_CHAT - assert InboxFollowService.is_following_thread( - db, user.id, chat_id, InboxThreadType.PROJECT_CHAT - ) + assert InboxFollowService.is_following_thread(db, user.id, chat_id, InboxThreadType.PROJECT_CHAT) diff --git a/backend/tests/test_inbox_mention_service.py b/backend/tests/test_inbox_mention_service.py index 13078bd..7a63cc0 100644 --- a/backend/tests/test_inbox_mention_service.py +++ b/backend/tests/test_inbox_mention_service.py @@ -1,14 +1,15 @@ """Tests for inbox mention service.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_mention import InboxConversationType from app.models.organization import Organization from app.models.project import Project -from app.models.inbox_mention import InboxMention, InboxConversationType +from app.models.user import User from app.services.inbox_mention_service import InboxMentionService diff --git a/backend/tests/test_inbox_status_service.py b/backend/tests/test_inbox_status_service.py index 7ea36b8..a5fb2d4 100644 --- a/backend/tests/test_inbox_status_service.py +++ b/backend/tests/test_inbox_status_service.py @@ -1,17 +1,18 @@ """Tests for inbox status service (read watermarks).""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_mention import InboxConversationType from app.models.organization import Organization from app.models.project import Project -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType -from app.models.inbox_mention import InboxConversationType +from app.models.thread import ContextType, Thread +from app.models.thread_item import ThreadItem, ThreadItemType +from app.models.user import User from app.services.inbox_status_service import InboxStatusService @@ -290,9 +291,7 @@ def test_get_followed_conversations_by_type(self, db: Session, setup_data): assert len(all_followed) == 2 # Get only feature type - feature_followed = InboxStatusService.get_followed_conversations( - db, user.id, InboxConversationType.FEATURE - ) + feature_followed = InboxStatusService.get_followed_conversations(db, user.id, InboxConversationType.FEATURE) assert len(feature_followed) == 1 assert feature_followed[0].conversation_type == InboxConversationType.FEATURE @@ -327,9 +326,7 @@ def test_get_unread_count_for_thread(self, db: Session, setup_data): db.commit() # No status = all unread - unread = InboxStatusService.get_unread_count_for_thread( - db, user.id, thread.id, InboxConversationType.FEATURE - ) + unread = InboxStatusService.get_unread_count_for_thread(db, user.id, thread.id, InboxConversationType.FEATURE) assert unread == 5 # Read up to 3 @@ -341,9 +338,7 @@ def test_get_unread_count_for_thread(self, db: Session, setup_data): sequence_number=3, ) - unread = InboxStatusService.get_unread_count_for_thread( - db, user.id, thread.id, InboxConversationType.FEATURE - ) + unread = InboxStatusService.get_unread_count_for_thread(db, user.id, thread.id, InboxConversationType.FEATURE) assert unread == 2 # Items 4 and 5 unread def test_get_unread_count_for_project_chat(self, db: Session, setup_data): @@ -376,9 +371,7 @@ def test_get_unread_count_for_project_chat(self, db: Session, setup_data): db.commit() # No status = all unread - unread = InboxStatusService.get_unread_count_for_project_chat( - db, user.id, chat.id - ) + unread = InboxStatusService.get_unread_count_for_project_chat(db, user.id, chat.id) assert unread == 4 # Read up to 2 @@ -390,9 +383,7 @@ def test_get_unread_count_for_project_chat(self, db: Session, setup_data): sequence_number=2, ) - unread = InboxStatusService.get_unread_count_for_project_chat( - db, user.id, chat.id - ) + unread = InboxStatusService.get_unread_count_for_project_chat(db, user.id, chat.id) assert unread == 2 # Messages 3 and 4 unread def test_mark_all_as_read_for_thread(self, db: Session, setup_data): @@ -436,9 +427,7 @@ def test_mark_all_as_read_for_thread(self, db: Session, setup_data): assert status.last_read_sequence == 3 # Check unread count - unread = InboxStatusService.get_unread_count_for_thread( - db, user.id, thread.id, InboxConversationType.FEATURE - ) + unread = InboxStatusService.get_unread_count_for_thread(db, user.id, thread.id, InboxConversationType.FEATURE) assert unread == 0 def test_different_conversation_types_are_separate(self, db: Session, setup_data): @@ -469,9 +458,7 @@ def test_different_conversation_types_are_separate(self, db: Session, setup_data feature_seq = InboxStatusService.get_last_read_sequence( db, user.id, InboxConversationType.FEATURE, conversation_id ) - phase_seq = InboxStatusService.get_last_read_sequence( - db, user.id, InboxConversationType.PHASE, conversation_id - ) + phase_seq = InboxStatusService.get_last_read_sequence(db, user.id, InboxConversationType.PHASE, conversation_id) assert feature_seq == 10 assert phase_seq == 5 diff --git a/backend/tests/test_invitation_router.py b/backend/tests/test_invitation_router.py index 5b4e7db..9eb41a5 100644 --- a/backend/tests/test_invitation_router.py +++ b/backend/tests/test_invitation_router.py @@ -1,15 +1,15 @@ """Integration tests for invitation router endpoints.""" -import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch + from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Organization, OrgRole +from app.models import Organization, OrgRole, User from app.models.org_invitation import InvitationStatus from app.services.invitation_service import InvitationService -from app.services.user_group_service import UserGroupService from app.services.org_service import OrgService +from app.services.user_group_service import UserGroupService class TestCreateInvitations: @@ -134,6 +134,7 @@ def test_create_invitation_with_groups( # Verify the invitation has group assignment from uuid import UUID + invitation_id = UUID(data["results"][0]["invitation_id"]) invitation = InvitationService.get_invitation_by_id(db, invitation_id) assert len(invitation.group_assignments) == 1 @@ -613,9 +614,7 @@ def test_cancel_invitation_wrong_org( ): """Test that invitation from different org returns 404.""" # Create another org - other_org, _ = OrgService.create_org_with_owner( - db, "Other Org", test_user.id - ) + other_org, _ = OrgService.create_org_with_owner(db, "Other Org", test_user.id) # Create invitation in other org invitation = InvitationService.create_invitation( diff --git a/backend/tests/test_invitation_service.py b/backend/tests/test_invitation_service.py index 480e1e5..3f3861b 100644 --- a/backend/tests/test_invitation_service.py +++ b/backend/tests/test_invitation_service.py @@ -1,10 +1,10 @@ """Tests for invitation service.""" -import pytest from datetime import datetime, timedelta, timezone + from sqlalchemy.orm import Session -from app.models import User, Organization, OrgRole +from app.models import Organization, OrgRole, User from app.models.org_invitation import InvitationStatus from app.models.provisioning import ProvisioningSource from app.services.invitation_service import InvitationService @@ -38,9 +38,7 @@ def test_create_invitation(self, db: Session, test_user: User, test_org: Organiz expires_at = expires_at.replace(tzinfo=timezone.utc) assert expires_at > datetime.now(timezone.utc) - def test_create_invitation_normalizes_email( - self, db: Session, test_user: User, test_org: Organization - ): + def test_create_invitation_normalizes_email(self, db: Session, test_user: User, test_org: Organization): """Test that email is normalized to lowercase and trimmed.""" invitation = InvitationService.create_invitation( db=db, @@ -52,9 +50,7 @@ def test_create_invitation_normalizes_email( assert invitation.email == "invitee@example.com" - def test_create_invitation_with_groups( - self, db: Session, test_user: User, test_org: Organization - ): + def test_create_invitation_with_groups(self, db: Session, test_user: User, test_org: Organization): """Test creating an invitation with group assignments.""" group = UserGroupService.create_group( db=db, @@ -75,9 +71,7 @@ def test_create_invitation_with_groups( assert len(invitation.group_assignments) == 1 assert invitation.group_assignments[0].group_id == group.id - def test_create_invitation_custom_expiry( - self, db: Session, test_user: User, test_org: Organization - ): + def test_create_invitation_custom_expiry(self, db: Session, test_user: User, test_org: Organization): """Test creating an invitation with custom expiry.""" invitation = InvitationService.create_invitation( db=db, @@ -96,9 +90,7 @@ def test_create_invitation_custom_expiry( # Allow 1 minute tolerance assert abs((expires_at - expected_expiry).total_seconds()) < 60 - def test_get_invitation_by_token( - self, db: Session, test_user: User, test_org: Organization - ): + def test_get_invitation_by_token(self, db: Session, test_user: User, test_org: Organization): """Test getting an invitation by token.""" invitation = InvitationService.create_invitation( db=db, @@ -117,9 +109,7 @@ def test_get_invitation_by_token_not_found(self, db: Session): fetched = InvitationService.get_invitation_by_token(db, "nonexistent-token") assert fetched is None - def test_validate_invitation_valid( - self, db: Session, test_user: User, test_org: Organization - ): + def test_validate_invitation_valid(self, db: Session, test_user: User, test_org: Organization): """Test validating a valid invitation.""" invitation = InvitationService.create_invitation( db=db, @@ -140,9 +130,7 @@ def test_validate_invitation_not_found(self, db: Session): assert result is None assert error == "Invitation not found" - def test_validate_invitation_already_accepted( - self, db: Session, test_user: User, test_org: Organization - ): + def test_validate_invitation_already_accepted(self, db: Session, test_user: User, test_org: Organization): """Test validating an already accepted invitation.""" invitation = InvitationService.create_invitation( db=db, @@ -158,9 +146,7 @@ def test_validate_invitation_already_accepted( assert result is not None assert "already been accepted" in error - def test_validate_invitation_cancelled( - self, db: Session, test_user: User, test_org: Organization - ): + def test_validate_invitation_cancelled(self, db: Session, test_user: User, test_org: Organization): """Test validating a cancelled invitation.""" invitation = InvitationService.create_invitation( db=db, @@ -176,9 +162,7 @@ def test_validate_invitation_cancelled( assert result is not None assert "already been cancelled" in error - def test_cancel_invitation( - self, db: Session, test_user: User, test_org: Organization - ): + def test_cancel_invitation(self, db: Session, test_user: User, test_org: Organization): """Test cancelling an invitation.""" invitation = InvitationService.create_invitation( db=db, @@ -199,9 +183,7 @@ def test_cancel_nonexistent_invitation(self, db: Session): result = InvitationService.cancel_invitation(db, uuid4()) assert result is None - def test_mark_as_accepted( - self, db: Session, test_user: User, test_org: Organization - ): + def test_mark_as_accepted(self, db: Session, test_user: User, test_org: Organization): """Test marking an invitation as accepted.""" invitation = InvitationService.create_invitation( db=db, @@ -217,9 +199,7 @@ def test_mark_as_accepted( assert result.accepted_by_user_id == test_user.id assert result.accepted_at is not None - def test_mark_as_accepted_with_metadata( - self, db: Session, test_user: User, test_org: Organization - ): + def test_mark_as_accepted_with_metadata(self, db: Session, test_user: User, test_org: Organization): """Test marking an invitation as accepted with metadata.""" invitation = InvitationService.create_invitation( db=db, @@ -230,16 +210,12 @@ def test_mark_as_accepted_with_metadata( ) metadata = {"accepted_user_email": "different@example.com"} - result = InvitationService.mark_as_accepted( - db, invitation.id, test_user.id, metadata - ) + result = InvitationService.mark_as_accepted(db, invitation.id, test_user.id, metadata) assert result.metadata_json is not None assert result.metadata_json["accepted_user_email"] == "different@example.com" - def test_list_org_invitations( - self, db: Session, test_user: User, test_org: Organization - ): + def test_list_org_invitations(self, db: Session, test_user: User, test_org: Organization): """Test listing invitations for an org.""" InvitationService.create_invitation( db=db, @@ -259,9 +235,7 @@ def test_list_org_invitations( invitations = InvitationService.list_org_invitations(db, test_org.id) assert len(invitations) == 2 - def test_list_org_invitations_with_filter( - self, db: Session, test_user: User, test_org: Organization - ): + def test_list_org_invitations_with_filter(self, db: Session, test_user: User, test_org: Organization): """Test listing invitations with status filter.""" inv1 = InvitationService.create_invitation( db=db, @@ -280,15 +254,11 @@ def test_list_org_invitations_with_filter( InvitationService.cancel_invitation(db, inv1.id) - pending = InvitationService.list_org_invitations( - db, test_org.id, InvitationStatus.PENDING - ) + pending = InvitationService.list_org_invitations(db, test_org.id, InvitationStatus.PENDING) assert len(pending) == 1 assert pending[0].email == "user2@example.com" - def test_get_pending_invitation_for_email( - self, db: Session, test_user: User, test_org: Organization - ): + def test_get_pending_invitation_for_email(self, db: Session, test_user: User, test_org: Organization): """Test finding a pending invitation for a specific email.""" InvitationService.create_invitation( db=db, @@ -298,15 +268,11 @@ def test_get_pending_invitation_for_email( invited_by_user_id=test_user.id, ) - result = InvitationService.get_pending_invitation_for_email( - db, test_org.id, "target@example.com" - ) + result = InvitationService.get_pending_invitation_for_email(db, test_org.id, "target@example.com") assert result is not None assert result.email == "target@example.com" - def test_unique_token_generation( - self, db: Session, test_user: User, test_org: Organization - ): + def test_unique_token_generation(self, db: Session, test_user: User, test_org: Organization): """Test that tokens are unique across invitations.""" inv1 = InvitationService.create_invitation( db=db, @@ -329,12 +295,10 @@ def test_unique_token_generation( class TestAcceptInvitation: """Tests for InvitationService.accept_invitation().""" - def test_accept_invitation_success( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_invitation_success(self, db: Session, test_user: User, test_org: Organization): """Test basic invitation acceptance.""" - from app.services.user_service import UserService from app.services.org_service import OrgService + from app.services.user_service import UserService # Create a new user to be the invitee invitee = UserService.create_user(db, "invitee@example.com", "password123") @@ -347,9 +311,7 @@ def test_accept_invitation_success( invited_by_user_id=test_user.id, ) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert error is None assert result.status == InvitationStatus.ACCEPTED @@ -363,9 +325,7 @@ def test_accept_invitation_success( assert membership.role == OrgRole.MEMBER assert membership.provisioning_source == ProvisioningSource.INVITE - def test_accept_invitation_with_groups( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_invitation_with_groups(self, db: Session, test_user: User, test_org: Organization): """Test invitation acceptance creates group memberships.""" from app.services.user_service import UserService @@ -387,9 +347,7 @@ def test_accept_invitation_with_groups( group_ids=[group.id], ) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert error is None assert "Engineering" in groups_added @@ -397,15 +355,11 @@ def test_accept_invitation_with_groups( # Verify group membership has correct provisioning source memberships = UserGroupService.get_group_members(db, group.id) - invitee_membership = next( - (m for m in memberships if m.user_id == invitee.id), None - ) + invitee_membership = next((m for m in memberships if m.user_id == invitee.id), None) assert invitee_membership is not None assert invitee_membership.provisioning_source == ProvisioningSource.INVITE - def test_accept_invitation_email_mismatch( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_invitation_email_mismatch(self, db: Session, test_user: User, test_org: Organization): """Test acceptance by different email stores metadata.""" from app.services.user_service import UserService @@ -419,9 +373,7 @@ def test_accept_invitation_email_mismatch( invited_by_user_id=test_user.id, ) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert error is None assert result.status == InvitationStatus.ACCEPTED @@ -430,12 +382,10 @@ def test_accept_invitation_email_mismatch( assert result.metadata_json.get("invited_email") == "invited@example.com" assert result.metadata_json.get("accepted_user_email") == "different@example.com" - def test_accept_invitation_upserts_existing_member( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_invitation_upserts_existing_member(self, db: Session, test_user: User, test_org: Organization): """Test that existing member gets role updated.""" - from app.services.user_service import UserService from app.services.org_service import OrgService + from app.services.user_service import UserService invitee = UserService.create_user(db, "invitee@example.com", "password123") @@ -451,17 +401,13 @@ def test_accept_invitation_upserts_existing_member( invited_by_user_id=test_user.id, ) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert error is None membership = OrgService.get_org_membership(db, test_org.id, invitee.id) assert membership.role == OrgRole.ADMIN - def test_accept_invitation_skips_existing_group_member( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_invitation_skips_existing_group_member(self, db: Session, test_user: User, test_org: Organization): """Test that existing group members are not duplicated.""" from app.services.user_service import UserService @@ -486,9 +432,7 @@ def test_accept_invitation_skips_existing_group_member( group_ids=[group.id], ) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert error is None # Group not in added list since already member @@ -498,25 +442,19 @@ def test_accept_invitation_skips_existing_group_member( invitee_memberships = [m for m in memberships if m.user_id == invitee.id] assert len(invitee_memberships) == 1 - def test_accept_invalid_token( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_invalid_token(self, db: Session, test_user: User, test_org: Organization): """Test accepting with invalid token returns error.""" from app.services.user_service import UserService invitee = UserService.create_user(db, "invitee@example.com", "password123") - result, groups_added, error = InvitationService.accept_invitation( - db, "invalid-token", invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, "invalid-token", invitee) assert result is None assert error == "Invitation not found" assert groups_added == [] - def test_accept_already_accepted( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_already_accepted(self, db: Session, test_user: User, test_org: Organization): """Test accepting already-accepted invitation fails.""" from app.services.user_service import UserService @@ -534,16 +472,12 @@ def test_accept_already_accepted( InvitationService.accept_invitation(db, invitation.token, invitee) # Try again - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert "already been accepted" in error assert groups_added == [] - def test_accept_expired_invitation( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_expired_invitation(self, db: Session, test_user: User, test_org: Organization): """Test accepting expired invitation fails.""" from app.services.user_service import UserService @@ -558,16 +492,12 @@ def test_accept_expired_invitation( expires_in_days=-1, # Already expired ) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert "expired" in error assert groups_added == [] - def test_accept_cancelled_invitation( - self, db: Session, test_user: User, test_org: Organization - ): + def test_accept_cancelled_invitation(self, db: Session, test_user: User, test_org: Organization): """Test accepting cancelled invitation fails.""" from app.services.user_service import UserService @@ -583,9 +513,7 @@ def test_accept_cancelled_invitation( InvitationService.cancel_invitation(db, invitation.id) - result, groups_added, error = InvitationService.accept_invitation( - db, invitation.token, invitee - ) + result, groups_added, error = InvitationService.accept_invitation(db, invitation.token, invitee) assert "already been cancelled" in error assert groups_added == [] diff --git a/backend/tests/test_invite_acceptance_router.py b/backend/tests/test_invite_acceptance_router.py index b0107a3..f35a6eb 100644 --- a/backend/tests/test_invite_acceptance_router.py +++ b/backend/tests/test_invite_acceptance_router.py @@ -1,15 +1,13 @@ """Integration tests for invite acceptance router endpoints.""" -import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Organization, OrgRole -from app.models.org_invitation import InvitationStatus +from app.models import Organization, OrgRole, User from app.services.invitation_service import InvitationService +from app.services.org_service import OrgService from app.services.user_group_service import UserGroupService from app.services.user_service import UserService -from app.services.org_service import OrgService class TestValidateInvite: @@ -537,10 +535,15 @@ def test_accept_invite_creates_personal_org( # Verify sample project was created in the personal org from app.models.project import Project - sample_projects = db.query(Project).filter( - Project.org_id == personal_org.id, - Project.is_sample == True, - ).all() + + sample_projects = ( + db.query(Project) + .filter( + Project.org_id == personal_org.id, + Project.is_sample == True, + ) + .all() + ) assert len(sample_projects) == 1 # Verify current_org_id is still the inviter's org (not personal org) diff --git a/backend/tests/test_jobs.py b/backend/tests/test_jobs.py index 1695691..3ff4852 100644 --- a/backend/tests/test_jobs.py +++ b/backend/tests/test_jobs.py @@ -3,19 +3,20 @@ These tests verify job creation, status updates, and querying. """ -import pytest -from datetime import datetime, timedelta, UTC + +from datetime import UTC, datetime, timedelta from uuid import uuid4 + +import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.job import Job, JobType, JobStatus +from app.models.job import Job, JobStatus, JobType from app.services.job_service import ( - JobService, - JOB_TIMEOUT_MINUTES, DEFAULT_TIMEOUT_MINUTES, QUEUED_JOB_TIMEOUT_MINUTES, + JobService, ) @@ -195,9 +196,7 @@ def test_list_jobs(test_db_session: Session): assert len(project_jobs) == 1 # Filter by job_type - discovery_jobs = JobService.list_jobs( - test_db_session, job_type=JobType.BRAINSTORM_CONVERSATION_GENERATE - ) + discovery_jobs = JobService.list_jobs(test_db_session, job_type=JobType.BRAINSTORM_CONVERSATION_GENERATE) assert len(discovery_jobs) == 2 # Filter by status diff --git a/backend/tests/test_legacy_migration.py b/backend/tests/test_legacy_migration.py index 45c3397..4e91198 100644 --- a/backend/tests/test_legacy_migration.py +++ b/backend/tests/test_legacy_migration.py @@ -1,17 +1,18 @@ """Tests for legacy migration from NotificationThreadWatch to InboxFollow.""" -import pytest + from datetime import datetime, timezone -from uuid import uuid4 from unittest.mock import patch +from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.user import User +from app.models.inbox_follow import InboxFollow, InboxFollowType, InboxThreadType +from app.models.notification_thread_watch import NotificationThreadWatch from app.models.organization import Organization from app.models.project import Project -from app.models.thread import Thread, ContextType -from app.models.notification_thread_watch import NotificationThreadWatch -from app.models.inbox_follow import InboxFollow, InboxFollowType, InboxThreadType +from app.models.thread import ContextType, Thread +from app.models.user import User from app.services.notification_service import NotificationService @@ -59,14 +60,16 @@ def setup_data(self, db: Session): def test_map_context_type_to_inbox_thread_type(self): """Test mapping of context types to inbox thread types.""" # BRAINSTORM_FEATURE maps to FEATURE - result = NotificationService._map_context_type_to_inbox_thread_type( - ContextType.BRAINSTORM_FEATURE - ) + result = NotificationService._map_context_type_to_inbox_thread_type(ContextType.BRAINSTORM_FEATURE) assert result == InboxThreadType.FEATURE # Other types map to PHASE - for context_type in [ContextType.SPEC, ContextType.GENERAL, - ContextType.SPEC_DRAFT, ContextType.PROMPT_PLAN_DRAFT]: + for context_type in [ + ContextType.SPEC, + ContextType.GENERAL, + ContextType.SPEC_DRAFT, + ContextType.PROMPT_PLAN_DRAFT, + ]: result = NotificationService._map_context_type_to_inbox_thread_type(context_type) assert result == InboxThreadType.PHASE @@ -99,10 +102,14 @@ def test_watch_thread_with_feature_flag_disabled(self, mock_settings, db: Sessio assert watch.thread_id == thread.id # Check that InboxFollow was NOT created - follow = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).first() + follow = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .first() + ) assert follow is None @patch("app.services.notification_service.settings") @@ -132,10 +139,14 @@ def test_watch_thread_with_feature_flag_enabled(self, mock_settings, db: Session assert watch is not None # Check that InboxFollow was created - follow = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).first() + follow = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .first() + ) assert follow is not None assert follow.project_id == project.id assert follow.thread_type == InboxThreadType.FEATURE @@ -166,10 +177,14 @@ def test_watch_thread_maps_phase_context_correctly(self, mock_settings, db: Sess NotificationService.watch_thread(db, user.id, thread.id) # Check that InboxFollow was created with PHASE type - follow = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).first() + follow = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .first() + ) assert follow is not None assert follow.thread_type == InboxThreadType.PHASE @@ -200,10 +215,14 @@ def test_watch_thread_idempotent(self, mock_settings, db: Session, setup_data): assert watch1.id == watch2.id # Check only one InboxFollow exists - follows = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).all() + follows = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .all() + ) assert len(follows) == 1 @patch("app.services.notification_service.settings") @@ -244,17 +263,25 @@ def test_unwatch_thread_with_feature_flag_disabled(self, mock_settings, db: Sess assert result is True # NotificationThreadWatch should be deleted - old_watch = db.query(NotificationThreadWatch).filter( - NotificationThreadWatch.user_id == user.id, - NotificationThreadWatch.thread_id == thread.id, - ).first() + old_watch = ( + db.query(NotificationThreadWatch) + .filter( + NotificationThreadWatch.user_id == user.id, + NotificationThreadWatch.thread_id == thread.id, + ) + .first() + ) assert old_watch is None # InboxFollow should still exist (feature flag disabled) - remaining_follow = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).first() + remaining_follow = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .first() + ) assert remaining_follow is not None @patch("app.services.notification_service.settings") @@ -281,26 +308,32 @@ def test_unwatch_thread_with_feature_flag_enabled(self, mock_settings, db: Sessi NotificationService.watch_thread(db, user.id, thread.id) # Verify both exist - assert db.query(NotificationThreadWatch).filter( - NotificationThreadWatch.user_id == user.id - ).first() is not None - assert db.query(InboxFollow).filter( - InboxFollow.user_id == user.id - ).first() is not None + assert db.query(NotificationThreadWatch).filter(NotificationThreadWatch.user_id == user.id).first() is not None + assert db.query(InboxFollow).filter(InboxFollow.user_id == user.id).first() is not None # Unwatch result = NotificationService.unwatch_thread(db, user.id, thread.id) assert result is True # Both should be deleted - assert db.query(NotificationThreadWatch).filter( - NotificationThreadWatch.user_id == user.id, - NotificationThreadWatch.thread_id == thread.id, - ).first() is None - assert db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).first() is None + assert ( + db.query(NotificationThreadWatch) + .filter( + NotificationThreadWatch.user_id == user.id, + NotificationThreadWatch.thread_id == thread.id, + ) + .first() + is None + ) + assert ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .first() + is None + ) @patch("app.services.notification_service.settings") def test_watch_nonexistent_thread(self, mock_settings, db: Session, setup_data): @@ -319,10 +352,14 @@ def test_watch_nonexistent_thread(self, mock_settings, db: Session, setup_data): assert watch.thread_id == fake_thread_id # InboxFollow should NOT be created (thread doesn't exist) - follow = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == fake_thread_id, - ).first() + follow = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == fake_thread_id, + ) + .first() + ) assert follow is None @patch("app.services.notification_service.settings") @@ -350,18 +387,20 @@ def test_dual_write_backfill_existing_watch(self, mock_settings, db: Session, se NotificationService.watch_thread(db, user.id, thread.id) # Verify no InboxFollow - assert db.query(InboxFollow).filter( - InboxFollow.user_id == user.id - ).first() is None + assert db.query(InboxFollow).filter(InboxFollow.user_id == user.id).first() is None # Enable feature flag and re-watch mock_settings.feature_flag_use_inbox_follows = True NotificationService.watch_thread(db, user.id, thread.id) # Now InboxFollow should exist - follow = db.query(InboxFollow).filter( - InboxFollow.user_id == user.id, - InboxFollow.thread_id == thread.id, - ).first() + follow = ( + db.query(InboxFollow) + .filter( + InboxFollow.user_id == user.id, + InboxFollow.thread_id == thread.id, + ) + .first() + ) assert follow is not None assert follow.thread_type == InboxThreadType.FEATURE diff --git a/backend/tests/test_llm_mock_prompt_plan.py b/backend/tests/test_llm_mock_prompt_plan.py index 58d7b73..3d47c7d 100644 --- a/backend/tests/test_llm_mock_prompt_plan.py +++ b/backend/tests/test_llm_mock_prompt_plan.py @@ -1,7 +1,7 @@ """Tests for Mock LLM Prompt Plan Generator.""" -import pytest from uuid import uuid4 + from app.services.llm_mock_prompt_plan import generate_prompt_plan_mock @@ -129,9 +129,7 @@ def test_incorporates_spec_content(self): assert markdown.startswith("#") # Check that phases reference key concepts from spec - all_phase_text = " ".join( - [p["title"] + p["description"] + p["test_plan"] for p in phases] - ).lower() + all_phase_text = " ".join([p["title"] + p["description"] + p["test_plan"] for p in phases]).lower() # Should mention database setup (from Data Model section) assert "database" in all_phase_text or "data" in all_phase_text diff --git a/backend/tests/test_llm_mock_spec.py b/backend/tests/test_llm_mock_spec.py index 9dd02ab..de06b9a 100644 --- a/backend/tests/test_llm_mock_spec.py +++ b/backend/tests/test_llm_mock_spec.py @@ -1,6 +1,5 @@ """Tests for mock LLM specification generator.""" -import pytest from uuid import uuid4 from app.services.llm_mock_spec import generate_specification_mock @@ -78,13 +77,9 @@ def test_different_inputs_produce_different_specs(self): """Test that different inputs produce different specs.""" project_id = uuid4() - spec1 = generate_specification_mock( - project_id, "Build a REST API", [] - ) + spec1 = generate_specification_mock(project_id, "Build a REST API", []) - spec2 = generate_specification_mock( - project_id, "Build a GraphQL API", [] - ) + spec2 = generate_specification_mock(project_id, "Build a GraphQL API", []) assert spec1 != spec2 diff --git a/backend/tests/test_markdown_parser.py b/backend/tests/test_markdown_parser.py index 00a3940..efacfbf 100644 --- a/backend/tests/test_markdown_parser.py +++ b/backend/tests/test_markdown_parser.py @@ -1,7 +1,8 @@ """Tests for markdown parser utility (TOC extraction and section extraction).""" import pytest -from app.mcp.utils.markdown_parser import extract_toc, extract_section, TocEntry + +from app.mcp.utils.markdown_parser import extract_section, extract_toc class TestExtractToc: diff --git a/backend/tests/test_mcp_call_log.py b/backend/tests/test_mcp_call_log.py index ce0b580..e18c3f6 100644 --- a/backend/tests/test_mcp_call_log.py +++ b/backend/tests/test_mcp_call_log.py @@ -1,25 +1,21 @@ """Tests for MCP Call Log feature.""" -import pytest -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from uuid import uuid4 + from sqlalchemy.orm import Session -from app.models import User, Organization, Project, ProjectType -from app.models.api_key import ApiKey -from app.models.mcp_call_log import MCPCallLog +from app.models import Organization, ProjectType, User +from app.schemas.api_key import ApiKeyCreate +from app.services.api_key_service import ApiKeyService from app.services.mcp_call_log_service import MCPCallLogService from app.services.project_service import ProjectService -from app.services.api_key_service import ApiKeyService -from app.schemas.api_key import ApiKeyCreate class TestMCPCallLogService: """Tests for MCPCallLogService.""" - def test_create_call_log( - self, db: Session, test_org: Organization, test_user: User - ): + def test_create_call_log(self, db: Session, test_org: Organization, test_user: User): """Test creating an MCP call log entry.""" # Create project project = ProjectService.create_project( @@ -67,9 +63,7 @@ def test_create_call_log( assert log.is_error is False assert log.duration_ms == 150 - def test_create_call_log_with_error( - self, db: Session, test_org: Organization, test_user: User - ): + def test_create_call_log_with_error(self, db: Session, test_org: Organization, test_user: User): """Test creating an MCP call log with error response.""" project = ProjectService.create_project( db=db, @@ -107,9 +101,7 @@ def test_create_call_log_with_error( assert log.is_error is True assert log.response_error["code"] == -32601 - def test_get_call_log( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_call_log(self, db: Session, test_org: Organization, test_user: User): """Test retrieving a specific call log by ID.""" project = ProjectService.create_project( db=db, @@ -152,9 +144,7 @@ def test_get_call_log_not_found(self, db: Session): result = MCPCallLogService.get_call_log(db, uuid4()) assert result is None - def test_list_call_logs( - self, db: Session, test_org: Organization, test_user: User - ): + def test_list_call_logs(self, db: Session, test_org: Organization, test_user: User): """Test listing call logs for an organization.""" project = ProjectService.create_project( db=db, @@ -190,9 +180,7 @@ def test_list_call_logs( logs = MCPCallLogService.list_call_logs(db, test_org.id) assert len(logs) >= 3 - def test_list_call_logs_with_filters( - self, db: Session, test_org: Organization, test_user: User - ): + def test_list_call_logs_with_filters(self, db: Session, test_org: Organization, test_user: User): """Test listing call logs with filters.""" project = ProjectService.create_project( db=db, @@ -242,20 +230,14 @@ def test_list_call_logs_with_filters( ) # Filter by tool_name - logs = MCPCallLogService.list_call_logs( - db, test_org.id, tool_name="getContext" - ) + logs = MCPCallLogService.list_call_logs(db, test_org.id, tool_name="getContext") assert all(log.tool_name == "getContext" for log in logs) # Filter by is_error - error_logs = MCPCallLogService.list_call_logs( - db, test_org.id, is_error=True - ) + error_logs = MCPCallLogService.list_call_logs(db, test_org.id, is_error=True) assert all(log.is_error for log in error_logs) - def test_count_call_logs( - self, db: Session, test_org: Organization, test_user: User - ): + def test_count_call_logs(self, db: Session, test_org: Organization, test_user: User): """Test counting call logs for an organization.""" project = ProjectService.create_project( db=db, @@ -290,9 +272,7 @@ def test_count_call_logs( count = MCPCallLogService.count_call_logs(db, test_org.id) assert count >= 5 - def test_create_call_log_with_coding_agent_name( - self, db: Session, test_org: Organization, test_user: User - ): + def test_create_call_log_with_coding_agent_name(self, db: Session, test_org: Organization, test_user: User): """Test creating an MCP call log with coding_agent_name.""" project = ProjectService.create_project( db=db, @@ -333,9 +313,7 @@ def test_create_call_log_with_coding_agent_name( assert log.tool_name == "ls" assert log.is_error is False - def test_create_call_log_without_coding_agent_name( - self, db: Session, test_org: Organization, test_user: User - ): + def test_create_call_log_without_coding_agent_name(self, db: Session, test_org: Organization, test_user: User): """Test creating an MCP call log without coding_agent_name (should be None).""" project = ProjectService.create_project( db=db, diff --git a/backend/tests/test_mcp_get_context.py b/backend/tests/test_mcp_get_context.py index 69802c2..2c03bd6 100644 --- a/backend/tests/test_mcp_get_context.py +++ b/backend/tests/test_mcp_get_context.py @@ -1,30 +1,35 @@ """Tests for getContext MCP tool - Phase-First Architecture.""" -import pytest from uuid import uuid4 + +import pytest from sqlalchemy.orm import Session from app.mcp.tools.get_context import get_context from app.models import ( - Project, + Organization, ProjectType, User, - Organization, ) from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.final_spec import FinalSpec +from app.models.feature import ( + Feature, + FeatureCompletionStatus, + FeaturePriority, + FeatureProvenance, + FeatureStatus, + FeatureType, +) from app.models.final_prompt_plan import FinalPromptPlan -from app.models.module import Module, ModuleType, ModuleProvenance -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureCompletionStatus, FeaturePriority, FeatureProvenance +from app.models.final_spec import FinalSpec +from app.models.module import Module, ModuleProvenance, ModuleType from app.services.project_service import ProjectService class TestGetContext: """Test getContext tool for phase-first architecture.""" - def test_get_context_with_phases_modules_features( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_context_with_phases_modules_features(self, db: Session, test_org: Organization, test_user: User): """Get context returns phases with modules and features.""" # Create project project = ProjectService.create_project( @@ -180,9 +185,7 @@ def test_get_context_with_phases_modules_features( assert context["summary"]["completed_features"] == 1 assert context["summary"]["overall_progress_percent"] == 50.0 - def test_get_context_by_project_key( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_context_by_project_key(self, db: Session, test_org: Organization, test_user: User): """Get context by project key.""" project = ProjectService.create_project( db=db, @@ -198,9 +201,7 @@ def test_get_context_by_project_key( assert context["project"]["id"] == str(project.id) assert context["project"]["key"] == "TESTY" - def test_get_context_for_project_without_phases( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_context_for_project_without_phases(self, db: Session, test_org: Organization, test_user: User): """Get context for project without phases returns empty phases array.""" project = ProjectService.create_project( db=db, @@ -217,9 +218,7 @@ def test_get_context_for_project_without_phases( assert context["summary"]["total_phases"] == 0 assert context["summary"]["total_features"] == 0 - def test_get_context_for_phase_without_spec( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_context_for_phase_without_spec(self, db: Session, test_org: Organization, test_user: User): """Get context for phase without final spec shows unavailable.""" project = ProjectService.create_project( db=db, @@ -414,9 +413,7 @@ def test_get_context_invalid_project_raises_error(self, db: Session): with pytest.raises(ValueError, match="Project with ID"): get_context(db, project_id="00000000-0000-0000-0000-000000000000") - def test_get_context_includes_parent_application_key( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_context_includes_parent_application_key(self, db: Session, test_org: Organization, test_user: User): """Get context includes parent application key for child projects.""" # Create parent application parent = ProjectService.create_project( @@ -442,9 +439,7 @@ def test_get_context_includes_parent_application_key( assert context["project"]["type"] == "feature" assert context["project"]["parent_application_key"] == "PARNT" - def test_get_context_excludes_archived_modules( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_context_excludes_archived_modules(self, db: Session, test_org: Organization, test_user: User): """Get context excludes archived modules.""" from datetime import datetime, timezone diff --git a/backend/tests/test_mcp_get_section.py b/backend/tests/test_mcp_get_section.py index 124b799..6288965 100644 --- a/backend/tests/test_mcp_get_section.py +++ b/backend/tests/test_mcp_get_section.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from app.mcp.tools.get_section import get_section -from app.models import ProjectType, Organization, User, SpecType +from app.models import Organization, ProjectType, SpecType, User from app.services.project_service import ProjectService from app.services.spec_service import SpecService @@ -12,9 +12,7 @@ class TestGetSection: """Test getSection tool.""" - def test_get_section_from_spec( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_section_from_spec(self, db: Session, test_org: Organization, test_user: User): """Get section from specification.""" project = ProjectService.create_project( db=db, @@ -42,9 +40,7 @@ def test_get_section_from_spec( created_by_user_id=test_user.id, ) - result = get_section( - db, project_id=str(project.id), target="spec", section_id="sec-data-model" - ) + result = get_section(db, project_id=str(project.id), target="spec", section_id="sec-data-model") assert result["target"] == "spec" assert result["section_id"] == "sec-data-model" @@ -53,9 +49,7 @@ def test_get_section_from_spec( # Should not include next section assert "API Endpoints" not in result["markdown"] - def test_get_section_from_prompt_plan( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_section_from_prompt_plan(self, db: Session, test_org: Organization, test_user: User): """Get section from prompt plan.""" project = ProjectService.create_project( db=db, @@ -93,9 +87,7 @@ def test_get_section_from_prompt_plan( # Should not include next phase assert "Phase 2" not in result["markdown"] - def test_get_nonexistent_section_raises_error( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_nonexistent_section_raises_error(self, db: Session, test_org: Organization, test_user: User): """Get non-existent section raises ValueError.""" project = ProjectService.create_project( db=db, @@ -121,9 +113,7 @@ def test_get_nonexistent_section_raises_error( section_id="sec-nonexistent", ) - def test_get_first_section( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_first_section(self, db: Session, test_org: Organization, test_user: User): """Get the first section.""" project = ProjectService.create_project( db=db, @@ -148,9 +138,7 @@ def test_get_first_section( created_by_user_id=test_user.id, ) - result = get_section( - db, project_id=str(project.id), target="spec", section_id="sec-overview" - ) + result = get_section(db, project_id=str(project.id), target="spec", section_id="sec-overview") assert "# Overview" in result["markdown"] assert "This is the overview." in result["markdown"] @@ -158,9 +146,7 @@ def test_get_first_section( # Should not include next section assert "Data Model" not in result["markdown"] - def test_get_deeply_nested_section( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_deeply_nested_section(self, db: Session, test_org: Organization, test_user: User): """Get deeply nested section.""" project = ProjectService.create_project( db=db, @@ -194,9 +180,7 @@ def test_get_deeply_nested_section( created_by_user_id=test_user.id, ) - result = get_section( - db, project_id=str(project.id), target="spec", section_id="sec-backend" - ) + result = get_section(db, project_id=str(project.id), target="spec", section_id="sec-backend") assert "### Backend" in result["markdown"] assert "Nested content here." in result["markdown"] @@ -204,9 +188,7 @@ def test_get_deeply_nested_section( # Should not include sibling section assert "Frontend" not in result["markdown"] - def test_get_section_invalid_target_raises_error( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_section_invalid_target_raises_error(self, db: Session, test_org: Organization, test_user: User): """Get section with invalid target raises ValueError.""" project = ProjectService.create_project( db=db, diff --git a/backend/tests/test_mcp_get_toc.py b/backend/tests/test_mcp_get_toc.py index 37ce149..3cd23f7 100644 --- a/backend/tests/test_mcp_get_toc.py +++ b/backend/tests/test_mcp_get_toc.py @@ -5,8 +5,14 @@ from app.mcp.tools.get_toc import get_toc from app.models import ( - ProjectType, Organization, User, SpecType, - BrainstormingPhase, BrainstormingPhaseType, FinalSpec, FinalPromptPlan + BrainstormingPhase, + BrainstormingPhaseType, + FinalPromptPlan, + FinalSpec, + Organization, + ProjectType, + SpecType, + User, ) from app.services.project_service import ProjectService from app.services.spec_service import SpecService @@ -15,9 +21,7 @@ class TestGetToc: """Test getToc tool.""" - def test_get_toc_for_spec( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_for_spec(self, db: Session, test_org: Organization, test_user: User): """Get TOC for specification.""" project = ProjectService.create_project( db=db, @@ -55,9 +59,7 @@ def test_get_toc_for_spec( assert result["toc"][2]["title"] == "Backend" assert result["toc"][2]["level"] == 3 - def test_get_toc_for_prompt_plan( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_for_prompt_plan(self, db: Session, test_org: Organization, test_user: User): """Get TOC for prompt plan.""" project = ProjectService.create_project( db=db, @@ -92,9 +94,7 @@ def test_get_toc_for_prompt_plan( assert result["toc"][1]["title"] == "Backend Setup" assert result["toc"][2]["title"] == "Phase 2: Implementation" - def test_get_toc_for_project_without_spec_returns_empty( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_for_project_without_spec_returns_empty(self, db: Session, test_org: Organization, test_user: User): """Get TOC for project without spec returns empty array.""" project = ProjectService.create_project( db=db, @@ -109,9 +109,7 @@ def test_get_toc_for_project_without_spec_returns_empty( assert result["target"] == "spec" assert result["toc"] == [] - def test_get_toc_with_nested_headings( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_with_nested_headings(self, db: Session, test_org: Organization, test_user: User): """Get TOC preserves nested heading structure.""" project = ProjectService.create_project( db=db, @@ -156,9 +154,7 @@ def test_get_toc_with_nested_headings( assert toc[3]["path"] == ["Project", "Section A", "Subsection A2"] assert toc[4]["path"] == ["Project", "Section B"] - def test_get_toc_invalid_target_raises_error( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_invalid_target_raises_error(self, db: Session, test_org: Organization, test_user: User): """Get TOC with invalid target raises ValueError.""" project = ProjectService.create_project( db=db, @@ -176,9 +172,7 @@ def test_get_toc_invalid_project_raises_error(self, db: Session): with pytest.raises(ValueError, match="Project with ID"): get_toc(db, project_id="00000000-0000-0000-0000-000000000000", target="spec") - def test_get_toc_from_final_spec_in_brainstorming_phase( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_from_final_spec_in_brainstorming_phase(self, db: Session, test_org: Organization, test_user: User): """Get TOC from FinalSpec in brainstorming phase (phase-first architecture).""" project = ProjectService.create_project( db=db, @@ -290,9 +284,7 @@ def test_get_toc_from_final_prompt_plan_in_brainstorming_phase( assert result["toc"][3]["title"] == "Phase 2: Core Implementation" assert result["toc"][3]["level"] == 1 - def test_get_toc_prefers_phase_level_over_project_level( - self, db: Session, test_org: Organization, test_user: User - ): + def test_get_toc_prefers_phase_level_over_project_level(self, db: Session, test_org: Organization, test_user: User): """Phase-level FinalSpec takes precedence over project-level SpecVersion.""" project = ProjectService.create_project( db=db, diff --git a/backend/tests/test_mcp_image_upload_api.py b/backend/tests/test_mcp_image_upload_api.py index a4cf34c..f94da85 100644 --- a/backend/tests/test_mcp_image_upload_api.py +++ b/backend/tests/test_mcp_image_upload_api.py @@ -1,29 +1,87 @@ """Tests for MCP image upload API endpoint.""" import io -import time -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from uuid import uuid4 import pytest -from app.models import User, Organization, Project -from app.models.mcp_image_submission import MCPImageSubmission, MCP_IMAGE_SUBMISSION_EXPIRY_HOURS +from app.models import Organization, Project, User +from app.models.mcp_image_submission import MCP_IMAGE_SUBMISSION_EXPIRY_HOURS, MCPImageSubmission from app.services.image_service import ImageService - # Small valid PNG (1x1 pixel, transparent) -VALID_PNG_BYTES = bytes([ - 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, # PNG signature - 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, # IHDR chunk - 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, - 0x08, 0x06, 0x00, 0x00, 0x00, 0x1F, 0x15, 0xC4, - 0x89, 0x00, 0x00, 0x00, 0x0A, 0x49, 0x44, 0x41, # IDAT chunk - 0x54, 0x78, 0x9C, 0x63, 0x00, 0x01, 0x00, 0x00, - 0x05, 0x00, 0x01, 0x0D, 0x0A, 0x2D, 0xB4, 0x00, - 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, # IEND chunk - 0x42, 0x60, 0x82, -]) +VALID_PNG_BYTES = bytes( + [ + 0x89, + 0x50, + 0x4E, + 0x47, + 0x0D, + 0x0A, + 0x1A, + 0x0A, # PNG signature + 0x00, + 0x00, + 0x00, + 0x0D, + 0x49, + 0x48, + 0x44, + 0x52, # IHDR chunk + 0x00, + 0x00, + 0x00, + 0x01, + 0x00, + 0x00, + 0x00, + 0x01, + 0x08, + 0x06, + 0x00, + 0x00, + 0x00, + 0x1F, + 0x15, + 0xC4, + 0x89, + 0x00, + 0x00, + 0x00, + 0x0A, + 0x49, + 0x44, + 0x41, # IDAT chunk + 0x54, + 0x78, + 0x9C, + 0x63, + 0x00, + 0x01, + 0x00, + 0x00, + 0x05, + 0x00, + 0x01, + 0x0D, + 0x0A, + 0x2D, + 0xB4, + 0x00, + 0x00, + 0x00, + 0x00, + 0x49, + 0x45, + 0x4E, + 0x44, + 0xAE, # IEND chunk + 0x42, + 0x60, + 0x82, + ] +) # Note: 'client' fixture comes from conftest.py with proper DB override @@ -155,9 +213,7 @@ def test_upload_valid_png(self, client, db, valid_token): assert data["expires_in_hours"] == MCP_IMAGE_SUBMISSION_EXPIRY_HOURS # Verify submission was created in DB - submission = db.query(MCPImageSubmission).filter( - MCPImageSubmission.submission_id == data["image_id"] - ).first() + submission = db.query(MCPImageSubmission).filter(MCPImageSubmission.submission_id == data["image_id"]).first() assert submission is not None def test_upload_requires_authorization(self, client): diff --git a/backend/tests/test_mcp_permissions.py b/backend/tests/test_mcp_permissions.py index a813502..4ab252c 100644 --- a/backend/tests/test_mcp_permissions.py +++ b/backend/tests/test_mcp_permissions.py @@ -1,27 +1,35 @@ """Tests for MCP HTTP endpoint permission enforcement.""" -import pytest from uuid import uuid4 -from sqlalchemy.orm import Session + +import pytest from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app.models import ( - User, Organization, + OrgMembership, + OrgRole, Project, ProjectRole, ProjectType, - OrgMembership, - OrgRole, + User, ) -from app.models.project_share import ProjectShare, ShareSubjectType from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleType, ModuleProvenance -from app.models.feature import Feature, FeatureType, FeatureStatus, FeatureCompletionStatus, FeaturePriority, FeatureProvenance -from app.services.user_service import UserService -from app.services.project_service import ProjectService -from app.services.api_key_service import ApiKeyService +from app.models.feature import ( + Feature, + FeatureCompletionStatus, + FeaturePriority, + FeatureProvenance, + FeatureStatus, + FeatureType, +) +from app.models.module import Module, ModuleProvenance, ModuleType +from app.models.project_share import ProjectShare, ShareSubjectType from app.schemas.api_key import ApiKeyCreate +from app.services.api_key_service import ApiKeyService +from app.services.project_service import ProjectService +from app.services.user_service import UserService class TestMCPPermissions: @@ -174,9 +182,7 @@ def setup_project_with_roles( return project @pytest.fixture - def viewer_api_key( - self, db: Session, viewer_user: User, setup_project_with_roles: Project - ) -> str: + def viewer_api_key(self, db: Session, viewer_user: User, setup_project_with_roles: Project) -> str: """Create an API key for the viewer user.""" _, raw_key = ApiKeyService.create_api_key( db=db, @@ -186,9 +192,7 @@ def viewer_api_key( return raw_key @pytest.fixture - def member_api_key( - self, db: Session, member_user: User, setup_project_with_roles: Project - ) -> str: + def member_api_key(self, db: Session, member_user: User, setup_project_with_roles: Project) -> str: """Create an API key for the member user.""" _, raw_key = ApiKeyService.create_api_key( db=db, @@ -198,9 +202,7 @@ def member_api_key( return raw_key @pytest.fixture - def admin_api_key( - self, db: Session, admin_user: User, setup_project_with_roles: Project - ) -> str: + def admin_api_key(self, db: Session, admin_user: User, setup_project_with_roles: Project) -> str: """Create an API key for the admin user.""" _, raw_key = ApiKeyService.create_api_key( db=db, @@ -209,9 +211,7 @@ def admin_api_key( ) return raw_key - def _call_mcp_tool( - self, client: TestClient, project_id: str, api_key: str, tool_name: str, arguments: dict = None - ): + def _call_mcp_tool(self, client: TestClient, project_id: str, api_key: str, tool_name: str, arguments: dict = None): """Helper to call an MCP tool via HTTP endpoint.""" return client.post( f"/api/v1/projects/{project_id}/mcp", @@ -283,7 +283,11 @@ def test_viewer_cannot_call_metadata_tools( str(setup_project_with_roles.id), viewer_api_key, "setMetadataValueForKey", - {"path": "/phases/test-phase/features/test-module/perm-001-test-feature/", "key": "completion_status", "value": "completed"}, + { + "path": "/phases/test-phase/features/test-module/perm-001-test-feature/", + "key": "completion_status", + "value": "completed", + }, ) assert response.status_code == 200 @@ -350,7 +354,11 @@ def test_member_can_call_metadata_tools( str(setup_project_with_roles.id), member_api_key, "setMetadataValueForKey", - {"path": "/phases/test-phase/features/test-module/perm-001-test-feature/", "key": "notes", "value": "Test note"}, + { + "path": "/phases/test-phase/features/test-module/perm-001-test-feature/", + "key": "notes", + "value": "Test note", + }, ) assert response.status_code == 200 @@ -414,7 +422,11 @@ def test_admin_can_call_metadata_tools( str(setup_project_with_roles.id), admin_api_key, "setMetadataValueForKey", - {"path": "/phases/test-phase/features/test-module/perm-001-test-feature/", "key": "notes", "value": "Admin note"}, + { + "path": "/phases/test-phase/features/test-module/perm-001-test-feature/", + "key": "notes", + "value": "Admin note", + }, ) assert response.status_code == 200 diff --git a/backend/tests/test_mcp_server.py b/backend/tests/test_mcp_server.py index db65cda..ce00036 100644 --- a/backend/tests/test_mcp_server.py +++ b/backend/tests/test_mcp_server.py @@ -1,14 +1,12 @@ """Tests for MFBT MCP Server.""" import json + import pytest -from unittest.mock import Mock from app.mcp.server import create_mcp_server -from app.models import ProjectType, Organization, User +from app.models import Organization, ProjectType, User from app.services.project_service import ProjectService -from app.services.spec_service import SpecService -from app.models import SpecType class TestMCPServer: @@ -67,10 +65,7 @@ async def test_call_get_context_tool(self, db, test_org: Organization, test_user server = create_mcp_server(db.bind.url) # Call tool - result = await server.call_tool( - "getContext", - {"project_id": str(project.id)} - ) + result = await server.call_tool("getContext", {"project_id": str(project.id)}) # Result is a list of TextContent objects - extract the JSON assert isinstance(result, list) diff --git a/backend/tests/test_mention_notification_handler.py b/backend/tests/test_mention_notification_handler.py index 448da79..4e24c4a 100644 --- a/backend/tests/test_mention_notification_handler.py +++ b/backend/tests/test_mention_notification_handler.py @@ -1,20 +1,20 @@ """Tests for mention notification handler.""" -import pytest from datetime import datetime, timezone from unittest.mock import MagicMock from uuid import uuid4 +import pytest from sqlalchemy.orm import Session from app.models.feature import FeatureType +from app.models.thread import ContextType from workers.handlers.integration import ( - mention_notification_handler, _build_mention_subject, _build_view_url, _format_relative_time, + mention_notification_handler, ) -from app.models.thread import ContextType class TestBuildMentionSubject: @@ -64,8 +64,8 @@ def test_conversation_feature_no_title(self): def _create_mock_db_query(platform_settings_base_url: str | None, module_result=None): """Helper to create a mock db.query that handles PlatformSettings and Module queries.""" - from app.models.platform_settings import PlatformSettings from app.models.module import Module + from app.models.platform_settings import PlatformSettings def query_side_effect(model_class): mock_query = MagicMock() @@ -122,7 +122,10 @@ def test_conversation_feature_with_phase_returns_phase_url(self, db: Session, mo feature.feature_type = FeatureType.CONVERSATION url = _build_view_url(thread, feature, "project-123", db) - assert url == f"https://example.com/projects/project-123/brainstorming/{phase_id}/conversations?feature={feature_id}" + assert ( + url + == f"https://example.com/projects/project-123/brainstorming/{phase_id}/conversations?feature={feature_id}" + ) def test_conversation_feature_without_phase_falls_back_to_feature_url(self, db: Session, monkeypatch): """Test URL for CONVERSATION feature without phase falls back to feature page.""" @@ -169,6 +172,7 @@ def test_no_feature_returns_project_url(self, db: Session, monkeypatch): def test_falls_back_to_config_when_platform_settings_base_url_is_none(self, db: Session, monkeypatch): """Test URL falls back to config.settings.base_url when platform settings has no base_url.""" from app import config + monkeypatch.setattr(config.settings, "base_url", "https://fallback.example.com") # Mock db.query to return platform settings with no base_url @@ -193,6 +197,7 @@ def test_just_now(self): def test_minutes_ago(self): """Test minutes ago formatting.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(minutes=5) result = _format_relative_time(time) assert result == "5 minutes ago" @@ -200,6 +205,7 @@ def test_minutes_ago(self): def test_one_minute_ago(self): """Test singular minute.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(minutes=1, seconds=30) result = _format_relative_time(time) assert result == "1 minute ago" @@ -207,6 +213,7 @@ def test_one_minute_ago(self): def test_hours_ago(self): """Test hours ago formatting.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(hours=3) result = _format_relative_time(time) assert result == "3 hours ago" @@ -214,6 +221,7 @@ def test_hours_ago(self): def test_one_hour_ago(self): """Test singular hour.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(hours=1, minutes=30) result = _format_relative_time(time) assert result == "1 hour ago" @@ -221,6 +229,7 @@ def test_one_hour_ago(self): def test_days_ago(self): """Test days ago formatting.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(days=2) result = _format_relative_time(time) assert result == "2 days ago" @@ -228,6 +237,7 @@ def test_days_ago(self): def test_one_day_ago(self): """Test singular day.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(days=1, hours=12) result = _format_relative_time(time) assert result == "1 day ago" @@ -235,6 +245,7 @@ def test_one_day_ago(self): def test_old_date(self): """Test date format for old dates.""" from datetime import timedelta + time = datetime.now(timezone.utc) - timedelta(days=30) result = _format_relative_time(time) # Should be formatted as "Nov 20, 2025" or similar diff --git a/backend/tests/test_mention_utils.py b/backend/tests/test_mention_utils.py index c2974c6..79bf99e 100644 --- a/backend/tests/test_mention_utils.py +++ b/backend/tests/test_mention_utils.py @@ -1,17 +1,16 @@ """Tests for mention parsing utility.""" -import pytest from uuid import UUID from app.services.mention_utils import ( - extract_user_mentions, - has_mfbtai_mention, + MENTION_PATTERN, + PHASE_FEATURE_MENTION_PATTERN, clean_markdown_for_display, clean_markdown_with_entity_mentions, - extract_phase_feature_mentions, extract_phase_feature_identifiers, - MENTION_PATTERN, - PHASE_FEATURE_MENTION_PATTERN, + extract_phase_feature_mentions, + extract_user_mentions, + has_mfbtai_mention, ) @@ -69,10 +68,7 @@ def test_extract_multiple_mentions(self): def test_deduplicate_mentions(self): """Test that duplicate mentions are deduplicated.""" - text = ( - "@[John](123e4567-e89b-12d3-a456-426614174000) said hi to " - "@[John](123e4567-e89b-12d3-a456-426614174000)" - ) + text = "@[John](123e4567-e89b-12d3-a456-426614174000) said hi to @[John](123e4567-e89b-12d3-a456-426614174000)" result = extract_user_mentions(text) assert len(result) == 1 @@ -95,10 +91,7 @@ def test_exclude_mfbtai_case_insensitive(self): def test_mixed_mfbtai_and_users(self): """Test extracting users while excluding mfbtai.""" - text = ( - "@[MFBTAI](mfbtai) please help " - "@[John](123e4567-e89b-12d3-a456-426614174000) with this" - ) + text = "@[MFBTAI](mfbtai) please help @[John](123e4567-e89b-12d3-a456-426614174000) with this" result = extract_user_mentions(text) assert len(result) == 1 assert result[0] == UUID("123e4567-e89b-12d3-a456-426614174000") @@ -174,10 +167,7 @@ def test_returns_false_for_user_only_mentions(self): def test_returns_true_for_mixed_mentions(self): """Test that @MFBTAI is detected even with other users mentioned.""" - text = ( - "@[MFBTAI](mfbtai) please help " - "@[John](123e4567-e89b-12d3-a456-426614174000) with this" - ) + text = "@[MFBTAI](mfbtai) please help @[John](123e4567-e89b-12d3-a456-426614174000) with this" assert has_mfbtai_mention(text) is True def test_empty_string(self): diff --git a/backend/tests/test_message_sequence_numbers.py b/backend/tests/test_message_sequence_numbers.py index f46a75a..066e108 100644 --- a/backend/tests/test_message_sequence_numbers.py +++ b/backend/tests/test_message_sequence_numbers.py @@ -1,16 +1,16 @@ """Tests for message sequence numbering on ProjectChatMessage and ThreadItem.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 +import pytest from sqlalchemy.orm import Session -from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.user import User from app.models.organization import Organization from app.models.project import Project +from app.models.project_chat import ProjectChat +from app.models.thread import ContextType, Thread +from app.models.user import User from app.services.project_chat_service import ProjectChatService from app.services.thread_service import ThreadService diff --git a/backend/tests/test_module_endpoints.py b/backend/tests/test_module_endpoints.py index b63a350..584dd9c 100644 --- a/backend/tests/test_module_endpoints.py +++ b/backend/tests/test_module_endpoints.py @@ -1,13 +1,12 @@ """Tests for module REST API endpoints.""" -import pytest -from datetime import datetime, timezone + from uuid import uuid4 -from app.models.module import ModuleProvenance from app.models.feature import FeatureProvenance, FeatureType -from app.services.module_service import ModuleService +from app.models.module import ModuleProvenance from app.services.feature_service import FeatureService from app.services.implementation_service import ImplementationService +from app.services.module_service import ModuleService class TestModuleEndpoints: diff --git a/backend/tests/test_module_feature_agent.py b/backend/tests/test_module_feature_agent.py index 5ba53e7..5db3406 100644 --- a/backend/tests/test_module_feature_agent.py +++ b/backend/tests/test_module_feature_agent.py @@ -1,61 +1,60 @@ """Tests for the module/feature extraction 5-agent system.""" -import pytest + from uuid import uuid4 +import pytest + from app.agents.module_feature.types import ( - # Context - ModuleFeatureContext, - # Agent 1: Spec Analyzer - SpecRequirement, - SpecAnalysis, - # Agent 2: Plan Structurer - ImplementationStep, - ImplementationPhase, - PlanStructure, - # Agent 3: Merger - FeatureMapping, - ModuleMapping, - MergedMapping, - # Agent 4: Writer - FeatureContent, - WriterOutput, + # UI Metadata + AGENT_METADATA, + MAX_MODULES, + MIN_FEATURES_PER_MODULE, + # Constants + MIN_MODULES, + MIN_PROMPT_PLAN_TEXT_LENGTH, + MIN_SPEC_TEXT_LENGTH, + WORKFLOW_STEPS, # Agent 5: Validator CoverageReport, # Final Output ExtractedFeature, ExtractedModule, ExtractionResult, - # Helpers - validate_extraction_result, - get_module_by_id, - get_feature_by_id, - get_requirement_by_id, - # Constants - MIN_MODULES, - MAX_MODULES, - MIN_FEATURES_PER_MODULE, - MAX_FEATURES_PER_MODULE, - MIN_TOTAL_FEATURES, - MAX_TOTAL_FEATURES, - MIN_SPEC_TEXT_LENGTH, - MIN_PROMPT_PLAN_TEXT_LENGTH, + FeatureCategoryType, + # Agent 4: Writer + FeatureContent, + # Agent 3: Merger + FeatureMapping, + FeaturePriorityLevel, + ImplementationPhase, + # Agent 2: Plan Structurer + ImplementationStep, + MergedMapping, # Enums ModuleCategory, - FeaturePriorityLevel, - FeatureCategoryType, - # UI Metadata - AGENT_METADATA, - WORKFLOW_STEPS, + # Context + ModuleFeatureContext, + ModuleMapping, + PlanStructure, + SpecAnalysis, + # Agent 1: Spec Analyzer + SpecRequirement, + WriterOutput, + get_feature_by_id, + get_module_by_id, + get_requirement_by_id, + # Helpers + validate_extraction_result, ) from app.agents.module_feature.utils import ( - strip_markdown_json, + chunk_list, + generate_unique_id, + normalize_whitespace, parse_json_response, safe_parse_json, - truncate_text, - normalize_whitespace, slugify, - generate_unique_id, - chunk_list, + strip_markdown_json, + truncate_text, ) @@ -419,22 +418,26 @@ def _make_valid_result( for i in range(num_modules): features = [] for j in range(features_per_module): - features.append(ExtractedFeature( - title=f"Feature {i}-{j}", - description="d" * description_length, - spec_text="x" * spec_text_length, - prompt_plan_text="y" * prompt_plan_text_length, - priority="important", - category="API", - order_index=j, - )) - modules.append(ExtractedModule( - title=f"Module {i}", - description=f"Description for module {i}", - order_index=i, - module_category="phase", - features=features, - )) + features.append( + ExtractedFeature( + title=f"Feature {i}-{j}", + description="d" * description_length, + spec_text="x" * spec_text_length, + prompt_plan_text="y" * prompt_plan_text_length, + priority="important", + category="API", + order_index=j, + ) + ) + modules.append( + ExtractedModule( + title=f"Module {i}", + description=f"Description for module {i}", + order_index=i, + module_category="phase", + features=features, + ) + ) return ExtractionResult(modules=modules) def test_valid_result_passes_validation(self): @@ -654,7 +657,8 @@ def test_parse_json_response_with_markdown(self): def test_parse_json_response_invalid_raises(self): """Test parse_json_response raises on invalid JSON.""" import json - text = 'not json' + + text = "not json" with pytest.raises(json.JSONDecodeError): parse_json_response(text) @@ -665,7 +669,7 @@ def test_safe_parse_json_valid(self): def test_safe_parse_json_invalid_returns_default(self): """Test safe_parse_json returns default on invalid JSON.""" - result = safe_parse_json('not json', default={"default": True}) + result = safe_parse_json("not json", default={"default": True}) assert result == {"default": True} def test_truncate_text(self): @@ -745,6 +749,7 @@ def test_repair_truncated_array(self): # Verify it's now valid JSON import json + parsed = json.loads(result) assert parsed["items"] == [1, 2, 3] @@ -757,6 +762,7 @@ def test_repair_truncated_nested_object(self): assert result == '{"items": [{"name": "test"}]}' import json + parsed = json.loads(result) assert parsed["items"][0]["name"] == "test" @@ -769,6 +775,7 @@ def test_repair_truncated_string(self): assert result == '{"text": "hello world"}' import json + parsed = json.loads(result) assert parsed["text"] == "hello world" @@ -816,8 +823,9 @@ def test_repair_escaped_quotes(self): assert result.endswith('"}') import json + parsed = json.loads(result) - assert 'hello' in parsed["text"] + assert "hello" in parsed["text"] def test_parse_json_with_repair_missing_colon(self): """Test parse_json_with_repair handles missing colon after key.""" diff --git a/backend/tests/test_module_service.py b/backend/tests/test_module_service.py index 855e6ff..8c172d0 100644 --- a/backend/tests/test_module_service.py +++ b/backend/tests/test_module_service.py @@ -2,19 +2,21 @@ Tests the service layer for module operations. """ + +from uuid import uuid4 + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType +from app.models.module import ModuleProvenance from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleProvenance -from app.services.user_service import UserService +from app.models.user import User from app.services.module_service import ModuleService +from app.services.user_service import UserService @pytest.fixture @@ -396,8 +398,8 @@ def test_archive_system_modules_for_specific_phase( sample_user: User, ): """Test that archive_system_modules with phase_id only archives modules for that phase.""" - from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.models.brainstorming_phase import BrainstormingPhaseType + from app.services.brainstorming_phase_service import BrainstormingPhaseService # Create two phases phase1 = BrainstormingPhaseService.create_brainstorming_phase( diff --git a/backend/tests/test_notification_adapters.py b/backend/tests/test_notification_adapters.py index 56d85e8..ab6f05d 100644 --- a/backend/tests/test_notification_adapters.py +++ b/backend/tests/test_notification_adapters.py @@ -4,6 +4,7 @@ Verifies that notification adapters correctly implement the base interface and can send notifications through their respective channels. """ + import pytest from app.services.notification_adapters import ( @@ -25,7 +26,7 @@ async def test_send_email_mock(self): recipient="user@example.com", subject="Test Notification", body="This is a test notification body.", - metadata={"project_id": "123", "thread_id": "456"} + metadata={"project_id": "123", "thread_id": "456"}, ) assert success is True @@ -40,7 +41,7 @@ async def test_send_email_with_channel_config(self): subject="Test Notification", body="This is a test notification body.", metadata={"project_id": "123"}, - channel_config={"smtp_server": "smtp.example.com"} + channel_config={"smtp_server": "smtp.example.com"}, ) assert success is True @@ -56,10 +57,7 @@ async def test_send_email_non_mock(self): adapter = EmailAdapter(mock=False) success = await adapter.send_notification( - recipient="user@example.com", - subject="Test", - body="Test body", - metadata={} + recipient="user@example.com", subject="Test", body="Test body", metadata={} ) assert success is False @@ -77,7 +75,7 @@ async def test_send_slack_mock(self): recipient="#general", subject="Test Notification", body="This is a test notification for Slack.", - metadata={"project_id": "123", "link": "https://example.com"} + metadata={"project_id": "123", "link": "https://example.com"}, ) assert success is True @@ -92,7 +90,7 @@ async def test_send_slack_with_webhook(self): subject="Alert", body="Something happened", metadata={"severity": "high"}, - channel_config={"webhook_url": "https://hooks.slack.com/services/xxx"} + channel_config={"webhook_url": "https://hooks.slack.com/services/xxx"}, ) assert success is True @@ -107,12 +105,7 @@ async def test_send_slack_non_mock(self): """Test that non-mock mode returns False (not implemented).""" adapter = SlackAdapter(mock=False) - success = await adapter.send_notification( - recipient="#test", - subject="Test", - body="Test body", - metadata={} - ) + success = await adapter.send_notification(recipient="#test", subject="Test", body="Test body", metadata={}) assert success is False @@ -129,7 +122,7 @@ async def test_send_teams_mock(self): recipient="team-channel-id", subject="Test Notification", body="This is a test notification for Teams.", - metadata={"project_id": "123", "link": "https://example.com"} + metadata={"project_id": "123", "link": "https://example.com"}, ) assert success is True @@ -144,7 +137,7 @@ async def test_send_teams_with_webhook(self): subject="Alert", body="Something happened", metadata={"severity": "high"}, - channel_config={"webhook_url": "https://outlook.office.com/webhook/xxx"} + channel_config={"webhook_url": "https://outlook.office.com/webhook/xxx"}, ) assert success is True @@ -159,11 +152,6 @@ async def test_send_teams_non_mock(self): """Test that non-mock mode returns False (not implemented).""" adapter = TeamsAdapter(mock=False) - success = await adapter.send_notification( - recipient="test", - subject="Test", - body="Test body", - metadata={} - ) + success = await adapter.send_notification(recipient="test", subject="Test", body="Test body", metadata={}) assert success is False diff --git a/backend/tests/test_notification_models.py b/backend/tests/test_notification_models.py index 2c643af..89645e7 100644 --- a/backend/tests/test_notification_models.py +++ b/backend/tests/test_notification_models.py @@ -4,24 +4,23 @@ Verifies that notification preference, project mute, and thread watch models work correctly with proper constraints and relationships. """ + import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from uuid import uuid4 from app.database import Base from app.models import ( - User, + ContextType, + NotificationChannel, + NotificationPreference, + NotificationProjectMute, Organization, Project, - ProjectType, ProjectStatus, + ProjectType, Thread, - ContextType, - NotificationPreference, - NotificationChannel, - NotificationProjectMute, - NotificationThreadWatch, + User, ) @@ -62,11 +61,7 @@ def test_db_session(test_engine): @pytest.fixture def sample_user(test_db_session): """Create a sample user for testing.""" - user = User( - email="testuser@example.com", - password_hash="hashed_password", - display_name="Test User" - ) + user = User(email="testuser@example.com", password_hash="hashed_password", display_name="Test User") test_db_session.add(user) test_db_session.commit() test_db_session.refresh(user) @@ -76,9 +71,7 @@ def sample_user(test_db_session): @pytest.fixture def sample_org(test_db_session): """Create a sample organization for testing.""" - org = Organization( - name="Test Org" - ) + org = Organization(name="Test Org") test_db_session.add(org) test_db_session.commit() test_db_session.refresh(org) @@ -95,7 +88,7 @@ def sample_project(test_db_session, sample_org, sample_user): status=ProjectStatus.DRAFT, org_id=sample_org.id, idea_text="Test idea", - created_by=sample_user.id + created_by=sample_user.id, ) test_db_session.add(project) test_db_session.commit() @@ -116,7 +109,7 @@ def sample_thread(test_db_session, sample_project, sample_user): title="Test Thread", project_id=str(sample_project.id), context_type=ContextType.GENERAL, - created_by=str(sample_user.id) + created_by=str(sample_user.id), ) test_db_session.add(thread) test_db_session.commit() @@ -129,11 +122,7 @@ class TestNotificationPreference: def test_create_notification_preference(self, test_db_session, sample_user): """Test creating a notification preference.""" - pref = NotificationPreference( - user_id=sample_user.id, - channel=NotificationChannel.EMAIL, - enabled=True - ) + pref = NotificationPreference(user_id=sample_user.id, channel=NotificationChannel.EMAIL, enabled=True) test_db_session.add(pref) test_db_session.commit() test_db_session.refresh(pref) @@ -150,10 +139,7 @@ def test_notification_preference_with_config(self, test_db_session, sample_user) """Test creating a notification preference with channel config.""" config = '{"webhook_url": "https://hooks.slack.com/services/xxx"}' pref = NotificationPreference( - user_id=sample_user.id, - channel=NotificationChannel.SLACK, - enabled=True, - channel_config=config + user_id=sample_user.id, channel=NotificationChannel.SLACK, enabled=True, channel_config=config ) test_db_session.add(pref) test_db_session.commit() @@ -164,11 +150,7 @@ def test_notification_preference_with_config(self, test_db_session, sample_user) def test_notification_preference_disabled(self, test_db_session, sample_user): """Test creating a disabled notification preference.""" - pref = NotificationPreference( - user_id=sample_user.id, - channel=NotificationChannel.TEAMS, - enabled=False - ) + pref = NotificationPreference(user_id=sample_user.id, channel=NotificationChannel.TEAMS, enabled=False) test_db_session.add(pref) test_db_session.commit() test_db_session.refresh(pref) @@ -177,19 +159,12 @@ def test_notification_preference_disabled(self, test_db_session, sample_user): def test_notification_preference_cascade_delete(self, test_db_session): """Test that preferences are deleted when user is deleted.""" - user = User( - email="delete@example.com", - password_hash="hashed_password" - ) + user = User(email="delete@example.com", password_hash="hashed_password") test_db_session.add(user) test_db_session.commit() test_db_session.refresh(user) - pref = NotificationPreference( - user_id=user.id, - channel=NotificationChannel.EMAIL, - enabled=True - ) + pref = NotificationPreference(user_id=user.id, channel=NotificationChannel.EMAIL, enabled=True) test_db_session.add(pref) test_db_session.commit() @@ -203,11 +178,7 @@ def test_notification_preference_cascade_delete(self, test_db_session): def test_notification_preference_repr(self, test_db_session, sample_user): """Test string representation of NotificationPreference.""" - pref = NotificationPreference( - user_id=sample_user.id, - channel=NotificationChannel.EMAIL, - enabled=True - ) + pref = NotificationPreference(user_id=sample_user.id, channel=NotificationChannel.EMAIL, enabled=True) test_db_session.add(pref) test_db_session.commit() test_db_session.refresh(pref) @@ -223,10 +194,7 @@ class TestNotificationProjectMute: def test_create_project_mute(self, test_db_session, sample_user, sample_project): """Test creating a project mute.""" - mute = NotificationProjectMute( - user_id=sample_user.id, - project_id=sample_project.id - ) + mute = NotificationProjectMute(user_id=sample_user.id, project_id=sample_project.id) test_db_session.add(mute) test_db_session.commit() test_db_session.refresh(mute) @@ -238,18 +206,12 @@ def test_create_project_mute(self, test_db_session, sample_user, sample_project) def test_project_mute_unique_constraint(self, test_db_session, sample_user, sample_project): """Test that user can't mute the same project twice.""" - mute1 = NotificationProjectMute( - user_id=sample_user.id, - project_id=sample_project.id - ) + mute1 = NotificationProjectMute(user_id=sample_user.id, project_id=sample_project.id) test_db_session.add(mute1) test_db_session.commit() # Try to create duplicate - mute2 = NotificationProjectMute( - user_id=sample_user.id, - project_id=sample_project.id - ) + mute2 = NotificationProjectMute(user_id=sample_user.id, project_id=sample_project.id) test_db_session.add(mute2) with pytest.raises(Exception): # SQLAlchemy will raise an IntegrityError @@ -257,18 +219,12 @@ def test_project_mute_unique_constraint(self, test_db_session, sample_user, samp def test_project_mute_cascade_delete_user(self, test_db_session, sample_project): """Test that mutes are deleted when user is deleted.""" - user = User( - email="delete@example.com", - password_hash="hashed_password" - ) + user = User(email="delete@example.com", password_hash="hashed_password") test_db_session.add(user) test_db_session.commit() test_db_session.refresh(user) - mute = NotificationProjectMute( - user_id=user.id, - project_id=sample_project.id - ) + mute = NotificationProjectMute(user_id=user.id, project_id=sample_project.id) test_db_session.add(mute) test_db_session.commit() @@ -289,16 +245,13 @@ def test_project_mute_cascade_delete_project(self, test_db_session, sample_user, status=ProjectStatus.DRAFT, org_id=sample_org.id, idea_text="Delete idea", - created_by=sample_user.id + created_by=sample_user.id, ) test_db_session.add(project) test_db_session.commit() test_db_session.refresh(project) - mute = NotificationProjectMute( - user_id=sample_user.id, - project_id=project.id - ) + mute = NotificationProjectMute(user_id=sample_user.id, project_id=project.id) test_db_session.add(mute) test_db_session.commit() @@ -312,10 +265,7 @@ def test_project_mute_cascade_delete_project(self, test_db_session, sample_user, def test_project_mute_repr(self, test_db_session, sample_user, sample_project): """Test string representation of NotificationProjectMute.""" - mute = NotificationProjectMute( - user_id=sample_user.id, - project_id=sample_project.id - ) + mute = NotificationProjectMute(user_id=sample_user.id, project_id=sample_project.id) test_db_session.add(mute) test_db_session.commit() test_db_session.refresh(mute) diff --git a/backend/tests/test_notification_service.py b/backend/tests/test_notification_service.py index b94832a..d682155 100644 --- a/backend/tests/test_notification_service.py +++ b/backend/tests/test_notification_service.py @@ -4,19 +4,19 @@ Verifies notification preference management, project muting, thread watching, and event enqueueing functionality. """ + import pytest from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from uuid import uuid4 from app.database import Base from app.models import ( - User, + NotificationChannel, Organization, Project, - ProjectType, ProjectStatus, - NotificationChannel, + ProjectType, + User, ) from app.services.notification_service import NotificationService @@ -57,10 +57,7 @@ def test_db(test_engine): @pytest.fixture def sample_user(test_db): """Create a sample user.""" - user = User( - email="testuser@example.com", - password_hash="hashed_password" - ) + user = User(email="testuser@example.com", password_hash="hashed_password") test_db.add(user) test_db.commit() test_db.refresh(user) @@ -87,7 +84,7 @@ def sample_project(test_db, sample_org, sample_user): status=ProjectStatus.DRAFT, org_id=sample_org.id, idea_text="Test idea", - created_by=sample_user.id + created_by=sample_user.id, ) test_db.add(project) test_db.commit() @@ -101,10 +98,7 @@ class TestNotificationPreferences: def test_create_preference(self, test_db, sample_user): """Test creating a notification preference.""" pref = NotificationService.create_preference( - db=test_db, - user_id=sample_user.id, - channel=NotificationChannel.EMAIL, - enabled=True + db=test_db, user_id=sample_user.id, channel=NotificationChannel.EMAIL, enabled=True ) assert pref.id is not None @@ -114,33 +108,17 @@ def test_create_preference(self, test_db, sample_user): def test_get_user_preferences(self, test_db, sample_user): """Test getting all preferences for a user.""" - NotificationService.create_preference( - db=test_db, - user_id=sample_user.id, - channel=NotificationChannel.EMAIL - ) - NotificationService.create_preference( - db=test_db, - user_id=sample_user.id, - channel=NotificationChannel.SLACK - ) + NotificationService.create_preference(db=test_db, user_id=sample_user.id, channel=NotificationChannel.EMAIL) + NotificationService.create_preference(db=test_db, user_id=sample_user.id, channel=NotificationChannel.SLACK) prefs = NotificationService.get_user_preferences(test_db, sample_user.id) assert len(prefs) == 2 def test_get_specific_preference(self, test_db, sample_user): """Test getting a specific preference.""" - NotificationService.create_preference( - db=test_db, - user_id=sample_user.id, - channel=NotificationChannel.EMAIL - ) + NotificationService.create_preference(db=test_db, user_id=sample_user.id, channel=NotificationChannel.EMAIL) - pref = NotificationService.get_user_preference( - test_db, - sample_user.id, - NotificationChannel.EMAIL - ) + pref = NotificationService.get_user_preference(test_db, sample_user.id, NotificationChannel.EMAIL) assert pref is not None assert pref.channel == NotificationChannel.EMAIL @@ -148,26 +126,17 @@ def test_get_specific_preference(self, test_db, sample_user): def test_update_preference(self, test_db, sample_user): """Test updating a preference.""" pref = NotificationService.create_preference( - db=test_db, - user_id=sample_user.id, - channel=NotificationChannel.EMAIL, - enabled=True + db=test_db, user_id=sample_user.id, channel=NotificationChannel.EMAIL, enabled=True ) - updated = NotificationService.update_preference( - db=test_db, - preference_id=pref.id, - enabled=False - ) + updated = NotificationService.update_preference(db=test_db, preference_id=pref.id, enabled=False) assert updated.enabled is False def test_delete_preference(self, test_db, sample_user): """Test deleting a preference.""" pref = NotificationService.create_preference( - db=test_db, - user_id=sample_user.id, - channel=NotificationChannel.EMAIL + db=test_db, user_id=sample_user.id, channel=NotificationChannel.EMAIL ) result = NotificationService.delete_preference(test_db, pref.id) @@ -183,11 +152,7 @@ class TestProjectMutes: def test_mute_project(self, test_db, sample_user, sample_project): """Test muting a project.""" - mute = NotificationService.mute_project( - db=test_db, - user_id=sample_user.id, - project_id=sample_project.id - ) + mute = NotificationService.mute_project(db=test_db, user_id=sample_user.id, project_id=sample_project.id) assert mute.id is not None assert mute.user_id == sample_user.id @@ -195,44 +160,20 @@ def test_mute_project(self, test_db, sample_user, sample_project): def test_is_project_muted(self, test_db, sample_user, sample_project): """Test checking if a project is muted.""" - assert NotificationService.is_project_muted( - test_db, - sample_user.id, - sample_project.id - ) is False - - NotificationService.mute_project( - test_db, - sample_user.id, - sample_project.id - ) + assert NotificationService.is_project_muted(test_db, sample_user.id, sample_project.id) is False + + NotificationService.mute_project(test_db, sample_user.id, sample_project.id) - assert NotificationService.is_project_muted( - test_db, - sample_user.id, - sample_project.id - ) is True + assert NotificationService.is_project_muted(test_db, sample_user.id, sample_project.id) is True def test_unmute_project(self, test_db, sample_user, sample_project): """Test unmuting a project.""" - NotificationService.mute_project( - test_db, - sample_user.id, - sample_project.id - ) + NotificationService.mute_project(test_db, sample_user.id, sample_project.id) - result = NotificationService.unmute_project( - test_db, - sample_user.id, - sample_project.id - ) + result = NotificationService.unmute_project(test_db, sample_user.id, sample_project.id) assert result is True - assert NotificationService.is_project_muted( - test_db, - sample_user.id, - sample_project.id - ) is False + assert NotificationService.is_project_muted(test_db, sample_user.id, sample_project.id) is False def test_get_user_project_mutes(self, test_db, sample_user, sample_project, sample_org): """Test getting all muted projects for a user.""" @@ -243,7 +184,7 @@ def test_get_user_project_mutes(self, test_db, sample_user, sample_project, samp status=ProjectStatus.DRAFT, org_id=sample_org.id, idea_text="Idea 2", - created_by=sample_user.id + created_by=sample_user.id, ) test_db.add(project2) test_db.commit() @@ -267,7 +208,7 @@ def test_enqueue_notification_event(self, test_db, sample_user, sample_project): project_id=sample_project.id, title="New comment", body="Someone commented on your thread", - recipients=[sample_user.id] + recipients=[sample_user.id], ) assert job_id is not None diff --git a/backend/tests/test_oauth_providers.py b/backend/tests/test_oauth_providers.py index e571f4a..3bb0d21 100644 --- a/backend/tests/test_oauth_providers.py +++ b/backend/tests/test_oauth_providers.py @@ -3,9 +3,9 @@ Tests provider initialization, registration, and normalization functions. """ + import pytest -from app.schemas.oauth import NormalizedUserInfo from app.auth.providers import ( KNOWN_PROVIDERS, get_configured_providers, @@ -15,6 +15,7 @@ normalize_google_userinfo, normalize_userinfo, ) +from app.schemas.oauth import NormalizedUserInfo class TestNormalizedUserInfo: diff --git a/backend/tests/test_oauth_routes.py b/backend/tests/test_oauth_routes.py index d10e3e9..61e00fe 100644 --- a/backend/tests/test_oauth_routes.py +++ b/backend/tests/test_oauth_routes.py @@ -6,15 +6,16 @@ - GET /auth/login/{provider} - Initiate OAuth login - GET /auth/callback/{provider} - Handle OAuth callback """ -import pytest + from unittest.mock import AsyncMock, MagicMock, patch + +import pytest from fastapi import status from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models.user import User from app.models.organization import Organization -from app.schemas.oauth import NormalizedUserInfo +from app.models.user import User @pytest.fixture(autouse=True) @@ -28,9 +29,7 @@ def mock_permitted_domains(): class TestListProviders: """Tests for GET /auth/providers endpoint.""" - def test_returns_empty_list_when_no_providers_configured( - self, client: TestClient - ): + def test_returns_empty_list_when_no_providers_configured(self, client: TestClient): """Returns empty list when no OAuth providers are configured.""" # In test environment, no providers are typically configured with patch("app.routers.auth.get_configured_providers", return_value=[]): @@ -108,9 +107,7 @@ def test_redirects_to_provider_when_configured(self, client: TestClient): class TestOAuthCallback: """Tests for GET /auth/callback/{provider_slug} endpoint.""" - def test_error_parameter_redirects_to_frontend_with_error( - self, client: TestClient - ): + def test_error_parameter_redirects_to_frontend_with_error(self, client: TestClient): """Redirects to frontend with error when OAuth provider returns error.""" response = client.get( "/api/v1/auth/callback/google", @@ -132,9 +129,7 @@ def test_unknown_provider_returns_400(self, client: TestClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Unknown OAuth provider" in response.json()["detail"] - def test_successful_callback_creates_user_and_sets_cookie( - self, client: TestClient, db: Session - ): + def test_successful_callback_creates_user_and_sets_cookie(self, client: TestClient, db: Session): """Successful callback creates user, org, and sets session cookie.""" # Mock the OAuth client and token exchange mock_client = MagicMock() @@ -170,11 +165,7 @@ def test_successful_callback_creates_user_and_sets_cookie( assert user.password_hash is None # OAuth users don't have passwords # Verify org was created for new user - org = ( - db.query(Organization) - .filter(Organization.name.like("%New User%")) - .first() - ) + org = db.query(Organization).filter(Organization.name.like("%New User%")).first() assert org is not None def test_successful_callback_links_identity_to_existing_user( @@ -215,9 +206,7 @@ def test_successful_callback_links_identity_to_existing_user( org_count_after = db.query(Organization).count() assert org_count_after == org_count_before - def test_github_callback_fetches_user_info( - self, client: TestClient, db: Session - ): + def test_github_callback_fetches_user_info(self, client: TestClient, db: Session): """GitHub callback fetches user info from API.""" mock_client = MagicMock() mock_token = {"access_token": "test_github_token"} @@ -247,9 +236,7 @@ def test_github_callback_fetches_user_info( # Verify GitHub API was called mock_client.get.assert_called() - def test_github_callback_fetches_private_email( - self, client: TestClient, db: Session - ): + def test_github_callback_fetches_private_email(self, client: TestClient, db: Session): """GitHub callback fetches email from /user/emails when primary email is private.""" mock_client = MagicMock() mock_token = {"access_token": "test_github_token"} @@ -292,9 +279,7 @@ def test_github_callback_fetches_private_email( class TestCookieAuthentication: """Tests for cookie-based authentication in get_current_user.""" - def test_auth_me_works_with_bearer_token( - self, client: TestClient, test_user: User, test_org, auth_headers - ): + def test_auth_me_works_with_bearer_token(self, client: TestClient, test_user: User, test_org, auth_headers): """GET /auth/me works with bearer token in Authorization header.""" # test_org ensures user has an organization headers = auth_headers(test_user) @@ -303,9 +288,7 @@ def test_auth_me_works_with_bearer_token( assert response.status_code == status.HTTP_200_OK assert response.json()["email"] == test_user.email - def test_auth_me_works_with_session_cookie( - self, client: TestClient, test_user: User, test_org, db: Session - ): + def test_auth_me_works_with_session_cookie(self, client: TestClient, test_user: User, test_org, db: Session): """GET /auth/me works with session cookie.""" from app.auth.utils import create_access_token @@ -331,13 +314,12 @@ def test_bearer_token_takes_precedence_over_cookie( self, client: TestClient, test_user: User, test_org, auth_headers, db: Session ): """Bearer token takes precedence when both token and cookie are present.""" - from app.auth.utils import create_access_token - from app.services.org_service import OrgService + from app.auth.utils import create_access_token, hash_password # test_org ensures test_user has an organization # Create a different user for the cookie from app.models.user import User - from app.auth.utils import hash_password + from app.services.org_service import OrgService other_user = User( email="other@example.com", @@ -370,9 +352,7 @@ def test_bearer_token_takes_precedence_over_cookie( class TestSessionToken: """Tests for GET /auth/me/token endpoint.""" - def test_returns_token_with_valid_cookie( - self, client: TestClient, test_user: User, test_org, db: Session - ): + def test_returns_token_with_valid_cookie(self, client: TestClient, test_user: User, test_org, db: Session): """GET /auth/me/token returns token when valid session cookie is present.""" from app.auth.utils import create_access_token @@ -422,9 +402,7 @@ def test_fails_with_nonexistent_user(self, client: TestClient, db: Session): assert response.status_code == status.HTTP_401_UNAUTHORIZED assert "User not found" in response.json()["detail"] - def test_ignores_bearer_token( - self, client: TestClient, test_user: User, test_org, auth_headers - ): + def test_ignores_bearer_token(self, client: TestClient, test_user: User, test_org, auth_headers): """GET /auth/me/token only uses cookie, ignores bearer token.""" # Pass only bearer token (no cookie) - should fail headers = auth_headers(test_user) @@ -437,9 +415,7 @@ def test_ignores_bearer_token( class TestLogout: """Tests for POST /auth/logout endpoint.""" - def test_logout_clears_session_cookie( - self, client: TestClient, test_user: User, test_org - ): + def test_logout_clears_session_cookie(self, client: TestClient, test_user: User, test_org): """Logout clears the session cookie.""" from app.auth.utils import create_access_token @@ -460,9 +436,7 @@ def test_logout_clears_session_cookie( # Cookie deletion sets value to empty string assert cookie == "" or "mfbt_session" not in response.cookies - def test_logout_works_with_bearer_token( - self, client: TestClient, test_user: User, test_org, auth_headers - ): + def test_logout_works_with_bearer_token(self, client: TestClient, test_user: User, test_org, auth_headers): """Logout works with bearer token authentication.""" headers = auth_headers(test_user) @@ -518,9 +492,7 @@ def test_health_is_public_endpoint(self, client: TestClient): class TestRedirectValidation: """Tests for redirect URI validation in OAuth flows.""" - def test_success_redirect_uses_frontend_url( - self, client: TestClient, db: Session - ): + def test_success_redirect_uses_frontend_url(self, client: TestClient, db: Session): """Successful callback redirects to configured frontend URL.""" mock_client = MagicMock() mock_token = { @@ -563,9 +535,7 @@ def test_error_redirect_includes_error_parameter(self, client: TestClient): class TestStructuredLogging: """Tests for structured auth logging.""" - def test_login_success_logs_event( - self, client: TestClient, db: Session, caplog - ): + def test_login_success_logs_event(self, client: TestClient, db: Session, caplog): """Successful login logs structured event with user_id and provider.""" import logging @@ -641,11 +611,10 @@ def test_login_initiated_logs_event(self, client: TestClient, caplog): assert any("login_initiated" in msg for msg in log_messages) assert any("google" in msg for msg in log_messages) - def test_logout_logs_event( - self, client: TestClient, test_user: User, test_org, caplog - ): + def test_logout_logs_event(self, client: TestClient, test_user: User, test_org, caplog): """Logout logs structured event with user info.""" import logging + from app.auth.utils import create_access_token token = create_access_token(data={"sub": test_user.email}) diff --git a/backend/tests/test_org_chats.py b/backend/tests/test_org_chats.py index 103eb64..7ea7138 100644 --- a/backend/tests/test_org_chats.py +++ b/backend/tests/test_org_chats.py @@ -1,13 +1,13 @@ """Tests for org-scoped chats (project-chat discussions) user isolation.""" + import pytest from sqlalchemy.orm import Session -from app.models.user import User -from app.models.organization import Organization from app.models.org_membership import OrgMembership, OrgRole -from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType -from app.services.user_service import UserService +from app.models.organization import Organization +from app.models.user import User from app.services.project_chat_service import ProjectChatService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_org_endpoints.py b/backend/tests/test_org_endpoints.py index 91e2807..9835a42 100644 --- a/backend/tests/test_org_endpoints.py +++ b/backend/tests/test_org_endpoints.py @@ -4,16 +4,19 @@ Tests the organization API routes including auto-creation on registration and listing user organizations. """ + +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock, AsyncMock from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from app.database import Base, get_db, get_async_db +from app.database import Base, get_async_db, get_db from app.main import app + # Import all models to ensure they're registered with SQLAlchemy metadata -from app.models import User, Organization, OrgMembership, Job +from app.models import User @pytest.fixture(autouse=True) @@ -28,8 +31,8 @@ def mock_permitted_domains(): def test_engine(): """Create a file-based SQLite database engine for testing.""" # Use a file-based database to ensure all connections see the same schema - import tempfile import os + import tempfile # Create a temporary file for the database fd, db_path = tempfile.mkstemp(suffix=".db") @@ -123,9 +126,7 @@ def test_registration_creates_personal_org(self, client: TestClient, test_db): token = login_response.json()["access_token"] # List organizations - orgs_response = client.get( - "/api/v1/orgs", headers={"Authorization": f"Bearer {token}"} - ) + orgs_response = client.get("/api/v1/orgs", headers={"Authorization": f"Bearer {token}"}) assert orgs_response.status_code == 200 orgs = orgs_response.json() @@ -167,9 +168,7 @@ def test_registration_without_display_name_uses_email(self, client: TestClient, token = login_response.json()["access_token"] # List orgs - orgs_response = client.get( - "/api/v1/orgs", headers={"Authorization": f"Bearer {token}"} - ) + orgs_response = client.get("/api/v1/orgs", headers={"Authorization": f"Bearer {token}"}) orgs = orgs_response.json() # Org name should use email @@ -186,9 +185,7 @@ def test_list_orgs_requires_authentication(self, client: TestClient): def test_list_orgs_with_invalid_token(self, client: TestClient): """Test that listing orgs with invalid token fails.""" - response = client.get( - "/api/v1/orgs", headers={"Authorization": "Bearer invalid_token"} - ) + response = client.get("/api/v1/orgs", headers={"Authorization": "Bearer invalid_token"}) assert response.status_code == 401 def test_list_orgs_returns_empty_for_user_with_no_orgs(self, client: TestClient, test_db): @@ -223,9 +220,7 @@ def test_list_orgs_returns_empty_for_user_with_no_orgs(self, client: TestClient, token = login_response.json()["access_token"] # List orgs - should have the auto-created org - orgs_response = client.get( - "/api/v1/orgs", headers={"Authorization": f"Bearer {token}"} - ) + orgs_response = client.get("/api/v1/orgs", headers={"Authorization": f"Bearer {token}"}) assert orgs_response.status_code == 200 orgs = orgs_response.json() assert len(orgs) >= 1 # At least the auto-created org @@ -257,9 +252,7 @@ def test_list_orgs_returns_correct_structure(self, client: TestClient, test_db): token = login_response.json()["access_token"] # List orgs - orgs_response = client.get( - "/api/v1/orgs", headers={"Authorization": f"Bearer {token}"} - ) + orgs_response = client.get("/api/v1/orgs", headers={"Authorization": f"Bearer {token}"}) assert orgs_response.status_code == 200 orgs = orgs_response.json() diff --git a/backend/tests/test_org_service.py b/backend/tests/test_org_service.py index 74aa8e9..0b47867 100644 --- a/backend/tests/test_org_service.py +++ b/backend/tests/test_org_service.py @@ -3,13 +3,14 @@ Tests organization creation, membership management, and retrieval operations. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.organization import Organization from app.models.org_membership import OrgMembership, OrgRole +from app.models.organization import Organization from app.models.user import User from app.services.org_service import OrgService from app.services.user_service import UserService @@ -85,11 +86,7 @@ def test_create_org_with_owner_persists_to_db(self, test_db_session: Session, te # Query from database to verify persistence db_org = test_db_session.query(Organization).filter(Organization.id == org.id).first() - db_membership = ( - test_db_session.query(OrgMembership) - .filter(OrgMembership.id == membership.id) - .first() - ) + db_membership = test_db_session.query(OrgMembership).filter(OrgMembership.id == membership.id).first() assert db_org is not None assert db_org.name == org_name @@ -103,9 +100,7 @@ class TestOrgServiceGetUserOrgs: def test_get_user_orgs_returns_single_org(self, test_db_session: Session, test_user: User): """Test getting orgs for a user with one organization.""" org_name = "User's Org" - org, _ = OrgService.create_org_with_owner( - db=test_db_session, name=org_name, owner_user_id=test_user.id - ) + org, _ = OrgService.create_org_with_owner(db=test_db_session, name=org_name, owner_user_id=test_user.id) orgs_with_roles = OrgService.get_user_orgs(db=test_db_session, user_id=test_user.id) @@ -118,14 +113,10 @@ def test_get_user_orgs_returns_single_org(self, test_db_session: Session, test_u def test_get_user_orgs_returns_multiple_orgs(self, test_db_session: Session, test_user: User): """Test getting orgs for a user with multiple organizations.""" # Create first org as owner - org1, _ = OrgService.create_org_with_owner( - db=test_db_session, name="First Org", owner_user_id=test_user.id - ) + org1, _ = OrgService.create_org_with_owner(db=test_db_session, name="First Org", owner_user_id=test_user.id) # Create second org as owner - org2, _ = OrgService.create_org_with_owner( - db=test_db_session, name="Second Org", owner_user_id=test_user.id - ) + org2, _ = OrgService.create_org_with_owner(db=test_db_session, name="Second Org", owner_user_id=test_user.id) orgs_with_roles = OrgService.get_user_orgs(db=test_db_session, user_id=test_user.id) @@ -137,9 +128,7 @@ def test_get_user_orgs_returns_multiple_orgs(self, test_db_session: Session, tes assert org2.id in org_ids assert roles == {OrgRole.OWNER} - def test_get_user_orgs_returns_empty_for_user_with_no_orgs( - self, test_db_session: Session, test_user: User - ): + def test_get_user_orgs_returns_empty_for_user_with_no_orgs(self, test_db_session: Session, test_user: User): """Test getting orgs for a user with no organizations.""" # Create another user with an org other_user = UserService.create_user( @@ -147,9 +136,7 @@ def test_get_user_orgs_returns_empty_for_user_with_no_orgs( email="other@example.com", password="password123", ) - OrgService.create_org_with_owner( - db=test_db_session, name="Other's Org", owner_user_id=other_user.id - ) + OrgService.create_org_with_owner(db=test_db_session, name="Other's Org", owner_user_id=other_user.id) # test_user should have no orgs orgs_with_roles = OrgService.get_user_orgs(db=test_db_session, user_id=test_user.id) @@ -166,31 +153,23 @@ def test_get_org_membership_returns_membership(self, test_db_session: Session, t db=test_db_session, name="Test Org", owner_user_id=test_user.id ) - retrieved_membership = OrgService.get_org_membership( - db=test_db_session, org_id=org.id, user_id=test_user.id - ) + retrieved_membership = OrgService.get_org_membership(db=test_db_session, org_id=org.id, user_id=test_user.id) assert retrieved_membership is not None assert retrieved_membership.id == membership.id assert retrieved_membership.role == OrgRole.OWNER - def test_get_org_membership_returns_none_for_non_member( - self, test_db_session: Session, test_user: User - ): + def test_get_org_membership_returns_none_for_non_member(self, test_db_session: Session, test_user: User): """Test retrieving membership for a user not in the org.""" other_user = UserService.create_user( db=test_db_session, email="other@example.com", password="password123", ) - org, _ = OrgService.create_org_with_owner( - db=test_db_session, name="Test Org", owner_user_id=test_user.id - ) + org, _ = OrgService.create_org_with_owner(db=test_db_session, name="Test Org", owner_user_id=test_user.id) # other_user is not a member - membership = OrgService.get_org_membership( - db=test_db_session, org_id=org.id, user_id=other_user.id - ) + membership = OrgService.get_org_membership(db=test_db_session, org_id=org.id, user_id=other_user.id) assert membership is None @@ -200,9 +179,7 @@ class TestOrgServiceGetOrgById: def test_get_org_by_id_returns_org(self, test_db_session: Session, test_user: User): """Test retrieving an organization by ID.""" - org, _ = OrgService.create_org_with_owner( - db=test_db_session, name="Test Org", owner_user_id=test_user.id - ) + org, _ = OrgService.create_org_with_owner(db=test_db_session, name="Test Org", owner_user_id=test_user.id) retrieved_org = OrgService.get_org_by_id(db=test_db_session, org_id=org.id) diff --git a/backend/tests/test_permissions.py b/backend/tests/test_permissions.py index da11d78..9a74644 100644 --- a/backend/tests/test_permissions.py +++ b/backend/tests/test_permissions.py @@ -7,11 +7,12 @@ - Permission helper functions - Integration tests for permission-protected endpoints """ + import pytest from fastapi import HTTPException from sqlalchemy.orm import Session -from app.models import OrgRole, Organization, OrgMembership, User +from app.models import OrgMembership, OrgRole from app.permissions import ( OrgContext, get_org_context, @@ -37,9 +38,7 @@ def test_get_org_context_succeeds_for_member(self, db: Session): password="password123", display_name="Test Member", ) - org, membership = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + org, membership = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # Change role to member for this test membership.role = OrgRole.MEMBER db.commit() @@ -89,9 +88,7 @@ def test_get_org_context_raises_404_user_not_member(self, db: Session): password="password123", display_name="Non Member", ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) # Action & Assert: non_member should get 404 (privacy best practice) with pytest.raises(HTTPException) as exc_info: @@ -104,36 +101,20 @@ def test_get_org_context_raises_404_user_not_member(self, db: Session): def test_get_org_context_with_all_roles(self, db: Session): """Test that get_org_context works for all role types.""" # Setup: Create org and users with each role - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) - org, owner_membership = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") + org, owner_membership = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) # Create users with other roles - admin = UserService.create_user( - db=db, email="admin@example.com", password="pass", display_name="Admin" - ) - admin_membership = OrgMembership( - org_id=org.id, user_id=admin.id, role=OrgRole.ADMIN - ) + admin = UserService.create_user(db=db, email="admin@example.com", password="pass", display_name="Admin") + admin_membership = OrgMembership(org_id=org.id, user_id=admin.id, role=OrgRole.ADMIN) db.add(admin_membership) - member = UserService.create_user( - db=db, email="member@example.com", password="pass", display_name="Member" - ) - member_membership = OrgMembership( - org_id=org.id, user_id=member.id, role=OrgRole.MEMBER - ) + member = UserService.create_user(db=db, email="member@example.com", password="pass", display_name="Member") + member_membership = OrgMembership(org_id=org.id, user_id=member.id, role=OrgRole.MEMBER) db.add(member_membership) - viewer = UserService.create_user( - db=db, email="viewer@example.com", password="pass", display_name="Viewer" - ) - viewer_membership = OrgMembership( - org_id=org.id, user_id=viewer.id, role=OrgRole.VIEWER - ) + viewer = UserService.create_user(db=db, email="viewer@example.com", password="pass", display_name="Viewer") + viewer_membership = OrgMembership(org_id=org.id, user_id=viewer.id, role=OrgRole.VIEWER) db.add(viewer_membership) db.commit() @@ -169,75 +150,51 @@ def test_role_rank_ordering(self): def test_is_org_admin_or_higher_true_cases(self, db: Session): """Test is_org_admin_or_higher returns True for ADMIN and OWNER.""" - user = UserService.create_user( - db=db, email="test@example.com", password="pass", display_name="Test" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="test@example.com", password="pass", display_name="Test") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # Test OWNER - owner_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.OWNER - ) + owner_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.OWNER) owner_ctx = OrgContext(org=org, membership=owner_membership, user=user) assert is_org_admin_or_higher(owner_ctx) is True # Test ADMIN - admin_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.ADMIN - ) + admin_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.ADMIN) admin_ctx = OrgContext(org=org, membership=admin_membership, user=user) assert is_org_admin_or_higher(admin_ctx) is True # Test MEMBER (should be False) - member_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.MEMBER - ) + member_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.MEMBER) member_ctx = OrgContext(org=org, membership=member_membership, user=user) assert is_org_admin_or_higher(member_ctx) is False # Test VIEWER (should be False) - viewer_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.VIEWER - ) + viewer_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.VIEWER) viewer_ctx = OrgContext(org=org, membership=viewer_membership, user=user) assert is_org_admin_or_higher(viewer_ctx) is False def test_is_org_member_or_higher_true_cases(self, db: Session): """Test is_org_member_or_higher returns True for MEMBER, ADMIN, and OWNER.""" - user = UserService.create_user( - db=db, email="test@example.com", password="pass", display_name="Test" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="test@example.com", password="pass", display_name="Test") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # Test OWNER - owner_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.OWNER - ) + owner_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.OWNER) owner_ctx = OrgContext(org=org, membership=owner_membership, user=user) assert is_org_member_or_higher(owner_ctx) is True # Test ADMIN - admin_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.ADMIN - ) + admin_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.ADMIN) admin_ctx = OrgContext(org=org, membership=admin_membership, user=user) assert is_org_member_or_higher(admin_ctx) is True # Test MEMBER - member_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.MEMBER - ) + member_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.MEMBER) member_ctx = OrgContext(org=org, membership=member_membership, user=user) assert is_org_member_or_higher(member_ctx) is True # Test VIEWER (should be False) - viewer_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.VIEWER - ) + viewer_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.VIEWER) viewer_ctx = OrgContext(org=org, membership=viewer_membership, user=user) assert is_org_member_or_higher(viewer_ctx) is False @@ -247,17 +204,11 @@ class TestPermissionHelpers: def test_require_org_role_succeeds_with_exact_role(self, db: Session): """Test that require_org_role succeeds when user has exact role.""" - user = UserService.create_user( - db=db, email="admin@example.com", password="pass", display_name="Admin" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="admin@example.com", password="pass", display_name="Admin") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # User has ADMIN role, require ADMIN - admin_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.ADMIN - ) + admin_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.ADMIN) admin_ctx = OrgContext(org=org, membership=admin_membership, user=user) # Should not raise @@ -265,17 +216,11 @@ def test_require_org_role_succeeds_with_exact_role(self, db: Session): def test_require_org_role_succeeds_with_higher_role(self, db: Session): """Test that require_org_role succeeds when user has higher role.""" - user = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # User has OWNER role, require ADMIN (owner is higher) - owner_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.OWNER - ) + owner_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.OWNER) owner_ctx = OrgContext(org=org, membership=owner_membership, user=user) # Should not raise @@ -285,17 +230,11 @@ def test_require_org_role_succeeds_with_higher_role(self, db: Session): def test_require_org_role_raises_403_for_insufficient_role(self, db: Session): """Test that require_org_role raises 403 when user role is insufficient.""" - user = UserService.create_user( - db=db, email="member@example.com", password="pass", display_name="Member" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="member@example.com", password="pass", display_name="Member") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # User has MEMBER role, require ADMIN - member_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.MEMBER - ) + member_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.MEMBER) member_ctx = OrgContext(org=org, membership=member_membership, user=user) # Should raise 403 @@ -307,17 +246,11 @@ def test_require_org_role_raises_403_for_insufficient_role(self, db: Session): def test_require_org_role_raises_403_for_viewer(self, db: Session): """Test that require_org_role raises 403 for viewer when member required.""" - user = UserService.create_user( - db=db, email="viewer@example.com", password="pass", display_name="Viewer" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="viewer@example.com", password="pass", display_name="Viewer") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # User has VIEWER role, require MEMBER - viewer_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.VIEWER - ) + viewer_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.VIEWER) viewer_ctx = OrgContext(org=org, membership=viewer_membership, user=user) # Should raise 403 @@ -328,16 +261,10 @@ def test_require_org_role_raises_403_for_viewer(self, db: Session): def test_require_org_role_error_message_is_clear(self, db: Session): """Test that error messages indicate required role.""" - user = UserService.create_user( - db=db, email="test@example.com", password="pass", display_name="Test" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="test@example.com", password="pass", display_name="Test") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) - viewer_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.VIEWER - ) + viewer_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.VIEWER) viewer_ctx = OrgContext(org=org, membership=viewer_membership, user=user) # Test error message includes required role @@ -350,24 +277,16 @@ def test_require_org_role_error_message_is_clear(self, db: Session): def test_is_org_owner_true_for_owner_only(self, db: Session): """Test that is_org_owner returns True only for OWNER role.""" - user = UserService.create_user( - db=db, email="test@example.com", password="pass", display_name="Test" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="test@example.com", password="pass", display_name="Test") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # Test OWNER - owner_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.OWNER - ) + owner_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.OWNER) owner_ctx = OrgContext(org=org, membership=owner_membership, user=user) assert is_org_owner(owner_ctx) is True # Test ADMIN (should be False) - admin_membership = OrgMembership( - org_id=org.id, user_id=user.id, role=OrgRole.ADMIN - ) + admin_membership = OrgMembership(org_id=org.id, user_id=user.id, role=OrgRole.ADMIN) admin_ctx = OrgContext(org=org, membership=admin_membership, user=user) assert is_org_owner(admin_ctx) is False @@ -378,12 +297,8 @@ class TestOrgMembersEndpoint: def test_list_org_members_requires_authentication(self, client, db: Session): """Test that endpoint requires authentication.""" # Setup: Create an org - user = UserService.create_user( - db=db, email="test@example.com", password="pass", display_name="Test" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=user.id - ) + user = UserService.create_user(db=db, email="test@example.com", password="pass", display_name="Test") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=user.id) # Action: Request without auth header response = client.get(f"/api/v1/orgs/{org.id}/members") @@ -391,17 +306,11 @@ def test_list_org_members_requires_authentication(self, client, db: Session): # Assert: Should get 401 assert response.status_code == 401 - def test_list_org_members_requires_org_membership( - self, client, db: Session, auth_headers - ): + def test_list_org_members_requires_org_membership(self, client, db: Session, auth_headers): """Test that endpoint requires user to be org member.""" # Setup: Create two users and an org - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) # Create a different user (not in org) non_member = UserService.create_user( @@ -423,21 +332,13 @@ def test_list_org_members_requires_org_membership( def test_list_org_members_requires_admin_role(self, client, db: Session, auth_headers): """Test that endpoint requires admin role.""" # Setup: Create org and user with MEMBER role - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) # Add a member (not admin) - member = UserService.create_user( - db=db, email="member@example.com", password="pass", display_name="Member" - ) + member = UserService.create_user(db=db, email="member@example.com", password="pass", display_name="Member") member.email_verified = True - member_membership = OrgMembership( - org_id=org.id, user_id=member.id, role=OrgRole.MEMBER - ) + member_membership = OrgMembership(org_id=org.id, user_id=member.id, role=OrgRole.MEMBER) db.add(member_membership) db.commit() @@ -453,28 +354,16 @@ def test_list_org_members_requires_admin_role(self, client, db: Session, auth_he def test_list_org_members_succeeds_for_admin(self, client, db: Session, auth_headers): """Test that endpoint succeeds for admin role.""" # Setup: Create org with admin and 3 total members - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) - org, owner_membership = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") + org, owner_membership = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) - admin = UserService.create_user( - db=db, email="admin@example.com", password="pass", display_name="Admin" - ) + admin = UserService.create_user(db=db, email="admin@example.com", password="pass", display_name="Admin") admin.email_verified = True - admin_membership = OrgMembership( - org_id=org.id, user_id=admin.id, role=OrgRole.ADMIN - ) + admin_membership = OrgMembership(org_id=org.id, user_id=admin.id, role=OrgRole.ADMIN) db.add(admin_membership) - member = UserService.create_user( - db=db, email="member@example.com", password="pass", display_name="Member" - ) - member_membership = OrgMembership( - org_id=org.id, user_id=member.id, role=OrgRole.MEMBER - ) + member = UserService.create_user(db=db, email="member@example.com", password="pass", display_name="Member") + member_membership = OrgMembership(org_id=org.id, user_id=member.id, role=OrgRole.MEMBER) db.add(member_membership) db.commit() @@ -495,13 +384,9 @@ def test_list_org_members_succeeds_for_admin(self, client, db: Session, auth_hea def test_list_org_members_succeeds_for_owner(self, client, db: Session, auth_headers): """Test that endpoint succeeds for owner role.""" # Setup: Create org - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") owner.email_verified = True - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) db.commit() headers = auth_headers(owner) @@ -514,18 +399,12 @@ def test_list_org_members_succeeds_for_owner(self, client, db: Session, auth_hea members = response.json() assert len(members) >= 1 # At least the owner - def test_list_org_members_returns_correct_schema( - self, client, db: Session, auth_headers - ): + def test_list_org_members_returns_correct_schema(self, client, db: Session, auth_headers): """Test that response contains all required fields.""" # Setup: Create org with admin - admin = UserService.create_user( - db=db, email="admin@example.com", password="pass", display_name="Admin User" - ) + admin = UserService.create_user(db=db, email="admin@example.com", password="pass", display_name="Admin User") admin.email_verified = True - org, membership = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=admin.id - ) + org, membership = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=admin.id) membership.role = OrgRole.ADMIN db.commit() @@ -557,30 +436,20 @@ def test_list_org_members_returns_correct_schema( def test_list_org_members_includes_all_roles(self, client, db: Session, auth_headers): """Test that all 4 role levels are returned correctly.""" # Setup: Create org with members of each role - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") owner.email_verified = True - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) # Add admin - admin = UserService.create_user( - db=db, email="admin@example.com", password="pass", display_name="Admin" - ) + admin = UserService.create_user(db=db, email="admin@example.com", password="pass", display_name="Admin") db.add(OrgMembership(org_id=org.id, user_id=admin.id, role=OrgRole.ADMIN)) # Add member - member = UserService.create_user( - db=db, email="member@example.com", password="pass", display_name="Member" - ) + member = UserService.create_user(db=db, email="member@example.com", password="pass", display_name="Member") db.add(OrgMembership(org_id=org.id, user_id=member.id, role=OrgRole.MEMBER)) # Add viewer - viewer = UserService.create_user( - db=db, email="viewer@example.com", password="pass", display_name="Viewer" - ) + viewer = UserService.create_user(db=db, email="viewer@example.com", password="pass", display_name="Viewer") db.add(OrgMembership(org_id=org.id, user_id=viewer.id, role=OrgRole.VIEWER)) db.commit() @@ -610,16 +479,10 @@ def test_list_org_members_includes_all_roles(self, client, db: Session, auth_hea def test_viewer_cannot_list_members(self, client, db: Session, auth_headers): """Test that viewer role cannot list members.""" # Setup: Create org and viewer - owner = UserService.create_user( - db=db, email="owner@example.com", password="pass", display_name="Owner" - ) - org, _ = OrgService.create_org_with_owner( - db=db, name="Test Org", owner_user_id=owner.id - ) + owner = UserService.create_user(db=db, email="owner@example.com", password="pass", display_name="Owner") + org, _ = OrgService.create_org_with_owner(db=db, name="Test Org", owner_user_id=owner.id) - viewer = UserService.create_user( - db=db, email="viewer@example.com", password="pass", display_name="Viewer" - ) + viewer = UserService.create_user(db=db, email="viewer@example.com", password="pass", display_name="Viewer") db.add(OrgMembership(org_id=org.id, user_id=viewer.id, role=OrgRole.VIEWER)) db.commit() diff --git a/backend/tests/test_phase_container_service.py b/backend/tests/test_phase_container_service.py index 93673b1..1363396 100644 --- a/backend/tests/test_phase_container_service.py +++ b/backend/tests/test_phase_container_service.py @@ -3,28 +3,28 @@ Tests the service layer for phase container operations including CRUD, container assignment, cascade archive, and validation. """ + +from uuid import uuid4 + import pytest -from datetime import datetime, timezone from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User -from app.models.organization import Organization -from app.models.project import Project, ProjectStatus -from app.models.phase_container import PhaseContainer from app.models.brainstorming_phase import ( BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype, ) from app.models.final_spec import FinalSpec +from app.models.organization import Organization +from app.models.project import Project, ProjectStatus from app.models.project_chat import ProjectChat, ProjectChatVisibility -from app.services.user_service import UserService -from app.services.phase_container_service import PhaseContainerService +from app.models.user import User from app.services.brainstorming_phase_service import BrainstormingPhaseService +from app.services.phase_container_service import PhaseContainerService from app.services.project_chat_service import ProjectChatService +from app.services.user_service import UserService @pytest.fixture @@ -74,9 +74,7 @@ def sample_org(test_db_session: Session, sample_user: User) -> Organization: @pytest.fixture -def sample_project( - test_db_session: Session, sample_org: Organization, sample_user: User -) -> Project: +def sample_project(test_db_session: Session, sample_org: Organization, sample_user: User) -> Project: """Create a sample project for testing.""" project = Project( org_id=sample_org.id, @@ -166,8 +164,7 @@ def test_url_identifier( # url_identifier should be slug-shortid format assert container.short_id in container.url_identifier - assert "my-test-container" in container.url_identifier.lower() or \ - "my" in container.url_identifier.lower() + assert "my-test-container" in container.url_identifier.lower() or "my" in container.url_identifier.lower() class TestGetContainer: @@ -1355,9 +1352,7 @@ def test_standalone_phase_creates_auto_container( assert phase.container_sequence == 1 # Verify container was created - container = PhaseContainerService.get_container( - test_db_session, phase.container_id - ) + container = PhaseContainerService.get_container(test_db_session, phase.container_id) assert container is not None assert container.title == "Auto Container Phase" assert container.project_id == sample_project.id @@ -1486,9 +1481,7 @@ def test_standalone_auto_container_order_index( user_id=sample_user.id, ) - container = PhaseContainerService.get_container( - test_db_session, phase.container_id - ) + container = PhaseContainerService.get_container(test_db_session, phase.container_id) assert container is not None assert container.order_index == 4 diff --git a/backend/tests/test_phase_containers_router.py b/backend/tests/test_phase_containers_router.py index 02b8ffb..8d07439 100644 --- a/backend/tests/test_phase_containers_router.py +++ b/backend/tests/test_phase_containers_router.py @@ -13,17 +13,17 @@ from app.database import Base, get_db from app.main import app from app.models import ( - User, Organization, - Project, - ProjectType, - ProjectStatus, OrgMembership, OrgRole, + Project, ProjectRole, + ProjectStatus, + ProjectType, + User, ) +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType from app.models.phase_container import PhaseContainer -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType, PhaseSubtype from app.models.project_share import ProjectShare, ShareSubjectType from app.services.user_service import UserService @@ -245,9 +245,7 @@ def test_create_container_unauthenticated(self, client, sample_project): assert response.status_code == 401 - def test_create_container_non_member( - self, client, sample_project, other_user, other_user_headers - ): + def test_create_container_non_member(self, client, sample_project, other_user, other_user_headers): """Test that non-members cannot create containers.""" response = client.post( f"/api/v1/projects/{sample_project.id}/phase-containers", @@ -321,9 +319,7 @@ def test_list_containers_includes_archived( assert response.status_code == 200 assert len(response.json()) == 1 - def test_list_containers_ordered( - self, client, db, sample_project, user_headers - ): + def test_list_containers_ordered(self, client, db, sample_project, user_headers): """Test that containers are ordered by order_index.""" container1 = PhaseContainer( project_id=sample_project.id, @@ -394,9 +390,7 @@ def test_get_container_not_found(self, client, user_headers): assert response.status_code == 404 - def test_get_container_non_member( - self, client, sample_container, other_user_headers - ): + def test_get_container_non_member(self, client, sample_container, other_user_headers): """Test that non-members get 404 (privacy).""" response = client.get( f"/api/v1/phase-containers/{sample_container.id}", @@ -446,9 +440,7 @@ def test_update_container_partial(self, client, sample_container, user_headers): assert data["title"] == "New Title Only" assert data["order_index"] == original_order - def test_update_container_non_member( - self, client, sample_container, other_user_headers - ): + def test_update_container_non_member(self, client, sample_container, other_user_headers): """Test that non-members cannot update containers.""" response = client.patch( f"/api/v1/phase-containers/{sample_container.id}", @@ -473,9 +465,7 @@ def test_archive_container(self, client, sample_container, user_headers): data = response.json() assert data["archived_at"] is not None - def test_archive_container_already_archived( - self, client, db, sample_container, sample_user, user_headers - ): + def test_archive_container_already_archived(self, client, db, sample_container, sample_user, user_headers): """Test archiving already archived container returns 400.""" from datetime import datetime, timezone @@ -525,9 +515,7 @@ def test_archive_container_viewer_forbidden( class TestRestorePhaseContainer: """Tests for POST /api/v1/phase-containers/{identifier}/restore.""" - def test_restore_container( - self, client, db, sample_container, sample_user, user_headers - ): + def test_restore_container(self, client, db, sample_container, sample_user, user_headers): """Test restoring an archived container.""" from datetime import datetime, timezone @@ -569,9 +557,7 @@ def test_list_phases_empty(self, client, sample_container, user_headers): assert response.status_code == 200 assert response.json() == [] - def test_list_phases( - self, client, db, sample_container, sample_phase, user_headers - ): + def test_list_phases(self, client, db, sample_container, sample_phase, user_headers): """Test listing phases in a container.""" # Assign phase to container sample_phase.container_id = sample_container.id @@ -592,9 +578,7 @@ def test_list_phases( class TestAssignPhaseToContainer: """Tests for POST /api/v1/phases/{phase_id}/assign-to-container.""" - def test_assign_phase( - self, client, sample_container, sample_phase, user_headers - ): + def test_assign_phase(self, client, sample_container, sample_phase, user_headers): """Test assigning a phase to a container.""" response = client.post( f"/api/v1/phases/{sample_phase.id}/assign-to-container", @@ -606,9 +590,7 @@ def test_assign_phase( data = response.json() assert data["container_sequence"] is not None - def test_assign_phase_with_sequence( - self, client, sample_container, sample_phase, user_headers - ): + def test_assign_phase_with_sequence(self, client, sample_container, sample_phase, user_headers): """Test assigning a phase with explicit sequence.""" response = client.post( f"/api/v1/phases/{sample_phase.id}/assign-to-container", @@ -619,9 +601,7 @@ def test_assign_phase_with_sequence( assert response.status_code == 200 assert response.json()["container_sequence"] == 5 - def test_assign_phase_missing_container_id( - self, client, sample_phase, user_headers - ): + def test_assign_phase_missing_container_id(self, client, sample_phase, user_headers): """Test that missing container_id returns 400.""" response = client.post( f"/api/v1/phases/{sample_phase.id}/assign-to-container", @@ -632,9 +612,7 @@ def test_assign_phase_missing_container_id( assert response.status_code == 400 assert "container_id is required" in response.json()["detail"] - def test_assign_phase_container_not_found( - self, client, sample_phase, user_headers - ): + def test_assign_phase_container_not_found(self, client, sample_phase, user_headers): """Test assigning to non-existent container returns 404.""" response = client.post( f"/api/v1/phases/{sample_phase.id}/assign-to-container", @@ -644,9 +622,7 @@ def test_assign_phase_container_not_found( assert response.status_code == 404 - def test_assign_phase_cross_project( - self, client, db, sample_phase, sample_org, sample_user, user_headers - ): + def test_assign_phase_cross_project(self, client, db, sample_phase, sample_org, sample_user, user_headers): """Test assigning phase to container from different project.""" # Create another project other_project = Project( @@ -683,9 +659,7 @@ def test_assign_phase_cross_project( class TestRemovePhaseFromContainer: """Tests for POST /api/v1/phases/{phase_id}/remove-from-container.""" - def test_remove_phase( - self, client, db, sample_container, sample_phase, user_headers - ): + def test_remove_phase(self, client, db, sample_container, sample_phase, user_headers): """Test removing a phase from its container.""" # First assign the phase sample_phase.container_id = sample_container.id @@ -701,9 +675,7 @@ def test_remove_phase( data = response.json() assert data["container_sequence"] is None - def test_remove_phase_not_in_container( - self, client, sample_phase, user_headers - ): + def test_remove_phase_not_in_container(self, client, sample_phase, user_headers): """Test removing a phase that's not in a container returns 400.""" response = client.post( f"/api/v1/phases/{sample_phase.id}/remove-from-container", @@ -742,9 +714,7 @@ def test_get_extension_preview_not_found(self, client, user_headers): assert response.status_code == 404 - def test_get_extension_preview_non_member( - self, client, sample_container, other_user_headers - ): + def test_get_extension_preview_non_member(self, client, sample_container, other_user_headers): """Test that non-members get 404 (privacy).""" response = client.get( f"/api/v1/phase-containers/{sample_container.id}/extension-preview", @@ -800,9 +770,7 @@ def test_create_extension_short_description(self, client, sample_container, user assert response.status_code == 422 - def test_create_extension_non_member( - self, client, sample_container, other_user_headers - ): + def test_create_extension_non_member(self, client, sample_container, other_user_headers): """Test that non-members get 404 (privacy).""" response = client.post( f"/api/v1/phase-containers/{sample_container.id}/create-extension", @@ -828,9 +796,7 @@ def test_create_extension_nonexistent_container(self, client, user_headers): assert response.status_code == 404 - def test_create_extension_archived_container( - self, client, db, sample_container, sample_user, user_headers - ): + def test_create_extension_archived_container(self, client, db, sample_container, sample_user, user_headers): """Test creating extension in archived container returns 404 (archived = not found).""" from datetime import datetime, timezone diff --git a/backend/tests/test_phase_progress_service.py b/backend/tests/test_phase_progress_service.py index 25d26db..83675b2 100644 --- a/backend/tests/test_phase_progress_service.py +++ b/backend/tests/test_phase_progress_service.py @@ -1,5 +1,5 @@ """Tests for phase progress service and implementation-progress endpoint.""" -import pytest + from datetime import datetime, timezone from uuid import uuid4 @@ -16,11 +16,11 @@ from app.services.brainstorming_phase_service import BrainstormingPhaseService from app.services.phase_progress_service import PhaseProgressService - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _create_phase(db, project, user): return BrainstormingPhaseService.create_brainstorming_phase( db=db, @@ -31,9 +31,17 @@ def _create_phase(db, project, user): ) -def _create_module(db, project, phase, *, title="Module", order_index=0, - module_type=ModuleType.IMPLEMENTATION, module_key_number=1, - archived_at=None): +def _create_module( + db, + project, + phase, + *, + title="Module", + order_index=0, + module_type=ModuleType.IMPLEMENTATION, + module_key_number=1, + archived_at=None, +): module = Module( project_id=project.id, brainstorming_phase_id=phase.id, @@ -52,11 +60,17 @@ def _create_module(db, project, phase, *, title="Module", order_index=0, return module -def _create_feature(db, module, *, title="Feature", feature_key_number=1, - completion_status=FeatureCompletionStatus.PENDING, - priority=FeaturePriority.IMPORTANT, - feature_type=FeatureType.IMPLEMENTATION, - status=FeatureStatus.ACTIVE): +def _create_feature( + db, + module, + *, + title="Feature", + feature_key_number=1, + completion_status=FeatureCompletionStatus.PENDING, + priority=FeaturePriority.IMPORTANT, + feature_type=FeatureType.IMPLEMENTATION, + status=FeatureStatus.ACTIVE, +): feature = Feature( module_id=module.id, title=title, @@ -79,13 +93,12 @@ def _create_feature(db, module, *, title="Feature", feature_key_number=1, # Service: compute_feature_stats # --------------------------------------------------------------------------- + class TestComputeFeatureStats: """Tests for the pure compute_feature_stats function.""" def test_empty_list(self): - total, completed, pending, in_prog, pct, nxt = ( - PhaseProgressService.compute_feature_stats([]) - ) + total, completed, pending, in_prog, pct, nxt = PhaseProgressService.compute_feature_stats([]) assert total == 0 assert completed == 0 assert pending == 0 @@ -99,9 +112,7 @@ def test_all_pending(self, db, test_project, test_user): f1 = _create_feature(db, module, feature_key_number=1) f2 = _create_feature(db, module, feature_key_number=2) - total, completed, pending, in_prog, pct, nxt = ( - PhaseProgressService.compute_feature_stats([f1, f2]) - ) + total, completed, pending, in_prog, pct, nxt = PhaseProgressService.compute_feature_stats([f1, f2]) assert total == 2 assert completed == 0 assert pending == 2 @@ -111,16 +122,11 @@ def test_all_pending(self, db, test_project, test_user): def test_mixed_statuses(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - f1 = _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) - f2 = _create_feature(db, module, feature_key_number=2, - completion_status=FeatureCompletionStatus.IN_PROGRESS) - f3 = _create_feature(db, module, feature_key_number=3, - completion_status=FeatureCompletionStatus.PENDING) - - total, completed, pending, in_prog, pct, nxt = ( - PhaseProgressService.compute_feature_stats([f1, f2, f3]) - ) + f1 = _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) + f2 = _create_feature(db, module, feature_key_number=2, completion_status=FeatureCompletionStatus.IN_PROGRESS) + f3 = _create_feature(db, module, feature_key_number=3, completion_status=FeatureCompletionStatus.PENDING) + + total, completed, pending, in_prog, pct, nxt = PhaseProgressService.compute_feature_stats([f1, f2, f3]) assert total == 3 assert completed == 1 assert in_prog == 1 @@ -131,14 +137,10 @@ def test_mixed_statuses(self, db, test_project, test_user): def test_all_completed(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - f1 = _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) - f2 = _create_feature(db, module, feature_key_number=2, - completion_status=FeatureCompletionStatus.COMPLETED) + f1 = _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) + f2 = _create_feature(db, module, feature_key_number=2, completion_status=FeatureCompletionStatus.COMPLETED) - total, completed, pending, in_prog, pct, nxt = ( - PhaseProgressService.compute_feature_stats([f1, f2]) - ) + total, completed, pending, in_prog, pct, nxt = PhaseProgressService.compute_feature_stats([f1, f2]) assert total == 2 assert completed == 2 assert pct == 100.0 @@ -148,14 +150,10 @@ def test_next_feature_sequential_ignores_priority(self, db, test_project, test_u """next_feature returns the first pending feature by order, regardless of priority.""" phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - f1 = _create_feature(db, module, feature_key_number=1, - priority=FeaturePriority.IMPORTANT) - f2 = _create_feature(db, module, feature_key_number=2, - priority=FeaturePriority.MUST_HAVE) + f1 = _create_feature(db, module, feature_key_number=1, priority=FeaturePriority.IMPORTANT) + f2 = _create_feature(db, module, feature_key_number=2, priority=FeaturePriority.MUST_HAVE) - total, completed, pending, in_prog, pct, nxt = ( - PhaseProgressService.compute_feature_stats([f1, f2]) - ) + total, completed, pending, in_prog, pct, nxt = PhaseProgressService.compute_feature_stats([f1, f2]) assert nxt == "TEST-001" @@ -163,6 +161,7 @@ def test_next_feature_sequential_ignores_priority(self, db, test_project, test_u # Service: get_phase_progress # --------------------------------------------------------------------------- + class TestGetPhaseProgress: """Tests for get_phase_progress DB queries.""" @@ -192,10 +191,8 @@ def test_single_module_all_pending(self, db, test_project, test_user): def test_mixed_statuses(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) - _create_feature(db, module, feature_key_number=2, - completion_status=FeatureCompletionStatus.IN_PROGRESS) + _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) + _create_feature(db, module, feature_key_number=2, completion_status=FeatureCompletionStatus.IN_PROGRESS) _create_feature(db, module, feature_key_number=3) progress = PhaseProgressService.get_phase_progress(db, phase.id) @@ -208,8 +205,7 @@ def test_mixed_statuses(self, db, test_project, test_user): def test_all_completed(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) + _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) progress = PhaseProgressService.get_phase_progress(db, phase.id) assert progress.progress_percent == 100.0 @@ -218,26 +214,19 @@ def test_all_completed(self, db, test_project, test_user): def test_next_feature_sequential_across_modules(self, db, test_project, test_user): """next_feature returns first pending feature by module order, regardless of priority.""" phase = _create_phase(db, test_project, test_user) - m1 = _create_module(db, test_project, phase, title="Module 1", - order_index=0, module_key_number=1) - m2 = _create_module(db, test_project, phase, title="Module 2", - order_index=1, module_key_number=2) - _create_feature(db, m1, feature_key_number=1, - priority=FeaturePriority.IMPORTANT) - _create_feature(db, m2, feature_key_number=2, - priority=FeaturePriority.MUST_HAVE) + m1 = _create_module(db, test_project, phase, title="Module 1", order_index=0, module_key_number=1) + m2 = _create_module(db, test_project, phase, title="Module 2", order_index=1, module_key_number=2) + _create_feature(db, m1, feature_key_number=1, priority=FeaturePriority.IMPORTANT) + _create_feature(db, m2, feature_key_number=2, priority=FeaturePriority.MUST_HAVE) progress = PhaseProgressService.get_phase_progress(db, phase.id) assert progress.next_feature == "TEST-001" def test_multiple_modules_aggregate(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) - m1 = _create_module(db, test_project, phase, title="Module 1", - order_index=0, module_key_number=1) - m2 = _create_module(db, test_project, phase, title="Module 2", - order_index=1, module_key_number=2) - _create_feature(db, m1, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) + m1 = _create_module(db, test_project, phase, title="Module 1", order_index=0, module_key_number=1) + m2 = _create_module(db, test_project, phase, title="Module 2", order_index=1, module_key_number=2) + _create_feature(db, m1, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) _create_feature(db, m2, feature_key_number=2) progress = PhaseProgressService.get_phase_progress(db, phase.id) @@ -249,11 +238,10 @@ def test_multiple_modules_aggregate(self, db, test_project, test_user): def test_archived_modules_excluded(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) - _create_module(db, test_project, phase, title="Archived", - archived_at=datetime.now(timezone.utc), - module_key_number=1) - m2 = _create_module(db, test_project, phase, title="Active", - module_key_number=2) + _create_module( + db, test_project, phase, title="Archived", archived_at=datetime.now(timezone.utc), module_key_number=1 + ) + m2 = _create_module(db, test_project, phase, title="Active", module_key_number=2) _create_feature(db, m2, feature_key_number=1) progress = PhaseProgressService.get_phase_progress(db, phase.id) @@ -262,10 +250,10 @@ def test_archived_modules_excluded(self, db, test_project, test_user): def test_conversation_modules_excluded(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) - _create_module(db, test_project, phase, title="Conversation", - module_type=ModuleType.CONVERSATION, module_key_number=1) - m2 = _create_module(db, test_project, phase, title="Impl", - module_key_number=2) + _create_module( + db, test_project, phase, title="Conversation", module_type=ModuleType.CONVERSATION, module_key_number=1 + ) + m2 = _create_module(db, test_project, phase, title="Impl", module_key_number=2) _create_feature(db, m2, feature_key_number=1) progress = PhaseProgressService.get_phase_progress(db, phase.id) @@ -274,8 +262,7 @@ def test_conversation_modules_excluded(self, db, test_project, test_user): def test_non_active_features_excluded(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - status=FeatureStatus.ARCHIVED) + _create_feature(db, module, feature_key_number=1, status=FeatureStatus.ARCHIVED) _create_feature(db, module, feature_key_number=2) progress = PhaseProgressService.get_phase_progress(db, phase.id) @@ -284,8 +271,7 @@ def test_non_active_features_excluded(self, db, test_project, test_user): def test_conversation_features_excluded(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - feature_type=FeatureType.CONVERSATION) + _create_feature(db, module, feature_key_number=1, feature_type=FeatureType.CONVERSATION) _create_feature(db, module, feature_key_number=2) progress = PhaseProgressService.get_phase_progress(db, phase.id) @@ -296,6 +282,7 @@ def test_conversation_features_excluded(self, db, test_project, test_user): # Service: get_module_progress # --------------------------------------------------------------------------- + class TestGetModuleProgress: """Tests for get_module_progress.""" @@ -307,8 +294,7 @@ def test_nonexistent_module(self, db): def test_module_with_features(self, db, test_project, test_user): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) + _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) _create_feature(db, module, feature_key_number=2) progress = PhaseProgressService.get_module_progress(db, module.id) @@ -323,14 +309,14 @@ def test_module_with_features(self, db, test_project, test_user): # Endpoint: GET /brainstorming-phases/{phase_id}/implementation-progress # --------------------------------------------------------------------------- + class TestImplementationProgressEndpoint: """Tests for the implementation-progress REST endpoint.""" def test_success_with_data(self, client, db, test_project, test_user, auth_headers): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) + _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) _create_feature(db, module, feature_key_number=2) headers = auth_headers(test_user) @@ -415,14 +401,14 @@ def test_requires_auth(self, client, db, test_project, test_user): # Endpoint: GET /modules/{module_id}/implementation-progress # --------------------------------------------------------------------------- + class TestModuleImplementationProgressEndpoint: """Tests for the module implementation-progress REST endpoint.""" def test_success_with_data(self, client, db, test_project, test_user, auth_headers): phase = _create_phase(db, test_project, test_user) module = _create_module(db, test_project, phase) - _create_feature(db, module, feature_key_number=1, - completion_status=FeatureCompletionStatus.COMPLETED) + _create_feature(db, module, feature_key_number=1, completion_status=FeatureCompletionStatus.COMPLETED) _create_feature(db, module, feature_key_number=2) headers = auth_headers(test_user) diff --git a/backend/tests/test_phase_validators.py b/backend/tests/test_phase_validators.py index 612e058..7e5931e 100644 --- a/backend/tests/test_phase_validators.py +++ b/backend/tests/test_phase_validators.py @@ -2,27 +2,28 @@ Tests the validation logic for phase-container relationships. """ + +from uuid import uuid4 + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhaseType from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.phase_container import PhaseContainer -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType +from app.models.user import User from app.models.validators.phase_validators import ( PhaseContainerValidationError, - _validate_container_sequence, - _validate_container_project_consistency, _validate_container_not_archived, + _validate_container_project_consistency, + _validate_container_sequence, _validate_sequence_uniqueness, ) -from app.services.user_service import UserService -from app.services.phase_container_service import PhaseContainerService from app.services.brainstorming_phase_service import BrainstormingPhaseService +from app.services.phase_container_service import PhaseContainerService +from app.services.user_service import UserService @pytest.fixture @@ -72,9 +73,7 @@ def sample_org(test_db_session: Session, sample_user: User) -> Organization: @pytest.fixture -def sample_project( - test_db_session: Session, sample_org: Organization, sample_user: User -) -> Project: +def sample_project(test_db_session: Session, sample_org: Organization, sample_user: User) -> Project: """Create a sample project for testing.""" project = Project( org_id=sample_org.id, diff --git a/backend/tests/test_plan_recommendation_service.py b/backend/tests/test_plan_recommendation_service.py index ec20009..16be8de 100644 --- a/backend/tests/test_plan_recommendation_service.py +++ b/backend/tests/test_plan_recommendation_service.py @@ -5,11 +5,9 @@ plan recommendations. """ -from datetime import date, datetime, timedelta, timezone -from decimal import Decimal +from datetime import datetime, timedelta, timezone from uuid import uuid4 -import pytest from sqlalchemy.orm import Session from app.models.daily_usage_summary import DailyUsageSummary @@ -21,9 +19,7 @@ ) from app.models.user import User from app.services.plan_recommendation_service import ( - DOWNGRADE_EFFICIENCY_THRESHOLD, DOWNGRADE_MIN_CONSECUTIVE_DAYS, - UPGRADE_EFFICIENCY_THRESHOLD, UPGRADE_MIN_CONSECUTIVE_DAYS, PlanRecommendationService, ) @@ -112,9 +108,7 @@ def test_returns_daily_usage(self, db: Session): db.add(summary) db.commit() - usage = PlanRecommendationService.get_org_daily_usage( - db, org.id, yesterday, yesterday - ) + usage = PlanRecommendationService.get_org_daily_usage(db, org.id, yesterday, yesterday) assert yesterday in usage assert usage[yesterday] == 500_000 @@ -128,9 +122,7 @@ def test_empty_usage(self, db: Session): today = datetime.now(timezone.utc).date() yesterday = today - timedelta(days=1) - usage = PlanRecommendationService.get_org_daily_usage( - db, org.id, yesterday, today - ) + usage = PlanRecommendationService.get_org_daily_usage(db, org.id, yesterday, today) assert usage == {} diff --git a/backend/tests/test_plan_service.py b/backend/tests/test_plan_service.py index c8b874a..fed6915 100644 --- a/backend/tests/test_plan_service.py +++ b/backend/tests/test_plan_service.py @@ -4,28 +4,28 @@ Tests organization plan management including token limits, trial plans, freemium plans, and usage tracking. """ -from datetime import datetime, timezone, timedelta + +from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock -from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from app.auth.trial import is_tokens_exhausted, is_trial_expired from app.models.organization import Organization -from app.models.user import User from app.models.platform_settings import PlatformSettings +from app.models.user import User from app.plugin_registry import PlanPlugin, get_plugin_registry +from app.services.org_service import OrgService from app.services.plan_service import ( - PlanService, - TRIAL_TOKEN_ALLOCATION, - get_freemium_settings_sync, DEFAULT_FREEMIUM_INITIAL_TOKENS, - DEFAULT_FREEMIUM_WEEKLY_TOPUP, DEFAULT_FREEMIUM_MAX_TOKENS, + DEFAULT_FREEMIUM_WEEKLY_TOPUP, + TRIAL_TOKEN_ALLOCATION, + PlanService, + get_freemium_settings_sync, ) -from app.services.org_service import OrgService from app.services.user_service import UserService -from app.auth.trial import is_tokens_exhausted, is_trial_expired def _on_org_created(db, org, user): @@ -399,9 +399,7 @@ def test_create_org_with_trial_user_initializes_plan(self, db: Session, plan_plu assert user.trial_started_at is not None # Create org for user - org, membership = OrgService.create_org_with_owner( - db, "New Org", user.id - ) + org, membership = OrgService.create_org_with_owner(db, "New Org", user.id) # Org should have freemium plan assert org.plan_name == "freemium" @@ -422,9 +420,7 @@ def test_create_org_with_grandfathered_user_no_plan(self, db: Session): db.commit() # Create org for user - org, membership = OrgService.create_org_with_owner( - db, "Grandfathered Org", user.id - ) + org, membership = OrgService.create_org_with_owner(db, "Grandfathered Org", user.id) # Org should not have plan set assert org.plan_name is None diff --git a/backend/tests/test_platform_admin.py b/backend/tests/test_platform_admin.py index 315316a..841270b 100644 --- a/backend/tests/test_platform_admin.py +++ b/backend/tests/test_platform_admin.py @@ -1,6 +1,8 @@ """Tests for platform admin authorization.""" + +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from fastapi import HTTPException from app.auth.platform_admin import is_platform_admin, require_platform_admin diff --git a/backend/tests/test_platform_settings_router.py b/backend/tests/test_platform_settings_router.py index 572b613..5ff2722 100644 --- a/backend/tests/test_platform_settings_router.py +++ b/backend/tests/test_platform_settings_router.py @@ -8,9 +8,10 @@ Tests that work without database access (auth checks, connector tests, signup restrictions) are included here. """ + import os import tempfile -from unittest.mock import patch, MagicMock, AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -18,7 +19,7 @@ from sqlalchemy.orm import sessionmaker from app.auth.utils import create_access_token -from app.database import Base, get_db, get_async_db +from app.database import Base, get_async_db, get_db from app.main import app from app.models import User from app.services.user_service import UserService @@ -229,8 +230,10 @@ def test_test_llm_connector_success(self, client, admin_headers): mock_response = MagicMock() mock_response.choices = [MagicMock(message=MagicMock(content="Hi"))] - with patch("app.auth.platform_admin.settings") as mock_settings, \ - patch("litellm.acompletion", new_callable=AsyncMock) as mock_llm: + with ( + patch("app.auth.platform_admin.settings") as mock_settings, + patch("litellm.acompletion", new_callable=AsyncMock) as mock_llm, + ): mock_settings.platform_admin_emails = {"admin@example.com"} mock_llm.return_value = mock_response @@ -276,8 +279,7 @@ def test_test_email_connector_success(self, client, admin_headers): mock_response = MagicMock() mock_response.status_code = 200 - with patch("app.auth.platform_admin.settings") as mock_settings, \ - patch("httpx.AsyncClient") as mock_client: + with patch("app.auth.platform_admin.settings") as mock_settings, patch("httpx.AsyncClient") as mock_client: mock_settings.platform_admin_emails = {"admin@example.com"} mock_instance = AsyncMock() mock_instance.get = AsyncMock(return_value=mock_response) @@ -344,8 +346,7 @@ def test_test_s3_connector_success(self, client, admin_headers): mock_s3_client = MagicMock() mock_s3_client.head_bucket.return_value = {} - with patch("app.auth.platform_admin.settings") as mock_settings, \ - patch("boto3.client") as mock_boto3: + with patch("app.auth.platform_admin.settings") as mock_settings, patch("boto3.client") as mock_boto3: mock_settings.platform_admin_emails = {"admin@example.com"} mock_boto3.return_value = mock_s3_client @@ -355,10 +356,12 @@ def test_test_s3_connector_success(self, client, admin_headers): json={ "connector_type": "object_storage", "provider": "aws-s3", - "credentials": json.dumps({ - "access_key_id": "test-access-key", - "secret_access_key": "test-secret-key", - }), + "credentials": json.dumps( + { + "access_key_id": "test-access-key", + "secret_access_key": "test-secret-key", + } + ), "config_json": {"bucket": "my-bucket", "region": "us-east-1"}, }, ) diff --git a/backend/tests/test_platform_settings_service.py b/backend/tests/test_platform_settings_service.py index d5452ce..230e358 100644 --- a/backend/tests/test_platform_settings_service.py +++ b/backend/tests/test_platform_settings_service.py @@ -4,17 +4,18 @@ We test it through the router endpoints in test_platform_settings_router.py. Here we focus on the synchronous helper functions used by workers. """ + +import base64 import os import tempfile -import base64 +from uuid import uuid4 import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC -from uuid import uuid4 +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker from app.database import Base from app.models.platform_connector import PlatformConnector @@ -337,9 +338,7 @@ def test_create_connector(self, sync_db): sync_db.add(connector) sync_db.commit() - fetched = sync_db.query(PlatformConnector).filter( - PlatformConnector.id == connector.id - ).first() + fetched = sync_db.query(PlatformConnector).filter(PlatformConnector.id == connector.id).first() assert fetched is not None assert fetched.connector_type == "llm" diff --git a/backend/tests/test_prefix_service.py b/backend/tests/test_prefix_service.py index 9460774..313f047 100644 --- a/backend/tests/test_prefix_service.py +++ b/backend/tests/test_prefix_service.py @@ -1,11 +1,12 @@ """Tests for PrefixService.""" + +from unittest.mock import AsyncMock, patch + import pytest -from uuid import uuid4 -from unittest.mock import AsyncMock, patch, MagicMock +from app.models import ProjectType from app.services.prefix_service import PrefixService from app.services.project_service import ProjectService -from app.models import ProjectType class TestValidatePrefix: @@ -185,9 +186,7 @@ def test_exclude_project_id(self, db, test_user, test_org): ) # Same prefix should be "available" when excluding self - available = PrefixService.is_prefix_available( - db, test_org.id, "XYZ", exclude_project_id=project.id - ) + available = PrefixService.is_prefix_available(db, test_org.id, "XYZ", exclude_project_id=project.id) assert available is True diff --git a/backend/tests/test_project_chat_list.py b/backend/tests/test_project_chat_list.py index 5b5f3b7..780399a 100644 --- a/backend/tests/test_project_chat_list.py +++ b/backend/tests/test_project_chat_list.py @@ -1,8 +1,10 @@ """Tests for the project chat list endpoint.""" + import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Project + +from app.models import Project, User from app.models.project_chat import ProjectChat from app.services.project_chat_service import ProjectChatService diff --git a/backend/tests/test_project_chat_reactions.py b/backend/tests/test_project_chat_reactions.py index d0de8a9..3a6c3fd 100644 --- a/backend/tests/test_project_chat_reactions.py +++ b/backend/tests/test_project_chat_reactions.py @@ -1,8 +1,10 @@ """Tests for project chat message reactions functionality.""" + import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Project + +from app.models import Project, User from app.models.project_chat import ProjectChat, ProjectChatMessage, ProjectChatMessageType from app.services.project_chat_service import ProjectChatService from app.services.user_service import UserService @@ -60,7 +62,11 @@ class TestProjectChatToggleReaction: """Tests for POST /projects/{project_id}/project-chats/{discussion_id}/messages/{message_id}/reactions endpoint.""" def test_toggle_reaction_requires_auth( - self, client: TestClient, test_project: Project, test_project_chat: ProjectChat, test_user_message: ProjectChatMessage + self, + client: TestClient, + test_project: Project, + test_project_chat: ProjectChat, + test_user_message: ProjectChatMessage, ): """Test that toggling reactions requires authentication.""" response = client.post( @@ -154,8 +160,8 @@ def test_toggle_reaction_multiple_users( db.commit() # Add project share for the other user using the service - from app.services.project_share_service import ProjectShareService from app.models.project_membership import ProjectRole + from app.services.project_share_service import ProjectShareService ProjectShareService.create_user_share( db=db, diff --git a/backend/tests/test_project_endpoints.py b/backend/tests/test_project_endpoints.py index 405b800..9e66cd7 100644 --- a/backend/tests/test_project_endpoints.py +++ b/backend/tests/test_project_endpoints.py @@ -1,12 +1,11 @@ """Tests for project endpoints.""" -import pytest + from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Organization, OrgMembership, OrgRole, ProjectType, ProjectStatus -from app.services.user_service import UserService -from app.services.org_service import OrgService +from app.models import Organization, OrgMembership, OrgRole, ProjectType, User from app.services.project_service import ProjectService +from app.services.user_service import UserService class TestProjectEndpoints: @@ -23,7 +22,9 @@ def test_create_project_requires_auth(self, client: TestClient): ) assert response.status_code == 401 - def test_create_project_as_member(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_project_as_member( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test creating a project as an org member.""" headers = auth_headers(test_user) @@ -46,7 +47,9 @@ def test_create_project_as_member(self, client: TestClient, db: Session, test_us assert data["org_id"] == str(test_org.id) assert data["created_by"] == str(test_user.id) - def test_create_project_with_key(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_project_with_key( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test creating a project with a custom key.""" headers = auth_headers(test_user) @@ -64,7 +67,9 @@ def test_create_project_with_key(self, client: TestClient, db: Session, test_use data = response.json() assert data["key"] == "CHKOT" - def test_create_project_feature_type_restricted(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_project_feature_type_restricted( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test creating a feature project is not allowed in Step 6.""" headers = auth_headers(test_user) @@ -81,7 +86,9 @@ def test_create_project_feature_type_restricted(self, client: TestClient, db: Se assert response.status_code == 400 assert "application" in response.json()["detail"].lower() - def test_create_project_as_viewer_fails(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_project_as_viewer_fails( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test that viewers cannot create projects.""" # Create another user who is a viewer viewer = UserService.create_user( @@ -113,7 +120,9 @@ def test_create_project_as_viewer_fails(self, client: TestClient, db: Session, t assert response.status_code == 403 - def test_list_org_projects(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_list_org_projects( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test listing projects for an organization.""" # Create some projects ProjectService.create_project( @@ -194,9 +203,7 @@ def test_list_org_projects_excludes_archived( ProjectService.archive_project(db, archived_project.id) headers = auth_headers(test_user) - response = client.get( - f"/api/v1/orgs/{test_org.id}/projects", headers=headers - ) + response = client.get(f"/api/v1/orgs/{test_org.id}/projects", headers=headers) assert response.status_code == 200 data = response.json() @@ -204,7 +211,9 @@ def test_list_org_projects_excludes_archived( assert "Active Project" in names assert "Archived Project" not in names - def test_list_org_projects_with_type_filter(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_list_org_projects_with_type_filter( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test listing projects with type filter.""" ProjectService.create_project( db=db, @@ -225,7 +234,9 @@ def test_list_org_projects_with_type_filter(self, client: TestClient, db: Sessio assert len(data) == 1 assert data[0]["name"] == "App 1" - def test_get_project_by_id(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_get_project_by_id( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test getting a single project by ID.""" project = ProjectService.create_project( db=db, @@ -248,7 +259,9 @@ def test_get_project_by_id(self, client: TestClient, db: Session, test_user: Use assert data["name"] == "Test Project" assert data["idea_text"] == "Some idea" - def test_get_project_nonmember_fails(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_get_project_nonmember_fails( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test that non-org members cannot view projects.""" project = ProjectService.create_project( db=db, @@ -278,7 +291,9 @@ def test_get_project_nonmember_fails(self, client: TestClient, db: Session, test assert response.status_code == 404 # Privacy: return 404, not 403 - def test_update_project(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_update_project( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test updating a project.""" project = ProjectService.create_project( db=db, @@ -303,7 +318,9 @@ def test_update_project(self, client: TestClient, db: Session, test_user: User, assert data["name"] == "Updated Name" assert data["status"] == "discovery" - def test_update_project_non_creator_fails(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_update_project_non_creator_fails( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test that non-creators cannot update projects (for now).""" project = ProjectService.create_project( db=db, @@ -337,7 +354,9 @@ def test_update_project_non_creator_fails(self, client: TestClient, db: Session, assert response.status_code == 403 - def test_archive_project(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_archive_project( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test archiving a project.""" project = ProjectService.create_project( db=db, @@ -357,7 +376,9 @@ def test_archive_project(self, client: TestClient, db: Session, test_user: User, data = response.json() assert data["status"] == "archived" - def test_archive_project_non_creator_fails(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_archive_project_non_creator_fails( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test that non-creators cannot archive projects.""" project = ProjectService.create_project( db=db, @@ -394,7 +415,9 @@ def test_archive_project_non_creator_fails(self, client: TestClient, db: Session class TestProjectHierarchyEndpoints: """Test project hierarchy endpoints (features and bugfixes).""" - def test_create_feature_under_application(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_feature_under_application( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test creating a feature under an application.""" # Create parent application app = ProjectService.create_project( @@ -431,7 +454,9 @@ def test_create_feature_requires_auth(self, client: TestClient): ) assert response.status_code == 401 - def test_create_feature_with_nonexistent_parent(self, client: TestClient, db: Session, test_user: User, auth_headers): + def test_create_feature_with_nonexistent_parent( + self, client: TestClient, db: Session, test_user: User, auth_headers + ): """Test creating feature with non-existent parent fails.""" headers = auth_headers(test_user) response = client.post( @@ -441,7 +466,9 @@ def test_create_feature_with_nonexistent_parent(self, client: TestClient, db: Se ) assert response.status_code == 404 - def test_create_bugfix_under_application(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_bugfix_under_application( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test creating a bugfix under an application.""" app = ProjectService.create_project( db=db, @@ -470,7 +497,9 @@ def test_create_bugfix_under_application(self, client: TestClient, db: Session, assert data["external_ticket_id"] == "JIRA-123" assert data["external_system"] == "jira" - def test_create_bugfix_without_ticket_id(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_bugfix_without_ticket_id( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test creating a bugfix without external ticket ID (optional).""" app = ProjectService.create_project( db=db, @@ -491,7 +520,9 @@ def test_create_bugfix_without_ticket_id(self, client: TestClient, db: Session, data = response.json() assert data["external_ticket_id"] is None - def test_create_feature_under_feature_fails(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_feature_under_feature_fails( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test that creating a feature under a feature fails (no child-of-child).""" app = ProjectService.create_project( db=db, @@ -519,7 +550,9 @@ def test_create_feature_under_feature_fails(self, client: TestClient, db: Sessio assert response.status_code == 400 assert "APPLICATION" in response.json()["detail"] - def test_create_feature_nonmember_fails(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_create_feature_nonmember_fails( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test that non-org members cannot create features.""" app = ProjectService.create_project( db=db, @@ -550,7 +583,9 @@ def test_create_feature_nonmember_fails(self, client: TestClient, db: Session, t assert response.status_code == 404 # Privacy: 404, not 403 - def test_list_project_children(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_list_project_children( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test listing children of a project.""" app = ProjectService.create_project( db=db, @@ -593,7 +628,9 @@ def test_list_project_children(self, client: TestClient, db: Session, test_user: assert "Feature 1" in names assert "Bugfix 1" in names - def test_list_project_children_with_type_filter(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_list_project_children_with_type_filter( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test listing children with type filter.""" app = ProjectService.create_project( db=db, @@ -643,7 +680,9 @@ def test_list_project_children_with_type_filter(self, client: TestClient, db: Se assert len(data) == 1 assert data[0]["type"] == "bugfix" - def test_list_project_children_empty(self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers): + def test_list_project_children_empty( + self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers + ): """Test listing children for project with no children.""" app = ProjectService.create_project( db=db, @@ -671,7 +710,6 @@ def test_load_sample_project_triggers_batch_generation( self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers, monkeypatch ): """Test that loading a sample project triggers batch generation for all phases.""" - from app.models.job import JobType, JobStatus # Track Kafka publish calls publish_calls = [] @@ -710,7 +748,8 @@ def test_load_sample_project_creates_jobs_in_database( ): """Test that loading a sample project creates job records in the database.""" from uuid import UUID - from app.models.job import Job, JobType, JobStatus + + from app.models.job import Job, JobStatus, JobType # Mock Kafka publishing monkeypatch.setattr("app.routers.projects.publish_job_to_kafka", lambda *args, **kwargs: True) @@ -726,10 +765,14 @@ def test_load_sample_project_creates_jobs_in_database( project_id = UUID(response.json()["id"]) # Check that 3 batch generation jobs were created - jobs = db.query(Job).filter( - Job.project_id == project_id, - Job.job_type == JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, - ).all() + jobs = ( + db.query(Job) + .filter( + Job.project_id == project_id, + Job.job_type == JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, + ) + .all() + ) assert len(jobs) == 3 for job in jobs: diff --git a/backend/tests/test_project_membership_service.py b/backend/tests/test_project_membership_service.py index c850a5c..2ce592a 100644 --- a/backend/tests/test_project_membership_service.py +++ b/backend/tests/test_project_membership_service.py @@ -1,10 +1,11 @@ """Tests for ProjectService membership methods.""" -import pytest + from uuid import uuid4 +import pytest + from app.models import ProjectRole, ProjectType from app.services.project_service import ProjectService -from app.services.org_service import OrgService from app.services.user_service import UserService @@ -22,6 +23,7 @@ def test_add_project_member_succeeds(self, db, test_user, test_org): ) # Add other user to the org from app.models import OrgMembership, OrgRole + org_membership = OrgMembership( org_id=test_org.id, user_id=other_user.id, @@ -115,6 +117,7 @@ def test_remove_project_member_succeeds(self, db, test_user, test_org): ) # Add to org from app.models import OrgMembership, OrgRole + org_membership = OrgMembership( org_id=test_org.id, user_id=other_user.id, @@ -239,6 +242,7 @@ def test_list_project_members(self, db, test_user, test_org): ) # Add to org from app.models import OrgMembership, OrgRole + org_membership = OrgMembership( org_id=test_org.id, user_id=other_user.id, diff --git a/backend/tests/test_project_resolver.py b/backend/tests/test_project_resolver.py index e832079..fde30ac 100644 --- a/backend/tests/test_project_resolver.py +++ b/backend/tests/test_project_resolver.py @@ -1,11 +1,12 @@ """Tests for project resolver utility.""" -import pytest from uuid import uuid4 + +import pytest from sqlalchemy.orm import Session from app.mcp.utils.project_resolver import resolve_project -from app.models import Project, ProjectType, ProjectStatus, Organization, User +from app.models import Organization, Project, ProjectStatus, ProjectType, User class TestResolveProject: @@ -53,9 +54,7 @@ def test_resolve_by_nonexistent_key_raises_error(self, db: Session): with pytest.raises(ValueError, match="Project with key 'NONEXISTENT' not found"): resolve_project(db, project_key="NONEXISTENT") - def test_resolve_prefers_id_when_both_provided( - self, db: Session, test_org: Organization, test_user: User - ): + def test_resolve_prefers_id_when_both_provided(self, db: Session, test_org: Organization, test_user: User): """When both ID and key provided, prefer ID.""" # Create two projects project1 = Project( @@ -82,9 +81,7 @@ def test_resolve_prefers_id_when_both_provided( db.refresh(project2) # Resolve with project1 ID but project2 key - resolved = resolve_project( - db, project_id=str(project1.id), project_key="APP:key2" - ) + resolved = resolve_project(db, project_id=str(project1.id), project_key="APP:key2") # Should resolve to project1 (ID takes precedence) assert resolved.id == project1.id diff --git a/backend/tests/test_project_service.py b/backend/tests/test_project_service.py index 8b10e0c..31a2d66 100644 --- a/backend/tests/test_project_service.py +++ b/backend/tests/test_project_service.py @@ -1,8 +1,10 @@ """Tests for ProjectService.""" -import pytest + from uuid import uuid4 -from app.models import Project, ProjectType, ProjectStatus +import pytest + +from app.models import ProjectStatus, ProjectType from app.services.project_service import ProjectService @@ -116,16 +118,12 @@ def test_list_org_projects_with_type_filter(self, db, test_user, test_org): ) # Filter by application type - projects = ProjectService.list_org_projects( - db, test_org.id, type_filter=ProjectType.APPLICATION - ) + projects = ProjectService.list_org_projects(db, test_org.id, type_filter=ProjectType.APPLICATION) assert len(projects) == 1 assert projects[0].id == app_project.id # Filter by feature type (should be empty) - projects = ProjectService.list_org_projects( - db, test_org.id, type_filter=ProjectType.FEATURE - ) + projects = ProjectService.list_org_projects(db, test_org.id, type_filter=ProjectType.FEATURE) assert len(projects) == 0 def test_list_org_projects_empty(self, db, test_org): @@ -301,9 +299,7 @@ def test_create_child_with_non_application_parent(self, db, test_user, test_org) ) # Try to create child of child (should fail) - with pytest.raises( - ValueError, match="Parent project must be of type APPLICATION" - ): + with pytest.raises(ValueError, match="Parent project must be of type APPLICATION"): ProjectService.create_child_project( db=db, parent_project_id=feature.id, @@ -401,16 +397,12 @@ def test_list_project_children_type_filter(self, db, test_user, test_org): ) # Filter by FEATURE - features = ProjectService.list_project_children( - db, app.id, type_filter=ProjectType.FEATURE - ) + features = ProjectService.list_project_children(db, app.id, type_filter=ProjectType.FEATURE) assert len(features) == 1 assert features[0].id == feature.id # Filter by BUGFIX - bugfixes = ProjectService.list_project_children( - db, app.id, type_filter=ProjectType.BUGFIX - ) + bugfixes = ProjectService.list_project_children(db, app.id, type_filter=ProjectType.BUGFIX) assert len(bugfixes) == 1 assert bugfixes[0].id == bugfix.id @@ -686,12 +678,10 @@ def test_update_project_key_duplicate_in_same_org(self, db, test_user, test_org) key="KY1", ) - def test_update_project_key_same_key_in_different_org_allowed( - self, db, test_user, test_org - ): + def test_update_project_key_same_key_in_different_org_allowed(self, db, test_user, test_org): """Test that same key in different orgs is allowed.""" - from app.services.user_service import UserService from app.services.org_service import OrgService + from app.services.user_service import UserService # Create second org and user user2 = UserService.create_user( @@ -779,9 +769,9 @@ def test_clone_project_not_found(self, db, test_user): def test_clone_project_with_phases(self, db, test_user, test_org): """Test cloning with phases, modules, features, and content versions.""" from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType - from app.models.module import Module, ModuleProvenance, ModuleType from app.models.feature import Feature, FeatureProvenance, FeatureType - from app.models.feature_content_version import FeatureContentVersion, FeatureContentType + from app.models.feature_content_version import FeatureContentType, FeatureContentVersion + from app.models.module import Module, ModuleProvenance, ModuleType # Create source project original = ProjectService.create_project( @@ -855,35 +845,25 @@ def test_clone_project_with_phases(self, db, test_user, test_org): ) # Verify cloned phases - cloned_phases = ( - db.query(BrainstormingPhase) - .filter(BrainstormingPhase.project_id == cloned.id) - .all() - ) + cloned_phases = db.query(BrainstormingPhase).filter(BrainstormingPhase.project_id == cloned.id).all() assert len(cloned_phases) == 1 assert cloned_phases[0].id != phase.id assert cloned_phases[0].title == phase.title # Verify cloned modules - cloned_modules = ( - db.query(Module).filter(Module.project_id == cloned.id).all() - ) + cloned_modules = db.query(Module).filter(Module.project_id == cloned.id).all() assert len(cloned_modules) == 1 assert cloned_modules[0].brainstorming_phase_id == cloned_phases[0].id assert cloned_modules[0].title == module.title # Verify cloned features - cloned_features = ( - db.query(Feature).filter(Feature.module_id == cloned_modules[0].id).all() - ) + cloned_features = db.query(Feature).filter(Feature.module_id == cloned_modules[0].id).all() assert len(cloned_features) == 1 assert cloned_features[0].title == feature.title # Verify content versions cloned_cvs = ( - db.query(FeatureContentVersion) - .filter(FeatureContentVersion.feature_id == cloned_features[0].id) - .all() + db.query(FeatureContentVersion).filter(FeatureContentVersion.feature_id == cloned_features[0].id).all() ) assert len(cloned_cvs) == 1 assert cloned_cvs[0].content_markdown == content_version.content_markdown @@ -891,7 +871,7 @@ def test_clone_project_with_phases(self, db, test_user, test_org): def test_clone_project_threads_summary_only(self, db, test_user, test_org): """Test cloning threads with decision summaries only.""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem # Create source project @@ -933,24 +913,18 @@ def test_clone_project_threads_summary_only(self, db, test_user, test_org): ) # Verify thread cloned - cloned_threads = ( - db.query(Thread).filter(Thread.project_id == str(cloned.id)).all() - ) + cloned_threads = db.query(Thread).filter(Thread.project_id == str(cloned.id)).all() assert len(cloned_threads) == 1 assert cloned_threads[0].decision_summary == thread.decision_summary assert cloned_threads[0].decision_summary_short == thread.decision_summary_short # Verify items NOT cloned - cloned_items = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == cloned_threads[0].id) - .all() - ) + cloned_items = db.query(ThreadItem).filter(ThreadItem.thread_id == cloned_threads[0].id).all() assert len(cloned_items) == 0 def test_clone_project_threads_full(self, db, test_user, test_org): """Test cloning threads with all items.""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem # Create source project @@ -991,17 +965,11 @@ def test_clone_project_threads_full(self, db, test_user, test_org): ) # Verify thread cloned - cloned_threads = ( - db.query(Thread).filter(Thread.project_id == str(cloned.id)).all() - ) + cloned_threads = db.query(Thread).filter(Thread.project_id == str(cloned.id)).all() assert len(cloned_threads) == 1 # Verify items cloned - cloned_items = ( - db.query(ThreadItem) - .filter(ThreadItem.thread_id == cloned_threads[0].id) - .all() - ) + cloned_items = db.query(ThreadItem).filter(ThreadItem.thread_id == cloned_threads[0].id).all() assert len(cloned_items) == 1 assert cloned_items[0].content_data == item.content_data @@ -1054,12 +1022,7 @@ def test_clone_project_preserves_order_index(self, db, test_user, test_org): ) # Verify order indices preserved - cloned_modules = ( - db.query(Module) - .filter(Module.project_id == cloned.id) - .order_by(Module.order_index) - .all() - ) + cloned_modules = db.query(Module).filter(Module.project_id == cloned.id).order_by(Module.order_index).all() assert len(cloned_modules) == 3 assert [m.order_index for m in cloned_modules] == [0, 1, 2] assert cloned_modules[0].title == "Module 0" diff --git a/backend/tests/test_project_share_service.py b/backend/tests/test_project_share_service.py index e6cb469..1facc3d 100644 --- a/backend/tests/test_project_share_service.py +++ b/backend/tests/test_project_share_service.py @@ -1,23 +1,20 @@ """Tests for project share service.""" -import pytest from sqlalchemy.orm import Session -from app.models import User, Organization, Project, ProjectRole, ProjectType +from app.models import Organization, Project, ProjectRole, ProjectType, User from app.models.project_share import ShareSubjectType -from app.services.project_share_service import ProjectShareService +from app.services.org_service import OrgService from app.services.project_service import ProjectService +from app.services.project_share_service import ProjectShareService from app.services.user_group_service import UserGroupService from app.services.user_service import UserService -from app.services.org_service import OrgService class TestProjectShareService: """Tests for ProjectShareService CRUD operations.""" - def test_create_user_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_create_user_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test creating a user share.""" share = ProjectShareService.create_user_share( db=db, @@ -34,9 +31,7 @@ def test_create_user_share( assert share.role == ProjectRole.MEMBER assert share.created_by_user_id == test_user.id - def test_create_group_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_create_group_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test creating a group share.""" group = UserGroupService.create_group( db=db, @@ -80,9 +75,7 @@ def test_create_share_upserts_role( assert share2.id == share1.id assert share2.role == ProjectRole.ADMIN - def test_get_share_by_id( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_get_share_by_id(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test getting a share by ID.""" share = ProjectShareService.create_user_share( db=db, @@ -95,9 +88,7 @@ def test_get_share_by_id( assert fetched is not None assert fetched.id == share.id - def test_remove_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_remove_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test removing a share by ID.""" share = ProjectShareService.create_user_share( db=db, @@ -112,9 +103,7 @@ def test_remove_share( fetched = ProjectShareService.get_share_by_id(db, share.id) assert fetched is None - def test_remove_user_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_remove_user_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test removing a user's share from a project.""" ProjectShareService.create_user_share( db=db, @@ -123,44 +112,30 @@ def test_remove_user_share( role=ProjectRole.MEMBER, ) - result = ProjectShareService.remove_user_share( - db, test_project.id, test_user.id - ) + result = ProjectShareService.remove_user_share(db, test_project.id, test_user.id) assert result is True - share = ProjectShareService.get_user_direct_share( - db, test_project.id, test_user.id - ) + share = ProjectShareService.get_user_direct_share(db, test_project.id, test_user.id) assert share is None - def test_list_project_shares( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_list_project_shares(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test listing all shares for a project.""" # Create another user user2 = UserService.create_user(db, "user2@test.com", "password") OrgService.add_org_member(db, test_org.id, user2.id) # Note: test_project already has an owner share for test_user from creation - ProjectShareService.create_user_share( - db, test_project.id, user2.id, ProjectRole.VIEWER - ) + ProjectShareService.create_user_share(db, test_project.id, user2.id, ProjectRole.VIEWER) shares = ProjectShareService.list_project_shares(db, test_project.id) # 2 shares: auto-created owner share + viewer share assert len(shares) == 2 - def test_get_user_direct_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_get_user_direct_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test getting a user's direct share.""" - ProjectShareService.create_user_share( - db, test_project.id, test_user.id, ProjectRole.ADMIN - ) + ProjectShareService.create_user_share(db, test_project.id, test_user.id, ProjectRole.ADMIN) - share = ProjectShareService.get_user_direct_share( - db, test_project.id, test_user.id - ) + share = ProjectShareService.get_user_direct_share(db, test_project.id, test_user.id) assert share is not None assert share.role == ProjectRole.ADMIN @@ -172,31 +147,21 @@ def test_get_user_direct_share_not_found( other_user = UserService.create_user(db, "other@test.com", "password") OrgService.add_org_member(db, test_org.id, other_user.id) - share = ProjectShareService.get_user_direct_share( - db, test_project.id, other_user.id - ) + share = ProjectShareService.get_user_direct_share(db, test_project.id, other_user.id) assert share is None class TestProjectShareGroupAccess: """Tests for group-based project access.""" - def test_get_user_group_shares( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_get_user_group_shares(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test getting group shares that apply to a user.""" - group = UserGroupService.create_group( - db, test_org.id, "Engineering", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Engineering", test_user.id) UserGroupService.add_member(db, group.id, test_user.id) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.MEMBER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.MEMBER) - shares = ProjectShareService.get_user_group_shares( - db, test_project.id, test_user.id - ) + shares = ProjectShareService.get_user_group_shares(db, test_project.id, test_user.id) assert len(shares) == 1 assert shares[0].subject_type == ShareSubjectType.GROUP assert shares[0].subject_id == group.id @@ -205,13 +170,9 @@ def test_get_user_effective_share_direct_only( self, db: Session, test_user: User, test_org: Organization, test_project: Project ): """Test effective share when user only has direct access.""" - ProjectShareService.create_user_share( - db, test_project.id, test_user.id, ProjectRole.ADMIN - ) + ProjectShareService.create_user_share(db, test_project.id, test_user.id, ProjectRole.ADMIN) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, test_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, test_user.id) assert share is not None assert share.role == ProjectRole.ADMIN assert share.subject_type == ShareSubjectType.USER @@ -224,17 +185,11 @@ def test_get_user_effective_share_group_only( group_user = UserService.create_user(db, "groupuser@test.com", "password") OrgService.add_org_member(db, test_org.id, group_user.id) - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) UserGroupService.add_member(db, group.id, group_user.id) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, group_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, group_user.id) assert share is not None assert share.role == ProjectRole.VIEWER assert share.subject_type == ShareSubjectType.GROUP @@ -243,23 +198,15 @@ def test_get_user_effective_share_prefers_higher_role( self, db: Session, test_user: User, test_org: Organization, test_project: Project ): """Test effective share picks highest role when user has multiple.""" - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) UserGroupService.add_member(db, group.id, test_user.id) # Group gives VIEWER - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) # Direct gives ADMIN - ProjectShareService.create_user_share( - db, test_project.id, test_user.id, ProjectRole.ADMIN - ) + ProjectShareService.create_user_share(db, test_project.id, test_user.id, ProjectRole.ADMIN) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, test_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, test_user.id) assert share.role == ProjectRole.ADMIN def test_get_user_effective_share_prefers_direct_on_tie( @@ -276,35 +223,21 @@ def test_get_user_effective_share_prefers_direct_on_tie( member_user = UserService.create_user(db, "member@example.com", "testpassword") OrgService.add_org_member(db, test_org.id, member_user.id, OrgRole.MEMBER) - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) UserGroupService.add_member(db, group.id, member_user.id) # Both give MEMBER - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.MEMBER - ) - ProjectShareService.create_user_share( - db, test_project.id, member_user.id, ProjectRole.MEMBER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.MEMBER) + ProjectShareService.create_user_share(db, test_project.id, member_user.id, ProjectRole.MEMBER) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, member_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, member_user.id) assert share.subject_type == ShareSubjectType.USER - def test_get_user_effective_role( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_get_user_effective_role(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test getting effective role.""" - ProjectShareService.create_user_share( - db, test_project.id, test_user.id, ProjectRole.OWNER - ) + ProjectShareService.create_user_share(db, test_project.id, test_user.id, ProjectRole.OWNER) - role = ProjectShareService.get_user_effective_role( - db, test_project.id, test_user.id - ) + role = ProjectShareService.get_user_effective_role(db, test_project.id, test_user.id) assert role == ProjectRole.OWNER def test_get_user_effective_role_no_access( @@ -315,39 +248,27 @@ def test_get_user_effective_role_no_access( no_access_user = UserService.create_user(db, "noaccess@test.com", "password") OrgService.add_org_member(db, test_org.id, no_access_user.id) - role = ProjectShareService.get_user_effective_role( - db, test_project.id, no_access_user.id - ) + role = ProjectShareService.get_user_effective_role(db, test_project.id, no_access_user.id) assert role is None def test_user_has_project_access_direct( self, db: Session, test_user: User, test_org: Organization, test_project: Project ): """Test access check with direct share.""" - ProjectShareService.create_user_share( - db, test_project.id, test_user.id, ProjectRole.VIEWER - ) + ProjectShareService.create_user_share(db, test_project.id, test_user.id, ProjectRole.VIEWER) - has_access = ProjectShareService.user_has_project_access( - db, test_project.id, test_user.id - ) + has_access = ProjectShareService.user_has_project_access(db, test_project.id, test_user.id) assert has_access is True def test_user_has_project_access_via_group( self, db: Session, test_user: User, test_org: Organization, test_project: Project ): """Test access check with group-based share.""" - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) UserGroupService.add_member(db, group.id, test_user.id) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) - has_access = ProjectShareService.user_has_project_access( - db, test_project.id, test_user.id - ) + has_access = ProjectShareService.user_has_project_access(db, test_project.id, test_user.id) assert has_access is True def test_user_has_project_access_no_access( @@ -358,9 +279,7 @@ def test_user_has_project_access_no_access( no_access_user = UserService.create_user(db, "noaccess2@test.com", "password") OrgService.add_org_member(db, test_org.id, no_access_user.id) - has_access = ProjectShareService.user_has_project_access( - db, test_project.id, no_access_user.id - ) + has_access = ProjectShareService.user_has_project_access(db, test_project.id, no_access_user.id) assert has_access is False def test_user_has_project_access_group_not_member( @@ -371,17 +290,11 @@ def test_user_has_project_access_group_not_member( non_member_user = UserService.create_user(db, "nonmember@test.com", "password") OrgService.add_org_member(db, test_org.id, non_member_user.id) - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) # Note: NOT adding non_member_user to group - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) - has_access = ProjectShareService.user_has_project_access( - db, test_project.id, non_member_user.id - ) + has_access = ProjectShareService.user_has_project_access(db, test_project.id, non_member_user.id) assert has_access is False def test_multiple_group_memberships( @@ -392,34 +305,22 @@ def test_multiple_group_memberships( multi_group_user = UserService.create_user(db, "multigroup@test.com", "password") OrgService.add_org_member(db, test_org.id, multi_group_user.id) - group1 = UserGroupService.create_group( - db, test_org.id, "Team1", test_user.id - ) - group2 = UserGroupService.create_group( - db, test_org.id, "Team2", test_user.id - ) + group1 = UserGroupService.create_group(db, test_org.id, "Team1", test_user.id) + group2 = UserGroupService.create_group(db, test_org.id, "Team2", test_user.id) UserGroupService.add_member(db, group1.id, multi_group_user.id) UserGroupService.add_member(db, group2.id, multi_group_user.id) - ProjectShareService.create_group_share( - db, test_project.id, group1.id, ProjectRole.VIEWER - ) - ProjectShareService.create_group_share( - db, test_project.id, group2.id, ProjectRole.ADMIN - ) + ProjectShareService.create_group_share(db, test_project.id, group1.id, ProjectRole.VIEWER) + ProjectShareService.create_group_share(db, test_project.id, group2.id, ProjectRole.ADMIN) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, multi_group_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, multi_group_user.id) assert share.role == ProjectRole.ADMIN class TestProjectShareOrgAccess: """Tests for org-level project access.""" - def test_create_org_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_create_org_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test creating an org-level share.""" share = ProjectShareService.create_org_share( db=db, @@ -434,22 +335,16 @@ def test_create_org_share( assert share.subject_id == test_org.id assert share.role == ProjectRole.VIEWER - def test_get_org_share( - self, db: Session, test_user: User, test_org: Organization, test_project: Project - ): + def test_get_org_share(self, db: Session, test_user: User, test_org: Organization, test_project: Project): """Test getting org share for a project.""" - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.MEMBER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.MEMBER) share = ProjectShareService.get_org_share(db, test_project.id) assert share is not None assert share.subject_type == ShareSubjectType.ORG assert share.role == ProjectRole.MEMBER - def test_get_org_share_not_found( - self, db: Session, test_project: Project - ): + def test_get_org_share_not_found(self, db: Session, test_project: Project): """Test getting org share returns None when none exists.""" share = ProjectShareService.get_org_share(db, test_project.id) assert share is None @@ -463,19 +358,13 @@ def test_user_has_access_via_org_share( OrgService.add_org_member(db, test_org.id, new_user.id) # Initially no access - assert not ProjectShareService.user_has_project_access( - db, test_project.id, new_user.id - ) + assert not ProjectShareService.user_has_project_access(db, test_project.id, new_user.id) # Create org share - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.VIEWER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.VIEWER) # Now has access - assert ProjectShareService.user_has_project_access( - db, test_project.id, new_user.id - ) + assert ProjectShareService.user_has_project_access(db, test_project.id, new_user.id) def test_get_effective_share_org_only( self, db: Session, test_user: User, test_org: Organization, test_project: Project @@ -486,13 +375,9 @@ def test_get_effective_share_org_only( OrgService.add_org_member(db, test_org.id, org_user.id) # Create org share - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.MEMBER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.MEMBER) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, org_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, org_user.id) assert share is not None assert share.role == ProjectRole.MEMBER assert share.subject_type == ShareSubjectType.ORG @@ -506,18 +391,12 @@ def test_effective_share_precedence_direct_over_org( OrgService.add_org_member(db, test_org.id, mixed_user.id) # Create org share as MEMBER - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.MEMBER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.MEMBER) # Create direct share also as MEMBER - ProjectShareService.create_user_share( - db, test_project.id, mixed_user.id, ProjectRole.MEMBER - ) + ProjectShareService.create_user_share(db, test_project.id, mixed_user.id, ProjectRole.MEMBER) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, mixed_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, mixed_user.id) assert share.subject_type == ShareSubjectType.USER # Direct takes precedence def test_effective_share_precedence_group_over_org( @@ -532,16 +411,10 @@ def test_effective_share_precedence_group_over_org( UserGroupService.add_member(db, group.id, group_user.id) # Create both shares at same role - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.VIEWER - ) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.VIEWER) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, group_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, group_user.id) assert share.subject_type == ShareSubjectType.GROUP # Group takes precedence def test_effective_share_higher_role_wins( @@ -553,18 +426,12 @@ def test_effective_share_higher_role_wins( OrgService.add_org_member(db, test_org.id, mixed_user.id) # Org share gives ADMIN - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.ADMIN - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.ADMIN) # Direct share only gives VIEWER - ProjectShareService.create_user_share( - db, test_project.id, mixed_user.id, ProjectRole.VIEWER - ) + ProjectShareService.create_user_share(db, test_project.id, mixed_user.id, ProjectRole.VIEWER) - share = ProjectShareService.get_user_effective_share( - db, test_project.id, mixed_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, mixed_user.id) assert share.role == ProjectRole.ADMIN # Higher role wins assert share.subject_type == ShareSubjectType.ORG @@ -577,20 +444,14 @@ def test_list_accessible_projects_includes_org_shares( OrgService.add_org_member(db, test_org.id, org_user.id) # Initially no access - project_ids = ProjectShareService.list_user_accessible_project_ids( - db, org_user.id, test_org.id - ) + project_ids = ProjectShareService.list_user_accessible_project_ids(db, org_user.id, test_org.id) assert test_project.id not in project_ids # Create org share - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.VIEWER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.VIEWER) # Now included - project_ids = ProjectShareService.list_user_accessible_project_ids( - db, org_user.id, test_org.id - ) + project_ids = ProjectShareService.list_user_accessible_project_ids(db, org_user.id, test_org.id) assert test_project.id in project_ids def test_non_org_member_no_access_via_org_share( @@ -603,22 +464,16 @@ def test_non_org_member_no_access_via_org_share( other_org, _ = OrgService.create_org_with_owner(db, "Other Org", outside_user.id) # Create org share - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.ADMIN - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.ADMIN) # Still no access (not an org member) - assert not ProjectShareService.user_has_project_access( - db, test_project.id, outside_user.id - ) + assert not ProjectShareService.user_has_project_access(db, test_project.id, outside_user.id) class TestOrgMemberNoImplicitAccess: """Tests that org members (including admins/owners) require explicit shares.""" - def test_org_owner_no_implicit_project_access( - self, db: Session, test_user: User, test_org: Organization - ): + def test_org_owner_no_implicit_project_access(self, db: Session, test_user: User, test_org: Organization): """Test that org owners don't have implicit access to projects.""" from app.models import OrgRole @@ -632,9 +487,7 @@ def test_org_owner_no_implicit_project_access( ) # test_user is org owner but didn't create this project - no access - share = ProjectShareService.get_user_effective_share( - db, project.id, test_user.id - ) + share = ProjectShareService.get_user_effective_share(db, project.id, test_user.id) assert share is None @@ -649,9 +502,7 @@ def test_org_admin_no_implicit_project_access( OrgService.add_org_member(db, test_org.id, admin_user.id, OrgRole.ADMIN) # Admin has no explicit share - no access - share = ProjectShareService.get_user_effective_share( - db, test_project.id, admin_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, admin_user.id) assert share is None @@ -666,34 +517,24 @@ def test_org_member_no_implicit_project_access( OrgService.add_org_member(db, test_org.id, member_user.id, OrgRole.MEMBER) # No explicit share, no implicit access - share = ProjectShareService.get_user_effective_share( - db, test_project.id, member_user.id - ) + share = ProjectShareService.get_user_effective_share(db, test_project.id, member_user.id) assert share is None - def test_org_admin_cannot_access_projects_without_share( - self, db: Session, test_user: User, test_org: Organization - ): + def test_org_admin_cannot_access_projects_without_share(self, db: Session, test_user: User, test_org: Organization): """Test that org admins cannot access projects via list_user_accessible_project_ids without explicit shares.""" from app.models import OrgRole # Create multiple projects (test_user creates them, gets explicit OWNER shares) - project1 = ProjectService.create_project( - db, test_org.id, test_user.id, ProjectType.APPLICATION, "Project 1" - ) - project2 = ProjectService.create_project( - db, test_org.id, test_user.id, ProjectType.APPLICATION, "Project 2" - ) + project1 = ProjectService.create_project(db, test_org.id, test_user.id, ProjectType.APPLICATION, "Project 1") + project2 = ProjectService.create_project(db, test_org.id, test_user.id, ProjectType.APPLICATION, "Project 2") # Create an admin user (no explicit shares to projects) admin_user = UserService.create_user(db, "admin2@example.com", "testpassword") OrgService.add_org_member(db, test_org.id, admin_user.id, OrgRole.ADMIN) # Admin should NOT see projects without explicit shares - project_ids = ProjectShareService.list_user_accessible_project_ids( - db, admin_user.id, test_org.id - ) + project_ids = ProjectShareService.list_user_accessible_project_ids(db, admin_user.id, test_org.id) assert project1.id not in project_ids assert project2.id not in project_ids @@ -711,9 +552,7 @@ def test_list_includes_direct_shares( # Create a member user with direct share member_user = UserService.create_user(db, "direct@example.com", "testpassword") OrgService.add_org_member(db, test_org.id, member_user.id, OrgRole.MEMBER) - ProjectShareService.create_user_share( - db, test_project.id, member_user.id, ProjectRole.MEMBER - ) + ProjectShareService.create_user_share(db, test_project.id, member_user.id, ProjectRole.MEMBER) users = ProjectShareService.list_project_accessible_users(db, test_project.id) user_ids = [u.id for u in users] @@ -749,9 +588,7 @@ def test_list_expands_group_shares( group = UserGroupService.create_group(db, test_org.id, "DevTeam", test_user.id) UserGroupService.add_member(db, group.id, group_member.id) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.MEMBER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.MEMBER) users = ProjectShareService.list_project_accessible_users(db, test_project.id) user_ids = [u.id for u in users] @@ -774,9 +611,7 @@ def test_list_includes_all_org_members_when_org_share_exists( assert org_member.id not in user_ids_before # Create org share - ProjectShareService.create_org_share( - db, test_project.id, test_org.id, ProjectRole.VIEWER - ) + ProjectShareService.create_org_share(db, test_project.id, test_org.id, ProjectRole.VIEWER) # Now member is included users_after = ProjectShareService.list_project_accessible_users(db, test_project.id) diff --git a/backend/tests/test_project_shares_router.py b/backend/tests/test_project_shares_router.py index a612541..5f96dbb 100644 --- a/backend/tests/test_project_shares_router.py +++ b/backend/tests/test_project_shares_router.py @@ -1,11 +1,9 @@ """Tests for project shares router endpoints.""" -import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Organization, Project, ProjectRole, OrgRole -from app.models.project_share import ShareSubjectType +from app.models import Organization, Project, ProjectRole, User from app.services.org_service import OrgService from app.services.project_share_service import ProjectShareService from app.services.user_group_service import UserGroupService @@ -62,9 +60,7 @@ def test_create_group_share_success( headers = auth_headers(test_user) # Create a group - group = UserGroupService.create_group( - db, test_org.id, "Engineering", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Engineering", test_user.id) response = client.post( f"/api/v1/projects/{test_project.id}/shares", @@ -168,12 +164,8 @@ def test_create_share_group_not_in_org( headers = auth_headers(test_user) # Create another org and group - other_org, _ = OrgService.create_org_with_owner( - db, "Other Org", test_user.id - ) - other_group = UserGroupService.create_group( - db, other_org.id, "Other Group", test_user.id - ) + other_org, _ = OrgService.create_org_with_owner(db, "Other Org", test_user.id) + other_group = UserGroupService.create_group(db, other_org.id, "Other Group", test_user.id) response = client.post( f"/api/v1/projects/{test_project.id}/shares", @@ -201,9 +193,7 @@ def test_create_share_requires_admin_role( # Create a viewer user viewer = UserService.create_user(db, "viewer@test.com", "password") OrgService.add_org_member(db, test_org.id, viewer.id) - ProjectShareService.create_user_share( - db, test_project.id, viewer.id, ProjectRole.VIEWER - ) + ProjectShareService.create_user_share(db, test_project.id, viewer.id, ProjectRole.VIEWER) headers = auth_headers(viewer) # Try to create a share as viewer @@ -238,16 +228,10 @@ def test_list_shares_success( # Create additional shares other_user = UserService.create_user(db, "other@test.com", "password") OrgService.add_org_member(db, test_org.id, other_user.id) - ProjectShareService.create_user_share( - db, test_project.id, other_user.id, ProjectRole.MEMBER - ) + ProjectShareService.create_user_share(db, test_project.id, other_user.id, ProjectRole.MEMBER) - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) response = client.get( f"/api/v1/projects/{test_project.id}/shares", @@ -273,12 +257,8 @@ def test_list_shares_filter_by_user_type( headers = auth_headers(test_user) # Create a group share - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) # Filter for users only response = client.get( @@ -305,12 +285,8 @@ def test_list_shares_filter_by_group_type( headers = auth_headers(test_user) # Create a group share - group = UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.VIEWER - ) + group = UserGroupService.create_group(db, test_org.id, "Team", test_user.id) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.VIEWER) # Filter for groups only response = client.get( @@ -336,13 +312,9 @@ def test_list_shares_includes_enriched_data( headers = auth_headers(test_user) # Create a group with members - group = UserGroupService.create_group( - db, test_org.id, "Engineering", test_user.id, description="The eng team" - ) + group = UserGroupService.create_group(db, test_org.id, "Engineering", test_user.id, description="The eng team") UserGroupService.add_member(db, group.id, test_user.id) - ProjectShareService.create_group_share( - db, test_project.id, group.id, ProjectRole.MEMBER - ) + ProjectShareService.create_group_share(db, test_project.id, group.id, ProjectRole.MEMBER) response = client.get( f"/api/v1/projects/{test_project.id}/shares?subject_type=group", @@ -376,9 +348,7 @@ def test_update_share_role_success( # Create a user share other_user = UserService.create_user(db, "other@test.com", "password") OrgService.add_org_member(db, test_org.id, other_user.id) - share = ProjectShareService.create_user_share( - db, test_project.id, other_user.id, ProjectRole.VIEWER - ) + share = ProjectShareService.create_user_share(db, test_project.id, other_user.id, ProjectRole.VIEWER) response = client.patch( f"/api/v1/projects/{test_project.id}/shares/{share.id}", @@ -423,9 +393,7 @@ def test_update_share_requires_admin( # Create a viewer user viewer = UserService.create_user(db, "viewer@test.com", "password") OrgService.add_org_member(db, test_org.id, viewer.id) - share = ProjectShareService.create_user_share( - db, test_project.id, viewer.id, ProjectRole.VIEWER - ) + share = ProjectShareService.create_user_share(db, test_project.id, viewer.id, ProjectRole.VIEWER) headers = auth_headers(viewer) response = client.patch( @@ -455,9 +423,7 @@ def test_delete_share_success( # Create a user share other_user = UserService.create_user(db, "other@test.com", "password") OrgService.add_org_member(db, test_org.id, other_user.id) - share = ProjectShareService.create_user_share( - db, test_project.id, other_user.id, ProjectRole.MEMBER - ) + share = ProjectShareService.create_user_share(db, test_project.id, other_user.id, ProjectRole.MEMBER) response = client.delete( f"/api/v1/projects/{test_project.id}/shares/{share.id}", @@ -508,9 +474,7 @@ def test_delete_owner_allowed_if_multiple_owners( # Add another owner other_user = UserService.create_user(db, "other@test.com", "password") OrgService.add_org_member(db, test_org.id, other_user.id) - other_owner_share = ProjectShareService.create_user_share( - db, test_project.id, other_user.id, ProjectRole.OWNER - ) + other_owner_share = ProjectShareService.create_user_share(db, test_project.id, other_user.id, ProjectRole.OWNER) # Now we can delete the other owner response = client.delete( @@ -555,9 +519,7 @@ def test_search_returns_users_and_groups( headers = auth_headers(test_user) # Create a group - UserGroupService.create_group( - db, test_org.id, "Engineering", test_user.id - ) + UserGroupService.create_group(db, test_org.id, "Engineering", test_user.id) response = client.get( f"/api/v1/orgs/{test_org.id}/shareable-subjects", @@ -616,9 +578,7 @@ def test_search_filter_by_user_type( headers = auth_headers(test_user) # Create a group - UserGroupService.create_group( - db, test_org.id, "Team", test_user.id - ) + UserGroupService.create_group(db, test_org.id, "Team", test_user.id) response = client.get( f"/api/v1/orgs/{test_org.id}/shareable-subjects?type=user", @@ -697,9 +657,7 @@ def test_search_group_includes_member_count( headers = auth_headers(test_user) # Create a group and add members - group = UserGroupService.create_group( - db, test_org.id, "Big Team", test_user.id - ) + group = UserGroupService.create_group(db, test_org.id, "Big Team", test_user.id) UserGroupService.add_member(db, group.id, test_user.id) other_user = UserService.create_user(db, "other@test.com", "password") @@ -714,7 +672,5 @@ def test_search_group_includes_member_count( assert response.status_code == 200 data = response.json() - group_subject = next( - s for s in data["subjects"] if s["name"] == "Big Team" - ) + group_subject = next(s for s in data["subjects"] if s["name"] == "Big Team") assert group_subject["member_count"] == 2 diff --git a/backend/tests/test_short_id.py b/backend/tests/test_short_id.py index d6fed35..c7b4615 100644 --- a/backend/tests/test_short_id.py +++ b/backend/tests/test_short_id.py @@ -1,15 +1,14 @@ """Tests for short_id utilities.""" -import pytest from app.utils.short_id import ( - generate_short_id, - slugify, + SHORT_ID_LENGTH, build_url_identifier, + extract_short_id, + generate_short_id, is_uuid, is_valid_short_id, parse_identifier, - extract_short_id, - SHORT_ID_LENGTH, + slugify, ) diff --git a/backend/tests/test_slack_models.py b/backend/tests/test_slack_models.py index 9e302f0..98bb6a6 100644 --- a/backend/tests/test_slack_models.py +++ b/backend/tests/test_slack_models.py @@ -2,9 +2,6 @@ Tests for Slack bot database models: SlackChannelProjectLink, SlackUserMapping. """ -import uuid -from datetime import datetime, timezone - import pytest from sqlalchemy.exc import IntegrityError diff --git a/backend/tests/test_spec_service.py b/backend/tests/test_spec_service.py index 2b67cfc..28d80c7 100644 --- a/backend/tests/test_spec_service.py +++ b/backend/tests/test_spec_service.py @@ -1,10 +1,10 @@ """Tests for SpecService.""" -import pytest -from sqlalchemy.orm import Session from uuid import uuid4 -from app.models import SpecVersion, SpecType, Project, User, Organization, OrgMembership, OrgRole +from sqlalchemy.orm import Session + +from app.models import SpecType, SpecVersion from app.services.spec_service import SpecService @@ -136,9 +136,7 @@ def test_get_active_spec_found(self, db: Session, test_project, test_user): ) # Retrieve it - active_spec = SpecService.get_active_spec( - db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION - ) + active_spec = SpecService.get_active_spec(db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION) assert active_spec is not None assert active_spec.id == created_spec.id @@ -146,9 +144,7 @@ def test_get_active_spec_found(self, db: Session, test_project, test_user): def test_get_active_spec_not_found(self, db: Session, test_project): """Test retrieving active spec when none exists.""" - active_spec = SpecService.get_active_spec( - db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION - ) + active_spec = SpecService.get_active_spec(db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION) assert active_spec is None @@ -201,9 +197,7 @@ def test_list_spec_versions_ordered(self, db: Session, test_project, test_user): ) # List all versions - versions = SpecService.list_spec_versions( - db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION - ) + versions = SpecService.list_spec_versions(db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION) assert len(versions) == 3 # Should be ordered by version descending (newest first) @@ -216,9 +210,7 @@ def test_list_spec_versions_ordered(self, db: Session, test_project, test_user): def test_list_spec_versions_empty(self, db: Session, test_project): """Test listing spec versions when none exist.""" - versions = SpecService.list_spec_versions( - db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION - ) + versions = SpecService.list_spec_versions(db=db, project_id=test_project.id, spec_type=SpecType.SPECIFICATION) assert versions == [] diff --git a/backend/tests/test_team_roles.py b/backend/tests/test_team_roles.py index ef1e47d..1e1cf53 100644 --- a/backend/tests/test_team_roles.py +++ b/backend/tests/test_team_roles.py @@ -1,25 +1,24 @@ """Tests for team roles API endpoints.""" import pytest -from uuid import uuid4 -from sqlalchemy.orm import Session from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app.models import ( - User, Organization, + OrgMembership, + OrgRole, Project, ProjectRole, + ProjectTeamAssignment, ProjectType, - OrgMembership, - OrgRole, TeamRoleDefinition, - ProjectTeamAssignment, + User, ) from app.models.project_share import ProjectShare, ShareSubjectType -from app.services.user_service import UserService from app.services.project_service import ProjectService from app.services.team_role_service import TeamRoleService +from app.services.user_service import UserService class TestOrgTeamRoleDefinitions: @@ -65,9 +64,7 @@ def test_update_team_role_definition( ): """Test updating a team role's title and description.""" # First create a role - role = TeamRoleService.create_role_definition( - db, test_org.id, "Engineers", "Write code" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Engineers", "Write code") response = client.patch( f"/api/v1/orgs/{test_org.id}/team-role-definitions/{role.id}", @@ -87,9 +84,7 @@ def test_update_role_definition_partial( self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers ): """Test updating only the title of a role.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Architects", "Design system architecture" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Architects", "Design system architecture") response = client.patch( f"/api/v1/orgs/{test_org.id}/team-role-definitions/{role.id}", @@ -107,9 +102,7 @@ def test_delete_team_role_definition( self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers ): """Test deleting a team role definition.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Temporary Role", "To be deleted" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Temporary Role", "To be deleted") response = client.delete( f"/api/v1/orgs/{test_org.id}/team-role-definitions/{role.id}", @@ -151,9 +144,7 @@ def test_reset_non_default_role_fails( self, client: TestClient, db: Session, test_user: User, test_org: Organization, auth_headers ): """Test that resetting a non-default role fails.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Custom Role", "Custom description" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Custom Role", "Custom description") response = client.post( f"/api/v1/orgs/{test_org.id}/team-role-definitions/{role.id}/reset", @@ -162,9 +153,7 @@ def test_reset_non_default_role_fails( assert response.status_code == 400 - def test_viewer_can_list_definitions( - self, client: TestClient, db: Session, test_org: Organization, auth_headers - ): + def test_viewer_can_list_definitions(self, client: TestClient, db: Session, test_org: Organization, auth_headers): """Test that VIEWER can list role definitions.""" viewer = UserService.create_user( db=db, @@ -270,7 +259,13 @@ def architect_role(self, db: Session, test_org: Organization) -> TeamRoleDefinit ) def test_get_project_team_empty( - self, client: TestClient, db: Session, test_user: User, team_project: Project, engineer_role: TeamRoleDefinition, auth_headers + self, + client: TestClient, + db: Session, + test_user: User, + team_project: Project, + engineer_role: TeamRoleDefinition, + auth_headers, ): """Test getting project team with no assignments.""" response = client.get( @@ -515,9 +510,7 @@ class TestTeamRoleService: def test_create_role_definition(self, db: Session, test_org: Organization): """Test creating a role definition.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "DevOps Team", "Manage infrastructure" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "DevOps Team", "Manage infrastructure") assert role.org_id == test_org.id assert role.title == "DevOps Team" @@ -528,30 +521,23 @@ def test_create_role_definition(self, db: Session, test_org: Organization): def test_create_role_definition_with_custom_key(self, db: Session, test_org: Organization): """Test creating a role with a custom key.""" role = TeamRoleService.create_role_definition( - db, test_org.id, "Senior Engineers", "Senior developers", - role_key="senior_devs" + db, test_org.id, "Senior Engineers", "Senior developers", role_key="senior_devs" ) assert role.role_key == "senior_devs" def test_update_role_definition(self, db: Session, test_org: Organization): """Test updating a role definition.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Engineers", "Write code" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Engineers", "Write code") - updated = TeamRoleService.update_role_definition( - db, role.id, "Dev Team", "Developers" - ) + updated = TeamRoleService.update_role_definition(db, role.id, "Dev Team", "Developers") assert updated.title == "Dev Team" assert updated.description == "Developers" def test_delete_role_definition(self, db: Session, test_org: Organization): """Test deleting a role definition.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Temp Role", "Temporary" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Temp Role", "Temporary") role_id = role.id TeamRoleService.delete_role_definition(db, role_id) @@ -571,40 +557,24 @@ def test_reset_role_to_default(self, db: Session, test_org: Organization): assert reset.title == "Product Owners" assert reset.description == "Define product vision and prioritize features" - def test_assign_user_to_role( - self, db: Session, test_user: User, test_project: Project, test_org: Organization - ): + def test_assign_user_to_role(self, db: Session, test_user: User, test_project: Project, test_org: Organization): """Test assigning a user to a role.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Architects", "Design systems" - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Architects", "Design systems") - assignment = TeamRoleService.assign_user_to_role( - db, test_project.id, test_user.id, role.id, test_user.id - ) + assignment = TeamRoleService.assign_user_to_role(db, test_project.id, test_user.id, role.id, test_user.id) assert assignment.project_id == test_project.id assert assignment.user_id == test_user.id assert assignment.role_definition_id == role.id assert assignment.assigned_by == test_user.id - def test_get_project_team( - self, db: Session, test_user: User, test_project: Project, test_org: Organization - ): + def test_get_project_team(self, db: Session, test_user: User, test_project: Project, test_org: Organization): """Test getting project team with assignments.""" - role1 = TeamRoleService.create_role_definition( - db, test_org.id, "Engineers", "Write code" - ) - role2 = TeamRoleService.create_role_definition( - db, test_org.id, "Architects", "Design systems" - ) + role1 = TeamRoleService.create_role_definition(db, test_org.id, "Engineers", "Write code") + role2 = TeamRoleService.create_role_definition(db, test_org.id, "Architects", "Design systems") - TeamRoleService.assign_user_to_role( - db, test_project.id, test_user.id, role1.id, test_user.id - ) - TeamRoleService.assign_user_to_role( - db, test_project.id, test_user.id, role2.id, test_user.id - ) + TeamRoleService.assign_user_to_role(db, test_project.id, test_user.id, role1.id, test_user.id) + TeamRoleService.assign_user_to_role(db, test_project.id, test_user.id, role2.id, test_user.id) team = TeamRoleService.get_project_team(db, test_project.id) @@ -614,16 +584,10 @@ def test_get_project_team( assert len(team[role2.id]) == 1 assert team[role1.id][0].user_id == test_user.id - def test_remove_assignment_by_id( - self, db: Session, test_user: User, test_project: Project, test_org: Organization - ): + def test_remove_assignment_by_id(self, db: Session, test_user: User, test_project: Project, test_org: Organization): """Test removing a team assignment by ID.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "QA Engineers", "Test things" - ) - assignment = TeamRoleService.assign_user_to_role( - db, test_project.id, test_user.id, role.id, test_user.id - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "QA Engineers", "Test things") + assignment = TeamRoleService.assign_user_to_role(db, test_project.id, test_user.id, role.id, test_user.id) TeamRoleService.remove_assignment_by_id(db, assignment.id) @@ -632,14 +596,10 @@ def test_remove_assignment_by_id( def test_generate_unique_role_key(self, db: Session, test_org: Organization): """Test generating unique role keys.""" - TeamRoleService.create_role_definition( - db, test_org.id, "Data Scientists", "Analyze data" - ) + TeamRoleService.create_role_definition(db, test_org.id, "Data Scientists", "Analyze data") # Create another role with the same title - should get unique key - role2 = TeamRoleService.create_role_definition( - db, test_org.id, "Data Scientists", "More data scientists" - ) + role2 = TeamRoleService.create_role_definition(db, test_org.id, "Data Scientists", "More data scientists") assert role2.role_key == "data_scientists_2" @@ -647,12 +607,8 @@ def test_delete_role_cascades_assignments( self, db: Session, test_user: User, test_project: Project, test_org: Organization ): """Test that deleting a role cascades to assignments.""" - role = TeamRoleService.create_role_definition( - db, test_org.id, "Temp Role", "Temporary" - ) - assignment = TeamRoleService.assign_user_to_role( - db, test_project.id, test_user.id, role.id, test_user.id - ) + role = TeamRoleService.create_role_definition(db, test_org.id, "Temp Role", "Temporary") + assignment = TeamRoleService.assign_user_to_role(db, test_project.id, test_user.id, role.id, test_user.id) assignment_id = assignment.id TeamRoleService.delete_role_definition(db, role.id) @@ -700,14 +656,14 @@ def test_seed_default_roles_skips_existing(self, db: Session): """Test that seed_default_roles skips existing roles with matching keys.""" # Create a raw org without using create_org_with_owner (which auto-seeds) from app.models import Organization + raw_org = Organization(name="Raw Test Org") db.add(raw_org) db.flush() # Create one of the default roles manually with a custom title custom_role = TeamRoleService.create_role_definition( - db, raw_org.id, "Custom Engineers Title", "Custom description", - role_key="engineers", is_default=False + db, raw_org.id, "Custom Engineers Title", "Custom description", role_key="engineers", is_default=False ) # Seed defaults diff --git a/backend/tests/test_thread_decision_summary.py b/backend/tests/test_thread_decision_summary.py index 50927a4..685978c 100644 --- a/backend/tests/test_thread_decision_summary.py +++ b/backend/tests/test_thread_decision_summary.py @@ -1,12 +1,10 @@ """Tests for thread decision summary functionality.""" -import pytest -from unittest.mock import patch, MagicMock + from datetime import datetime, timezone from uuid import uuid4 -from app.models.thread import Thread, ContextType +from app.models.thread import ContextType, Thread from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.job import Job, JobType, JobStatus from app.services.thread_service import ThreadService diff --git a/backend/tests/test_thread_endpoints.py b/backend/tests/test_thread_endpoints.py index d5199e1..75b32c4 100644 --- a/backend/tests/test_thread_endpoints.py +++ b/backend/tests/test_thread_endpoints.py @@ -1,8 +1,9 @@ """Tests for thread and comment endpoints.""" -import pytest + from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Project, ContextType + +from app.models import ContextType, Project, User from app.services.thread_service import ThreadService @@ -14,13 +15,15 @@ def test_list_project_threads_requires_auth(self, client: TestClient, test_proje response = client.get(f"/api/v1/projects/{test_project.id}/threads") assert response.status_code == 401 - def test_list_project_threads_requires_membership(self, client: TestClient, auth_headers, db: Session, test_user: User): + def test_list_project_threads_requires_membership( + self, client: TestClient, auth_headers, db: Session, test_user: User + ): """Test that listing threads requires project membership (404 for non-members).""" # Create a different user and project - from app.services.user_service import UserService + from app.models import ProjectType from app.services.org_service import OrgService from app.services.project_service import ProjectService - from app.models import ProjectType + from app.services.user_service import UserService other_user = UserService.create_user(db, "other@example.com", "password", "Other User") other_org, _ = OrgService.create_org_with_owner(db, "Other Org", other_user.id) @@ -33,7 +36,9 @@ def test_list_project_threads_requires_membership(self, client: TestClient, auth response = client.get(f"/api/v1/projects/{other_project.id}/threads", headers=headers) assert response.status_code == 404 - def test_list_project_threads_success(self, client: TestClient, auth_headers, test_user: User, test_project: Project, db: Session): + def test_list_project_threads_success( + self, client: TestClient, auth_headers, test_user: User, test_project: Project, db: Session + ): """Test successfully listing threads for a project.""" # Create some threads ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread 1", str(test_user.id)) @@ -45,7 +50,9 @@ def test_list_project_threads_success(self, client: TestClient, auth_headers, te data = response.json() assert len(data) == 2 - def test_list_project_threads_filter_by_context_type(self, client: TestClient, auth_headers, test_user: User, test_project: Project, db: Session): + def test_list_project_threads_filter_by_context_type( + self, client: TestClient, auth_headers, test_user: User, test_project: Project, db: Session + ): """Test filtering threads by context type.""" ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "General", str(test_user.id)) ThreadService.create_thread(db, str(test_project.id), ContextType.SPEC, "spec-1", "Spec", str(test_user.id)) @@ -60,18 +67,17 @@ def test_list_project_threads_filter_by_context_type(self, client: TestClient, a def test_create_thread_requires_auth(self, client: TestClient, test_project: Project): """Test that creating a thread requires authentication.""" response = client.post( - f"/api/v1/projects/{test_project.id}/threads", - json={"context_type": "general", "title": "Test"} + f"/api/v1/projects/{test_project.id}/threads", json={"context_type": "general", "title": "Test"} ) assert response.status_code == 401 def test_create_thread_requires_membership(self, client: TestClient, auth_headers, db: Session, test_user: User): """Test that creating a thread requires project membership (404 for non-members).""" # Create a different user and project - from app.services.user_service import UserService + from app.models import ProjectType from app.services.org_service import OrgService from app.services.project_service import ProjectService - from app.models import ProjectType + from app.services.user_service import UserService other_user = UserService.create_user(db, "other2@example.com", "password", "Other User 2") other_org, _ = OrgService.create_org_with_owner(db, "Other Org 2", other_user.id) @@ -82,7 +88,7 @@ def test_create_thread_requires_membership(self, client: TestClient, auth_header response = client.post( f"/api/v1/projects/{other_project.id}/threads", json={"context_type": "general", "title": "Test"}, - headers=auth_headers(test_user) + headers=auth_headers(test_user), ) assert response.status_code == 404 @@ -91,7 +97,7 @@ def test_create_thread_success(self, client: TestClient, auth_headers, test_user response = client.post( f"/api/v1/projects/{test_project.id}/threads", json={"context_type": "general", "title": "New Thread"}, - headers=auth_headers(test_user) + headers=auth_headers(test_user), ) assert response.status_code == 201 data = response.json() @@ -103,8 +109,12 @@ def test_create_thread_with_context(self, client: TestClient, auth_headers, test """Test creating a thread with context_id.""" response = client.post( f"/api/v1/projects/{test_project.id}/threads", - json={"context_type": "brainstorm_feature", "context_id": "feature-123", "title": "Feature discussion thread"}, - headers=auth_headers(test_user) + json={ + "context_type": "brainstorm_feature", + "context_id": "feature-123", + "title": "Feature discussion thread", + }, + headers=auth_headers(test_user), ) assert response.status_code == 201 data = response.json() @@ -113,13 +123,19 @@ def test_create_thread_with_context(self, client: TestClient, auth_headers, test def test_get_thread_requires_auth(self, client: TestClient, db: Session, test_user: User, test_project: Project): """Test that getting a thread requires authentication.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) response = client.get(f"/api/v1/threads/{thread.id}") assert response.status_code == 401 - def test_get_thread_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_get_thread_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully getting a thread with its comments.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment 1") ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment 2") @@ -136,14 +152,22 @@ def test_get_thread_not_found(self, client: TestClient, auth_headers, test_user: def test_update_thread_requires_auth(self, client: TestClient, db: Session, test_user: User, test_project: Project): """Test that updating a thread requires authentication.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) response = client.patch(f"/api/v1/threads/{thread.id}", json={"title": "Updated"}) assert response.status_code == 401 - def test_update_thread_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_update_thread_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully updating a thread title.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Old Title", str(test_user.id)) - response = client.patch(f"/api/v1/threads/{thread.id}", json={"title": "New Title"}, headers=auth_headers(test_user)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Old Title", str(test_user.id) + ) + response = client.patch( + f"/api/v1/threads/{thread.id}", json={"title": "New Title"}, headers=auth_headers(test_user) + ) assert response.status_code == 200 data = response.json() assert data["title"] == "New Title" @@ -152,19 +176,27 @@ def test_update_thread_success(self, client: TestClient, auth_headers, db: Sessi class TestCommentEndpoints: """Test comment REST endpoints.""" - def test_create_comment_requires_auth(self, client: TestClient, db: Session, test_user: User, test_project: Project): + def test_create_comment_requires_auth( + self, client: TestClient, db: Session, test_user: User, test_project: Project + ): """Test that creating a comment requires authentication.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) response = client.post(f"/api/v1/threads/{thread.id}/comments", json={"body_markdown": "Comment"}) assert response.status_code == 401 - def test_create_comment_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_create_comment_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully creating a comment.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) response = client.post( f"/api/v1/threads/{thread.id}/comments", json={"body_markdown": "This is a comment."}, - headers=auth_headers(test_user) + headers=auth_headers(test_user), ) assert response.status_code == 201 data = response.json() @@ -174,42 +206,54 @@ def test_create_comment_success(self, client: TestClient, auth_headers, db: Sess def test_create_comment_invalid_thread(self, client: TestClient, auth_headers, test_user: User): """Test creating a comment on a non-existent thread returns 404.""" response = client.post( - "/api/v1/threads/nonexistent/comments", - json={"body_markdown": "Comment"}, - headers=auth_headers(test_user) + "/api/v1/threads/nonexistent/comments", json={"body_markdown": "Comment"}, headers=auth_headers(test_user) ) assert response.status_code == 404 - def test_update_comment_requires_auth(self, client: TestClient, db: Session, test_user: User, test_project: Project): + def test_update_comment_requires_auth( + self, client: TestClient, db: Session, test_user: User, test_project: Project + ): """Test that updating a comment requires authentication.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment(db, thread.id, str(test_user.id), "Original") response = client.patch(f"/api/v1/comments/{comment.id}", json={"body_markdown": "Updated"}) assert response.status_code == 401 - def test_update_comment_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_update_comment_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully updating a comment.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment(db, thread.id, str(test_user.id), "Original") response = client.patch( - f"/api/v1/comments/{comment.id}", - json={"body_markdown": "Updated text"}, - headers=auth_headers(test_user) + f"/api/v1/comments/{comment.id}", json={"body_markdown": "Updated text"}, headers=auth_headers(test_user) ) assert response.status_code == 200 data = response.json() assert data["body_markdown"] == "Updated text" - def test_delete_comment_requires_auth(self, client: TestClient, db: Session, test_user: User, test_project: Project): + def test_delete_comment_requires_auth( + self, client: TestClient, db: Session, test_user: User, test_project: Project + ): """Test that deleting a comment requires authentication.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment") response = client.delete(f"/api/v1/comments/{comment.id}") assert response.status_code == 401 - def test_delete_comment_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_delete_comment_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully deleting a comment.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment") response = client.delete(f"/api/v1/comments/{comment.id}", headers=auth_headers(test_user)) assert response.status_code == 204 @@ -228,52 +272,62 @@ def test_delete_comment_not_found(self, client: TestClient, auth_headers, test_u class TestTypingIndicatorEndpoint: """Test typing indicator endpoint.""" - def test_typing_indicator_requires_auth(self, client: TestClient, db: Session, test_user: User, test_project: Project): + def test_typing_indicator_requires_auth( + self, client: TestClient, db: Session, test_user: User, test_project: Project + ): """Test that sending typing indicator requires authentication.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) response = client.post(f"/api/v1/threads/{thread.id}/typing", json={"typing": True}) assert response.status_code == 401 def test_typing_indicator_thread_not_found(self, client: TestClient, auth_headers, test_user: User): """Test typing indicator returns 404 for non-existent thread.""" response = client.post( - "/api/v1/threads/nonexistent-thread-id/typing", - json={"typing": True}, - headers=auth_headers(test_user) + "/api/v1/threads/nonexistent-thread-id/typing", json={"typing": True}, headers=auth_headers(test_user) ) assert response.status_code == 404 - def test_typing_indicator_requires_project_membership(self, client: TestClient, auth_headers, db: Session, test_user: User): + def test_typing_indicator_requires_project_membership( + self, client: TestClient, auth_headers, db: Session, test_user: User + ): """Test that typing indicator requires project membership (404 for non-members).""" - from app.services.user_service import UserService + from app.models import ProjectType from app.services.org_service import OrgService from app.services.project_service import ProjectService - from app.models import ProjectType + from app.services.user_service import UserService other_user = UserService.create_user(db, "typing-other@example.com", "password", "Other User") other_org, _ = OrgService.create_org_with_owner(db, "Typing Other Org", other_user.id) other_project = ProjectService.create_project( db, other_org.id, other_user.id, ProjectType.APPLICATION, "Typing Other Project", None, None ) - thread = ThreadService.create_thread(db, str(other_project.id), ContextType.GENERAL, None, "Thread", str(other_user.id)) + thread = ThreadService.create_thread( + db, str(other_project.id), ContextType.GENERAL, None, "Thread", str(other_user.id) + ) # Try to access with test_user's credentials (not a member) response = client.post( - f"/api/v1/threads/{thread.id}/typing", - json={"typing": True}, - headers=auth_headers(test_user) + f"/api/v1/threads/{thread.id}/typing", json={"typing": True}, headers=auth_headers(test_user) ) assert response.status_code == 404 - def test_typing_indicator_start_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_typing_indicator_start_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully sending typing start indicator.""" from unittest.mock import patch - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) # Mock the Redis client to avoid needing a real Redis connection - with patch("app.services.typing_indicator_service.get_redis_client") as mock_redis, \ - patch("app.services.typing_indicator_service.get_sync_kafka_producer") as mock_kafka: + with ( + patch("app.services.typing_indicator_service.get_redis_client") as mock_redis, + patch("app.services.typing_indicator_service.get_sync_kafka_producer") as mock_kafka, + ): mock_client = mock_redis.return_value mock_client.scan_iter.return_value = [] @@ -281,22 +335,26 @@ def test_typing_indicator_start_success(self, client: TestClient, auth_headers, mock_producer.publish.return_value = True response = client.post( - f"/api/v1/threads/{thread.id}/typing", - json={"typing": True}, - headers=auth_headers(test_user) + f"/api/v1/threads/{thread.id}/typing", json={"typing": True}, headers=auth_headers(test_user) ) assert response.status_code == 204 - def test_typing_indicator_stop_success(self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project): + def test_typing_indicator_stop_success( + self, client: TestClient, auth_headers, db: Session, test_user: User, test_project: Project + ): """Test successfully sending typing stop indicator.""" from unittest.mock import patch - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) # Mock the Redis client - with patch("app.services.typing_indicator_service.get_redis_client") as mock_redis, \ - patch("app.services.typing_indicator_service.get_sync_kafka_producer") as mock_kafka: + with ( + patch("app.services.typing_indicator_service.get_redis_client") as mock_redis, + patch("app.services.typing_indicator_service.get_sync_kafka_producer") as mock_kafka, + ): mock_client = mock_redis.return_value mock_client.delete.return_value = 0 mock_client.scan_iter.return_value = [] @@ -305,9 +363,7 @@ def test_typing_indicator_stop_success(self, client: TestClient, auth_headers, d mock_producer.publish.return_value = True response = client.post( - f"/api/v1/threads/{thread.id}/typing", - json={"typing": False}, - headers=auth_headers(test_user) + f"/api/v1/threads/{thread.id}/typing", json={"typing": False}, headers=auth_headers(test_user) ) assert response.status_code == 204 diff --git a/backend/tests/test_thread_item_deletion.py b/backend/tests/test_thread_item_deletion.py index a52f766..0e69523 100644 --- a/backend/tests/test_thread_item_deletion.py +++ b/backend/tests/test_thread_item_deletion.py @@ -1,8 +1,10 @@ """Tests for thread item deletion and start-over functionality.""" + import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Project, Thread, ContextType + +from app.models import ContextType, Project, Thread, User from app.models.thread_item import ThreadItem, ThreadItemType from app.services.thread_service import ThreadService @@ -85,8 +87,13 @@ def test_delete_success( assert deleted is None def test_delete_resets_decision_summary( - self, client: TestClient, auth_headers, db: Session, test_user: User, - test_thread: Thread, test_thread_items: list[ThreadItem] + self, + client: TestClient, + auth_headers, + db: Session, + test_user: User, + test_thread: Thread, + test_thread_items: list[ThreadItem], ): """Test that deleting an item resets decision summary fields.""" # Set decision summary on thread @@ -140,8 +147,13 @@ def test_start_over_author_only( assert response.status_code == 404 # 404 for privacy (not member of project) def test_start_over_deletes_item_and_all_after( - self, client: TestClient, auth_headers, db: Session, test_user: User, - test_thread: Thread, test_thread_items: list[ThreadItem] + self, + client: TestClient, + auth_headers, + db: Session, + test_user: User, + test_thread: Thread, + test_thread_items: list[ThreadItem], ): """Test that start-over deletes the target item and all items after it.""" headers = auth_headers(test_user) @@ -179,8 +191,13 @@ def test_start_over_returns_deleted_ids( assert len(data["deleted_item_ids"]) == 3 def test_start_over_resets_decision_summary( - self, client: TestClient, auth_headers, db: Session, test_user: User, - test_thread: Thread, test_thread_items: list[ThreadItem] + self, + client: TestClient, + auth_headers, + db: Session, + test_user: User, + test_thread: Thread, + test_thread_items: list[ThreadItem], ): """Test that start-over resets all decision summary fields.""" # Set decision summary on thread @@ -204,9 +221,7 @@ def test_start_over_resets_decision_summary( class TestStartOverService: """Tests for ThreadService.start_over_from_item().""" - def test_start_over_permission_error( - self, db: Session, test_thread: Thread, test_thread_items: list[ThreadItem] - ): + def test_start_over_permission_error(self, db: Session, test_thread: Thread, test_thread_items: list[ThreadItem]): """Test that PermissionError is raised if user is not the author.""" from app.services.user_service import UserService @@ -262,6 +277,7 @@ def answered_mcq_with_downstream( ) -> tuple[ThreadItem, list[ThreadItem]]: """Create an answered MCQ with downstream items.""" import time + from sqlalchemy.orm.attributes import flag_modified # Answer the MCQ @@ -291,9 +307,7 @@ def answered_mcq_with_downstream( return mcq_item, downstream - def test_get_downstream_count_with_no_downstream( - self, db: Session, mcq_item: ThreadItem - ): + def test_get_downstream_count_with_no_downstream(self, db: Session, mcq_item: ThreadItem): """Test that get_downstream_items_count returns 0 for MCQ with no downstream.""" result = ThreadService.get_downstream_items_count(db, str(mcq_item.id)) assert result["has_downstream"] is False @@ -309,25 +323,23 @@ def test_get_downstream_count_with_downstream( assert result["downstream_count"] == 2 def test_downstream_count_endpoint( - self, client: TestClient, auth_headers, test_user: User, - answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] + self, + client: TestClient, + auth_headers, + test_user: User, + answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]], ): """Test GET /thread-items/{item_id}/downstream-count endpoint.""" mcq_item, _ = answered_mcq_with_downstream headers = auth_headers(test_user) - response = client.get( - f"/api/v1/thread-items/{mcq_item.id}/downstream-count", - headers=headers - ) + response = client.get(f"/api/v1/thread-items/{mcq_item.id}/downstream-count", headers=headers) assert response.status_code == 200 data = response.json() assert data["has_downstream"] is True assert data["downstream_count"] == 2 - def test_answer_mcq_first_time_service( - self, db: Session, mcq_item: ThreadItem - ): + def test_answer_mcq_first_time_service(self, db: Session, mcq_item: ThreadItem): """Test answering an MCQ for the first time via service.""" result = ThreadService.answer_mcq_item( db=db, @@ -339,9 +351,7 @@ def test_answer_mcq_first_time_service( assert result.content_data["selected_option_id"] == "choice_1" assert result.content_data["free_text"] == "Some details" - def test_answer_mcq_no_downstream_skips_mcq_answer_item( - self, db: Session, mcq_item: ThreadItem - ): + def test_answer_mcq_no_downstream_skips_mcq_answer_item(self, db: Session, mcq_item: ThreadItem): """Test that MCQ_ANSWER item is NOT created when there are no downstream items.""" from app.models.thread_item import ThreadItemType @@ -352,9 +362,7 @@ def test_answer_mcq_no_downstream_skips_mcq_answer_item( ) # Should only have the original MCQ item - no MCQ_ANSWER created - items = db.query(ThreadItem).filter( - ThreadItem.thread_id == mcq_item.thread_id - ).all() + items = db.query(ThreadItem).filter(ThreadItem.thread_id == mcq_item.thread_id).all() assert len(items) == 1 assert items[0].item_type == ThreadItemType.MCQ_FOLLOWUP @@ -381,18 +389,19 @@ def test_answer_mcq_with_downstream_creates_mcq_answer_item( ) # Should have MCQ + comment + MCQ_ANSWER - items = db.query(ThreadItem).filter( - ThreadItem.thread_id == mcq_item.thread_id - ).order_by(ThreadItem.created_at).all() + items = ( + db.query(ThreadItem) + .filter(ThreadItem.thread_id == mcq_item.thread_id) + .order_by(ThreadItem.created_at) + .all() + ) assert len(items) == 3 assert items[0].item_type == ThreadItemType.MCQ_FOLLOWUP assert items[1].item_type == ThreadItemType.COMMENT assert items[2].item_type == ThreadItemType.MCQ_ANSWER assert items[2].content_data["selected_option_id"] == "choice_1" - def test_change_answer_no_downstream_no_force_service( - self, db: Session, mcq_item: ThreadItem - ): + def test_change_answer_no_downstream_no_force_service(self, db: Session, mcq_item: ThreadItem): """Test changing MCQ answer when no downstream items - no force_change needed.""" # First answer the MCQ mcq_item.content_data["selected_option_id"] = "choice_1" @@ -408,8 +417,7 @@ def test_change_answer_no_downstream_no_force_service( assert result.content_data["selected_option_id"] == "choice_2" def test_change_answer_with_downstream_no_force_raises_409( - self, db: Session, - answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] + self, db: Session, answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] ): """Test that changing answer with downstream items and no force_change raises HTTPException.""" from fastapi import HTTPException @@ -429,14 +437,11 @@ def test_change_answer_with_downstream_no_force_raises_409( assert exc_info.value.detail["downstream_count"] == 2 # Verify downstream items are NOT deleted - remaining = db.query(ThreadItem).filter( - ThreadItem.thread_id == mcq_item.thread_id - ).count() + remaining = db.query(ThreadItem).filter(ThreadItem.thread_id == mcq_item.thread_id).count() assert remaining == 3 # MCQ + 2 downstream def test_change_answer_with_downstream_force_deletes( - self, db: Session, test_thread: Thread, - answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] + self, db: Session, test_thread: Thread, answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] ): """Test that changing answer with force_change deletes downstream items.""" from app.models.thread_item import ThreadItemType @@ -457,16 +462,18 @@ def test_change_answer_with_downstream_force_deletes( # Verify downstream items are deleted - only MCQ remains # No MCQ_ANSWER is created because there are no downstream items after deletion - remaining = db.query(ThreadItem).filter( - ThreadItem.thread_id == mcq_item.thread_id - ).order_by(ThreadItem.created_at).all() + remaining = ( + db.query(ThreadItem) + .filter(ThreadItem.thread_id == mcq_item.thread_id) + .order_by(ThreadItem.created_at) + .all() + ) assert len(remaining) == 1 assert remaining[0].id == mcq_item.id assert remaining[0].item_type == ThreadItemType.MCQ_FOLLOWUP def test_change_answer_with_force_resets_decision_summary( - self, db: Session, test_thread: Thread, - answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] + self, db: Session, test_thread: Thread, answered_mcq_with_downstream: tuple[ThreadItem, list[ThreadItem]] ): """Test that force_change resets decision summary fields.""" mcq_item, _ = answered_mcq_with_downstream @@ -493,9 +500,7 @@ class TestSnapshotBasedDeletion: """Tests for efficient deletion using per-item summary snapshots.""" @pytest.fixture - def thread_with_snapshots( - self, db: Session, test_thread: Thread, test_user: User - ) -> list[ThreadItem]: + def thread_with_snapshots(self, db: Session, test_thread: Thread, test_user: User) -> list[ThreadItem]: """Create thread items with summary snapshots set.""" import time @@ -526,8 +531,13 @@ def thread_with_snapshots( return items def test_delete_middle_item_restores_from_snapshot( - self, db: Session, test_thread: Thread, test_user: User, - thread_with_snapshots: list[ThreadItem], client: TestClient, auth_headers + self, + db: Session, + test_thread: Thread, + test_user: User, + thread_with_snapshots: list[ThreadItem], + client: TestClient, + auth_headers, ): """Test that deleting a middle item restores summary from previous item's snapshot.""" items = thread_with_snapshots @@ -544,8 +554,13 @@ def test_delete_middle_item_restores_from_snapshot( assert test_thread.last_summarized_item_id == str(items[0].id) def test_delete_last_item_restores_from_snapshot( - self, db: Session, test_thread: Thread, test_user: User, - thread_with_snapshots: list[ThreadItem], client: TestClient, auth_headers + self, + db: Session, + test_thread: Thread, + test_user: User, + thread_with_snapshots: list[ThreadItem], + client: TestClient, + auth_headers, ): """Test that deleting the last item restores from second-to-last snapshot.""" items = thread_with_snapshots @@ -562,8 +577,13 @@ def test_delete_last_item_restores_from_snapshot( assert test_thread.last_summarized_item_id == str(items[1].id) def test_delete_first_item_falls_back_to_reset( - self, db: Session, test_thread: Thread, test_user: User, - thread_with_snapshots: list[ThreadItem], client: TestClient, auth_headers + self, + db: Session, + test_thread: Thread, + test_user: User, + thread_with_snapshots: list[ThreadItem], + client: TestClient, + auth_headers, ): """Test that deleting the first item (no previous) resets summary to None.""" items = thread_with_snapshots @@ -580,8 +600,13 @@ def test_delete_first_item_falls_back_to_reset( assert test_thread.last_summarized_item_id is None def test_start_over_with_snapshot_restores_correctly( - self, db: Session, test_thread: Thread, test_user: User, - thread_with_snapshots: list[ThreadItem], client: TestClient, auth_headers + self, + db: Session, + test_thread: Thread, + test_user: User, + thread_with_snapshots: list[ThreadItem], + client: TestClient, + auth_headers, ): """Test that start-over restores from previous item's snapshot.""" items = thread_with_snapshots @@ -601,8 +626,13 @@ def test_start_over_with_snapshot_restores_correctly( assert test_thread.last_summarized_item_id == str(items[0].id) def test_start_over_from_first_item_resets( - self, db: Session, test_thread: Thread, test_user: User, - thread_with_snapshots: list[ThreadItem], client: TestClient, auth_headers + self, + db: Session, + test_thread: Thread, + test_user: User, + thread_with_snapshots: list[ThreadItem], + client: TestClient, + auth_headers, ): """Test that start-over from first item (no previous) resets summary.""" items = thread_with_snapshots @@ -622,8 +652,7 @@ def test_start_over_from_first_item_resets( assert test_thread.last_summarized_item_id is None def test_start_over_resets_button_visibility_and_suggested_name( - self, db: Session, test_thread: Thread, test_user: User, - client: TestClient, auth_headers + self, db: Session, test_thread: Thread, test_user: User, client: TestClient, auth_headers ): """Test that start-over correctly resets button visibility and suggested name. @@ -691,15 +720,13 @@ class TestStartOverWithImplementations: """Tests for start-over functionality that also deletes implementations.""" @pytest.fixture - def feature_with_thread_and_implementations( - self, db: Session, test_project: Project, test_user: User - ): + def feature_with_thread_and_implementations(self, db: Session, test_project: Project, test_user: User): """Create a feature with a thread containing implementation markers.""" import time - from app.models.feature import Feature - from app.models.module import Module, ModuleType, ModuleProvenance - from app.models.feature import FeatureProvenance + + from app.models.feature import Feature, FeatureProvenance from app.models.implementation import Implementation + from app.models.module import Module, ModuleProvenance, ModuleType # Create module module = Module( @@ -828,11 +855,9 @@ def feature_with_thread_and_implementations( } def test_start_over_deletes_implementation_in_range( - self, db: Session, test_user: User, - feature_with_thread_and_implementations, client: TestClient, auth_headers + self, db: Session, test_user: User, feature_with_thread_and_implementations, client: TestClient, auth_headers ): """Test that implementation is deleted when its marker is in range.""" - from uuid import UUID as PyUUID from app.models.implementation import Implementation data = feature_with_thread_and_implementations @@ -845,10 +870,7 @@ def test_start_over_deletes_implementation_in_range( comment2_id = str(data["comment2"].id) # Start over from impl2_marker (should delete impl2_marker, comment2, and impl2) - response = client.post( - f"/api/v1/thread-items/{impl2_marker_id}/start-over", - headers=headers - ) + response = client.post(f"/api/v1/thread-items/{impl2_marker_id}/start-over", headers=headers) assert response.status_code == 200 result = response.json() @@ -859,20 +881,15 @@ def test_start_over_deletes_implementation_in_range( assert str(impl2_id) in result["deleted_implementation_ids"] # Verify impl2 is deleted from database - impl2_check = db.query(Implementation).filter( - Implementation.id == impl2_id - ).first() + impl2_check = db.query(Implementation).filter(Implementation.id == impl2_id).first() assert impl2_check is None # Verify impl1 still exists - impl1_check = db.query(Implementation).filter( - Implementation.id == impl1_id - ).first() + impl1_check = db.query(Implementation).filter(Implementation.id == impl1_id).first() assert impl1_check is not None def test_start_over_returns_deleted_implementation_ids( - self, db: Session, test_user: User, - feature_with_thread_and_implementations, client: TestClient, auth_headers + self, db: Session, test_user: User, feature_with_thread_and_implementations, client: TestClient, auth_headers ): """Test that response includes deleted implementation IDs.""" data = feature_with_thread_and_implementations @@ -882,10 +899,7 @@ def test_start_over_returns_deleted_implementation_ids( impl2_id = str(data["impl2"].id) impl2_marker_id = str(data["impl2_marker"].id) - response = client.post( - f"/api/v1/thread-items/{impl2_marker_id}/start-over", - headers=headers - ) + response = client.post(f"/api/v1/thread-items/{impl2_marker_id}/start-over", headers=headers) assert response.status_code == 200 result = response.json() @@ -895,8 +909,7 @@ def test_start_over_returns_deleted_implementation_ids( assert impl2_id in result["deleted_implementation_ids"] def test_start_over_blocked_when_would_delete_all_implementations( - self, db: Session, test_user: User, - feature_with_thread_and_implementations, client: TestClient, auth_headers + self, db: Session, test_user: User, feature_with_thread_and_implementations, client: TestClient, auth_headers ): """Test that start-over is blocked if it would delete all implementations.""" from app.models.implementation import Implementation @@ -909,22 +922,16 @@ def test_start_over_blocked_when_would_delete_all_implementations( feature_id = data["feature"].id # Keep as UUID # Start over from impl1_marker (would delete both implementations) - response = client.post( - f"/api/v1/thread-items/{impl1_marker_id}/start-over", - headers=headers - ) + response = client.post(f"/api/v1/thread-items/{impl1_marker_id}/start-over", headers=headers) assert response.status_code == 400 assert "would delete all implementations" in response.json()["detail"] # Verify both implementations still exist - impls = db.query(Implementation).filter( - Implementation.feature_id == feature_id - ).all() + impls = db.query(Implementation).filter(Implementation.feature_id == feature_id).all() assert len(impls) == 2 def test_start_over_preserves_implementations_not_in_range( - self, db: Session, test_user: User, - feature_with_thread_and_implementations, client: TestClient, auth_headers + self, db: Session, test_user: User, feature_with_thread_and_implementations, client: TestClient, auth_headers ): """Test that implementations not in deleted range are preserved.""" from app.models.implementation import Implementation @@ -937,10 +944,7 @@ def test_start_over_preserves_implementations_not_in_range( feature_id = data["feature"].id # Keep as UUID # Start over from comment2 (only deletes comment2, no implementations) - response = client.post( - f"/api/v1/thread-items/{comment2_id}/start-over", - headers=headers - ) + response = client.post(f"/api/v1/thread-items/{comment2_id}/start-over", headers=headers) assert response.status_code == 200 result = response.json() @@ -948,14 +952,11 @@ def test_start_over_preserves_implementations_not_in_range( assert result["deleted_implementation_ids"] == [] # Both implementations should still exist - impls = db.query(Implementation).filter( - Implementation.feature_id == feature_id - ).all() + impls = db.query(Implementation).filter(Implementation.feature_id == feature_id).all() assert len(impls) == 2 def test_start_over_on_non_feature_thread_ignores_implementation_logic( - self, db: Session, test_project: Project, test_user: User, - client: TestClient, auth_headers + self, db: Session, test_project: Project, test_user: User, client: TestClient, auth_headers ): """Test that start-over on non-feature threads doesn't error.""" # Create a general thread (not associated with a feature) @@ -970,6 +971,7 @@ def test_start_over_on_non_feature_thread_ignores_implementation_logic( # Add some comments import time + items = [] for i in range(3): item = ThreadItem( @@ -985,10 +987,7 @@ def test_start_over_on_non_feature_thread_ignores_implementation_logic( time.sleep(0.01) headers = auth_headers(test_user) - response = client.post( - f"/api/v1/thread-items/{items[1].id}/start-over", - headers=headers - ) + response = client.post(f"/api/v1/thread-items/{items[1].id}/start-over", headers=headers) assert response.status_code == 200 result = response.json() diff --git a/backend/tests/test_thread_item_reactions.py b/backend/tests/test_thread_item_reactions.py index 2137bad..1bb8b6f 100644 --- a/backend/tests/test_thread_item_reactions.py +++ b/backend/tests/test_thread_item_reactions.py @@ -1,8 +1,10 @@ """Tests for thread item reactions functionality.""" + import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Project, Thread, ContextType + +from app.models import ContextType, Project, Thread, User from app.models.thread_item import ThreadItem, ThreadItemType from app.services.thread_service import ThreadService from app.services.user_service import UserService @@ -54,9 +56,7 @@ def test_mcq_item(db: Session, test_thread: Thread, test_user: User) -> ThreadIt class TestToggleReaction: """Tests for POST /thread-items/{item_id}/reactions endpoint.""" - def test_toggle_reaction_requires_auth( - self, client: TestClient, test_comment_item: ThreadItem - ): + def test_toggle_reaction_requires_auth(self, client: TestClient, test_comment_item: ThreadItem): """Test that toggling reactions requires authentication.""" response = client.post( f"/api/v1/thread-items/{test_comment_item.id}/reactions", @@ -64,9 +64,7 @@ def test_toggle_reaction_requires_auth( ) assert response.status_code == 401 - def test_toggle_reaction_item_not_found( - self, client: TestClient, auth_headers, test_user: User - ): + def test_toggle_reaction_item_not_found(self, client: TestClient, auth_headers, test_user: User): """Test 404 for non-existent item.""" headers = auth_headers(test_user) response = client.post( diff --git a/backend/tests/test_thread_service.py b/backend/tests/test_thread_service.py index fead7ba..7ed673a 100644 --- a/backend/tests/test_thread_service.py +++ b/backend/tests/test_thread_service.py @@ -1,9 +1,11 @@ """Tests for ThreadService.""" + import pytest from sqlalchemy.orm import Session + +from app.models import ContextType, Project, User +from app.models.thread_item import ThreadItemType from app.services.thread_service import ThreadService -from app.models import Thread, Comment, ContextType, User, Project -from app.models.thread_item import ThreadItem, ThreadItemType class TestThreadCRUD: @@ -17,7 +19,7 @@ def test_create_thread_general(self, db: Session, test_user: User, test_project: context_type=ContextType.GENERAL, context_id=None, title="General discussion", - created_by=str(test_user.id) + created_by=str(test_user.id), ) assert thread.id is not None assert thread.project_id == str(test_project.id) @@ -35,7 +37,7 @@ def test_create_thread_with_context(self, db: Session, test_user: User, test_pro context_type=ContextType.BRAINSTORM_FEATURE, context_id=context_id, title="Question about architecture", - created_by=str(test_user.id) + created_by=str(test_user.id), ) assert thread.context_type == ContextType.BRAINSTORM_FEATURE assert thread.context_id == context_id @@ -48,7 +50,7 @@ def test_get_thread_by_id(self, db: Session, test_user: User, test_project: Proj context_type=ContextType.GENERAL, context_id=None, title="Test thread", - created_by=str(test_user.id) + created_by=str(test_user.id), ) retrieved = ThreadService.get_thread_by_id(db, thread.id) assert retrieved is not None @@ -65,7 +67,9 @@ def test_list_project_threads(self, db: Session, test_user: User, test_project: # Create multiple threads ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread 1", str(test_user.id)) ThreadService.create_thread(db, str(test_project.id), ContextType.SPEC, "spec-1", "Thread 2", str(test_user.id)) - ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, "phase-1", "Thread 3", str(test_user.id)) + ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, "phase-1", "Thread 3", str(test_user.id) + ) threads = ThreadService.list_project_threads(db, str(test_project.id)) assert len(threads) == 3 @@ -82,9 +86,15 @@ def test_list_project_threads_filtered_by_context_type(self, db: Session, test_u def test_list_project_threads_filtered_by_context_id(self, db: Session, test_user: User, test_project: Project): """Test listing threads filtered by context ID.""" - ThreadService.create_thread(db, str(test_project.id), ContextType.BRAINSTORM_FEATURE, "q1", "Q1", str(test_user.id)) - ThreadService.create_thread(db, str(test_project.id), ContextType.BRAINSTORM_FEATURE, "q2", "Q2", str(test_user.id)) - ThreadService.create_thread(db, str(test_project.id), ContextType.BRAINSTORM_FEATURE, "q1", "Q1 again", str(test_user.id)) + ThreadService.create_thread( + db, str(test_project.id), ContextType.BRAINSTORM_FEATURE, "q1", "Q1", str(test_user.id) + ) + ThreadService.create_thread( + db, str(test_project.id), ContextType.BRAINSTORM_FEATURE, "q2", "Q2", str(test_user.id) + ) + ThreadService.create_thread( + db, str(test_project.id), ContextType.BRAINSTORM_FEATURE, "q1", "Q1 again", str(test_user.id) + ) q1_threads = ThreadService.list_project_threads(db, str(test_project.id), context_id="q1") assert len(q1_threads) == 2 @@ -92,7 +102,9 @@ def test_list_project_threads_filtered_by_context_id(self, db: Session, test_use def test_update_thread_title(self, db: Session, test_user: User, test_project: Project): """Test updating a thread title.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Old Title", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Old Title", str(test_user.id) + ) updated = ThreadService.update_thread(db, thread.id, title="New Title") assert updated.title == "New Title" @@ -107,12 +119,11 @@ class TestCommentCRUD: def test_create_comment(self, db: Session, test_user: User, test_project: Project): """Test creating a comment on a thread.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment( - db=db, - thread_id=thread.id, - author_id=str(test_user.id), - body_markdown="This is a comment." + db=db, thread_id=thread.id, author_id=str(test_user.id), body_markdown="This is a comment." ) assert comment.id is not None assert comment.thread_id == thread.id @@ -126,7 +137,9 @@ def test_create_comment_invalid_thread(self, db: Session, test_user: User): def test_list_thread_comments(self, db: Session, test_user: User, test_project: Project): """Test listing comments on a thread.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment 1") ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment 2") ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment 3") @@ -136,7 +149,9 @@ def test_list_thread_comments(self, db: Session, test_user: User, test_project: def test_update_comment(self, db: Session, test_user: User, test_project: Project): """Test updating a comment.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment(db, thread.id, str(test_user.id), "Original text") updated = ThreadService.update_comment(db, comment.id, body_markdown="Updated text") assert updated.body_markdown == "Updated text" @@ -149,7 +164,9 @@ def test_update_comment_not_found(self, db: Session): def test_delete_comment(self, db: Session, test_user: User, test_project: Project): """Test deleting a comment.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "Thread", str(test_user.id) + ) comment = ThreadService.create_comment(db, thread.id, str(test_user.id), "Comment") ThreadService.delete_comment(db, comment.id) @@ -168,12 +185,16 @@ class TestContextValidation: def test_context_type_general_no_context_id(self, db: Session, test_user: User, test_project: Project): """Test that general threads can have no context_id.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, None, "General", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, None, "General", str(test_user.id) + ) assert thread.context_id is None def test_context_type_general_with_context_id(self, db: Session, test_user: User, test_project: Project): """Test that general threads can optionally have context_id.""" - thread = ThreadService.create_thread(db, str(test_project.id), ContextType.GENERAL, "some-id", "General", str(test_user.id)) + thread = ThreadService.create_thread( + db, str(test_project.id), ContextType.GENERAL, "some-id", "General", str(test_user.id) + ) assert thread.context_id == "some-id" def test_brainstorm_feature_context(self, db: Session, test_user: User, test_project: Project): @@ -427,6 +448,7 @@ def test_touch_linked_feature_does_nothing_for_nonexistent_feature( ): """Test that _touch_linked_feature handles nonexistent feature gracefully.""" import uuid + # Create a BRAINSTORM_FEATURE thread with valid UUID that doesn't exist thread = ThreadService.create_thread( db=db, @@ -461,9 +483,7 @@ def test_skips_decision_summary_when_only_mfbtai_mentioned( created_by=str(test_user.id), ) - with patch.object( - ThreadService, "trigger_decision_summary" - ) as mock_trigger: + with patch.object(ThreadService, "trigger_decision_summary") as mock_trigger: ThreadService.create_comment_item( db=db, thread_id=thread.id, @@ -488,16 +508,13 @@ def test_triggers_decision_summary_when_mfbtai_and_users_mentioned( created_by=str(test_user.id), ) - with patch.object( - ThreadService, "trigger_decision_summary" - ) as mock_trigger: + with patch.object(ThreadService, "trigger_decision_summary") as mock_trigger: ThreadService.create_comment_item( db=db, thread_id=thread.id, author_id=str(test_user.id), body_markdown=( - "@[MFBTAI](mfbtai) and @[John](11111111-1111-1111-1111-111111111111) " - "please review this" + "@[MFBTAI](mfbtai) and @[John](11111111-1111-1111-1111-111111111111) please review this" ), ) # Should trigger decision summary when other users are also mentioned @@ -518,9 +535,7 @@ def test_triggers_decision_summary_when_only_users_mentioned( created_by=str(test_user.id), ) - with patch.object( - ThreadService, "trigger_decision_summary" - ) as mock_trigger: + with patch.object(ThreadService, "trigger_decision_summary") as mock_trigger: ThreadService.create_comment_item( db=db, thread_id=thread.id, @@ -530,9 +545,7 @@ def test_triggers_decision_summary_when_only_users_mentioned( # Should trigger decision summary for regular user mentions mock_trigger.assert_called_once() - def test_triggers_decision_summary_when_no_mentions( - self, db: Session, test_project: Project, test_user: User - ): + def test_triggers_decision_summary_when_no_mentions(self, db: Session, test_project: Project, test_user: User): """Test that decision summarizer runs when there are no mentions.""" from unittest.mock import patch @@ -545,9 +558,7 @@ def test_triggers_decision_summary_when_no_mentions( created_by=str(test_user.id), ) - with patch.object( - ThreadService, "trigger_decision_summary" - ) as mock_trigger: + with patch.object(ThreadService, "trigger_decision_summary") as mock_trigger: ThreadService.create_comment_item( db=db, thread_id=thread.id, diff --git a/backend/tests/test_thread_service_version.py b/backend/tests/test_thread_service_version.py index ac14f47..e3db22a 100644 --- a/backend/tests/test_thread_service_version.py +++ b/backend/tests/test_thread_service_version.py @@ -2,21 +2,21 @@ Tests the new methods for creating and querying threads anchored to draft versions. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from uuid import uuid4 +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.spec_version import SpecVersion, SpecType -from app.models.thread import Thread, ContextType -from app.services.user_service import UserService +from app.models.spec_version import SpecVersion +from app.models.thread import ContextType +from app.models.user import User from app.services.draft_version_service import DraftVersionService from app.services.thread_service import ThreadService +from app.services.user_service import UserService @pytest.fixture diff --git a/backend/tests/test_thread_version_anchoring.py b/backend/tests/test_thread_version_anchoring.py index 1cf5cc7..7f7a3b0 100644 --- a/backend/tests/test_thread_version_anchoring.py +++ b/backend/tests/test_thread_version_anchoring.py @@ -2,18 +2,18 @@ Tests the enhanced Thread model with version_id and block_id for inline comments. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session -from datetime import datetime, timezone +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User +from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType from app.models.organization import Organization from app.models.project import Project, ProjectStatus -from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.spec_version import SpecVersion, SpecType -from app.models.thread import Thread, ContextType +from app.models.spec_version import SpecType, SpecVersion +from app.models.thread import ContextType, Thread +from app.models.user import User from app.services.user_service import UserService @@ -225,11 +225,7 @@ def test_query_threads_by_version_id( test_db_session.commit() # Query threads by version_id - version_threads = ( - test_db_session.query(Thread) - .filter(Thread.version_id == str(sample_spec_version.id)) - .all() - ) + version_threads = test_db_session.query(Thread).filter(Thread.version_id == str(sample_spec_version.id)).all() assert len(version_threads) == 2 assert all(t.version_id == str(sample_spec_version.id) for t in version_threads) diff --git a/backend/tests/test_typing_indicator_service.py b/backend/tests/test_typing_indicator_service.py index dce51d5..5cef1f8 100644 --- a/backend/tests/test_typing_indicator_service.py +++ b/backend/tests/test_typing_indicator_service.py @@ -3,9 +3,8 @@ Tests the Redis-based typing indicator state management. """ -import pytest -from unittest.mock import patch, MagicMock import json +from unittest.mock import MagicMock, patch from app.services.typing_indicator_service import TypingIndicatorService @@ -31,10 +30,12 @@ def test_set_typing_success(self, mock_kafka, mock_redis): mock_client = MagicMock() mock_redis.return_value = mock_client mock_client.scan_iter.return_value = ["typing:thread-123:user-456"] - mock_client.get.return_value = json.dumps({ - "user_id": "user-456", - "user_name": "Test User", - }) + mock_client.get.return_value = json.dumps( + { + "user_id": "user-456", + "user_name": "Test User", + } + ) mock_producer = MagicMock() mock_kafka.return_value = mock_producer diff --git a/backend/tests/test_user_group_service.py b/backend/tests/test_user_group_service.py index 11e89cc..5166eb8 100644 --- a/backend/tests/test_user_group_service.py +++ b/backend/tests/test_user_group_service.py @@ -1,14 +1,14 @@ """Tests for user group service.""" import pytest -from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session -from app.models import User, Organization, OrgMembership, OrgRole +from app.models import Organization, OrgRole, User from app.models.provisioning import ProvisioningSource +from app.services.org_service import OrgService from app.services.user_group_service import UserGroupService from app.services.user_service import UserService -from app.services.org_service import OrgService class TestUserGroupService: @@ -31,9 +31,7 @@ def test_create_group(self, db: Session, test_user: User, test_org: Organization assert group.created_by_user_id == test_user.id assert group.provisioning_source == ProvisioningSource.MANUAL - def test_create_group_with_custom_provisioning( - self, db: Session, test_user: User, test_org: Organization - ): + def test_create_group_with_custom_provisioning(self, db: Session, test_user: User, test_org: Organization): """Test creating a group with custom provisioning source.""" group = UserGroupService.create_group( db=db, @@ -45,9 +43,7 @@ def test_create_group_with_custom_provisioning( assert group.provisioning_source == ProvisioningSource.SSO_SCIM - def test_create_duplicate_group_name_fails( - self, db: Session, test_user: User, test_org: Organization - ): + def test_create_duplicate_group_name_fails(self, db: Session, test_user: User, test_org: Organization): """Test that duplicate group names within an org fail.""" UserGroupService.create_group( db=db, @@ -100,12 +96,8 @@ def test_get_group_by_name(self, db: Session, test_user: User, test_org: Organiz def test_list_org_groups(self, db: Session, test_user: User, test_org: Organization): """Test listing all groups in an org.""" - UserGroupService.create_group( - db=db, org_id=test_org.id, name="Alpha", created_by_user_id=test_user.id - ) - UserGroupService.create_group( - db=db, org_id=test_org.id, name="Beta", created_by_user_id=test_user.id - ) + UserGroupService.create_group(db=db, org_id=test_org.id, name="Alpha", created_by_user_id=test_user.id) + UserGroupService.create_group(db=db, org_id=test_org.id, name="Beta", created_by_user_id=test_user.id) groups = UserGroupService.list_org_groups(db, test_org.id) assert len(groups) == 2 @@ -168,9 +160,7 @@ def test_add_member(self, db: Session, test_user: User, test_org: Organization): assert membership.user_id == test_user.id assert membership.provisioning_source == ProvisioningSource.MANUAL - def test_add_member_with_invite_provisioning( - self, db: Session, test_user: User, test_org: Organization - ): + def test_add_member_with_invite_provisioning(self, db: Session, test_user: User, test_org: Organization): """Test adding a user with INVITE provisioning source.""" group = UserGroupService.create_group( db=db, @@ -179,15 +169,11 @@ def test_add_member_with_invite_provisioning( created_by_user_id=test_user.id, ) - membership = UserGroupService.add_member( - db, group.id, test_user.id, ProvisioningSource.INVITE - ) + membership = UserGroupService.add_member(db, group.id, test_user.id, ProvisioningSource.INVITE) assert membership.provisioning_source == ProvisioningSource.INVITE - def test_add_duplicate_member_fails( - self, db: Session, test_user: User, test_org: Organization - ): + def test_add_duplicate_member_fails(self, db: Session, test_user: User, test_org: Organization): """Test that adding the same user twice fails.""" group = UserGroupService.create_group( db=db, @@ -216,9 +202,7 @@ def test_remove_member(self, db: Session, test_user: User, test_org: Organizatio assert UserGroupService.is_user_in_group(db, group.id, test_user.id) is False - def test_remove_nonexistent_member( - self, db: Session, test_user: User, test_org: Organization - ): + def test_remove_nonexistent_member(self, db: Session, test_user: User, test_org: Organization): """Test removing a non-member returns False.""" group = UserGroupService.create_group( db=db, @@ -250,12 +234,8 @@ def test_get_group_members(self, db: Session, test_user: User, test_org: Organiz def test_get_user_groups(self, db: Session, test_user: User, test_org: Organization): """Test getting all groups a user belongs to.""" - group1 = UserGroupService.create_group( - db=db, org_id=test_org.id, name="Alpha", created_by_user_id=test_user.id - ) - group2 = UserGroupService.create_group( - db=db, org_id=test_org.id, name="Beta", created_by_user_id=test_user.id - ) + group1 = UserGroupService.create_group(db=db, org_id=test_org.id, name="Alpha", created_by_user_id=test_user.id) + group2 = UserGroupService.create_group(db=db, org_id=test_org.id, name="Beta", created_by_user_id=test_user.id) UserGroupService.add_member(db, group1.id, test_user.id) UserGroupService.add_member(db, group2.id, test_user.id) @@ -298,9 +278,7 @@ def test_is_user_in_group(self, db: Session, test_user: User, test_org: Organiza assert UserGroupService.is_user_in_group(db, group.id, test_user.id) is True - def test_cascade_delete_memberships( - self, db: Session, test_user: User, test_org: Organization - ): + def test_cascade_delete_memberships(self, db: Session, test_user: User, test_org: Organization): """Test that deleting a group cascades to memberships.""" group = UserGroupService.create_group( db=db, diff --git a/backend/tests/test_user_groups_router.py b/backend/tests/test_user_groups_router.py index 078b741..ed5dba4 100644 --- a/backend/tests/test_user_groups_router.py +++ b/backend/tests/test_user_groups_router.py @@ -1,10 +1,9 @@ """Integration tests for user groups router endpoints.""" -import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from app.models import User, Organization, OrgRole +from app.models import Organization, OrgRole, User from app.services.org_service import OrgService from app.services.user_group_service import UserGroupService from app.services.user_service import UserService diff --git a/backend/tests/test_user_service.py b/backend/tests/test_user_service.py index fa1ec2e..043ff64 100644 --- a/backend/tests/test_user_service.py +++ b/backend/tests/test_user_service.py @@ -3,12 +3,12 @@ Tests user creation, authentication, and retrieval operations. """ + import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, Session +from sqlalchemy.orm import Session, sessionmaker from app.database import Base -from app.models.user import User from app.services.user_service import UserService @@ -89,9 +89,7 @@ def test_create_user_with_duplicate_email_raises_error(self, test_db_session: Se class TestUserServiceAuthenticateUser: """Tests for UserService.authenticate_user().""" - def test_authenticate_user_succeeds_with_correct_credentials( - self, test_db_session: Session - ): + def test_authenticate_user_succeeds_with_correct_credentials(self, test_db_session: Session): """Test that authentication succeeds with correct email and password.""" email = "test@example.com" password = "securepassword123" @@ -132,9 +130,7 @@ def test_authenticate_user_fails_with_wrong_password(self, test_db_session: Sess assert authenticated_user is None - def test_authenticate_user_fails_with_nonexistent_email( - self, test_db_session: Session - ): + def test_authenticate_user_fails_with_nonexistent_email(self, test_db_session: Session): """Test that authentication fails with non-existent email.""" authenticated_user = UserService.authenticate_user( db=test_db_session, @@ -165,13 +161,9 @@ def test_get_user_by_email_returns_user(self, test_db_session: Session): assert retrieved_user.id == created_user.id assert retrieved_user.email == email - def test_get_user_by_email_returns_none_for_nonexistent( - self, test_db_session: Session - ): + def test_get_user_by_email_returns_none_for_nonexistent(self, test_db_session: Session): """Test that getting a non-existent user by email returns None.""" - retrieved_user = UserService.get_user_by_email( - db=test_db_session, email="nonexistent@example.com" - ) + retrieved_user = UserService.get_user_by_email(db=test_db_session, email="nonexistent@example.com") assert retrieved_user is None @@ -186,17 +178,13 @@ def test_get_user_by_id_returns_user(self, test_db_session: Session): password=password, ) - retrieved_user = UserService.get_user_by_id( - db=test_db_session, user_id=created_user.id - ) + retrieved_user = UserService.get_user_by_id(db=test_db_session, user_id=created_user.id) assert retrieved_user is not None assert retrieved_user.id == created_user.id assert retrieved_user.email == email - def test_get_user_by_id_returns_none_for_nonexistent( - self, test_db_session: Session - ): + def test_get_user_by_id_returns_none_for_nonexistent(self, test_db_session: Session): """Test that getting a non-existent user by ID returns None.""" from uuid import uuid4 diff --git a/backend/tests/test_vfs_content.py b/backend/tests/test_vfs_content.py index f073db7..f06f54b 100644 --- a/backend/tests/test_vfs_content.py +++ b/backend/tests/test_vfs_content.py @@ -1,33 +1,33 @@ """Tests for VFS content generation, particularly prompt plan wrapping.""" -import pytest +from unittest.mock import MagicMock from uuid import uuid4 + from sqlalchemy.orm import Session -from unittest.mock import MagicMock from app.mcp.vfs.content import ( - _wrap_prompt_plan_with_instructions, - _get_feature_file_content, + MFBT_USAGE_GUIDE_CONTENT, _get_conversations_file_content, - _list_conversations_dir, - _transform_mentions_for_vfs, _get_extension_from_content_type, - _list_mfbt_usage_guide_dir, + _get_feature_file_content, _get_mfbt_usage_guide_file_content, - MFBT_USAGE_GUIDE_CONTENT, + _list_conversations_dir, + _list_mfbt_usage_guide_dir, + _transform_mentions_for_vfs, + _wrap_prompt_plan_with_instructions, ) -from app.mcp.vfs.path_resolver import ResolvedPath, NodeType -from app.models import ProjectType, User, Organization +from app.mcp.vfs.path_resolver import NodeType, ResolvedPath +from app.models import Organization, ProjectType, User from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleType, ModuleProvenance from app.models.feature import ( Feature, - FeatureType, - FeatureStatus, FeatureCompletionStatus, FeaturePriority, FeatureProvenance, + FeatureStatus, + FeatureType, ) +from app.models.module import Module, ModuleProvenance, ModuleType from app.services.project_service import ProjectService @@ -94,9 +94,7 @@ def test_feature_key_included_in_postscript(self): class TestGetFeatureFileContentPromptPlan: """Integration tests for prompt_plan.md retrieval with wrapping.""" - def test_prompt_plan_with_content_is_wrapped( - self, db: Session, test_org: Organization, test_user: User - ): + def test_prompt_plan_with_content_is_wrapped(self, db: Session, test_org: Organization, test_user: User): """Test that prompt_plan.md with content includes preamble/postscript.""" # Create project, phase, module, feature project = ProjectService.create_project( @@ -179,9 +177,7 @@ def test_prompt_plan_with_content_is_wrapped( assert "# Steps" in result["content"] assert "1. Do this" in result["content"] - def test_prompt_plan_without_content_not_wrapped( - self, db: Session, test_org: Organization, test_user: User - ): + def test_prompt_plan_without_content_not_wrapped(self, db: Session, test_org: Organization, test_user: User): """Test that prompt_plan.md without content shows placeholder without wrapper.""" project = ProjectService.create_project( db=db, diff --git a/backend/tests/test_vfs_path_resolver.py b/backend/tests/test_vfs_path_resolver.py index 5edf944..fae10ca 100644 --- a/backend/tests/test_vfs_path_resolver.py +++ b/backend/tests/test_vfs_path_resolver.py @@ -1,20 +1,21 @@ """Tests for VFS path resolution.""" -import pytest from uuid import uuid4 +import pytest + +from app.mcp.vfs.errors import PathNotFoundError from app.mcp.vfs.path_resolver import ( - slugify, + NodeType, + _normalize_path, feature_dir_name, module_dir_name, resolve_path, - NodeType, - _normalize_path, + slugify, ) -from app.mcp.vfs.errors import PathNotFoundError, InvalidPathError from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleType, ModuleProvenance -from app.models.feature import Feature, FeatureType, FeatureProvenance, FeatureStatus +from app.models.feature import Feature, FeatureProvenance, FeatureStatus, FeatureType +from app.models.module import Module, ModuleProvenance, ModuleType class TestSlugify: @@ -488,7 +489,9 @@ def test_invalid_feature_file(self, db, test_project, test_user): with pytest.raises(PathNotFoundError) as exc_info: resolve_path( - db, test_project.id, "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/invalid.md" + db, + test_project.id, + "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/invalid.md", ) assert "spec.md" in str(exc_info.value.available) @@ -546,8 +549,9 @@ def test_resolve_conversations_dir(self, db, test_project, test_user): phase, module, feature = self._create_feature_fixture(db, test_project, test_user) resolved = resolve_path( - db, test_project.id, - "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations" + db, + test_project.id, + "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations", ) assert resolved.node_type == NodeType.CONVERSATIONS_DIR assert resolved.is_directory is True @@ -559,8 +563,9 @@ def test_resolve_conversations_file(self, db, test_project, test_user): phase, module, feature = self._create_feature_fixture(db, test_project, test_user) resolved = resolve_path( - db, test_project.id, - "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations/conversations.md" + db, + test_project.id, + "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations/conversations.md", ) assert resolved.node_type == NodeType.CONVERSATIONS_FILE assert resolved.is_directory is False @@ -578,14 +583,15 @@ def test_invalid_conversations_path(self, db, test_project, test_user): with pytest.raises(PathNotFoundError) as exc_info: resolve_path( - db, test_project.id, - "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations/invalid" + db, + test_project.id, + "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations/invalid", ) assert "conversations.md" in str(exc_info.value.available) def test_conversations_with_thread(self, db, test_project, test_user): """Test resolving conversations dir when a thread exists.""" - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread phase, module, feature = self._create_feature_fixture(db, test_project, test_user) @@ -601,8 +607,9 @@ def test_conversations_with_thread(self, db, test_project, test_user): db.commit() resolved = resolve_path( - db, test_project.id, - "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations" + db, + test_project.id, + "/phases/system-generated/initial/features/MTEST-001-auth/TEST-001-jwt-login/conversations", ) assert resolved.node_type == NodeType.CONVERSATIONS_DIR assert resolved.thread_id == thread.id diff --git a/backend/tests/test_vfs_sed.py b/backend/tests/test_vfs_sed.py index 8be9dfc..82d91f9 100644 --- a/backend/tests/test_vfs_sed.py +++ b/backend/tests/test_vfs_sed.py @@ -1,21 +1,22 @@ """Tests for VFS sed tool.""" -import pytest from uuid import uuid4 + +import pytest from sqlalchemy.orm import Session from app.mcp.tools.vfs_sed import vfs_sed -from app.models import ProjectType, User, Organization +from app.models import Organization, ProjectType, User from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleType, ModuleProvenance from app.models.feature import ( Feature, - FeatureType, - FeatureStatus, FeatureCompletionStatus, FeaturePriority, FeatureProvenance, + FeatureStatus, + FeatureType, ) +from app.models.module import Module, ModuleProvenance, ModuleType from app.services.project_service import ProjectService @@ -91,7 +92,7 @@ class TestVfsSed: def test_basic_substitution(self, db: Session, feature_with_notes): """Basic regex substitution on notes.md.""" data = feature_with_notes - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, @@ -115,7 +116,7 @@ def test_basic_substitution(self, db: Session, feature_with_notes): def test_global_flag(self, db: Session, feature_with_notes): """Global flag replaces all matches.""" data = feature_with_notes - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, @@ -134,7 +135,7 @@ def test_global_flag(self, db: Session, feature_with_notes): def test_single_replacement_no_global(self, db: Session, feature_with_notes): """Without 'g' flag, only first match is replaced.""" data = feature_with_notes - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, @@ -158,7 +159,7 @@ def test_case_insensitive_flag(self, db: Session, feature_with_notes): data["feature"].implementation_notes = "Hello World, HELLO universe" db.commit() - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, @@ -178,7 +179,7 @@ def test_no_matches(self, db: Session, feature_with_notes): """No matches returns 0 and doesn't modify file.""" data = feature_with_notes original_notes = data["feature"].implementation_notes - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, @@ -197,7 +198,7 @@ def test_no_matches(self, db: Session, feature_with_notes): def test_readonly_path_error(self, db: Session, feature_with_notes): """Read-only paths return error.""" data = feature_with_notes - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/spec.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/spec.md" result = vfs_sed( db=db, @@ -214,7 +215,7 @@ def test_readonly_path_error(self, db: Session, feature_with_notes): def test_invalid_regex_error(self, db: Session, feature_with_notes): """Invalid regex returns error.""" data = feature_with_notes - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, @@ -266,7 +267,7 @@ def test_multiline_flag(self, db: Session, feature_with_notes): data["feature"].implementation_notes = "line1\nline2\nline3" db.commit() - path = f"/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" + path = "/phases/system-generated/initial-planning/features/MTEST-001-auth-module/TEST-001-test-feature/notes.md" result = vfs_sed( db=db, diff --git a/backend/tests/test_vfs_write_conversations.py b/backend/tests/test_vfs_write_conversations.py index e1b7453..9e57f7d 100644 --- a/backend/tests/test_vfs_write_conversations.py +++ b/backend/tests/test_vfs_write_conversations.py @@ -1,32 +1,31 @@ """Tests for VFS write to conversations.md - creating thread comments via MCP.""" import json -import pytest +from unittest.mock import MagicMock, patch from uuid import uuid4 -from unittest.mock import patch, MagicMock + +import pytest from sqlalchemy.orm import Session from app.mcp.tools.vfs_write import ( - vfs_write, - _write_conversation_comment, _get_or_create_feature_thread, + _write_conversation_comment, ) -from app.mcp.vfs.path_resolver import ResolvedPath, NodeType -from app.models import ProjectType, User, Organization +from app.mcp.vfs.path_resolver import NodeType, ResolvedPath +from app.models import Organization, ProjectType, User from app.models.brainstorming_phase import BrainstormingPhase, BrainstormingPhaseType -from app.models.module import Module, ModuleType, ModuleProvenance from app.models.feature import ( Feature, - FeatureType, - FeatureStatus, FeatureCompletionStatus, FeaturePriority, FeatureProvenance, + FeatureStatus, + FeatureType, ) -from app.models.thread import Thread, ContextType -from app.models.thread_item import ThreadItem, ThreadItemType -from app.models.project_share import ProjectShare, ShareSubjectType +from app.models.module import Module, ModuleProvenance, ModuleType from app.models.project_membership import ProjectRole +from app.models.project_share import ProjectShare, ShareSubjectType +from app.models.thread import ContextType, Thread from app.services.project_service import ProjectService @@ -140,9 +139,7 @@ def non_member_user(db: Session): class TestWriteConversationCommentValidation: """Tests for _write_conversation_comment validation logic.""" - def test_rejects_invalid_json( - self, db: Session, test_user: User, project_with_feature - ): + def test_rejects_invalid_json(self, db: Session, test_user: User, project_with_feature): """Test that invalid JSON is rejected.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -167,9 +164,7 @@ def test_rejects_invalid_json( assert "error" in result assert "Invalid JSON" in result["error"] - def test_rejects_unknown_action( - self, db: Session, test_user: User, project_with_feature - ): + def test_rejects_unknown_action(self, db: Session, test_user: User, project_with_feature): """Test that unknown actions are rejected.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -182,10 +177,12 @@ def test_rejects_unknown_action( thread_id=None, ) - content = json.dumps({ - "action": "delete_comment", - "body_markdown": "test", - }) + content = json.dumps( + { + "action": "delete_comment", + "body_markdown": "test", + } + ) result = _write_conversation_comment( db=db, @@ -199,9 +196,7 @@ def test_rejects_unknown_action( assert "error" in result assert "Unknown action" in result["error"] - def test_rejects_empty_body_markdown( - self, db: Session, test_user: User, project_with_feature - ): + def test_rejects_empty_body_markdown(self, db: Session, test_user: User, project_with_feature): """Test that empty body_markdown is rejected.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -214,10 +209,12 @@ def test_rejects_empty_body_markdown( thread_id=None, ) - content = json.dumps({ - "action": "add_comment", - "body_markdown": "", - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": "", + } + ) result = _write_conversation_comment( db=db, @@ -231,9 +228,7 @@ def test_rejects_empty_body_markdown( assert "error" in result assert "body_markdown is required" in result["error"] - def test_rejects_missing_body_markdown( - self, db: Session, test_user: User, project_with_feature - ): + def test_rejects_missing_body_markdown(self, db: Session, test_user: User, project_with_feature): """Test that missing body_markdown is rejected.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -246,9 +241,11 @@ def test_rejects_missing_body_markdown( thread_id=None, ) - content = json.dumps({ - "action": "add_comment", - }) + content = json.dumps( + { + "action": "add_comment", + } + ) result = _write_conversation_comment( db=db, @@ -262,9 +259,7 @@ def test_rejects_missing_body_markdown( assert "error" in result assert "body_markdown is required" in result["error"] - def test_rejects_missing_feature_id( - self, db: Session, test_user: User, project_with_feature - ): + def test_rejects_missing_feature_id(self, db: Session, test_user: User, project_with_feature): """Test that missing feature_id returns an error.""" project = project_with_feature["project"] @@ -276,10 +271,12 @@ def test_rejects_missing_feature_id( thread_id=None, ) - content = json.dumps({ - "action": "add_comment", - "body_markdown": "test", - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": "test", + } + ) result = _write_conversation_comment( db=db, @@ -313,10 +310,12 @@ def test_rejects_mention_of_non_project_member( ) # Mention a user who is NOT a project member - content = json.dumps({ - "action": "add_comment", - "body_markdown": f"Hey @[Outsider]({non_member_user.id}) check this out!", - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": f"Hey @[Outsider]({non_member_user.id}) check this out!", + } + ) result = _write_conversation_comment( db=db, @@ -331,9 +330,7 @@ def test_rejects_mention_of_non_project_member( assert "not project members" in result["error"] assert "invalid_mentions" in result - def test_allows_mention_of_project_member( - self, db: Session, test_user: User, project_with_feature, member_user - ): + def test_allows_mention_of_project_member(self, db: Session, test_user: User, project_with_feature, member_user): """Test that mentioning project members passes validation.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -347,10 +344,12 @@ def test_allows_mention_of_project_member( ) # Mention a user who IS a project member - content = json.dumps({ - "action": "add_comment", - "body_markdown": f"Hey @[Member]({member_user.id}) please review!", - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": f"Hey @[Member]({member_user.id}) please review!", + } + ) # Mock ThreadService.create_comment_item to avoid SQLite/UUID issues mock_item = MagicMock() @@ -361,9 +360,7 @@ def test_allows_mention_of_project_member( mock_create.return_value = mock_item # Also mock _get_or_create_feature_thread - with patch( - "app.mcp.tools.vfs_write._get_or_create_feature_thread" - ) as mock_get_thread: + with patch("app.mcp.tools.vfs_write._get_or_create_feature_thread") as mock_get_thread: mock_thread = MagicMock() mock_thread.id = str(uuid4()) mock_get_thread.return_value = mock_thread @@ -384,9 +381,7 @@ def test_allows_mention_of_project_member( class TestWriteConversationCommentSuccess: """Tests for successful comment creation with mocked ThreadService.""" - def test_creates_comment_with_mcp_tracking( - self, db: Session, test_user: User, project_with_feature - ): + def test_creates_comment_with_mcp_tracking(self, db: Session, test_user: User, project_with_feature): """Test that comment is created with MCP tracking metadata.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -399,11 +394,13 @@ def test_creates_comment_with_mcp_tracking( thread_id=None, ) - content = json.dumps({ - "action": "add_comment", - "body_markdown": "This is a test comment", - "coding_agent_name": "claude_code", - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": "This is a test comment", + "coding_agent_name": "claude_code", + } + ) # Mock ThreadService.create_comment_item mock_item = MagicMock() @@ -413,9 +410,7 @@ def test_creates_comment_with_mcp_tracking( with patch("app.services.thread_service.ThreadService.create_comment_item") as mock_create: mock_create.return_value = mock_item - with patch( - "app.mcp.tools.vfs_write._get_or_create_feature_thread" - ) as mock_get_thread: + with patch("app.mcp.tools.vfs_write._get_or_create_feature_thread") as mock_get_thread: mock_thread = MagicMock() mock_thread.id = str(uuid4()) mock_get_thread.return_value = mock_thread @@ -439,9 +434,7 @@ def test_creates_comment_with_mcp_tracking( assert mock_item.content_data.get("created_via_mcp") is True assert mock_item.content_data.get("coding_agent_name") == "claude_code" - def test_coding_agent_name_from_json_overrides_param( - self, db: Session, test_user: User, project_with_feature - ): + def test_coding_agent_name_from_json_overrides_param(self, db: Session, test_user: User, project_with_feature): """Test that coding_agent_name in JSON overrides the function parameter.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -454,11 +447,13 @@ def test_coding_agent_name_from_json_overrides_param( thread_id=None, ) - content = json.dumps({ - "action": "add_comment", - "body_markdown": "Test message", - "coding_agent_name": "cursor", # This should be used - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": "Test message", + "coding_agent_name": "cursor", # This should be used + } + ) mock_item = MagicMock() mock_item.id = uuid4() @@ -467,9 +462,7 @@ def test_coding_agent_name_from_json_overrides_param( with patch("app.services.thread_service.ThreadService.create_comment_item") as mock_create: mock_create.return_value = mock_item - with patch( - "app.mcp.tools.vfs_write._get_or_create_feature_thread" - ) as mock_get_thread: + with patch("app.mcp.tools.vfs_write._get_or_create_feature_thread") as mock_get_thread: mock_thread = MagicMock() mock_thread.id = str(uuid4()) mock_get_thread.return_value = mock_thread @@ -486,9 +479,7 @@ def test_coding_agent_name_from_json_overrides_param( # JSON value should override param assert mock_item.content_data.get("coding_agent_name") == "cursor" - def test_uses_existing_thread_when_available( - self, db: Session, test_user: User, project_with_feature - ): + def test_uses_existing_thread_when_available(self, db: Session, test_user: User, project_with_feature): """Test that existing thread is reused.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -514,10 +505,12 @@ def test_uses_existing_thread_when_available( thread_id=thread.id, # Pass existing thread ID ) - content = json.dumps({ - "action": "add_comment", - "body_markdown": "Adding to existing thread", - }) + content = json.dumps( + { + "action": "add_comment", + "body_markdown": "Adding to existing thread", + } + ) mock_item = MagicMock() mock_item.id = uuid4() @@ -542,9 +535,7 @@ def test_uses_existing_thread_when_available( class TestGetOrCreateFeatureThread: """Tests for _get_or_create_feature_thread helper function.""" - def test_creates_new_thread_for_feature( - self, db: Session, test_user: User, project_with_feature - ): + def test_creates_new_thread_for_feature(self, db: Session, test_user: User, project_with_feature): """Test that a new thread is created when none exists.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -561,9 +552,7 @@ def test_creates_new_thread_for_feature( assert thread.context_id == str(feature.id) assert thread.title == feature.title - def test_returns_existing_thread( - self, db: Session, test_user: User, project_with_feature - ): + def test_returns_existing_thread(self, db: Session, test_user: User, project_with_feature): """Test that existing thread is returned instead of creating new one.""" feature = project_with_feature["feature"] project = project_with_feature["project"] @@ -590,9 +579,7 @@ def test_returns_existing_thread( assert thread.id == existing_thread.id assert thread.title == "Pre-existing Thread" - def test_returns_none_for_invalid_feature( - self, db: Session, test_user: User, project_with_feature - ): + def test_returns_none_for_invalid_feature(self, db: Session, test_user: User, project_with_feature): """Test that None is returned for non-existent feature.""" project = project_with_feature["project"] diff --git a/backend/tests/test_worker_graceful_shutdown.py b/backend/tests/test_worker_graceful_shutdown.py index a743048..11771aa 100644 --- a/backend/tests/test_worker_graceful_shutdown.py +++ b/backend/tests/test_worker_graceful_shutdown.py @@ -11,7 +11,7 @@ import pytest -from workers.core.worker import JobWorker, DEFAULT_SHUTDOWN_TIMEOUT +from workers.core.worker import DEFAULT_SHUTDOWN_TIMEOUT, JobWorker class TestJobWorkerGracefulShutdown: @@ -147,9 +147,7 @@ async def slow_handler(payload, db): # Start processing a job job_id = "test-job-123" - process_task = asyncio.create_task( - worker._process_message({"job_id": job_id}) - ) + process_task = asyncio.create_task(worker._process_message({"job_id": job_id})) # Give it time to add to in-flight set await asyncio.sleep(0.05) diff --git a/backend/tests/workers/test_helpers.py b/backend/tests/workers/test_helpers.py index 04a9443..ca36227 100644 --- a/backend/tests/workers/test_helpers.py +++ b/backend/tests/workers/test_helpers.py @@ -3,6 +3,7 @@ Tests the utility functions in workers/core/helpers.py, including the advisory lock ID generators. """ + from uuid import UUID, uuid4 from workers.core.helpers import get_advisory_lock_id, get_grounding_lock_id diff --git a/backend/workers/consumer.py b/backend/workers/consumer.py index 52a11ae..5a3b5a2 100644 --- a/backend/workers/consumer.py +++ b/backend/workers/consumer.py @@ -22,37 +22,37 @@ from app.models.job import JobType from workers.core import JobWorker, KafkaConnectionError from workers.handlers import ( + brainstorm_conversation_batch_generate_handler, + brainstorm_conversation_generate_handler, # Brainstorming brainstorm_generate_handler, - brainstorm_conversation_generate_handler, - brainstorm_conversation_batch_generate_handler, - brainstorm_spec_generate_handler, brainstorm_prompt_plan_generate_handler, - # Grounding - grounding_update_handler, - grounding_summarize_handler, - grounding_branch_summarize_handler, - grounding_merge_handler, - grounding_pull_handler, + brainstorm_spec_generate_handler, + # Integration + bugsync_handler, + # Code exploration + code_explorer_explore_handler, + code_explorer_grounding_generate_handler, # Collaboration collab_thread_ai_mention_handler, collab_thread_decision_summarize_handler, + feature_content_generate_handler, + grounding_branch_summarize_handler, + grounding_merge_handler, + grounding_pull_handler, + grounding_summarize_handler, + # Grounding + grounding_update_handler, + # Image annotation + image_annotate_handler, + mention_notification_handler, # Generation module_feature_generate_handler, - feature_content_generate_handler, - user_initiated_question_generate_handler, - # Integration - bugsync_handler, notification_fanout_handler, - mention_notification_handler, project_chat_mention_notification_handler, # Project chat project_chat_respond_handler, - # Image annotation - image_annotate_handler, - # Code exploration - code_explorer_explore_handler, - code_explorer_grounding_generate_handler, + user_initiated_question_generate_handler, # Web search web_search_execute_handler, ) @@ -73,7 +73,9 @@ async def main(): worker.register_handler(JobType.MODULE_FEATURE_GENERATE, module_feature_generate_handler) worker.register_handler(JobType.BRAINSTORM_GENERATE, brainstorm_generate_handler) # Legacy worker.register_handler(JobType.BRAINSTORM_CONVERSATION_GENERATE, brainstorm_conversation_generate_handler) - worker.register_handler(JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, brainstorm_conversation_batch_generate_handler) + worker.register_handler( + JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, brainstorm_conversation_batch_generate_handler + ) worker.register_handler(JobType.BRAINSTORM_SPEC_GENERATE, brainstorm_spec_generate_handler) worker.register_handler(JobType.BRAINSTORM_PROMPT_PLAN_GENERATE, brainstorm_prompt_plan_generate_handler) worker.register_handler(JobType.BUG_SYNC, bugsync_handler) diff --git a/backend/workers/core/__init__.py b/backend/workers/core/__init__.py index 4fdcc0e..db03e98 100644 --- a/backend/workers/core/__init__.py +++ b/backend/workers/core/__init__.py @@ -4,8 +4,8 @@ Provides the JobWorker class and shared helper utilities. """ -from workers.core.worker import JobWorker, JobHandler, KafkaConnectionError -from workers.core.helpers import publish_job_to_kafka, broadcast_pending_count_update +from workers.core.helpers import broadcast_pending_count_update, publish_job_to_kafka +from workers.core.worker import JobHandler, JobWorker, KafkaConnectionError __all__ = [ "JobWorker", diff --git a/backend/workers/core/helpers.py b/backend/workers/core/helpers.py index fa5a607..fd4a83a 100644 --- a/backend/workers/core/helpers.py +++ b/backend/workers/core/helpers.py @@ -88,9 +88,7 @@ def publish_job_to_kafka( ) if success: - logger.info( - f"Published job to Kafka: job_id={job_id}, topic={topic}, key={key_value}" - ) + logger.info(f"Published job to Kafka: job_id={job_id}, topic={topic}, key={key_value}") return success @@ -106,9 +104,7 @@ def broadcast_pending_count_update(db: Session, brainstorming_phase_id: UUID): counts = BrainstormingPhaseService.get_pending_questions_count(db, brainstorming_phase_id) # Get org_id and project_id from a module in this phase - module = db.query(Module).filter( - Module.brainstorming_phase_id == brainstorming_phase_id - ).first() + module = db.query(Module).filter(Module.brainstorming_phase_id == brainstorming_phase_id).first() if not module: logger.warning(f"No module found for phase {brainstorming_phase_id}, cannot broadcast") @@ -136,9 +132,7 @@ def broadcast_pending_count_update(db: Session, brainstorming_phase_id: UUID): ) if success: - logger.info( - f"Broadcasted pending_questions_updated via Kafka: phase_id={brainstorming_phase_id}" - ) + logger.info(f"Broadcasted pending_questions_updated via Kafka: phase_id={brainstorming_phase_id}") except Exception as e: logger.error(f"Failed to broadcast pending count update: {e}", exc_info=True) @@ -159,11 +153,7 @@ def clear_phase_generation_flag(db: Session, phase_id: UUID, flag_name: str): from app.models.brainstorming_phase import BrainstormingPhase try: - phase = ( - db.query(BrainstormingPhase) - .filter(BrainstormingPhase.id == phase_id) - .first() - ) + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() if not phase: logger.warning(f"Phase {phase_id} not found when clearing {flag_name}") @@ -203,16 +193,12 @@ def clear_implementation_generation_flag(db: Session, implementation_id: UUID, f implementation_id: Implementation ID flag_name: Name of the flag to clear (e.g., 'is_generating_spec') """ - from app.models.implementation import Implementation from app.models.feature import Feature + from app.models.implementation import Implementation from app.models.module import Module try: - implementation = ( - db.query(Implementation) - .filter(Implementation.id == implementation_id) - .first() - ) + implementation = db.query(Implementation).filter(Implementation.id == implementation_id).first() if not implementation: logger.warning(f"Implementation {implementation_id} not found when clearing {flag_name}") @@ -235,13 +221,11 @@ def clear_implementation_generation_flag(db: Session, implementation_id: UUID, f return # Find the thread for this feature (thread is linked via context_type/context_id) - from app.models.thread import Thread, ContextType + from app.models.thread import ContextType, Thread + thread = ( db.query(Thread) - .filter( - Thread.context_type == ContextType.BRAINSTORM_FEATURE, - Thread.context_id == str(feature.id) - ) + .filter(Thread.context_type == ContextType.BRAINSTORM_FEATURE, Thread.context_id == str(feature.id)) .first() ) @@ -282,11 +266,7 @@ def clear_thread_generation_flag(db: Session, thread_id: str, flag_name: str): from app.models.thread import Thread try: - thread = ( - db.query(Thread) - .filter(Thread.id == thread_id) - .first() - ) + thread = db.query(Thread).filter(Thread.id == thread_id).first() if not thread: logger.warning(f"Thread {thread_id} not found when clearing {flag_name}") diff --git a/backend/workers/core/worker.py b/backend/workers/core/worker.py index 0847ea1..fb4d381 100644 --- a/backend/workers/core/worker.py +++ b/backend/workers/core/worker.py @@ -181,8 +181,7 @@ async def start(self, topics: list[str]): f"Kafka brokers: {self.bootstrap_servers}" ) raise KafkaConnectionError( - f"Unable to connect to Kafka at {self.bootstrap_servers} " - f"after {self.max_connection_retries} retries" + f"Unable to connect to Kafka at {self.bootstrap_servers} after {self.max_connection_retries} retries" ) logger.info(f"Worker started, subscribed to: {topics}") @@ -288,9 +287,7 @@ async def _monitor_stuck_jobs(self): try: # Try to acquire advisory lock (transaction-level, auto-releases on commit) # This ensures only one worker runs cleanup at a time - acquired = db.execute( - text(f"SELECT pg_try_advisory_xact_lock({lock_id})") - ).scalar() + acquired = db.execute(text(f"SELECT pg_try_advisory_xact_lock({lock_id})")).scalar() if not acquired: logger.debug("Job monitor: another worker is handling cleanup") @@ -304,6 +301,7 @@ async def _monitor_stuck_jobs(self): # Clean up expired MCP image submissions from app.services.mcp_image_service import MCPImageService + expired_count = MCPImageService.cleanup_expired_submissions(db) db.commit() @@ -418,9 +416,7 @@ async def _process_message(self, payload: dict): # Extract LLM usage info if present in result # Handlers can include a "_llm_usage" key with model/token/cost info - llm_usage = ( - result.pop("_llm_usage", None) if isinstance(result, dict) else None - ) + llm_usage = result.pop("_llm_usage", None) if isinstance(result, dict) else None # Mark job as succeeded with usage info update_kwargs = { @@ -433,9 +429,7 @@ async def _process_message(self, payload: dict): if llm_usage: update_kwargs["model_used"] = llm_usage.get("model") update_kwargs["total_prompt_tokens"] = llm_usage.get("prompt_tokens") - update_kwargs["total_completion_tokens"] = llm_usage.get( - "completion_tokens" - ) + update_kwargs["total_completion_tokens"] = llm_usage.get("completion_tokens") update_kwargs["total_cost_usd"] = llm_usage.get("cost_usd") logger.info( f"Job {job_id} LLM usage: model={llm_usage.get('model')}, " diff --git a/backend/workers/handlers/__init__.py b/backend/workers/handlers/__init__.py index 7a10e24..1866518 100644 --- a/backend/workers/handlers/__init__.py +++ b/backend/workers/handlers/__init__.py @@ -13,20 +13,17 @@ # Brainstorming handlers from workers.handlers.brainstorming import ( - brainstorm_generate_handler, - brainstorm_conversation_generate_handler, brainstorm_conversation_batch_generate_handler, - brainstorm_spec_generate_handler, + brainstorm_conversation_generate_handler, + brainstorm_generate_handler, brainstorm_prompt_plan_generate_handler, + brainstorm_spec_generate_handler, ) -# Grounding handlers -from workers.handlers.grounding import ( - grounding_update_handler, - grounding_summarize_handler, - grounding_branch_summarize_handler, - grounding_merge_handler, - grounding_pull_handler, +# Code exploration handlers +from workers.handlers.code_explorer import ( + code_explorer_explore_handler, + code_explorer_grounding_generate_handler, ) # Collaboration handlers @@ -37,16 +34,30 @@ # Generation handlers from workers.handlers.generation import ( - module_feature_generate_handler, feature_content_generate_handler, + module_feature_generate_handler, user_initiated_question_generate_handler, ) +# Grounding handlers +from workers.handlers.grounding import ( + grounding_branch_summarize_handler, + grounding_merge_handler, + grounding_pull_handler, + grounding_summarize_handler, + grounding_update_handler, +) + +# Image annotation handlers +from workers.handlers.image_annotator import ( + image_annotate_handler, +) + # Integration handlers from workers.handlers.integration import ( bugsync_handler, - notification_fanout_handler, mention_notification_handler, + notification_fanout_handler, project_chat_mention_notification_handler, ) @@ -55,17 +66,6 @@ project_chat_respond_handler, ) -# Image annotation handlers -from workers.handlers.image_annotator import ( - image_annotate_handler, -) - -# Code exploration handlers -from workers.handlers.code_explorer import ( - code_explorer_explore_handler, - code_explorer_grounding_generate_handler, -) - # Web search handlers from workers.handlers.web_search import ( web_search_execute_handler, diff --git a/backend/workers/handlers/brainstorming.py b/backend/workers/handlers/brainstorming.py index b6174b7..c698fd6 100644 --- a/backend/workers/handlers/brainstorming.py +++ b/backend/workers/handlers/brainstorming.py @@ -10,12 +10,12 @@ from sqlalchemy.orm import Session -from app.models.job import JobType from app.agents.brainstorm_spec import JobCancelledException +from app.models.job import JobType from workers.core.helpers import ( - publish_job_to_kafka, broadcast_pending_count_update, clear_phase_generation_flag, + publish_job_to_kafka, ) logger = logging.getLogger(__name__) @@ -114,15 +114,11 @@ async def brainstorm_conversation_generate_handler(payload: dict, db: Session) - project_id = UUID(project_id_str) brainstorming_phase_id = UUID(brainstorming_phase_id_str) job_id = UUID(job_id_str) if job_id_str else None - created_by_user_id = ( - UUID(created_by_user_id_str) if created_by_user_id_str else None - ) + created_by_user_id = UUID(created_by_user_id_str) if created_by_user_id_str else None except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Generating brainstorm conversations for phase {brainstorming_phase_id} using 6-agent pipeline" - ) + logger.info(f"Generating brainstorm conversations for phase {brainstorming_phase_id} using 6-agent pipeline") # Update job status to running if job_id: @@ -139,9 +135,7 @@ def progress_callback(progress_data: dict): # Update job's result field with progress (triggers WebSocket broadcast) if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info( f"Progress update: {progress_data.get('workflow_step')} - " @@ -184,13 +178,17 @@ def progress_callback(progress_data: dict): # After initial generation, immediately queue a second "proactive" run # This generates PENDING questions for the user to review - if is_initial_generation and result.get('aspects_created', 0) > 0: + if is_initial_generation and result.get("aspects_created", 0) > 0: from app.models.brainstorming_phase import BrainstormingPhase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id, - BrainstormingPhase.archived_at.is_(None), - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == brainstorming_phase_id, + BrainstormingPhase.archived_at.is_(None), + ) + .first() + ) if phase: # Create a batch generation job for the second run @@ -227,10 +225,15 @@ def progress_callback(progress_data: dict): # Check for rerun request (set when a new trigger came in while this job was running) from app.models.brainstorming_phase import BrainstormingPhase as BP - phase = db.query(BP).filter( - BP.id == brainstorming_phase_id, - BP.archived_at.is_(None), - ).first() + + phase = ( + db.query(BP) + .filter( + BP.id == brainstorming_phase_id, + BP.archived_at.is_(None), + ) + .first() + ) if phase and phase.conversation_rerun_requested: phase.conversation_rerun_requested = False db.commit() @@ -269,13 +272,9 @@ def progress_callback(progress_data: dict): return result except Exception as e: - logger.error( - f"Brainstorm conversation generation failed for phase {brainstorming_phase_id}: {e}" - ) + logger.error(f"Brainstorm conversation generation failed for phase {brainstorming_phase_id}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: @@ -323,10 +322,10 @@ async def brainstorm_conversation_batch_generate_handler(payload: dict, db: Sess Raises: ValueError: If phase not found, at limit, or LLM config missing """ - from app.services.brainstorming_phase_service import BrainstormingPhaseService - from app.services.job_service import JobService from app.models.feature import FeatureVisibilityStatus from app.models.job import JobStatus + from app.services.brainstorming_phase_service import BrainstormingPhaseService + from app.services.job_service import JobService project_id_str = payload.get("project_id") brainstorming_phase_id_str = payload.get("brainstorming_phase_id") @@ -353,10 +352,15 @@ async def brainstorm_conversation_batch_generate_handler(payload: dict, db: Sess # Load the phase to get creator ID (for created_by field on modules/features) from app.models.brainstorming_phase import BrainstormingPhase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == brainstorming_phase_id, - BrainstormingPhase.archived_at.is_(None), - ).first() + + phase = ( + db.query(BrainstormingPhase) + .filter( + BrainstormingPhase.id == brainstorming_phase_id, + BrainstormingPhase.archived_at.is_(None), + ) + .first() + ) if not phase: raise ValueError(f"Phase {brainstorming_phase_id} not found or is archived") @@ -392,12 +396,7 @@ def progress_callback(progress_data: dict): # Update job's result field with progress (triggers WebSocket broadcast) if job_id: try: - JobService.update_job_status( - db, - job_id, - JobStatus.RUNNING, - result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info( f"Progress update: {progress_data.get('workflow_step')} - " f"{progress_data.get('progress_percentage')}% - " @@ -477,12 +476,7 @@ def progress_callback(progress_data: dict): except Exception as e: logger.error(f"Proactive batch generation failed for phase {brainstorming_phase_id}: {e}") if job_id: - JobService.update_job_status( - db, - job_id, - JobStatus.FAILED, - result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: @@ -549,15 +543,11 @@ async def brainstorm_spec_generate_handler(payload: dict, db: Session) -> dict: project_id = UUID(project_id_str) brainstorming_phase_id = UUID(brainstorming_phase_id_str) job_id = UUID(job_id_str) if job_id_str else None - created_by_user_id = ( - UUID(created_by_user_id_str) if created_by_user_id_str else None - ) + created_by_user_id = UUID(created_by_user_id_str) if created_by_user_id_str else None except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Generating brainstorm spec for phase {brainstorming_phase_id} using 4-agent pipeline" - ) + logger.info(f"Generating brainstorm spec for phase {brainstorming_phase_id} using 4-agent pipeline") # Update job status to running if job_id: @@ -574,9 +564,7 @@ def progress_callback(progress_data: dict): # Update job's result field with progress (triggers WebSocket broadcast) if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info( f"Progress update: {progress_data.get('workflow_step')} - " @@ -624,24 +612,18 @@ def progress_callback(progress_data: dict): return result except JobCancelledException: - logger.info( - f"Brainstorm spec generation cancelled for phase {brainstorming_phase_id}" - ) + logger.info(f"Brainstorm spec generation cancelled for phase {brainstorming_phase_id}") return { "cancelled": True, "brainstorming_phase_id": str(brainstorming_phase_id), } except Exception as e: - logger.error( - f"Brainstorm spec generation failed for phase {brainstorming_phase_id}: {e}" - ) + logger.error(f"Brainstorm spec generation failed for phase {brainstorming_phase_id}: {e}") # Rollback session in case of IntegrityError or other DB errors db.rollback() if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: @@ -706,15 +688,11 @@ async def brainstorm_prompt_plan_generate_handler(payload: dict, db: Session) -> project_id = UUID(project_id_str) brainstorming_phase_id = UUID(brainstorming_phase_id_str) job_id = UUID(job_id_str) if job_id_str else None - created_by_user_id = ( - UUID(created_by_user_id_str) if created_by_user_id_str else None - ) + created_by_user_id = UUID(created_by_user_id_str) if created_by_user_id_str else None except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Generating brainstorm prompt plan for phase {brainstorming_phase_id} using 4-agent pipeline" - ) + logger.info(f"Generating brainstorm prompt plan for phase {brainstorming_phase_id} using 4-agent pipeline") # Update job status to running if job_id: @@ -731,9 +709,7 @@ def progress_callback(progress_data: dict): # Update job's result field with progress (triggers WebSocket broadcast) if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info( f"Progress update: {progress_data.get('workflow_step')} - " f"{progress_data.get('progress_percentage')}% - " @@ -780,22 +756,16 @@ def progress_callback(progress_data: dict): return result except JobCancelledException: - logger.info( - f"Brainstorm prompt plan generation cancelled for phase {brainstorming_phase_id}" - ) + logger.info(f"Brainstorm prompt plan generation cancelled for phase {brainstorming_phase_id}") return { "cancelled": True, "brainstorming_phase_id": str(brainstorming_phase_id), } except Exception as e: - logger.error( - f"Brainstorm prompt plan generation failed for phase {brainstorming_phase_id}: {e}" - ) + logger.error(f"Brainstorm prompt plan generation failed for phase {brainstorming_phase_id}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: diff --git a/backend/workers/handlers/code_explorer.py b/backend/workers/handlers/code_explorer.py index f01a72b..f4ef63e 100644 --- a/backend/workers/handlers/code_explorer.py +++ b/backend/workers/handlers/code_explorer.py @@ -12,18 +12,17 @@ from app.models import ( CodeExplorationResult, + IntegrationConfig, Job, - JobStatus, JobType, PlatformConnector, PlatformSettings, Project, - IntegrationConfig, ) from app.models.llm_call_log import LLMCallLog from app.services.code_explorer_client import code_explorer_client -from app.services.platform_settings_service import PlatformSettingsService from app.services.kafka_producer import SyncKafkaProducer +from app.services.platform_settings_service import PlatformSettingsService logger = logging.getLogger(__name__) @@ -31,9 +30,7 @@ CODE_EXPLORER_MODEL = "claude-sonnet-4-5" -def calculate_claude_cost( - model: str, prompt_tokens: int, completion_tokens: int -) -> Optional[float]: +def calculate_claude_cost(model: str, prompt_tokens: int, completion_tokens: int) -> Optional[float]: """Calculate cost for Claude API usage using LiteLLM's pricing.""" try: mock_response = ModelResponse( @@ -58,11 +55,7 @@ def get_code_explorer_api_key(db: Session) -> str | None: if not settings or not settings.code_explorer_connector_id: return None - connector = ( - db.query(PlatformConnector) - .filter(PlatformConnector.id == settings.code_explorer_connector_id) - .first() - ) + connector = db.query(PlatformConnector).filter(PlatformConnector.id == settings.code_explorer_connector_id).first() if not connector or not connector.encrypted_credentials: return None @@ -89,11 +82,7 @@ async def get_github_token_for_org( # First check if a specific GitHub connector is configured if github_integration_config_id: - config = ( - db.query(IntegrationConfig) - .filter(IntegrationConfig.id == github_integration_config_id) - .first() - ) + config = db.query(IntegrationConfig).filter(IntegrationConfig.id == github_integration_config_id).first() if not config: # Fallback: use first available GitHub connector for the org @@ -201,32 +190,28 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: # Get Anthropic API key from Platform Settings anthropic_key = get_code_explorer_api_key(db) if not anthropic_key: - raise ValueError( - "Code Explorer Anthropic API key not configured in Platform Settings" - ) + raise ValueError("Code Explorer Anthropic API key not configured in Platform Settings") # Build repos list from ALL project repositories # This enables multi-repo exploration where Claude can see all connected repos all_repos = project.repositories if not all_repos: - raise ValueError( - "No repository configured. Set a GitHub repository in Project Settings." - ) + raise ValueError("No repository configured. Set a GitHub repository in Project Settings.") repos = [] for repo in all_repos: github_token = None if repo.github_integration_config_id: - github_token = await get_github_token_for_org( - db, project.org_id, repo.github_integration_config_id - ) - repos.append({ - "slug": repo.slug, - "repo_url": repo.repo_url, - "branch": repo.default_branch or "main", - "github_token": github_token, - "user_remarks": repo.user_remarks, - }) + github_token = await get_github_token_for_org(db, project.org_id, repo.github_integration_config_id) + repos.append( + { + "slug": repo.slug, + "repo_url": repo.repo_url, + "branch": repo.default_branch or "main", + "github_token": github_token, + "user_remarks": repo.user_remarks, + } + ) logger.info(f"Exploring {len(repos)} repos: {[r['slug'] for r in repos]}") @@ -262,9 +247,7 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: prompt_tokens=result.get("prompt_tokens"), completion_tokens=result.get("completion_tokens"), execution_time_seconds=( - Decimal(str(result["execution_time_seconds"])) - if result.get("execution_time_seconds") - else None + Decimal(str(result["execution_time_seconds"])) if result.get("execution_time_seconds") else None ), ) db.add(exploration) @@ -297,7 +280,11 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: created_at=finished_at, ) db.add(llm_call_log) - logger.info(f"Created LLMCallLog for code exploration job {job.id}, cost=${cost_usd:.6f}" if cost_usd else f"Created LLMCallLog for code exploration job {job.id}") + logger.info( + f"Created LLMCallLog for code exploration job {job.id}, cost=${cost_usd:.6f}" + if cost_usd + else f"Created LLMCallLog for code exploration job {job.id}" + ) db.commit() db.refresh(exploration) @@ -312,9 +299,7 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: from app.models.project_chat import ProjectChat from app.services.project_chat_service import ProjectChatService - discussion = db.query(ProjectChat).filter( - ProjectChat.id == UUID(project_chat_id) - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == UUID(project_chat_id)).first() if discussion: discussion.is_exploring_code = False discussion.exploring_code_prompt = None @@ -344,7 +329,7 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: # Only trigger follow-up if we have actual output if has_output: - logger.info(f"Triggering follow-up project-chat discussion for exploration results") + logger.info("Triggering follow-up project-chat discussion for exploration results") # Create a CODE_EXPLORATION message to show the results in the conversation from app.models.project_chat import ProjectChatMessage, ProjectChatMessageType @@ -358,7 +343,9 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: "prompt": exploration.prompt, "branch": exploration.branch, "repo_url": exploration.repo_url, - "execution_time_seconds": float(exploration.execution_time_seconds) if exploration.execution_time_seconds else None, + "execution_time_seconds": float(exploration.execution_time_seconds) + if exploration.execution_time_seconds + else None, "prompt_tokens": exploration.prompt_tokens, "completion_tokens": exploration.completion_tokens, }, @@ -450,13 +437,13 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: if has_output: # Check if job was cancelled while exploration was running from app.services.job_service import JobService as JS + if job and JS.is_job_cancelled(db, job.id): logger.info( - f"Job {job.id} was cancelled during exploration, " - f"skipping follow-up for thread {thread_id}" + f"Job {job.id} was cancelled during exploration, skipping follow-up for thread {thread_id}" ) else: - logger.info(f"Triggering follow-up MFBTAI for thread exploration results") + logger.info("Triggering follow-up MFBTAI for thread exploration results") # Create CODE_EXPLORATION thread item to show results in conversation ThreadService.create_code_exploration_item( @@ -569,9 +556,7 @@ async def code_explorer_explore_handler(payload: dict, db: Session) -> dict: from app.models.project_chat import ProjectChat from app.services.project_chat_service import ProjectChatService - discussion = db.query(ProjectChat).filter( - ProjectChat.id == UUID(project_chat_id) - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == UUID(project_chat_id)).first() if discussion: discussion.is_exploring_code = False discussion.exploring_code_prompt = None @@ -653,15 +638,11 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - # Get all repositories for the project all_repos = project.repositories if not all_repos: - raise ValueError( - "Project must have at least one GitHub repository configured for grounding generation" - ) + raise ValueError("Project must have at least one GitHub repository configured for grounding generation") primary_repo = project.get_primary_repository() if not primary_repo or not primary_repo.github_integration_config_id: - raise ValueError( - "Primary repository must have a GitHub connector configured" - ) + raise ValueError("Primary repository must have a GitHub connector configured") logger.info(f"Starting grounding generation for project {project_id} with {len(all_repos)} repos") @@ -681,25 +662,23 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - # Get Anthropic API key from Platform Settings anthropic_key = get_code_explorer_api_key(db) if not anthropic_key: - raise ValueError( - "Code Explorer Anthropic API key not configured in Platform Settings" - ) + raise ValueError("Code Explorer Anthropic API key not configured in Platform Settings") # Build repos list with GitHub tokens and user_remarks for all repositories repos = [] for repo in all_repos: github_token = None if repo.github_integration_config_id: - github_token = await get_github_token_for_org( - db, project.org_id, repo.github_integration_config_id - ) - repos.append({ - "slug": repo.slug, - "repo_url": repo.repo_url, - "branch": repo.default_branch or "main", - "github_token": github_token, - "user_remarks": repo.user_remarks, - }) + github_token = await get_github_token_for_org(db, project.org_id, repo.github_integration_config_id) + repos.append( + { + "slug": repo.slug, + "repo_url": repo.repo_url, + "branch": repo.default_branch or "main", + "github_token": github_token, + "user_remarks": repo.user_remarks, + } + ) # Call code-explorer service in grounding_generate mode result = await code_explorer_client.explore( @@ -744,7 +723,8 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - db.add(llm_call_log) logger.info( f"Created LLMCallLog for grounding generation job {job.id}, cost=${cost_usd:.6f}" - if cost_usd else f"Created LLMCallLog for grounding generation job {job.id}" + if cost_usd + else f"Created LLMCallLog for grounding generation job {job.id}" ) # Initialize summary_usage before conditional block @@ -752,13 +732,12 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - if has_output: # Update agents.md with generated content - from app.services.grounding_service import GroundingService - # Generate summary using grounding orchestrator for consistency from app.agents.grounding import create_orchestrator from app.agents.llm_client import LLMCallLogger - from app.services.platform_settings_service import require_llm_config_sync from app.database import SessionLocal + from app.services.grounding_service import GroundingService + from app.services.platform_settings_service import require_llm_config_sync summary = None try: @@ -770,9 +749,7 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - # Create LLM call logger for tracking llm_call_logger = ( - LLMCallLogger(db_session_factory=SessionLocal, job_id=UUID(job_id)) - if job_id - else None + LLMCallLogger(db_session_factory=SessionLocal, job_id=UUID(job_id)) if job_id else None ) # Create orchestrator and generate summary @@ -785,7 +762,11 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - ) summary = await orchestrator.summarize_content(result["output"]) - logger.info(f"Generated summary for agents.md: {summary[:100]}..." if len(summary) > 100 else f"Generated summary for agents.md: {summary}") + logger.info( + f"Generated summary for agents.md: {summary[:100]}..." + if len(summary) > 100 + else f"Generated summary for agents.md: {summary}" + ) # Capture usage stats from the summarization call summary_usage = orchestrator.model_client.get_usage_stats() @@ -808,9 +789,7 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - ) # Broadcast grounding update via WebSocket - GroundingService._broadcast_grounding_update( - db, project_id, grounding_file, "generated" - ) + GroundingService._broadcast_grounding_update(db, project_id, grounding_file, "generated") logger.info(f"Updated agents.md for project {project_id} with generated grounding content") else: @@ -852,14 +831,10 @@ async def code_explorer_grounding_generate_handler(payload: dict, db: Session) - try: from app.services.grounding_service import GroundingService - grounding_file = GroundingService.clear_generating_flag( - db, project_id, "agents.md" - ) + grounding_file = GroundingService.clear_generating_flag(db, project_id, "agents.md") if grounding_file: # Broadcast update to notify UI that generation finished - GroundingService._broadcast_grounding_update( - db, project_id, grounding_file, "generating_finished" - ) + GroundingService._broadcast_grounding_update(db, project_id, grounding_file, "generating_finished") logger.info(f"Cleared is_generating flag for project {project_id}") except Exception as cleanup_error: logger.error(f"Failed to clear is_generating flag: {cleanup_error}") diff --git a/backend/workers/handlers/collaboration.py b/backend/workers/handlers/collaboration.py index 7f3b878..1daee23 100644 --- a/backend/workers/handlers/collaboration.py +++ b/backend/workers/handlers/collaboration.py @@ -17,8 +17,8 @@ from workers.core.helpers import publish_job_to_kafka if TYPE_CHECKING: - from app.models.thread import Thread from app.models.project import Project + from app.models.thread import Thread logger = logging.getLogger(__name__) @@ -64,11 +64,11 @@ async def collab_thread_ai_mention_handler(payload: dict, db: Session) -> dict: """ from app.agents.collab_thread_assistant import handle_ai_mention from app.agents.collab_thread_assistant.spec_draft_handler import handle_spec_draft_ai_mention - from app.services.job_service import JobService - from app.services.thread_service import ThreadService - from app.services.agent_utils import get_or_create_agent_user from app.models.job import JobStatus from app.models.thread_item import ThreadItem + from app.services.agent_utils import get_or_create_agent_user + from app.services.job_service import JobService + from app.services.thread_service import ThreadService thread_id = payload.get("thread_id") message_text = payload.get("message_text") @@ -105,6 +105,7 @@ async def collab_thread_ai_mention_handler(payload: dict, db: Session) -> dict: # Get thread early for retry callbacks from app.models.thread import Thread + thread = db.query(Thread).filter(Thread.id == thread_id).first() if not thread: raise ValueError(f"Thread {thread_id} not found") @@ -130,10 +131,7 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: job_result["block_id"] = block_id else: job_result["feature_id"] = feature_id - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, - result=job_result - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result=job_result) # Check if this is an MCQ-triggered mention trigger_type = payload.get("trigger_type", "mention") @@ -227,13 +225,12 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # Check if job was cancelled while LLM was generating if job_id and JobService.is_job_cancelled(db, job_id): logger.info( - f"Job {job_id} was cancelled during LLM generation, " - f"skipping message creation for thread {thread_id}" + f"Job {job_id} was cancelled during LLM generation, skipping message creation for thread {thread_id}" ) # Extract LLM usage before returning llm_usage = {} model_client = result.get("model_client") - if model_client and hasattr(model_client, 'get_usage_stats'): + if model_client and hasattr(model_client, "get_usage_stats"): usage_stats = model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -249,9 +246,9 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # Check for code exploration request first from app.agents.collab_thread_assistant.exploration_parser import ( + has_exploration_block, parse_exploration_request, strip_exploration_block, - has_exploration_block, ) reply_text = result["reply_text"] @@ -295,7 +292,7 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # Extract LLM usage stats llm_usage = None model_client = result.get("model_client") - if model_client and hasattr(model_client, 'get_usage_stats'): + if model_client and hasattr(model_client, "get_usage_stats"): usage_stats = model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -312,17 +309,21 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: if llm_usage: result_dict["_llm_usage"] = llm_usage - logger.info(f"Code exploration triggered for thread {thread_id}, prompt: {exploration_req.prompt[:50]}...") + logger.info( + f"Code exploration triggered for thread {thread_id}, prompt: {exploration_req.prompt[:50]}..." + ) triggered_async_followup = True return result_dict else: - logger.warning(f"Project {thread.project_id} has no GitHub repo configured, skipping exploration") + logger.warning( + f"Project {thread.project_id} has no GitHub repo configured, skipping exploration" + ) # Check for web search request from app.agents.collab_thread_assistant.web_search_parser import ( + has_web_search_block, parse_web_search_request, strip_web_search_block, - has_web_search_block, ) if has_web_search_block(reply_text): @@ -354,7 +355,7 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # Extract LLM usage stats llm_usage = None model_client = result.get("model_client") - if model_client and hasattr(model_client, 'get_usage_stats'): + if model_client and hasattr(model_client, "get_usage_stats"): usage_stats = model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -410,6 +411,7 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # Store mcq_depth for loop prevention mcq_item.content_data["mcq_depth"] = mcq_depth from sqlalchemy.orm.attributes import flag_modified + flag_modified(mcq_item, "content_data") db.commit() created_items.append(mcq_item) @@ -442,7 +444,7 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # Extract LLM usage stats from model client llm_usage = None model_client = result.get("model_client") - if model_client and hasattr(model_client, 'get_usage_stats'): + if model_client and hasattr(model_client, "get_usage_stats"): usage_stats = model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -483,11 +485,7 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: if job_id: JobService.update_job_status( - db, - job_id, - JobStatus.FAILED, - error_message=str(e), - result={"error": str(e), "thread_id": thread_id} + db, job_id, JobStatus.FAILED, error_message=str(e), result={"error": str(e), "thread_id": thread_id} ) raise @@ -496,11 +494,13 @@ async def _on_assistant_retry(attempt: int, max_attempts: int) -> None: # or web search - in those cases the follow-up job will clear the flag when complete if not triggered_async_followup: from workers.core.helpers import clear_thread_generation_flag + clear_thread_generation_flag(db, thread_id, "is_generating_ai_response") # Always clear retry_status on completion (success or failure) # Re-query the thread to avoid stale object issues after multiple commits from app.models.thread import Thread + fresh_thread = db.query(Thread).filter(Thread.id == thread_id).first() if fresh_thread and fresh_thread.retry_status is not None: fresh_thread.retry_status = None @@ -526,6 +526,7 @@ def _check_simple_single_mcq_case(thread: "Thread", db: Session) -> dict | None: Dict with summary info if simple case, None if LLM needed """ from sqlalchemy.orm import joinedload + from app.models.thread import Thread from app.models.thread_item import ThreadItem, ThreadItemType from app.services.agent_utils import AGENT_EMAIL @@ -551,7 +552,7 @@ def _check_simple_single_mcq_case(thread: "Thread", db: Session) -> dict | None: for item in thread_with_items.items: item_type = item.item_type # Handle both string and enum values - if hasattr(item_type, 'value'): + if hasattr(item_type, "value"): item_type = item_type.value if item_type == ThreadItemType.MCQ_FOLLOWUP.value: @@ -647,16 +648,14 @@ async def collab_thread_decision_summarize_handler(payload: dict, db: Session) - ValueError: If thread not found or LLM config missing """ from app.agents.collab_thread_decision_summarizer import create_orchestrator + from app.models.feature import Feature, FeatureType + from app.models.job import Job, JobStatus + from app.models.module import Module + from app.models.thread import ContextType, Thread + from app.models.thread_item import ThreadItem from app.services.job_service import JobService from app.services.platform_settings_service import require_llm_config_sync from app.services.thread_service import ThreadService - from app.agents.llm_client import LLMCallLogger - from app.database import SessionLocal - from app.models.job import Job, JobStatus - from app.models.thread import Thread, ContextType - from app.models.thread_item import ThreadItem - from app.models.feature import Feature, FeatureType - from app.models.module import Module thread_id = payload.get("thread_id") job_id_str = payload.get("job_id") @@ -701,13 +700,15 @@ async def collab_thread_decision_summarize_handler(payload: dict, db: Session) - thread.decision_summary = decision_summary thread.decision_summary_short = decision_summary_short thread.unresolved_points = [] - thread.last_summarized_item_id = simple_mcq['mcq_item_id'] - thread.suggested_implementation_name = simple_mcq.get('suggested_implementation_name') + thread.last_summarized_item_id = simple_mcq["mcq_item_id"] + thread.suggested_implementation_name = simple_mcq.get("suggested_implementation_name") # Evaluate button visibility after summarization - thread.show_create_implementation_button = ThreadService._should_show_create_implementation_button(db, thread_id) + thread.show_create_implementation_button = ThreadService._should_show_create_implementation_button( + db, thread_id + ) # Store snapshot on the MCQ item for efficient deletion handling - mcq_item = db.query(ThreadItem).filter(ThreadItem.id == simple_mcq['mcq_item_id']).first() + mcq_item = db.query(ThreadItem).filter(ThreadItem.id == simple_mcq["mcq_item_id"]).first() if mcq_item: mcq_item.summary_snapshot = decision_summary mcq_item.summary_short_snapshot = decision_summary_short @@ -744,19 +745,27 @@ class SimpleSummaryResult: if module and module.brainstorming_phase_id: from app.models.brainstorming_phase import BrainstormingPhase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == module.brainstorming_phase_id - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter(BrainstormingPhase.id == module.brainstorming_phase_id) + .first() + ) if phase: - existing_job = db.query(Job).filter( - Job.project_id == phase.project_id, - Job.job_type.in_([ - JobType.BRAINSTORM_CONVERSATION_GENERATE, - JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, - ]), - Job.status.in_([JobStatus.QUEUED, JobStatus.RUNNING]), - ).first() + existing_job = ( + db.query(Job) + .filter( + Job.project_id == phase.project_id, + Job.job_type.in_( + [ + JobType.BRAINSTORM_CONVERSATION_GENERATE, + JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, + ] + ), + Job.status.in_([JobStatus.QUEUED, JobStatus.RUNNING]), + ) + .first() + ) if existing_job: # Set rerun flag instead of skipping @@ -796,7 +805,9 @@ class SimpleSummaryResult: f"(simple_mcq=True, first={is_first_summarization})" ) else: - logger.warning(f"Failed to publish DECISIONS_RESOLVED job for phase {module.brainstorming_phase_id}") + logger.warning( + f"Failed to publish DECISIONS_RESOLVED job for phase {module.brainstorming_phase_id}" + ) except Exception as e: logger.warning(f"Failed to queue conversation generation: {e}") @@ -833,6 +844,7 @@ class SimpleSummaryResult: # Clear generation flag on completion (simple_mcq path) from workers.core.helpers import clear_thread_generation_flag + clear_thread_generation_flag(db, thread_id, "is_generating_decision_summary") return result_dict @@ -855,10 +867,7 @@ def progress_callback(progress_data: dict): if job_id: try: JobService.update_job_status( - db, - job_id, - JobStatus.RUNNING, - result={"progress": latest_progress, "thread_id": thread_id} + db, job_id, JobStatus.RUNNING, result={"progress": latest_progress, "thread_id": thread_id} ) logger.info( f"Progress update: {progress_data.get('workflow_step')} - " @@ -895,16 +904,15 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: db.commit() # Evaluate and set button visibility after summarization - thread.show_create_implementation_button = ThreadService._should_show_create_implementation_button(db, thread_id) + thread.show_create_implementation_button = ThreadService._should_show_create_implementation_button( + db, thread_id + ) db.commit() # Check if we should create a proactive conversation trigger # Trigger if: (1) first summarization OR (2) unresolved count decreased new_unresolved_count = len(result.final_unresolved_points) - should_trigger = ( - is_first_summarization or - new_unresolved_count < previous_unresolved_count - ) + should_trigger = is_first_summarization or new_unresolved_count < previous_unresolved_count if should_trigger: # Only trigger for BRAINSTORM_FEATURE threads @@ -921,20 +929,28 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: if module and module.brainstorming_phase_id: from app.models.brainstorming_phase import BrainstormingPhase - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == module.brainstorming_phase_id - ).first() + phase = ( + db.query(BrainstormingPhase) + .filter(BrainstormingPhase.id == module.brainstorming_phase_id) + .first() + ) if phase: # Check for existing running job for this phase - existing_job = db.query(Job).filter( - Job.project_id == phase.project_id, - Job.job_type.in_([ - JobType.BRAINSTORM_CONVERSATION_GENERATE, - JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, - ]), - Job.status.in_([JobStatus.QUEUED, JobStatus.RUNNING]), - ).first() + existing_job = ( + db.query(Job) + .filter( + Job.project_id == phase.project_id, + Job.job_type.in_( + [ + JobType.BRAINSTORM_CONVERSATION_GENERATE, + JobType.BRAINSTORM_CONVERSATION_BATCH_GENERATE, + ] + ), + Job.status.in_([JobStatus.QUEUED, JobStatus.RUNNING]), + ) + .first() + ) if existing_job: # Set rerun flag instead of skipping @@ -975,15 +991,15 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: f"(first={is_first_summarization}, unresolved: {previous_unresolved_count} -> {new_unresolved_count})" ) else: - logger.warning(f"Failed to publish DECISIONS_RESOLVED job for phase {module.brainstorming_phase_id}") + logger.warning( + f"Failed to publish DECISIONS_RESOLVED job for phase {module.brainstorming_phase_id}" + ) except Exception as e: logger.warning(f"Failed to queue conversation generation: {e}") # Extract LLM usage stats llm_usage = None - if hasattr(orchestrator, "model_client") and hasattr( - orchestrator.model_client, "get_usage_stats" - ): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -1026,14 +1042,13 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: except Exception as e: logger.error(f"Decision summary failed for thread {thread_id}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: # Always clear generation flag when job completes (success or failure) from workers.core.helpers import clear_thread_generation_flag + clear_thread_generation_flag(db, thread_id, "is_generating_decision_summary") # Always clear retry_status on completion (success or failure) @@ -1069,7 +1084,6 @@ async def _trigger_thread_code_exploration( original_message: The original user message (for retry) """ from app.models.platform_settings import PlatformSettings - from app.models.job import Job from app.services.job_service import JobService if not exploration_prompt: @@ -1130,8 +1144,7 @@ async def _trigger_thread_code_exploration( if success: logger.info( - f"Triggered code exploration for thread {thread_id}: " - f"job_id={job.id}, prompt={exploration_prompt[:50]}..." + f"Triggered code exploration for thread {thread_id}: job_id={job.id}, prompt={exploration_prompt[:50]}..." ) else: logger.error(f"Failed to publish code exploration job for thread {thread_id}") @@ -1161,8 +1174,8 @@ async def _trigger_thread_web_search( user_id: User who triggered the mention original_message: The original user message (for followup) """ - from app.models.thread import Thread from app.models.project import Project + from app.models.thread import Thread from app.services.job_service import JobService from app.services.platform_settings_service import is_web_search_available_sync @@ -1219,10 +1232,7 @@ async def _trigger_thread_web_search( ) if success: - logger.info( - f"Triggered web search for thread {thread_id}: " - f"job_id={job.id}, query={search_query[:50]}..." - ) + logger.info(f"Triggered web search for thread {thread_id}: job_id={job.id}, query={search_query[:50]}...") else: logger.error(f"Failed to publish web search job for thread {thread_id}") # Clear searching state if job failed to publish diff --git a/backend/workers/handlers/generation.py b/backend/workers/handlers/generation.py index 37c8ee8..61b60c7 100644 --- a/backend/workers/handlers/generation.py +++ b/backend/workers/handlers/generation.py @@ -10,8 +10,8 @@ from sqlalchemy.orm import Session -from app.models.job import JobStatus from app.agents.brainstorm_spec import JobCancelledException +from app.models.job import JobStatus from workers.core.helpers import clear_phase_generation_flag logger = logging.getLogger(__name__) @@ -62,15 +62,11 @@ async def module_feature_generate_handler(payload: dict, db: Session) -> dict: project_id = UUID(project_id_str) brainstorming_phase_id = UUID(brainstorming_phase_id_str) job_id = UUID(job_id_str) if job_id_str else None - created_by_user_id = ( - UUID(created_by_user_id_str) if created_by_user_id_str else None - ) + created_by_user_id = UUID(created_by_user_id_str) if created_by_user_id_str else None except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Generating modules/features for phase {brainstorming_phase_id} in project {project_id}" - ) + logger.info(f"Generating modules/features for phase {brainstorming_phase_id} in project {project_id}") # Update job status to running if job_id: @@ -90,12 +86,8 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Update job's result field with progress if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) - logger.info( - f"Progress update: {workflow_step} - {progress_percentage}%" - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) + logger.info(f"Progress update: {workflow_step} - {progress_percentage}%") except Exception as e: logger.warning(f"Failed to update job progress: {e}") @@ -141,13 +133,9 @@ def progress_callback(workflow_step: str, progress_percentage: int): return {"cancelled": True, "brainstorming_phase_id": str(brainstorming_phase_id)} except Exception as e: - logger.error( - f"Module/feature generation failed for phase {brainstorming_phase_id}: {e}" - ) + logger.error(f"Module/feature generation failed for phase {brainstorming_phase_id}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: @@ -187,18 +175,18 @@ async def feature_content_generate_handler(payload: dict, db: Session) -> dict: ValueError: If feature not found, no thread, or LLM config missing """ from app.agents.feature_content import ( - FeatureContentOrchestrator, ContentType, + FeatureContentOrchestrator, load_feature_context, ) from app.agents.llm_client import LLMCallLogger + from app.database import SessionLocal from app.models.feature_content_version import FeatureContentType from app.services.feature_content_version_service import ( FeatureContentVersionService, ) from app.services.job_service import JobService from app.services.platform_settings_service import require_llm_config_sync - from app.database import SessionLocal feature_id_str = payload.get("feature_id") content_type_str = payload.get("content_type") @@ -214,12 +202,8 @@ async def feature_content_generate_handler(payload: dict, db: Session) -> dict: try: feature_id = UUID(feature_id_str) job_id = UUID(job_id_str) if job_id_str else None - created_by_user_id = ( - UUID(created_by_user_id_str) if created_by_user_id_str else None - ) - implementation_id = ( - UUID(implementation_id_str) if implementation_id_str else None - ) + created_by_user_id = UUID(created_by_user_id_str) if created_by_user_id_str else None + implementation_id = UUID(implementation_id_str) if implementation_id_str else None # Map content type string to agent enum content_type_map = { @@ -228,9 +212,7 @@ async def feature_content_generate_handler(payload: dict, db: Session) -> dict: } content_type = content_type_map.get(content_type_str) if not content_type: - raise ValueError( - f"Invalid content_type: {content_type_str}. Must be 'spec' or 'prompt_plan'" - ) + raise ValueError(f"Invalid content_type: {content_type_str}. Must be 'spec' or 'prompt_plan'") # Also map to service enum for version creation service_content_type_map = { @@ -273,9 +255,7 @@ def progress_callback(progress_data: dict): # Include implementation_id if generating for a specific implementation if implementation_id: result_data["implementation_id"] = str(implementation_id) - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result=result_data - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result=result_data) logger.info( f"Progress update: {progress_data.get('workflow_step')} - " f"{progress_data.get('progress_percentage')}%" @@ -290,6 +270,7 @@ def progress_callback(progress_data: dict): # Look up implementation name if generating for a specific implementation if implementation_id: from app.models.implementation import Implementation + impl = db.query(Implementation).filter(Implementation.id == implementation_id).first() if impl: context.implementation_name = impl.name @@ -336,10 +317,12 @@ def progress_callback(progress_data: dict): } # Report progress: saving content - progress_callback({ - "workflow_step": "saving_content", - "progress_percentage": 90, - }) + progress_callback( + { + "workflow_step": "saving_content", + "progress_percentage": 90, + } + ) # Store content in Implementation or FeatureContentVersion if implementation_id: @@ -375,13 +358,12 @@ def progress_callback(progress_data: dict): if chain_prompt_plan and content_type_str == "spec": # Queue prompt_plan generation job - from app.models.job import JobType - from app.services.kafka_producer import get_sync_kafka_producer - # Get feature, module, and project for chaining from app.models.feature import Feature + from app.models.job import JobType from app.models.module import Module from app.models.project import Project + from app.services.kafka_producer import get_sync_kafka_producer feature = db.query(Feature).filter(Feature.id == feature_id).first() if not feature: @@ -436,12 +418,9 @@ def progress_callback(progress_data: dict): db.commit() else: # Mark the chained job as failed if Kafka publish fails - logger.error( - f"Failed to publish chained prompt_plan job {chained_job.id} to Kafka" - ) + logger.error(f"Failed to publish chained prompt_plan job {chained_job.id} to Kafka") JobService.update_job_status( - db, chained_job.id, JobStatus.FAILED, - error_message="Failed to publish job to Kafka" + db, chained_job.id, JobStatus.FAILED, error_message="Failed to publish job to Kafka" ) # Build result for implementation @@ -474,10 +453,12 @@ def progress_callback(progress_data: dict): } # Report progress: complete - progress_callback({ - "workflow_step": "complete", - "progress_percentage": 100, - }) + progress_callback( + { + "workflow_step": "complete", + "progress_percentage": 100, + } + ) # Include LLM usage stats usage_stats = orchestrator.get_usage_stats() @@ -517,15 +498,14 @@ def progress_callback(progress_data: dict): except Exception as e: logger.error(f"Feature content generation failed for feature {feature_id}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: # Always clear generation flag when job completes (success or failure) if implementation_id: from workers.core.helpers import clear_implementation_generation_flag + if content_type_str == "spec": clear_implementation_generation_flag(db, implementation_id, "is_generating_spec") elif content_type_str == "prompt_plan": @@ -554,23 +534,22 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - "progress": {...} } """ - from app.services.user_question_session_service import UserQuestionSessionService - from app.services.brainstorming_phase_service import ( - BrainstormingPhaseService, - _build_existing_conversation_context, - ) - from app.services.job_service import JobService - from app.services.platform_settings_service import require_llm_config_sync + from app.agents.brainstorm_conversation.input_validator import InputValidatorAgent from app.agents.brainstorm_conversation.orchestrator import BrainstormConversationOrchestrator from app.agents.brainstorm_conversation.types import ( BrainstormConversationContext, - UserInitiatedContext, PhaseType, + UserInitiatedContext, ) - from app.agents.llm_client import LiteLLMChatCompletionClient, LLMCallLogger, create_litellm_client - from app.agents.brainstorm_conversation.input_validator import InputValidatorAgent + from app.agents.llm_client import create_litellm_client from app.models.brainstorming_phase import BrainstormingPhase from app.models.project import Project + from app.services.brainstorming_phase_service import ( + _build_existing_conversation_context, + ) + from app.services.job_service import JobService + from app.services.platform_settings_service import require_llm_config_sync + from app.services.user_question_session_service import UserQuestionSessionService session_id_str = payload.get("session_id") user_prompt = payload.get("user_prompt", "") @@ -588,10 +567,7 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Generating user-initiated questions for session {session_id}, " - f"prompt: {user_prompt[:50]}..." - ) + logger.info(f"Generating user-initiated questions for session {session_id}, prompt: {user_prompt[:50]}...") # Update job status to running if job_id: @@ -611,9 +587,7 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - project_id = UUID(session_context["project_id"]) # Get phase and project - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == phase_id - ).first() + phase = db.query(BrainstormingPhase).filter(BrainstormingPhase.id == phase_id).first() if not phase: raise ValueError(f"Phase {phase_id} not found") @@ -659,9 +633,7 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - ) if not validation_result.is_valid: - logger.info( - f"Input validation failed for session {session_id}: {validation_result.reason}" - ) + logger.info(f"Input validation failed for session {session_id}: {validation_result.reason}") # Store clarification message as assistant response (appears in chat UI) UserQuestionSessionService.add_assistant_message( @@ -674,7 +646,7 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - # Extract validation LLM usage validation_usage = None - if hasattr(validation_client, 'get_usage_stats'): + if hasattr(validation_client, "get_usage_stats"): usage_stats = validation_client.get_usage_stats() validation_usage = { "model": usage_stats.get("model"), @@ -743,21 +715,20 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - generated_questions = [] for aspect in result.aspects: for question in aspect.clarification_questions: - generated_questions.append({ - "temp_id": str(uuid4()), - "aspect_title": aspect.title, - "title": question.title, - "description": question.description, - "priority": question.priority.value, - "mcq": { - "question_text": question.initial_mcq.question_text, - "choices": [ - {"id": c.id, "label": c.label} - for c in question.initial_mcq.choices - ], - "explanation": question.initial_mcq.explanation, - }, - }) + generated_questions.append( + { + "temp_id": str(uuid4()), + "aspect_title": aspect.title, + "title": question.title, + "description": question.description, + "priority": question.priority.value, + "mcq": { + "question_text": question.initial_mcq.question_text, + "choices": [{"id": c.id, "label": c.label} for c in question.initial_mcq.choices], + "explanation": question.initial_mcq.explanation, + }, + } + ) # Store assistant message with generated questions assistant_content = f"Generated {len(generated_questions)} question(s) based on your request." @@ -774,7 +745,11 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - # Extract LLM usage stats from orchestrator's model client llm_usage = None - if hasattr(orchestrator, 'model_client') and orchestrator.model_client and hasattr(orchestrator.model_client, 'get_usage_stats'): + if ( + hasattr(orchestrator, "model_client") + and orchestrator.model_client + and hasattr(orchestrator.model_client, "get_usage_stats") + ): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -819,6 +794,6 @@ async def user_initiated_question_generate_handler(payload: dict, db: Session) - job_id, JobStatus.FAILED, error_message=str(e), - result={"error": str(e), "session_id": str(session_id)} + result={"error": str(e), "session_id": str(session_id)}, ) raise diff --git a/backend/workers/handlers/grounding.py b/backend/workers/handlers/grounding.py index 6638a4b..2363476 100644 --- a/backend/workers/handlers/grounding.py +++ b/backend/workers/handlers/grounding.py @@ -22,6 +22,7 @@ def _get_grounding_branch_lock_id(project_id: UUID, user_id: UUID, branch_name: Combines project, user, and branch to create a unique lock key. """ import zlib + key = f"grounding_branch:{project_id}:{user_id}:{branch_name}" return zlib.crc32(key.encode()) & 0x7FFFFFFF @@ -98,9 +99,7 @@ async def grounding_update_handler(payload: dict, db: Session) -> dict: f"branch {branch_name}, triggered by feature {feature_id}" ) else: - logger.info( - f"Updating grounding file for project {project_id} triggered by feature {feature_id}" - ) + logger.info(f"Updating grounding file for project {project_id} triggered by feature {feature_id}") # Acquire blocking advisory lock to ensure sequential execution. # For branch updates, use a branch-specific lock. For global, use project lock. @@ -135,9 +134,7 @@ async def grounding_update_handler(payload: dict, db: Session) -> dict: # Get current agents.md content (branch-specific or global) if is_branch_update: - agents_file = GroundingService.get_or_create_branch_file( - db, project_id, user_id, branch_name, repo_path - ) + agents_file = GroundingService.get_or_create_branch_file(db, project_id, user_id, branch_name, repo_path) else: agents_file = GroundingService.get_or_create_agents_md(db, project_id) current_agents_md = agents_file.content @@ -145,6 +142,7 @@ async def grounding_update_handler(payload: dict, db: Session) -> dict: # Determine which notes to use - implementation-level takes precedence if implementation_id: from app.models.implementation import Implementation + impl = db.query(Implementation).filter(Implementation.id == implementation_id).first() feature_notes = impl.implementation_notes if impl else "" logger.info(f"Using implementation notes for grounding (impl_id={implementation_id})") @@ -176,12 +174,8 @@ def progress_callback(workflow_step: str, progress_percentage: int): if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) - logger.info( - f"Progress update: {workflow_step} - {progress_percentage}%" - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) + logger.info(f"Progress update: {workflow_step} - {progress_percentage}%") except Exception as e: logger.warning(f"Failed to update job progress: {e}") @@ -189,11 +183,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Create LLM call logger for this job from app.database import SessionLocal - llm_call_logger = ( - LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) - if job_id - else None - ) + llm_call_logger = LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) if job_id else None # Create orchestrator orchestrator = await create_orchestrator( @@ -222,9 +212,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): repo_path=repo_path, ) # Broadcast branch file update to WebSocket clients - GroundingService._broadcast_branch_grounding_update( - db, project_id, grounding_file, "written" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, grounding_file, "written") logger.info(f"Updated branch agents.md ({branch_name}): {result.summary}") else: grounding_file = GroundingService.update_file( @@ -236,9 +224,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): summary=result.content_summary, ) # Broadcast grounding file update to WebSocket clients - GroundingService._broadcast_grounding_update( - db, project_id, grounding_file, "written" - ) + GroundingService._broadcast_grounding_update(db, project_id, grounding_file, "written") logger.info(f"Updated agents.md: {result.summary}") else: if is_branch_update: @@ -251,26 +237,18 @@ def progress_callback(workflow_step: str, progress_percentage: int): from app.models.implementation import Implementation from app.services.implementation_service import ImplementationService - impl = db.query(Implementation).filter( - Implementation.id == implementation_id - ).first() + impl = db.query(Implementation).filter(Implementation.id == implementation_id).first() if impl: impl.completion_summary = result.completion_summary db.commit() db.refresh(impl) - logger.info( - f"Updated completion_summary for implementation {implementation_id}" - ) + logger.info(f"Updated completion_summary for implementation {implementation_id}") # Broadcast implementation update for real-time UI refresh - ImplementationService.broadcast_implementation_updated( - db, impl, "completion_summary" - ) + ImplementationService.broadcast_implementation_updated(db, impl, "completion_summary") # Extract LLM usage stats before closing orchestrator llm_usage = None - if hasattr(orchestrator, "model_client") and hasattr( - orchestrator.model_client, "get_usage_stats" - ): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -302,9 +280,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): branch_info = f", branch {branch_name}" if is_branch_update else "" logger.error(f"Grounding update failed for project {project_id}{branch_info}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise @@ -392,9 +368,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info(f"Progress update: {workflow_step} - {progress_percentage}%") except Exception as e: logger.warning(f"Failed to update job progress: {e}") @@ -403,11 +377,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Create LLM call logger for this job from app.database import SessionLocal - llm_call_logger = ( - LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) - if job_id - else None - ) + llm_call_logger = LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) if job_id else None # Create orchestrator orchestrator = await create_orchestrator( @@ -422,22 +392,16 @@ def progress_callback(workflow_step: str, progress_percentage: int): summary = await orchestrator.summarize_content(content, progress_callback) # Update the summary only (not content) - grounding_file = GroundingService.update_summary( - db, project_id, "agents.md", summary - ) + grounding_file = GroundingService.update_summary(db, project_id, "agents.md", summary) # Broadcast the update - GroundingService._broadcast_grounding_update( - db, project_id, grounding_file, "summary_updated" - ) + GroundingService._broadcast_grounding_update(db, project_id, grounding_file, "summary_updated") logger.info(f"Updated agents.md summary ({len(summary)} chars)") # Extract LLM usage stats llm_usage = None - if hasattr(orchestrator, "model_client") and hasattr( - orchestrator.model_client, "get_usage_stats" - ): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -461,9 +425,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): except Exception as e: logger.error(f"Grounding summarization failed for project {project_id}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise @@ -520,10 +482,7 @@ async def grounding_branch_summarize_handler(payload: dict, db: Session) -> dict except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Summarizing branch agents.md for project {project_id}, " - f"user {user_id}, branch {branch_name}" - ) + logger.info(f"Summarizing branch agents.md for project {project_id}, user {user_id}, branch {branch_name}") # Acquire branch-specific advisory lock lock_id = _get_grounding_branch_lock_id(project_id, user_id, branch_name) @@ -564,9 +523,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info(f"Progress update: {workflow_step} - {progress_percentage}%") except Exception as e: logger.warning(f"Failed to update job progress: {e}") @@ -575,11 +532,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Create LLM call logger for this job from app.database import SessionLocal - llm_call_logger = ( - LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) - if job_id - else None - ) + llm_call_logger = LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) if job_id else None # Create orchestrator orchestrator = await create_orchestrator( @@ -594,22 +547,16 @@ def progress_callback(workflow_step: str, progress_percentage: int): summary = await orchestrator.summarize_content(content, progress_callback) # Update the branch summary only (not content) - grounding_file = GroundingService.update_branch_summary( - db, project_id, user_id, branch_name, summary - ) + grounding_file = GroundingService.update_branch_summary(db, project_id, user_id, branch_name, summary) # Broadcast the update - GroundingService._broadcast_branch_grounding_update( - db, project_id, grounding_file, "summary_updated" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, grounding_file, "summary_updated") logger.info(f"Updated branch agents.md summary ({len(summary)} chars)") # Extract LLM usage stats llm_usage = None - if hasattr(orchestrator, "model_client") and hasattr( - orchestrator.model_client, "get_usage_stats" - ): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -632,14 +579,9 @@ def progress_callback(workflow_step: str, progress_percentage: int): return result_dict except Exception as e: - logger.error( - f"Branch grounding summarization failed for project {project_id}, " - f"branch {branch_name}: {e}" - ) + logger.error(f"Branch grounding summarization failed for project {project_id}, branch {branch_name}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise @@ -698,10 +640,7 @@ async def grounding_merge_handler(payload: dict, db: Session) -> dict: except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Merging branch agents.md into global for project {project_id}, " - f"branch {branch_name}" - ) + logger.info(f"Merging branch agents.md into global for project {project_id}, branch {branch_name}") # Acquire GLOBAL grounding lock (we're updating global file) lock_id = get_grounding_lock_id(project_id) @@ -744,9 +683,7 @@ async def grounding_merge_handler(payload: dict, db: Session) -> dict: if branch_file.is_merging: branch_file.is_merging = False db.commit() - GroundingService._broadcast_branch_grounding_update( - db, project_id, branch_file, "merging_completed" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, branch_file, "merging_completed") return { "merged": False, "summary": "Branch already merged with no new changes", @@ -769,9 +706,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info(f"Progress update: {workflow_step} - {progress_percentage}%") except Exception as e: logger.warning(f"Failed to update job progress: {e}") @@ -780,11 +715,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Create LLM call logger for this job from app.database import SessionLocal - llm_call_logger = ( - LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) - if job_id - else None - ) + llm_call_logger = LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) if job_id else None # Create merge orchestrator orchestrator = await create_merge_orchestrator( @@ -815,9 +746,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): summary=result.content_summary, ) # Broadcast global file update - GroundingService._broadcast_grounding_update( - db, project_id, global_file, "written" - ) + GroundingService._broadcast_grounding_update(db, project_id, global_file, "written") logger.info(f"Updated global agents.md with merged content: {result.summary}") # Use the updated global's content_updated_at for sync tracking global_sync_time = global_file.content_updated_at @@ -829,20 +758,19 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Mark branch as merged with the global's content_updated_at merged_branch = GroundingService.mark_branch_merged( - db, project_id, user_id, branch_name, + db, + project_id, + user_id, + branch_name, global_content_updated_at=global_sync_time, ) # Broadcast branch status update - GroundingService._broadcast_branch_grounding_update( - db, project_id, merged_branch, "merged" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, merged_branch, "merged") logger.info(f"Marked branch {branch_name} as merged") # Extract LLM usage stats llm_usage = None - if hasattr(orchestrator, "model_client") and hasattr( - orchestrator.model_client, "get_usage_stats" - ): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -870,34 +798,24 @@ def progress_callback(workflow_step: str, progress_percentage: int): return result_dict except Exception as e: - logger.error( - f"Grounding merge failed for project {project_id}, branch {branch_name}: {e}" - ) + logger.error(f"Grounding merge failed for project {project_id}, branch {branch_name}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: # Always clear is_merging flag when job completes (success or failure) try: # Re-fetch branch file to ensure we have fresh state - branch_file = GroundingService.get_branch_file( - db, project_id, user_id, branch_name - ) + branch_file = GroundingService.get_branch_file(db, project_id, user_id, branch_name) if branch_file and branch_file.is_merging: branch_file.is_merging = False db.commit() # Broadcast the is_merging flag change - GroundingService._broadcast_branch_grounding_update( - db, project_id, branch_file, "merging_completed" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, branch_file, "merging_completed") logger.info(f"Cleared is_merging flag for branch {branch_name}") except Exception as cleanup_error: - logger.error( - f"Failed to clear is_merging flag for branch {branch_name}: {cleanup_error}" - ) + logger.error(f"Failed to clear is_merging flag for branch {branch_name}: {cleanup_error}") async def grounding_pull_handler(payload: dict, db: Session) -> dict: @@ -954,10 +872,7 @@ async def grounding_pull_handler(payload: dict, db: Session) -> dict: except ValueError as e: raise ValueError(f"Invalid UUID format: {e}") - logger.info( - f"Pulling global agents.md into branch for project {project_id}, " - f"branch {branch_name}" - ) + logger.info(f"Pulling global agents.md into branch for project {project_id}, branch {branch_name}") # Acquire BRANCH-specific lock (we're updating the branch file, not global) lock_id = _get_grounding_branch_lock_id(project_id, user_id, branch_name) @@ -1002,9 +917,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): if job_id: try: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, result={"progress": latest_progress} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"progress": latest_progress}) logger.info(f"Progress update: {workflow_step} - {progress_percentage}%") except Exception as e: logger.warning(f"Failed to update job progress: {e}") @@ -1013,11 +926,7 @@ def progress_callback(workflow_step: str, progress_percentage: int): # Create LLM call logger for this job from app.database import SessionLocal - llm_call_logger = ( - LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) - if job_id - else None - ) + llm_call_logger = LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id) if job_id else None # Create merge orchestrator (reuse for pull) orchestrator = await create_merge_orchestrator( @@ -1051,26 +960,18 @@ def progress_callback(workflow_step: str, progress_percentage: int): logger.info(f"Updated branch agents.md with pulled content: {result.summary}") else: # Content didn't change, but we still need to get the branch file - branch_file = GroundingService.get_branch_file( - db, project_id, user_id, branch_name - ) + branch_file = GroundingService.get_branch_file(db, project_id, user_id, branch_name) logger.info(f"No changes after pull for branch {branch_name}") # Always mark branch as synced with global after pull (even if no content changed) if branch_file: - GroundingService.mark_branch_synced_with_global( - db, branch_file, global_file.content_updated_at - ) + GroundingService.mark_branch_synced_with_global(db, branch_file, global_file.content_updated_at) # Broadcast branch file update - GroundingService._broadcast_branch_grounding_update( - db, project_id, branch_file, "pulled" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, branch_file, "pulled") # Extract LLM usage stats llm_usage = None - if hasattr(orchestrator, "model_client") and hasattr( - orchestrator.model_client, "get_usage_stats" - ): + if hasattr(orchestrator, "model_client") and hasattr(orchestrator.model_client, "get_usage_stats"): usage_stats = orchestrator.model_client.get_usage_stats() llm_usage = { "model": usage_stats.get("model"), @@ -1098,31 +999,21 @@ def progress_callback(workflow_step: str, progress_percentage: int): return result_dict except Exception as e: - logger.error( - f"Grounding pull failed for project {project_id}, branch {branch_name}: {e}" - ) + logger.error(f"Grounding pull failed for project {project_id}, branch {branch_name}: {e}") if job_id: - JobService.update_job_status( - db, job_id, JobStatus.FAILED, result={"error": str(e)} - ) + JobService.update_job_status(db, job_id, JobStatus.FAILED, result={"error": str(e)}) raise finally: # Always clear is_merging flag when job completes (success or failure) try: # Re-fetch branch file to ensure we have fresh state - branch_file = GroundingService.get_branch_file( - db, project_id, user_id, branch_name - ) + branch_file = GroundingService.get_branch_file(db, project_id, user_id, branch_name) if branch_file and branch_file.is_merging: branch_file.is_merging = False db.commit() # Broadcast the is_merging flag change - GroundingService._broadcast_branch_grounding_update( - db, project_id, branch_file, "pulling_completed" - ) + GroundingService._broadcast_branch_grounding_update(db, project_id, branch_file, "pulling_completed") logger.info(f"Cleared is_merging flag for branch {branch_name}") except Exception as cleanup_error: - logger.error( - f"Failed to clear is_merging flag for branch {branch_name}: {cleanup_error}" - ) \ No newline at end of file + logger.error(f"Failed to clear is_merging flag for branch {branch_name}: {cleanup_error}") diff --git a/backend/workers/handlers/image_annotator.py b/backend/workers/handlers/image_annotator.py index 11c6369..7ceb43d 100644 --- a/backend/workers/handlers/image_annotator.py +++ b/backend/workers/handlers/image_annotator.py @@ -6,8 +6,6 @@ import base64 import logging -from datetime import datetime, timezone -from typing import Dict, Any from uuid import UUID from sqlalchemy.orm import Session @@ -47,11 +45,11 @@ async def image_annotate_handler(payload: dict, db: Session) -> dict: ValueError: If required fields are missing or LLM config missing """ from app.agents.image_annotator import handle_image_annotation - from app.agents.llm_client import create_litellm_client, LLMCallLogger + from app.agents.llm_client import LLMCallLogger, create_litellm_client + from app.database import SessionLocal from app.services.image_service import ImageService from app.services.job_service import JobService from app.services.platform_settings_service import require_llm_config_sync - from app.database import SessionLocal # Extract payload fields image_id = payload.get("image_id") @@ -79,10 +77,7 @@ async def image_annotate_handler(payload: dict, db: Session) -> dict: # Update job status to running if job_id: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, - result={"image_id": image_id} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"image_id": image_id}) try: # Fetch image bytes from S3 @@ -142,8 +137,7 @@ async def image_annotate_handler(payload: dict, db: Session) -> dict: ) logger.info( - f"Image annotation generated: image_id={image_id}, " - f"annotation_length={len(result.get('annotation', ''))}" + f"Image annotation generated: image_id={image_id}, annotation_length={len(result.get('annotation', ''))}" ) return result diff --git a/backend/workers/handlers/integration.py b/backend/workers/handlers/integration.py index 62aac02..6e5229e 100644 --- a/backend/workers/handlers/integration.py +++ b/backend/workers/handlers/integration.py @@ -52,9 +52,7 @@ def bugsync_handler(payload: dict, db: Session) -> dict: triggered_by = payload.get("triggered_by", "system") - logger.info( - f"Syncing bug ticket for project {project_id}, triggered by {triggered_by}" - ) + logger.info(f"Syncing bug ticket for project {project_id}, triggered by {triggered_by}") # We need to use async context for the bug sync service # Since this handler is synchronous but uses async services @@ -128,9 +126,7 @@ def notification_fanout_handler(payload: dict, db: Session) -> dict: raise ValueError("Missing required fields in payload") project_id = UUID(project_id_str) - logger.info( - f"Fanning out notifications for event {event_type} in project {project_id}" - ) + logger.info(f"Fanning out notifications for event {event_type} in project {project_id}") # Initialize adapters (mock mode for now) adapters = { @@ -220,12 +216,11 @@ async def mention_notification_handler(payload: dict, db: Session) -> dict: "skipped_reasons": dict } """ - from datetime import timezone from app.database import get_async_session_context - from app.models import Thread, User, Project, NotificationChannel + from app.models import NotificationChannel, Project, Thread, User + from app.models.feature import Feature from app.models.thread import ContextType from app.models.thread_item import ThreadItem - from app.models.feature import Feature, FeatureType from app.services.notification_service import NotificationService thread_id = payload.get("thread_id") @@ -237,10 +232,7 @@ async def mention_notification_handler(payload: dict, db: Session) -> dict: if not all([thread_id, item_id, author_id, project_id]): raise ValueError("Missing required fields in payload") - logger.info( - f"Processing mention notifications for thread {thread_id}, " - f"{len(mentioned_user_ids)} mentions" - ) + logger.info(f"Processing mention notifications for thread {thread_id}, {len(mentioned_user_ids)} mentions") # Load thread and related data thread = db.query(Thread).filter(Thread.id == thread_id).first() @@ -293,11 +285,7 @@ async def mention_notification_handler(payload: dict, db: Session) -> dict: recent_messages = [] for item in recent_items: item_author = db.query(User).filter(User.id == item.author_id).first() - author_display = ( - item_author.display_name or item_author.email.split("@")[0] - if item_author - else "Unknown" - ) + author_display = item_author.display_name or item_author.email.split("@")[0] if item_author else "Unknown" # Get body from content_data and clean markdown for display from app.services.mention_utils import clean_markdown_for_display @@ -309,11 +297,13 @@ async def mention_notification_handler(payload: dict, db: Session) -> dict: # Format timestamp as relative time timestamp = _format_relative_time(item.created_at) - recent_messages.append({ - "author": author_display, - "body_preview": body, - "timestamp": timestamp, - }) + recent_messages.append( + { + "author": author_display, + "body_preview": body, + "timestamp": timestamp, + } + ) # Build view URL with item_id for deep linking view_url = _build_view_url(thread, feature, project_id, db, item_id=item_id) @@ -438,16 +428,18 @@ def _build_view_url( item_id: Optional thread item ID for deep linking to specific message """ from app.config import settings - from app.models.thread import ContextType + from app.models.brainstorming_phase import BrainstormingPhase from app.models.feature import FeatureType from app.models.module import Module - from app.models.project import Project - from app.models.brainstorming_phase import BrainstormingPhase from app.models.platform_settings import PlatformSettings + from app.models.project import Project + from app.models.thread import ContextType # Get base URL from platform settings (UI-configured) with fallback to frontend_url config platform_settings = db.query(PlatformSettings).first() - base_url = (platform_settings.base_url if platform_settings and platform_settings.base_url else settings.frontend_url).rstrip("/") + base_url = ( + platform_settings.base_url if platform_settings and platform_settings.base_url else settings.frontend_url + ).rstrip("/") # Get project for URL identifier project = db.query(Project).filter(Project.id == project_id).first() @@ -459,9 +451,9 @@ def _build_view_url( # Load module to get phase for URL module = db.query(Module).filter(Module.id == feature.module_id).first() if module and module.brainstorming_phase_id: - phase = db.query(BrainstormingPhase).filter( - BrainstormingPhase.id == module.brainstorming_phase_id - ).first() + phase = ( + db.query(BrainstormingPhase).filter(BrainstormingPhase.id == module.brainstorming_phase_id).first() + ) if phase: url = f"{base_url}/projects/{project_url_id}/brainstorming/{phase.url_identifier}/conversations?feature={feature.url_identifier}" if item_id: @@ -526,13 +518,12 @@ async def project_chat_mention_notification_handler(payload: dict, db: Session) "skipped_reasons": dict } """ - from datetime import timezone + from app.config import settings from app.database import get_async_session_context - from app.models import User, Project, NotificationChannel + from app.models import NotificationChannel, Project, User + from app.models.platform_settings import PlatformSettings from app.models.project_chat import ProjectChat, ProjectChatMessage from app.services.notification_service import NotificationService - from app.config import settings - from app.models.platform_settings import PlatformSettings project_chat_id = payload.get("project_chat_id") message_id = payload.get("message_id") @@ -549,16 +540,12 @@ async def project_chat_mention_notification_handler(payload: dict, db: Session) ) # Load discussion - discussion = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not discussion: raise ValueError(f"ProjectChat {project_chat_id} not found") # Load the triggering message - trigger_message = db.query(ProjectChatMessage).filter( - ProjectChatMessage.id == message_id - ).first() + trigger_message = db.query(ProjectChatMessage).filter(ProjectChatMessage.id == message_id).first() if not trigger_message: raise ValueError(f"ProjectChatMessage {message_id} not found") @@ -597,11 +584,7 @@ async def project_chat_mention_notification_handler(payload: dict, db: Session) for msg in recent_messages_list: if msg.message_type.value == "user": msg_author = db.query(User).filter(User.id == msg.created_by).first() - author_display = ( - msg_author.display_name or msg_author.email.split("@")[0] - if msg_author - else "Unknown" - ) + author_display = msg_author.display_name or msg_author.email.split("@")[0] if msg_author else "Unknown" else: author_display = "MFBTAI" @@ -614,15 +597,19 @@ async def project_chat_mention_notification_handler(payload: dict, db: Session) # Format timestamp as relative time timestamp = _format_relative_time(msg.created_at) - recent_messages.append({ - "author": author_display, - "body_preview": body, - "timestamp": timestamp, - }) + recent_messages.append( + { + "author": author_display, + "body_preview": body, + "timestamp": timestamp, + } + ) # Build view URL with message_id for deep linking platform_settings = db.query(PlatformSettings).first() - base_url = (platform_settings.base_url if platform_settings and platform_settings.base_url else settings.frontend_url).rstrip("/") + base_url = ( + platform_settings.base_url if platform_settings and platform_settings.base_url else settings.frontend_url + ).rstrip("/") view_url = f"{base_url}/projects/{project.url_identifier}/project-chat/{discussion.url_identifier}" if message_id: view_url += f"?item={message_id}" diff --git a/backend/workers/handlers/project_chat.py b/backend/workers/handlers/project_chat.py index 6eaf6c8..ab4a535 100644 --- a/backend/workers/handlers/project_chat.py +++ b/backend/workers/handlers/project_chat.py @@ -6,8 +6,7 @@ import logging import re -from datetime import datetime, timezone -from typing import Dict, Any, Optional +from typing import Optional from uuid import UUID from sqlalchemy.orm import Session @@ -45,14 +44,13 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: Raises: ValueError: If discussion not found or LLM config missing """ + from app.agents.llm_client import LLMCallLogger, create_litellm_client from app.agents.project_chat_assistant import handle_user_message - from app.agents.llm_client import create_litellm_client, LLMCallLogger - from app.services.project_chat_service import ProjectChatService + from app.database import SessionLocal + from app.models.project_chat import ProjectChat from app.services.job_service import JobService from app.services.platform_settings_service import require_llm_config_sync - from app.models.project_chat import ProjectChat - from app.models.project import Project - from app.database import SessionLocal + from app.services.project_chat_service import ProjectChatService project_chat_id_str = payload.get("project_chat_id") user_message = payload.get("user_message") @@ -90,9 +88,7 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: ) # Get discussion - discussion = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not discussion: raise ValueError(f"Discussion {project_chat_id} not found") @@ -110,19 +106,16 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: # Update job status to running if job_id: - JobService.update_job_status( - db, job_id, JobStatus.RUNNING, - result={"project_chat_id": str(project_chat_id)} - ) + JobService.update_job_status(db, job_id, JobStatus.RUNNING, result={"project_chat_id": str(project_chat_id)}) # Run gating agent to decide if AI should respond # Skip gating for MCQ answers (user explicitly clicked an option), exploration follow-ups, and web search follow-ups is_mcq_answer = payload.get("is_mcq_answer", False) if not is_mcq_answer and not is_exploration_followup and not is_web_search_followup: + from app.agents.llm_client import LLMCallLogger, create_litellm_client from app.agents.project_chat_gating import should_ai_respond - from app.agents.llm_client import create_litellm_client, LLMCallLogger - from app.services.platform_settings_service import require_llm_config_sync from app.database import SessionLocal + from app.services.platform_settings_service import require_llm_config_sync platform_llm_config = require_llm_config_sync(db) @@ -147,9 +140,7 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: ) # Get recent messages for context (last 3 messages) - recent_messages = ProjectChatService.get_conversation_history( - db, project_chat_id, limit=3 - ) + recent_messages = ProjectChatService.get_conversation_history(db, project_chat_id, limit=3) # Append current user message (not yet saved to DB) recent_messages.append({"type": "user", "content": user_message}) @@ -187,11 +178,7 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: # Create LLM client with call logger # For org-scoped discussions, project_id may be None - call_logger = LLMCallLogger( - db_session_factory=SessionLocal, - job_id=job_id, - project_id=effective_project_id - ) + call_logger = LLMCallLogger(db_session_factory=SessionLocal, job_id=job_id, project_id=effective_project_id) model_client = create_litellm_client( provider=llm_provider, model=llm_model, @@ -204,9 +191,9 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: if is_exploration_followup: from app.models.code_exploration_result import CodeExplorationResult - exploration = db.query(CodeExplorationResult).filter( - CodeExplorationResult.id == exploration_result_id - ).first() + exploration = ( + db.query(CodeExplorationResult).filter(CodeExplorationResult.id == exploration_result_id).first() + ) if not exploration: raise ValueError(f"Exploration result {exploration_result_id} not found") @@ -217,7 +204,9 @@ async def project_chat_respond_handler(payload: dict, db: Session) -> dict: effective_user_message = ( "Based on the code exploration results above, please provide your response to my original question." ) - logger.info(f"Processing exploration follow-up - exploration results are in conversation history ({len(exploration.output or '')} chars)") + logger.info( + f"Processing exploration follow-up - exploration results are in conversation history ({len(exploration.output or '')} chars)" + ) elif is_web_search_followup: # Web search results are in the conversation history as a WEB_SEARCH message # Use the original message from payload for context, or a generic prompt @@ -299,6 +288,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: reply_text_content = result["reply_text"] if isinstance(reply_text_content, str): import json + stripped = reply_text_content.strip() inner_json_str = None @@ -308,7 +298,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: # Check for markdown-wrapped JSON (starts with ```) elif stripped.startswith("```"): # Extract JSON from markdown code block - inner_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```\s*$', stripped) + inner_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```\s*$", stripped) if inner_match: inner_json_str = inner_match.group(1).strip() @@ -336,10 +326,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: # Defensive: Strip any lingering [MFBT_WEB_SEARCH] blocks that might still appear # (LLM confusion / backward compatibility) reply_text_content = re.sub( - r"\[MFBT_WEB_SEARCH\].*?\[/MFBT_WEB_SEARCH\]", - "", - reply_text_content, - flags=re.DOTALL + r"\[MFBT_WEB_SEARCH\].*?\[/MFBT_WEB_SEARCH\]", "", reply_text_content, flags=re.DOTALL ).strip() # Strip any code exploration blocks from reply text @@ -385,8 +372,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: # Allow exploration if prompt is different or if it's not a follow-up # This prevents true infinite loops while allowing re-exploration with different topics prompts_are_similar = ( - last_prompt and new_prompt and - last_prompt.lower().strip() == new_prompt.lower().strip() + last_prompt and new_prompt and last_prompt.lower().strip() == new_prompt.lower().strip() ) if prompts_are_similar and is_exploration_followup: @@ -396,9 +382,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: ) else: if is_exploration_followup: - logger.info( - f"Allowing code exploration in follow-up with different prompt: {new_prompt[:100]}" - ) + logger.info(f"Allowing code exploration in follow-up with different prompt: {new_prompt[:100]}") await _trigger_code_exploration( db=db, discussion=discussion, @@ -412,8 +396,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: # Prevent infinite loops - don't search if this is already a web search follow-up with same query if is_web_search_followup: logger.warning( - f"Ignoring web search request during web search follow-up - " - f"query: {web_search_query[:100]}" + f"Ignoring web search request during web search follow-up - query: {web_search_query[:100]}" ) else: await _trigger_project_chat_web_search( @@ -469,9 +452,7 @@ async def _on_retry(attempt: int, max_attempts: int) -> None: finally: # Always clear retry_status on completion (success or failure) # Re-query the discussion to avoid stale object issues after multiple commits - fresh_discussion = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + fresh_discussion = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if fresh_discussion and fresh_discussion.retry_status is not None: fresh_discussion.retry_status = None db.commit() @@ -489,9 +470,9 @@ async def _trigger_code_exploration( Publishes a code exploration job and updates discussion state. """ - from app.models.project import Project - from app.models.platform_settings import PlatformSettings from app.models.job import Job, JobType + from app.models.platform_settings import PlatformSettings + from app.models.project import Project from app.services.project_chat_service import ProjectChatService if not discussion.project_id: @@ -597,8 +578,8 @@ async def _trigger_project_chat_web_search( differently - as a ProjectChatMessage with type WEB_SEARCH. """ from app.models.job import Job, JobType - from app.services.project_chat_service import ProjectChatService from app.services.platform_settings_service import is_web_search_available_sync + from app.services.project_chat_service import ProjectChatService if not search_query: logger.warning("Cannot trigger web search without query") @@ -661,10 +642,7 @@ async def _trigger_project_chat_web_search( key=str(discussion.id), # Partition by discussion ) - logger.info( - f"Triggered web search for discussion {discussion.id}: " - f"query={search_query[:50]}..." - ) + logger.info(f"Triggered web search for discussion {discussion.id}: query={search_query[:50]}...") # Broadcast updated discussion state db.refresh(discussion) diff --git a/backend/workers/handlers/web_search.py b/backend/workers/handlers/web_search.py index e478be4..6b297d2 100644 --- a/backend/workers/handlers/web_search.py +++ b/backend/workers/handlers/web_search.py @@ -4,7 +4,6 @@ Executes web searches via Tavily and creates thread items or project-chat messages with results. """ -import asyncio import logging from typing import Any, Dict, Optional from uuid import UUID @@ -14,8 +13,8 @@ from app.models.job import JobType from app.models.thread import Thread from app.services.job_service import JobService -from app.services.thread_service import ThreadService from app.services.platform_settings_service import get_web_search_config_sync +from app.services.thread_service import ThreadService from app.services.web_search_service import WebSearchService logger = logging.getLogger(__name__) @@ -172,9 +171,7 @@ async def _handle_project_chat_web_search( logger.info(f"Executing web search for project-chat discussion {project_chat_id}: {search_query}") # Get the discussion - discussion = db.query(ProjectChat).filter( - ProjectChat.id == project_chat_id - ).first() + discussion = db.query(ProjectChat).filter(ProjectChat.id == project_chat_id).first() if not discussion: raise ValueError(f"Pre-phase discussion {project_chat_id} not found") @@ -284,10 +281,10 @@ def _format_web_search_for_conversation( lines.append("\n**Sources:**") for i, result in enumerate(results[:5], 1): lines.append(f"{i}. [{result['title']}]({result['url']})") - if result.get('content'): + if result.get("content"): # Truncate content to first 600 chars - content = result['content'][:600] - if len(result['content']) > 600: + content = result["content"][:600] + if len(result["content"]) > 600: content += "..." lines.append(f" > {content}") @@ -330,6 +327,7 @@ async def _trigger_thread_followup_response( # Broadcast thread_generation_status_changed so frontend shows AI spinner from app.services.kafka_producer import get_sync_kafka_producer + message = { "type": "thread_generation_status_changed", "org_id": str(project.org_id), diff --git a/backend/workers/scheduler.py b/backend/workers/scheduler.py index f02b55a..eb18342 100644 --- a/backend/workers/scheduler.py +++ b/backend/workers/scheduler.py @@ -13,10 +13,11 @@ - Weekly freemium top-up: Every Monday at 00:00 UTC - Daily usage aggregation: Every day at 00:05 UTC """ + import logging import sys import time -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, timezone from apscheduler.schedulers.blocking import BlockingScheduler from apscheduler.triggers.cron import CronTrigger @@ -65,9 +66,7 @@ def run_weekly_topup(): db = SessionLocal() try: # Try to acquire advisory lock to prevent duplicate runs across instances - acquired = db.execute( - text(f"SELECT pg_try_advisory_xact_lock({lock_id})") - ).scalar() + acquired = db.execute(text(f"SELECT pg_try_advisory_xact_lock({lock_id})")).scalar() if not acquired: logger.info("Weekly top-up: another instance is already running, skipping") @@ -79,16 +78,11 @@ def run_weekly_topup(): # Get freemium settings settings = get_freemium_settings_sync(db) logger.info( - f"Freemium settings: weekly_topup={settings['weekly_topup_tokens']:,}, " - f"max={settings['max_tokens']:,}" + f"Freemium settings: weekly_topup={settings['weekly_topup_tokens']:,}, max={settings['max_tokens']:,}" ) # Get all freemium organizations - freemium_orgs = ( - db.query(Organization) - .filter(Organization.plan_name == "freemium") - .all() - ) + freemium_orgs = db.query(Organization).filter(Organization.plan_name == "freemium").all() logger.info(f"Found {len(freemium_orgs)} freemium organizations") topped_up = 0 @@ -107,10 +101,7 @@ def run_weekly_topup(): logger.debug(f"Skipped org {org.id} ({org.name}): already topped up") elapsed = (datetime.now(timezone.utc) - start_time).total_seconds() - logger.info( - f"Weekly top-up complete: {topped_up} topped up, {skipped} skipped " - f"(elapsed: {elapsed:.2f}s)" - ) + logger.info(f"Weekly top-up complete: {topped_up} topped up, {skipped} skipped (elapsed: {elapsed:.2f}s)") except Exception as e: logger.exception(f"Error during weekly top-up: {e}") @@ -144,9 +135,7 @@ def run_daily_aggregation(): db = SessionLocal() try: # Try to acquire advisory lock to prevent duplicate runs - acquired = db.execute( - text(f"SELECT pg_try_advisory_xact_lock({lock_id})") - ).scalar() + acquired = db.execute(text(f"SELECT pg_try_advisory_xact_lock({lock_id})")).scalar() if not acquired: logger.info("Daily aggregation: another instance is already running, skipping") @@ -201,15 +190,10 @@ def _run_aggregation_with_retry(db, target_date, retry_delays): attempt += 1 if attempt < max_attempts: delay = retry_delays[attempt - 1] - logger.warning( - f"Aggregation attempt {attempt} failed: {e}. " - f"Retrying in {delay}s..." - ) + logger.warning(f"Aggregation attempt {attempt} failed: {e}. Retrying in {delay}s...") time.sleep(delay) else: - logger.error( - f"Aggregation failed after {max_attempts} attempts: {e}" - ) + logger.error(f"Aggregation failed after {max_attempts} attempts: {e}") return False return False @@ -237,9 +221,7 @@ def run_daily_recommendation_evaluation(): db = SessionLocal() try: # Try to acquire advisory lock to prevent duplicate runs - acquired = db.execute( - text(f"SELECT pg_try_advisory_xact_lock({lock_id})") - ).scalar() + acquired = db.execute(text(f"SELECT pg_try_advisory_xact_lock({lock_id})")).scalar() if not acquired: logger.info("Recommendation evaluation: another instance is already running, skipping") diff --git a/code-explorer/app/routes/explore.py b/code-explorer/app/routes/explore.py index 58df553..a9f9db2 100644 --- a/code-explorer/app/routes/explore.py +++ b/code-explorer/app/routes/explore.py @@ -4,7 +4,7 @@ import time from typing import Literal -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter from pydantic import BaseModel from app.config import settings @@ -97,13 +97,15 @@ async def explore(request: ExploreRequest) -> ExploreResponse: elif request.repo_url: # Legacy single-repo format - convert to repos list slug = request.repo_url.rstrip("/").split("/")[-1].replace(".git", "") - repos_list = [{ - "slug": slug, - "repo_url": request.repo_url, - "branch": request.branch, - "github_token": request.github_token, - "user_remarks": None, - }] + repos_list = [ + { + "slug": slug, + "repo_url": request.repo_url, + "branch": request.branch, + "github_token": request.github_token, + "user_remarks": None, + } + ] logger.info(f"Legacy single-repo request: {request.repo_url}") else: return ExploreResponse( @@ -127,6 +129,7 @@ async def explore(request: ExploreRequest) -> ExploreResponse: # For multi-repo: use the project's repos directory (parent of all repos) # For single-repo: use the repo's worktree directly from pathlib import Path + if len(repos_list) == 1: exploration_root = list(worktree_paths.values())[0] else: @@ -140,7 +143,7 @@ async def explore(request: ExploreRequest) -> ExploreResponse: timeout = request.timeout_seconds or settings.grounding_timeout_seconds max_turns = request.max_turns or settings.grounding_max_turns max_output_lines = request.max_output_lines or settings.grounding_max_output_lines - logger.info(f"Running Claude Code grounding generation...") + logger.info("Running Claude Code grounding generation...") else: # Standard exploration mode timeout = request.timeout_seconds or settings.claude_timeout_seconds @@ -165,7 +168,9 @@ async def explore(request: ExploreRequest) -> ExploreResponse: ) execution_time = time.time() - start_time - logger.info(f"{'Grounding generation' if request.mode == 'grounding_generate' else 'Exploration'} completed in {execution_time:.2f}s") + logger.info( + f"{'Grounding generation' if request.mode == 'grounding_generate' else 'Exploration'} completed in {execution_time:.2f}s" + ) return ExploreResponse( success=True, diff --git a/code-explorer/app/services/claude_runner.py b/code-explorer/app/services/claude_runner.py index f53a948..e73de1f 100644 --- a/code-explorer/app/services/claude_runner.py +++ b/code-explorer/app/services/claude_runner.py @@ -4,7 +4,6 @@ import json import logging import os -import re from pathlib import Path from typing import Literal @@ -211,7 +210,11 @@ async def run_exploration( else: # For exploration: include paths so Claude knows where to look repo_context = self._build_repo_context(repos, include_paths=True) - multi_repo_note = "\nExplore all repositories as needed. Be specific about which repo files come from.\n" if is_multi_repo else "" + multi_repo_note = ( + "\nExplore all repositories as needed. Be specific about which repo files come from.\n" + if is_multi_repo + else "" + ) wrapped_prompt = f"""You are exploring a codebase to answer a question. {repo_context}{multi_repo_note}IMPORTANT: Keep your final response concise - maximum 3000 characters. Focus on the most relevant findings and be specific about file paths and code snippets. @@ -292,21 +295,21 @@ async def run_exploration( "Anthropic API authentication failed. Please check the API key configuration." ) if "rate limit" in error_msg.lower(): - raise self.ClaudeAPIError( - "Anthropic API rate limit exceeded. Please try again later." - ) + raise self.ClaudeAPIError("Anthropic API rate limit exceeded. Please try again later.") raise self.ClaudeError(f"Claude Code execution failed: {error_msg}") # Parse the JSON output logger.info("Parsing Claude Code output...") result = self._parse_output(stdout_str) - logger.info(f"Parsed result: output_length={len(result.get('output') or '')}, " - f"raw_output_length={len(result.get('raw_output') or '')}, " - f"prompt_tokens={result.get('prompt_tokens')}, " - f"completion_tokens={result.get('completion_tokens')}") + logger.info( + f"Parsed result: output_length={len(result.get('output') or '')}, " + f"raw_output_length={len(result.get('raw_output') or '')}, " + f"prompt_tokens={result.get('prompt_tokens')}, " + f"completion_tokens={result.get('completion_tokens')}" + ) # Log warning if output is empty but we have tokens - if not result.get('output') and result.get('completion_tokens', 0) > 0: + if not result.get("output") and result.get("completion_tokens", 0) > 0: logger.warning(f"Empty output despite {result.get('completion_tokens')} completion tokens!") logger.debug(f"Raw stdout (first 2000 chars): {stdout_str[:2000]}") return result @@ -339,17 +342,17 @@ def _parse_output(self, output: str) -> dict: if data.get("type") == "result": result_text = data.get("result", "") subtype = data.get("subtype") - logger.debug(f"Found result type message, subtype={subtype}, result field: '{result_text[:200] if result_text else '(empty)'}...'") + logger.debug( + f"Found result type message, subtype={subtype}, result field: '{result_text[:200] if result_text else '(empty)'}...'" + ) logger.debug(f"Result JSON keys: {list(data.keys())}") if subtype == "error_max_turns": - logger.warning(f"Claude Code hit max_turns limit (num_turns={data.get('num_turns')}). Consider increasing max_turns.") - if "usage" in data: - usage_info["prompt_tokens"] = data["usage"].get( - "input_tokens", 0 - ) - usage_info["completion_tokens"] = data["usage"].get( - "output_tokens", 0 + logger.warning( + f"Claude Code hit max_turns limit (num_turns={data.get('num_turns')}). Consider increasing max_turns." ) + if "usage" in data: + usage_info["prompt_tokens"] = data["usage"].get("input_tokens", 0) + usage_info["completion_tokens"] = data["usage"].get("output_tokens", 0) break except json.JSONDecodeError: diff --git a/code-explorer/app/services/worktree.py b/code-explorer/app/services/worktree.py index 0e9becd..ed58f1d 100644 --- a/code-explorer/app/services/worktree.py +++ b/code-explorer/app/services/worktree.py @@ -44,9 +44,7 @@ def _sanitize_branch_name(self, branch: str) -> str: # Replace / with - and remove other problematic characters return re.sub(r"[^a-zA-Z0-9_-]", "-", branch) - def _get_authenticated_url( - self, repo_url: str, github_token: str | None - ) -> str: + def _get_authenticated_url(self, repo_url: str, github_token: str | None) -> str: """Get the authenticated URL for cloning.""" if not github_token: return repo_url @@ -85,9 +83,7 @@ async def _run_git_command( ) try: - stdout, stderr = await asyncio.wait_for( - proc.communicate(), timeout=timeout - ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) except asyncio.TimeoutError: proc.kill() raise self.GitError(f"Git command timed out: {' '.join(args)}") @@ -170,9 +166,7 @@ async def _fetch(self, project_dir: Path) -> None: timeout=120, ) - async def _create_worktree( - self, project_dir: Path, worktree_dir: Path, branch: str - ) -> None: + async def _create_worktree(self, project_dir: Path, worktree_dir: Path, branch: str) -> None: """Create a new worktree for a branch.""" logger.info(f"Creating worktree for branch {branch} at {worktree_dir}") git_dir = project_dir / ".git" diff --git a/frontend/app/admin/analytics/AdminAnalyticsClient.tsx b/frontend/app/admin/analytics/AdminAnalyticsClient.tsx index 2b51ce6..bfee037 100644 --- a/frontend/app/admin/analytics/AdminAnalyticsClient.tsx +++ b/frontend/app/admin/analytics/AdminAnalyticsClient.tsx @@ -39,7 +39,7 @@ function AdminAnalyticsContent() { const [timeRange, setTimeRange] = useState( urlTimeRange && ["daily", "weekly", "monthly", "yearly", "all"].includes(urlTimeRange) ? urlTimeRange - : "monthly" + : "monthly", ); const [isLoading, setIsLoading] = useState(true); @@ -58,12 +58,15 @@ function AdminAnalyticsContent() { }, [isPlatformAdmin, adminLoading, router]); // Update URL when time range changes - const handleTimeRangeChange = useCallback((newRange: TimeRange) => { - setTimeRange(newRange); - const params = new URLSearchParams(searchParams.toString()); - params.set("range", newRange); - router.replace(`/admin/analytics?${params.toString()}`, { scroll: false }); - }, [router, searchParams]); + const handleTimeRangeChange = useCallback( + (newRange: TimeRange) => { + setTimeRange(newRange); + const params = new URLSearchParams(searchParams.toString()); + params.set("range", newRange); + router.replace(`/admin/analytics?${params.toString()}`, { scroll: false }); + }, + [router, searchParams], + ); // Fetch data when time range changes useEffect(() => { @@ -122,7 +125,7 @@ function AdminAnalyticsContent() { if (adminLoading) { return ( -
+

Loading...

); @@ -130,7 +133,7 @@ function AdminAnalyticsContent() { if (!isPlatformAdmin) { return ( -
+

Access denied

); @@ -140,7 +143,7 @@ function AdminAnalyticsContent() {

Analytics Dashboard

-

+

Platform usage analytics and plan recommendations

@@ -154,22 +157,21 @@ function AdminAnalyticsContent() { onClick={handleFlushCache} disabled={isFlushing || isLoading} > - + {isFlushing ? "Refreshing..." : "Refresh Data"}
{topUsers && ( -
- Total platform usage: {formatCredits(topUsers.total_platform_tokens)} credits +
+ Total platform usage:{" "} + + {formatCredits(topUsers.total_platform_tokens)} credits +
)}
- {error && ( -
- {error} -
- )} + {error &&
{error}
} {isLoading ? ( @@ -192,7 +194,7 @@ function AdminAnalyticsContent() { {topUsers && topUsers.users.length > 0 ? ( ) : ( -
+
No data available for this time range
)} @@ -202,7 +204,7 @@ function AdminAnalyticsContent() { {topProjects && topProjects.projects.length > 0 ? ( ) : ( -
+
No data available for this time range
)} diff --git a/frontend/app/agent-log/AgentLogClient.tsx b/frontend/app/agent-log/AgentLogClient.tsx index 342029e..d77df75 100644 --- a/frontend/app/agent-log/AgentLogClient.tsx +++ b/frontend/app/agent-log/AgentLogClient.tsx @@ -26,12 +26,8 @@ export default function AgentLogClient() { const [total, setTotal] = useState(0); // Filter state - default to current user - const [selectedUserId, setSelectedUserId] = useState( - user?.id || null - ); - const [selectedProjectId, setSelectedProjectId] = useState( - null - ); + const [selectedUserId, setSelectedUserId] = useState(user?.id || null); + const [selectedProjectId, setSelectedProjectId] = useState(null); // Update selectedUserId when user loads useEffect(() => { @@ -65,7 +61,7 @@ export default function AgentLogClient() { } const response = await apiClient.get( - `/api/v1/platform/agent-logs?${params.toString()}` + `/api/v1/platform/agent-logs?${params.toString()}`, ); setJobs(response.items); setTotal(response.total); @@ -98,15 +94,15 @@ export default function AgentLogClient() { // Show loading while checking trial status or not exempt if (trialLoading || !trialStatus?.is_strictly_exempt) { return ( -
- +
+
); } if (error) { return ( -
+

{error}

); @@ -116,7 +112,7 @@ export default function AgentLogClient() {

Agent Log

-

+

View LLM call details for debugging, troubleshooting, and analysis

@@ -131,16 +127,16 @@ export default function AgentLogClient() { /> {loading ? ( -
- +
+
) : jobs.length === 0 ? (
-
- +
+

No agent logs found

-

+

{selectedUserId || selectedProjectId ? "Try adjusting your filters to see more results." : "LLM call logs will appear here after running generation tasks like brainstorming, spec generation, or feature extraction."} diff --git a/frontend/app/dashboard/page.tsx b/frontend/app/dashboard/page.tsx index d1374e8..6936159 100644 --- a/frontend/app/dashboard/page.tsx +++ b/frontend/app/dashboard/page.tsx @@ -46,8 +46,7 @@ export default function DashboardPage() { prevStats.llm_usage_this_month.total_prompt_tokens + usageLog.prompt_tokens, total_completion_tokens: prevStats.llm_usage_this_month.total_completion_tokens + usageLog.completion_tokens, - total_tokens: - prevStats.llm_usage_this_month.total_tokens + newTokens, + total_tokens: prevStats.llm_usage_this_month.total_tokens + newTokens, total_cost_usd: usageLog.cost_usd !== null ? (prevStats.llm_usage_this_month.total_cost_usd || 0) + usageLog.cost_usd @@ -85,7 +84,7 @@ export default function DashboardPage() { try { const data = await apiClient.get( - `/api/v1/dashboard/orgs/${user.org_id}/stats` + `/api/v1/dashboard/orgs/${user.org_id}/stats`, ); setStats(data); } catch (err) { @@ -105,7 +104,7 @@ export default function DashboardPage() { if (authLoading) { return (

- +
); } @@ -119,7 +118,7 @@ export default function DashboardPage() { if (loading) { return (
- +
); } @@ -130,7 +129,7 @@ export default function DashboardPage() {

Dashboard

-

{error}

+

{error}

); @@ -141,7 +140,7 @@ export default function DashboardPage() {

Dashboard

-

+

Organization overview and LLM usage metrics

@@ -149,7 +148,7 @@ export default function DashboardPage() { {stats && ( <> {isEnterprise ? ( -
+
@@ -168,9 +167,7 @@ export default function DashboardPage() { /> )} - {isEnterprise && ( - - )} + {isEnterprise && } diff --git a/frontend/app/email-templates/EmailTemplatesClient.tsx b/frontend/app/email-templates/EmailTemplatesClient.tsx index 4689d02..a595418 100644 --- a/frontend/app/email-templates/EmailTemplatesClient.tsx +++ b/frontend/app/email-templates/EmailTemplatesClient.tsx @@ -90,7 +90,7 @@ function EmailTemplatesContent() { if (adminLoading) { return ( -
+

Loading...

); @@ -98,33 +98,33 @@ function EmailTemplatesContent() { if (!isPlatformAdmin) { return ( -
+

Access denied

); } return ( -
+
{/* Header */} -
+

Email Templates

-

+

Manage email templates sent by the platform

{/* Error display */} {error && ( -
-

{error}

+
+

{error}

)} {/* Main content area */}
{/* Sidebar */} -
+
) : ( -
+

Select a template from the sidebar to edit

)} diff --git a/frontend/app/globals.css b/frontend/app/globals.css index f8a8d56..c52876a 100644 --- a/frontend/app/globals.css +++ b/frontend/app/globals.css @@ -64,43 +64,43 @@ .dark { --background: oklch(0.2303 0.0125 264.2926); --foreground: oklch(0.9219 0 0); - --card: oklch(0.3210 0.0078 223.6661); + --card: oklch(0.321 0.0078 223.6661); --card-foreground: oklch(0.9219 0 0); - --popover: oklch(0.3210 0.0078 223.6661); + --popover: oklch(0.321 0.0078 223.6661); --popover-foreground: oklch(0.9219 0 0); --primary: oklch(0.5676 0.2021 283.0838); - --primary-foreground: oklch(1.0000 0 0); - --secondary: oklch(0.3390 0.1793 301.6848); + --primary-foreground: oklch(1 0 0); + --secondary: oklch(0.339 0.1793 301.6848); --secondary-foreground: oklch(0.9219 0 0); --muted: oklch(0.3867 0 0); --muted-foreground: oklch(0.7155 0 0); --accent: oklch(0.45 0 0); --accent-foreground: oklch(0.9219 0 0); --destructive: oklch(0.6368 0.2078 25.3313); - --destructive-foreground: oklch(1.0000 0 0); + --destructive-foreground: oklch(1 0 0); --border: oklch(0.3867 0 0); --input: oklch(0.3867 0 0); --ring: oklch(0.5676 0.2021 283.0838); --chart-1: oklch(0.5676 0.2021 283.0838); --chart-2: oklch(0.5261 0.1705 314.6534); - --chart-3: oklch(0.3390 0.1793 301.6848); - --chart-4: oklch(0.6746 0.1414 261.3380); - --chart-5: oklch(0.5880 0.0993 245.7394); + --chart-3: oklch(0.339 0.1793 301.6848); + --chart-4: oklch(0.6746 0.1414 261.338); + --chart-5: oklch(0.588 0.0993 245.7394); --sidebar: oklch(0.2303 0.0125 264.2926); --sidebar-foreground: oklch(0.9219 0 0); --sidebar-primary: oklch(0.5676 0.2021 283.0838); - --sidebar-primary-foreground: oklch(1.0000 0 0); + --sidebar-primary-foreground: oklch(1 0 0); --sidebar-accent: oklch(0.45 0 0); --sidebar-accent-foreground: oklch(0.9219 0 0); --sidebar-border: oklch(0.3867 0 0); --sidebar-ring: oklch(0.5676 0.2021 283.0838); --shadow-2xs: 0px 5px 10px -2px hsl(0 0% 0% / 0.05); --shadow-xs: 0px 5px 10px -2px hsl(0 0% 0% / 0.05); - --shadow-sm: 0px 5px 10px -2px hsl(0 0% 0% / 0.10), 0px 1px 2px -3px hsl(0 0% 0% / 0.10); - --shadow: 0px 5px 10px -2px hsl(0 0% 0% / 0.10), 0px 1px 2px -3px hsl(0 0% 0% / 0.10); - --shadow-md: 0px 5px 10px -2px hsl(0 0% 0% / 0.10), 0px 2px 4px -3px hsl(0 0% 0% / 0.10); - --shadow-lg: 0px 5px 10px -2px hsl(0 0% 0% / 0.10), 0px 4px 6px -3px hsl(0 0% 0% / 0.10); - --shadow-xl: 0px 5px 10px -2px hsl(0 0% 0% / 0.10), 0px 8px 10px -3px hsl(0 0% 0% / 0.10); + --shadow-sm: 0px 5px 10px -2px hsl(0 0% 0% / 0.1), 0px 1px 2px -3px hsl(0 0% 0% / 0.1); + --shadow: 0px 5px 10px -2px hsl(0 0% 0% / 0.1), 0px 1px 2px -3px hsl(0 0% 0% / 0.1); + --shadow-md: 0px 5px 10px -2px hsl(0 0% 0% / 0.1), 0px 2px 4px -3px hsl(0 0% 0% / 0.1); + --shadow-lg: 0px 5px 10px -2px hsl(0 0% 0% / 0.1), 0px 4px 6px -3px hsl(0 0% 0% / 0.1); + --shadow-xl: 0px 5px 10px -2px hsl(0 0% 0% / 0.1), 0px 8px 10px -3px hsl(0 0% 0% / 0.1); --shadow-2xl: 0px 5px 10px -2px hsl(0 0% 0% / 0.25); } @@ -158,7 +158,6 @@ --shadow-2xl: var(--shadow-2xl); } - /* Base styles with lower specificity than utilities */ @layer base { *, @@ -350,7 +349,8 @@ body { /* Highlight animation for deep-linked items (email notification links) */ @keyframes highlight-pulse { - 0%, 100% { + 0%, + 100% { background-color: rgb(245 158 11 / 0.15); /* amber-500 with low opacity */ } 50% { diff --git a/frontend/app/inbox/[...path]/page.tsx b/frontend/app/inbox/[...path]/page.tsx index ba4427f..9f005d2 100644 --- a/frontend/app/inbox/[...path]/page.tsx +++ b/frontend/app/inbox/[...path]/page.tsx @@ -26,16 +26,13 @@ export default function InboxDeepLinkPage() { const conversationType = pathSegments?.[0]; const conversationId = pathSegments?.[1]; const messageSequenceStr = pathSegments?.[2]; - const messageSequence = messageSequenceStr - ? parseInt(messageSequenceStr, 10) - : undefined; + const messageSequence = messageSequenceStr ? parseInt(messageSequenceStr, 10) : undefined; - const { isLoading, error, redirectUrl, projectName, conversationTitle } = - useInboxDeepLink( - conversationType, - conversationId, - isNaN(messageSequence as number) ? undefined : messageSequence - ); + const { isLoading, error, redirectUrl, projectName, conversationTitle } = useInboxDeepLink( + conversationType, + conversationId, + isNaN(messageSequence as number) ? undefined : messageSequence, + ); // Redirect unauthenticated users to login useEffect(() => { @@ -57,10 +54,8 @@ export default function InboxDeepLinkPage() { if (authLoading) { return (
- -

- Checking authentication... -

+ +

Checking authentication...

); } @@ -74,10 +69,8 @@ export default function InboxDeepLinkPage() { if (isLoading) { return (
- -

- Resolving link... -

+ +

Resolving link...

); } @@ -87,14 +80,14 @@ export default function InboxDeepLinkPage() { return (
-
- +
+

Unable to Open Link

-

{error}

+

{error}

-
+
@@ -109,13 +102,11 @@ export default function InboxDeepLinkPage() { if (redirectUrl) { return (
- -

+ +

Redirecting to {conversationTitle || "conversation"}...

- {projectName && ( -

in {projectName}

- )} + {projectName &&

in {projectName}

}
); } diff --git a/frontend/app/inbox/page.tsx b/frontend/app/inbox/page.tsx index 96bf6f8..f86246e 100644 --- a/frontend/app/inbox/page.tsx +++ b/frontend/app/inbox/page.tsx @@ -18,7 +18,7 @@ export default function InboxPage() { // Show loading state while redirecting return (
- +
); } diff --git a/frontend/app/invite/accept/InviteRedirectClient.tsx b/frontend/app/invite/accept/InviteRedirectClient.tsx index 1a05724..4bdc534 100644 --- a/frontend/app/invite/accept/InviteRedirectClient.tsx +++ b/frontend/app/invite/accept/InviteRedirectClient.tsx @@ -23,9 +23,9 @@ export default function InviteRedirectClient() { }, [searchParams, router]); return ( -
+
- +

Redirecting...

diff --git a/frontend/app/invites/[token]/InviteClient.tsx b/frontend/app/invites/[token]/InviteClient.tsx index e52abaf..9b67380 100644 --- a/frontend/app/invites/[token]/InviteClient.tsx +++ b/frontend/app/invites/[token]/InviteClient.tsx @@ -11,7 +11,14 @@ import { apiClient } from "@/lib/api/client"; import { useAuth } from "@/lib/auth/AuthContext"; import { InviteValidationResponse } from "@/lib/api/types"; -type PageState = "loading" | "invalid" | "not_logged_in" | "ready" | "accepting" | "success" | "error"; +type PageState = + | "loading" + | "invalid" + | "not_logged_in" + | "ready" + | "accepting" + | "success" + | "error"; export default function InviteClient() { const params = useParams(); @@ -88,10 +95,10 @@ export default function InviteClient() { // Loading state if (state === "loading" || authLoading) { return ( -
+
- +

Validating invitation...

@@ -102,12 +109,12 @@ export default function InviteClient() { // Invalid/Expired state if (state === "invalid") { return ( -
+
- -

Invalid Invitation

-

+ +

Invalid Invitation

+

{error || "This invitation link is invalid or has expired."}

))}
@@ -219,7 +211,7 @@ export default function LoginClient() {
- or + or
)} @@ -247,13 +239,9 @@ export default function LoginClient() { required />
- {error && ( -

{error}

- )} + {error &&

{error}

} @@ -270,14 +258,14 @@ export default function LoginClient() { ) : ( Sign up now. diff --git a/frontend/app/mcp-explorer/McpExplorerClient.tsx b/frontend/app/mcp-explorer/McpExplorerClient.tsx index a449b26..d65bf13 100644 --- a/frontend/app/mcp-explorer/McpExplorerClient.tsx +++ b/frontend/app/mcp-explorer/McpExplorerClient.tsx @@ -3,12 +3,7 @@ import { useEffect, useState } from "react"; import { useRouter } from "next/navigation"; import { FolderTree, Loader2 } from "lucide-react"; -import { - Card, - CardHeader, - CardTitle, - CardDescription, -} from "@/components/ui/card"; +import { Card, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; import { EmptyState } from "@/components/EmptyState"; import { ProtectedRoute } from "@/components/ProtectedRoute"; import { useTrial } from "@/lib/auth/TrialContext"; @@ -58,15 +53,15 @@ function MCPExplorerSelectorContent() { // Show loading while checking trial status or not exempt if (trialLoading || !trialStatus?.is_strictly_exempt) { return ( -
- +
+
); } if (isLoading) { return ( -
+

Loading projects...

); @@ -76,7 +71,7 @@ function MCPExplorerSelectorContent() {

MCP Explorer

-

+

Browse your project's virtual filesystem exposed by the MCP server.

@@ -94,7 +89,7 @@ function MCPExplorerSelectorContent() { {projects.map((project) => ( handleProjectClick(project.id)} > diff --git a/frontend/app/mcp-log/McpLogClient.tsx b/frontend/app/mcp-log/McpLogClient.tsx index 8547b63..6f469c3 100644 --- a/frontend/app/mcp-log/McpLogClient.tsx +++ b/frontend/app/mcp-log/McpLogClient.tsx @@ -31,7 +31,7 @@ export default function McpLogClient() { try { setLoading(true); const response = await apiClient.get( - `/api/v1/orgs/${user.org_id}/mcp-logs` + `/api/v1/orgs/${user.org_id}/mcp-logs`, ); setLogs(response.items); setError(null); @@ -49,23 +49,23 @@ export default function McpLogClient() { // Show loading while checking trial status or not exempt if (trialLoading || !trialStatus?.is_strictly_exempt) { return ( -
- +
+
); } if (loading) { return ( -
- +
+
); } if (error) { return ( -
+

{error}

); @@ -75,20 +75,20 @@ export default function McpLogClient() {

MCP Log

-

+

View MCP tool call history for debugging and monitoring

{logs.length === 0 ? (
-
- +
+

No MCP logs yet

-

- MCP call logs will appear here after agents access your project via - the MCP HTTP endpoint. +

+ MCP call logs will appear here after agents access your project via the MCP HTTP + endpoint.

) : ( diff --git a/frontend/app/members-groups/MembersGroupsClient.tsx b/frontend/app/members-groups/MembersGroupsClient.tsx index 7d8d73a..f5244fb 100644 --- a/frontend/app/members-groups/MembersGroupsClient.tsx +++ b/frontend/app/members-groups/MembersGroupsClient.tsx @@ -20,12 +20,12 @@ function MembersGroupsContent() { const orgId = user?.org_id; // Get user's role in the current org - const currentOrgMembership = userOrgs.find(m => m.org_id === orgId); + const currentOrgMembership = userOrgs.find((m) => m.org_id === orgId); const userRole = currentOrgMembership?.role || "member"; if (isLoading) { return ( -
+

Loading...

); @@ -33,7 +33,7 @@ function MembersGroupsContent() { if (!orgId) { return ( -
+

No organization found

); @@ -43,7 +43,7 @@ function MembersGroupsContent() {

Members & Groups

-

+

Manage organization members, invitations, and user groups

diff --git a/frontend/app/oauth/authorize/OAuthAuthorizeClient.tsx b/frontend/app/oauth/authorize/OAuthAuthorizeClient.tsx index f89f04a..bc420ca 100644 --- a/frontend/app/oauth/authorize/OAuthAuthorizeClient.tsx +++ b/frontend/app/oauth/authorize/OAuthAuthorizeClient.tsx @@ -59,9 +59,11 @@ export default function OAuthAuthorizeClient() { const [isLoadingProjects, setIsLoadingProjects] = useState(false); // Computed resource URL - either from params or constructed from selected project - const resource = resourceParam || (selectedProjectId - ? `${process.env.NEXT_PUBLIC_API_URL}/api/v1/projects/${selectedProjectId}/mcp` - : null); + const resource = + resourceParam || + (selectedProjectId + ? `${process.env.NEXT_PUBLIC_API_URL}/api/v1/projects/${selectedProjectId}/mcp` + : null); const [authInfo, setAuthInfo] = useState(null); const [isLoadingInfo, setIsLoadingInfo] = useState(false); @@ -132,21 +134,18 @@ export default function OAuthAuthorizeClient() { resource: resource, }); - const response = await fetch( - `${apiUrl}/api/v1/oauth/authorize/info?${params}`, - { - headers: { - Authorization: `Bearer ${token}`, - }, - } - ); + const response = await fetch(`${apiUrl}/api/v1/oauth/authorize/info?${params}`, { + headers: { + Authorization: `Bearer ${token}`, + }, + }); if (!response.ok) { const errorData = await response.json(); throw new Error( errorData.detail?.error_description || errorData.detail || - "Failed to fetch authorization info" + "Failed to fetch authorization info", ); } @@ -196,9 +195,7 @@ export default function OAuthAuthorizeClient() { if (!response.ok) { const errorData = await response.json(); throw new Error( - errorData.detail?.error_description || - errorData.detail || - "Authorization failed" + errorData.detail?.error_description || errorData.detail || "Authorization failed", ); } @@ -213,7 +210,7 @@ export default function OAuthAuthorizeClient() { // Show loading while checking auth if (authLoading || !isAuthenticated) { return ( -
+

Checking authentication...

@@ -223,13 +220,13 @@ export default function OAuthAuthorizeClient() { // Validate required params (except resource which can be selected) if (!clientId || !redirectUri || !codeChallenge) { return ( -
+
Invalid Request - Missing required OAuth parameters. Please try the authorization - flow again from your MCP client. + Missing required OAuth parameters. Please try the authorization flow again from your + MCP client. @@ -241,13 +238,13 @@ export default function OAuthAuthorizeClient() { const needsProjectSelection = !resourceParam && !selectedProjectId; return ( -
+
{/* MFBT Branding */}
-

+

mfbt.

-

+

move fast and build things

@@ -255,7 +252,7 @@ export default function OAuthAuthorizeClient() { {/* Authorization Card */} -
+
Authorization Request
@@ -277,14 +274,14 @@ export default function OAuthAuthorizeClient() {
) : projects.length === 0 ? ( -
+

You don't have access to any projects. Please create a project first.

) : (
-
+
MCP Client is requesting access @@ -319,13 +316,13 @@ export default function OAuthAuthorizeClient() {
) : error && !authInfo ? ( -
+

{error}

) : authInfo ? ( <> {/* Client Info */} -
+
@@ -334,16 +331,12 @@ export default function OAuthAuthorizeClient() {

{authInfo.client_name}

-

- MCP Client -

+

MCP Client

-
-

- Wants to access: -

+
+

Wants to access:

{authInfo.project_name}

@@ -383,7 +376,7 @@ export default function OAuthAuthorizeClient() {
{error && ( -
+

{error}

)} @@ -400,9 +393,9 @@ export default function OAuthAuthorizeClient() { disabled={isSubmitting} > {isSubmitting ? ( - + ) : ( - + )} Deny @@ -412,9 +405,9 @@ export default function OAuthAuthorizeClient() { disabled={isSubmitting} > {isSubmitting ? ( - + ) : ( - + )} Authorize @@ -423,9 +416,9 @@ export default function OAuthAuthorizeClient() { {/* Security note */} -

- Only authorize applications you trust. This access can be revoked at any - time from your account settings. +

+ Only authorize applications you trust. This access can be revoked at any time from your + account settings.

); diff --git a/frontend/app/page.tsx b/frontend/app/page.tsx index 84cbeb7..be30506 100644 --- a/frontend/app/page.tsx +++ b/frontend/app/page.tsx @@ -22,7 +22,7 @@ export default function Home() { // Show loading state while checking auth and redirecting return (
- +
); } diff --git a/frontend/app/platform-settings/PlatformSettingsClient.tsx b/frontend/app/platform-settings/PlatformSettingsClient.tsx index 23f3475..1873f25 100644 --- a/frontend/app/platform-settings/PlatformSettingsClient.tsx +++ b/frontend/app/platform-settings/PlatformSettingsClient.tsx @@ -39,7 +39,7 @@ function PlatformSettingsContent() { if (isLoading) { return ( -
+

Loading...

); @@ -47,7 +47,7 @@ function PlatformSettingsContent() { if (!isPlatformAdmin) { return ( -
+

Access denied

); @@ -57,13 +57,13 @@ function PlatformSettingsContent() {

Platform Settings

-

+

Manage platform-wide configurations for LLM providers, email, and storage

- + LLM Connectors LLM Settings Platform Connectors diff --git a/frontend/app/projects/ProjectsClient.tsx b/frontend/app/projects/ProjectsClient.tsx index eb52e39..2b230b5 100644 --- a/frontend/app/projects/ProjectsClient.tsx +++ b/frontend/app/projects/ProjectsClient.tsx @@ -1,7 +1,17 @@ "use client"; import { useEffect, useState } from "react"; -import { FolderPlus, Plus, Search, Trash2, Copy, PlusCircle, ChevronDown, Gamepad2, Pencil } from "lucide-react"; +import { + FolderPlus, + Plus, + Search, + Trash2, + Copy, + PlusCircle, + ChevronDown, + Gamepad2, + Pencil, +} from "lucide-react"; import { useRouter } from "next/navigation"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -113,7 +123,7 @@ function ProjectsPageContent() { (project) => project.name.toLowerCase().includes(query) || (project.short_description && project.short_description.toLowerCase().includes(query)) || - (project.owner?.email && project.owner.email.toLowerCase().includes(query)) + (project.owner?.email && project.owner.email.toLowerCase().includes(query)), ); setFilteredProjects(filtered); }, [searchQuery, projects]); @@ -154,7 +164,7 @@ function ProjectsPageContent() { // Check if project has collaborators (shares beyond owner) try { const shares = await apiClient.get<{ id: string; role: string }[]>( - `/api/v1/projects/${project.id}/shares` + `/api/v1/projects/${project.id}/shares`, ); // Has collaborators if more than 1 share (owner) exists setHasCollaborators(shares.length > 1); @@ -233,7 +243,7 @@ function ProjectsPageContent() { if (isLoading) { return ( -
+

Loading projects...

); @@ -241,12 +251,10 @@ function ProjectsPageContent() { return (
-
+

Projects

-

- Create and manage all your projects. -

+

Create and manage all your projects.

@@ -282,7 +290,7 @@ function ProjectsPageContent() { ) : (
- + {filteredProjects.length === 0 ? ( - + No projects found matching your search. @@ -346,7 +354,7 @@ function ProjectsPageContent() { onClick={(e) => handleDeleteClick(e, project)} title="Delete project" > - +
@@ -365,7 +373,8 @@ function ProjectsPageContent() { Delete Project? - Are you sure you want to delete {projectToDelete?.name}? This action cannot be undone. + Are you sure you want to delete {projectToDelete?.name}? This action + cannot be undone. @@ -400,10 +409,7 @@ function ProjectsPageContent() { setCloneOptions({ ...cloneOptions, includePhases: checked as boolean }) } /> -
@@ -417,10 +423,7 @@ function ProjectsPageContent() { setCloneOptions({ ...cloneOptions, includeThreads: checked as boolean }) } /> -
@@ -437,7 +440,7 @@ function ProjectsPageContent() { /> @@ -465,9 +468,7 @@ function ProjectsPageContent() { Rename Project - - Enter a new name for this project. - + Enter a new name for this project.
@@ -488,12 +489,10 @@ function ProjectsPageContent() { } }} /> -

- {renameValue.length}/255 characters -

+

{renameValue.length}/255 characters

{renameError && ( -
+
{renameError}
)} diff --git a/frontend/app/projects/[projectId]/activity/page.tsx b/frontend/app/projects/[projectId]/activity/page.tsx index 156e87f..eaa468d 100644 --- a/frontend/app/projects/[projectId]/activity/page.tsx +++ b/frontend/app/projects/[projectId]/activity/page.tsx @@ -37,7 +37,7 @@ export default function ProjectActivityPage() { setIsLoading(false); } }, - [projectId] + [projectId], ); useEffect(() => { @@ -52,9 +52,7 @@ export default function ProjectActivityPage() {

Activity

-

- Recent activity across this project -

+

Recent activity across this project

0 - ? (tabs.find((tab) => pathname.startsWith(tab.href))?.href || tabs[0].href) - : ""; + const activeTab = + tabs.length > 0 ? tabs.find((tab) => pathname.startsWith(tab.href))?.href || tabs[0].href : ""; if (isLoading || tabs.length === 0) { return (
- +
); } @@ -153,7 +164,7 @@ export default function BrainstormingLayoutClient({ children }: BrainstormingLay return (
{/* Breadcrumb */} -
+
- - {phase?.title || "Phase"} - + {phase?.title || "Phase"}
{/* Phase header */} @@ -172,10 +181,10 @@ export default function BrainstormingLayoutClient({ children }: BrainstormingLay

{phase?.title}

@@ -188,14 +197,14 @@ export default function BrainstormingLayoutClient({ children }: BrainstormingLay setShowRenameDialog(true)}> - + Rename Phase setShowArchiveDialog(true)} className="text-destructive focus:text-destructive focus:bg-destructive/10 dark:focus:bg-destructive/20" > - + Archive Phase @@ -218,7 +227,7 @@ export default function BrainstormingLayoutClient({ children }: BrainstormingLay "flex items-center gap-2 border-b-2 py-3 text-sm font-medium transition-colors", isActive ? "border-primary text-foreground" - : "border-transparent text-muted-foreground hover:text-foreground hover:border-border" + : "text-muted-foreground hover:text-foreground hover:border-border border-transparent", )} > @@ -239,8 +248,8 @@ export default function BrainstormingLayoutClient({ children }: BrainstormingLay Archive Phase? Are you sure you want to archive the phase{" "} - "{phase?.title}"? - + "{phase?.title}"? + Archived phases can be restored later. @@ -254,7 +263,7 @@ export default function BrainstormingLayoutClient({ children }: BrainstormingLay > {isArchiving ? ( <> - + Archiving... ) : ( diff --git a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/activity/page.tsx b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/activity/page.tsx index 7f61556..88b9af8 100644 --- a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/activity/page.tsx +++ b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/activity/page.tsx @@ -21,7 +21,11 @@ export default function BrainstormingPhaseActivityPage() { async (currentOffset: number) => { setIsLoading(true); try { - const response = await apiClient.getBrainstormingPhaseActivity(phaseId, PAGE_SIZE, currentOffset); + const response = await apiClient.getBrainstormingPhaseActivity( + phaseId, + PAGE_SIZE, + currentOffset, + ); if (currentOffset === 0) { setLogs(response.items); @@ -37,7 +41,7 @@ export default function BrainstormingPhaseActivityPage() { setIsLoading(false); } }, - [phaseId] + [phaseId], ); useEffect(() => { @@ -50,9 +54,7 @@ export default function BrainstormingPhaseActivityPage() { return (
-

- Activity history for this brainstorming phase -

+

Activity history for this brainstorming phase

= { - must_have: "bg-red-100 text-red-800 border-red-200 dark:bg-red-900/40 dark:text-red-400 dark:border-red-800", - important: "bg-orange-100 text-orange-800 border-orange-200 dark:bg-orange-900/40 dark:text-orange-400 dark:border-orange-800", - optional: "bg-blue-100 text-blue-800 border-blue-200 dark:bg-blue-900/40 dark:text-blue-400 dark:border-blue-800", + must_have: + "bg-red-100 text-red-800 border-red-200 dark:bg-red-900/40 dark:text-red-400 dark:border-red-800", + important: + "bg-orange-100 text-orange-800 border-orange-200 dark:bg-orange-900/40 dark:text-orange-400 dark:border-orange-800", + optional: + "bg-blue-100 text-blue-800 border-blue-200 dark:bg-blue-900/40 dark:text-blue-400 dark:border-blue-800", }; const priorityLabels: Record = { @@ -162,7 +165,7 @@ export default function BrainstormingConversationsPage() { } } }, - [fetchPhase] + [fetchPhase], ); useWebSocketSubscription({ @@ -176,7 +179,11 @@ export default function BrainstormingConversationsPage() { // WebSocket connection for real-time job updates // Use currentProject?.id (UUID) for filtering, not projectId (short URL identifier) - const wsJobs = useJobWebSocket({ orgId: orgId || "", projectId: currentProject?.id, enabled: !!orgId && !!currentProject?.id }); + const wsJobs = useJobWebSocket({ + orgId: orgId || "", + projectId: currentProject?.id, + enabled: !!orgId && !!currentProject?.id, + }); // Initial running job fetched on mount (used before WebSocket delivers updates) const [initialRunningJob, setInitialRunningJob] = useState(null); @@ -225,7 +232,7 @@ export default function BrainstormingConversationsPage() { (j) => j.job_type === "brainstorm_conversation_generate" || j.job_type === "brainstorm_conversation_batch_generate" || - j.job_type === "brainstorm_generate" + j.job_type === "brainstorm_generate", ); if (hasWsJob) { setInitialRunningJob(null); @@ -253,14 +260,14 @@ export default function BrainstormingConversationsPage() { // Fetch conversation-type modules for this phase (not implementation modules) const modulesData = await apiClient.get( - `/api/v1/projects/${projectId}/modules?brainstorming_phase_id=${phaseId}&module_type=conversation` + `/api/v1/projects/${projectId}/modules?brainstorming_phase_id=${phaseId}&module_type=conversation`, ); setModules(modulesData); // Fetch conversation-type features for these modules (not implementation features) if (modulesData.length > 0) { const featuresResponse = await apiClient.get<{ items: Feature[]; total: number }>( - `/api/v1/projects/${projectId}/features?brainstorming_phase_id=${phaseId}&feature_type=conversation` + `/api/v1/projects/${projectId}/features?brainstorming_phase_id=${phaseId}&feature_type=conversation`, ); setFeatures(featuresResponse.items); } else { @@ -298,7 +305,7 @@ export default function BrainstormingConversationsPage() { // Fall back to auto-selecting first feature // Find the first module that has features (matches display order) const firstModuleWithFeatures = modules.find((m) => - features.some((f) => f.module_id === m.id) + features.some((f) => f.module_id === m.id), ); if (firstModuleWithFeatures) { const firstFeature = features.find((f) => f.module_id === firstModuleWithFeatures.id); @@ -338,7 +345,9 @@ export default function BrainstormingConversationsPage() { // Use requestAnimationFrame to ensure DOM is ready const rafId = requestAnimationFrame(() => { // Find the module card element - const moduleCard = document.querySelector(`[data-module-id="${selectedFeature.module_id}"]`); + const moduleCard = document.querySelector( + `[data-module-id="${selectedFeature.module_id}"]`, + ); if (moduleCard && leftColumnRef.current) { const containerRect = leftColumnRef.current.getBoundingClientRect(); const moduleRect = moduleCard.getBoundingClientRect(); @@ -357,7 +366,7 @@ export default function BrainstormingConversationsPage() { (job.job_type === "brainstorm_conversation_generate" || job.job_type === "brainstorm_conversation_batch_generate" || job.job_type === "brainstorm_generate") && - job.status === "succeeded" + job.status === "succeeded", ); if (completedJob && !processedJobIds.has(completedJob.id)) { @@ -462,8 +471,8 @@ export default function BrainstormingConversationsPage() { // Loading state (but not during generation) if (isLoading && !isGenerating && !runningJob) { return ( -
-
+
+
); } @@ -475,7 +484,7 @@ export default function BrainstormingConversationsPage() {

Conversations

- +

{error}

- - - handleDeleteModuleClick(module, e)} - className="text-destructive focus:text-destructive" - > - - Delete - - - -
-
-

{module.title}

- - {moduleFeatures.filter((f) => f.is_answered).length} / {moduleFeatures.length} - -
- {module.description && ( -

{module.description}

- )} - - {moduleFeatures.length > 0 ? ( -
- {moduleFeatures.map((feature) => ( -
handleFeatureClick(feature.id)} - className={cn( - "flex items-start gap-3 p-2 rounded-md cursor-pointer transition-colors group/feature relative", - selectedFeatureId === feature.id && "ring-2 ring-primary", - feature.is_answered - ? "border border-green-200 bg-green-50/30 hover:bg-green-100/50 dark:border-green-800 dark:bg-green-950/40 dark:hover:bg-green-900/50" - : "bg-muted/50 hover:bg-muted" - )} - > - {/* Kebab menu for question - visible on hover */} -
- - - - - - handleDeleteFeatureClick(feature, e)} - className="text-destructive focus:text-destructive" +
+ {modules + .filter((module) => features.some((f) => f.module_id === module.id)) + .map((module) => { + const moduleFeatures = features.filter((f) => f.module_id === module.id); + const hasSelectedFeature = moduleFeatures.some((f) => f.id === selectedFeatureId); + + return ( + + {/* Kebab menu for aspect - visible on hover */} +
+ + + + + + handleDeleteModuleClick(module, e)} + className="text-destructive focus:text-destructive" + > + + Delete + + + +
+
+

{module.title}

+ + {moduleFeatures.filter((f) => f.is_answered).length} / {moduleFeatures.length} + +
+ {module.description && ( +

{module.description}

+ )} + + {moduleFeatures.length > 0 ? ( +
+ {moduleFeatures.map((feature) => ( +
handleFeatureClick(feature.id)} + className={cn( + "group/feature relative flex cursor-pointer items-start gap-3 rounded-md p-2 transition-colors", + selectedFeatureId === feature.id && "ring-primary ring-2", + feature.is_answered + ? "border border-green-200 bg-green-50/30 hover:bg-green-100/50 dark:border-green-800 dark:bg-green-950/40 dark:hover:bg-green-900/50" + : "bg-muted/50 hover:bg-muted", + )} + > + {/* Kebab menu for question - visible on hover */} +
+ + + + + + handleDeleteFeatureClick(feature, e)} + className="text-destructive focus:text-destructive" + > + + Delete + + + +
+ {feature.is_answered ? ( + + ) : ( + + )} +
+

{feature.title}

+ {/* Priority and Category badges */} +
+ - - Delete - - - -
- {feature.is_answered ? ( - - ) : ( - - )} -
-

{feature.title}

- {/* Priority and Category badges */} -
- - {priorityLabels[feature.priority]} - - {feature.category && ( - - {feature.category} + {priorityLabels[feature.priority]} - )} - {feature.unresolved_count !== undefined && feature.unresolved_count > 0 && ( - + {feature.category && ( + + {feature.category} + + )} + {feature.unresolved_count !== undefined && + feature.unresolved_count > 0 && ( + + )} +
+ {feature.spec_text && ( +

+ {feature.spec_text} +

)}
- {feature.spec_text && ( -

- {feature.spec_text} -

- )}
-
- ))} -
- ) : ( -

No questions in this aspect yet

- )} -
- ); - })} + ))} +
+ ) : ( +

No questions in this aspect yet

+ )} + + ); + })}
{/* Right column: Thread panel - positioned based on selected question */} {selectedFeature && orgId && (
- + {/* Delete Aspect Confirmation Dialog */} - !open && setModuleToDelete(null)}> + !open && setModuleToDelete(null)} + > Delete Aspect? Are you sure you want to delete the aspect{" "} - "{moduleToDelete?.title}"? - + "{moduleToDelete?.title}"? + This will also delete all clarification questions within this aspect. @@ -792,14 +817,17 @@ export default function BrainstormingConversationsPage() { {/* Delete Question Confirmation Dialog */} - !open && setFeatureToDelete(null)}> + !open && setFeatureToDelete(null)} + > Delete Question? Are you sure you want to delete the clarification question{" "} - "{featureToDelete?.title}"? - + "{featureToDelete?.title}"? + Any associated conversation and comments will also be removed. diff --git a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/description/page.tsx b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/description/page.tsx index 6f0f981..ecb54ac 100644 --- a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/description/page.tsx +++ b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/description/page.tsx @@ -7,7 +7,17 @@ import { apiClient } from "@/lib/api/client"; import { BrainstormingPhase } from "@/lib/api/types"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Progress } from "@/components/ui/progress"; -import { Loader2, MessageSquare, FileText, ClipboardList, Boxes, CheckCircle2, Circle, ArrowRight, AlertTriangle } from "lucide-react"; +import { + Loader2, + MessageSquare, + FileText, + ClipboardList, + Boxes, + CheckCircle2, + Circle, + ArrowRight, + AlertTriangle, +} from "lucide-react"; import { MarkdownArticle } from "@/components/MarkdownArticle"; import { DescriptionImageGallery } from "@/components/DescriptionImageGallery"; import { useNav } from "@/lib/contexts/NavContext"; @@ -18,7 +28,6 @@ import { buildPhaseFeaturesUrl, } from "@/lib/url"; - interface Feature { id: string; title: string; @@ -51,7 +60,7 @@ export default function BrainstormingDescriptionPage() { // Fetch conversation-type features for stats (only ACTIVE features are returned) const featuresData = await apiClient.get<{ items: Feature[] }>( - `/api/v1/projects/${projectId}/features?brainstorming_phase_id=${phaseId}&feature_type=conversation` + `/api/v1/projects/${projectId}/features?brainstorming_phase_id=${phaseId}&feature_type=conversation`, ); setFeatures(featuresData.items); } catch (err) { @@ -67,8 +76,8 @@ export default function BrainstormingDescriptionPage() { if (isLoading) { return ( -
- +
+
); } @@ -89,11 +98,15 @@ export default function BrainstormingDescriptionPage() { const totalQuestions = features.length; const answeredQuestions = features.filter((f) => f.is_answered).length; const unansweredQuestions = totalQuestions - answeredQuestions; - const completionPercentage = totalQuestions > 0 ? Math.round((answeredQuestions / totalQuestions) * 100) : 0; + const completionPercentage = + totalQuestions > 0 ? Math.round((answeredQuestions / totalQuestions) * 100) : 0; const totalUnresolved = features.reduce((sum, f) => sum + (f.unresolved_count || 0), 0); // Build URLs using URL builders if project and phase data are available - const buildQuickLinkUrl = (urlBuilder: (project: any, phase: any) => string, fallbackPath: string) => { + const buildQuickLinkUrl = ( + urlBuilder: (project: any, phase: any) => string, + fallbackPath: string, + ) => { if (currentProject && phase) { const projectInfo = { name: currentProject.name, short_id: currentProject.short_id }; const phaseInfo = { title: phase.title, short_id: phase.short_id }; @@ -131,7 +144,7 @@ export default function BrainstormingDescriptionPage() { ]; return ( -
+
{/* Main Content - Description rendered as markdown */}
@@ -145,9 +158,10 @@ export default function BrainstormingDescriptionPage() {

No description provided.

)} {/* Display description images if any */} - {phase.description_image_attachments && phase.description_image_attachments.length > 0 && ( - - )} + {phase.description_image_attachments && + phase.description_image_attachments.length > 0 && ( + + )}
@@ -171,35 +185,41 @@ export default function BrainstormingDescriptionPage() { {/* Stats Grid */}
-
+
{totalAspects}
-
Aspects
+
Aspects
-
+
{totalQuestions}
-
Questions
+
Questions
-
+
- {answeredQuestions} + + {answeredQuestions} +
-
Answered
+
Answered
-
+
- {unansweredQuestions} + + {unansweredQuestions} +
-
Pending
+
Pending
{totalUnresolved > 0 && ( -
+
- {totalUnresolved} + + {totalUnresolved} +
-
Unresolved Points
+
Unresolved Points
)}
@@ -218,16 +238,16 @@ export default function BrainstormingDescriptionPage() { -
- +
+
-
+
{link.name}
-
{link.description}
+
{link.description}
- + ); })} diff --git a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/features/page.tsx b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/features/page.tsx index 593612d..56c02eb 100644 --- a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/features/page.tsx +++ b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/features/page.tsx @@ -36,7 +36,16 @@ import { AlertDialogHeader, AlertDialogTitle, } from "@/components/ui/alert-dialog"; -import { Boxes, Loader2, Sparkles, RotateCcw, CheckCircle2, Circle, CircleDot, X } from "lucide-react"; +import { + Boxes, + Loader2, + Sparkles, + RotateCcw, + CheckCircle2, + Circle, + CircleDot, + X, +} from "lucide-react"; import { useNav } from "@/lib/contexts/NavContext"; import { useToast } from "@/hooks/use-toast"; import { buildFeatureUrl } from "@/lib/url"; @@ -63,8 +72,8 @@ export default function BrainstormingFeaturesPage() { // UI state const [showArchived, setShowArchived] = useState(false); - const [expandedModules, setExpandedModules] = useState>( - () => expandModuleId ? new Set([expandModuleId]) : new Set() + const [expandedModules, setExpandedModules] = useState>(() => + expandModuleId ? new Set([expandModuleId]) : new Set(), ); const [isGeneratingLocal, setIsGeneratingLocal] = useState(false); const [processedJobIds, setProcessedJobIds] = useState>(new Set()); @@ -73,7 +82,8 @@ export default function BrainstormingFeaturesPage() { const [phase, setPhase] = useState(null); // Generation status state - const [generationStatus, setGenerationStatus] = useState(null); + const [generationStatus, setGenerationStatus] = + useState(null); const [isFetchingStatus, setIsFetchingStatus] = useState(false); const [isCancellingGeneration, setIsCancellingGeneration] = useState(false); @@ -117,11 +127,13 @@ export default function BrainstormingFeaturesPage() { // Check if there's a running generation job const runningJob = wsJobs.find( - (job) => job.job_type === "module_feature_generate" && job.status === "running" + (job) => job.job_type === "module_feature_generate" && job.status === "running", ); // Get progress from running job - const progress = runningJob?.result?.progress as { workflow_step?: string; progress_percentage?: number } | undefined; + const progress = runningJob?.result?.progress as + | { workflow_step?: string; progress_percentage?: number } + | undefined; // Fetch data const fetchData = useCallback(async () => { @@ -135,7 +147,11 @@ export default function BrainstormingFeaturesPage() { apiClient.listModules(projectId, showArchived, "implementation"), // Only show implementation features (not conversation questions) // Use max limit (500) to fetch all features - apiClient.listFeatures(projectId, { include_archived: showArchived, feature_type: "implementation", limit: 500 }), + apiClient.listFeatures(projectId, { + include_archived: showArchived, + feature_type: "implementation", + limit: 500, + }), ]); setModules(modulesData); @@ -189,7 +205,7 @@ export default function BrainstormingFeaturesPage() { } } }, - [fetchData, fetchPhase] + [fetchData, fetchPhase], ); useWebSocketSubscription({ @@ -201,7 +217,7 @@ export default function BrainstormingFeaturesPage() { // Refresh when generation job completes useEffect(() => { const completedJob = wsJobs.find( - (job) => job.job_type === "module_feature_generate" && job.status === "succeeded" + (job) => job.job_type === "module_feature_generate" && job.status === "succeeded", ); if (completedJob && !processedJobIds.has(completedJob.id)) { @@ -251,18 +267,31 @@ export default function BrainstormingFeaturesPage() { // Compute stats for active features (not archived) const stats = useMemo(() => { // Get all features for this phase's modules - const phaseModuleIds = new Set(phaseModules.map(m => m.id)); - const phaseFeatures = mergedFeatures.filter(f => phaseModuleIds.has(f.module_id) && f.status === "active"); + const phaseModuleIds = new Set(phaseModules.map((m) => m.id)); + const phaseFeatures = mergedFeatures.filter( + (f) => phaseModuleIds.has(f.module_id) && f.status === "active", + ); const total = phaseFeatures.length; - const completed = phaseFeatures.filter(f => f.completion_status === "completed").length; - const inProgress = phaseFeatures.filter(f => f.completion_status === "in_progress").length; - const pending = phaseFeatures.filter(f => f.completion_status === "pending" || !f.completion_status).length; - const mustHave = phaseFeatures.filter(f => f.priority === "must_have").length; - const important = phaseFeatures.filter(f => f.priority === "important").length; - const optional = phaseFeatures.filter(f => f.priority === "optional").length; + const completed = phaseFeatures.filter((f) => f.completion_status === "completed").length; + const inProgress = phaseFeatures.filter((f) => f.completion_status === "in_progress").length; + const pending = phaseFeatures.filter( + (f) => f.completion_status === "pending" || !f.completion_status, + ).length; + const mustHave = phaseFeatures.filter((f) => f.priority === "must_have").length; + const important = phaseFeatures.filter((f) => f.priority === "important").length; + const optional = phaseFeatures.filter((f) => f.priority === "optional").length; const completionPercent = total > 0 ? Math.round((completed / total) * 100) : 0; - return { total, completed, inProgress, pending, mustHave, important, optional, completionPercent }; + return { + total, + completed, + inProgress, + pending, + mustHave, + important, + optional, + completionPercent, + }; }, [mergedFeatures, phaseModules]); // Handlers @@ -414,7 +443,7 @@ export default function BrainstormingFeaturesPage() { if (isLoading) { return (
- +
); } @@ -428,7 +457,7 @@ export default function BrainstormingFeaturesPage() {

Modules & Features

-

+

Organize implementation into modules and features

@@ -439,9 +468,11 @@ export default function BrainstormingFeaturesPage() { size="sm" onClick={() => setShowRestoreFeature(true)} disabled={!canArchive} - title={!canArchive ? "You need admin access or higher to restore features" : undefined} + title={ + !canArchive ? "You need admin access or higher to restore features" : undefined + } > - + Restore )} @@ -449,12 +480,14 @@ export default function BrainstormingFeaturesPage() { size="sm" onClick={handleGenerate} disabled={isFetchingStatus || showProgress || !canCreateModules} - title={!canCreateModules ? "You need member access or higher to generate modules" : undefined} + title={ + !canCreateModules ? "You need member access or higher to generate modules" : undefined + } > {isFetchingStatus ? ( - + ) : ( - + )} {phaseModules.length > 0 ? "Regenerate from Spec" : "Generate from Spec"} @@ -482,16 +515,16 @@ export default function BrainstormingFeaturesPage() { size="sm" onClick={handleCancelGeneration} disabled={isCancellingGeneration} - className="h-7 px-2 text-muted-foreground hover:text-red-600 dark:hover:text-red-400 hover:bg-red-100 dark:hover:bg-red-900/30" + className="text-muted-foreground h-7 px-2 hover:bg-red-100 hover:text-red-600 dark:hover:bg-red-900/30 dark:hover:text-red-400" > {isCancellingGeneration ? ( <> - + Cancelling... ) : ( <> - + Cancel )} @@ -506,19 +539,21 @@ export default function BrainstormingFeaturesPage() { {/* Empty state during generation */} {showProgress && phaseModules.length === 0 && ( - + -
+
- - + +
-
-

+
+

Building your modules and features...

-

- AI agents are analyzing your specification and prompt plan to create a structured breakdown of implementation modules and features. This typically takes 1-2 minutes. +

+ AI agents are analyzing your specification and prompt plan to create a structured + breakdown of implementation modules and features. This typically takes 1-2 + minutes.

@@ -532,10 +567,10 @@ export default function BrainstormingFeaturesPage() {
{/* Completion Progress */} -
-
+
+
Completion - + {stats.completed}/{stats.total} ({stats.completionPercent}%)
@@ -553,7 +588,7 @@ export default function BrainstormingFeaturesPage() { {stats.inProgress} in progress
- + {stats.pending} pending
@@ -561,19 +596,17 @@ export default function BrainstormingFeaturesPage() { {/* Priority Breakdown */}
{stats.mustHave > 0 && ( - + {stats.mustHave} P0 )} {stats.important > 0 && ( - + {stats.important} P1 )} {stats.optional > 0 && ( - - {stats.optional} P2 - + {stats.optional} P2 )}
@@ -589,10 +622,7 @@ export default function BrainstormingFeaturesPage() { checked={showArchived} onCheckedChange={(checked) => setShowArchived(checked === true)} /> -
@@ -600,7 +630,7 @@ export default function BrainstormingFeaturesPage() { {/* Error state */} {error && ( -
+
{error}
)} @@ -609,21 +639,26 @@ export default function BrainstormingFeaturesPage() { {phaseModules.length === 0 && !showProgress && ( - -

No Features Generated Yet

-

- Generate modules and features from your Final Specification and Prompt Plan with MFBT AI. + +

No Features Generated Yet

+

+ Generate modules and features from your Final Specification and Prompt Plan with MFBT + AI.

@@ -713,15 +748,13 @@ export default function BrainstormingFeaturesPage() { Archive Module - Are you sure you want to archive "{selectedModule?.title}"? - This will hide the module and all its features from the active view. + Are you sure you want to archive "{selectedModule?.title}"? This will hide the module + and all its features from the active view. Cancel - - Archive - + Archive @@ -732,15 +765,13 @@ export default function BrainstormingFeaturesPage() { Archive Feature - Are you sure you want to archive "{selectedFeature?.feature_key}: {selectedFeature?.title}"? - You can restore it later if needed. + Are you sure you want to archive "{selectedFeature?.feature_key}:{" "} + {selectedFeature?.title}"? You can restore it later if needed. Cancel - - Archive - + Archive diff --git a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/layout.tsx b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/layout.tsx index 0a6a078..8be8b8a 100644 --- a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/layout.tsx +++ b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/layout.tsx @@ -38,18 +38,13 @@ export async function generateMetadata({ return { title: "Brainstorming - MFBT" }; } - const [project, phase] = await Promise.all([ - projectRes.json(), - phaseRes.json(), - ]); + const [project, phase] = await Promise.all([projectRes.json(), phaseRes.json()]); return { title: `${phase.title} | ${project.name} - MFBT` }; } catch { return { title: "Brainstorming - MFBT" }; } } -export default function BrainstormingPhaseLayout({ - children, -}: BrainstormingPhaseLayoutProps) { +export default function BrainstormingPhaseLayout({ children }: BrainstormingPhaseLayoutProps) { return {children}; } diff --git a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/prompt-plan/page.tsx b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/prompt-plan/page.tsx index b94ef3a..2e808a2 100644 --- a/frontend/app/projects/[projectId]/brainstorming/[phaseId]/prompt-plan/page.tsx +++ b/frontend/app/projects/[projectId]/brainstorming/[phaseId]/prompt-plan/page.tsx @@ -10,10 +10,6 @@ export default function BrainstormingPromptPlanPage() { const phaseId = params.phaseId as string; return ( - + ); } diff --git a/frontend/app/projects/[projectId]/brainstorming/page.tsx b/frontend/app/projects/[projectId]/brainstorming/page.tsx index 4575ccd..95b4629 100644 --- a/frontend/app/projects/[projectId]/brainstorming/page.tsx +++ b/frontend/app/projects/[projectId]/brainstorming/page.tsx @@ -51,7 +51,7 @@ import { CreateChoiceDialog } from "@/components/CreateChoiceDialog"; function BrainstormingPageSkeleton() { return (
- +
); } @@ -148,7 +148,7 @@ function BrainstormingPhasesPageContent() { // Sorted containers by order_index const sortedContainers = useMemo( () => [...containers].sort((a, b) => a.order_index - b.order_index), - [containers] + [containers], ); const handleNewPhase = () => { @@ -159,13 +159,11 @@ function BrainstormingPhasesPageContent() { e.stopPropagation(); const containerMention = `#[${container.title}](${container.url_identifier})`; const message = encodeURIComponent( - `I'd like to add an extension to ${containerMention}. Here's what I want to explore:` + `I'd like to add an extension to ${containerMention}. Here's what I want to explore:`, ); - const basePath = currentProject - ? buildProjectUrl(currentProject) - : `/projects/${projectId}`; + const basePath = currentProject ? buildProjectUrl(currentProject) : `/projects/${projectId}`; router.push( - `${basePath}/project-chat?new=true&initialMessage=${message}&targetContainerId=${container.id}` + `${basePath}/project-chat?new=true&initialMessage=${message}&targetContainerId=${container.id}`, ); }; @@ -190,7 +188,7 @@ function BrainstormingPhasesPageContent() { const handleRenameSuccess = (updatedPhase: BrainstormingPhase) => { // Merge updated fields with existing phase data to preserve computed fields setPhases((prev) => - prev.map((p) => (p.id === updatedPhase.id ? { ...p, ...updatedPhase } : p)) + prev.map((p) => (p.id === updatedPhase.id ? { ...p, ...updatedPhase } : p)), ); }; @@ -234,7 +232,7 @@ function BrainstormingPhasesPageContent() { if (isLoading) { return (
- +
); } @@ -260,12 +258,12 @@ function BrainstormingPhasesPageContent() {

Brainstorming

-

+

Collaborate on ideas and generate specifications

@@ -274,14 +272,14 @@ function BrainstormingPhasesPageContent() { {!hasContent ? ( - -

No Brainstorming Phases

-

- Start by creating a brainstorming phase to collaborate on ideas, - generate specifications, and plan features. + +

No Brainstorming Phases

+

+ Start by creating a brainstorming phase to collaborate on ideas, generate + specifications, and plan features.

@@ -355,14 +353,20 @@ function BrainstormingPhasesPageContent() { )} {/* Archive Phase Confirmation Dialog */} - !open && setPhaseToArchive(null)}> + !open && setPhaseToArchive(null)} + > Archive Phase? Are you sure you want to archive the phase{" "} - "{phaseToArchive?.title}"? - + + "{phaseToArchive?.title}" + + ? + Archived phases can be restored later. @@ -376,7 +380,7 @@ function BrainstormingPhasesPageContent() { > {isArchiving ? ( <> - + Archiving... ) : ( @@ -388,15 +392,22 @@ function BrainstormingPhasesPageContent() { {/* Archive Container Confirmation Dialog */} - !open && setContainerToArchive(null)}> + !open && setContainerToArchive(null)} + > Archive Container? Are you sure you want to archive the container{" "} - "{containerToArchive?.title}"? - - All phases inside this container will also be archived. Archived containers can be restored later. + + "{containerToArchive?.title}" + + ? + + All phases inside this container will also be archived. Archived containers can be + restored later. @@ -409,7 +420,7 @@ function BrainstormingPhasesPageContent() { > {isArchiving ? ( <> - + Archiving... ) : ( @@ -435,7 +446,6 @@ function BrainstormingPhasesPageContent() { projectId={projectId} currentProject={currentProject} /> -
); } @@ -468,15 +478,15 @@ function ContainerGroup({ return ( <> {/* Container header row */} - +
- + {container.title} - + {phases.length} {phases.length === 1 ? "phase" : "phases"}
@@ -492,14 +502,14 @@ function ContainerGroup({ onAddExtension(container, e)}> - + Add Extension onArchiveContainer(container, e)} className="text-destructive focus:text-destructive focus:bg-destructive/10 dark:focus:bg-destructive/20" > - + Delete Container @@ -547,13 +557,10 @@ function PhaseRow({ const label = indented ? getPhaseLabel(phase) : phase.title; return ( - +
-
+
{label} {onRename && ( )}
{phase.description && ( -
+
{phase.description}
)} {phase.archived_at && ( - + Archived )} @@ -601,9 +611,9 @@ function PhaseRow({
{phase.created_by_name && ( - {phase.created_by_name} + {phase.created_by_name} )} - + {formatDistanceToNow(new Date(phase.created_at), { addSuffix: true })}
@@ -619,26 +629,28 @@ function PhaseRow({ {container && onAddExtension && ( onAddExtension(container, e)}> - + Add Extension )} {onRename && ( - + Rename Phase )} container && onArchiveContainer ? onArchiveContainer(container, e) : onArchive(e)} + onClick={(e) => + container && onArchiveContainer ? onArchiveContainer(container, e) : onArchive(e) + } className="text-destructive focus:text-destructive focus:bg-destructive/10 dark:focus:bg-destructive/20" > - + Archive - +
diff --git a/frontend/app/projects/[projectId]/features/[featureId]/FeatureDetailClient.tsx b/frontend/app/projects/[projectId]/features/[featureId]/FeatureDetailClient.tsx index ace48ab..9b3d5f2 100644 --- a/frontend/app/projects/[projectId]/features/[featureId]/FeatureDetailClient.tsx +++ b/frontend/app/projects/[projectId]/features/[featureId]/FeatureDetailClient.tsx @@ -82,33 +82,35 @@ import { buildProjectUrl, buildPhaseUrl } from "@/lib/url"; const PROVENANCE_BADGES: Record = { system: { label: "System", - className: "bg-purple-100 text-purple-700 border-purple-200 dark:bg-purple-900/30 dark:text-purple-400 dark:border-purple-800", + className: + "bg-purple-100 text-purple-700 border-purple-200 dark:bg-purple-900/30 dark:text-purple-400 dark:border-purple-800", }, user: { label: "User", - className: "bg-slate-100 text-slate-700 border-slate-200 dark:bg-slate-900/30 dark:text-slate-400 dark:border-slate-800", + className: + "bg-slate-100 text-slate-700 border-slate-200 dark:bg-slate-900/30 dark:text-slate-400 dark:border-slate-800", }, restored: { label: "Restored", - className: "bg-blue-100 text-blue-700 border-blue-200 dark:bg-blue-900/30 dark:text-blue-400 dark:border-blue-800", + className: + "bg-blue-100 text-blue-700 border-blue-200 dark:bg-blue-900/30 dark:text-blue-400 dark:border-blue-800", }, }; const STATUS_BADGES: Record = { active: { label: "Active", - className: "bg-green-100 text-green-700 border-green-200 dark:bg-green-900/30 dark:text-green-400 dark:border-green-800", + className: + "bg-green-100 text-green-700 border-green-200 dark:bg-green-900/30 dark:text-green-400 dark:border-green-800", }, archived: { label: "Archived", - className: "bg-slate-100 text-slate-700 border-slate-200 dark:bg-slate-900/30 dark:text-slate-400 dark:border-slate-800", + className: + "bg-slate-100 text-slate-700 border-slate-200 dark:bg-slate-900/30 dark:text-slate-400 dark:border-slate-800", }, }; -const PRIORITY_CONFIG: Record< - FeaturePriority, - { label: string; className: string } -> = { +const PRIORITY_CONFIG: Record = { must_have: { label: "Must Have", className: "bg-red-100 text-red-800 dark:bg-red-900/30 dark:text-red-400", @@ -209,7 +211,9 @@ export default function FeatureDetailClient() { const [hasConversation, setHasConversation] = useState(false); const [threadId, setThreadId] = useState(null); const [latestThreadItemCreatedAt, setLatestThreadItemCreatedAt] = useState(null); - const [suggestedImplementationName, setSuggestedImplementationName] = useState(null); + const [suggestedImplementationName, setSuggestedImplementationName] = useState( + null, + ); const [showCreateImplementationButton, setShowCreateImplementationButton] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [isCancellingGeneration, setIsCancellingGeneration] = useState(false); @@ -225,7 +229,11 @@ export default function FeatureDetailClient() { // WebSocket for implementation created/deleted/updated events // Use feature?.id (UUID), not featureId (short URL identifier) - const { updates: implementationUpdates, deletions: implementationDeletions, contentUpdates: implementationContentUpdates } = useImplementationWebSocket({ + const { + updates: implementationUpdates, + deletions: implementationDeletions, + contentUpdates: implementationContentUpdates, + } = useImplementationWebSocket({ orgId: user?.org_id || "", projectId: currentProject?.id, featureId: feature?.id, @@ -235,7 +243,9 @@ export default function FeatureDetailClient() { // Unresolved points warning state const [sidebarData, setSidebarData] = useState(null); const [showUnresolvedWarning, setShowUnresolvedWarning] = useState(false); - const [pendingGenerationType, setPendingGenerationType] = useState<"spec" | "prompt_plan" | null>(null); + const [pendingGenerationType, setPendingGenerationType] = useState<"spec" | "prompt_plan" | null>( + null, + ); // Implementation state const [implementations, setImplementations] = useState([]); @@ -261,7 +271,8 @@ export default function FeatureDetailClient() { setImplementations(implementationsData); // Auto-select primary implementation if none selected if (implementationsData.length > 0 && !selectedImplementationId) { - const primary = implementationsData.find(impl => impl.is_primary) || implementationsData[0]; + const primary = + implementationsData.find((impl) => impl.is_primary) || implementationsData[0]; setSelectedImplementationId(primary.id); } @@ -441,7 +452,7 @@ export default function FeatureDetailClient() { decision_summary: latest.decision_summary, unresolved_points: latest.unresolved_points, } - : prev + : prev, ); } }, [decisionUpdates, threadId]); @@ -452,7 +463,7 @@ export default function FeatureDetailClient() { (job) => job.job_type === "feature_content_generate" && job.status === "running" && - job.result?.feature_id === feature?.id + job.result?.feature_id === feature?.id, ); // Find running decision summarizer job for this thread @@ -460,7 +471,7 @@ export default function FeatureDetailClient() { (job) => job.job_type === "collab_thread_decision_summarize" && job.status === "running" && - job.result?.thread_id === threadId + job.result?.thread_id === threadId, ); // Clear local isGenerating when WebSocket picks up the job @@ -480,7 +491,7 @@ export default function FeatureDetailClient() { job.job_type === "feature_content_generate" && job.status === "succeeded" && job.result?.feature_id === feature.id && - !processedJobIds.current.has(job.id) + !processedJobIds.current.has(job.id), ); if (completedJob) { @@ -504,9 +515,7 @@ export default function FeatureDetailClient() { apiClient.initializeAuth(); await apiClient.archiveFeature(featureId); setShowArchiveConfirm(false); - const basePath = currentProject - ? buildProjectUrl(currentProject) - : `/projects/${projectId}`; + const basePath = currentProject ? buildProjectUrl(currentProject) : `/projects/${projectId}`; router.push(`${basePath}/features`); } catch (err) { setError(err instanceof Error ? err.message : "Failed to archive feature"); @@ -532,56 +541,71 @@ export default function FeatureDetailClient() { }, []); // Handle thread loaded - track conversation state and button visibility - const handleThreadLoaded = useCallback((hasItems: boolean, loadedThreadId: string | null, latestItemCreatedAt: string | null, suggestedImplName: string | null, showButton: boolean) => { - setHasConversation(hasItems); - setThreadId(loadedThreadId); - setLatestThreadItemCreatedAt(latestItemCreatedAt); - setSuggestedImplementationName(suggestedImplName); - setShowCreateImplementationButton(showButton); - }, []); + const handleThreadLoaded = useCallback( + ( + hasItems: boolean, + loadedThreadId: string | null, + latestItemCreatedAt: string | null, + suggestedImplName: string | null, + showButton: boolean, + ) => { + setHasConversation(hasItems); + setThreadId(loadedThreadId); + setLatestThreadItemCreatedAt(latestItemCreatedAt); + setSuggestedImplementationName(suggestedImplName); + setShowCreateImplementationButton(showButton); + }, + [], + ); // Actually perform the content generation - const performGeneration = useCallback(async (contentType: "spec" | "prompt_plan") => { - if (!feature) return; + const performGeneration = useCallback( + async (contentType: "spec" | "prompt_plan") => { + if (!feature) return; - // Require an implementation to be selected - if (!selectedImplementationId) { - setError("Please select an implementation first"); - return; - } + // Require an implementation to be selected + if (!selectedImplementationId) { + setError("Please select an implementation first"); + return; + } - setIsGenerating(true); - setError(null); - try { - apiClient.initializeAuth(); - await apiClient.generateFeatureContent(feature.id, contentType, selectedImplementationId); - // Job started successfully - WebSocket will track progress and refresh on completion - // Keep isGenerating true until WebSocket picks up the running job - } catch (err) { - console.error("Failed to generate content:", err); - setError(err instanceof Error ? err.message : "Failed to start generation"); - setIsGenerating(false); - } - }, [feature, selectedImplementationId]); + setIsGenerating(true); + setError(null); + try { + apiClient.initializeAuth(); + await apiClient.generateFeatureContent(feature.id, contentType, selectedImplementationId); + // Job started successfully - WebSocket will track progress and refresh on completion + // Keep isGenerating true until WebSocket picks up the running job + } catch (err) { + console.error("Failed to generate content:", err); + setError(err instanceof Error ? err.message : "Failed to start generation"); + setIsGenerating(false); + } + }, + [feature, selectedImplementationId], + ); // Handle generate content - check for unresolved points first - const handleGenerateContent = useCallback(async (contentType: "spec" | "prompt_plan") => { - if (!hasConversation || !feature) { - return; - } + const handleGenerateContent = useCallback( + async (contentType: "spec" | "prompt_plan") => { + if (!hasConversation || !feature) { + return; + } - // Check if there are unresolved points - const unresolvedCount = sidebarData?.unresolved_points?.length || 0; - if (unresolvedCount > 0) { - // Show warning dialog and store the pending generation type - setPendingGenerationType(contentType); - setShowUnresolvedWarning(true); - return; - } + // Check if there are unresolved points + const unresolvedCount = sidebarData?.unresolved_points?.length || 0; + if (unresolvedCount > 0) { + // Show warning dialog and store the pending generation type + setPendingGenerationType(contentType); + setShowUnresolvedWarning(true); + return; + } - // No unresolved points, proceed directly - await performGeneration(contentType); - }, [hasConversation, feature, sidebarData, performGeneration]); + // No unresolved points, proceed directly + await performGeneration(contentType); + }, + [hasConversation, feature, sidebarData, performGeneration], + ); // Handle confirmation to proceed despite unresolved points const handleProceedWithGeneration = useCallback(async () => { @@ -629,7 +653,7 @@ export default function FeatureDetailClient() { if (isLoading) { return (
- +
); } @@ -645,12 +669,12 @@ export default function FeatureDetailClient() {
- + Back to Features -
+
{error}
@@ -666,47 +690,42 @@ export default function FeatureDetailClient() { {/* Back link */} - + Back to Features {/* Archived banner */} {isArchived && ( -
- +
+ Archived Feature — read-only
)} {/* Header */}
-
+

- + : {feature.title}

-
+
{module && } {module?.brainstorming_phase_id ? ( {module.title} ) : ( - - {module?.title || "Unknown"} - + {module?.title || "Unknown"} )} {feature.category && ( @@ -715,21 +734,19 @@ export default function FeatureDetailClient() { {module?.brainstorming_phase_id ? ( {feature.category} ) : ( - - {feature.category} - + {feature.category} )} )}
-
+
{/* Left: Badges */} -
+
{/* Completion Status */} {(() => { const completionConfig = COMPLETION_CONFIG[feature.completion_status || "pending"]; @@ -737,7 +754,7 @@ export default function FeatureDetailClient() { return ( {completionConfig.label} @@ -754,13 +771,19 @@ export default function FeatureDetailClient() { {/* Provenance */} - + {PROVENANCE_BADGES[feature.provenance].label} {/* Status (only show if archived) */} {isArchived && ( - + {STATUS_BADGES[feature.status].label} )} @@ -770,9 +793,21 @@ export default function FeatureDetailClient() { {/* When an implementation is selected, use its content status instead of feature-level */} @@ -780,21 +815,21 @@ export default function FeatureDetailClient() { {/* Import Source Info */} {feature.external_provider && ( -
-
+
+
Imported from {feature.external_provider === "github" ? "GitHub" : "Jira"}
-
+
Issue:{" "} {feature.external_id} @@ -803,13 +838,17 @@ export default function FeatureDetailClient() { {feature.external_author && (
Author:{" "} - {feature.external_author} + + {feature.external_author} +
)} {feature.external_status && (
Status:{" "} - {feature.external_status} + + {feature.external_status} +
)} {feature.external_imported_at && ( @@ -824,7 +863,11 @@ export default function FeatureDetailClient() { {feature.external_labels && feature.external_labels.length > 0 && (
{feature.external_labels.map((label) => ( - + {label} ))} @@ -845,18 +888,18 @@ export default function FeatureDetailClient() { {!isArchived && ( <> setShowClearStatusNotesDialog(true)}> - + Clear Status & Notes setShowArchiveConfirm(true)}> - + Archive )} {isArchived && ( setShowRestoreDialog(true)}> - + Restore )} @@ -866,7 +909,7 @@ export default function FeatureDetailClient() { {/* Error message */} {error && ( -
+
{error}
)} @@ -910,7 +953,7 @@ export default function FeatureDetailClient() { )} {/* Two-column layout: Tab content on left, Completion Summary on right */} -
+
{/* Left column: Tab content */}
@@ -943,7 +986,7 @@ export default function FeatureDetailClient() { // If summarizer is running, show status instead of button if (isSummarizerRunning) { return ( -
+
); @@ -953,7 +996,7 @@ export default function FeatureDetailClient() { if (!showCreateImplementationButton) return undefined; return ( -
+
0} @@ -970,7 +1013,7 @@ export default function FeatureDetailClient() { handleGenerateContent("spec")} onCancelClick={handleCancelGeneration} @@ -990,7 +1037,7 @@ export default function FeatureDetailClient() { handleGenerateContent("prompt_plan")} onCancelClick={handleCancelGeneration} @@ -1010,7 +1063,7 @@ export default function FeatureDetailClient() { {/* Right column: Sidebar (sticky) */} -
+
{/* Implementation Selector Card */} -
+
+
- - Completion Summary - + Completion Summary
-

+

{selectedImplementation.completion_summary}

{selectedImplementation.completed_at && ( -

+

{new Date(selectedImplementation.completed_at).toLocaleString(undefined, { dateStyle: "medium", @@ -1088,8 +1139,8 @@ export default function FeatureDetailClient() { Archive Feature - Are you sure you want to archive "{feature.feature_key}: {feature.title}"? - You can restore it later if needed. + Are you sure you want to archive "{feature.feature_key}: {feature.title}"? You can + restore it later if needed. @@ -1139,21 +1190,24 @@ export default function FeatureDetailClient() {

This feature has {sidebarData?.unresolved_points?.length || 0} unresolved{" "} - {(sidebarData?.unresolved_points?.length || 0) === 1 ? "point" : "points"} in the conversation. + {(sidebarData?.unresolved_points?.length || 0) === 1 ? "point" : "points"} in the + conversation.

- Generating the {pendingGenerationType === "prompt_plan" ? "prompt plan" : "specification"} now may result - in incomplete or less accurate content. Consider resolving these points first for better results. + Generating the{" "} + {pendingGenerationType === "prompt_plan" ? "prompt plan" : "specification"} now + may result in incomplete or less accurate content. Consider resolving these points + first for better results.

-
-
    +
    +
      {sidebarData?.unresolved_points?.slice(0, 3).map((point, index) => (
    • • {point.question}
    • ))} {(sidebarData?.unresolved_points?.length || 0) > 3 && ( -
    • +
    • ...and {(sidebarData?.unresolved_points?.length || 0) - 3} more
    • )} diff --git a/frontend/app/projects/[projectId]/features/[featureId]/page.tsx b/frontend/app/projects/[projectId]/features/[featureId]/page.tsx index 39c30d7..eed6b73 100644 --- a/frontend/app/projects/[projectId]/features/[featureId]/page.tsx +++ b/frontend/app/projects/[projectId]/features/[featureId]/page.tsx @@ -6,9 +6,7 @@ interface FeatureDetailPageProps { params: Promise<{ projectId: string; featureId: string }>; } -export async function generateMetadata({ - params, -}: FeatureDetailPageProps): Promise { +export async function generateMetadata({ params }: FeatureDetailPageProps): Promise { const { projectId, featureId } = await params; try { @@ -37,10 +35,7 @@ export async function generateMetadata({ return { title: "Feature - MFBT" }; } - const [project, feature] = await Promise.all([ - projectRes.json(), - featureRes.json(), - ]); + const [project, feature] = await Promise.all([projectRes.json(), featureRes.json()]); return { title: `${feature.feature_key}: ${feature.title} | ${project.name} - MFBT`, }; diff --git a/frontend/app/projects/[projectId]/features/page.tsx b/frontend/app/projects/[projectId]/features/page.tsx index ad7e8f4..2c7cc1d 100644 --- a/frontend/app/projects/[projectId]/features/page.tsx +++ b/frontend/app/projects/[projectId]/features/page.tsx @@ -4,15 +4,15 @@ import { Suspense, useEffect, useState, useMemo, useCallback } from "react"; import { useParams, useRouter } from "next/navigation"; import { apiClient } from "@/lib/api/client"; import { useAuth } from "@/lib/auth/AuthContext"; +import { useFeatureWebSocket, useMergedFeatures } from "@/lib/hooks/useFeatureWebSocket"; +import { useFeaturePreferences, FeatureFilters } from "@/lib/hooks/useFeaturePreferences"; import { - useFeatureWebSocket, - useMergedFeatures, -} from "@/lib/hooks/useFeatureWebSocket"; -import { - useFeaturePreferences, - FeatureFilters, -} from "@/lib/hooks/useFeaturePreferences"; -import { Module, Feature, FeatureSortField, SortOrder, ModuleArchiveResponse } from "@/lib/api/types"; + Module, + Feature, + FeatureSortField, + SortOrder, + ModuleArchiveResponse, +} from "@/lib/api/types"; import { FeatureTable, ImportIssueModal, @@ -63,7 +63,7 @@ function FeaturesPageSkeleton() {
      {/* Header skeleton */}
      - +
      @@ -76,17 +76,14 @@ function FeaturesPageSkeleton() { {/* Table skeleton */} -
      +
      {[1, 2, 3, 4, 5].map((i) => ( -
      +
      @@ -138,9 +135,7 @@ function ProjectFeaturesPageContent() { const [paginationTotal, setPaginationTotal] = useState(0); // Track expanded modules for grouped view - const [expandedModules, setExpandedModules] = useState>( - new Set(), - ); + const [expandedModules, setExpandedModules] = useState>(new Set()); // Modal state const [isImportModalOpen, setIsImportModalOpen] = useState(false); @@ -188,21 +183,15 @@ function ProjectFeaturesPageContent() { include_archived: preferences.showArchived, feature_type: "implementation", module_id: - preferences.selectedModuleId !== "all" - ? preferences.selectedModuleId - : undefined, + preferences.selectedModuleId !== "all" ? preferences.selectedModuleId : undefined, priority: - preferences.filters.priority.length > 0 - ? preferences.filters.priority - : undefined, + preferences.filters.priority.length > 0 ? preferences.filters.priority : undefined, completion_status: preferences.filters.completion_status.length > 0 ? preferences.filters.completion_status : undefined, provenance: - preferences.filters.provenance.length > 0 - ? preferences.filters.provenance - : undefined, + preferences.filters.provenance.length > 0 ? preferences.filters.provenance : undefined, has_spec: preferences.filters.has_spec, has_notes: preferences.filters.has_notes, external_provider: preferences.filters.external_provider, @@ -282,26 +271,15 @@ function ProjectFeaturesPageContent() { const stats = useMemo(() => { const activeFeatures = mergedFeatures.filter((f) => f.status === "active"); const total = activeFeatures.length; - const completed = activeFeatures.filter( - (f) => f.completion_status === "completed", - ).length; - const inProgress = activeFeatures.filter( - (f) => f.completion_status === "in_progress", - ).length; + const completed = activeFeatures.filter((f) => f.completion_status === "completed").length; + const inProgress = activeFeatures.filter((f) => f.completion_status === "in_progress").length; const pending = activeFeatures.filter( (f) => f.completion_status === "pending" || !f.completion_status, ).length; - const mustHave = activeFeatures.filter( - (f) => f.priority === "must_have", - ).length; - const important = activeFeatures.filter( - (f) => f.priority === "important", - ).length; - const optional = activeFeatures.filter( - (f) => f.priority === "optional", - ).length; - const completionPercent = - total > 0 ? Math.round((completed / total) * 100) : 0; + const mustHave = activeFeatures.filter((f) => f.priority === "must_have").length; + const important = activeFeatures.filter((f) => f.priority === "important").length; + const optional = activeFeatures.filter((f) => f.priority === "optional").length; + const completionPercent = total > 0 ? Math.round((completed / total) * 100) : 0; return { total, @@ -349,9 +327,12 @@ function ProjectFeaturesPageContent() { }; // Get active feature count for a module - const getActiveFeatureCount = useCallback((moduleId: string): number => { - return (groupedFeatures.get(moduleId) || []).filter(f => f.status === "active").length; - }, [groupedFeatures]); + const getActiveFeatureCount = useCallback( + (moduleId: string): number => { + return (groupedFeatures.get(moduleId) || []).filter((f) => f.status === "active").length; + }, + [groupedFeatures], + ); // Handle archive module - opens confirmation dialog const handleArchiveModule = (module: Module) => { @@ -372,7 +353,7 @@ function ProjectFeaturesPageContent() { if (result.archived_features_count > 0) { toast({ title: "Module archived", - description: `Archived "${result.title}" and ${result.archived_features_count} feature${result.archived_features_count === 1 ? '' : 's'}`, + description: `Archived "${result.title}" and ${result.archived_features_count} feature${result.archived_features_count === 1 ? "" : "s"}`, }); } else { toast({ @@ -463,20 +444,18 @@ function ProjectFeaturesPageContent() { {/* Header */}
      -

      - All Features -

      -

      +

      All Features

      +

      User-defined Features and Features across all Brainstorming Phases

      @@ -487,7 +466,7 @@ function ProjectFeaturesPageContent() { - + Clear Status & Notes @@ -501,10 +480,10 @@ function ProjectFeaturesPageContent() {
      {/* Completion Progress */} -
      -
      +
      +
      Completion - + {stats.completed}/{stats.total} ({stats.completionPercent}%)
      @@ -519,12 +498,10 @@ function ProjectFeaturesPageContent() {
      - - {stats.inProgress} in progress - + {stats.inProgress} in progress
      - + {stats.pending} pending
      @@ -532,19 +509,17 @@ function ProjectFeaturesPageContent() { {/* Priority Breakdown */}
      {stats.mustHave > 0 && ( - + {stats.mustHave} P0 )} {stats.important > 0 && ( - + {stats.important} P1 )} {stats.optional > 0 && ( - - {stats.optional} P2 - + {stats.optional} P2 )}
      @@ -570,17 +545,17 @@ function ProjectFeaturesPageContent() { {/* Error state */} {error && ( -
      +
      {error}
      )} {/* Empty state */} {mergedFeatures.length === 0 && !error && ( -
      - -

      No Features Found

      -

      +

      + +

      No Features Found

      +

      {preferences.selectedModuleId !== "all" ? 'No features in the selected module. Try selecting a different module or "All Modules".' : "Add a new feature here or create a Brainstorming Phase so mfbt can generate them for you."} @@ -671,11 +646,17 @@ function ProjectFeaturesPageContent() { <> Are you sure you want to archive "{moduleToArchive.title}" {getActiveFeatureCount(moduleToArchive.id) > 0 && ( - <> and its {getActiveFeatureCount(moduleToArchive.id)} active feature{getActiveFeatureCount(moduleToArchive.id) === 1 ? '' : 's'} - )}? -

      - This will hide the module and all its features from the active view. - You can restore them later if needed. + <> + {" "} + and its {getActiveFeatureCount(moduleToArchive.id)} active + feature{getActiveFeatureCount(moduleToArchive.id) === 1 ? "" : "s"} + + )} + ? +
      +
      + This will hide the module and all its features from the active view. You can + restore them later if needed. )} diff --git a/frontend/app/projects/[projectId]/jobs/page.tsx b/frontend/app/projects/[projectId]/jobs/page.tsx index aac5610..3f875f4 100644 --- a/frontend/app/projects/[projectId]/jobs/page.tsx +++ b/frontend/app/projects/[projectId]/jobs/page.tsx @@ -28,14 +28,12 @@ export default function JobsPage() { apiClient.initializeAuth(); // Fetch project to get org_id - const projectData = await apiClient.get( - `/api/v1/projects/${projectId}` - ); + const projectData = await apiClient.get(`/api/v1/projects/${projectId}`); setOrgId(projectData.org_id); // Fetch jobs for this project const jobsData = await apiClient.get( - `/api/v1/orgs/${projectData.org_id}/jobs?project_id=${projectId}` + `/api/v1/orgs/${projectData.org_id}/jobs?project_id=${projectId}`, ); setJobs(jobsData); } catch (err) { @@ -52,7 +50,11 @@ export default function JobsPage() { // WebSocket connection for real-time job updates // Use currentProject?.id (UUID) for filtering, not projectId (short URL identifier) - const wsJobs = useJobWebSocket({ orgId: orgId || "", projectId: currentProject?.id, enabled: !!orgId && !!currentProject?.id }); + const wsJobs = useJobWebSocket({ + orgId: orgId || "", + projectId: currentProject?.id, + enabled: !!orgId && !!currentProject?.id, + }); // Merge API jobs with WebSocket jobs const mergedJobs = (() => { @@ -70,7 +72,7 @@ export default function JobsPage() { // Convert back to array and sort by created_at descending return Array.from(jobMap.values()).sort( - (a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime() + (a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime(), ); })(); @@ -79,12 +81,10 @@ export default function JobsPage() {

      Jobs

      -

      - Background job status and history. -

      +

      Background job status and history.

      -
      - +
      +
      ); @@ -95,11 +95,9 @@ export default function JobsPage() {

      Jobs

      -

      - Background job status and history. -

      +

      Background job status and history.

      -
      {error}
      +
      {error}
      ); } @@ -114,14 +112,14 @@ export default function JobsPage() {
      {mergedJobs.length === 0 ? ( -
      -
      - +
      +
      +
      -

      No jobs yet

      -

      - Jobs will appear here when you run background tasks like generating - discovery questions or specifications. +

      No jobs yet

      +

      + Jobs will appear here when you run background tasks like generating discovery questions + or specifications.

      ) : ( diff --git a/frontend/app/projects/[projectId]/layout.tsx b/frontend/app/projects/[projectId]/layout.tsx index 2a0e838..6e4c1be 100644 --- a/frontend/app/projects/[projectId]/layout.tsx +++ b/frontend/app/projects/[projectId]/layout.tsx @@ -7,9 +7,7 @@ interface ProjectLayoutProps { children: React.ReactNode; } -export async function generateMetadata({ - params, -}: ProjectLayoutProps): Promise { +export async function generateMetadata({ params }: ProjectLayoutProps): Promise { const { projectId } = await params; try { @@ -39,13 +37,8 @@ export async function generateMetadata({ } } -export default async function ProjectLayout({ - params, - children, -}: ProjectLayoutProps) { +export default async function ProjectLayout({ params, children }: ProjectLayoutProps) { const { projectId } = await params; - return ( - {children} - ); + return {children}; } diff --git a/frontend/app/projects/[projectId]/project-chat/[discussionId]/page.tsx b/frontend/app/projects/[projectId]/project-chat/[discussionId]/page.tsx index 0fc3fea..32d8cb6 100644 --- a/frontend/app/projects/[projectId]/project-chat/[discussionId]/page.tsx +++ b/frontend/app/projects/[projectId]/project-chat/[discussionId]/page.tsx @@ -82,9 +82,7 @@ export default function ProjectChatPage() { const [isCreatingFeature, setIsCreatingFeature] = useState(false); const [isCancelling, setIsCancelling] = useState(false); const [isUpdatingVisibility, setIsUpdatingVisibility] = useState(false); - const [createdFeatures, setCreatedFeatures] = useState( - [], - ); + const [createdFeatures, setCreatedFeatures] = useState([]); const [pendingImages, setPendingImages] = useState([]); // Pre-fill state for quick action buttons @@ -108,9 +106,7 @@ export default function ProjectChatPage() { } | null>(null); // Sidebar state - const [discussions, setProjectChats] = useState( - [], - ); + const [discussions, setProjectChats] = useState([]); const [discussionsTotal, setProjectChatsTotal] = useState(0); const [discussionsHasMore, setProjectChatsHasMore] = useState(false); const [discussionsOffset, setProjectChatsOffset] = useState(0); @@ -190,10 +186,7 @@ export default function ProjectChatPage() { const loadCreatedFeatures = useCallback(async () => { if (!projectId || !projectChatId) return; try { - const features = await apiClient.getCreatedFeatures( - projectId, - projectChatId, - ); + const features = await apiClient.getCreatedFeatures(projectId, projectChatId); setCreatedFeatures(features); } catch (err) { console.error("Failed to load created features:", err); @@ -211,11 +204,7 @@ export default function ProjectChatPage() { } try { const offset = reset ? 0 : discussionsOffset; - const result = await apiClient.listProjectChats( - projectId, - 20, - offset, - ); + const result = await apiClient.listProjectChats(projectId, 20, offset); if (reset) { setProjectChats(result.project_chats); setProjectChatsOffset(20); @@ -293,18 +282,10 @@ export default function ProjectChatPage() { ); } else if (update.event_type === "project_chat_message_updated" && update.message) { // Update the message in place (e.g., reactions changed) - setMessages((prev) => - prev.map((m) => - m.id === update.message!.id ? update.message! : m, - ), - ); - } else if ( - update.event_type === "project_chat_updated" && - update.project_chat - ) { + setMessages((prev) => prev.map((m) => (m.id === update.message!.id ? update.message! : m))); + } else if (update.event_type === "project_chat_updated" && update.project_chat) { const prevFeatureCount = discussion?.created_feature_ids?.length ?? 0; - const newFeatureCount = - update.project_chat.created_feature_ids?.length ?? 0; + const newFeatureCount = update.project_chat.created_feature_ids?.length ?? 0; setProjectChat(update.project_chat); // Sync AI processing state with project_chat is_generating flag @@ -339,12 +320,13 @@ export default function ProjectChatPage() { // Update sidebar message count setProjectChats((prev) => prev.map((d) => - d.id === discussion?.id - ? { ...d, message_count: Math.max(0, d.message_count - 1) } - : d, + d.id === discussion?.id ? { ...d, message_count: Math.max(0, d.message_count - 1) } : d, ), ); - } else if (update.event_type === "project_chat_messages_bulk_deleted" && update.deleted_message_ids) { + } else if ( + update.event_type === "project_chat_messages_bulk_deleted" && + update.deleted_message_ids + ) { // Remove bulk deleted messages const deletedIds = new Set(update.deleted_message_ids); setMessages((prev) => prev.filter((m) => !deletedIds.has(m.id))); @@ -358,7 +340,13 @@ export default function ProjectChatPage() { ); } } - }, [wsUpdates, discussion?.id, discussion?.created_feature_ids?.length, loadCreatedFeatures, user?.id]); + }, [ + wsUpdates, + discussion?.id, + discussion?.created_feature_ids?.length, + loadCreatedFeatures, + user?.id, + ]); // Load project and discussion useEffect(() => { @@ -369,16 +357,11 @@ export default function ProjectChatPage() { apiClient.initializeAuth(); // Fetch project - const projectData = await apiClient.get( - `/api/v1/projects/${projectId}`, - ); + const projectData = await apiClient.get(`/api/v1/projects/${projectId}`); setProject(projectData); // Fetch discussion by ID - const discussionData = await apiClient.getProjectChatById( - projectId, - projectChatId, - ); + const discussionData = await apiClient.getProjectChatById(projectId, projectChatId); if (discussionData) { setProjectChat(discussionData.project_chat); setMessages(discussionData.messages); @@ -425,7 +408,7 @@ export default function ProjectChatPage() { if (highlightedItemId) return; // Skip auto-scroll when deep linking if (scrollContainerRef.current) { // Small delay when proposal panels appear to allow DOM to render - const delay = (readyToCreatePhase || readyToCreateFeature) ? 100 : 0; + const delay = readyToCreatePhase || readyToCreateFeature ? 100 : 0; setTimeout(() => { scrollContainerRef.current?.scrollTo({ top: scrollContainerRef.current?.scrollHeight ?? 0, @@ -433,7 +416,15 @@ export default function ProjectChatPage() { }); }, delay); } - }, [messages.length, isAIProcessing, isExploringCode, isSearchingWeb, readyToCreatePhase, readyToCreateFeature, highlightedItemId]); + }, [ + messages.length, + isAIProcessing, + isExploringCode, + isSearchingWeb, + readyToCreatePhase, + readyToCreateFeature, + highlightedItemId, + ]); // Scroll to bottom on initial load (after loading completes) // But NOT when there's a highlighted item (deep link) @@ -456,7 +447,7 @@ export default function ProjectChatPage() { if (!highlightedItemId || !messages.length) return; // Check if the highlighted message exists - const messageExists = messages.some(msg => msg.id === highlightedItemId); + const messageExists = messages.some((msg) => msg.id === highlightedItemId); if (!messageExists) return; // Poll for the element to be rendered (up to 3 seconds) @@ -487,17 +478,10 @@ export default function ProjectChatPage() { const previewUrl = URL.createObjectURL(file); // Add pending image with uploading state - setPendingImages((prev) => [ - ...prev, - { id, file, previewUrl, uploading: true }, - ]); + setPendingImages((prev) => [...prev, { id, file, previewUrl, uploading: true }]); try { - const result = await apiClient.uploadDiscussionImage( - file, - projectId, - discussion.id, - ); + const result = await apiClient.uploadDiscussionImage(file, projectId, discussion.id); // Remove from pending images - we're opening the annotation modal setPendingImages((prev) => prev.filter((img) => img.id !== id)); @@ -515,9 +499,7 @@ export default function ProjectChatPage() { console.error("Failed to upload image:", err); setPendingImages((prev) => prev.map((img) => - img.id === id - ? { ...img, uploading: false, error: "Upload failed" } - : img, + img.id === id ? { ...img, uploading: false, error: "Upload failed" } : img, ), ); } @@ -560,13 +542,8 @@ export default function ProjectChatPage() { }, [discussion, projectId, isCancelling]); const handleSendMessage = useCallback( - async ( - content: string, - imagesToSend?: ImageAttachment[], - mcqAnswer?: MCQAnswerContext, - ) => { - if (!discussion || !content.trim() || isSending || discussion.is_readonly) - return; + async (content: string, imagesToSend?: ImageAttachment[], mcqAnswer?: MCQAnswerContext) => { + if (!discussion || !content.trim() || isSending || discussion.is_readonly) return; const messageContent = content.trim(); setIsSending(true); @@ -659,21 +636,15 @@ export default function ProjectChatPage() { } catch (err) { console.error("Failed to send message:", err); // Remove optimistic message on error - setMessages((prev) => - prev.filter((m) => m.id !== optimisticMessage.id), - ); + setMessages((prev) => prev.filter((m) => m.id !== optimisticMessage.id)); // Revert sidebar message count setProjectChats((prev) => prev.map((d) => - d.id === projectChatId - ? { ...d, message_count: Math.max(0, d.message_count - 1) } - : d, + d.id === projectChatId ? { ...d, message_count: Math.max(0, d.message_count - 1) } : d, ), ); setIsAIProcessing(false); - setAIError( - err instanceof Error ? err.message : "Failed to send message", - ); + setAIError(err instanceof Error ? err.message : "Failed to send message"); } finally { setIsSending(false); } @@ -711,9 +682,7 @@ export default function ProjectChatPage() { const handleMCQClick = useCallback( (option: ProjectChatMCQOption) => { // Find the last bot message with MCQ options to get context - const lastBotMsg = messages - .filter((m) => m.message_type === "bot") - .slice(-1)[0]; + const lastBotMsg = messages.filter((m) => m.message_type === "bot").slice(-1)[0]; const mcqOpts = lastBotMsg?.response_data?.mcq_options; if (!lastBotMsg || !mcqOpts) { @@ -756,10 +725,7 @@ export default function ProjectChatPage() { setIsCreatingPhase(true); try { - const result = await apiClient.createPhaseFromDiscussion( - projectId, - discussion.id, - ); + const result = await apiClient.createPhaseFromDiscussion(projectId, discussion.id); // Navigate to the new phase - use short URL if phase info is returned if (project && result.phase_short_id && result.phase_title) { const phaseInfo = { @@ -768,9 +734,7 @@ export default function ProjectChatPage() { }; router.push(buildPhaseConversationsUrl(project, phaseInfo)); } else { - router.push( - `/projects/${projectId}/brainstorming/${result.phase_id}/conversations`, - ); + router.push(`/projects/${projectId}/brainstorming/${result.phase_id}/conversations`); } } catch (err) { console.error("Failed to create phase:", err); @@ -789,7 +753,7 @@ export default function ProjectChatPage() { const updated = await apiClient.updateProjectChatVisibility( projectId, discussion.id, - newVisibility + newVisibility, ); setProjectChat(updated); } catch (err) { @@ -798,7 +762,7 @@ export default function ProjectChatPage() { setIsUpdatingVisibility(false); } }, - [discussion, projectId] + [discussion, projectId], ); const handleCreateFeature = useCallback(async () => { @@ -806,10 +770,7 @@ export default function ProjectChatPage() { setIsCreatingFeature(true); try { - const result = await apiClient.createFeatureFromDiscussion( - projectId, - discussion.id, - ); + const result = await apiClient.createFeatureFromDiscussion(projectId, discussion.id); // Navigate to the newly created feature page - use short URL if feature info is returned if (project && result.feature_short_id && result.feature_key) { const featureInfo = { @@ -874,9 +835,7 @@ export default function ProjectChatPage() { } } catch (err) { console.error("Failed to delete discussion:", err); - setError( - err instanceof Error ? err.message : "Failed to delete discussion", - ); + setError(err instanceof Error ? err.message : "Failed to delete discussion"); } }, [projectId, discussion?.id, router, loadDiscussions, project], @@ -885,11 +844,7 @@ export default function ProjectChatPage() { const handleGoToPhase = useCallback(() => { if (discussion?.created_phase_id) { // Use short URL if we have phase info - if ( - project && - discussion.created_phase_short_id && - discussion.created_phase_title - ) { + if (project && discussion.created_phase_short_id && discussion.created_phase_title) { const phaseInfo = { title: discussion.created_phase_title, short_id: discussion.created_phase_short_id, @@ -922,9 +877,7 @@ export default function ProjectChatPage() { const userNames = { ...(existingReaction.user_names || {}) }; if (existingReaction.user_ids.includes(userId)) { // Remove user from reaction - existingReaction.user_ids = existingReaction.user_ids.filter( - (id) => id !== userId, - ); + existingReaction.user_ids = existingReaction.user_ids.filter((id) => id !== userId); delete userNames[userId]; existingReaction.user_names = userNames; existingReaction.count -= 1; @@ -987,9 +940,7 @@ export default function ProjectChatPage() { // WebSocket will handle the UI update } catch (err) { console.error("Failed to delete message:", err); - setError( - err instanceof Error ? err.message : "Failed to delete message", - ); + setError(err instanceof Error ? err.message : "Failed to delete message"); } }, [deleteDialogState.messageId, discussion, projectId]); @@ -1012,9 +963,7 @@ export default function ProjectChatPage() { // WebSocket will handle the UI update } catch (err) { console.error("Failed to start over:", err); - setError( - err instanceof Error ? err.message : "Failed to start over", - ); + setError(err instanceof Error ? err.message : "Failed to start over"); } }, [deleteDialogState.messageId, discussion, projectId, messages]); @@ -1033,10 +982,10 @@ export default function ProjectChatPage() { if (isLoading) { return ( -
      -
      - -

      Loading...

      +
      +
      + +

      Loading...

      ); @@ -1044,14 +993,10 @@ export default function ProjectChatPage() { if (error) { return ( -
      -
      -

      {error}

      -
      @@ -1060,9 +1005,7 @@ export default function ProjectChatPage() { } // Get the last bot message to check for MCQ options - const lastBotMessage = messages - .filter((m) => m.message_type === "bot") - .slice(-1)[0]; + const lastBotMessage = messages.filter((m) => m.message_type === "bot").slice(-1)[0]; const mcqOptions = lastBotMessage?.response_data?.mcq_options; const showMCQOptions = mcqOptions && @@ -1077,7 +1020,8 @@ export default function ProjectChatPage() { const isReadonly = discussion?.is_readonly ?? false; // Check if the repo banner is showing (same condition as AppTopNav) - const showRepoNotSetBanner = project && (!project.repositories || project.repositories.length === 0); + const showRepoNotSetBanner = + project && (!project.repositories || project.repositories.length === 0); // Navbar is h-16 (64px) + banner height (~32px) when shown const topOffsetClass = showRepoNotSetBanner ? "top-24" : "top-16"; @@ -1085,7 +1029,9 @@ export default function ProjectChatPage() { // Use fixed positioning to fill viewport below header if (!hasMessages) { return ( -
      +
      {/* Main area with sidebar and content */}
      {/* Sidebar */} @@ -1102,7 +1048,7 @@ export default function ProjectChatPage() { /> {/* Main content */} -
      +
      {/* Header with sidebar controls and visibility toggle */} 0} + hasCreatedContent={ + !!discussion?.created_phase_id || (discussion?.created_feature_ids?.length ?? 0) > 0 + } onNewChat={handleNewChat} onToggleSidebar={() => setIsSidebarExpanded(!isSidebarExpanded)} isSidebarExpanded={isSidebarExpanded} /> {/* Centered empty state */} -
      +
      {/* Welcome Message */} -
      - -

      +
      + +

      Create a Brainstorming Phase or Feature to Collaborate

      -

      - Tell me about what you want to build. I'll help you - articulate your idea and create a Brainstorming Phase or - feature when you're ready. +

      + Tell me about what you want to build. I'll help you articulate your idea and + create a Brainstorming Phase or feature when you're ready.

      @@ -1154,35 +1101,28 @@ export default function ProjectChatPage() { )} {/* Create Phase Banner - above input when ready */} - {discussion?.ready_to_create_phase && - !discussion?.created_phase_id && ( -
      - -
      - )} + {discussion?.ready_to_create_phase && !discussion?.created_phase_id && ( +
      + +
      + )} {/* Create Feature Banner - above input when ready */} {discussion?.ready_to_create_feature && (
      +
      {/* Main area with sidebar and content */}
      {/* Sidebar */} @@ -1299,7 +1247,7 @@ export default function ProjectChatPage() { /> {/* Main content */} -
      +
      {/* Header with sidebar controls and visibility toggle */} 0} + hasCreatedContent={ + !!discussion?.created_phase_id || (discussion?.created_feature_ids?.length ?? 0) > 0 + } onNewChat={handleNewChat} onToggleSidebar={() => setIsSidebarExpanded(!isSidebarExpanded)} isSidebarExpanded={isSidebarExpanded} /> {/* Chat Messages - constrained width for readability */} -
      -
      +
      +
      {/* Message List */}
      {messages.map((message) => { @@ -1329,7 +1276,7 @@ export default function ProjectChatPage() { ref={(el) => handleMessageRef(message.id, el)} className={cn( "transition-all duration-300", - isHighlighted && "animate-highlight-fade rounded-md" + isHighlighted && "animate-highlight-fade rounded-md", )} > - -
      - )} + {isAIProcessing && + !discussion?.is_exploring_code && + !discussion?.is_searching_web && ( +
      + +
      + )}
      {/* Input Area - sticky at bottom, constrained width */}
      -
      +
      {/* Phase Created Banner - show when phase already created */} {discussion?.created_phase_id && ( )} {/* Create Phase Banner - above input when ready */} - {discussion?.ready_to_create_phase && - !discussion?.created_phase_id && ( - - )} + {discussion?.ready_to_create_phase && !discussion?.created_phase_id && ( + + )} {/* Create Feature Banner - above input when ready */} {discussion?.ready_to_create_feature && ( 0 && ( - + )} {/* AI Error panel with retry */} @@ -1466,7 +1405,13 @@ export default function ProjectChatPage() { ? "Tag @MFBTAI to chat (switch chat visibility to Team to mention teammates)" : "Tag @MFBTAI or mention teammates. What would you like to build today?" } - disabled={isSending || isAIProcessing || !!aiError || !!discussion?.is_exploring_code || !!discussion?.is_searching_web} + disabled={ + isSending || + isAIProcessing || + !!aiError || + !!discussion?.is_exploring_code || + !!discussion?.is_searching_web + } isSubmitting={isSending} autoFocus pendingImages={pendingImages} @@ -1476,9 +1421,7 @@ export default function ProjectChatPage() { initialContent={prefillContent} contentKey={prefillKey} restrictMentionsToMfbtai={discussion?.visibility === "private"} - onTypingChange={(isTyping) => - isTyping ? sendTypingStart() : sendTypingStop() - } + onTypingChange={(isTyping) => (isTyping ? sendTypingStart() : sendTypingStop())} /> @@ -1486,9 +1429,8 @@ export default function ProjectChatPage() { {/* Readonly message */} {isReadonly && ( -
      - This conversation is read-only. A phase has been created from - this discussion. +
      + This conversation is read-only. A phase has been created from this discussion.
      )}
      @@ -1517,11 +1459,7 @@ export default function ProjectChatPage() { @@ -1551,10 +1489,7 @@ interface MessageCardProps { onToggleReaction?: (emoji: string, emojiNative: string) => Promise; } -function getDisplayName(author: { - display_name: string; - email: string; -}): string { +function getDisplayName(author: { display_name: string; email: string }): string { return author.display_name || author.email; } @@ -1608,11 +1543,11 @@ function MessageCard({ // Render system messages with subtle centered styling if (isSystem) { return ( -
      -
      -
      +
      +
      +
      {message.content} -
      +
      ); @@ -1620,7 +1555,7 @@ function MessageCard({ if (isBot) { return ( -
      +
      @@ -1628,10 +1563,10 @@ function MessageCard({ -
      +
      MFBT AI - {timeAgo} + {timeAgo}
      @@ -1644,37 +1579,30 @@ function MessageCard({ } // Get user display name from author field - const displayName = message.author - ? getDisplayName(message.author) - : "Unknown User"; + const displayName = message.author ? getDisplayName(message.author) : "Unknown User"; return ( -
      +
      - - {getInitials(displayName)} - + {getInitials(displayName)} -
      +
      {displayName} - {timeAgo} + {timeAgo}
      {/* Display attached images */} {message.images && message.images.length > 0 && ( - + )} {/* Reactions */} @@ -1688,12 +1616,12 @@ function MessageCard({ {/* Hover actions - top right corner */} {(canReact || canDelete) && ( -
      -
      +
      +
      {/* Add reaction button */} - {canReact && ( - - )} + {canReact && } {/* Dropdown menu for delete/start over - only for own user messages */} {canDelete && ( @@ -1712,14 +1640,14 @@ function MessageCard({ onClick={() => onOpenDeleteDialog(message.id, "delete")} className="text-destructive" > - + Delete onOpenDeleteDialog(message.id, "start-over")} className="text-destructive" > - + Start over from here diff --git a/frontend/app/projects/[projectId]/project-chat/page.tsx b/frontend/app/projects/[projectId]/project-chat/page.tsx index cf03b67..7a9c3a0 100644 --- a/frontend/app/projects/[projectId]/project-chat/page.tsx +++ b/frontend/app/projects/[projectId]/project-chat/page.tsx @@ -3,15 +3,15 @@ import { useEffect, useState, useCallback, useRef, useLayoutEffect } from "react"; import { useParams, useRouter, useSearchParams } from "next/navigation"; import { apiClient } from "@/lib/api/client"; -import { - Project, - ProjectChatListItem, - ImageAttachment, -} from "@/lib/api/types"; +import { Project, ProjectChatListItem, ImageAttachment } from "@/lib/api/types"; import { ProjectChatSidebar } from "@/components/ProjectChatSidebar"; import { ProjectChatHeader } from "@/components/ProjectChatHeader"; import { ChatInput, type PendingImage } from "@/components/editor/ChatInput"; -import { QuickActionButtons, QUICK_ACTIONS, type QuickActionType } from "@/components/chat/QuickActionButtons"; +import { + QuickActionButtons, + QUICK_ACTIONS, + type QuickActionType, +} from "@/components/chat/QuickActionButtons"; import { ImageAnnotationModal } from "@/components/ImageAnnotationModal"; import { Bot, Loader2 } from "lucide-react"; import { buildProjectUrl, buildDiscussionUrl } from "@/lib/url"; @@ -85,9 +85,7 @@ export default function ProjectChatPage() { } | null>(null); // Sidebar state for empty state view - const [discussions, setProjectChats] = useState( - [], - ); + const [discussions, setProjectChats] = useState([]); const [discussionsHasMore, setProjectChatsHasMore] = useState(false); const [isLoadingDiscussions, setIsLoadingDiscussions] = useState(false); const [isSidebarExpanded, setIsSidebarExpanded] = useState(() => { @@ -109,17 +107,11 @@ export default function ProjectChatPage() { apiClient.initializeAuth(); // Fetch project first - const projectData = await apiClient.get( - `/api/v1/projects/${projectId}`, - ); + const projectData = await apiClient.get(`/api/v1/projects/${projectId}`); setProject(projectData); // Try to get the list of discussions - const result = await apiClient.listProjectChats( - projectId, - 20, - 0, - ); + const result = await apiClient.listProjectChats(projectId, 20, 0); if (result.project_chats.length > 0 && !isNewChat) { // Redirect to the most recent project chat (unless explicitly starting new chat) @@ -134,9 +126,7 @@ export default function ProjectChatPage() { } } catch (err) { console.error("Failed to load discussions:", err); - setError( - err instanceof Error ? err.message : "Failed to load discussions", - ); + setError(err instanceof Error ? err.message : "Failed to load discussions"); setIsLoading(false); } }; @@ -163,16 +153,10 @@ export default function ProjectChatPage() { try { // Create discussion AND send the first message - const newDiscussion = - await apiClient.createProjectChat(projectId, { - target_container_id: targetContainerId || undefined, - }); - await apiClient.sendProjectChatMessage( - projectId, - newDiscussion.id, - messageContent, - images, - ); + const newDiscussion = await apiClient.createProjectChat(projectId, { + target_container_id: targetContainerId || undefined, + }); + await apiClient.sendProjectChatMessage(projectId, newDiscussion.id, messageContent, images); // Navigate to the new discussion router.push(buildDiscussionUrl(project, newDiscussion)); @@ -204,24 +188,16 @@ export default function ProjectChatPage() { const previewUrl = URL.createObjectURL(file); // Add pending image with uploading state - setPendingImages((prev) => [ - ...prev, - { id, file, previewUrl, uploading: true }, - ]); + setPendingImages((prev) => [...prev, { id, file, previewUrl, uploading: true }]); try { // Create discussion first (lazily) for image upload - const newDiscussion = - await apiClient.createProjectChat(projectId, { - target_container_id: targetContainerId || undefined, - }); + const newDiscussion = await apiClient.createProjectChat(projectId, { + target_container_id: targetContainerId || undefined, + }); // Upload image to the new discussion - const result = await apiClient.uploadDiscussionImage( - file, - projectId, - newDiscussion.id, - ); + const result = await apiClient.uploadDiscussionImage(file, projectId, newDiscussion.id); // Remove from pending images - we're opening the annotation modal setPendingImages((prev) => prev.filter((img) => img.id !== id)); @@ -240,9 +216,7 @@ export default function ProjectChatPage() { console.error("Failed to upload image:", err); setPendingImages((prev) => prev.map((img) => - img.id === id - ? { ...img, uploading: false, error: "Upload failed" } - : img, + img.id === id ? { ...img, uploading: false, error: "Upload failed" } : img, ), ); } @@ -279,31 +253,21 @@ export default function ProjectChatPage() { const messageContent = metadata.user_remark || "Shared an image"; try { - await apiClient.sendProjectChatMessage( - projectId, - projectChatId, - messageContent, - [metadata], - ); + await apiClient.sendProjectChatMessage(projectId, projectChatId, messageContent, [ + metadata, + ]); // Navigate to the discussion // We need to fetch the discussion to get its short_id for the URL - const discussionData = await apiClient.getProjectChatById( - projectId, - projectChatId, - ); + const discussionData = await apiClient.getProjectChatById(projectId, projectChatId); if (discussionData) { router.push(buildDiscussionUrl(project, discussionData.project_chat)); } else { - router.push( - `${buildProjectUrl(project)}/project-chat/${projectChatId}`, - ); + router.push(`${buildProjectUrl(project)}/project-chat/${projectChatId}`); } } catch (err) { console.error("Failed to send image message:", err); - setError( - err instanceof Error ? err.message : "Failed to send image message", - ); + setError(err instanceof Error ? err.message : "Failed to send image message"); } }, [annotationModalData, projectId, project, router], @@ -392,12 +356,12 @@ export default function ProjectChatPage() { if (error) { return ( -
      -
      -

      {error}

      +
      +
      +

      {error}

      @@ -408,22 +372,25 @@ export default function ProjectChatPage() { if (isLoading || !showEmptyState) { return ( -
      -
      - -

      Loading...

      +
      +
      + +

      Loading...

      ); } // Check if the repo banner is showing (same condition as AppTopNav) - const showRepoNotSetBanner = project && (!project.repositories || project.repositories.length === 0); + const showRepoNotSetBanner = + project && (!project.repositories || project.repositories.length === 0); const topOffsetClass = showRepoNotSetBanner ? "top-24" : "top-16"; // Empty state - no discussions yet return ( -
      +
      {/* Main area with sidebar and content */}
      {/* Sidebar */} @@ -440,7 +407,7 @@ export default function ProjectChatPage() { /> {/* Main content */} -
      +
      {/* Header with sidebar controls */} {/* Content area */} -
      -
      +
      +
      {/* Welcome Message */} -
      - +
      + {selectedActionType === "phase" ? ( <> -

      +

      Create a Brainstorming Phase

      -

      - Tell me about what you want to explore. I'll help you articulate - your idea and create a Brainstorming Phase when you're ready. +

      + Tell me about what you want to explore. I'll help you articulate your idea and + create a Brainstorming Phase when you're ready.

      ) : selectedActionType === "feature" ? ( <> -

      - Create a Feature -

      -

      - Tell me about what you want to build. I'll help you articulate - your idea and create a Feature when you're ready. +

      Create a Feature

      +

      + Tell me about what you want to build. I'll help you articulate your idea and + create a Feature when you're ready.

      ) : ( <> -

      +

      Create a Brainstorming Phase or Feature to Collaborate

      -

      - Tell me about what you want to build. I'll help you articulate - your idea and create a Brainstorming Phase or feature when - you're ready. +

      + Tell me about what you want to build. I'll help you articulate your idea and + create a Brainstorming Phase or feature when you're ready.

      )} @@ -493,15 +457,12 @@ export default function ProjectChatPage() { {/* Quick Action Buttons - only show when no action selected */} {!selectedActionType && (
      - +
      )} {/* Input - container is flex-1, measured by ResizeObserver */} -
      +
      {/* Switch link - show when action is selected */} {selectedActionType && ( -
      +
      -
      +
      Already have an account?{" "} +

      Loading settings...

      ); @@ -68,7 +68,7 @@ function SettingsPageContent() { if (!orgId) { return ( -
      +

      No organization found

      ); @@ -78,7 +78,7 @@ function SettingsPageContent() {

      Settings

      -

      +

      Manage your organization settings and integrations

      @@ -89,9 +89,7 @@ function SettingsPageContent() { API Keys Team Roles Testing & Debugging - {isEnterprise && ( - Slack Bot - )} + {isEnterprise && Slack Bot} diff --git a/frontend/app/settings/page.tsx b/frontend/app/settings/page.tsx index e6051c0..5a29176 100644 --- a/frontend/app/settings/page.tsx +++ b/frontend/app/settings/page.tsx @@ -8,7 +8,13 @@ export const metadata: Metadata = { export default function SettingsPage() { return ( -

      Loading settings...

      }> + +

      Loading settings...

      +
      + } + > ); diff --git a/frontend/app/trial-expired/page.tsx b/frontend/app/trial-expired/page.tsx index 71f62e9..8c05589 100644 --- a/frontend/app/trial-expired/page.tsx +++ b/frontend/app/trial-expired/page.tsx @@ -3,13 +3,7 @@ import { Clock, Mail, LogOut, Building2, Zap, Eye, Edit, CalendarClock } from "lucide-react"; import Link from "next/link"; import { Button } from "@/components/ui/button"; -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle, -} from "@/components/ui/card"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { useAuth } from "@/lib/auth/AuthContext"; import { useTrial } from "@/lib/auth/TrialContext"; @@ -22,11 +16,12 @@ export default function TrialExpiredPage() { const hasOtherOrgs = otherOrgs.length > 0; // Determine if this is a token exhaustion (freemium) or trial expiration (legacy) - const isFreemiumTokensExhausted = isTokensExhausted && trialStatus?.plan?.plan_name === "freemium"; + const isFreemiumTokensExhausted = + isTokensExhausted && trialStatus?.plan?.plan_name === "freemium"; return ( -
      - +
      +
      {isFreemiumTokensExhausted ? ( @@ -38,16 +33,16 @@ export default function TrialExpiredPage() { {isFreemiumTokensExhausted ? "Token Allocation Exhausted" : "Your Trial Has Ended"} - + {isFreemiumTokensExhausted ? ( <> - You've used all your available tokens. Your tokens will be - replenished on Monday. + You've used all your available tokens. Your tokens will be replenished on + Monday. ) : ( <> - Your 14-day free trial has expired. Upgrade to continue using MFBT - and unlock the full power of collaborative AI-assisted development. + Your 14-day free trial has expired. Upgrade to continue using MFBT and unlock the + full power of collaborative AI-assisted development. )} @@ -56,8 +51,8 @@ export default function TrialExpiredPage() { {/* For freemium users: show what they can still do */} {isFreemiumTokensExhausted && ( <> -
      -

      You can still:

      +
      +

      You can still:

      • @@ -70,34 +65,32 @@ export default function TrialExpiredPage() {
      -
      +
      - +

      - AI generation features will be available again when your tokens - are replenished on Monday. + AI generation features will be available again when your tokens are replenished + on Monday.

      {/* Go to Dashboard button for freemium users */} )} {/* Note about other orgs */} {hasOtherOrgs && ( -
      +
      - +

      - You have access to {otherOrgs.length} other organization{otherOrgs.length > 1 ? "s" : ""}. - You can continue using MFBT by switching to an active organization - from the user menu in the top navigation. + You have access to {otherOrgs.length} other organization + {otherOrgs.length > 1 ? "s" : ""}. You can continue using MFBT by switching to an + active organization from the user menu in the top navigation.

      @@ -106,9 +99,9 @@ export default function TrialExpiredPage() { {/* For legacy trial users: show contact sales */} {!isFreemiumTokensExhausted && ( <> -

      - Ready to continue? Contact us to discuss subscription options - tailored to your team's needs. +

      + Ready to continue? Contact us to discuss subscription options tailored to your + team's needs.

      {/* Contact Sales button */} @@ -127,7 +120,9 @@ export default function TrialExpiredPage() {
      - or + + or +
      diff --git a/frontend/app/verify-email/VerifyEmailClient.tsx b/frontend/app/verify-email/VerifyEmailClient.tsx index aee268f..68fd6fb 100644 --- a/frontend/app/verify-email/VerifyEmailClient.tsx +++ b/frontend/app/verify-email/VerifyEmailClient.tsx @@ -83,17 +83,21 @@ export default function VerifyEmailClient() { // Verifying state - show spinner if (pageState === "verifying") { return ( -
      +
      -

      mfbt.

      -

      move fast and build things

      +

      + mfbt. +

      +

      + move fast and build things +

      - +

      Verifying your email...

      -

      +

      Please wait while we verify your email address.

      @@ -106,17 +110,21 @@ export default function VerifyEmailClient() { // Success state if (pageState === "success") { return ( -
      +
      -

      mfbt.

      -

      move fast and build things

      +

      + mfbt. +

      +

      + move fast and build things +

      Email Verified!

      -

      +

      Your email has been verified successfully. Redirecting you to your projects...

      @@ -129,23 +137,27 @@ export default function VerifyEmailClient() { // Error state (from token verification) if (pageState === "error" && token) { return ( -
      +
      -

      mfbt.

      -

      move fast and build things

      +

      + mfbt. +

      +

      + move fast and build things +

      - + Verification Failed
      -
      +
      {error || "The verification link is invalid or has expired."}
      -

      +

      Enter your email address below to request a new verification link.

      @@ -175,12 +187,12 @@ export default function VerifyEmailClient() { )} {resendSuccess && ( -

      +

      A new verification link has been sent to your email.

      )}
      - + Back to Sign In
      @@ -192,24 +204,26 @@ export default function VerifyEmailClient() { // Pending state - waiting for user to check email return ( -
      +
      -

      mfbt.

      -

      Move fast and build things

      +

      mfbt.

      +

      Move fast and build things

      -
      -
      - +
      +
      +
      - + Check Your Email {emailSent ? ( - <>We've sent a verification link to {email || "your email"} + <> + We've sent a verification link to {email || "your email"} + ) : ( <>Please verify your email address to continue )} @@ -217,39 +231,36 @@ export default function VerifyEmailClient() { {!emailSent && ( -
      - We couldn't send the verification email automatically. Please request a new verification link below. +
      + We couldn't send the verification email automatically. Please request a new + verification link below.
      )} -
      -

      - Click the link in your email to verify your account and start using MFBT. -

      -

      - The link will expire in 24 hours. -

      +
      +

      Click the link in your email to verify your account and start using MFBT.

      +

      The link will expire in 24 hours.

      {error && ( -
      - {error} -
      +
      {error}
      )} {resendSuccess && ( -
      +
      A new verification link has been sent to your email.
      )}
      -

      +

      Didn't receive the email?

      {!email && (
      - +
      -
      +
      Back to Sign In diff --git a/frontend/components/AICommentCard.tsx b/frontend/components/AICommentCard.tsx index 3e5781b..5cde143 100644 --- a/frontend/components/AICommentCard.tsx +++ b/frontend/components/AICommentCard.tsx @@ -14,7 +14,7 @@ export function AICommentCard({ item }: AICommentCardProps) { const timeAgo = useRelativeTime(item.created_at); return ( -
      +
      @@ -22,10 +22,10 @@ export function AICommentCard({ item }: AICommentCardProps) { -
      +
      MFBT AI - {timeAgo} + {timeAgo}
      diff --git a/frontend/components/AIErrorPanel.tsx b/frontend/components/AIErrorPanel.tsx index 766f14d..876b50b 100644 --- a/frontend/components/AIErrorPanel.tsx +++ b/frontend/components/AIErrorPanel.tsx @@ -24,18 +24,14 @@ export function AIErrorPanel({ isCancelling = false, }: AIErrorPanelProps) { return ( -
      +
      - -
      -

      - Something went wrong -

      -

      - {error} -

      + +
      +

      Something went wrong

      +

      {error}

      -
      +
      {onCancel && ( )}
      -
      -
      -
      +
      +
      +
      diff --git a/frontend/components/ActivityLogCard.tsx b/frontend/components/ActivityLogCard.tsx index 4133bc3..e521074 100644 --- a/frontend/components/ActivityLogCard.tsx +++ b/frontend/components/ActivityLogCard.tsx @@ -20,28 +20,23 @@ export function ActivityLogCard({ activity }: ActivityLogCardProps) { const label = getEventTypeLabel(activity.event_type); const variant = getEventTypeVariant(activity.event_type); const relativeTime = useRelativeTime(activity.created_at); - const metadataText = formatEventMetadata( - activity.event_type, - activity.event_metadata - ); + const metadataText = formatEventMetadata(activity.event_type, activity.event_metadata); const entityLabel = getEntityTypeLabel(activity.entity_type); return ( -
      -
      -
      - +
      +
      +
      +
      -
      -
      +
      +
      {label} - {entityLabel} + {entityLabel}
      - {metadataText && ( -

      {metadataText}

      - )} -

      {relativeTime}

      + {metadataText &&

      {metadataText}

      } +

      {relativeTime}

      ); diff --git a/frontend/components/ActivityLogView.tsx b/frontend/components/ActivityLogView.tsx index ccdab4d..4fd49e8 100644 --- a/frontend/components/ActivityLogView.tsx +++ b/frontend/components/ActivityLogView.tsx @@ -43,10 +43,10 @@ export function ActivityLogView({
      ) : logs.length === 0 ? (
      -
      - +
      +
      -

      {emptyMessage}

      +

      {emptyMessage}

      ) : (
      @@ -58,12 +58,7 @@ export function ActivityLogView({ {hasMore && onLoadMore && (
      -
      diff --git a/frontend/components/AddBugTrackerConnectorModal.tsx b/frontend/components/AddBugTrackerConnectorModal.tsx index 955f1c5..eac0ea3 100644 --- a/frontend/components/AddBugTrackerConnectorModal.tsx +++ b/frontend/components/AddBugTrackerConnectorModal.tsx @@ -36,7 +36,7 @@ const DEFAULT_CONNECTOR_NAMES: Record = { "github-oauth": "My OAuth GitHub Connector", "github-pat": "My GitHub PAT Connector", "github-github_app": "My GitHub App Connector", - "jira": "My Jira Connector", + jira: "My Jira Connector", }; // Generate a unique name by appending timestamp if needed @@ -205,7 +205,12 @@ export function AddBugTrackerConnectorModal({ useEffect(() => { const handleMessage = (event: MessageEvent) => { // Validate the message has the expected shape - const data = event.data as { success?: boolean; config_id?: string; display_name?: string; error?: string }; + const data = event.data as { + success?: boolean; + config_id?: string; + display_name?: string; + error?: string; + }; if (!data || typeof data.success !== "boolean") { return; } @@ -243,10 +248,7 @@ export function AddBugTrackerConnectorModal({ }, [oauthPopup]); // Clear test result when sensitive fields change - const handleSensitiveFieldChange = ( - setter: (value: string) => void, - value: string - ) => { + const handleSensitiveFieldChange = (setter: (value: string) => void, value: string) => { setter(value); setTestResult(null); setError(""); @@ -388,7 +390,7 @@ export function AddBugTrackerConnectorModal({ const popup = window.open( response.authorization_url, "github_oauth", - `width=${width},height=${height},left=${left},top=${top},toolbar=no,menubar=no` + `width=${width},height=${height},left=${left},top=${top},toolbar=no,menubar=no`, ); if (!popup) { @@ -474,9 +476,7 @@ export function AddBugTrackerConnectorModal({ onSuccess(); handleClose(); } catch (err: unknown) { - setError( - err instanceof Error ? err.message : "Failed to create connector" - ); + setError(err instanceof Error ? err.message : "Failed to create connector"); } finally { setIsSubmitting(false); } @@ -546,9 +546,7 @@ export function AddBugTrackerConnectorModal({ - - Make available to everyone in the org - + Make available to everyone in the org Private to me (and to explicitly shared users & groups) @@ -576,20 +574,20 @@ export function AddBugTrackerConnectorModal({ {isOAuthAvailable && (
      -
      )}
      -
      -
      @@ -606,7 +604,7 @@ export function AddBugTrackerConnectorModal({ onChange={(e) => setDisplayName(e.target.value)} required /> -

      +

      A friendly name to identify this connector

      @@ -627,7 +625,7 @@ export function AddBugTrackerConnectorModal({ )} Connect with GitHub -

      +

      A popup will open for you to authorize access to your GitHub repositories.

      @@ -653,18 +651,19 @@ export function AddBugTrackerConnectorModal({ type="button" variant="ghost" size="sm" - className="absolute right-0 top-0 h-full px-3 py-2 hover:bg-transparent" + className="absolute top-0 right-0 h-full px-3 py-2 hover:bg-transparent" onClick={() => setShowApiToken(!showApiToken)} > {showApiToken ? ( - + ) : ( - + )}
      -

      - Generate from GitHub Settings > Developer settings > Personal access tokens +

      + Generate from GitHub Settings > Developer settings > Personal access + tokens

      )} @@ -678,12 +677,10 @@ export function AddBugTrackerConnectorModal({ id="appId" placeholder="123456" value={appId} - onChange={(e) => - handleSensitiveFieldChange(setAppId, e.target.value) - } + onChange={(e) => handleSensitiveFieldChange(setAppId, e.target.value)} required /> -

      +

      Found in your GitHub App settings

      @@ -699,7 +696,7 @@ export function AddBugTrackerConnectorModal({ } required /> -

      +

      Found in the installation URL: github.com/settings/installations/[ID]

      @@ -715,7 +712,7 @@ export function AddBugTrackerConnectorModal({ handleSensitiveFieldChange(setPrivateKey, e.target.value) } rows={4} - className={`font-mono text-xs ${!showPrivateKey ? "text-transparent select-none caret-foreground" : ""}`} + className={`font-mono text-xs ${!showPrivateKey ? "caret-foreground text-transparent select-none" : ""}`} style={!showPrivateKey ? { textShadow: "0 0 8px rgba(0,0,0,0.5)" } : {}} required /> @@ -723,17 +720,17 @@ export function AddBugTrackerConnectorModal({ type="button" variant="ghost" size="sm" - className="absolute right-2 top-2 h-8 px-2 hover:bg-muted" + className="hover:bg-muted absolute top-2 right-2 h-8 px-2" onClick={() => setShowPrivateKey(!showPrivateKey)} > {showPrivateKey ? ( - + ) : ( - + )}
      -

      +

      Download from GitHub App settings. Keep this secure.

      @@ -750,7 +747,7 @@ export function AddBugTrackerConnectorModal({ value={orgName} onChange={(e) => setOrgName(e.target.value)} /> -

      +

      Optional. For PAT mode, limits repository visibility to this organization.

      @@ -768,14 +765,12 @@ export function AddBugTrackerConnectorModal({ onClick={handleFetchRepos} disabled={isLoadingRepos} > - {isLoadingRepos ? ( - - ) : null} + {isLoadingRepos ? : null} {availableRepos.length > 0 ? "Refresh" : "Load Repositories"}
      {availableRepos.length > 0 && ( -
      +
      {availableRepos.map((repo) => (
      r !== repo.full_name)); + setSelectedRepos( + selectedRepos.filter((r) => r !== repo.full_name), + ); } }} /> -
      ))}
      )} {availableRepos.length === 0 && !isLoadingRepos && ( -

      - Click "Load Repositories" to select specific repositories (optional). +

      + Click "Load Repositories" to select specific repositories + (optional).

      )}
      @@ -821,7 +820,7 @@ export function AddBugTrackerConnectorModal({ onChange={(e) => setDisplayName(e.target.value)} required /> -

      +

      A friendly name to identify this connector

      @@ -832,34 +831,25 @@ export function AddBugTrackerConnectorModal({ id="baseUrl" placeholder="https://your-domain.atlassian.net" value={baseUrl} - onChange={(e) => - handleSensitiveFieldChange(setBaseUrl, e.target.value) - } + onChange={(e) => handleSensitiveFieldChange(setBaseUrl, e.target.value)} required /> -

      - Your Jira Cloud instance URL -

      +

      Your Jira Cloud instance URL

      - + - handleSensitiveFieldChange( - setServiceAccountEmail, - e.target.value - ) + handleSensitiveFieldChange(setServiceAccountEmail, e.target.value) } required /> -

      +

      The email associated with the API token

      @@ -872,9 +862,7 @@ export function AddBugTrackerConnectorModal({ type={showApiToken ? "text" : "password"} placeholder="Enter your Jira API token" value={apiToken} - onChange={(e) => - handleSensitiveFieldChange(setApiToken, e.target.value) - } + onChange={(e) => handleSensitiveFieldChange(setApiToken, e.target.value)} className="pr-10" required /> @@ -882,19 +870,18 @@ export function AddBugTrackerConnectorModal({ type="button" variant="ghost" size="sm" - className="absolute right-0 top-0 h-full px-3 py-2 hover:bg-transparent" + className="absolute top-0 right-0 h-full px-3 py-2 hover:bg-transparent" onClick={() => setShowApiToken(!showApiToken)} > {showApiToken ? ( - + ) : ( - + )}
      -

      - Generate from Atlassian Account Settings > Security > - API Tokens +

      + Generate from Atlassian Account Settings > Security > API Tokens

      @@ -904,11 +891,9 @@ export function AddBugTrackerConnectorModal({ id="projectKeys" placeholder="PROJ, TEAM, BACKEND" value={projectKeys} - onChange={(e) => - handleSensitiveFieldChange(setProjectKeys, e.target.value) - } + onChange={(e) => handleSensitiveFieldChange(setProjectKeys, e.target.value)} /> -

      +

      Comma-separated list of Jira project keys (optional)

      @@ -920,8 +905,8 @@ export function AddBugTrackerConnectorModal({
      {testResult.message} @@ -930,7 +915,7 @@ export function AddBugTrackerConnectorModal({ {/* Error display */} {error && ( -
      +
      {error}
      )} @@ -946,25 +931,16 @@ export function AddBugTrackerConnectorModal({ type="button" variant="outline" onClick={handleTestConnection} - disabled={ - isTestingConnection || - !baseUrl || - !serviceAccountEmail || - !apiToken - } + disabled={isTestingConnection || !baseUrl || !serviceAccountEmail || !apiToken} > - {isTestingConnection && ( - - )} + {isTestingConnection && } Test Connection @@ -977,18 +953,14 @@ export function AddBugTrackerConnectorModal({ onClick={handleTestConnection} disabled={isTestingConnection || !canTestGitHub} > - {isTestingConnection && ( - - )} + {isTestingConnection && } Test Connection diff --git a/frontend/components/AddLLMConnectorModal.tsx b/frontend/components/AddLLMConnectorModal.tsx index 7bfd4c6..0126d0e 100644 --- a/frontend/components/AddLLMConnectorModal.tsx +++ b/frontend/components/AddLLMConnectorModal.tsx @@ -230,7 +230,7 @@ export function AddLLMConnectorModal({ }} rows={3} /> -

      +

      Optional JSON configuration for model settings

      @@ -239,8 +239,8 @@ export function AddLLMConnectorModal({
      {testResult.message} @@ -248,7 +248,7 @@ export function AddLLMConnectorModal({ )} {error && ( -
      +
      {error}
      )} @@ -261,15 +261,10 @@ export function AddLLMConnectorModal({ onClick={handleTestConnection} disabled={isTestingConnection || !provider || !apiKey} > - {isTestingConnection && ( - - )} + {isTestingConnection && } Test Connection - diff --git a/frontend/components/AddReactionButton.tsx b/frontend/components/AddReactionButton.tsx index 97d02ee..7e64cca 100644 --- a/frontend/components/AddReactionButton.tsx +++ b/frontend/components/AddReactionButton.tsx @@ -12,10 +12,7 @@ interface AddReactionButtonProps { disabled?: boolean; } -export function AddReactionButton({ - onSelectEmoji, - disabled = false, -}: AddReactionButtonProps) { +export function AddReactionButton({ onSelectEmoji, disabled = false }: AddReactionButtonProps) { const [isPickerOpen, setIsPickerOpen] = useState(false); const [isSubmitting, setIsSubmitting] = useState(false); const { resolvedTheme } = useTheme(); @@ -45,11 +42,7 @@ export function AddReactionButton({ - + -