diff --git a/ai-platform-backend/alembic/versions/b0c1d2e3f4a5_merge_heads.py b/ai-platform-backend/alembic/versions/b0c1d2e3f4a5_merge_heads.py new file mode 100644 index 0000000..4544daa --- /dev/null +++ b/ai-platform-backend/alembic/versions/b0c1d2e3f4a5_merge_heads.py @@ -0,0 +1,23 @@ +"""Merge heads: trello fields + subscription status fields + +Revision ID: b0c1d2e3f4a5 +Revises: a9b0c1d2e3f4, b2c3d4e5f6a7 +Create Date: 2026-04-09 00:00:00.000000 + +""" + +from collections.abc import Sequence + +# revision identifiers, used by Alembic. +revision: str = "b0c1d2e3f4a5" +down_revision = ("a9b0c1d2e3f4", "b2c3d4e5f6a7") +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/alembic/versions/a9b0c1d2e3f4_add_trello_fields_to_user_tokens.py b/alembic/versions/a9b0c1d2e3f4_add_trello_fields_to_user_tokens.py new file mode 100644 index 0000000..4eccafc --- /dev/null +++ b/alembic/versions/a9b0c1d2e3f4_add_trello_fields_to_user_tokens.py @@ -0,0 +1,31 @@ +"""Add trello_workspace_id and trello_board_id to user_tokens + +Revision ID: a9b0c1d2e3f4 +Revises: f8a9b0c1d2e3 +Create Date: 2026-04-09 00:00:00.000000 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a9b0c1d2e3f4" +down_revision: str | Sequence[str] | None = "f8a9b0c1d2e3" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Add Trello workspace and board ID columns to user_tokens table.""" + op.add_column("user_tokens", sa.Column("trello_workspace_id", sa.String(), nullable=True)) + op.add_column("user_tokens", sa.Column("trello_board_id", sa.String(), nullable=True)) + + +def downgrade() -> None: + """Remove Trello columns from user_tokens table.""" + op.drop_column("user_tokens", "trello_board_id") + op.drop_column("user_tokens", "trello_workspace_id") diff --git a/alembic/versions/b0c1d2e3f4a5_merge_heads.py b/alembic/versions/b0c1d2e3f4a5_merge_heads.py new file mode 100644 index 0000000..4544daa --- /dev/null +++ b/alembic/versions/b0c1d2e3f4a5_merge_heads.py @@ -0,0 +1,23 @@ +"""Merge heads: trello fields + subscription status fields + +Revision ID: b0c1d2e3f4a5 +Revises: a9b0c1d2e3f4, b2c3d4e5f6a7 +Create Date: 2026-04-09 00:00:00.000000 + +""" + +from collections.abc import Sequence + +# revision identifiers, used by Alembic. +revision: str = "b0c1d2e3f4a5" +down_revision = ("a9b0c1d2e3f4", "b2c3d4e5f6a7") +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/apply_github_migration.py b/apply_github_migration.py deleted file mode 100644 index 4d5a64a..0000000 --- a/apply_github_migration.py +++ /dev/null @@ -1,79 +0,0 @@ -import os - -from dotenv import load_dotenv -from sqlalchemy import create_engine, text - -# Load environment variables -load_dotenv() - -DATABASE_URL = os.getenv("DATABASE_URL") -if not DATABASE_URL: - print("DATABASE_URL not found in .env") - exit(1) - -# Ensure the URL is compatible with SQLAlchemy -if DATABASE_URL.startswith("postgres://"): - DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://", 1) - -engine = create_engine(DATABASE_URL) - - -def apply_migrations(): - with engine.connect() as conn: - print("Checking for missing columns in 'users' table...") - - # Check if github_access_token exists - result = conn.execute( - text(""" - SELECT column_name - FROM information_schema.columns - WHERE table_name='users' AND column_name='github_access_token'; - """) - ).fetchone() - - if not result: - print("Adding 'github_access_token' column to 'users' table...") - conn.execute(text("ALTER TABLE users ADD COLUMN github_access_token VARCHAR;")) - conn.commit() - print("Successfully added 'github_access_token'.") - else: - print("'github_access_token' column already exists.") - - print("Checking for 'github_webhooks' table...") - result = conn.execute( - text(""" - SELECT table_name - FROM information_schema.tables - WHERE table_name='github_webhooks'; - """) - ).fetchone() - - if not result: - print("Creating 'github_webhooks' table...") - conn.execute( - text(""" - CREATE TABLE github_webhooks ( - id SERIAL PRIMARY KEY, - user_id INTEGER REFERENCES users(id), - repo_id VARCHAR, - repo_full_name VARCHAR, - webhook_id VARCHAR, - is_active BOOLEAN DEFAULT TRUE, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - CREATE INDEX ix_github_webhooks_id ON github_webhooks (id); - CREATE INDEX ix_github_webhooks_user_id ON github_webhooks (user_id); - """) - ) - conn.commit() - print("Successfully created 'github_webhooks' table.") - else: - print("'github_webhooks' table already exists.") - - -if __name__ == "__main__": - try: - apply_migrations() - print("Migrations applied successfully!") - except Exception as e: - print(f"Error applying migrations: {e}") diff --git a/apps/agents/agent_server/config.py b/apps/agents/agent_server/config.py deleted file mode 100644 index 19c452b..0000000 --- a/apps/agents/agent_server/config.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -from pathlib import Path -from typing import Optional - -from dotenv import load_dotenv - -from apps.agents.agent_server.src.common.jira_client import JiraClient -from apps.agents.agent_server.src.server.context import current_user_token - -# Load .env from project root -current_file = Path(__file__).resolve() -project_root = current_file.parent.parent.parent.parent -env_path = project_root / ".env" - -if env_path.exists(): - load_dotenv(env_path) -else: - load_dotenv() - - -def get_jira( - access_token: str | None = None, - refresh_token: str | None = None, - cloud_id: str | None = None, -): - """ - Creates a JiraClient instance with user-specific OAuth tokens. - - If tokens are not provided explicitly, attempts to retrieve them from - the current request context (set via middleware). - - Args: - access_token: User's Jira access token - refresh_token: User's Jira refresh token - cloud_id: Optional cloud_id (will auto-detect if not provided) - - Returns: - JiraClient: Authenticated Jira client for this user - """ - client_id = os.getenv("JIRA_CLIENT_ID") - client_secret = os.getenv("JIRA_CLIENT_SECRET") - - if not client_id or not client_secret: - raise ValueError( - "Server configuration error: JIRA_CLIENT_ID/SECRET missing from environment variables." - ) - - # Try to get from context if not provided - if not access_token: - ctx_token = current_user_token.get() - if ctx_token: - access_token = ctx_token.get("access_token") - if not refresh_token: - refresh_token = ctx_token.get("refresh_token") - if not cloud_id: - cloud_id = ctx_token.get("cloud_id") - - if not access_token: - # Fallback for headless/startup checks (might fail if truly no token) - # But we don't want to crash on import - pass - - # Create JiraClient with user tokens - # Callback to persist refreshed tokens. - # libs/common/jira_client.py calls this without await, so we expose a sync wrapper - # that schedules the async DB work onto the running event loop. - async def _do_save_token(new_tokens, email: str): - try: - import asyncio - - from sqlalchemy import select - - from libs.common.db_utils import managed_async_db_session - from libs.common.models import User, UserToken - - async with managed_async_db_session() as db: - result = await db.execute(select(User).where(User.email == email)) - user = result.scalar_one_or_none() - if not user: - return - - if new_tokens is None: - # TOKEN INVALIDATION SIGNAL - print( - f"DEBUG: Token invalidation signal received for {email}. Deleting Jira connection." - ) - if user.jira_token: - await db.delete(user.jira_token) - await db.commit() - return - - if user.jira_token: - user.jira_token.access_token = new_tokens["access_token"] - user.jira_token.refresh_token = new_tokens["refresh_token"] - if new_tokens.get("cloud_id"): - user.jira_token.cloud_id = new_tokens.get("cloud_id") - else: - token = UserToken( - user_id=user.id, - access_token=new_tokens["access_token"], - refresh_token=new_tokens["refresh_token"], - cloud_id=new_tokens.get("cloud_id"), - ) - db.add(token) - await db.commit() - except Exception as e: - print(f"Error saving tokens: {e}") - - def save_token_callback(new_tokens): - # We need the JWT to authorize the save - ctx = current_user_token.get() - if not ctx or "email" not in ctx: - print("Warning: No user email found in context, cannot save refreshed tokens.") - return - - import asyncio - - email = ctx["email"] - try: - loop = asyncio.get_running_loop() - loop.create_task(_do_save_token(new_tokens, email)) - except RuntimeError: - # No running loop — fall back to asyncio.run (should not happen in production) - asyncio.run(_do_save_token(new_tokens, email)) - - return JiraClient( - client_id=client_id, - client_secret=client_secret, - site_id=cloud_id, - access_token=access_token, - refresh_token=refresh_token, - token_saver=save_token_callback, - ) diff --git a/apps/agents/agent_server/main.py b/apps/agents/agent_server/main.py index 36fd344..8b070d7 100644 --- a/apps/agents/agent_server/main.py +++ b/apps/agents/agent_server/main.py @@ -2,7 +2,7 @@ import sys # Import mcp instance -from apps.agents.agent_server.src.server.tools import mcp +from apps.agents.agent_server.tools import mcp # Import submodules to register tools diff --git a/apps/agents/agent_server/response_utils.py b/apps/agents/agent_server/response_utils.py index ab2aef0..fb4d7c0 100644 --- a/apps/agents/agent_server/response_utils.py +++ b/apps/agents/agent_server/response_utils.py @@ -4,7 +4,7 @@ import json import re -from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( +from apps.agents.agent_server.tools.jira_task_proposal_tools import ( get_proposed_delete_tasks, get_proposed_tasks, ) diff --git a/apps/agents/agent_server/routes/chat.py b/apps/agents/agent_server/routes/chat.py index d7080ca..38b1312 100644 --- a/apps/agents/agent_server/routes/chat.py +++ b/apps/agents/agent_server/routes/chat.py @@ -1,7 +1,6 @@ """Chat endpoint: SSE streaming, HITL approval loop, and Agent mode orchestration.""" import asyncio -import inspect import json import logging import traceback @@ -20,13 +19,19 @@ is_valid_approval, ) from apps.agents.agent_server.schemas import ChatRequest +from apps.agents.agent_server.services.agent_service import run_agent_mode +from apps.agents.agent_server.services.tool_executor import ( + SENSITIVE_TOOLS, + build_decision_request, + execute_tool_calls, +) from apps.agents.agent_server.session_manager import session_manager from apps.agents.agent_server.src.cli.config import FAST_MODEL, SMART_MODEL from apps.agents.agent_server.src.cli.modes import load_prompt from apps.agents.agent_server.src.common.redis_client import RedisManager from apps.agents.agent_server.src.server.config import get_jira from apps.agents.agent_server.src.server.context_builder import get_genius_context -from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( +from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_delete_tasks, clear_proposed_tasks, ) @@ -36,215 +41,78 @@ from libs.common.subscription_limits import FEATURE_CHAT_LIMIT, get_tier_limits logger = logging.getLogger(__name__) - router = APIRouter() -SENSITIVE_TOOLS = { - "create_issue", - "update_issue", - "delete_issue", - "bulk_delete_issues", - "transition_issue", - "bulk_transition_issues", - "bulk_update_issues", - "add_comment", - "log_work", -} - - -async def run_agent_mode(session_id: str, user_msg: str, user_id: int | None = None): - """Executes the Agent Mode (Planner -> Executor) logic.""" - try: - effective_project = await _get_chat_project_key(user_id, None) - project_context = "" - try: - jira_client = get_jira() - project_context = await get_genius_context(jira_client, effective_project) - except Exception as e: - logger.warning(f"[Agent] Failed to fetch genius context: {e}") - - planner_sys = load_prompt("planner.md") - sys_prompt = session_manager.get_system_prompt(project_context) - planner_sys = sys_prompt + "\n\nPLANNER SPECIFIC:\n" + planner_sys - - user_ctx = await get_user_context() - user_name = user_ctx["name"] - user_language = user_ctx["language"] - - context_injection = f"\n\nSYSTEM CONTEXT:\n- User Name: {user_name}\n- User Language Preference: {user_language}\n- Current Project: {effective_project}\n- Current Date: {date.today().isoformat()}" - language_instruction = f"\n\nCRITICAL INSTRUCTION: You MUST reply in {user_language}. Do not change language based on user input language. Your response must be exclusively in {user_language}." - - planner_sys += context_injection + language_instruction - - plan_response = await session_manager.client.aio.models.generate_content( - model=SMART_MODEL, - contents=[types.Content(role="user", parts=[types.Part(text=user_msg)])], - config=types.GenerateContentConfig(system_instruction=planner_sys), - ) - - if ( - not plan_response.candidates - or not plan_response.candidates[0].content - or not plan_response.candidates[0].content.parts - ): - return build_chat_response( - "Planning failed. Model returned no content.", - session_id, - msg_id=None, - ) - plan_text = plan_response.candidates[0].content.parts[0].text - - try: - clean_json = plan_text.replace("```json", "").replace("```", "").strip() - steps = json.loads(clean_json) - if not isinstance(steps, list): - raise ValueError("Not a list") - except Exception: - return build_chat_response( - f"Planning failed. Response: {plan_text}", - session_id, - msg_id=None, - ) - - executor_sys = session_manager.get_system_prompt(project_context) - executor_sys += context_injection + language_instruction - agent_session_id = f"{session_id}_agent" - - if not session_manager.get_history(agent_session_id): - await session_manager.update_history( - agent_session_id, - types.Content(role="user", parts=[types.Part(text=f"Objective: {user_msg}")]), - ) - - accumulated_output = [f"**Objective:** {user_msg}\n\n**Plan:**\n"] - for i, step in enumerate(steps): - accumulated_output.append(f"{i + 1}. {step}") - accumulated_output.append("\n---\n") - - for i, step in enumerate(steps): - logger.info(f"[Agent] Executing Step {i + 1}/{len(steps)}: {step}") - await session_manager.update_history( - agent_session_id, - types.Content( - role="user", - parts=[types.Part(text=f"Execute Step {i + 1}: {step}")], - ), - ) - - step_output = "" - for turn in range(3): - agent_history = session_manager.get_history(agent_session_id) - try: - resp = await session_manager.client.aio.models.generate_content( - model=FAST_MODEL, - contents=agent_history, - config=types.GenerateContentConfig( - system_instruction=executor_sys, - tools=session_manager.tools, - ), - ) - except Exception as e: - logger.error(f"Step {i + 1} turn {turn + 1} generation failed: {e}") - step_output = f"Execution error: {str(e)}" - break - if not resp.candidates: - break - candidate = resp.candidates[0] - if not candidate.content or not candidate.content.parts: - break +def _build_system_instruction( + project_context: str, + mode: str, + user_ctx: dict, + effective_project: str, +) -> tuple[str, str]: + """Build system instruction and model_id for the given mode.""" + system_instruction = session_manager.get_system_prompt(project_context) + model_id = FAST_MODEL - parts = candidate.content.parts - function_calls = [p.function_call for p in parts if p.function_call] + if mode == "Expert": + system_instruction += "\n\nEXPERT SPECIALIST:\n" + load_prompt("expert.md") + model_id = SMART_MODEL + elif mode == "Architect": + system_instruction += "\n\nARCHITECT SPECIALIST:\n" + load_prompt("architect.md") + model_id = SMART_MODEL - if function_calls: - await session_manager.update_history( - agent_session_id, - types.Content(role="model", parts=parts), - ) - outputs = [] - for fc in function_calls: - func = next( - (t for t in session_manager.tools if t.__name__ == fc.name), - None, - ) - if func: - try: - sig = inspect.signature(func) - valid_args = { - k: v for k, v in fc.args.items() if k in sig.parameters - } - if len(valid_args) < len(fc.args): - invalid_keys = set(fc.args.keys()) - set(valid_args.keys()) - logger.warning( - f"Filtered out invalid parameters for {fc.name}: {invalid_keys}" - ) - - if inspect.iscoroutinefunction(func): - res = await func(**valid_args) - else: - res = func(**valid_args) - - outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"result": str(res)} - ) - ) - ) - except Exception as e: - logger.error(f"Tool {fc.name} failed: {e}") - outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"error": str(e)} - ) - ) - ) - else: - outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, - response={"error": f"Tool {fc.name} not found"}, - ) - ) - ) + bot_name = user_ctx.get("bot_name", "Kwillo Agent") + user_name = user_ctx.get("name", "User") + user_language = user_ctx.get("language", "EN") - if outputs: - await session_manager.update_history( - agent_session_id, - types.Content(role="user", parts=outputs), - ) - continue + system_instruction += ( + f"\n\nIDENTITY: You are {bot_name}, an AI assistant." + f"\n\nSYSTEM CONTEXT:\n- User Name: {user_name}" + f"\n- User Language: {user_language}" + f"\n- Project: {effective_project}" + f"\n- Current Date: {date.today().isoformat()}" + f"\n\nCRITICAL: Reply in {user_language}." + ) + return system_instruction, model_id - text = "".join([p.text for p in parts if p.text]) - if text: - step_output = text - await session_manager.update_history( - agent_session_id, - types.Content(role="model", parts=parts), - ) - break - accumulated_output.append(f"### Step {i + 1}: {step}\n{step_output}\n") +async def _check_quota(user_id: int, tier: str, session_id: str) -> dict | None: + """Check monthly chat quota. Returns error response dict if exceeded, else None.""" + import datetime - final_response = "\n".join(accumulated_output) + limits = get_tier_limits(tier) + max_msgs = limits.get(FEATURE_CHAT_LIMIT, 30) + if max_msgs == -1: + return None - await session_manager.update_history( - session_id, - types.Content(role="user", parts=[types.Part(text=user_msg)]), - ) - await session_manager.update_history( - session_id, - types.Content(role="model", parts=[types.Part(text=final_response)]), - ) + now = datetime.datetime.now() + month_key = f"chat_usage:{user_id}:{now.strftime('%Y-%m')}" + try: + current_usage = await RedisManager.get(month_key) + if int(current_usage or 0) >= max_msgs: + return build_chat_response( + f"You have reached your monthly limit of {max_msgs} messages for the {tier} plan.", + session_id, + None, + ) + await RedisManager.incr(month_key) + await RedisManager.expire(month_key, 60 * 60 * 24 * 30) + except Exception as e: + logger.error(f"Redis quota check failed: {e}") + return None - return build_chat_response(final_response, session_id, msg_id=None) +async def _ensure_session_exists(session_id: str, user_id: int, user_msg: str) -> None: + """Persist chat session to DB if it doesn't exist yet.""" + try: + async with managed_async_db_session() as db: + result = await db.execute(select(ChatSession).where(ChatSession.id == session_id)) + if not result.scalar_one_or_none(): + title = user_msg[:50] + ("..." if len(user_msg) > 50 else "") + db.add(ChatSession(id=session_id, user_id=user_id, title=title)) + await db.commit() except Exception as e: - traceback.print_exc() - return {"type": "error", "data": f"Agent Error: {str(e)}"} + logger.error(f"Failed to record session: {e}") async def _chat_stream(request: ChatRequest, user_ctx: dict): @@ -257,54 +125,25 @@ def _sse(data: dict) -> str: user_msg = request.message mode = request.mode user_id = user_ctx.get("id") - user_name = user_ctx.get("name", "User") - user_language = user_ctx.get("language", "EN") tier = user_ctx.get("subscription_tier", "free") - # Agent mode: no streaming if mode == "Agent": result = await run_agent_mode(session_id, user_msg, user_id) yield _sse(result) yield "data: [DONE]\n\n" return - # Quota check if user_id: - import datetime - - limits = get_tier_limits(tier) - max_msgs = limits.get(FEATURE_CHAT_LIMIT, 30) - if max_msgs != -1: - now = datetime.datetime.now() - month_key = f"chat_usage:{user_id}:{now.strftime('%Y-%m')}" - try: - current_usage = await RedisManager.get(month_key) - if int(current_usage or 0) >= max_msgs: - text = f"You have reached your monthly limit of {max_msgs} messages for the {tier} plan." - yield _sse(build_chat_response(text, session_id, None)) - yield "data: [DONE]\n\n" - return - await RedisManager.incr(month_key) - await RedisManager.expire(month_key, 60 * 60 * 24 * 30) - except Exception as e: - logger.error(f"Redis quota check failed: {e}") - - # Session persistence - if user_id: - try: - async with managed_async_db_session() as db: - result = await db.execute(select(ChatSession).where(ChatSession.id == session_id)) - if not result.scalar_one_or_none(): - title = user_msg[:50] + ("..." if len(user_msg) > 50 else "") - db.add(ChatSession(id=session_id, user_id=user_id, title=title)) - await db.commit() - except Exception as e: - logger.error(f"Failed to record session: {e}") + quota_err = await _check_quota(user_id, tier, session_id) + if quota_err: + yield _sse(quota_err) + yield "data: [DONE]\n\n" + return + await _ensure_session_exists(session_id, user_id, user_msg) clear_proposed_tasks() clear_proposed_delete_tasks() - model_id = FAST_MODEL effective_project = await _get_chat_project_key(user_id, request.project_key) project_context = "" try: @@ -312,27 +151,15 @@ def _sse(data: dict) -> str: except Exception as e: logger.warning(f"Failed to fetch genius context: {e}") - system_instruction = session_manager.get_system_prompt(project_context) - if mode == "Expert": - system_instruction += "\n\nEXPERT SPECIALIST:\n" + load_prompt("expert.md") - model_id = SMART_MODEL - elif mode == "Architect": - system_instruction += "\n\nARCHITECT SPECIALIST:\n" + load_prompt("architect.md") - model_id = SMART_MODEL - - bot_name = user_ctx.get("bot_name", "Kwillo Agent") - system_instruction += ( - f"\n\nIDENTITY: You are {bot_name}, an AI assistant." - f"\n\nSYSTEM CONTEXT:\n- User Name: {user_name}\n- User Language: {user_language}\n- Project: {effective_project}\n- Current Date: {date.today().isoformat()}" - f"\n\nCRITICAL: Reply in {user_language}." + system_instruction, model_id = _build_system_instruction( + project_context, mode, user_ctx, effective_project ) # HITL pending tool check — fall back to non-streaming for this rare path history = session_manager.get_history(session_id) if history and history[-1].role == "model": last_parts = history[-1].parts - pending_calls = [p.function_call for p in last_parts if p.function_call] - if pending_calls: + if any(p.function_call for p in last_parts): result = await chat_endpoint(request, user_ctx) if isinstance(result, dict): yield _sse(result) @@ -359,7 +186,6 @@ def _sse(data: dict) -> str: try: max_turns = 15 _consecutive_empty = 0 - _last_tool_names: list[str] = [] for _turn_count in range(1, max_turns + 1): current_history = session_manager.get_history(session_id) full_text = "" @@ -383,45 +209,35 @@ def _sse(data: dict) -> str: yield _sse({"type": "token", "token": chunk.text}) if not function_calls_found: - if not full_text.strip(): - # Mid-task empty response: nudge the model to continue instead of stopping. - if _turn_count > 1: - _consecutive_empty += 1 - if _consecutive_empty >= 3: - # Model is stuck in a loop — abort gracefully. - logger.warning( - f"Model stuck after {_consecutive_empty} consecutive empty responses, aborting." - ) - full_text = "Action completed successfully." - msg_id = await session_manager.update_history( - session_id, - types.Content(role="model", parts=[types.Part(text=full_text)]), - ) - yield _sse(build_chat_response(full_text, session_id, msg_id)) - yield "data: [DONE]\n\n" - return - - # After first empty: gentle nudge. After second: directive nudge. - if _consecutive_empty == 1: - nudge_text = ( - "Continue. Use the tool results above to complete the task." - ) - else: - nudge_text = ( - "IMPORTANT: You already have all the lookup results you need. " - "Do NOT call any lookup tools again (search_users, list_sprints, etc.). " - "Proceed immediately to the final action (create_issue, update_issue, etc.) " - "using the data already retrieved." - ) - logger.info( - f"Empty response mid-task (turn={_turn_count}, consecutive={_consecutive_empty}), injecting nudge." + if not full_text.strip() and _turn_count > 1: + _consecutive_empty += 1 + if _consecutive_empty >= 3: + logger.warning("Model stuck, aborting after 3 empty responses.") + full_text = "Action completed successfully." + msg_id = await session_manager.update_history( + session_id, + types.Content(role="model", parts=[types.Part(text=full_text)]), ) - # Only update in-memory — nudge messages must NOT be persisted to DB - # or they'll appear as user messages in chat history on reload. - session_manager.histories.setdefault(session_id, []).append( - types.Content(role="user", parts=[types.Part(text=nudge_text)]) + yield _sse(build_chat_response(full_text, session_id, msg_id)) + yield "data: [DONE]\n\n" + return + + nudge_text = ( + "Continue. Use the tool results above to complete the task." + if _consecutive_empty == 1 + else ( + "IMPORTANT: You already have all the lookup results you need. " + "Do NOT call any lookup tools again. " + "Proceed immediately to the final action using the data already retrieved." ) - continue + ) + logger.info(f"Empty response mid-task (turn={_turn_count}), injecting nudge.") + session_manager.histories.setdefault(session_id, []).append( + types.Content(role="user", parts=[types.Part(text=nudge_text)]) + ) + continue + + if not full_text.strip(): full_text = "Action completed successfully." _consecutive_empty = 0 clean_text, chart_configs = extract_chart_configs(full_text) @@ -442,59 +258,16 @@ def _sse(data: dict) -> str: pending_sensitive = [fc for fc in function_calls_found if fc.name in SENSITIVE_TOOLS] if pending_sensitive: - fc = pending_sensitive[0] - args = dict(fc.args) if fc.args else {} - yield _sse( - { - "type": "decision_request", - "session_id": session_id, - "data": { - "title": "Approval Required", - "reason": f"Agent wants to call: {fc.name}", - "policy": "confirm", - "actions": [{"label": "Approve", "value": "yes"}], - "pending_action": {"type": "tool_call", "name": fc.name, "args": args}, - "pending_actions": [ - {"type": "tool_call", "name": f.name, "args": dict(f.args)} - for f in pending_sensitive - ], - }, - } - ) + decision = build_decision_request(session_id, pending_sensitive) + # Override type/reason for stream variant (simpler format) + decision["data"]["reason"] = f"Agent wants to call: {pending_sensitive[0].name}" + decision["data"]["policy"] = "confirm" + yield _sse(decision) yield "data: [DONE]\n\n" return - tool_outputs = [] - for fc in function_calls_found: - func = next((t for t in session_manager.tools if t.__name__ == fc.name), None) - if func: - try: - sig = inspect.signature(func) - valid_args = {k: v for k, v in fc.args.items() if k in sig.parameters} - res = ( - await func(**valid_args) - if inspect.iscoroutinefunction(func) - else func(**valid_args) - ) - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"result": str(res)} - ) - ) - ) - except Exception as e: - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"error": str(e)} - ) - ) - ) - - # Tool calls were made — reset the empty-response counter + tool_outputs = await execute_tool_calls(function_calls_found, session_manager.tools) _consecutive_empty = 0 - if tool_outputs: await session_manager.update_history( session_id, types.Content(role="user", parts=tool_outputs) @@ -521,10 +294,7 @@ async def chat_endpoint(request: ChatRequest, user_ctx: dict = Depends(get_user_ session_id = request.session_id or str(uuid.uuid4()) user_msg = request.message mode = request.mode - user_id = user_ctx.get("id") - user_name = user_ctx.get("name", "User") - user_language = user_ctx.get("language", "EN") tier = user_ctx.get("subscription_tier", "free") if not user_id: @@ -533,87 +303,32 @@ async def chat_endpoint(request: ChatRequest, user_ctx: dict = Depends(get_user_ detail="Authentication required. Please log in to use the chat.", ) - logger.info(f"Chat request for session {session_id}, user {user_id} ({user_name})") - - # Quota check - if user_id: - limits = get_tier_limits(tier) - max_msgs = limits.get(FEATURE_CHAT_LIMIT, 30) + logger.info(f"Chat request for session {session_id}, user {user_id}") - if max_msgs != -1: - import datetime + quota_err = await _check_quota(user_id, tier, session_id) + if quota_err: + return quota_err - now = datetime.datetime.now() - month_key = f"chat_usage:{user_id}:{now.strftime('%Y-%m')}" - - try: - current_usage = await RedisManager.get(month_key) - current_count = int(current_usage) if current_usage else 0 - - if current_count >= max_msgs: - return build_chat_response( - f"You have reached your monthly limit of {max_msgs} messages for the {tier} plan. Please upgrade to continue.", - session_id, - msg_id=None, - ) - - await RedisManager.incr(month_key) - await RedisManager.expire(month_key, 60 * 60 * 24 * 30) - except Exception as e: - logger.error(f"Redis quota check failed: {e}") - - # Session persistence - if user_id: - - async def _save_sess(): - try: - async with managed_async_db_session() as db: - result = await db.execute( - select(ChatSession).where(ChatSession.id == session_id) - ) - session_record = result.scalar_one_or_none() - if not session_record: - title = user_msg[:50] + ("..." if len(user_msg) > 50 else "") - db.add(ChatSession(id=session_id, user_id=user_id, title=title)) - await db.commit() - except Exception as e: - logger.error(f"Failed to record session in DB: {e}") - - await _save_sess() + await _ensure_session_exists(session_id, user_id, user_msg) clear_proposed_tasks() clear_proposed_delete_tasks() - model_id = FAST_MODEL - effective_project = await _get_chat_project_key(user_id, request.project_key) + if mode == "Agent": + return await run_agent_mode(session_id, user_msg, user_id) + effective_project = await _get_chat_project_key(user_id, request.project_key) project_context = "" try: - jira_client = get_jira() - project_context = await get_genius_context(jira_client, effective_project) + project_context = await get_genius_context(get_jira(), effective_project) except Exception as e: logger.warning(f"Failed to fetch genius context: {e}") - system_instruction = session_manager.get_system_prompt(project_context) - - if mode == "Expert": - system_instruction += "\n\nEXPERT SPECIALIST:\n" + load_prompt("expert.md") - model_id = SMART_MODEL - elif mode == "Architect": - system_instruction += "\n\nARCHITECT SPECIALIST:\n" + load_prompt("architect.md") - model_id = SMART_MODEL - elif mode == "Agent": - return await run_agent_mode(session_id, user_msg, user_id) - - bot_name = user_ctx.get("bot_name", "Kwillo Agent") - identity_instruction = f"\n\nIDENTITY: You are {bot_name}, an AI assistant." - context_injection = f"\n\nSYSTEM CONTEXT:\n- User Name: {user_name}\n- User Language Preference: {user_language}\n- Current Project: {effective_project}\n- Current Date: {date.today().isoformat()}" - language_instruction = f"\n\nCRITICAL INSTRUCTION: You MUST reply in {user_language}." - - system_instruction += identity_instruction + context_injection + language_instruction + system_instruction, model_id = _build_system_instruction( + project_context, mode, user_ctx, effective_project + ) # HITL: Check for pending function calls in history - # Load from DB first to handle cases where in-memory cache is empty (e.g. after restart) history = session_manager.get_history(session_id) if not history: history = await session_manager.load_history_from_db(session_id, user_id) @@ -626,69 +341,17 @@ async def _save_sess(): if pending_calls: logger.info(f"Found {len(pending_calls)} pending tool calls for session {session_id}") is_approval = is_valid_approval(user_msg) - tool_outputs = [] if is_approval: logger.info("User approved pending action. Executing tools...") - for fc in pending_calls: - func = next((t for t in session_manager.tools if t.__name__ == fc.name), None) - if func: - try: - args = dict(fc.args) - if request.edited_args and len(pending_calls) == 1: - args.update(request.edited_args) - logger.info( - f"Using edited args for {fc.name}: {request.edited_args}" - ) - sig = inspect.signature(func) - valid_args = {k: v for k, v in args.items() if k in sig.parameters} - if len(valid_args) < len(args): - invalid_keys = set(args.keys()) - set(valid_args.keys()) - logger.warning( - f"Filtered out invalid parameters for {fc.name}: {invalid_keys}" - ) - res = ( - await func(**valid_args) - if inspect.iscoroutinefunction(func) - else func(**valid_args) - ) - - response_data = {"result": str(res)} - if isinstance(res, dict) and any(k.lower() == "error" for k in res): - response_data = {"error": str(res.get("Error", res.get("error")))} - - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response=response_data - ) - ) - ) - except Exception as e: - logger.error(f"Error executing tool {fc.name}: {e}") - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"error": str(e)} - ) - ) - ) - else: - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, - response={"error": f"Tool {fc.name} not found"}, - ) - ) - ) - + edited_args = request.edited_args if len(pending_calls) == 1 else None + tool_outputs = await execute_tool_calls( + pending_calls, session_manager.tools, edited_args + ) if tool_outputs: await session_manager.update_history( - session_id, - types.Content(role="user", parts=tool_outputs), + session_id, types.Content(role="user", parts=tool_outputs) ) - response_texts = [] for t_out in tool_outputs: f_res = t_out.function_response @@ -698,7 +361,6 @@ async def _save_sess(): ) else: response_texts.append(f"Successfully executed '{f_res.name}'.") - final_text = "\n".join(response_texts) msg_id = await session_manager.update_history( session_id, @@ -707,43 +369,41 @@ async def _save_sess(): return build_chat_response(final_text, session_id, msg_id) else: logger.info("User rejected/interrupted tool execution.") - for fc in pending_calls: - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, - response={"error": f"User rejected action. Input was: {user_msg}"}, - ) + rejection_outputs = [ + types.Part( + function_response=types.FunctionResponse( + name=fc.name, + response={"error": f"User rejected action. Input was: {user_msg}"}, ) ) + for fc in pending_calls + ] await session_manager.update_history( - session_id, - types.Content(role="user", parts=tool_outputs), + session_id, types.Content(role="user", parts=rejection_outputs) ) pending_tool_execution = True if not pending_tool_execution: parts = [types.Part(text=user_msg)] - if request.attachments: attachment_parts = await _handle_chat_attachments( request.attachments, session_manager, session_id ) parts.extend(attachment_parts) - await session_manager.update_history(session_id, types.Content(role="user", parts=parts)) try: max_turns = 15 turn_count = 0 consecutive_empty = 0 - max_retries = 5 + chart_configs = [] while turn_count < max_turns: turn_count += 1 current_history = session_manager.get_history(session_id) - for attempt in range(max_retries + 1): + response = None + for attempt in range(6): try: response = await session_manager.client.aio.models.generate_content( model=model_id, @@ -758,12 +418,9 @@ async def _save_sess(): ), ) except Exception as e: - logger.error(f"Generation failed: {e}") - if attempt == max_retries: - return { - "type": "error", - "data": f"Model generation error: {str(e)}", - } + logger.error(f"Generation failed (attempt {attempt + 1}): {e}") + if attempt == 5: + return {"type": "error", "data": f"Model generation error: {str(e)}"} continue if not response.candidates: @@ -776,9 +433,9 @@ async def _save_sess(): if candidate.finish_reason and candidate.finish_reason != "STOP": if "MALFORMED_FUNCTION_CALL" in str(candidate.finish_reason): - if attempt < max_retries: + if attempt < 5: logger.warning( - f"Malformed function call detected (Attempt {attempt + 1}/{max_retries}). Retrying..." + f"Malformed function call (attempt {attempt + 1}), retrying..." ) await session_manager.update_history( session_id, @@ -786,24 +443,23 @@ async def _save_sess(): role="user", parts=[ types.Part( - text="SYSTEM ERROR: Your previous function call was MALFORMED. Common causes:\n" - "1. Wrong argument types (e.g., string instead of list/object)\n" - "2. Missing required arguments\n" - "3. Invalid JSON in arguments\n" - "Please review the tool schema carefully and try again with valid arguments." + text=( + "SYSTEM ERROR: Your previous function call was MALFORMED. " + "Common causes:\n1. Wrong argument types\n2. Missing required arguments\n" + "3. Invalid JSON in arguments\n" + "Please review the tool schema carefully and retry with valid arguments." + ) ) ], ), ) current_history = session_manager.get_history(session_id) continue - else: - return build_chat_response( - f"Model failed after retries. Reason: {candidate.finish_reason}. Content: {candidate.content}", - session_id, - msg_id=None, - ) - + return build_chat_response( + f"Model failed after retries. Reason: {candidate.finish_reason}", + session_id, + msg_id=None, + ) return build_chat_response( f"Model stopped generating. Reason: {candidate.finish_reason}", session_id, @@ -812,23 +468,18 @@ async def _save_sess(): if candidate.finish_reason == "STOP": if not candidate.content or not candidate.content.parts: - if attempt < max_retries: - import time - + if attempt < 5: wait_time = 2 * (attempt + 1) logger.warning( - f"Empty content or STOP reason without content (Attempt {attempt + 1}/{max_retries}). Retrying in {wait_time}s..." + f"Empty STOP content (attempt {attempt + 1}), retrying in {wait_time}s..." ) await asyncio.sleep(wait_time) continue - else: - logger.error("Model returned empty content after retries.") - return build_chat_response( - "I'm having trouble generating a response right now. Please try again.", - session_id, - msg_id=None, - ) - + return build_chat_response( + "I'm having trouble generating a response right now. Please try again.", + session_id, + msg_id=None, + ) break if not response: @@ -837,19 +488,12 @@ async def _save_sess(): candidate = response.candidates[0] if not candidate.content or not candidate.content.parts: - logger.warning(f"Empty content parts. Finish reason: {candidate.finish_reason}") - print(f"DEBUG: Empty Content Candidate: {candidate}") - if candidate.safety_ratings: - logger.warning(f"Safety ratings: {candidate.safety_ratings}") - - if turn_count > 1 or pending_tool_execution: - text = "Action completed successfully." - else: - text = "I'm having trouble processing your request. Please try rephrasing or ask again." - - logger.info( - f"Empty content fallback: turn={turn_count}, pending={pending_tool_execution}, text='{text}'" + text = ( + "Action completed successfully." + if (turn_count > 1 or pending_tool_execution) + else "I'm having trouble processing your request. Please try rephrasing or ask again." ) + logger.info(f"Empty content fallback: turn={turn_count}, text='{text}'") msg_id = await session_manager.update_history( session_id, types.Content(role="model", parts=[types.Part(text=text)]), @@ -861,67 +505,44 @@ async def _save_sess(): if not function_calls: text = "".join([p.text for p in parts if p.text]) - print( - f"DEBUG: No FC. Text='{text}', pending={pending_tool_execution}, turn={turn_count}", - flush=True, - ) - - if not text.strip(): - # Mid-task empty response: model called a tool but returned nothing after. - # Nudge it to continue instead of treating as done. - if turn_count > 1: - consecutive_empty += 1 - if consecutive_empty >= 3: - logger.warning( - f"Model stuck after {consecutive_empty} consecutive empty responses, aborting." - ) - text = "Action completed successfully." - msg_id = await session_manager.update_history( - session_id, - types.Content(role="model", parts=[types.Part(text=text)]), - ) - return build_chat_response(text, session_id, msg_id) - if consecutive_empty == 1: - nudge_text = ( - "Continue. Use the tool results above to complete the task." - ) - else: - nudge_text = ( - "IMPORTANT: You already have all the lookup results you need. " - "Do NOT call any lookup tools again (search_users, list_sprints, etc.). " - "Proceed immediately to the final action (create_issue, update_issue, etc.) " - "using the data already retrieved." - ) - logger.info( - f"Empty response mid-task (turn={turn_count}, consecutive={consecutive_empty}), injecting nudge." - ) - await session_manager.update_history( + if not text.strip() and turn_count > 1: + consecutive_empty += 1 + if consecutive_empty >= 3: + logger.warning("Model stuck, aborting.") + text = "Action completed successfully." + msg_id = await session_manager.update_history( session_id, - types.Content( - role="user", - parts=[types.Part(text=nudge_text)], - ), + types.Content(role="model", parts=[types.Part(text=text)]), + ) + return build_chat_response(text, session_id, msg_id) + + nudge_text = ( + "Continue. Use the tool results above to complete the task." + if consecutive_empty == 1 + else ( + "IMPORTANT: You already have all the lookup results you need. " + "Do NOT call any lookup tools again. " + "Proceed immediately to the final action using the data already retrieved." ) - continue - - print("DEBUG: Triggering fallback!", flush=True) - logger.info("Empty response received. Injecting default success message.") - text = "Action completed successfully." - msg_id = await session_manager.update_history( - session_id, - types.Content(role="model", parts=[types.Part(text=text)]), ) - else: - text, chart_configs = extract_chart_configs(text) - print(f"DEBUG: chart_configs={chart_configs}", flush=True) - msg_id = await session_manager.update_history( + logger.info(f"Empty mid-task (turn={turn_count}), injecting nudge.") + await session_manager.update_history( session_id, - types.Content(role="model", parts=[types.Part(text=text)]), - charts=chart_configs, + types.Content(role="user", parts=[types.Part(text=nudge_text)]), ) + continue + if not text.strip(): + text = "Action completed successfully." + + text, chart_configs = extract_chart_configs(text) consecutive_empty = 0 + msg_id = await session_manager.update_history( + session_id, + types.Content(role="model", parts=[types.Part(text=text)]), + charts=chart_configs, + ) return build_chat_response(text, session_id, msg_id, charts=chart_configs) consecutive_empty = 0 @@ -929,168 +550,20 @@ async def _save_sess(): session_id, types.Content(role="model", parts=parts) ) - SENSITIVE = [ - "create_issue", - "update_issue", - "delete_issue", - "bulk_delete_issues", - "transition_issue", - "bulk_transition_issues", - "bulk_update_issues", - "add_comment", - "log_work", - ] - is_approval = is_valid_approval(user_msg) - pending_sensitive = [ - fc for fc in function_calls if fc.name in SENSITIVE and not is_approval + fc for fc in function_calls if fc.name in SENSITIVE_TOOLS and not is_approval ] - for fc in function_calls: - print(f"DEBUG: AI calling tool -> {fc.name}") - print(f"DEBUG: Checking sensitivity for {fc.name}. Is Approval: {is_approval}") - if pending_sensitive: - tool_names = [fc.name for fc in pending_sensitive] - if len(tool_names) == 1: - title = "Approval Required" - if "create" in tool_names[0]: - title = "Create Jira Task" - elif "update" in tool_names[0]: - title = "Update Jira Task" - elif "delete" in tool_names[0]: - title = "Delete Jira Task" - elif "transition" in tool_names[0]: - title = "Transition Jira Task" - else: - title = "Multiple Actions Require Approval" - - if len(pending_sensitive) == 1: - fc = pending_sensitive[0] - args = dict(fc.args) if fc.args else {} - if fc.name == "create_issue": - summary = args.get("summary", "") - issue_type = args.get("issue_type", "Task") - project_key = args.get("project_key", "") - reason = ( - f"Create a new {issue_type} in {project_key}: {summary}" - if project_key - else f"Create a new {issue_type}: {summary}" - ) - elif fc.name in ("update_issue", "bulk_update_issues"): - issue_key = args.get("issue_key") or (args.get("issue_keys") or [""])[0] - reason = f"Update issue {issue_key}" if issue_key else "Update a Jira issue" - elif fc.name == "delete_issue": - issue_key = args.get("issue_key", "") - reason = ( - f"Permanently delete issue {issue_key}" - if issue_key - else "Delete a Jira issue" - ) - elif fc.name in ("transition_issue", "bulk_transition_issues"): - issue_key = args.get("issue_key") or (args.get("issue_keys") or [""])[0] - status = args.get("status", "") - reason = ( - f'Move {issue_key} to "{status}"' - if issue_key and status - else "Transition a Jira issue" - ) - else: - reason = f"Execute: {fc.name}" - else: - action_labels = [] - for pfc in pending_sensitive: - pargs = dict(pfc.args) if pfc.args else {} - if pfc.name == "create_issue": - s = pargs.get("summary", "new issue") - action_labels.append(f'Create "{s}"') - elif pfc.name == "delete_issue": - k = pargs.get("issue_key", "?") - action_labels.append(f"Delete {k}") - elif pfc.name == "bulk_delete_issues": - keys = pargs.get("issue_keys", []) - action_labels.append(f"Delete {', '.join(keys) if keys else 'issues'}") - elif pfc.name in ("update_issue", "bulk_update_issues"): - k = pargs.get("issue_key") or ", ".join(pargs.get("issue_keys", ["?"])) - action_labels.append(f"Update {k}") - elif pfc.name in ("transition_issue", "bulk_transition_issues"): - k = pargs.get("issue_key") or ", ".join(pargs.get("issue_keys", ["?"])) - st = pargs.get("status", "new status") - action_labels.append(f'Move {k} → "{st}"') - else: - action_labels.append(pfc.name) - reason = "\n".join(f"• {a}" for a in action_labels) - - return { - "type": "decision_request", - "session_id": session_id, - "data": { - "title": title, - "policy": "User confirmation needed", - "reason": reason, - "actions": [{"label": "Approve", "value": "yes"}], - "pending_action": { - "type": "tool_call", - "name": pending_sensitive[0].name, - "args": dict(pending_sensitive[0].args) - if pending_sensitive[0].args - else {}, - }, - "pending_actions": [ - { - "type": "tool_call", - "name": pfc.name, - "args": dict(pfc.args) if pfc.args else {}, - } - for pfc in pending_sensitive - ], - }, - } - - tool_outputs = [] - for fc in function_calls: - func = next((t for t in session_manager.tools if t.__name__ == fc.name), None) - if func: - try: - print(f"Executing {fc.name} with {fc.args}") - - sig = inspect.signature(func) - valid_args = {k: v for k, v in fc.args.items() if k in sig.parameters} - if len(valid_args) < len(fc.args): - invalid_keys = set(fc.args.keys()) - set(valid_args.keys()) - logger.warning( - f"Filtered out invalid parameters for {fc.name}: {invalid_keys}" - ) - - if inspect.iscoroutinefunction(func): - result = await func(**valid_args) - else: - result = func(**valid_args) - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"result": str(result)} - ) - ) - ) - except Exception as e: - tool_outputs.append( - types.Part( - function_response=types.FunctionResponse( - name=fc.name, response={"error": str(e)} - ) - ) - ) + return build_decision_request(session_id, pending_sensitive) + tool_outputs = await execute_tool_calls(function_calls, session_manager.tools) if tool_outputs: await session_manager.update_history( - session_id, - types.Content(role="user", parts=tool_outputs), + session_id, types.Content(role="user", parts=tool_outputs) ) - continue - return build_chat_response( f"Task stopped after {max_turns} steps. It may be too complex or stuck in a loop.", session_id, diff --git a/apps/agents/agent_server/routes/jira.py b/apps/agents/agent_server/routes/jira.py index cf4c449..8adee23 100644 --- a/apps/agents/agent_server/routes/jira.py +++ b/apps/agents/agent_server/routes/jira.py @@ -17,7 +17,7 @@ ) from apps.agents.agent_server.session_manager import session_manager from apps.agents.agent_server.src.server.config import get_jira -from apps.agents.agent_server.src.server.tools import issue_tools, workflow_tools +from apps.agents.agent_server.tools import issue_tools, workflow_tools from apps.agents.agent_server.user_context import get_user_context from libs.common.db_utils import managed_async_db_session from libs.common.jira_client import JiraSessionExpired diff --git a/apps/agents/agent_server/schemas.py b/apps/agents/agent_server/schemas/__init__.py similarity index 100% rename from apps/agents/agent_server/schemas.py rename to apps/agents/agent_server/schemas/__init__.py diff --git a/apps/agents/agent_server/services/__init__.py b/apps/agents/agent_server/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/agents/agent_server/services/agent_service.py b/apps/agents/agent_server/services/agent_service.py new file mode 100644 index 0000000..ee8c9b6 --- /dev/null +++ b/apps/agents/agent_server/services/agent_service.py @@ -0,0 +1,173 @@ +"""Agent Mode (Planner → Executor) orchestration.""" + +import inspect +import json +import logging +import traceback +from datetime import date + +from google.genai import types + +from apps.agents.agent_server.response_utils import build_chat_response +from apps.agents.agent_server.session_manager import session_manager +from apps.agents.agent_server.src.cli.config import FAST_MODEL, SMART_MODEL +from apps.agents.agent_server.src.cli.modes import load_prompt +from apps.agents.agent_server.src.server.config import get_jira +from apps.agents.agent_server.src.server.context_builder import get_genius_context +from apps.agents.agent_server.user_context import _get_chat_project_key, get_user_context + +logger = logging.getLogger(__name__) + + +async def run_agent_mode(session_id: str, user_msg: str, user_id: int | None = None): + """Executes the Agent Mode (Planner -> Executor) logic.""" + try: + effective_project = await _get_chat_project_key(user_id, None) + project_context = "" + try: + jira_client = get_jira() + project_context = await get_genius_context(jira_client, effective_project) + except Exception as e: + logger.warning(f"[Agent] Failed to fetch genius context: {e}") + + planner_sys = load_prompt("planner.md") + sys_prompt = session_manager.get_system_prompt(project_context) + planner_sys = sys_prompt + "\n\nPLANNER SPECIFIC:\n" + planner_sys + + user_ctx = await get_user_context() + user_name = user_ctx["name"] + user_language = user_ctx["language"] + + context_injection = ( + f"\n\nSYSTEM CONTEXT:\n- User Name: {user_name}" + f"\n- User Language Preference: {user_language}" + f"\n- Current Project: {effective_project}" + f"\n- Current Date: {date.today().isoformat()}" + ) + language_instruction = ( + f"\n\nCRITICAL INSTRUCTION: You MUST reply in {user_language}. " + "Do not change language based on user input language. " + f"Your response must be exclusively in {user_language}." + ) + + planner_sys += context_injection + language_instruction + + plan_response = await session_manager.client.aio.models.generate_content( + model=SMART_MODEL, + contents=[types.Content(role="user", parts=[types.Part(text=user_msg)])], + config=types.GenerateContentConfig(system_instruction=planner_sys), + ) + + if ( + not plan_response.candidates + or not plan_response.candidates[0].content + or not plan_response.candidates[0].content.parts + ): + return build_chat_response( + "Planning failed. Model returned no content.", session_id, msg_id=None + ) + plan_text = plan_response.candidates[0].content.parts[0].text + + try: + clean_json = plan_text.replace("```json", "").replace("```", "").strip() + steps = json.loads(clean_json) + if not isinstance(steps, list): + raise ValueError("Not a list") + except Exception: + return build_chat_response( + f"Planning failed. Response: {plan_text}", session_id, msg_id=None + ) + + executor_sys = session_manager.get_system_prompt(project_context) + executor_sys += context_injection + language_instruction + agent_session_id = f"{session_id}_agent" + + if not session_manager.get_history(agent_session_id): + await session_manager.update_history( + agent_session_id, + types.Content(role="user", parts=[types.Part(text=f"Objective: {user_msg}")]), + ) + + accumulated_output = [f"**Objective:** {user_msg}\n\n**Plan:**\n"] + for i, step in enumerate(steps): + accumulated_output.append(f"{i + 1}. {step}") + accumulated_output.append("\n---\n") + + for i, step in enumerate(steps): + logger.info(f"[Agent] Executing Step {i + 1}/{len(steps)}: {step}") + await session_manager.update_history( + agent_session_id, + types.Content( + role="user", + parts=[types.Part(text=f"Execute Step {i + 1}: {step}")], + ), + ) + + step_output = "" + for turn in range(3): + agent_history = session_manager.get_history(agent_session_id) + try: + resp = await session_manager.client.aio.models.generate_content( + model=FAST_MODEL, + contents=agent_history, + config=types.GenerateContentConfig( + system_instruction=executor_sys, + tools=session_manager.tools, + ), + ) + except Exception as e: + logger.error(f"Step {i + 1} turn {turn + 1} generation failed: {e}") + step_output = f"Execution error: {str(e)}" + break + + if not resp.candidates: + break + candidate = resp.candidates[0] + if not candidate.content or not candidate.content.parts: + break + + parts = candidate.content.parts + function_calls = [p.function_call for p in parts if p.function_call] + + if function_calls: + await session_manager.update_history( + agent_session_id, + types.Content(role="model", parts=parts), + ) + from apps.agents.agent_server.services.tool_executor import execute_tool_calls + + outputs = await execute_tool_calls(function_calls, session_manager.tools) + if outputs: + await session_manager.update_history( + agent_session_id, + types.Content(role="user", parts=outputs), + ) + continue + + text = "".join([p.text for p in parts if p.text]) + if text: + step_output = text + await session_manager.update_history( + agent_session_id, + types.Content(role="model", parts=parts), + ) + break + + accumulated_output.append(f"### Step {i + 1}: {step}\n{step_output}\n") + + final_response = "\n".join(accumulated_output) + + await session_manager.update_history( + session_id, + types.Content(role="user", parts=[types.Part(text=user_msg)]), + ) + await session_manager.update_history( + session_id, + types.Content(role="model", parts=[types.Part(text=final_response)]), + ) + + return build_chat_response(final_response, session_id, msg_id=None) + + except Exception as e: + traceback.print_exc() + return {"type": "error", "data": f"Agent Error: {str(e)}"} diff --git a/apps/agents/agent_server/services/tool_executor.py b/apps/agents/agent_server/services/tool_executor.py new file mode 100644 index 0000000..6adb360 --- /dev/null +++ b/apps/agents/agent_server/services/tool_executor.py @@ -0,0 +1,174 @@ +""" +Tool execution helpers shared by chat_endpoint and _chat_stream. +Handles safe invocation, parameter filtering, and HITL decision request building. +""" + +import inspect +import logging + +from google.genai import types + +logger = logging.getLogger(__name__) + +SENSITIVE_TOOLS = { + "create_issue", + "update_issue", + "delete_issue", + "bulk_delete_issues", + "transition_issue", + "bulk_transition_issues", + "bulk_update_issues", + "add_comment", + "log_work", +} + + +async def execute_tool_call( + fc, + tools: list, + edited_args: dict | None = None, +) -> types.Part: + """ + Execute a single function call safely. + Returns a FunctionResponse Part (either result or error). + """ + func = next((t for t in tools if t.__name__ == fc.name), None) + + if not func: + return types.Part( + function_response=types.FunctionResponse( + name=fc.name, response={"error": f"Tool {fc.name} not found"} + ) + ) + + try: + args = dict(fc.args) if fc.args else {} + if edited_args: + args.update(edited_args) + + sig = inspect.signature(func) + valid_args = {k: v for k, v in args.items() if k in sig.parameters} + if len(valid_args) < len(args): + invalid_keys = set(args.keys()) - set(valid_args.keys()) + logger.warning(f"Filtered out invalid parameters for {fc.name}: {invalid_keys}") + + res = await func(**valid_args) if inspect.iscoroutinefunction(func) else func(**valid_args) + + response_data: dict = {"result": str(res)} + if isinstance(res, dict) and any(k.lower() == "error" for k in res): + response_data = {"error": str(res.get("Error", res.get("error")))} + + return types.Part( + function_response=types.FunctionResponse(name=fc.name, response=response_data) + ) + except Exception as e: + logger.error(f"Error executing tool {fc.name}: {e}") + return types.Part( + function_response=types.FunctionResponse(name=fc.name, response={"error": str(e)}) + ) + + +async def execute_tool_calls( + function_calls: list, + tools: list, + edited_args: dict | None = None, +) -> list[types.Part]: + """Execute a list of function calls and return all FunctionResponse Parts.""" + return [await execute_tool_call(fc, tools, edited_args) for fc in function_calls] + + +def build_decision_request(session_id: str, pending_sensitive: list) -> dict: + """ + Build the HITL decision_request payload for the frontend. + pending_sensitive: list of function_call objects that require approval. + """ + tool_names = [fc.name for fc in pending_sensitive] + + if len(tool_names) == 1: + title = "Approval Required" + if "create" in tool_names[0]: + title = "Create Jira Task" + elif "update" in tool_names[0]: + title = "Update Jira Task" + elif "delete" in tool_names[0]: + title = "Delete Jira Task" + elif "transition" in tool_names[0]: + title = "Transition Jira Task" + else: + title = "Multiple Actions Require Approval" + + if len(pending_sensitive) == 1: + fc = pending_sensitive[0] + args = dict(fc.args) if fc.args else {} + reason = _build_single_reason(fc.name, args) + else: + action_labels = [_build_action_label(fc) for fc in pending_sensitive] + reason = "\n".join(f"• {a}" for a in action_labels) + + return { + "type": "decision_request", + "session_id": session_id, + "data": { + "title": title, + "policy": "User confirmation needed", + "reason": reason, + "actions": [{"label": "Approve", "value": "yes"}], + "pending_action": { + "type": "tool_call", + "name": pending_sensitive[0].name, + "args": dict(pending_sensitive[0].args) if pending_sensitive[0].args else {}, + }, + "pending_actions": [ + { + "type": "tool_call", + "name": pfc.name, + "args": dict(pfc.args) if pfc.args else {}, + } + for pfc in pending_sensitive + ], + }, + } + + +def _build_single_reason(tool_name: str, args: dict) -> str: + if tool_name == "create_issue": + summary = args.get("summary", "") + issue_type = args.get("issue_type", "Task") + project_key = args.get("project_key", "") + return ( + f"Create a new {issue_type} in {project_key}: {summary}" + if project_key + else f"Create a new {issue_type}: {summary}" + ) + if tool_name in ("update_issue", "bulk_update_issues"): + issue_key = args.get("issue_key") or (args.get("issue_keys") or [""])[0] + return f"Update issue {issue_key}" if issue_key else "Update a Jira issue" + if tool_name == "delete_issue": + issue_key = args.get("issue_key", "") + return f"Permanently delete issue {issue_key}" if issue_key else "Delete a Jira issue" + if tool_name in ("transition_issue", "bulk_transition_issues"): + issue_key = args.get("issue_key") or (args.get("issue_keys") or [""])[0] + status = args.get("status", "") + return ( + f'Move {issue_key} to "{status}"' if issue_key and status else "Transition a Jira issue" + ) + return f"Execute: {tool_name}" + + +def _build_action_label(fc) -> str: + args = dict(fc.args) if fc.args else {} + if fc.name == "create_issue": + return f'Create "{args.get("summary", "new issue")}"' + if fc.name == "delete_issue": + return f"Delete {args.get('issue_key', '?')}" + if fc.name == "bulk_delete_issues": + keys = args.get("issue_keys", []) + return f"Delete {', '.join(keys) if keys else 'issues'}" + if fc.name in ("update_issue", "bulk_update_issues"): + k = args.get("issue_key") or ", ".join(args.get("issue_keys", ["?"])) + return f"Update {k}" + if fc.name in ("transition_issue", "bulk_transition_issues"): + k = args.get("issue_key") or ", ".join(args.get("issue_keys", ["?"])) + st = args.get("status", "new status") + return f'Move {k} → "{st}"' + return fc.name diff --git a/apps/agents/agent_server/session_manager.py b/apps/agents/agent_server/session_manager.py index 29c6f42..e11723d 100644 --- a/apps/agents/agent_server/session_manager.py +++ b/apps/agents/agent_server/session_manager.py @@ -8,7 +8,7 @@ from sqlalchemy import select from apps.agents.agent_server.response_utils import sanitize_for_json -from apps.agents.agent_server.src.server.tools import ( +from apps.agents.agent_server.tools import ( issue_tools, jira_task_proposal_tools, proposal_tools, @@ -16,7 +16,7 @@ user_tools, workflow_tools, ) -from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( +from apps.agents.agent_server.tools.jira_task_proposal_tools import ( get_proposed_delete_tasks, get_proposed_tasks, ) diff --git a/apps/agents/agent_server/src/server/main.py b/apps/agents/agent_server/src/server/main.py index 36fd344..8b070d7 100644 --- a/apps/agents/agent_server/src/server/main.py +++ b/apps/agents/agent_server/src/server/main.py @@ -2,7 +2,7 @@ import sys # Import mcp instance -from apps.agents.agent_server.src.server.tools import mcp +from apps.agents.agent_server.tools import mcp # Import submodules to register tools diff --git a/apps/agents/agent_server/tests/test_chat_persistence.py b/apps/agents/agent_server/tests/test_chat_persistence.py index 6c09a06..8640622 100644 --- a/apps/agents/agent_server/tests/test_chat_persistence.py +++ b/apps/agents/agent_server/tests/test_chat_persistence.py @@ -7,7 +7,7 @@ import pytest from fastapi import HTTPException -from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( +from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, get_proposed_tasks, propose_jira_tasks, diff --git a/apps/agents/agent_server/tests/test_labels.py b/apps/agents/agent_server/tests/test_labels.py index 52e9fed..572429a 100644 --- a/apps/agents/agent_server/tests/test_labels.py +++ b/apps/agents/agent_server/tests/test_labels.py @@ -11,13 +11,13 @@ class TestLabelsOnCreate: @pytest.mark.asyncio async def test_labels_forwarded_to_jira_create(self): """Labels in CreateTaskRequest must reach jira.create_issue().""" - from apps.agents.agent_server.src.server.tools import issue_tools + from apps.agents.agent_server.tools import issue_tools mock_jira = AsyncMock() mock_jira.create_issue = AsyncMock(return_value={"key": "PROJ-1", "id": "10001"}) with patch( - "apps.agents.agent_server.src.server.tools.issue_tools.get_jira", + "apps.agents.agent_server.tools.issue_tools.get_jira", return_value=mock_jira, ): await issue_tools.create_issue( @@ -35,13 +35,13 @@ async def test_labels_forwarded_to_jira_create(self): @pytest.mark.asyncio async def test_labels_none_by_default(self): """Calling create_issue without labels still works.""" - from apps.agents.agent_server.src.server.tools import issue_tools + from apps.agents.agent_server.tools import issue_tools mock_jira = AsyncMock() mock_jira.create_issue = AsyncMock(return_value={"key": "PROJ-1", "id": "10001"}) with patch( - "apps.agents.agent_server.src.server.tools.issue_tools.get_jira", + "apps.agents.agent_server.tools.issue_tools.get_jira", return_value=mock_jira, ): await issue_tools.create_issue( diff --git a/apps/agents/agent_server/tests/test_proposal_tools.py b/apps/agents/agent_server/tests/test_proposal_tools.py index ea3b4ac..e0a29ca 100644 --- a/apps/agents/agent_server/tests/test_proposal_tools.py +++ b/apps/agents/agent_server/tests/test_proposal_tools.py @@ -12,7 +12,7 @@ class TestProposalTools: async def test_get_pending_proposals_no_context(self): """Test getting proposals when user context is missing.""" from apps.agents.agent_server.src.server.context import current_user_token - from apps.agents.agent_server.src.server.tools.proposal_tools import get_pending_proposals + from apps.agents.agent_server.tools.proposal_tools import get_pending_proposals token = current_user_token.set(None) try: @@ -25,16 +25,16 @@ async def test_get_pending_proposals_no_context(self): async def test_get_pending_proposals_no_proposals(self): """Test getting proposals when none exist.""" from apps.agents.agent_server.src.server.context import current_user_token - from apps.agents.agent_server.src.server.tools.proposal_tools import get_pending_proposals + from apps.agents.agent_server.tools.proposal_tools import get_pending_proposals token = current_user_token.set({"user_id": 123}) try: with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.get_proposals_for_user", + "apps.agents.agent_server.tools.proposal_tools.proposal_service.get_proposals_for_user", return_value=[], ): with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.managed_async_db_session" + "apps.agents.agent_server.tools.proposal_tools.managed_async_db_session" ): result = await get_pending_proposals() assert "No pending proposals found" in result @@ -45,7 +45,7 @@ async def test_get_pending_proposals_no_proposals(self): async def test_get_pending_proposals_success(self): """Test getting proposals successfully.""" from apps.agents.agent_server.src.server.context import current_user_token - from apps.agents.agent_server.src.server.tools.proposal_tools import get_pending_proposals + from apps.agents.agent_server.tools.proposal_tools import get_pending_proposals mock_proposal = MagicMock() mock_proposal.id = "prop-123" @@ -62,11 +62,11 @@ async def test_get_pending_proposals_success(self): token = current_user_token.set({"user_id": 123}) try: with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.get_proposals_for_user", + "apps.agents.agent_server.tools.proposal_tools.proposal_service.get_proposals_for_user", return_value=[mock_proposal], ): with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.managed_async_db_session" + "apps.agents.agent_server.tools.proposal_tools.managed_async_db_session" ): result = await get_pending_proposals() assert "PROJ-456" in result @@ -79,7 +79,7 @@ async def test_get_pending_proposals_success(self): async def test_create_proposal_success(self): """Test creating a proposal.""" from apps.agents.agent_server.src.server.context import current_user_token - from apps.agents.agent_server.src.server.tools.proposal_tools import create_proposal + from apps.agents.agent_server.tools.proposal_tools import create_proposal mock_proposal = MagicMock() mock_proposal.id = "new-prop-id" @@ -87,11 +87,11 @@ async def test_create_proposal_success(self): token = current_user_token.set({"user_id": 123}) try: with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.create_proposal", + "apps.agents.agent_server.tools.proposal_tools.proposal_service.create_proposal", return_value=mock_proposal, ): with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.managed_async_db_session" + "apps.agents.agent_server.tools.proposal_tools.managed_async_db_session" ): result = await create_proposal( project_key="PROJ", title="New Idea", description="Description here" @@ -104,36 +104,34 @@ async def test_create_proposal_success(self): @pytest.mark.asyncio async def test_confirm_proposal_not_found(self): """Test confirming a nonexistent proposal.""" - from apps.agents.agent_server.src.server.tools.proposal_tools import confirm_proposal + from apps.agents.agent_server.tools.proposal_tools import confirm_proposal with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.get_proposal_by_id", + "apps.agents.agent_server.tools.proposal_tools.proposal_service.get_proposal_by_id", return_value=None, ): - with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.managed_async_db_session" - ): + with patch("apps.agents.agent_server.tools.proposal_tools.managed_async_db_session"): result = await confirm_proposal("missing", True) assert "not found" in result @pytest.mark.asyncio async def test_confirm_proposal_reject(self): """Test rejecting a proposal.""" - from apps.agents.agent_server.src.server.tools.proposal_tools import confirm_proposal + from apps.agents.agent_server.tools.proposal_tools import confirm_proposal mock_proposal = MagicMock() mock_proposal.status = "pending" mock_proposal.project_key = "PROJ-123" with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.get_proposal_by_id", + "apps.agents.agent_server.tools.proposal_tools.proposal_service.get_proposal_by_id", return_value=mock_proposal, ): with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.update_proposal_status" + "apps.agents.agent_server.tools.proposal_tools.proposal_service.update_proposal_status" ) as mock_update: with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.managed_async_db_session" + "apps.agents.agent_server.tools.proposal_tools.managed_async_db_session" ): result = await confirm_proposal("prop-123", False) assert "Proposal rejected" in result @@ -142,7 +140,7 @@ async def test_confirm_proposal_reject(self): @pytest.mark.asyncio async def test_confirm_proposal_approve(self): """Test approving and executing a proposal.""" - from apps.agents.agent_server.src.server.tools.proposal_tools import confirm_proposal + from apps.agents.agent_server.tools.proposal_tools import confirm_proposal mock_proposal = MagicMock() mock_proposal.status = "pending" @@ -155,18 +153,18 @@ async def test_confirm_proposal_approve(self): mock_jira.update_status = AsyncMock(return_value={"status": "Done"}) with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.get_proposal_by_id", + "apps.agents.agent_server.tools.proposal_tools.proposal_service.get_proposal_by_id", return_value=mock_proposal, ): with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.get_jira", + "apps.agents.agent_server.tools.proposal_tools.get_jira", return_value=mock_jira, ): with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.proposal_service.update_proposal_status" + "apps.agents.agent_server.tools.proposal_tools.proposal_service.update_proposal_status" ) as mock_update: with patch( - "apps.agents.agent_server.src.server.tools.proposal_tools.managed_async_db_session" + "apps.agents.agent_server.tools.proposal_tools.managed_async_db_session" ): result = await confirm_proposal("prop-123", True) assert "Proposal executed" in result diff --git a/apps/agents/agent_server/src/server/tools/__init__.py b/apps/agents/agent_server/tools/__init__.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/__init__.py rename to apps/agents/agent_server/tools/__init__.py diff --git a/apps/agents/agent_server/src/server/tools/issue_tools.py b/apps/agents/agent_server/tools/issue_tools.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/issue_tools.py rename to apps/agents/agent_server/tools/issue_tools.py diff --git a/apps/agents/agent_server/src/server/tools/jira_task_proposal_tools.py b/apps/agents/agent_server/tools/jira_task_proposal_tools.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/jira_task_proposal_tools.py rename to apps/agents/agent_server/tools/jira_task_proposal_tools.py diff --git a/apps/agents/agent_server/src/server/tools/proposal_tools.py b/apps/agents/agent_server/tools/proposal_tools.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/proposal_tools.py rename to apps/agents/agent_server/tools/proposal_tools.py diff --git a/apps/agents/agent_server/src/server/tools/thinking_tools.py b/apps/agents/agent_server/tools/thinking_tools.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/thinking_tools.py rename to apps/agents/agent_server/tools/thinking_tools.py diff --git a/apps/agents/agent_server/src/server/tools/user_tools.py b/apps/agents/agent_server/tools/user_tools.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/user_tools.py rename to apps/agents/agent_server/tools/user_tools.py diff --git a/apps/agents/agent_server/src/server/tools/workflow_tools.py b/apps/agents/agent_server/tools/workflow_tools.py similarity index 100% rename from apps/agents/agent_server/src/server/tools/workflow_tools.py rename to apps/agents/agent_server/tools/workflow_tools.py diff --git a/apps/api_server/main.py b/apps/api_server/main.py index d0d58c3..4050bf5 100644 --- a/apps/api_server/main.py +++ b/apps/api_server/main.py @@ -12,7 +12,7 @@ from apps.api_server import support as support_router from apps.api_server.middleware import JiraAuthMiddleware from apps.api_server.routers import health as health_router -from apps.integrations.github.app.routes import webhook as github_webhook +from apps.integrations.github.routes import webhook as github_webhook from apps.integrations.jira.routes import api as jira_api from apps.integrations.jira.routes import auth as jira_auth from apps.integrations.jira.routes import payment as jira_payment diff --git a/apps/integrations/github/app/__init__.py b/apps/integrations/github/app/__init__.py deleted file mode 100644 index b099acb..0000000 --- a/apps/integrations/github/app/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# GitHub Webhook Service for AI PM-Bot diff --git a/apps/integrations/github/app/ai/__init__.py b/apps/integrations/github/app/ai/__init__.py deleted file mode 100644 index e7042fa..0000000 --- a/apps/integrations/github/app/ai/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# AI layer for Jira action proposals diff --git a/apps/integrations/github/app/github/__init__.py b/apps/integrations/github/app/github/__init__.py deleted file mode 100644 index 6aa01bc..0000000 --- a/apps/integrations/github/app/github/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# GitHub webhook handling diff --git a/apps/integrations/github/app/jira/__init__.py b/apps/integrations/github/app/jira/__init__.py deleted file mode 100644 index b6ef58a..0000000 --- a/apps/integrations/github/app/jira/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Jira HTTP client diff --git a/apps/integrations/github/app/routes/__init__.py b/apps/integrations/github/app/routes/__init__.py deleted file mode 100644 index 6bb2aec..0000000 --- a/apps/integrations/github/app/routes/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Routes diff --git a/apps/integrations/github/app/store/__init__.py b/apps/integrations/github/app/store/__init__.py deleted file mode 100644 index 14dd613..0000000 --- a/apps/integrations/github/app/store/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Store (Redis idempotency) diff --git a/apps/integrations/github/app/utils/__init__.py b/apps/integrations/github/app/utils/__init__.py deleted file mode 100644 index 285c9e8..0000000 --- a/apps/integrations/github/app/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Utils diff --git a/apps/integrations/github/app/app_factory.py b/apps/integrations/github/app_factory.py similarity index 90% rename from apps/integrations/github/app/app_factory.py rename to apps/integrations/github/app_factory.py index 981eb63..4df5023 100644 --- a/apps/integrations/github/app/app_factory.py +++ b/apps/integrations/github/app_factory.py @@ -8,9 +8,9 @@ from fastapi import FastAPI -from apps.integrations.github.app.config import get_settings -from apps.integrations.github.app.routes import webhook -from apps.integrations.github.app.utils.logging import setup_logging +from apps.integrations.github.config import get_settings +from apps.integrations.github.routes import webhook +from apps.integrations.github.utils.logging import setup_logging logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/app/config.py b/apps/integrations/github/config.py similarity index 100% rename from apps/integrations/github/app/config.py rename to apps/integrations/github/config.py diff --git a/apps/integrations/github/app/main.py b/apps/integrations/github/main.py similarity index 75% rename from apps/integrations/github/app/main.py rename to apps/integrations/github/main.py index e14dfdf..ce5e155 100644 --- a/apps/integrations/github/app/main.py +++ b/apps/integrations/github/main.py @@ -4,7 +4,7 @@ Uvicorn loads the 'app' object from this file. """ -from apps.integrations.github.app.app_factory import create_app +from apps.integrations.github.app_factory import create_app # Create application instance app = create_app() diff --git a/apps/integrations/github/routes/__init__.py b/apps/integrations/github/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/app/routes/webhook.py b/apps/integrations/github/routes/webhook.py similarity index 89% rename from apps/integrations/github/app/routes/webhook.py rename to apps/integrations/github/routes/webhook.py index af3cccb..9cac09d 100644 --- a/apps/integrations/github/app/routes/webhook.py +++ b/apps/integrations/github/routes/webhook.py @@ -5,10 +5,10 @@ from fastapi import APIRouter, Header, HTTPException, Request, Response -from apps.integrations.github.app.config import get_settings -from apps.integrations.github.app.github.handler import handle_event -from apps.integrations.github.app.github.verify import verify_signature -from apps.integrations.github.app.store.idempotency import is_processed, mark_processed +from apps.integrations.github.config import get_settings +from apps.integrations.github.services.github.handler import handle_event +from apps.integrations.github.services.github.verify import verify_signature +from apps.integrations.github.services.store.idempotency import is_processed, mark_processed logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/scripts/check_secret.py b/apps/integrations/github/scripts/check_secret.py index 0cb76df..355a96e 100644 --- a/apps/integrations/github/scripts/check_secret.py +++ b/apps/integrations/github/scripts/check_secret.py @@ -9,7 +9,7 @@ ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) -from apps.integrations.github.app.config import get_settings # noqa: E402 +from apps.integrations.github.config import get_settings # noqa: E402 env_secret = os.environ.get("GITHUB_WEBHOOK_SECRET", "") config_secret = get_settings().github_webhook_secret diff --git a/apps/integrations/github/services/__init__.py b/apps/integrations/github/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/services/ai/__init__.py b/apps/integrations/github/services/ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/app/ai/commit_analyzer.py b/apps/integrations/github/services/ai/commit_analyzer.py similarity index 98% rename from apps/integrations/github/app/ai/commit_analyzer.py rename to apps/integrations/github/services/ai/commit_analyzer.py index 4fd3239..145b852 100644 --- a/apps/integrations/github/app/ai/commit_analyzer.py +++ b/apps/integrations/github/services/ai/commit_analyzer.py @@ -7,7 +7,7 @@ import google.genai as genai from google.genai import types -from apps.integrations.github.app.config import get_settings +from apps.integrations.github.config import get_settings logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/app/ai/diff_analyzer.py b/apps/integrations/github/services/ai/diff_analyzer.py similarity index 98% rename from apps/integrations/github/app/ai/diff_analyzer.py rename to apps/integrations/github/services/ai/diff_analyzer.py index e3ea86e..0aa4d40 100644 --- a/apps/integrations/github/app/ai/diff_analyzer.py +++ b/apps/integrations/github/services/ai/diff_analyzer.py @@ -11,7 +11,7 @@ from google.genai import types from pydantic import BaseModel, Field -from apps.integrations.github.app.config import get_settings +from apps.integrations.github.config import get_settings logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/app/ai/gemini.py b/apps/integrations/github/services/ai/gemini.py similarity index 99% rename from apps/integrations/github/app/ai/gemini.py rename to apps/integrations/github/services/ai/gemini.py index 8f27084..c09d5cb 100644 --- a/apps/integrations/github/app/ai/gemini.py +++ b/apps/integrations/github/services/ai/gemini.py @@ -8,7 +8,7 @@ from google.genai import types from pydantic import BaseModel, Field -from apps.integrations.github.app.config import get_settings +from apps.integrations.github.config import get_settings logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/app/ai/risk_analysis.py b/apps/integrations/github/services/ai/risk_analysis.py similarity index 98% rename from apps/integrations/github/app/ai/risk_analysis.py rename to apps/integrations/github/services/ai/risk_analysis.py index 2263b76..aa3e364 100644 --- a/apps/integrations/github/app/ai/risk_analysis.py +++ b/apps/integrations/github/services/ai/risk_analysis.py @@ -8,7 +8,7 @@ from google.genai import types from pydantic import BaseModel, Field -from apps.integrations.github.app.config import get_settings +from apps.integrations.github.config import get_settings logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/services/github/__init__.py b/apps/integrations/github/services/github/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/app/github/ai_logic.py b/apps/integrations/github/services/github/ai_logic.py similarity index 95% rename from apps/integrations/github/app/github/ai_logic.py rename to apps/integrations/github/services/github/ai_logic.py index 928b201..a801b4e 100644 --- a/apps/integrations/github/app/github/ai_logic.py +++ b/apps/integrations/github/services/github/ai_logic.py @@ -4,16 +4,16 @@ from typing import Any from apps.agents.agent_server.src.common.redis_client import RedisManager -from apps.integrations.github.app.ai.diff_analyzer import analyze_diff_vs_tasks -from apps.integrations.github.app.config import get_settings -from apps.integrations.github.app.github.client import GitHubClient -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.config import get_settings +from apps.integrations.github.services.ai.diff_analyzer import analyze_diff_vs_tasks +from apps.integrations.github.services.github.client import GitHubClient +from apps.integrations.github.services.github.utils import ( get_link_pr, get_link_push, get_repo_name, ) -from apps.integrations.github.app.jira.client import get_user_active_tasks, get_user_for_repo -from apps.integrations.github.app.store.proposal_store import store_proposal +from apps.integrations.github.services.jira.client import get_user_active_tasks, get_user_for_repo +from apps.integrations.github.services.store.proposal_store import store_proposal from libs.common.subscription_limits import FEATURE_GITHUB_DIFFS, get_tier_limits logger = logging.getLogger("GitHubAI") diff --git a/apps/integrations/github/app/github/ci_status.py b/apps/integrations/github/services/github/ci_status.py similarity index 98% rename from apps/integrations/github/app/github/ci_status.py rename to apps/integrations/github/services/github/ci_status.py index d995722..bfca690 100644 --- a/apps/integrations/github/app/github/ci_status.py +++ b/apps/integrations/github/services/github/ci_status.py @@ -90,7 +90,7 @@ async def get_ci_status( if full_name and sha: owner, repo_name = full_name.split("/") - from apps.integrations.github.app.github.client import GitHubClient + from apps.integrations.github.services.github.client import GitHubClient client = GitHubClient(github_token) diff --git a/apps/integrations/github/app/github/client.py b/apps/integrations/github/services/github/client.py similarity index 100% rename from apps/integrations/github/app/github/client.py rename to apps/integrations/github/services/github/client.py diff --git a/apps/integrations/github/app/github/events.py b/apps/integrations/github/services/github/events.py similarity index 94% rename from apps/integrations/github/app/github/events.py rename to apps/integrations/github/services/github/events.py index 8ea919d..8571450 100644 --- a/apps/integrations/github/app/github/events.py +++ b/apps/integrations/github/services/github/events.py @@ -1,11 +1,14 @@ import logging from typing import Any -from apps.integrations.github.app.ai.gemini import ai_decide_transition -from apps.integrations.github.app.config import get_settings -from apps.integrations.github.app.github.ai_logic import handle_ai_decision -from apps.integrations.github.app.github.ci_status import get_ci_status, should_block_transition -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.config import get_settings +from apps.integrations.github.services.ai.gemini import ai_decide_transition +from apps.integrations.github.services.github.ai_logic import handle_ai_decision +from apps.integrations.github.services.github.ci_status import ( + get_ci_status, + should_block_transition, +) +from apps.integrations.github.services.github.utils import ( build_comment_text_pr, build_comment_text_push, build_comment_text_review, @@ -17,8 +20,8 @@ get_link_push, get_repo_name, ) -from apps.integrations.github.app.jira.client import get_user_for_repo -from apps.integrations.github.app.jira.context import get_issue_context, get_project_context +from apps.integrations.github.services.jira.client import get_user_for_repo +from apps.integrations.github.services.jira.context import get_issue_context, get_project_context logger = logging.getLogger("GitHubEvents") diff --git a/apps/integrations/github/app/github/handler.py b/apps/integrations/github/services/github/handler.py similarity index 81% rename from apps/integrations/github/app/github/handler.py rename to apps/integrations/github/services/github/handler.py index e82e8d9..d00ecf5 100644 --- a/apps/integrations/github/app/github/handler.py +++ b/apps/integrations/github/services/github/handler.py @@ -2,46 +2,49 @@ import logging from typing import Any -from apps.integrations.github.app.ai.gemini import ai_decide_transition -from apps.integrations.github.app.ai.risk_analysis import analyze_pr_risk -from apps.integrations.github.app.config import get_settings -from apps.integrations.github.app.github.ai_logic import ( +from apps.integrations.github.config import get_settings +from apps.integrations.github.services.ai.gemini import ai_decide_transition +from apps.integrations.github.services.ai.risk_analysis import analyze_pr_risk +from apps.integrations.github.services.github.ai_logic import ( handle_ai_decision as _handle_ai_decision, ) -from apps.integrations.github.app.github.ai_logic import ( +from apps.integrations.github.services.github.ai_logic import ( run_semantic_analysis as _run_semantic_analysis, ) -from apps.integrations.github.app.github.ci_status import get_ci_status, should_block_transition +from apps.integrations.github.services.github.ci_status import ( + get_ci_status, + should_block_transition, +) # Forward-facing modules for events -from apps.integrations.github.app.github.events import build_actions as _build_actions -from apps.integrations.github.app.github.parser import extract_keys_for_event -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.events import build_actions as _build_actions +from apps.integrations.github.services.github.parser import extract_keys_for_event +from apps.integrations.github.services.github.utils import ( deduplicate_actions as _deduplicate_actions, ) # Re-exports for backward compatibility (used in tests) -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.utils import ( get_actor as _actor, ) -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.utils import ( get_link_issue as _link_issue, ) -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.utils import ( get_link_pr as _link_pr, ) -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.utils import ( get_link_push as _link_push, ) -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.utils import ( get_repo_name as _repo, ) -from apps.integrations.github.app.github.utils import ( +from apps.integrations.github.services.github.utils import ( get_user_email_from_payload, ) -from apps.integrations.github.app.jira.client import send_actions -from apps.integrations.github.app.jira.context import get_issue_context, get_project_context -from apps.integrations.github.app.store.proposal_store import ( +from apps.integrations.github.services.jira.client import send_actions +from apps.integrations.github.services.jira.context import get_issue_context, get_project_context +from apps.integrations.github.services.store.proposal_store import ( publish_new_proposal, store_proposal, ) @@ -158,7 +161,7 @@ async def handle_event(event: str, payload: dict[str, Any]) -> None: if settings.enable_ai_bool: try: # Note: propose_actions is usually from gemini module, imported here if needed - from apps.integrations.github.app.ai.gemini import propose_actions + from apps.integrations.github.services.ai.gemini import propose_actions ai_actions = await propose_actions(event, payload, keys, actions) logger.info("handle_event: AI proposed %d actions", len(ai_actions)) diff --git a/apps/integrations/github/app/github/parser.py b/apps/integrations/github/services/github/parser.py similarity index 100% rename from apps/integrations/github/app/github/parser.py rename to apps/integrations/github/services/github/parser.py diff --git a/apps/integrations/github/app/github/utils.py b/apps/integrations/github/services/github/utils.py similarity index 98% rename from apps/integrations/github/app/github/utils.py rename to apps/integrations/github/services/github/utils.py index 029ec63..c269bd6 100644 --- a/apps/integrations/github/app/github/utils.py +++ b/apps/integrations/github/services/github/utils.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any -from apps.integrations.github.app.github.client import GitHubClient +from apps.integrations.github.services.github.client import GitHubClient logger = logging.getLogger("GitHubUtils") diff --git a/apps/integrations/github/app/github/verify.py b/apps/integrations/github/services/github/verify.py similarity index 100% rename from apps/integrations/github/app/github/verify.py rename to apps/integrations/github/services/github/verify.py diff --git a/apps/integrations/github/services/jira/__init__.py b/apps/integrations/github/services/jira/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/app/jira/client.py b/apps/integrations/github/services/jira/client.py similarity index 99% rename from apps/integrations/github/app/jira/client.py rename to apps/integrations/github/services/jira/client.py index c4eef46..fe99c45 100644 --- a/apps/integrations/github/app/jira/client.py +++ b/apps/integrations/github/services/jira/client.py @@ -163,7 +163,7 @@ async def get_user_for_repo(repo_full_name: str) -> "User | None": async def get_user_active_tasks(email: str) -> list[dict[str, Any]]: """Fetch active Jira tasks for a user identified by their email address.""" - from apps.integrations.github.app.config import get_settings + from apps.integrations.github.config import get_settings if not email: return [] diff --git a/apps/integrations/github/app/jira/context.py b/apps/integrations/github/services/jira/context.py similarity index 100% rename from apps/integrations/github/app/jira/context.py rename to apps/integrations/github/services/jira/context.py diff --git a/apps/integrations/github/services/store/__init__.py b/apps/integrations/github/services/store/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/app/store/idempotency.py b/apps/integrations/github/services/store/idempotency.py similarity index 96% rename from apps/integrations/github/app/store/idempotency.py rename to apps/integrations/github/services/store/idempotency.py index c4bbe8b..bbf4f34 100644 --- a/apps/integrations/github/app/store/idempotency.py +++ b/apps/integrations/github/services/store/idempotency.py @@ -4,7 +4,7 @@ from redis.asyncio import Redis -from apps.integrations.github.app.config import get_settings +from apps.integrations.github.config import get_settings logger = logging.getLogger(__name__) diff --git a/apps/integrations/github/app/store/proposal_store.py b/apps/integrations/github/services/store/proposal_store.py similarity index 99% rename from apps/integrations/github/app/store/proposal_store.py rename to apps/integrations/github/services/store/proposal_store.py index 26d94ff..24df80f 100644 --- a/apps/integrations/github/app/store/proposal_store.py +++ b/apps/integrations/github/services/store/proposal_store.py @@ -11,7 +11,7 @@ import redis.asyncio as redis -from apps.integrations.github.app.config import get_settings +from apps.integrations.github.config import get_settings from libs.common.database import AsyncSessionLocal from libs.common.proposal_service import ( create_proposal, diff --git a/apps/integrations/github/tests/test_ai_decision.py b/apps/integrations/github/tests/test_ai_decision.py index 7da343f..c4c6a30 100644 --- a/apps/integrations/github/tests/test_ai_decision.py +++ b/apps/integrations/github/tests/test_ai_decision.py @@ -5,7 +5,7 @@ import pytest -from apps.integrations.github.app.ai.gemini import ( +from apps.integrations.github.services.ai.gemini import ( _build_tech_lead_prompt, ai_decide_transition, ) @@ -55,8 +55,8 @@ async def test_ai_decide_transition_should_complete(): } with ( - patch("apps.integrations.github.app.ai.gemini.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.gemini.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.gemini.get_settings") as mock_settings, + patch("apps.integrations.github.services.ai.gemini.genai.Client") as mock_client_class, ): mock_settings.return_value.enable_ai_bool = True mock_settings.return_value.gemini_api_key = "fake_key" @@ -93,8 +93,8 @@ async def test_ai_decide_transition_wip_no_transition(): } with ( - patch("apps.integrations.github.app.ai.gemini.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.gemini.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.gemini.get_settings") as mock_settings, + patch("apps.integrations.github.services.ai.gemini.genai.Client") as mock_client_class, ): mock_settings.return_value.enable_ai_bool = True mock_settings.return_value.gemini_api_key = "fake_key" @@ -118,7 +118,7 @@ async def test_ai_decide_transition_ai_disabled(): payload = {"commits": []} # MUST patch target in gemini.py where it is used - with patch("apps.integrations.github.app.ai.gemini.get_settings") as mock_settings: + with patch("apps.integrations.github.services.ai.gemini.get_settings") as mock_settings: mock_settings.return_value.enable_ai_bool = False mock_settings.return_value.gemini_api_key = None @@ -140,8 +140,8 @@ async def test_ai_decide_transition_api_failure(): payload = {"commits": [{"message": "Fix bug"}]} with ( - patch("apps.integrations.github.app.ai.gemini.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.gemini.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.gemini.get_settings") as mock_settings, + patch("apps.integrations.github.services.ai.gemini.genai.Client") as mock_client_class, ): mock_client = mock_client_class.return_value mock_client.aio.models.generate_content = AsyncMock(side_effect=Exception("API timeout")) @@ -184,8 +184,8 @@ async def test_ai_decide_transition_pr_merged(): } with ( - patch("apps.integrations.github.app.ai.gemini.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.gemini.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.gemini.get_settings") as mock_settings, + patch("apps.integrations.github.services.ai.gemini.genai.Client") as mock_client_class, ): mock_settings.return_value.enable_ai_bool = True mock_settings.return_value.gemini_api_key = "fake_key" diff --git a/apps/integrations/github/tests/test_client.py b/apps/integrations/github/tests/test_client.py index 0a65f4c..fcc8f4c 100644 --- a/apps/integrations/github/tests/test_client.py +++ b/apps/integrations/github/tests/test_client.py @@ -4,7 +4,7 @@ import pytest -from apps.integrations.github.app.github.client import GitHubClient +from apps.integrations.github.services.github.client import GitHubClient class TestGitHubClient: diff --git a/apps/integrations/github/tests/test_deduplication.py b/apps/integrations/github/tests/test_deduplication.py index 5c57396..ebe64a8 100644 --- a/apps/integrations/github/tests/test_deduplication.py +++ b/apps/integrations/github/tests/test_deduplication.py @@ -1,6 +1,6 @@ """Unit tests for _deduplicate_actions() in the GitHub event handler.""" -from apps.integrations.github.app.github.handler import _deduplicate_actions +from apps.integrations.github.services.github.handler import _deduplicate_actions def test_dedup_comment_same_text_same_key(): diff --git a/apps/integrations/github/tests/test_diff_analyzer.py b/apps/integrations/github/tests/test_diff_analyzer.py index 80bcf85..01c6aa9 100644 --- a/apps/integrations/github/tests/test_diff_analyzer.py +++ b/apps/integrations/github/tests/test_diff_analyzer.py @@ -12,7 +12,7 @@ class TestDiffAnalyzer: @pytest.mark.asyncio async def test_analyze_diff_empty_inputs(self): """Test with empty diff or tasks.""" - from apps.integrations.github.app.ai.diff_analyzer import analyze_diff_vs_tasks + from apps.integrations.github.services.ai.diff_analyzer import analyze_diff_vs_tasks # Empty diff result = await analyze_diff_vs_tasks("", [{"key": "PROJ-1", "summary": "Test"}]) @@ -25,9 +25,11 @@ async def test_analyze_diff_empty_inputs(self): @pytest.mark.asyncio async def test_analyze_diff_no_api_key(self): """Test when GEMINI_API_KEY is not configured.""" - from apps.integrations.github.app.ai.diff_analyzer import analyze_diff_vs_tasks + from apps.integrations.github.services.ai.diff_analyzer import analyze_diff_vs_tasks - with patch("apps.integrations.github.app.ai.diff_analyzer.get_settings") as mock_settings: + with patch( + "apps.integrations.github.services.ai.diff_analyzer.get_settings" + ) as mock_settings: mock_settings.return_value.gemini_api_key = "" result = await analyze_diff_vs_tasks( @@ -39,11 +41,13 @@ async def test_analyze_diff_no_api_key(self): @pytest.mark.asyncio async def test_analyze_diff_success(self): """Test successful diff analysis.""" - from apps.integrations.github.app.ai.diff_analyzer import analyze_diff_vs_tasks + from apps.integrations.github.services.ai.diff_analyzer import analyze_diff_vs_tasks with ( - patch("apps.integrations.github.app.ai.diff_analyzer.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.diff_analyzer.genai.Client") as mock_genai, + patch( + "apps.integrations.github.services.ai.diff_analyzer.get_settings" + ) as mock_settings, + patch("apps.integrations.github.services.ai.diff_analyzer.genai.Client") as mock_genai, ): mock_settings.return_value.gemini_api_key = "test-key" @@ -87,11 +91,13 @@ async def test_analyze_diff_success(self): @pytest.mark.asyncio async def test_analyze_diff_with_pydantic_response(self): """Test analysis with the new structured output response.""" - from apps.integrations.github.app.ai.diff_analyzer import analyze_diff_vs_tasks + from apps.integrations.github.services.ai.diff_analyzer import analyze_diff_vs_tasks with ( - patch("apps.integrations.github.app.ai.diff_analyzer.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.diff_analyzer.genai.Client") as mock_genai, + patch( + "apps.integrations.github.services.ai.diff_analyzer.get_settings" + ) as mock_settings, + patch("apps.integrations.github.services.ai.diff_analyzer.genai.Client") as mock_genai, ): mock_settings.return_value.gemini_api_key = "test-key" @@ -117,7 +123,7 @@ async def test_analyze_diff_with_pydantic_response(self): @pytest.mark.asyncio async def test_truncate_diff(self): """Test diff truncation for large diffs.""" - from apps.integrations.github.app.ai.diff_analyzer import _truncate_diff + from apps.integrations.github.services.ai.diff_analyzer import _truncate_diff # Short diff - no truncation short_diff = "a" * 1000 diff --git a/apps/integrations/github/tests/test_handler_integration.py b/apps/integrations/github/tests/test_handler_integration.py index 5891444..a3e2713 100644 --- a/apps/integrations/github/tests/test_handler_integration.py +++ b/apps/integrations/github/tests/test_handler_integration.py @@ -4,7 +4,7 @@ import pytest -from apps.integrations.github.app.github.handler import _build_actions +from apps.integrations.github.services.github.handler import _build_actions # Reusable mock user — no DB connection needed in unit tests. @@ -46,30 +46,30 @@ async def test_build_actions_push_event_ai_decides_done(): } with ( - patch("apps.integrations.github.app.github.events.get_settings") as mock_settings, + patch("apps.integrations.github.services.github.events.get_settings") as mock_settings, patch( - "apps.integrations.github.app.github.events.get_user_for_repo", + "apps.integrations.github.services.github.events.get_user_for_repo", new_callable=AsyncMock, return_value=_mock_user(), ), patch( - "apps.integrations.github.app.github.events.get_issue_context", + "apps.integrations.github.services.github.events.get_issue_context", new_callable=AsyncMock, return_value=mock_jira_context, ), patch( - "apps.integrations.github.app.github.events.get_project_context", + "apps.integrations.github.services.github.events.get_project_context", new_callable=AsyncMock, return_value={"summary": "Test project"}, ), patch( - "apps.integrations.github.app.github.events.ai_decide_transition", + "apps.integrations.github.services.github.events.ai_decide_transition", new_callable=AsyncMock, return_value=mock_ai_decision, ), # Proposal storage must be mocked — no DB in unit tests. patch( - "apps.integrations.github.app.github.ai_logic.store_proposal", + "apps.integrations.github.services.github.ai_logic.store_proposal", new_callable=AsyncMock, return_value="proposal-id-123", ), @@ -110,24 +110,24 @@ async def test_build_actions_push_event_ai_decides_comment_only(): } with ( - patch("apps.integrations.github.app.github.events.get_settings") as mock_settings, + patch("apps.integrations.github.services.github.events.get_settings") as mock_settings, patch( - "apps.integrations.github.app.github.events.get_user_for_repo", + "apps.integrations.github.services.github.events.get_user_for_repo", new_callable=AsyncMock, return_value=_mock_user(), ), patch( - "apps.integrations.github.app.github.events.get_issue_context", + "apps.integrations.github.services.github.events.get_issue_context", new_callable=AsyncMock, return_value=mock_jira_context, ), patch( - "apps.integrations.github.app.github.events.get_project_context", + "apps.integrations.github.services.github.events.get_project_context", new_callable=AsyncMock, return_value={"summary": "Test"}, ), patch( - "apps.integrations.github.app.github.events.ai_decide_transition", + "apps.integrations.github.services.github.events.ai_decide_transition", new_callable=AsyncMock, return_value=mock_ai_decision, ), @@ -157,29 +157,30 @@ async def test_build_actions_pr_merged_with_ci_failure(): keys = ["SCRUM-123"] with ( - patch("apps.integrations.github.app.github.events.get_settings") as mock_settings, + patch("apps.integrations.github.services.github.events.get_settings") as mock_settings, patch( - "apps.integrations.github.app.github.events.get_user_for_repo", + "apps.integrations.github.services.github.events.get_user_for_repo", new_callable=AsyncMock, return_value=_mock_user(), ), patch( - "apps.integrations.github.app.github.events.get_project_context", + "apps.integrations.github.services.github.events.get_project_context", new_callable=AsyncMock, return_value={"summary": "test"}, ), patch( - "apps.integrations.github.app.github.events.get_issue_context", + "apps.integrations.github.services.github.events.get_issue_context", new_callable=AsyncMock, return_value={}, ), patch( - "apps.integrations.github.app.github.events.get_ci_status", + "apps.integrations.github.services.github.events.get_ci_status", new_callable=AsyncMock, return_value=("failed", "Tests failed"), ), patch( - "apps.integrations.github.app.github.events.should_block_transition", return_value=True + "apps.integrations.github.services.github.events.should_block_transition", + return_value=True, ), ): mock_settings.return_value.enable_ci_awareness_bool = True @@ -222,24 +223,24 @@ async def test_build_actions_pr_opened_ai_decides_review(): } with ( - patch("apps.integrations.github.app.github.events.get_settings") as mock_settings, + patch("apps.integrations.github.services.github.events.get_settings") as mock_settings, patch( - "apps.integrations.github.app.github.events.get_user_for_repo", + "apps.integrations.github.services.github.events.get_user_for_repo", new_callable=AsyncMock, return_value=_mock_user(), ), patch( - "apps.integrations.github.app.github.events.get_issue_context", + "apps.integrations.github.services.github.events.get_issue_context", new_callable=AsyncMock, return_value=mock_jira_context, ), patch( - "apps.integrations.github.app.github.events.get_project_context", + "apps.integrations.github.services.github.events.get_project_context", new_callable=AsyncMock, return_value={"summary": "test"}, ), patch( - "apps.integrations.github.app.github.events.ai_decide_transition", + "apps.integrations.github.services.github.events.ai_decide_transition", new_callable=AsyncMock, return_value=mock_ai_decision, ), diff --git a/apps/integrations/github/tests/test_interactive_flow.py b/apps/integrations/github/tests/test_interactive_flow.py index ef9a043..8f243ba 100644 --- a/apps/integrations/github/tests/test_interactive_flow.py +++ b/apps/integrations/github/tests/test_interactive_flow.py @@ -2,7 +2,7 @@ import pytest -from apps.integrations.github.app.github.handler import _handle_ai_decision +from apps.integrations.github.services.github.handler import _handle_ai_decision @pytest.mark.asyncio @@ -25,10 +25,11 @@ async def test_interactive_flow_proposal_creation(): # We mock store_proposal and publish_new_proposal to verify they are called with ( patch( - "apps.integrations.github.app.github.ai_logic.store_proposal", new_callable=AsyncMock + "apps.integrations.github.services.github.ai_logic.store_proposal", + new_callable=AsyncMock, ) as mock_store, patch( - "apps.integrations.github.app.github.ai_logic.get_repo_name", + "apps.integrations.github.services.github.ai_logic.get_repo_name", return_value="owner/repo", ), ): @@ -74,7 +75,7 @@ async def test_interactive_flow_no_done_transition(): link = "http://link" with patch( - "apps.integrations.github.app.github.ai_logic.store_proposal", new_callable=AsyncMock + "apps.integrations.github.services.github.ai_logic.store_proposal", new_callable=AsyncMock ) as mock_store: actions = await _handle_ai_decision( decision, issue_key, user_email, payload, link, "default" @@ -105,7 +106,7 @@ async def test_interactive_flow_no_email(): link = "http://link" with patch( - "apps.integrations.github.app.github.ai_logic.store_proposal", new_callable=AsyncMock + "apps.integrations.github.services.github.ai_logic.store_proposal", new_callable=AsyncMock ) as mock_store: actions = await _handle_ai_decision( decision, issue_key, user_email, payload, link, "default" diff --git a/apps/integrations/github/tests/test_jira_context.py b/apps/integrations/github/tests/test_jira_context.py index 4707857..ac17c2c 100644 --- a/apps/integrations/github/tests/test_jira_context.py +++ b/apps/integrations/github/tests/test_jira_context.py @@ -4,7 +4,7 @@ import pytest -from apps.integrations.github.app.jira.context import ( +from apps.integrations.github.services.jira.context import ( get_issue_context, get_project_context, ) @@ -30,7 +30,7 @@ async def test_get_issue_context_success(): mock_jira_client.get_status_transitions = AsyncMock(return_value=mock_transitions_response) with patch( - "apps.integrations.github.app.jira.context._get_system_jira_client", + "apps.integrations.github.services.jira.context._get_system_jira_client", return_value=mock_jira_client, ): result = await get_issue_context("SCRUM-123") @@ -46,7 +46,7 @@ async def test_get_issue_context_success(): async def test_get_issue_context_failure_fallback(): """Test fallback when Jira service is unavailable.""" with patch( - "apps.integrations.github.app.jira.context._get_system_jira_client", return_value=None + "apps.integrations.github.services.jira.context._get_system_jira_client", return_value=None ): result = await get_issue_context("SCRUM-123") @@ -65,7 +65,7 @@ async def test_get_project_context_success(): mock_jira_client.get_project_summary_text = AsyncMock(return_value=mock_summary) with patch( - "apps.integrations.github.app.jira.context._get_system_jira_client", + "apps.integrations.github.services.jira.context._get_system_jira_client", return_value=mock_jira_client, ): result = await get_project_context() @@ -81,7 +81,7 @@ async def test_get_project_context_empty(): mock_jira_client.get_project_summary_text = AsyncMock(return_value="") with patch( - "apps.integrations.github.app.jira.context._get_system_jira_client", + "apps.integrations.github.services.jira.context._get_system_jira_client", return_value=mock_jira_client, ): result = await get_project_context() diff --git a/apps/integrations/github/tests/test_parser.py b/apps/integrations/github/tests/test_parser.py index ddb9990..f626934 100644 --- a/apps/integrations/github/tests/test_parser.py +++ b/apps/integrations/github/tests/test_parser.py @@ -1,6 +1,6 @@ """Unit tests for Jira key extraction from GitHub payloads.""" -from apps.integrations.github.app.github.parser import ( +from apps.integrations.github.services.github.parser import ( extract_jira_keys, extract_keys_for_event, extract_keys_from_pull_request, diff --git a/apps/integrations/github/tests/test_proposal_store.py b/apps/integrations/github/tests/test_proposal_store.py index a7dbdb6..1fac489 100644 --- a/apps/integrations/github/tests/test_proposal_store.py +++ b/apps/integrations/github/tests/test_proposal_store.py @@ -12,7 +12,7 @@ class TestProposalStore: @pytest.mark.asyncio async def test_store_proposal(self): """Test storing a proposal in PostgreSQL and notifying Redis.""" - from apps.integrations.github.app.store.proposal_store import store_proposal + from apps.integrations.github.services.store.proposal_store import store_proposal mock_redis = AsyncMock() mock_redis.publish = AsyncMock() @@ -21,18 +21,18 @@ async def test_store_proposal(self): # Mock DB helpers and service with ( patch( - "apps.integrations.github.app.store.proposal_store._get_user_id_by_email", + "apps.integrations.github.services.store.proposal_store._get_user_id_by_email", return_value=123, ), patch( - "apps.integrations.github.app.store.proposal_store._get_user_project_key", + "apps.integrations.github.services.store.proposal_store._get_user_project_key", return_value="PROJ", ), patch( - "apps.integrations.github.app.store.proposal_store.store_proposal_in_db", + "apps.integrations.github.services.store.proposal_store.store_proposal_in_db", return_value="test-uuid", ), - patch("apps.integrations.github.app.store.proposal_store.RedisManager") as mock_rm, + patch("apps.integrations.github.services.store.proposal_store.RedisManager") as mock_rm, ): mock_rm.get_client.return_value = mock_redis @@ -50,7 +50,7 @@ async def test_store_proposal(self): @pytest.mark.asyncio async def test_get_proposal(self): """Test retrieving a proposal from PostgreSQL.""" - from apps.integrations.github.app.store.proposal_store import get_proposal + from apps.integrations.github.services.store.proposal_store import get_proposal mock_proposal = MagicMock() mock_proposal.id = "test-id" @@ -62,10 +62,10 @@ async def test_get_proposal(self): with ( patch( - "apps.integrations.github.app.store.proposal_store.get_proposal_by_id", + "apps.integrations.github.services.store.proposal_store.get_proposal_by_id", return_value=mock_proposal, ), - patch("apps.integrations.github.app.store.proposal_store.AsyncSessionLocal"), + patch("apps.integrations.github.services.store.proposal_store.AsyncSessionLocal"), ): result = await get_proposal("test-id") @@ -77,14 +77,14 @@ async def test_get_proposal(self): @pytest.mark.asyncio async def test_get_proposal_not_found(self): """Test getting non-existent proposal.""" - from apps.integrations.github.app.store.proposal_store import get_proposal + from apps.integrations.github.services.store.proposal_store import get_proposal with ( patch( - "apps.integrations.github.app.store.proposal_store.get_proposal_by_id", + "apps.integrations.github.services.store.proposal_store.get_proposal_by_id", return_value=None, ), - patch("apps.integrations.github.app.store.proposal_store.AsyncSessionLocal"), + patch("apps.integrations.github.services.store.proposal_store.AsyncSessionLocal"), ): result = await get_proposal("nonexistent") assert result is None @@ -92,7 +92,7 @@ async def test_get_proposal_not_found(self): @pytest.mark.asyncio async def test_get_user_proposals(self): """Test getting all proposals for a user.""" - from apps.integrations.github.app.store.proposal_store import get_user_proposals + from apps.integrations.github.services.store.proposal_store import get_user_proposals mock_proposal = MagicMock() mock_proposal.id = "id1" @@ -103,14 +103,14 @@ async def test_get_user_proposals(self): with ( patch( - "apps.integrations.github.app.store.proposal_store._get_user_id_by_email", + "apps.integrations.github.services.store.proposal_store._get_user_id_by_email", return_value=123, ), patch( - "apps.integrations.github.app.store.proposal_store.get_proposals_for_user", + "apps.integrations.github.services.store.proposal_store.get_proposals_for_user", return_value=[mock_proposal], ), - patch("apps.integrations.github.app.store.proposal_store.AsyncSessionLocal"), + patch("apps.integrations.github.services.store.proposal_store.AsyncSessionLocal"), ): result = await get_user_proposals("user@example.com") @@ -121,14 +121,14 @@ async def test_get_user_proposals(self): @pytest.mark.asyncio async def test_update_proposal_status(self): """Test updating proposal status.""" - from apps.integrations.github.app.store.proposal_store import update_proposal_status + from apps.integrations.github.services.store.proposal_store import update_proposal_status with ( patch( - "apps.integrations.github.app.store.proposal_store.update_db_status", + "apps.integrations.github.services.store.proposal_store.update_db_status", return_value=MagicMock(), ), - patch("apps.integrations.github.app.store.proposal_store.AsyncSessionLocal"), + patch("apps.integrations.github.services.store.proposal_store.AsyncSessionLocal"), ): result = await update_proposal_status("test-id", "confirmed") assert result is True @@ -136,12 +136,14 @@ async def test_update_proposal_status(self): @pytest.mark.asyncio async def test_publish_new_proposal(self): """Test publishing proposal notification.""" - from apps.integrations.github.app.store.proposal_store import publish_new_proposal + from apps.integrations.github.services.store.proposal_store import publish_new_proposal mock_redis = AsyncMock() mock_redis.publish = AsyncMock() - with patch("apps.integrations.github.app.store.proposal_store.RedisManager") as mock_rm: + with patch( + "apps.integrations.github.services.store.proposal_store.RedisManager" + ) as mock_rm: mock_rm.get_client.return_value = mock_redis await publish_new_proposal("user@example.com", "test-proposal-id") diff --git a/apps/integrations/github/tests/test_proposal_store_db.py b/apps/integrations/github/tests/test_proposal_store_db.py index e349503..8feb18c 100644 --- a/apps/integrations/github/tests/test_proposal_store_db.py +++ b/apps/integrations/github/tests/test_proposal_store_db.py @@ -11,7 +11,7 @@ class TestStoreProposalInDb: @pytest.mark.asyncio async def test_store_proposal_in_db_success(self): """Test successful storage in DB.""" - from apps.integrations.github.app.store.proposal_store import store_proposal_in_db + from apps.integrations.github.services.store.proposal_store import store_proposal_in_db mock_proposal = MagicMock() mock_proposal.id = "db-prop-123" @@ -23,11 +23,11 @@ async def test_store_proposal_in_db_success(self): with ( patch( - "apps.integrations.github.app.store.proposal_store.AsyncSessionLocal", + "apps.integrations.github.services.store.proposal_store.AsyncSessionLocal", return_value=mock_cm, ), patch( - "apps.integrations.github.app.store.proposal_store.create_proposal", + "apps.integrations.github.services.store.proposal_store.create_proposal", return_value=mock_proposal, ) as mock_create, ): @@ -41,7 +41,7 @@ async def test_store_proposal_in_db_success(self): @pytest.mark.asyncio async def test_store_proposal_in_db_failure(self): """Test db failure propagation.""" - from apps.integrations.github.app.store.proposal_store import store_proposal_in_db + from apps.integrations.github.services.store.proposal_store import store_proposal_in_db mock_session = AsyncMock() mock_cm = MagicMock() @@ -50,11 +50,11 @@ async def test_store_proposal_in_db_failure(self): with ( patch( - "apps.integrations.github.app.store.proposal_store.AsyncSessionLocal", + "apps.integrations.github.services.store.proposal_store.AsyncSessionLocal", return_value=mock_cm, ), patch( - "apps.integrations.github.app.store.proposal_store.create_proposal", + "apps.integrations.github.services.store.proposal_store.create_proposal", side_effect=Exception("DB Error"), ), ): @@ -64,25 +64,25 @@ async def test_store_proposal_in_db_failure(self): @pytest.mark.asyncio async def test_call_site_swallows_db_error(self): """Test that ai_logic swallows DB errors when creating proposals.""" - from apps.integrations.github.app.github.ai_logic import create_transition_proposal + from apps.integrations.github.services.github.ai_logic import create_transition_proposal mock_redis = AsyncMock() mock_redis.publish = AsyncMock() with ( patch( - "apps.integrations.github.app.store.proposal_store._get_user_id_by_email", + "apps.integrations.github.services.store.proposal_store._get_user_id_by_email", return_value=1, ), patch( - "apps.integrations.github.app.store.proposal_store._get_user_project_key", + "apps.integrations.github.services.store.proposal_store._get_user_project_key", return_value="PROJ", ), patch( - "apps.integrations.github.app.store.proposal_store.store_proposal_in_db", + "apps.integrations.github.services.store.proposal_store.store_proposal_in_db", side_effect=Exception("DB DOWN"), ), - patch("apps.integrations.github.app.store.proposal_store.RedisManager") as mock_rm, + patch("apps.integrations.github.services.store.proposal_store.RedisManager") as mock_rm, ): mock_rm.get_client.return_value = mock_redis diff --git a/apps/integrations/github/tests/test_risk_analysis.py b/apps/integrations/github/tests/test_risk_analysis.py index 4c72760..326cfa0 100644 --- a/apps/integrations/github/tests/test_risk_analysis.py +++ b/apps/integrations/github/tests/test_risk_analysis.py @@ -5,7 +5,7 @@ import pytest -from apps.integrations.github.app.ai.risk_analysis import ( +from apps.integrations.github.services.ai.risk_analysis import ( _build_risk_context, analyze_pr_risk, ) @@ -14,7 +14,7 @@ @pytest.mark.asyncio async def test_analyze_pr_risk_disabled(): """Returns None when enable_ai_risk_check_bool is False.""" - with patch("apps.integrations.github.app.ai.risk_analysis.get_settings") as mock_settings: + with patch("apps.integrations.github.services.ai.risk_analysis.get_settings") as mock_settings: mock_settings.return_value.enable_ai_risk_check_bool = False mock_settings.return_value.gemini_api_key = "fake_key" @@ -26,7 +26,7 @@ async def test_analyze_pr_risk_disabled(): @pytest.mark.asyncio async def test_analyze_pr_risk_no_api_key(): """Returns None when gemini_api_key is falsy.""" - with patch("apps.integrations.github.app.ai.risk_analysis.get_settings") as mock_settings: + with patch("apps.integrations.github.services.ai.risk_analysis.get_settings") as mock_settings: mock_settings.return_value.enable_ai_risk_check_bool = True mock_settings.return_value.gemini_api_key = None @@ -60,8 +60,10 @@ async def test_analyze_pr_risk_auth_files_high_risk(): } with ( - patch("apps.integrations.github.app.ai.risk_analysis.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.risk_analysis.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.risk_analysis.get_settings") as mock_settings, + patch( + "apps.integrations.github.services.ai.risk_analysis.genai.Client" + ) as mock_client_class, ): mock_settings.return_value.enable_ai_risk_check_bool = True mock_settings.return_value.gemini_api_key = "fake_key" @@ -104,8 +106,10 @@ async def test_analyze_pr_risk_readme_low_risk(): } with ( - patch("apps.integrations.github.app.ai.risk_analysis.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.risk_analysis.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.risk_analysis.get_settings") as mock_settings, + patch( + "apps.integrations.github.services.ai.risk_analysis.genai.Client" + ) as mock_client_class, ): mock_settings.return_value.enable_ai_risk_check_bool = True mock_settings.return_value.gemini_api_key = "fake_key" @@ -132,8 +136,10 @@ async def test_analyze_pr_risk_gemini_exception(): } with ( - patch("apps.integrations.github.app.ai.risk_analysis.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.risk_analysis.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.risk_analysis.get_settings") as mock_settings, + patch( + "apps.integrations.github.services.ai.risk_analysis.genai.Client" + ) as mock_client_class, ): mock_settings.return_value.enable_ai_risk_check_bool = True mock_settings.return_value.gemini_api_key = "fake_key" @@ -155,8 +161,10 @@ async def test_analyze_pr_risk_invalid_json(): } with ( - patch("apps.integrations.github.app.ai.risk_analysis.get_settings") as mock_settings, - patch("apps.integrations.github.app.ai.risk_analysis.genai.Client") as mock_client_class, + patch("apps.integrations.github.services.ai.risk_analysis.get_settings") as mock_settings, + patch( + "apps.integrations.github.services.ai.risk_analysis.genai.Client" + ) as mock_client_class, ): mock_settings.return_value.enable_ai_risk_check_bool = True mock_settings.return_value.gemini_api_key = "fake_key" diff --git a/apps/integrations/github/tests/test_user_mapping.py b/apps/integrations/github/tests/test_user_mapping.py index b3a3976..ce931ff 100644 --- a/apps/integrations/github/tests/test_user_mapping.py +++ b/apps/integrations/github/tests/test_user_mapping.py @@ -10,7 +10,7 @@ class TestUserMapping: def test_set_and_get_user_mapping(self): """Test setting and retrieving user mappings.""" - from apps.integrations.github.app.utils.user_mapping import ( + from apps.integrations.github.utils.user_mapping import ( get_user_mapping, set_user_mapping, ) @@ -28,14 +28,14 @@ def test_set_and_get_user_mapping(self): def test_get_user_mapping_no_mapping(self): """Test getting mapping when none exists - returns same email.""" - from apps.integrations.github.app.utils.user_mapping import get_user_mapping + from apps.integrations.github.utils.user_mapping import get_user_mapping result = get_user_mapping("new_user@example.com") assert result == "new_user@example.com" def test_extract_email_from_payload_pusher(self): """Test extracting email from pusher in payload.""" - from apps.integrations.github.app.utils.user_mapping import extract_email_from_payload + from apps.integrations.github.utils.user_mapping import extract_email_from_payload payload = {"pusher": {"name": "user", "email": "pusher@example.com"}} @@ -44,7 +44,7 @@ def test_extract_email_from_payload_pusher(self): def test_extract_email_from_payload_commits(self): """Test extracting email from commits when pusher has no email.""" - from apps.integrations.github.app.utils.user_mapping import extract_email_from_payload + from apps.integrations.github.utils.user_mapping import extract_email_from_payload payload = { "pusher": {"name": "user"}, @@ -56,7 +56,7 @@ def test_extract_email_from_payload_commits(self): def test_extract_username_from_payload_sender(self): """Test extracting username from sender.""" - from apps.integrations.github.app.utils.user_mapping import extract_username_from_payload + from apps.integrations.github.utils.user_mapping import extract_username_from_payload payload = {"sender": {"login": "github_user"}} @@ -65,7 +65,7 @@ def test_extract_username_from_payload_sender(self): def test_extract_username_from_payload_pusher(self): """Test extracting username from pusher when no sender.""" - from apps.integrations.github.app.utils.user_mapping import extract_username_from_payload + from apps.integrations.github.utils.user_mapping import extract_username_from_payload payload = {"pusher": {"name": "pusher_user"}} @@ -75,7 +75,7 @@ def test_extract_username_from_payload_pusher(self): @pytest.mark.asyncio async def test_get_jira_email_for_github_user_with_email(self): """Test getting Jira email when GitHub email is provided.""" - from apps.integrations.github.app.utils.user_mapping import get_jira_email_for_github_user + from apps.integrations.github.utils.user_mapping import get_jira_email_for_github_user result = await get_jira_email_for_github_user("testuser", github_email="test@example.com") @@ -84,15 +84,15 @@ async def test_get_jira_email_for_github_user_with_email(self): @pytest.mark.asyncio async def test_get_jira_email_for_github_user_fetch_from_api(self): """Test getting Jira email by fetching from GitHub API.""" - from apps.integrations.github.app.utils.user_mapping import get_jira_email_for_github_user + from apps.integrations.github.utils.user_mapping import get_jira_email_for_github_user mock_client = MagicMock() mock_client.get_user_email = AsyncMock(return_value="api_email@example.com") with ( - patch("apps.integrations.github.app.utils.user_mapping.get_settings") as mock_settings, + patch("apps.integrations.github.utils.user_mapping.get_settings") as mock_settings, patch( - "apps.integrations.github.app.utils.user_mapping.GitHubClient", + "apps.integrations.github.utils.user_mapping.GitHubClient", return_value=mock_client, ), ): diff --git a/apps/integrations/github/tests/test_verify.py b/apps/integrations/github/tests/test_verify.py index 8083248..83c193c 100644 --- a/apps/integrations/github/tests/test_verify.py +++ b/apps/integrations/github/tests/test_verify.py @@ -3,7 +3,7 @@ import hashlib import hmac -from apps.integrations.github.app.github.verify import verify_signature +from apps.integrations.github.services.github.verify import verify_signature def _hmac_sha256_hex(secret: str, body: bytes) -> str: diff --git a/apps/integrations/github/tests/test_webhook_endpoint.py b/apps/integrations/github/tests/test_webhook_endpoint.py index 1f8c4c4..b449490 100644 --- a/apps/integrations/github/tests/test_webhook_endpoint.py +++ b/apps/integrations/github/tests/test_webhook_endpoint.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from apps.integrations.github.app.routes.webhook import router +from apps.integrations.github.routes.webhook import router # Build a minimal test app using only the webhook router. app = FastAPI() @@ -26,13 +26,13 @@ def test_webhook_valid_push_event(): """Valid signature + push event + not a duplicate → 200 accepted.""" with ( - patch("apps.integrations.github.app.routes.webhook.get_settings") as mock_settings, - patch("apps.integrations.github.app.routes.webhook.verify_signature", return_value=True), + patch("apps.integrations.github.routes.webhook.get_settings") as mock_settings, + patch("apps.integrations.github.routes.webhook.verify_signature", return_value=True), patch( - "apps.integrations.github.app.routes.webhook.mark_processed", + "apps.integrations.github.routes.webhook.mark_processed", new=AsyncMock(return_value=True), ), - patch("apps.integrations.github.app.routes.webhook.handle_event", new=AsyncMock()), + patch("apps.integrations.github.routes.webhook.handle_event", new=AsyncMock()), ): mock_settings.return_value.github_webhook_secret = "secret" @@ -54,8 +54,8 @@ def test_webhook_valid_push_event(): def test_webhook_invalid_signature(): """Failing signature verification → 403 Forbidden.""" with ( - patch("apps.integrations.github.app.routes.webhook.get_settings") as mock_settings, - patch("apps.integrations.github.app.routes.webhook.verify_signature", return_value=False), + patch("apps.integrations.github.routes.webhook.get_settings") as mock_settings, + patch("apps.integrations.github.routes.webhook.verify_signature", return_value=False), ): mock_settings.return_value.github_webhook_secret = "secret" @@ -76,13 +76,13 @@ def test_webhook_invalid_signature(): def test_webhook_duplicate_delivery(): """A delivery ID that was already processed → 200 ignored/duplicate.""" with ( - patch("apps.integrations.github.app.routes.webhook.get_settings") as mock_settings, - patch("apps.integrations.github.app.routes.webhook.verify_signature", return_value=True), + patch("apps.integrations.github.routes.webhook.get_settings") as mock_settings, + patch("apps.integrations.github.routes.webhook.verify_signature", return_value=True), patch( - "apps.integrations.github.app.routes.webhook.mark_processed", + "apps.integrations.github.routes.webhook.mark_processed", new=AsyncMock(return_value=False), ), - patch("apps.integrations.github.app.routes.webhook.handle_event", new=AsyncMock()), + patch("apps.integrations.github.routes.webhook.handle_event", new=AsyncMock()), ): mock_settings.return_value.github_webhook_secret = "secret" @@ -106,12 +106,12 @@ def test_webhook_ping_event(): ping_payload = {"zen": "Keep it logically awesome.", "hook_id": 123} with ( - patch("apps.integrations.github.app.routes.webhook.get_settings") as mock_settings, + patch("apps.integrations.github.routes.webhook.get_settings") as mock_settings, patch( - "apps.integrations.github.app.routes.webhook.mark_processed", + "apps.integrations.github.routes.webhook.mark_processed", new=AsyncMock(return_value=True), ), - patch("apps.integrations.github.app.routes.webhook.handle_event", new=AsyncMock()), + patch("apps.integrations.github.routes.webhook.handle_event", new=AsyncMock()), ): mock_settings.return_value.github_webhook_secret = None @@ -132,13 +132,13 @@ def test_webhook_ping_event(): def test_webhook_no_webhook_secret_bypasses_signature(): """When github_webhook_secret is None/empty, signature is not checked → 200 accepted.""" with ( - patch("apps.integrations.github.app.routes.webhook.get_settings") as mock_settings, - patch("apps.integrations.github.app.routes.webhook.verify_signature") as mock_verify, + patch("apps.integrations.github.routes.webhook.get_settings") as mock_settings, + patch("apps.integrations.github.routes.webhook.verify_signature") as mock_verify, patch( - "apps.integrations.github.app.routes.webhook.mark_processed", + "apps.integrations.github.routes.webhook.mark_processed", new=AsyncMock(return_value=True), ), - patch("apps.integrations.github.app.routes.webhook.handle_event", new=AsyncMock()), + patch("apps.integrations.github.routes.webhook.handle_event", new=AsyncMock()), ): mock_settings.return_value.github_webhook_secret = None diff --git a/apps/integrations/github/utils/__init__.py b/apps/integrations/github/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/github/app/utils/logging.py b/apps/integrations/github/utils/logging.py similarity index 100% rename from apps/integrations/github/app/utils/logging.py rename to apps/integrations/github/utils/logging.py diff --git a/apps/integrations/github/app/utils/user_mapping.py b/apps/integrations/github/utils/user_mapping.py similarity index 96% rename from apps/integrations/github/app/utils/user_mapping.py rename to apps/integrations/github/utils/user_mapping.py index f578589..b70cbfe 100644 --- a/apps/integrations/github/app/utils/user_mapping.py +++ b/apps/integrations/github/utils/user_mapping.py @@ -6,8 +6,8 @@ import logging from typing import Optional -from apps.integrations.github.app.config import get_settings -from apps.integrations.github.app.github.client import GitHubClient +from apps.integrations.github.config import get_settings +from apps.integrations.github.services.github.client import GitHubClient logger = logging.getLogger(__name__) diff --git a/apps/integrations/jira/config.py b/apps/integrations/jira/config.py index 7bafe91..21b4f3b 100644 --- a/apps/integrations/jira/config.py +++ b/apps/integrations/jira/config.py @@ -36,8 +36,12 @@ GITHUB_USERINFO_URL = "https://api.github.com/user" -# Scopes required for the app -SCOPES = "read:jira-work write:jira-work read:jira-user read:me offline_access read:board-scope:jira-software read:sprint:jira-software write:sprint:jira-software read:project:jira" +# Scopes required for the app (includes Trello scopes — same Atlassian OAuth token covers both) +SCOPES = ( + "read:jira-work write:jira-work read:jira-user read:me offline_access" + " read:board-scope:jira-software read:sprint:jira-software write:sprint:jira-software read:project:jira" + " read:trello-user read:board:trello write:board:trello manage:board:trello" +) # Auth URL AUTH_URL = "https://auth.atlassian.com/authorize" diff --git a/apps/integrations/jira/routes/api_routes/__init__.py b/apps/integrations/jira/routes/api_routes/__init__.py index 4d4418a..48bcdc7 100644 --- a/apps/integrations/jira/routes/api_routes/__init__.py +++ b/apps/integrations/jira/routes/api_routes/__init__.py @@ -9,6 +9,7 @@ metadata, projects, tasks, + trello, ) router = APIRouter() @@ -21,3 +22,4 @@ router.include_router(dashboard.router, tags=["dashboard"]) router.include_router(metadata.router, prefix="/metadata", tags=["metadata"]) router.include_router(debug.router, tags=["debug"]) +router.include_router(trello.router, tags=["trello"]) diff --git a/apps/integrations/jira/routes/api_routes/trello.py b/apps/integrations/jira/routes/api_routes/trello.py new file mode 100644 index 0000000..7133658 --- /dev/null +++ b/apps/integrations/jira/routes/api_routes/trello.py @@ -0,0 +1,375 @@ +import logging +from datetime import UTC, datetime, timedelta + +from fastapi import APIRouter, Depends, Header, HTTPException +from jose import jwt +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from apps.integrations.jira.routes.api_routes.shared import ALGORITHM, get_secret_key +from apps.integrations.jira.trello_client import TrelloAuthError, TrelloClient +from libs.common.database import get_async_db +from libs.common.models import User + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# List name keywords → board column category +_TODO_KEYWORDS = {"backlog", "to do", "todo", "to-do", "open", "new", "queue"} +_IN_PROGRESS_KEYWORDS = { + "doing", + "in progress", + "in-progress", + "wip", + "active", + "working", + "development", +} +_DONE_KEYWORDS = {"done", "complete", "completed", "finished", "closed", "released", "deployed"} + + +def _classify_list(list_name: str, position: int, total: int) -> str: + """Map a Trello list name to a board column category.""" + name = list_name.lower().strip() + if name in _DONE_KEYWORDS or any(kw in name for kw in _DONE_KEYWORDS): + return "done" + if name in _IN_PROGRESS_KEYWORDS or any(kw in name for kw in _IN_PROGRESS_KEYWORDS): + return "inProgress" + if name in _TODO_KEYWORDS or any(kw in name for kw in _TODO_KEYWORDS): + return "todo" + # Positional fallback: first → todo, last → done, middle → inProgress + if total == 1: + return "todo" + if position == 0: + return "todo" + if position == total - 1: + return "done" + return "inProgress" + + +def _card_to_board_issue(card: dict, list_name: str, status_category: str) -> dict: + """Map a Trello card to the BoardIssue format used by the frontend.""" + members = card.get("members") or [] + assignee = members[0].get("fullName", "Unassigned") if members else "Unassigned" + labels = card.get("labels") or [] + # Map first label color to a priority approximation + label_priority_map = {"red": "High", "orange": "High", "yellow": "Medium", "green": "Low"} + priority = "Medium" + if labels: + priority = label_priority_map.get(labels[0].get("color", ""), "Medium") + + return { + "key": card["id"], + "summary": card.get("name", ""), + "status": list_name, + "statusCategory": status_category, + "assignee": assignee, + "assigneeAvatar": members[0].get("avatarUrl") if members else None, + "priority": priority, + "issueType": "Card", + "duedate": card.get("due"), + "description": card.get("desc", ""), + "labels": [lbl.get("name", "") for lbl in labels], + "updated": card.get("dateLastActivity", ""), + "url": card.get("url"), + } + + +async def _get_trello_client_core(authorization: str, db: AsyncSession) -> TrelloClient: + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Unauthorized - No bearer token provided") + + token = authorization.split(" ")[1] + try: + secret_key = get_secret_key() + payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) + email = payload.get("sub") + if not email: + raise HTTPException(status_code=401, detail="Invalid token") + + result = await db.execute( + select(User).where(User.email == email).options(selectinload(User.jira_token)) + ) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=401, detail="User not found") + + if not user.jira_token: + raise HTTPException(status_code=404, detail="Jira not connected") + + if not user.jira_token.trello_workspace_id: + raise HTTPException( + status_code=404, + detail="Trello not connected. Reconnect Jira to grant Trello access.", + ) + + return TrelloClient( + access_token=user.jira_token.access_token, + workspace_id=user.jira_token.trello_workspace_id, + ) + except HTTPException: + raise + except jwt.JWTError as e: + raise HTTPException(status_code=401, detail="Could not validate credentials") from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +async def get_trello_client( + authorization: str = Header(None), db: AsyncSession = Depends(get_async_db) +) -> TrelloClient: + return await _get_trello_client_core(authorization, db) + + +async def _get_current_user(authorization: str, db: AsyncSession) -> User: + token = authorization.split(" ")[1] + secret_key = get_secret_key() + payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) + email = payload.get("sub") + result = await db.execute( + select(User).where(User.email == email).options(selectinload(User.jira_token)) + ) + return result.scalar_one_or_none() + + +# ─── Pydantic models ────────────────────────────────────────────────────────── + + +class CreateCardRequest(BaseModel): + list_id: str + name: str + desc: str = "" + due: str | None = None + member_ids: list[str] | None = None + + +class UpdateCardRequest(BaseModel): + name: str | None = None + desc: str | None = None + due: str | None = None + idList: str | None = None + + +class MoveCardRequest(BaseModel): + list_id: str + + +# ─── Endpoints ──────────────────────────────────────────────────────────────── + + +@router.get("/trello/boards") +async def list_trello_boards(client: TrelloClient = Depends(get_trello_client)): + """Return all open boards for the authenticated Trello user.""" + try: + boards = await client.get_boards() + return {"boards": [{"id": b["id"], "name": b["name"]} for b in boards]} + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + logger.exception("Error fetching Trello boards") + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/trello/boards/{board_id}/select") +async def select_trello_board( + board_id: str, + authorization: str = Header(None), + db: AsyncSession = Depends(get_async_db), +): + """Persist the user's selected Trello board.""" + user = await _get_current_user(authorization, db) + if not user or not user.jira_token: + raise HTTPException(status_code=404, detail="Trello not connected") + user.jira_token.trello_board_id = board_id + await db.commit() + return {"status": "ok", "board_id": board_id} + + +@router.get("/trello/stats") +async def get_trello_stats( + board_id: str | None = None, + authorization: str = Header(None), + db: AsyncSession = Depends(get_async_db), + client: TrelloClient = Depends(get_trello_client), +): + """Return dashboard stats in the same format as /api/jira/stats.""" + try: + user = await _get_current_user(authorization, db) + resolved_board_id = board_id or ( + user.jira_token.trello_board_id if user and user.jira_token else None + ) + if not resolved_board_id: + raise HTTPException(status_code=400, detail="No board selected") + + board_data = await client.get_board_with_cards(resolved_board_id) + lists = board_data.get("lists", []) + cards = board_data.get("cards", []) + + # Build list_id → (name, category) mapping + total = len(lists) + list_meta: dict[str, dict] = {} + for i, lst in enumerate(lists): + cat = _classify_list(lst["name"], i, total) + list_meta[lst["id"]] = {"name": lst["name"], "category": cat} + + board: dict = {"todo": [], "inProgress": [], "done": []} + now = datetime.now(UTC) + week_ago = now - timedelta(days=7) + overdue_count = 0 + high_priority_count = 0 + + for card in cards: + if card.get("closed"): + continue + meta = list_meta.get(card.get("idList", ""), {"name": "Unknown", "category": "todo"}) + issue = _card_to_board_issue(card, meta["name"], meta["category"]) + board[meta["category"]].append(issue) + + if issue["priority"] == "High": + high_priority_count += 1 + due = card.get("due") + if due and meta["category"] != "done": + try: + due_dt = datetime.fromisoformat(due.rstrip("Z")).replace(tzinfo=UTC) + if due_dt < now: + overdue_count += 1 + except ValueError: + pass + + active_count = len(board["inProgress"]) + done_this_week = sum( + 1 + for c in cards + if list_meta.get(c.get("idList", ""), {}).get("category") == "done" + and c.get("dateLastActivity") + and datetime.fromisoformat(c["dateLastActivity"].rstrip("Z")).replace(tzinfo=UTC) + >= week_ago + ) + + # Activities: last 10 modified cards + recent = sorted(cards, key=lambda c: c.get("dateLastActivity", ""), reverse=True)[:10] + activities = [ + { + "text": f"Updated card: {c['name']}", + "time": c.get("dateLastActivity", ""), + "issue": c["id"], + "status": list_meta.get(c.get("idList", ""), {}).get("name", ""), + } + for c in recent + ] + + return { + "active": str(active_count), + "completedThisWeek": str(done_this_week), + "overdue": str(overdue_count), + "highPriority": str(high_priority_count), + "board": board, + "filters": { + "priorities": ["High", "Medium", "Low"], + "types": ["Card"], + }, + "processed": str(active_count), + "extracted": str(active_count), + "approved": str(done_this_week), + "activities": activities, + "all_activities": activities, + "distributions": { + "priorities": {"High": high_priority_count, "Medium": 0, "Low": 0}, + "highPriorityDone": 0, + "totalDone": len(board["done"]), + }, + } + except HTTPException: + raise + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + logger.exception("Error fetching Trello stats") + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.post("/trello/cards") +async def create_trello_card( + req: CreateCardRequest, + client: TrelloClient = Depends(get_trello_client), +): + """Create a new Trello card.""" + try: + card = await client.create_card( + list_id=req.list_id, + name=req.name, + desc=req.desc, + due=req.due, + member_ids=req.member_ids, + ) + return card + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put("/trello/cards/{card_id}") +async def update_trello_card( + card_id: str, + req: UpdateCardRequest, + client: TrelloClient = Depends(get_trello_client), +): + """Update an existing Trello card.""" + try: + fields = req.model_dump(exclude_none=True) + if not fields: + raise HTTPException(status_code=400, detail="No fields to update") + return await client.update_card(card_id, **fields) + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.delete("/trello/cards/{card_id}") +async def delete_trello_card( + card_id: str, + client: TrelloClient = Depends(get_trello_client), +): + """Delete a Trello card permanently.""" + try: + await client.delete_card(card_id) + return {"status": "ok"} + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.put("/trello/cards/{card_id}/move") +async def move_trello_card( + card_id: str, + req: MoveCardRequest, + client: TrelloClient = Depends(get_trello_client), +): + """Move a card to a different list (status change).""" + try: + return await client.move_card(card_id, req.list_id) + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + +@router.get("/trello/boards/{board_id}/members") +async def get_trello_board_members( + board_id: str, + client: TrelloClient = Depends(get_trello_client), +): + """Return members of a Trello board (for assignee selection).""" + try: + members = await client.get_board_members(board_id) + return {"members": members} + except TrelloAuthError as e: + raise HTTPException(status_code=401, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/apps/integrations/jira/routes/auth_routes/oauth_jira.py b/apps/integrations/jira/routes/auth_routes/oauth_jira.py index 876767f..562fc0b 100644 --- a/apps/integrations/jira/routes/auth_routes/oauth_jira.py +++ b/apps/integrations/jira/routes/auth_routes/oauth_jira.py @@ -119,7 +119,10 @@ async def callback( if not resources: raise HTTPException(status_code=400, detail="No accessible resources found") - cloud_id = resources[0]["id"] + jira_resources = [r for r in resources if "atlassian.net" in r.get("url", "")] + trello_resources = [r for r in resources if "trello.com" in r.get("url", "")] + cloud_id = jira_resources[0]["id"] if jira_resources else resources[0]["id"] + trello_workspace_id = trello_resources[0]["id"] if trello_resources else None user_resp = await client.get( "https://api.atlassian.com/me", @@ -160,6 +163,8 @@ async def callback( token_entry.access_token = access_token token_entry.refresh_token = refresh_token token_entry.cloud_id = cloud_id + if trello_workspace_id: + token_entry.trello_workspace_id = trello_workspace_id await db.commit() frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") @@ -187,6 +192,8 @@ async def callback( existing_token.access_token = access_token existing_token.refresh_token = refresh_token existing_token.cloud_id = cloud_id + if trello_workspace_id: + existing_token.trello_workspace_id = trello_workspace_id await db.commit() return RedirectResponse( f"{frontend_url}/auth/callback?token={existing_jwt}&jira_connected=true" diff --git a/apps/integrations/jira/routes/auth_routes/profile.py b/apps/integrations/jira/routes/auth_routes/profile.py index 549ef3d..25a1416 100644 --- a/apps/integrations/jira/routes/auth_routes/profile.py +++ b/apps/integrations/jira/routes/auth_routes/profile.py @@ -40,6 +40,7 @@ async def get_me(db: AsyncSession = Depends(get_async_db), authorization: str = raise HTTPException(status_code=401, detail="User not found") jira_connected = user.jira_token is not None + trello_connected = bool(user.jira_token and user.jira_token.trello_workspace_id) return { "id": user.id, @@ -51,6 +52,7 @@ async def get_me(db: AsyncSession = Depends(get_async_db), authorization: str = "jiraConnected": jira_connected, "githubConnected": user.github_access_token is not None, "notionConnected": user.notion_token is not None, + "trelloConnected": trello_connected, "github_id": user.github_id, "subscription_tier": user.subscription_tier, "timezone": user.timezone or "Auto (Browser)", diff --git a/apps/integrations/jira/routes/github_repos.py b/apps/integrations/jira/routes/github_repos.py index c30e767..c842425 100644 --- a/apps/integrations/jira/routes/github_repos.py +++ b/apps/integrations/jira/routes/github_repos.py @@ -1,4 +1,3 @@ -import asyncio import os import httpx @@ -8,6 +7,12 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from apps.integrations.jira.services.github_integration_service import ( + confirm_installation, + create_webhook, + delete_webhook, + fetch_commits_with_stats, +) from libs.common.database import get_async_db from libs.common.models import GitHubInstallation, GitHubWebhook, User @@ -30,12 +35,10 @@ async def get_current_user_from_token( email = payload.get("sub") if not email: raise HTTPException(status_code=401, detail="Invalid token - missing sub") - result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=401, detail="User not found") - return user except jwt.JWTError: raise HTTPException(status_code=401, detail="Could not validate credentials") from None @@ -44,104 +47,38 @@ async def get_current_user_from_token( @router.get("/github/install") def github_install(): """Redirect to GitHub App installation page.""" - # Note: user will need to provide the app name/slug later if dynamic, - # but for now we assume it's part of the GITHUB_CLIENT_ID or known. - # Standard flow: https://github.com/apps//installations/new - # We use GITHUB_APP_SLUG from env, fallback to 'kwillo-ai' app_slug = os.getenv("GITHUB_APP_SLUG", "kwillo-ai") - url = f"https://github.com/apps/{app_slug}/installations/new" - return RedirectResponse(url) + return RedirectResponse(f"https://github.com/apps/{app_slug}/installations/new") @router.get("/github/setup") async def github_setup(installation_id: int, setup_action: str = "install"): - """ - Handle redirect from GitHub App installation. - Redirects back to frontend with the installation ID. - """ + """Handle redirect from GitHub App installation.""" frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") - # Redirect to frontend with installation info return RedirectResponse( f"{frontend_url}/github?installation_id={installation_id}&setup_action={setup_action}" ) @router.post("/github/installations/confirm") -async def confirm_installation( +async def confirm_installation_endpoint( installation_id: int, user: User = Depends(get_current_user_from_token), db: AsyncSession = Depends(get_async_db), ): - """ - Confirm and record a new GitHub App installation. - """ + """Confirm and record a new GitHub App installation.""" if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - - # Fetch installation details from GitHub to verify and get account info - async with httpx.AsyncClient() as client: - # We use the user's token because they just authorized/installed it. - # GitHub allows the user who installed the app to see it via their token. - resp = await client.get( - "https://api.github.com/user/installations", - headers={ - "Authorization": f"Bearer {user.github_access_token}", - "Accept": "application/vnd.github.v3+json", - }, - ) - - if resp.status_code != 200: - raise HTTPException(status_code=400, detail="Failed to fetch installations from GitHub") - - installations = resp.json().get("installations", []) - inst_data = next((i for i in installations if i["id"] == installation_id), None) - - if not inst_data: - raise HTTPException( - status_code=404, - detail=f"Installation {installation_id} not found for this user", - ) - - account = inst_data.get("account", {}) - # Upsert into database - _existing_result = await db.execute( - select(GitHubInstallation).where(GitHubInstallation.id == installation_id) - ) - existing = _existing_result.scalar_one_or_none() - - if existing: - existing.user_id = user.id - existing.account_login = account.get("login") - existing.account_id = account.get("id") - existing.account_type = account.get("type") - existing.repository_selection = inst_data.get("repository_selection", "all") - else: - new_inst = GitHubInstallation( - id=installation_id, - user_id=user.id, - account_login=account.get("login"), - account_id=account.get("id"), - account_type=account.get("type"), - repository_selection=inst_data.get("repository_selection", "all"), - ) - db.add(new_inst) - - await db.commit() - return {"status": "success", "installation_id": installation_id} + return await confirm_installation(installation_id, user, db) @router.get("/github/repos") async def list_github_repos(user: User = Depends(get_current_user_from_token)): """List all repositories for the connected GitHub user.""" if not user.github_access_token: - print(f"DEBUG: list_github_repos failed - User {user.email} has no GitHub token") raise HTTPException(status_code=400, detail="GitHub not connected - missing access token") async with httpx.AsyncClient() as client: - # V3: If user has installations, we could list repos from those installations. - # But for now, we keep the user-token flow as fallback/primary for the picker. - # The key change is that the UI will now mostly drive through installations. - response = await client.get( "https://api.github.com/user/repos", params={ @@ -154,34 +91,25 @@ async def list_github_repos(user: User = Depends(get_current_user_from_token)): "Accept": "application/vnd.github.v3+json", }, ) + if response.status_code == 401: + raise HTTPException( + status_code=401, detail="GitHub token expired or revoked. Please reconnect GitHub." + ) if response.status_code != 200: - print(f"ERROR: GitHub API failed: {response.text}") - if response.status_code == 401: - raise HTTPException( - status_code=401, - detail="GitHub token expired or revoked. Please reconnect GitHub.", - ) - if response.status_code == 400: - print(f"ERROR: GitHub API returned 400: {response.text}") - raise HTTPException( - status_code=400, detail=f"GitHub API returned 400: {response.text}" - ) raise HTTPException( status_code=400, detail=f"Failed to fetch repositories from GitHub: {response.status_code}", ) - return response.json() @router.get("/github/orgs") async def list_github_orgs(user: User = Depends(get_current_user_from_token)): - """List GitHub organizations using memberships endpoint for better visibility.""" + """List GitHub organizations the user belongs to.""" if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") async with httpx.AsyncClient() as client: - # memberships endpoint shows more roles than /user/orgs response = await client.get( "https://api.github.com/user/memberships/orgs", headers={ @@ -191,21 +119,16 @@ async def list_github_orgs(user: User = Depends(get_current_user_from_token)): ) if response.status_code != 200: return [] - - memberships = response.json() - orgs = [] - for m in memberships: - if m.get("state") == "active": - org_data = m.get("organization", {}) - orgs.append( - { - "login": org_data.get("login"), - "avatar_url": org_data.get("avatar_url"), - "id": org_data.get("id"), - "description": org_data.get("description"), - } - ) - return orgs + return [ + { + "login": m["organization"]["login"], + "avatar_url": m["organization"]["avatar_url"], + "id": m["organization"]["id"], + "description": m["organization"].get("description"), + } + for m in response.json() + if m.get("state") == "active" + ] @router.get("/github/installations") @@ -213,11 +136,10 @@ async def list_installations( user: User = Depends(get_current_user_from_token), db: AsyncSession = Depends(get_async_db) ): """List GitHub App installations for the current user.""" - _result = await db.execute( + result = await db.execute( select(GitHubInstallation).where(GitHubInstallation.user_id == user.id) ) - installations = _result.scalars().all() - return installations + return result.scalars().all() @router.get("/github/installations/{installation_id}/repos") @@ -231,11 +153,6 @@ async def list_installation_repos( raise HTTPException(status_code=400, detail="GitHub not connected") async with httpx.AsyncClient() as client: - # Note: In a true GitHub App, we would use an Installation Access Token (IAT). - # However, since the user is also OAuth-connected, we can use their user token - # to list repos they can see in that installation context, OR - # use the /installation/repositories endpoint if the token allows. - response = await client.get( f"https://api.github.com/user/installations/{installation_id}/repositories", headers={ @@ -244,27 +161,17 @@ async def list_installation_repos( }, ) if response.status_code != 200: - print( - f"DEBUG: Failed to fetch installation repos for {installation_id}: {response.status_code} - {response.text}" - ) - # Fallback to org repos if it's an org installation - _inst_result = await db.execute( + inst_result = await db.execute( select(GitHubInstallation).where(GitHubInstallation.id == installation_id) ) - inst = _inst_result.scalar_one_or_none() + inst = inst_result.scalar_one_or_none() if inst and inst.account_type == "Organization": - print(f"DEBUG: Falling back to org repos for {inst.account_login}") return await list_org_repos(inst.account_login, user) - raise HTTPException( status_code=response.status_code, detail=f"Failed to fetch installation repositories: {response.text}", ) - - data = response.json() - repos = data.get("repositories", []) - print(f"DEBUG: Successfully fetched {len(repos)} repos for installation {installation_id}") - return repos + return response.json().get("repositories", []) @router.get("/github/orgs/{org}/repos") @@ -284,8 +191,7 @@ async def list_org_repos(org: str, user: User = Depends(get_current_user_from_to ) if response.status_code != 200: raise HTTPException( - status_code=400, - detail=f"Failed to fetch repositories for organization {org}", + status_code=400, detail=f"Failed to fetch repositories for organization {org}" ) return response.json() @@ -294,14 +200,16 @@ async def list_org_repos(org: str, user: User = Depends(get_current_user_from_to async def get_connected_repos( user: User = Depends(get_current_user_from_token), db: AsyncSession = Depends(get_async_db) ): - """Get list of repos that have active webhooks configured for this user.""" - _result = await db.execute(select(GitHubWebhook).where(GitHubWebhook.user_id == user.id)) - webhooks = _result.scalars().all() - return [{"repo_full_name": w.repo_full_name, "webhook_id": w.webhook_id} for w in webhooks] + """Get repos that have active webhooks configured for this user.""" + result = await db.execute(select(GitHubWebhook).where(GitHubWebhook.user_id == user.id)) + return [ + {"repo_full_name": w.repo_full_name, "webhook_id": w.webhook_id} + for w in result.scalars().all() + ] @router.post("/github/repos/{owner}/{repo}/webhook") -async def create_webhook( +async def create_webhook_endpoint( owner: str, repo: str, user: User = Depends(get_current_user_from_token), @@ -310,126 +218,11 @@ async def create_webhook( """Create a webhook on the specified repository.""" if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - - # Check if we already have one - _existing_result = await db.execute( - select(GitHubWebhook).where( - GitHubWebhook.user_id == user.id, - GitHubWebhook.repo_full_name == f"{owner}/{repo}", - ) - ) - existing = _existing_result.scalar_one_or_none() - - if existing: - return { - "status": "success", - "message": "Webhook already exists", - "webhook_id": existing.webhook_id, - } - - # Prepare webhook config - # GITHUB_WEBHOOK_PUBLIC_URL must be set to the publicly accessible URL of the API server. - # E.g. https://kwilloai.abrdns.com - url_base = os.getenv("GITHUB_WEBHOOK_PUBLIC_URL") or os.getenv("JIRA_SERVICE_PUBLIC_URL") - if not url_base or "localhost" in url_base: - raise HTTPException( - status_code=503, - detail="Webhook creation is not available in local development. " - "Set GITHUB_WEBHOOK_PUBLIC_URL to your public server URL.", - ) - # Ensure no double slashes - webhook_url = f"{url_base.rstrip('/')}/api/webhooks/github" - webhook_secret = os.getenv("GITHUB_WEBHOOK_SECRET", "") - - if not webhook_secret: - print("WARNING: GITHUB_WEBHOOK_SECRET is not set. Webhooks will not be secured.") - - payload = { - "name": "web", - "active": True, - "events": [ - "push", - "pull_request", - "issues", - "pull_request_review", - "workflow_run", - ], - "config": { - "url": webhook_url, - "content_type": "json", - "insecure_ssl": "0", # Always require SSL if possible, though local envs might fail - }, - } - - if webhook_secret: - payload["config"]["secret"] = webhook_secret - - print(f"DEBUG: Creating GitHub webhook for {owner}/{repo} at {webhook_url}") - - async with httpx.AsyncClient() as client: - response = await client.post( - f"https://api.github.com/repos/{owner}/{repo}/hooks", - json=payload, - headers={ - "Authorization": f"Bearer {user.github_access_token}", - "Accept": "application/vnd.github.v3+json", - }, - ) - - if response.status_code == 201: - data = response.json() - # Save to database - webhook = GitHubWebhook( - user_id=user.id, repo_full_name=f"{owner}/{repo}", webhook_id=data["id"] - ) - db.add(webhook) - await db.commit() - return {"status": "success", "webhook_id": data["id"]} - else: - print(f"ERROR: GitHub webhook creation failed: {response.status_code} {response.text}") - error_msg = response.json().get("message", "Unknown error") - if "already exists" in error_msg.lower(): - # Handle case where user manually added it but app doesn't know - # Try to list hooks to find the ID and save it - return await sync_existing_webhook(owner, repo, user, db, webhook_url) - - raise HTTPException(status_code=400, detail=f"Failed to create webhook: {error_msg}") - - -async def sync_existing_webhook( - owner: str, repo: str, user: User, db: AsyncSession, target_url: str -): - """Helper to find an existing webhook and sync it to the database""" - async with httpx.AsyncClient() as client: - response = await client.get( - f"https://api.github.com/repos/{owner}/{repo}/hooks", - headers={ - "Authorization": f"Bearer {user.github_access_token}", - "Accept": "application/vnd.github.v3+json", - }, - ) - if response.status_code == 200: - hooks = response.json() - for hook in hooks: - if hook.get("config", {}).get("url") == target_url: - webhook = GitHubWebhook( - user_id=user.id, - repo_full_name=f"{owner}/{repo}", - webhook_id=hook["id"], - ) - db.add(webhook) - await db.commit() - return { - "status": "success", - "message": "Synced existing webhook", - "webhook_id": hook["id"], - } - - raise HTTPException(status_code=400, detail="Webhook already exists but could not be synced") + return await create_webhook(owner, repo, user, db) @router.delete("/github/repos/{owner}/{repo}/webhook") -async def delete_webhook( +async def delete_webhook_endpoint( owner: str, repo: str, user: User = Depends(get_current_user_from_token), @@ -438,78 +231,7 @@ async def delete_webhook( """Delete the webhook from the specified repository.""" if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - - _existing_result = await db.execute( - select(GitHubWebhook).where( - GitHubWebhook.user_id == user.id, - GitHubWebhook.repo_full_name == f"{owner}/{repo}", - ) - ) - existing = _existing_result.scalar_one_or_none() - - if not existing: - return {"status": "success", "message": "Webhook not found in database"} - - print(f"DEBUG: Deleting GitHub webhook {existing.webhook_id} for {owner}/{repo}") - - async with httpx.AsyncClient() as client: - response = await client.delete( - f"https://api.github.com/repos/{owner}/{repo}/hooks/{existing.webhook_id}", - headers={ - "Authorization": f"Bearer {user.github_access_token}", - "Accept": "application/vnd.github.v3+json", - }, - ) - - # 204 No Content is success for delete - if response.status_code == 204 or response.status_code == 404: - # 404 might mean user already deleted it manually, we still remove from DB - await db.delete(existing) - await db.commit() - return {"status": "success", "message": "Webhook deleted"} - else: - print(f"ERROR: GitHub webhook deletion failed: {response.text}") - raise HTTPException(status_code=400, detail="Failed to delete webhook on GitHub API") - - -async def _fetch_commit_stats( - client: httpx.AsyncClient, - owner: str, - repo: str, - sha: str, - headers: dict, -) -> dict: - """Fetch stats for a single commit (additions, deletions, files changed).""" - try: - resp = await client.get( - f"https://api.github.com/repos/{owner}/{repo}/commits/{sha}", - headers=headers, - timeout=15.0, - ) - if resp.status_code != 200: - return {"additions": 0, "deletions": 0, "filesChanged": 0, "files": []} - data = resp.json() - stats = data.get("stats", {}) - files_raw = data.get("files", []) - - files = [ - { - "path": f.get("filename"), - "additions": f.get("additions", 0), - "deletions": f.get("deletions", 0), - "patch": f.get("patch", ""), - } - for f in files_raw - ] - - return { - "additions": stats.get("additions", 0), - "deletions": stats.get("deletions", 0), - "filesChanged": len(files_raw), - "files": files, - } - except Exception: - return {"additions": 0, "deletions": 0, "filesChanged": 0, "files": []} + return await delete_webhook(owner, repo, user, db) @router.get("/github/branches/{owner}/{repo}") @@ -522,37 +244,26 @@ async def list_github_branches( if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - gh_headers = { - "Authorization": f"Bearer {user.github_access_token}", - "Accept": "application/vnd.github.v3+json", - } - async with httpx.AsyncClient(timeout=20.0) as client: response = await client.get( f"https://api.github.com/repos/{owner}/{repo}/branches", - headers=gh_headers, + headers={ + "Authorization": f"Bearer {user.github_access_token}", + "Accept": "application/vnd.github.v3+json", + }, ) - if response.status_code == 401: raise HTTPException( - status_code=401, - detail="GitHub token expired or revoked. Please reconnect GitHub.", + status_code=401, detail="GitHub token expired or revoked. Please reconnect GitHub." ) if response.status_code == 404: raise HTTPException( - status_code=404, - detail=f"Repository {owner}/{repo} not found or not accessible.", + status_code=404, detail=f"Repository {owner}/{repo} not found or not accessible." ) if response.status_code != 200: - raise HTTPException( - status_code=400, - detail="Failed to fetch branches from GitHub.", - ) - - branches_raw = response.json() + raise HTTPException(status_code=400, detail="Failed to fetch branches from GitHub.") return [ - {"name": branch["name"], "protected": branch.get("protected", False)} - for branch in branches_raw + {"name": b["name"], "protected": b.get("protected", False)} for b in response.json() ] @@ -567,74 +278,7 @@ async def list_github_commits( """List commits for a repository with stats (additions, deletions, files changed).""" if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - - gh_headers = { - "Authorization": f"Bearer {user.github_access_token}", - "Accept": "application/vnd.github.v3+json", - } - - async with httpx.AsyncClient(timeout=20.0) as client: - # 1. Fetch commit list - response = await client.get( - f"https://api.github.com/repos/{owner}/{repo}/commits", - params={"sha": sha, "per_page": per_page}, - headers=gh_headers, - ) - - if response.status_code == 401: - raise HTTPException( - status_code=401, - detail="GitHub token expired or revoked. Please reconnect GitHub.", - ) - if response.status_code == 404: - raise HTTPException( - status_code=404, - detail=f"Repository {owner}/{repo} not found or not accessible.", - ) - if response.status_code != 200: - raise HTTPException( - status_code=400, - detail="Failed to fetch commits from GitHub.", - ) - - commits_raw = response.json() - - # 2. Fetch stats for all commits in parallel - stats_list = await asyncio.gather( - *[_fetch_commit_stats(client, owner, repo, c["sha"], gh_headers) for c in commits_raw] - ) - - # 3. Normalise to the shape expected by the frontend - result = [] - for commit_raw, stats in zip(commits_raw, stats_list, strict=True): - commit = commit_raw.get("commit", {}) - author_info = commit_raw.get("author") or {} # GitHub user object (may be null) - commit_author = commit.get("author") or {} - - # Prefer the GitHub user login name; fall back to git commit author name - author_name = author_info.get("login") or commit_author.get("name") or "Unknown" - # Avatar from GitHub user object; fall back to gravatar-style placeholder - avatar_url = author_info.get("avatar_url") or "https://avatars.githubusercontent.com/u/0" - - result.append( - { - "hash": commit_raw.get("sha", ""), - "message": (commit.get("message") or "").split("\n")[0], # first line only - "author": { - "name": author_name, - "avatarUrl": avatar_url, - }, - "timestamp": commit_author.get("date") or "", - "stats": { - "additions": stats.get("additions", 0), - "deletions": stats.get("deletions", 0), - "filesChanged": stats.get("filesChanged", 0), - }, - "files": stats.get("files", []), - } - ) - - return result + return await fetch_commits_with_stats(owner, repo, sha, per_page, user.github_access_token) @router.get("/github/commits/{owner}/{repo}/{sha}") @@ -645,10 +289,9 @@ async def get_github_commit_details( if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - from apps.integrations.github.app.github.client import GitHubClient + from apps.integrations.github.services.github.client import GitHubClient client = GitHubClient(user.github_access_token) - commit_raw = await client.get_commit_details(owner, repo, sha) if not commit_raw: raise HTTPException(status_code=404, detail="Commit not found") @@ -656,34 +299,16 @@ async def get_github_commit_details( commit = commit_raw.get("commit", {}) author_info = commit_raw.get("author") or {} commit_author = commit.get("author") or {} - - # Prefer the GitHub user login name; fall back to git commit author name - author_name = author_info.get("login") or commit_author.get("name") or "Unknown" - # Avatar from GitHub user object; fall back to gravatar-style placeholder - avatar_url = author_info.get("avatar_url") or "https://avatars.githubusercontent.com/u/0" - stats = commit_raw.get("stats", {}) files_raw = commit_raw.get("files", []) - # Optional: fetch raw diff if needed, otherwise rely on the patches in files - # diff = await client.get_commit_diff(owner, repo, sha) - - files = [ - { - "path": f.get("filename"), - "additions": f.get("additions", 0), - "deletions": f.get("deletions", 0), - "patch": f.get("patch", ""), - } - for f in files_raw - ] - return { "hash": commit_raw.get("sha", ""), "message": commit.get("message", ""), "author": { - "name": author_name, - "avatarUrl": avatar_url, + "name": author_info.get("login") or commit_author.get("name") or "Unknown", + "avatarUrl": author_info.get("avatar_url") + or "https://avatars.githubusercontent.com/u/0", }, "timestamp": commit_author.get("date") or "", "stats": { @@ -691,8 +316,15 @@ async def get_github_commit_details( "deletions": stats.get("deletions", 0), "filesChanged": len(files_raw), }, - "files": files, - # "diff": diff + "files": [ + { + "path": f.get("filename"), + "additions": f.get("additions", 0), + "deletions": f.get("deletions", 0), + "patch": f.get("patch", ""), + } + for f in files_raw + ], } @@ -704,19 +336,15 @@ async def analyze_github_commit( if not user.github_access_token: raise HTTPException(status_code=400, detail="GitHub not connected") - from apps.integrations.github.app.ai.commit_analyzer import analyze_commit - from apps.integrations.github.app.github.client import GitHubClient + from apps.integrations.github.services.ai.commit_analyzer import analyze_commit + from apps.integrations.github.services.github.client import GitHubClient client = GitHubClient(user.github_access_token) - commit_raw = await client.get_commit_details(owner, repo, sha) if not commit_raw: raise HTTPException(status_code=404, detail="Commit not found") commit = commit_raw.get("commit", {}) - files_raw = commit_raw.get("files", []) - - # Format files for the analyzer files = [ { "path": f.get("filename"), @@ -724,17 +352,14 @@ async def analyze_github_commit( "deletions": f.get("deletions", 0), "patch": f.get("patch", ""), } - for f in files_raw + for f in commit_raw.get("files", []) ] analysis = await analyze_commit(message=commit.get("message", ""), files=files, force=True) - if not analysis: - # Provide a fallback if AI is disabled or fails return { "summary": "AI analysis is currently unavailable.", "issuesFound": [], "securityScore": 100, } - return analysis diff --git a/apps/integrations/jira/routes/payment.py b/apps/integrations/jira/routes/payment.py index 176b2bf..22dbd8c 100644 --- a/apps/integrations/jira/routes/payment.py +++ b/apps/integrations/jira/routes/payment.py @@ -5,10 +5,19 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request from jose import jwt from polar_sdk import Polar -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from apps.agents.agent_server.src.common.redis_client import RedisManager +from apps.integrations.jira.services.billing_service import ( + PLAN_META, + handle_order_or_customer, + handle_subscription_active, + handle_subscription_canceled, + handle_subscription_revoked, + handle_subscription_updated, + sync_subscription_from_polar, +) from libs.common.database import get_async_db from libs.common.models import User from libs.common.subscription_limits import ( @@ -18,15 +27,57 @@ ) logger = logging.getLogger(__name__) - router = APIRouter() -# Configuration + +def _get(obj, key, default=None): + """Get a field from a dict or SDK object uniformly.""" + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _extract_sub_fields(data) -> dict: + """Normalize a Polar subscription dict or SDK object into a flat dict.""" + customer = _get(data, "customer") or {} + return { + "subscription_id": _get(data, "id"), + "status": _get(data, "status"), + "ends_at": _get(data, "ends_at"), + "product_id": _get(data, "product_id"), + "customer_email": _get(customer, "email") + if isinstance(customer, dict) + else getattr(customer, "email", None), + } + + +def _resolve_tier(product_id: str | None) -> str | None: + """Map a Polar product_id to an internal tier name via env vars.""" + if not product_id: + return None + mapping = { + v: k + for k, v in { + "starter": os.getenv("POLAR_PRODUCT_ID_STARTER", ""), + "pro": os.getenv("POLAR_PRODUCT_ID_PRO", ""), + "unlimited": os.getenv("POLAR_PRODUCT_ID_UNLIMITED", ""), + }.items() + if v # skip empty env vars + } + return mapping.get(product_id) + + SECRET_KEY = os.getenv("SECRET_KEY") if not SECRET_KEY: raise RuntimeError("SECRET_KEY environment variable is not set") ALGORITHM = "HS256" +try: + polar_client = Polar(access_token=os.getenv("POLAR_ACCESS_TOKEN", "")) +except Exception as e: + logger.warning(f"Polar Client Init Warning: {e}") + polar_client = None + async def get_current_user( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db) @@ -39,144 +90,53 @@ async def get_current_user( email = payload.get("sub") if not email: raise HTTPException(status_code=401, detail="Invalid token - missing sub") - - from sqlalchemy import func - result = await db.execute(select(User).where(func.lower(User.email) == func.lower(email))) user = result.scalar_one_or_none() if not user: raise HTTPException(status_code=401, detail="User not found") - return user except jwt.JWTError: raise HTTPException(status_code=401, detail="Could not validate credentials") from None -# Initialize Polar Client -try: - polar_client = Polar(access_token=os.getenv("POLAR_ACCESS_TOKEN", "")) -except Exception as e: - print(f"Polar Client Init Warning: {e}") - polar_client = None - - @router.post("/payment/checkout") async def create_checkout(plan: str = "pro", user: User = Depends(get_current_user)): - """ - Create a Polar.sh checkout session for the specified plan. - 'plan' can be 'pro', 'ultimate' or 'custom'. - """ - print(f"DEBUG: create_checkout called for user {user.email}, plan: {plan}") - + """Create a Polar.sh checkout session for the specified plan.""" if not polar_client: - print("DEBUG: polar_client is None") raise HTTPException(status_code=500, detail="Polar client not initialized") - # Map plan name to env variable - if plan == "starter": - product_id = os.getenv("POLAR_PRODUCT_ID_STARTER") or os.getenv("POLAR_PRODUCT_ID_TEST") - elif plan in ["unlimited", "ultimate"]: - product_id = os.getenv("POLAR_PRODUCT_ID_UNLIMITED") or os.getenv( - "POLAR_PRODUCT_ID_ULTIMATE" - ) - elif plan == "pro": - product_id = os.getenv("POLAR_PRODUCT_ID_PRO") - elif plan == "enterprise" or plan == "custom": - product_id = os.getenv("POLAR_PRODUCT_ID_TEST") # Fallback for now - else: - product_id = os.getenv("POLAR_PRODUCT_ID_PRO") + product_id = { + "starter": os.getenv("POLAR_PRODUCT_ID_STARTER") or os.getenv("POLAR_PRODUCT_ID_TEST"), + "unlimited": os.getenv("POLAR_PRODUCT_ID_UNLIMITED") + or os.getenv("POLAR_PRODUCT_ID_ULTIMATE"), + "ultimate": os.getenv("POLAR_PRODUCT_ID_UNLIMITED") + or os.getenv("POLAR_PRODUCT_ID_ULTIMATE"), + "pro": os.getenv("POLAR_PRODUCT_ID_PRO"), + }.get(plan, os.getenv("POLAR_PRODUCT_ID_PRO")) if not product_id: - print(f"DEBUG: Missing product_id for plan: {plan}") raise HTTPException(status_code=500, detail=f"Polar product ID for '{plan}' not configured") - # Retrieve Frontend URL from env or default frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") - success_url = f"{frontend_url}/dashboard?payment=success" - cancel_url = f"{frontend_url}/pricing" - - print(f"DEBUG: Using FRONTEND_URL = {frontend_url}") - print(f"DEBUG: Using success_url = {success_url}") - print(f"DEBUG: Using cancel_url = {cancel_url}") + checkout_params = { + "products": [product_id], + "customer_email": user.email, + "success_url": f"{frontend_url}/dashboard?payment=success", + "cancel_url": f"{frontend_url}/pricing", + } - # Primary Checkout Creation try: - print("DEBUG: Attempting primary checkout creation...") - checkout = polar_client.checkouts.create( - request={ - "products": [product_id], - "customer_email": user.email, - "success_url": success_url, - "cancel_url": cancel_url, - } - ) - print(f"DEBUG: Checkout success! URL: {checkout.url}") + checkout = polar_client.checkouts.create(request=checkout_params) return {"checkout_url": checkout.url} - except Exception as e: - print(f"DEBUG: Primary attempt failed: {str(e)}") - # Fallback with less parameters if needed? No, request={...} is standard + except Exception: + # Fallback without customer_email try: - print("DEBUG: Attempting fallback checkout creation...") checkout = polar_client.checkouts.create( - request={ - "products": [product_id], - "success_url": success_url, - "cancel_url": cancel_url, - } + request={k: v for k, v in checkout_params.items() if k != "customer_email"} ) return {"checkout_url": checkout.url} except Exception as e2: - print(f"DEBUG: Fallback failed: {str(e2)}") - raise HTTPException(status_code=500, detail=f"Polar API Error: {str(e)}") from e - - -# ── Webhook Helpers ─────────────────────────────────────────────────────────── - - -def _extract_sub_fields(data) -> dict: - """Extract subscription fields from either a dict or SDK model object.""" - if isinstance(data, dict): - customer = data.get("customer") or {} - customer_email = customer.get("email") or data.get("customer_email") or data.get("email") - return { - "subscription_id": data.get("id"), - "product_id": data.get("product_id"), - "customer_id": customer.get("id") or data.get("customer_id"), - "customer_email": customer_email, - "status": data.get("status"), - "ends_at": data.get("ends_at"), - } - # SDK object (returned by validate_event) - customer = getattr(data, "customer", None) - customer_email = None - customer_id = None - if customer: - customer_email = getattr(customer, "email", None) - customer_id = getattr(customer, "id", None) - return { - "subscription_id": getattr(data, "id", None), - "product_id": getattr(data, "product_id", None), - "customer_id": customer_id or getattr(data, "customer_id", None), - "customer_email": customer_email or getattr(data, "customer_email", None), - "status": getattr(data, "status", None), - "ends_at": getattr(data, "ends_at", None), - } - - -def _resolve_tier(product_id: str | None) -> str | None: - """Map a Polar product_id to an internal tier name. Returns None if unknown.""" - if not product_id: - return None - mapping = { - k: v - for k, v in { - os.getenv("POLAR_PRODUCT_ID_STARTER"): "starter", - os.getenv("POLAR_PRODUCT_ID_PRO"): "pro", - os.getenv("POLAR_PRODUCT_ID_ULTIMATE"): "unlimited", - }.items() - if k # exclude empty/missing env vars - } - return mapping.get(product_id) + raise HTTPException(status_code=500, detail=f"Polar API Error: {str(e2)}") from e2 @router.post("/webhooks/polar") @@ -200,14 +160,10 @@ async def polar_webhook(request: Request, db: AsyncSession = Depends(get_async_d event_type = event.type data = event.data except ImportError: - logger.error( - "polar_sdk.validate_event not available — cannot verify webhook signature" - ) raise HTTPException( status_code=500, detail="Webhook validation unavailable" ) from None except Exception as exc: - logger.error(f"Webhook signature validation failed: {exc}") raise HTTPException(status_code=400, detail="Invalid webhook signature") from exc else: raw = await request.json() @@ -216,34 +172,19 @@ async def polar_webhook(request: Request, db: AsyncSession = Depends(get_async_d logger.info(f"Polar webhook received: {event_type}") - # ── Subscription lifecycle events ───────────────────────────────────── if event_type == "subscription.created": - # Do NOT grant access yet — wait for subscription.active - logger.info( - "subscription.created received — awaiting subscription.active to grant access" - ) + logger.info("subscription.created — awaiting subscription.active to grant access") return {"status": "received"} - if event_type == "subscription.active": - return await _handle_subscription_active(data, db) - + return await handle_subscription_active(data, db) if event_type == "subscription.updated": - return await _handle_subscription_updated(data, db) - + return await handle_subscription_updated(data, db) if event_type == "subscription.canceled": - return await _handle_subscription_canceled(data, db) - + return await handle_subscription_canceled(data, db) if event_type == "subscription.revoked": - return await _handle_subscription_revoked(data, db) - - # ── Order / customer events ──────────────────────────────────────────── - if event_type in ( - "order.created", - "order.updated", - "customer.created", - "customer.updated", - ): - return await _handle_order_or_customer(event_type, data, db) + return await handle_subscription_revoked(data, db) + if event_type in ("order.created", "order.updated", "customer.created", "customer.updated"): + return await handle_order_or_customer(event_type, data, db) return {"status": "received", "note": f"unhandled event type: {event_type}"} @@ -254,142 +195,11 @@ async def polar_webhook(request: Request, db: AsyncSession = Depends(get_async_d raise HTTPException(status_code=500, detail=str(exc)) from exc -async def _handle_subscription_active(data, db: AsyncSession): - """Grant access when subscription becomes active (payment confirmed).""" - fields = _extract_sub_fields(data) - user = await _find_user_by_email(fields["customer_email"], db) - if not user: - return {"status": "ignored", "reason": "user_not_found"} - - # Idempotency: if already processed this subscription+active combo, skip - if ( - user.polar_subscription_id == fields["subscription_id"] - and user.subscription_status == "active" - ): - return {"status": "skipped", "reason": "already_active"} - - new_tier = _resolve_tier(fields["product_id"]) - if not new_tier: - logger.error(f"Unknown product_id in subscription.active: {fields['product_id']!r}") - return {"status": "ignored", "reason": "unknown_product_id"} - - user.subscription_tier = new_tier - user.subscription_status = "active" - user.subscription_ends_at = None - user.polar_subscription_id = fields["subscription_id"] - if fields["customer_id"]: - user.polar_customer_id = fields["customer_id"] - await db.commit() - - logger.info(f"Granted {new_tier} tier to {fields['customer_email']} via subscription.active") - return {"status": "success", "tier": new_tier} - - -async def _handle_subscription_updated(data, db: AsyncSession): - """Sync tier on any subscription change (upgrade/downgrade).""" - fields = _extract_sub_fields(data) - user = await _find_user_by_email(fields["customer_email"], db) - if not user: - return {"status": "ignored", "reason": "user_not_found"} - - new_tier = _resolve_tier(fields["product_id"]) - if not new_tier: - logger.error(f"Unknown product_id in subscription.updated: {fields['product_id']!r}") - return {"status": "ignored", "reason": "unknown_product_id"} - - user.subscription_tier = new_tier - user.subscription_status = fields["status"] or "active" - user.polar_subscription_id = fields["subscription_id"] - await db.commit() - - logger.info(f"Updated {fields['customer_email']} tier to {new_tier} via subscription.updated") - return {"status": "success", "tier": new_tier} - - -async def _handle_subscription_canceled(data, db: AsyncSession): - """Mark subscription as canceled; keep tier until ends_at.""" - fields = _extract_sub_fields(data) - user = await _find_user_by_email(fields["customer_email"], db) - if not user: - return {"status": "ignored", "reason": "user_not_found"} - - ends_at = fields["ends_at"] - if isinstance(ends_at, str): - try: - ends_at = datetime.datetime.fromisoformat(ends_at.replace("Z", "+00:00")).replace( - tzinfo=None - ) - except ValueError: - ends_at = None - - user.subscription_status = "canceled" - user.subscription_ends_at = ends_at - await db.commit() - - logger.info(f"Marked {fields['customer_email']} subscription as canceled (ends: {ends_at})") - return {"status": "success", "action": "scheduled_downgrade"} - - -async def _handle_subscription_revoked(data, db: AsyncSession): - """Immediately downgrade to free on revocation.""" - fields = _extract_sub_fields(data) - user = await _find_user_by_email(fields["customer_email"], db) - if not user: - return {"status": "ignored", "reason": "user_not_found"} - - user.subscription_tier = "free" - user.subscription_status = "revoked" - user.subscription_ends_at = None - await db.commit() - - logger.info(f"Revoked subscription for {fields['customer_email']} — downgraded to free") - return {"status": "success", "action": "downgraded_to_free"} - - -async def _handle_order_or_customer(event_type: str, data, db: AsyncSession): - """Store polar_customer_id from order/customer events.""" - if isinstance(data, dict): - customer = data.get("customer") or {} - customer_id = customer.get("id") or data.get("customer_id") - customer_email = customer.get("email") or data.get("customer_email") - else: - customer = getattr(data, "customer", None) - customer_id = (getattr(customer, "id", None) if customer else None) or getattr( - data, "customer_id", None - ) - customer_email = (getattr(customer, "email", None) if customer else None) or getattr( - data, "customer_email", None - ) - - if customer_email and customer_id: - user = await _find_user_by_email(customer_email, db) - if user and not user.polar_customer_id: - user.polar_customer_id = customer_id - await db.commit() - - return {"status": "received"} - - -async def _find_user_by_email(email: str | None, db: AsyncSession): - if not email: - return None - from sqlalchemy import func - - result = await db.execute(select(User).where(func.lower(User.email) == func.lower(email))) - return result.scalar_one_or_none() - - @router.get("/payment/portal") async def get_customer_portal( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """ - Generate a secure Polar.sh Customer Portal URL for the user. - Strategies: - 0. Use user.polar_customer_id if available. - 1. Try external_customer_id (user email). - 2. Fallback to customers.list() by email. - """ + """Generate a secure Polar.sh Customer Portal URL for the user.""" if not polar_client: raise HTTPException(status_code=500, detail="Polar client not initialized") @@ -416,46 +226,25 @@ async def get_customer_portal( except Exception: pass - # Strategy 2: Look up by email (requires customers:read scope) + # Strategy 2: Look up by email try: - print(f"DEBUG: Searching Polar for email: {user.email}") response = polar_client.customers.list(email=user.email) - print(f"DEBUG: Polar list response: {response}") - - # Determine items list from response (ListResource object or similar) - items = [] - target = response - if hasattr(response, "result"): - target = response.result - - if hasattr(target, "items"): - items = target.items - elif isinstance(target, list): - items = target - elif hasattr(target, "__getitem__"): # Try as subscriptable - try: - items = list(target) - except Exception: - pass + target = getattr(response, "result", response) + items = getattr(target, "items", None) or (target if isinstance(target, list) else []) if not items: - print(f"DEBUG: No customers found for {user.email}") raise HTTPException( status_code=404, detail=f"Subscription not found in Polar for {user.email}. Please use the email you paid with or contact support.", ) customer = items[0] - # Get customer ID - handle both pydantic object and dict customer_id = getattr(customer, "id", None) or ( customer.get("id") if isinstance(customer, dict) else None ) - if not customer_id: - print(f"DEBUG: Customer object found but no ID: {customer}") raise HTTPException(status_code=500, detail="Could not retrieve Customer ID from Polar") - # Save it for next time (Strategy 0) user.polar_customer_id = customer_id await db.commit() @@ -469,114 +258,21 @@ async def get_customer_portal( raise except Exception as e: err = str(e) - print(f"DEBUG: Portal exception: {err}") - # Check if it's a scope issue (403 Forbidden) if any(x in err.lower() for x in ["403", "forbidden", "scope", "privilege"]): raise HTTPException( status_code=503, - detail=( - "Billing portal scope error. Please ensure your Polar API key has " - "'customers:read' and 'customer_sessions:write' scopes." - ), + detail="Billing portal scope error. Ensure your Polar API key has 'customers:read' and 'customer_sessions:write' scopes.", ) from None raise HTTPException(status_code=500, detail=f"Polar API Error: {err}") from None -# ── Billing Info Endpoints ──────────────────────────────────────────────────── - -PLAN_META = { - "free": {"name": "Free", "price": "$0", "price_monthly": 0}, - "starter": {"name": "Starter", "price": "$15", "price_monthly": 15}, - "pro": {"name": "Professional", "price": "$35", "price_monthly": 35}, - "unlimited": {"name": "Truly Unlimited", "price": "$89", "price_monthly": 89}, - "enterprise": {"name": "Enterprise", "price": "Custom", "price_monthly": -1}, -} - - -async def _sync_subscription_from_polar(user: User, db: AsyncSession): - """Proactively sync subscription status from Polar if it seems stale.""" - if not polar_client: - return - - try: - # 1. Find or verify customer - customer_id = user.polar_customer_id - if not customer_id: - logger.info(f"Sync: Searching Polar for customer {user.email}") - res = polar_client.customers.list(email=user.email) - items = [] - target = res - if hasattr(res, "result"): - target = res.result - if hasattr(target, "items"): - items = target.items - elif isinstance(target, list): - items = target - - if items: - customer_id = getattr(items[0], "id", None) or items[0].get("id") - if customer_id: - user.polar_customer_id = customer_id - await db.commit() - - if not customer_id: - return - - # 2. List subscriptions - logger.info(f"Sync: Fetching subscriptions for customer {customer_id}") - sub_res = polar_client.subscriptions.list(customer_id=customer_id) - subs = [] - target = sub_res - if hasattr(sub_res, "result"): - target = sub_res.result - if hasattr(target, "items"): - subs = target.items - elif isinstance(target, list): - subs = target - - # 3. Find active/trialing sub - active_sub = next( - (s for s in subs if getattr(s, "status", None) in ["active", "trialing"]), None - ) - if active_sub: - product_id = getattr(active_sub, "product_id", None) - new_tier = _resolve_tier(product_id) - current_status = getattr(active_sub, "status", "active") - - logger.info( - f"Sync: Found active sub {getattr(active_sub, 'id', 'N/A')} with tier {new_tier}, status {current_status}" - ) - - if new_tier and ( - user.subscription_tier != new_tier or user.subscription_status != current_status - ): - logger.info( - f"Sync: Updating user {user.email} -> tier: {new_tier}, status: {current_status}" - ) - user.subscription_tier = new_tier - user.subscription_status = current_status - user.polar_subscription_id = getattr(active_sub, "id", None) - await db.commit() - else: - logger.info(f"Sync: No active/trialing subscription found in Polar for {user.email}") - if user.subscription_tier != "free": - # Only log, don't auto-downgrade yet to be safe - logger.warning( - f"Sync: User {user.email} has local tier {user.subscription_tier} but no active sub in Polar" - ) - - except Exception as e: - logger.warning(f"Subscription sync failed for {user.email}: {e}") - - @router.get("/billing/subscription") async def get_subscription( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """Return the user's current subscription plan details.""" - # Proactively sync if they are on free or we have a customer ID if user.subscription_tier == "free" or user.polar_customer_id: - await _sync_subscription_from_polar(user, db) + await sync_subscription_from_polar(user, db, polar_client) tier = (user.subscription_tier or "free").lower() meta = PLAN_META.get(tier, PLAN_META["free"]) @@ -600,7 +296,7 @@ async def get_subscription( @router.get("/billing/usage") async def get_usage(user: User = Depends(get_current_user)): - """Return the user's current month usage from Redis with overage projections.""" + """Return the user's current month usage with overage projections.""" tier = (user.subscription_tier or "free").lower() limits = get_tier_limits(tier) chat_limit = limits.get(FEATURE_CHAT_LIMIT, 30) @@ -608,12 +304,11 @@ async def get_usage(user: User = Depends(get_current_user)): now = datetime.datetime.utcnow() year_month = now.strftime("%Y-%m") - - # Days until end of month - if now.month == 12: - next_month = now.replace(year=now.year + 1, month=1, day=1) - else: - next_month = now.replace(month=now.month + 1, day=1) + next_month = ( + now.replace(year=now.year + 1, month=1, day=1) + if now.month == 12 + else now.replace(month=now.month + 1, day=1) + ) resets_in_days = (next_month - now).days transcription_used = 0.0 @@ -622,34 +317,24 @@ async def get_usage(user: User = Depends(get_current_user)): try: r = RedisManager.get_client() - - transcription_key = f"transcription_usage:{user.email}:{year_month}" - chat_key = f"chat_usage:{user.id}:{year_month}" - - transcription_raw, chat_raw = await r.mget(transcription_key, chat_key) + transcription_raw, chat_raw = await r.mget( + f"transcription_usage:{user.email}:{year_month}", + f"chat_usage:{user.id}:{year_month}", + ) overage_flag = await r.get(f"transcription_overage:{user.id}") transcription_overage = overage_flag == "1" - transcription_used = round(float(transcription_raw or 0), 1) chat_used = int(chat_raw or 0) except Exception as e: - print(f"Billing usage Redis error: {e}") + logger.warning(f"Billing usage Redis error: {e}") - # Calculate Overage for UI + overage_mins = ( + max(0, transcription_used - transcription_limit) if transcription_limit > 0 else 0.0 + ) + is_hard_capped = (tier == "free") or (overage_mins >= 120.0) projected_overage = 0.0 - OVERAGE_LIMIT = 120.0 # 2 hour hard cap on overage - - overage_mins = 0.0 - if transcription_limit > 0: - overage_mins = max(0, transcription_used - transcription_limit) - - is_hard_capped = (tier == "free") or (overage_mins >= OVERAGE_LIMIT) - if not is_hard_capped and tier != "free" and overage_mins > 0: - # Starter: $1.50/hr, Professional: $1.00/hr - rate = 1.50 if tier == "starter" else 1.00 - if tier in ["unlimited", "enterprise"]: - rate = 0.0 + rate = 1.50 if tier == "starter" else (0.0 if tier in ["unlimited", "enterprise"] else 1.00) projected_overage = round((overage_mins / 60.0) * rate, 2) return { @@ -664,11 +349,10 @@ async def get_usage(user: User = Depends(get_current_user)): "used": chat_used, "limit": chat_limit, "is_overage": (chat_used > chat_limit > 0), - "is_hard_capped": (tier == "free") - or (chat_used > chat_limit > 0 and tier == "free"), # Placeholder for chat + "is_hard_capped": (tier == "free") or (chat_used > chat_limit > 0 and tier == "free"), }, "resets_in_days": resets_in_days, - "is_hard_capped": is_hard_capped, # Lock UI if over buffer + "is_hard_capped": is_hard_capped, "transcription_overage": transcription_overage, } diff --git a/apps/integrations/jira/services/billing_service.py b/apps/integrations/jira/services/billing_service.py new file mode 100644 index 0000000..916fe39 --- /dev/null +++ b/apps/integrations/jira/services/billing_service.py @@ -0,0 +1,245 @@ +""" +Subscription management business logic for Polar.sh integration. +HTTP routing and Polar API calls stay in routes/payment.py. +""" + +import datetime +import logging +import os + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from libs.common.models import User + +logger = logging.getLogger(__name__) + +PLAN_META = { + "free": {"name": "Free", "price": "$0", "price_monthly": 0}, + "starter": {"name": "Starter", "price": "$15", "price_monthly": 15}, + "pro": {"name": "Professional", "price": "$35", "price_monthly": 35}, + "unlimited": {"name": "Truly Unlimited", "price": "$89", "price_monthly": 89}, + "enterprise": {"name": "Enterprise", "price": "Custom", "price_monthly": -1}, +} + + +def resolve_tier(product_id: str | None) -> str | None: + """Map a Polar product_id to an internal tier name. Returns None if unknown.""" + if not product_id: + return None + mapping = { + k: v + for k, v in { + os.getenv("POLAR_PRODUCT_ID_STARTER"): "starter", + os.getenv("POLAR_PRODUCT_ID_PRO"): "pro", + os.getenv("POLAR_PRODUCT_ID_ULTIMATE"): "unlimited", + }.items() + if k + } + return mapping.get(product_id) + + +def extract_sub_fields(data) -> dict: + """Extract subscription fields from either a dict or Polar SDK model object.""" + if isinstance(data, dict): + customer = data.get("customer") or {} + return { + "subscription_id": data.get("id"), + "product_id": data.get("product_id"), + "customer_id": customer.get("id") or data.get("customer_id"), + "customer_email": customer.get("email") + or data.get("customer_email") + or data.get("email"), + "status": data.get("status"), + "ends_at": data.get("ends_at"), + } + customer = getattr(data, "customer", None) + customer_email = getattr(customer, "email", None) if customer else None + customer_id = getattr(customer, "id", None) if customer else None + return { + "subscription_id": getattr(data, "id", None), + "product_id": getattr(data, "product_id", None), + "customer_id": customer_id or getattr(data, "customer_id", None), + "customer_email": customer_email or getattr(data, "customer_email", None), + "status": getattr(data, "status", None), + "ends_at": getattr(data, "ends_at", None), + } + + +async def find_user_by_email(email: str | None, db: AsyncSession) -> User | None: + if not email: + return None + result = await db.execute(select(User).where(func.lower(User.email) == func.lower(email))) + return result.scalar_one_or_none() + + +async def handle_subscription_active(data, db: AsyncSession) -> dict: + """Grant access when subscription becomes active (payment confirmed).""" + fields = extract_sub_fields(data) + user = await find_user_by_email(fields["customer_email"], db) + if not user: + return {"status": "ignored", "reason": "user_not_found"} + + if ( + user.polar_subscription_id == fields["subscription_id"] + and user.subscription_status == "active" + ): + return {"status": "skipped", "reason": "already_active"} + + new_tier = resolve_tier(fields["product_id"]) + if not new_tier: + logger.error(f"Unknown product_id in subscription.active: {fields['product_id']!r}") + return {"status": "ignored", "reason": "unknown_product_id"} + + user.subscription_tier = new_tier + user.subscription_status = "active" + user.subscription_ends_at = None + user.polar_subscription_id = fields["subscription_id"] + if fields["customer_id"]: + user.polar_customer_id = fields["customer_id"] + await db.commit() + + logger.info(f"Granted {new_tier} tier to {fields['customer_email']} via subscription.active") + return {"status": "success", "tier": new_tier} + + +async def handle_subscription_updated(data, db: AsyncSession) -> dict: + """Sync tier on any subscription change (upgrade/downgrade).""" + fields = extract_sub_fields(data) + user = await find_user_by_email(fields["customer_email"], db) + if not user: + return {"status": "ignored", "reason": "user_not_found"} + + new_tier = resolve_tier(fields["product_id"]) + if not new_tier: + logger.error(f"Unknown product_id in subscription.updated: {fields['product_id']!r}") + return {"status": "ignored", "reason": "unknown_product_id"} + + user.subscription_tier = new_tier + user.subscription_status = fields["status"] or "active" + user.polar_subscription_id = fields["subscription_id"] + await db.commit() + + logger.info(f"Updated {fields['customer_email']} tier to {new_tier} via subscription.updated") + return {"status": "success", "tier": new_tier} + + +async def handle_subscription_canceled(data, db: AsyncSession) -> dict: + """Mark subscription as canceled; keep tier until ends_at.""" + fields = extract_sub_fields(data) + user = await find_user_by_email(fields["customer_email"], db) + if not user: + return {"status": "ignored", "reason": "user_not_found"} + + ends_at = fields["ends_at"] + if isinstance(ends_at, str): + try: + ends_at = datetime.datetime.fromisoformat(ends_at.replace("Z", "+00:00")).replace( + tzinfo=None + ) + except ValueError: + ends_at = None + + user.subscription_status = "canceled" + user.subscription_ends_at = ends_at + await db.commit() + + logger.info(f"Marked {fields['customer_email']} subscription as canceled (ends: {ends_at})") + return {"status": "success", "action": "scheduled_downgrade"} + + +async def handle_subscription_revoked(data, db: AsyncSession) -> dict: + """Immediately downgrade to free on revocation.""" + fields = extract_sub_fields(data) + user = await find_user_by_email(fields["customer_email"], db) + if not user: + return {"status": "ignored", "reason": "user_not_found"} + + user.subscription_tier = "free" + user.subscription_status = "revoked" + user.subscription_ends_at = None + await db.commit() + + logger.info(f"Revoked subscription for {fields['customer_email']} — downgraded to free") + return {"status": "success", "action": "downgraded_to_free"} + + +async def handle_order_or_customer(event_type: str, data, db: AsyncSession) -> dict: + """Store polar_customer_id from order/customer events.""" + if isinstance(data, dict): + customer = data.get("customer") or {} + customer_id = customer.get("id") or data.get("customer_id") + customer_email = customer.get("email") or data.get("customer_email") + else: + customer = getattr(data, "customer", None) + customer_id = (getattr(customer, "id", None) if customer else None) or getattr( + data, "customer_id", None + ) + customer_email = (getattr(customer, "email", None) if customer else None) or getattr( + data, "customer_email", None + ) + + if customer_email and customer_id: + user = await find_user_by_email(customer_email, db) + if user and not user.polar_customer_id: + user.polar_customer_id = customer_id + await db.commit() + + return {"status": "received"} + + +async def sync_subscription_from_polar(user: User, db: AsyncSession, polar_client) -> None: + """Proactively sync subscription status from Polar if it seems stale.""" + if not polar_client: + return + + try: + customer_id = user.polar_customer_id + if not customer_id: + logger.info(f"Sync: Searching Polar for customer {user.email}") + res = polar_client.customers.list(email=user.email) + target = getattr(res, "result", res) + items = getattr(target, "items", None) or (target if isinstance(target, list) else []) + if items: + customer_id = getattr(items[0], "id", None) or ( + items[0].get("id") if isinstance(items[0], dict) else None + ) + if customer_id: + user.polar_customer_id = customer_id + await db.commit() + + if not customer_id: + return + + logger.info(f"Sync: Fetching subscriptions for customer {customer_id}") + sub_res = polar_client.subscriptions.list(customer_id=customer_id) + target = getattr(sub_res, "result", sub_res) + subs = getattr(target, "items", None) or (target if isinstance(target, list) else []) + + active_sub = next( + (s for s in subs if getattr(s, "status", None) in ["active", "trialing"]), None + ) + if active_sub: + product_id = getattr(active_sub, "product_id", None) + new_tier = resolve_tier(product_id) + current_status = getattr(active_sub, "status", "active") + + if new_tier and ( + user.subscription_tier != new_tier or user.subscription_status != current_status + ): + logger.info( + f"Sync: Updating {user.email} -> tier: {new_tier}, status: {current_status}" + ) + user.subscription_tier = new_tier + user.subscription_status = current_status + user.polar_subscription_id = getattr(active_sub, "id", None) + await db.commit() + else: + logger.info(f"Sync: No active/trialing subscription found for {user.email}") + if user.subscription_tier != "free": + logger.warning( + f"Sync: User {user.email} has local tier {user.subscription_tier} but no active sub in Polar" + ) + + except Exception as e: + logger.warning(f"Subscription sync failed for {user.email}: {e}") diff --git a/apps/integrations/jira/services/github_integration_service.py b/apps/integrations/jira/services/github_integration_service.py new file mode 100644 index 0000000..e3f79f4 --- /dev/null +++ b/apps/integrations/jira/services/github_integration_service.py @@ -0,0 +1,312 @@ +""" +Business logic for GitHub↔Jira integration endpoints. +HTTP auth and route definitions stay in routes/github_repos.py. +""" + +import asyncio +import logging +import os + +import httpx +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from libs.common.models import GitHubInstallation, GitHubWebhook, User + +logger = logging.getLogger(__name__) + + +async def confirm_installation( + installation_id: int, + user: User, + db: AsyncSession, +) -> dict: + """Verify a GitHub App installation with GitHub and persist it to the DB.""" + async with httpx.AsyncClient() as client: + resp = await client.get( + "https://api.github.com/user/installations", + headers={ + "Authorization": f"Bearer {user.github_access_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + if resp.status_code != 200: + raise HTTPException(status_code=400, detail="Failed to fetch installations from GitHub") + + installations = resp.json().get("installations", []) + inst_data = next((i for i in installations if i["id"] == installation_id), None) + if not inst_data: + raise HTTPException( + status_code=404, + detail=f"Installation {installation_id} not found for this user", + ) + + account = inst_data.get("account", {}) + existing_result = await db.execute( + select(GitHubInstallation).where(GitHubInstallation.id == installation_id) + ) + existing = existing_result.scalar_one_or_none() + + if existing: + existing.user_id = user.id + existing.account_login = account.get("login") + existing.account_id = account.get("id") + existing.account_type = account.get("type") + existing.repository_selection = inst_data.get("repository_selection", "all") + else: + db.add( + GitHubInstallation( + id=installation_id, + user_id=user.id, + account_login=account.get("login"), + account_id=account.get("id"), + account_type=account.get("type"), + repository_selection=inst_data.get("repository_selection", "all"), + ) + ) + + await db.commit() + return {"status": "success", "installation_id": installation_id} + + +async def create_webhook( + owner: str, + repo: str, + user: User, + db: AsyncSession, +) -> dict: + """Create a GitHub webhook for the given repo and persist it.""" + existing_result = await db.execute( + select(GitHubWebhook).where( + GitHubWebhook.user_id == user.id, + GitHubWebhook.repo_full_name == f"{owner}/{repo}", + ) + ) + existing = existing_result.scalar_one_or_none() + if existing: + return { + "status": "success", + "message": "Webhook already exists", + "webhook_id": existing.webhook_id, + } + + url_base = os.getenv("GITHUB_WEBHOOK_PUBLIC_URL") or os.getenv("JIRA_SERVICE_PUBLIC_URL") + if not url_base or "localhost" in url_base: + raise HTTPException( + status_code=503, + detail=( + "Webhook creation is not available in local development. " + "Set GITHUB_WEBHOOK_PUBLIC_URL to your public server URL." + ), + ) + + webhook_url = f"{url_base.rstrip('/')}/api/webhooks/github" + webhook_secret = os.getenv("GITHUB_WEBHOOK_SECRET", "") + if not webhook_secret: + logger.warning("GITHUB_WEBHOOK_SECRET is not set. Webhooks will not be secured.") + + payload: dict = { + "name": "web", + "active": True, + "events": ["push", "pull_request", "issues", "pull_request_review", "workflow_run"], + "config": {"url": webhook_url, "content_type": "json", "insecure_ssl": "0"}, + } + if webhook_secret: + payload["config"]["secret"] = webhook_secret + + logger.info(f"Creating GitHub webhook for {owner}/{repo} at {webhook_url}") + + async with httpx.AsyncClient() as client: + response = await client.post( + f"https://api.github.com/repos/{owner}/{repo}/hooks", + json=payload, + headers={ + "Authorization": f"Bearer {user.github_access_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + + if response.status_code == 201: + data = response.json() + db.add( + GitHubWebhook( + user_id=user.id, repo_full_name=f"{owner}/{repo}", webhook_id=data["id"] + ) + ) + await db.commit() + return {"status": "success", "webhook_id": data["id"]} + + error_msg = response.json().get("message", "Unknown error") + if "already exists" in error_msg.lower(): + return await sync_existing_webhook(owner, repo, user, db, webhook_url) + + raise HTTPException(status_code=400, detail=f"Failed to create webhook: {error_msg}") + + +async def sync_existing_webhook( + owner: str, + repo: str, + user: User, + db: AsyncSession, + target_url: str, +) -> dict: + """Find an existing webhook on GitHub and sync it to the DB.""" + async with httpx.AsyncClient() as client: + response = await client.get( + f"https://api.github.com/repos/{owner}/{repo}/hooks", + headers={ + "Authorization": f"Bearer {user.github_access_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + if response.status_code == 200: + for hook in response.json(): + if hook.get("config", {}).get("url") == target_url: + db.add( + GitHubWebhook( + user_id=user.id, + repo_full_name=f"{owner}/{repo}", + webhook_id=hook["id"], + ) + ) + await db.commit() + return { + "status": "success", + "message": "Synced existing webhook", + "webhook_id": hook["id"], + } + + raise HTTPException(status_code=400, detail="Webhook already exists but could not be synced") + + +async def delete_webhook( + owner: str, + repo: str, + user: User, + db: AsyncSession, +) -> dict: + """Remove a GitHub webhook from the repo and delete from DB.""" + existing_result = await db.execute( + select(GitHubWebhook).where( + GitHubWebhook.user_id == user.id, + GitHubWebhook.repo_full_name == f"{owner}/{repo}", + ) + ) + existing = existing_result.scalar_one_or_none() + if not existing: + return {"status": "success", "message": "Webhook not found in database"} + + async with httpx.AsyncClient() as client: + response = await client.delete( + f"https://api.github.com/repos/{owner}/{repo}/hooks/{existing.webhook_id}", + headers={ + "Authorization": f"Bearer {user.github_access_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + if response.status_code in (204, 404): + await db.delete(existing) + await db.commit() + return {"status": "success", "message": "Webhook deleted"} + + raise HTTPException(status_code=400, detail="Failed to delete webhook on GitHub API") + + +async def _fetch_commit_stats( + client: httpx.AsyncClient, + owner: str, + repo: str, + sha: str, + headers: dict, +) -> dict: + """Fetch additions/deletions/filesChanged stats for a single commit.""" + try: + resp = await client.get( + f"https://api.github.com/repos/{owner}/{repo}/commits/{sha}", + headers=headers, + timeout=15.0, + ) + if resp.status_code != 200: + return {"additions": 0, "deletions": 0, "filesChanged": 0, "files": []} + data = resp.json() + stats = data.get("stats", {}) + files_raw = data.get("files", []) + return { + "additions": stats.get("additions", 0), + "deletions": stats.get("deletions", 0), + "filesChanged": len(files_raw), + "files": [ + { + "path": f.get("filename"), + "additions": f.get("additions", 0), + "deletions": f.get("deletions", 0), + "patch": f.get("patch", ""), + } + for f in files_raw + ], + } + except Exception: + return {"additions": 0, "deletions": 0, "filesChanged": 0, "files": []} + + +async def fetch_commits_with_stats( + owner: str, + repo: str, + branch: str, + per_page: int, + github_token: str, +) -> list[dict]: + """Fetch commits for a branch and enrich each with parallel stat fetches.""" + gh_headers = { + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github.v3+json", + } + + async with httpx.AsyncClient(timeout=20.0) as client: + response = await client.get( + f"https://api.github.com/repos/{owner}/{repo}/commits", + params={"sha": branch, "per_page": per_page}, + headers=gh_headers, + ) + if response.status_code == 401: + raise HTTPException( + status_code=401, + detail="GitHub token expired or revoked. Please reconnect GitHub.", + ) + if response.status_code == 404: + raise HTTPException( + status_code=404, detail=f"Repository {owner}/{repo} not found or not accessible." + ) + if response.status_code != 200: + raise HTTPException(status_code=400, detail="Failed to fetch commits from GitHub.") + + commits_raw = response.json() + stats_list = await asyncio.gather( + *[_fetch_commit_stats(client, owner, repo, c["sha"], gh_headers) for c in commits_raw] + ) + + result = [] + for commit_raw, stats in zip(commits_raw, stats_list, strict=True): + commit = commit_raw.get("commit", {}) + author_info = commit_raw.get("author") or {} + commit_author = commit.get("author") or {} + result.append( + { + "hash": commit_raw.get("sha", ""), + "message": (commit.get("message") or "").split("\n")[0], + "author": { + "name": author_info.get("login") or commit_author.get("name") or "Unknown", + "avatarUrl": author_info.get("avatar_url") + or "https://avatars.githubusercontent.com/u/0", + }, + "timestamp": commit_author.get("date") or "", + "stats": { + "additions": stats.get("additions", 0), + "deletions": stats.get("deletions", 0), + "filesChanged": stats.get("filesChanged", 0), + }, + "files": stats.get("files", []), + } + ) + return result diff --git a/apps/integrations/jira/trello_client.py b/apps/integrations/jira/trello_client.py new file mode 100644 index 0000000..20476a6 --- /dev/null +++ b/apps/integrations/jira/trello_client.py @@ -0,0 +1,93 @@ +import logging + +import httpx + +logger = logging.getLogger(__name__) + +TRELLO_API_BASE = "https://api.trello.com/1" + + +class TrelloClient: + """Async Trello API client using the same Atlassian OAuth access token as Jira.""" + + def __init__(self, access_token: str, workspace_id: str): + self.access_token = access_token + self.workspace_id = workspace_id + self._client = httpx.AsyncClient( + headers={"Authorization": f"Bearer {access_token}", "Accept": "application/json"}, + timeout=30.0, + ) + + async def aclose(self): + await self._client.aclose() + + async def _request(self, method: str, path: str, **kwargs) -> dict | list: + url = f"{TRELLO_API_BASE}{path}" + resp = await self._client.request(method, url, **kwargs) + if resp.status_code == 401: + raise TrelloAuthError("Trello access token is invalid or expired") + resp.raise_for_status() + return resp.json() + + async def get_boards(self) -> list[dict]: + """Return all open boards for the authenticated user.""" + return await self._request("GET", "/members/me/boards", params={"filter": "open"}) + + async def get_board_with_cards(self, board_id: str) -> dict: + """Return board data including open lists and all cards in one request.""" + return await self._request( + "GET", + f"/boards/{board_id}", + params={"lists": "open", "cards": "open", "members": "all", "card_members": "true"}, + ) + + async def get_board_lists(self, board_id: str) -> list[dict]: + """Return open lists for a board.""" + return await self._request("GET", f"/boards/{board_id}/lists", params={"filter": "open"}) + + async def get_board_members(self, board_id: str) -> list[dict]: + """Return members of a board.""" + return await self._request("GET", f"/boards/{board_id}/members") + + async def create_card( + self, + list_id: str, + name: str, + desc: str = "", + due: str | None = None, + member_ids: list[str] | None = None, + label_ids: list[str] | None = None, + ) -> dict: + """Create a card in the given list.""" + payload: dict = {"idList": list_id, "name": name, "desc": desc} + if due: + payload["due"] = due + if member_ids: + payload["idMembers"] = ",".join(member_ids) + if label_ids: + payload["idLabels"] = ",".join(label_ids) + return await self._request("POST", "/cards", json=payload) + + async def update_card(self, card_id: str, **fields) -> dict: + """Partially update a card. Supported fields: name, desc, due, idList, closed, etc.""" + return await self._request("PUT", f"/cards/{card_id}", json=fields) + + async def delete_card(self, card_id: str) -> None: + """Delete a card permanently.""" + await self._request("DELETE", f"/cards/{card_id}") + + async def move_card(self, card_id: str, list_id: str) -> dict: + """Move a card to a different list (status change).""" + return await self._request("PUT", f"/cards/{card_id}", json={"idList": list_id}) + + async def get_card(self, card_id: str) -> dict: + """Return a single card with member info.""" + return await self._request( + "GET", + f"/cards/{card_id}", + params={"members": "true", "member_fields": "fullName,username"}, + ) + + +class TrelloAuthError(Exception): + """Raised when the Trello access token is invalid or revoked.""" diff --git a/apps/integrations/notion/routes/api.py b/apps/integrations/notion/routes/api.py index 11d7bc3..51b28f1 100644 --- a/apps/integrations/notion/routes/api.py +++ b/apps/integrations/notion/routes/api.py @@ -1,6 +1,5 @@ import logging import os -from typing import Optional from fastapi import APIRouter, Depends, Header, HTTPException from jose import jwt @@ -9,9 +8,10 @@ from apps.integrations.jira.routes.api_routes.shared import _get_jira_client_core from apps.integrations.notion.notion_client import NotionClient +from apps.integrations.notion.services import sync_service from apps.integrations.notion.types import MeetingExportPayload, NotionConfigUpdate from libs.common.database import get_async_db -from libs.common.models import ActiveJiraProject, Meeting, NotionToken, User +from libs.common.models import NotionToken, User logger = logging.getLogger(__name__) router = APIRouter() @@ -20,11 +20,9 @@ ALGORITHM = "HS256" -# Helper to get the current authenticated user's NotionToken async def get_user_notion_token(authorization: str, db: AsyncSession) -> NotionToken: if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Unauthorized") - token = authorization.split(" ")[1] try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) @@ -45,11 +43,26 @@ async def get_user_notion_token(authorization: str, db: AsyncSession) -> NotionT raise HTTPException(status_code=401, detail="Invalid token") from e +async def _get_current_user(authorization: str, db: AsyncSession) -> User: + token = authorization.split(" ")[1] + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + email = payload.get("sub") + result = await db.execute(select(User).where(User.email == email)) + user = result.scalar_one_or_none() + if not user: + raise HTTPException(status_code=401, detail="User not found") + return user + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid token") from e + + @router.get("/notion/config") async def get_notion_config( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db) ): - """Fetch current user's Notion configuration.""" notion_token = await get_user_notion_token(authorization, db) return { "workspace_name": notion_token.workspace_name, @@ -71,9 +84,7 @@ async def update_notion_config( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db), ): - """Update user's Notion configuration.""" notion_token = await get_user_notion_token(authorization, db) - if payload.database_id is not None: notion_token.database_id = payload.database_id if payload.page_id is not None: @@ -88,7 +99,6 @@ async def update_notion_config( notion_token.backlog_database_id = payload.backlog_database_id if payload.backlog_status_property is not None: notion_token.backlog_status_property = payload.backlog_status_property - await db.commit() return {"status": "success"} @@ -99,69 +109,9 @@ async def export_to_notion( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db), ): - """Export a meeting to Notion using the DB-stored token.""" notion_token = await get_user_notion_token(authorization, db) - - client = NotionClient(access_token=notion_token.access_token) - - # Use config from DB if not fully provided in payload (or merge) - # The frontend currently sends it in payload, so we use that but fallback to DB. - sync_mode = payload.config.sync_mode or notion_token.sync_mode - database_id = payload.config.database_id or notion_token.database_id - data_source_id = notion_token.data_source_id - page_id = payload.config.page_id or notion_token.page_id - include_transcript = payload.config.include_transcript or bool(notion_token.include_transcript) - backlog_db_id = ( - payload.backlog_config.database_id - if payload.backlog_config - else notion_token.backlog_database_id - ) - try: - try: - if sync_mode == "database" and database_id: - res_id = await client.add_meeting_to_database( - database_id=database_id, - data_source_id=data_source_id, - title=payload.title, - date=payload.date, - summary=payload.summary, - action_items=payload.action_items, - transcript=payload.transcript, - include_transcript=include_transcript, - meeting_url=payload.meeting_url, - duration=payload.duration, - ) - elif page_id: - res_id = await client.add_meeting_as_page( - parent_page_id=page_id, - title=payload.title, - date=payload.date, - summary=payload.summary, - action_items=payload.action_items, - transcript=payload.transcript, - include_transcript=include_transcript, - meeting_url=payload.meeting_url, - duration=payload.duration, - ) - else: - raise HTTPException(status_code=400, detail="No target database or page configured") - - backlog_ids = [] - if backlog_db_id: - backlog_ids = await client.add_action_items_to_backlog( - backlog_database_id=backlog_db_id, - action_items=payload.action_items, - meeting_title=payload.title, - ) - - return { - "status": "success", - "page_id": res_id, - "backlog_items_created": len(backlog_ids), - } - finally: - await client.aclose() + return await sync_service.export_meeting(notion_token, payload) except HTTPException: raise except Exception as e: @@ -174,165 +124,23 @@ async def sync_kanban_to_notion( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db), ): - """ - Export the user's active Jira sprint board to a Notion Kanban database. - Idempotent: creates the database on first call, then upserts on subsequent calls. - Uses 'Issue Key' as the idempotency key so issues are updated, not duplicated. - Returns {"status", "created", "updated", "database_url"}. - """ - # 1. Authenticate with Notion notion_token = await get_user_notion_token(authorization, db) - - # 3. Build Jira client (reuses shared helper that handles token refresh) + user = await _get_current_user(authorization, db) jira_client = await _get_jira_client_core(authorization, db) - - # 4. Get current user to resolve their active Jira project - token = authorization.split(" ")[1] try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - email = payload.get("sub") - user_result = await db.execute(select(User).where(User.email == email)) - user = user_result.scalar_one_or_none() - if not user: - raise HTTPException(status_code=401, detail="User not found") - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=401, detail="Invalid token") from e - - # 5. Find the first active Jira project for this user - project_result = await db.execute( - select(ActiveJiraProject) - .where(ActiveJiraProject.user_id == user.id) - .order_by(ActiveJiraProject.created_at) - ) - active_project = project_result.scalars().first() - if not active_project: - raise HTTPException( - status_code=400, - detail="No active Jira project found. Please activate a project in the dashboard first.", - ) - - project_key = active_project.project_key - project_title = active_project.project_name or project_key - - # 6. Fetch Jira issues for this project (ordered by last updated, max 100) - try: - jira_resp = await jira_client.search_issues( - f'project = "{project_key}" ORDER BY updated DESC', - max_results=100, - fields=["summary", "status", "priority", "assignee"], - ) - except Exception as e: - logger.error(f"[Notion sync-kanban] Jira fetch failed: {e}") - raise HTTPException(status_code=502, detail=f"Failed to fetch Jira issues: {e}") from e - - issues_raw = jira_resp.get("issues", []) - - # 7. Map raw Jira fields to flat dicts for the Notion client - issues = [] - for issue in issues_raw: - fields = issue.get("fields", {}) - assignee_obj = fields.get("assignee") - priority_obj = fields.get("priority") - issues.append( - { - "key": issue["key"], - "summary": fields.get("summary", ""), - "status": fields.get("status", {}).get("name", "To Do"), - "assignee": assignee_obj.get("displayName", "Unassigned") - if assignee_obj - else "Unassigned", - "priority": priority_obj.get("name", "Medium") if priority_obj else "Medium", - } - ) - - # 8. Create or reuse the Kanban database in Notion - notion_client = NotionClient(access_token=notion_token.access_token) - try: - # Check if an existing kanban DB is still accessible - if notion_token.kanban_database_id: - exists = await notion_client.check_database_exists(notion_token.kanban_database_id) - if not exists: - logger.warning( - f"[Notion] Kanban DB {notion_token.kanban_database_id} not accessible, recreating" - ) - notion_token.kanban_database_id = None - - # Create a new database if we don't have one - board_view_hint: str | None = None - if not notion_token.kanban_database_id: - # Resolve parent page: prefer stored page_id, else auto-discover first accessible page - parent_page_id = notion_token.page_id - if not parent_page_id: - workspaces = await notion_client.get_workspaces() - pages = [w for w in workspaces if w["type"] == "page"] - if not pages: - raise HTTPException( - status_code=400, - detail=( - "No accessible Notion pages found. Please share at least one page " - "with the Kwillo integration in Notion, then try again." - ), - ) - parent_page_id = pages[0]["id"] - # Persist auto-discovered page so next sync skips this step - notion_token.page_id = parent_page_id - - kanban_result = await notion_client.create_kanban_database( - parent_page_id=parent_page_id, - title=f"{project_title} — Kanban", - ) - notion_token.kanban_database_id = kanban_result["database_id"] - notion_token.kanban_data_source_id = kanban_result["data_source_id"] - board_view_hint = kanban_result["board_view_hint"] - await db.commit() - logger.info( - f"[Notion] Kanban DB created and saved: {notion_token.kanban_database_id}, " - f"data_source_id: {notion_token.kanban_data_source_id}" - ) - elif not notion_token.kanban_data_source_id: - # Migration path: existing DB but no data_source_id stored — retrieve it now - ds_id = await notion_client.get_data_source_id(notion_token.kanban_database_id) - if ds_id: - notion_token.kanban_data_source_id = ds_id - await db.commit() - logger.info(f"[Notion] Backfilled kanban_data_source_id: {ds_id}") - - # 9. Upsert all issues into the Kanban database - result = await notion_client.sync_issues_to_kanban( - database_id=notion_token.kanban_database_id, - issues=issues, - data_source_id=notion_token.kanban_data_source_id, - ) - - db_id_clean = notion_token.kanban_database_id.replace("-", "") - response: dict = { - "status": "success", - "created": result["created"], - "updated": result["updated"], - "total": len(issues), - "database_url": f"https://notion.so/{db_id_clean}", - } - if board_view_hint: - response["board_view_hint"] = board_view_hint - return response + return await sync_service.sync_kanban(notion_token, jira_client, user, db) except HTTPException: raise except Exception as e: logger.exception("[Notion] sync-kanban failed") raise HTTPException(status_code=500, detail=str(e)) from e - finally: - await notion_client.aclose() @router.get("/notion/workspaces") async def get_notion_workspaces( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db) ): - """List databases and pages available to this Notion integration.""" notion_token = await get_user_notion_token(authorization, db) - client = NotionClient(access_token=notion_token.access_token) try: workspaces = await client.get_workspaces() @@ -347,65 +155,9 @@ async def get_notion_workspaces( async def init_notion_database( authorization: str = Header(None), db: AsyncSession = Depends(get_async_db) ): - """ - Initialize a meetings database in Notion and update the user's config. - Idempotent: if a database is already configured and still exists in Notion, - returns the existing one without creating a duplicate. - """ notion_token = await get_user_notion_token(authorization, db) - client = NotionClient(access_token=notion_token.access_token) - try: - try: - # 1. Check if database already configured and still exists in Notion - if notion_token.database_id: - exists = await client.check_database_exists(notion_token.database_id) - if exists: - logger.info( - f"[Notion] init-database: existing DB {notion_token.database_id} still valid, returning it" - ) - return { - "status": "existing", - "database_id": notion_token.database_id, - "data_source_id": notion_token.data_source_id, - "message": "Notion database already configured and accessible.", - } - else: - logger.warning( - f"[Notion] init-database: configured DB {notion_token.database_id} not found in Notion, creating new one" - ) - - # 2. Find a suitable parent page if not configured - parent_page_id = notion_token.page_id - if not parent_page_id: - workspaces = await client.get_workspaces() - pages = [w for w in workspaces if w["type"] == "page"] - if not pages: - raise HTTPException( - status_code=400, - detail="No accessible Notion pages found or shared with the integration. Please share at least one page with the 'Kwillo' integration in Notion.", - ) - parent_page_id = pages[0]["id"] - - # 3. Create the database - db_res = await client.create_meetings_database(parent_page_id=parent_page_id) - db_id = db_res["database_id"] - ds_id = db_res["data_source_id"] - - # 4. Update config - notion_token.database_id = db_id - notion_token.data_source_id = ds_id - notion_token.sync_mode = "database" - await db.commit() - - return { - "status": "success", - "database_id": db_id, - "data_source_id": ds_id, - "message": "Notion database initialized and configured as export target.", - } - finally: - await client.aclose() + return await sync_service.init_database(notion_token, db) except HTTPException: raise except Exception as e: @@ -415,120 +167,15 @@ async def init_notion_database( @router.post("/notion/sync-all-meetings") async def sync_all_meetings_to_notion( - authorization: str = Header(None), db: AsyncSession = Depends(get_async_db) + authorization: str = Header(None), + db: AsyncSession = Depends(get_async_db), ): - """ - Sync all user's ready meetings to Notion. - Meetings already synced (notion_page_id set) are skipped — idempotent by design. - Returns a summary of synced, skipped, and errored meetings. - """ notion_token = await get_user_notion_token(authorization, db) - - # Extract user_id from JWT to query meetings - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Unauthorized") - token = authorization.split(" ")[1] + user = await _get_current_user(authorization, db) try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) - email = payload.get("sub") - user_result = await db.execute(select(User).where(User.email == email)) - user = user_result.scalar_one_or_none() - if not user: - raise HTTPException(status_code=401, detail="User not found") + return await sync_service.sync_all_meetings(notion_token, user, db) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=401, detail="Invalid token") from e - - sync_mode = notion_token.sync_mode or "database" - database_id = notion_token.database_id - page_id = notion_token.page_id - include_transcript = bool(notion_token.include_transcript) - data_source_id = notion_token.data_source_id - frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") - - if sync_mode == "database" and not database_id: - raise HTTPException( - status_code=400, detail="No Notion database configured. Run init-database first." - ) - if sync_mode == "page" and not page_id: - raise HTTPException(status_code=400, detail="No Notion page configured.") - - # Fetch all ready meetings for this user - meetings_result = await db.execute( - select(Meeting) - .where(Meeting.user_id == user.id, Meeting.status == "ready") - .order_by(Meeting.created_at.desc()) - ) - meetings = meetings_result.scalars().all() - - synced_count = 0 - skipped_count = 0 - errors = [] - - client = NotionClient(access_token=notion_token.access_token) - try: - for meeting in meetings: - # Skip already-synced meetings (idempotency key) - if meeting.notion_page_id: - skipped_count += 1 - continue - - try: - meeting_date = meeting.created_at.strftime("%Y-%m-%d") if meeting.created_at else "" - meeting_link = f"{frontend_url.rstrip('/')}/meetings/{meeting.id}" - - from apps.agents.orchestrator.meeting_routes.shared import format_duration - - duration_str = format_duration(meeting.duration_seconds or 0) - - action_items_detailed = [ - {"title": item, "description": ""} if isinstance(item, str) else item - for item in (meeting.action_items or []) - ] - - synced_page_id = None - if sync_mode == "database" and database_id: - synced_page_id = await client.add_meeting_to_database( - database_id=database_id, - data_source_id=data_source_id, - title=meeting.title or "Untitled Meeting", - date=meeting_date, - summary=meeting.summary or "", - action_items=action_items_detailed, - transcript=meeting.raw_transcript, - include_transcript=include_transcript, - duration=duration_str, - meeting_url=meeting_link, - ) - elif sync_mode == "page" and page_id: - synced_page_id = await client.add_meeting_as_page( - parent_page_id=page_id, - title=meeting.title or "Untitled Meeting", - date=meeting_date, - summary=meeting.summary or "", - action_items=action_items_detailed, - transcript=meeting.raw_transcript, - include_transcript=include_transcript, - duration=duration_str, - meeting_url=meeting_link, - ) - - if synced_page_id: - meeting.notion_page_id = synced_page_id - await db.commit() - synced_count += 1 - - except Exception as e: - logger.error(f"[Notion] Failed to sync meeting {meeting.id}: {e}") - errors.append({"meeting_id": meeting.id, "title": meeting.title, "error": str(e)}) - finally: - await client.aclose() - - return { - "status": "success", - "synced": synced_count, - "skipped": skipped_count, - "total": len(meetings), - "errors": errors, - } + logger.exception("[Notion] sync-all-meetings failed") + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/apps/integrations/notion/services/__init__.py b/apps/integrations/notion/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/integrations/notion/services/sync_service.py b/apps/integrations/notion/services/sync_service.py new file mode 100644 index 0000000..cd0fb88 --- /dev/null +++ b/apps/integrations/notion/services/sync_service.py @@ -0,0 +1,350 @@ +""" +Business logic for Notion sync operations. +Routes in routes/api.py delegate to these functions after auth. +""" + +import logging +import os + +from fastapi import HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from apps.integrations.notion.notion_client import NotionClient +from apps.integrations.notion.types import MeetingExportPayload +from libs.common.models import ActiveJiraProject, Meeting, NotionToken, User + +logger = logging.getLogger(__name__) + + +async def export_meeting( + notion_token: NotionToken, + payload: MeetingExportPayload, +) -> dict: + """Export a single meeting to Notion (database or page mode).""" + client = NotionClient(access_token=notion_token.access_token) + sync_mode = payload.config.sync_mode or notion_token.sync_mode + database_id = payload.config.database_id or notion_token.database_id + data_source_id = notion_token.data_source_id + page_id = payload.config.page_id or notion_token.page_id + include_transcript = payload.config.include_transcript or bool(notion_token.include_transcript) + backlog_db_id = ( + payload.backlog_config.database_id + if payload.backlog_config + else notion_token.backlog_database_id + ) + + try: + if sync_mode == "database" and database_id: + res_id = await client.add_meeting_to_database( + database_id=database_id, + data_source_id=data_source_id, + title=payload.title, + date=payload.date, + summary=payload.summary, + action_items=payload.action_items, + transcript=payload.transcript, + include_transcript=include_transcript, + meeting_url=payload.meeting_url, + duration=payload.duration, + ) + elif page_id: + res_id = await client.add_meeting_as_page( + parent_page_id=page_id, + title=payload.title, + date=payload.date, + summary=payload.summary, + action_items=payload.action_items, + transcript=payload.transcript, + include_transcript=include_transcript, + meeting_url=payload.meeting_url, + duration=payload.duration, + ) + else: + raise HTTPException(status_code=400, detail="No target database or page configured") + + backlog_ids = [] + if backlog_db_id: + backlog_ids = await client.add_action_items_to_backlog( + backlog_database_id=backlog_db_id, + action_items=payload.action_items, + meeting_title=payload.title, + ) + + return { + "status": "success", + "page_id": res_id, + "backlog_items_created": len(backlog_ids), + } + finally: + await client.aclose() + + +async def init_database( + notion_token: NotionToken, + db: AsyncSession, +) -> dict: + """ + Initialize a meetings database in Notion and persist the IDs. + Idempotent: returns existing DB if still accessible. + """ + client = NotionClient(access_token=notion_token.access_token) + try: + if notion_token.database_id: + exists = await client.check_database_exists(notion_token.database_id) + if exists: + logger.info( + f"[Notion] init-database: existing DB {notion_token.database_id} still valid" + ) + return { + "status": "existing", + "database_id": notion_token.database_id, + "data_source_id": notion_token.data_source_id, + "message": "Notion database already configured and accessible.", + } + logger.warning( + f"[Notion] init-database: configured DB {notion_token.database_id} not found, creating new one" + ) + + parent_page_id = notion_token.page_id + if not parent_page_id: + workspaces = await client.get_workspaces() + pages = [w for w in workspaces if w["type"] == "page"] + if not pages: + raise HTTPException( + status_code=400, + detail=( + "No accessible Notion pages found or shared with the integration. " + "Please share at least one page with the 'Kwillo' integration in Notion." + ), + ) + parent_page_id = pages[0]["id"] + + db_res = await client.create_meetings_database(parent_page_id=parent_page_id) + notion_token.database_id = db_res["database_id"] + notion_token.data_source_id = db_res["data_source_id"] + notion_token.sync_mode = "database" + await db.commit() + + return { + "status": "success", + "database_id": db_res["database_id"], + "data_source_id": db_res["data_source_id"], + "message": "Notion database initialized and configured as export target.", + } + finally: + await client.aclose() + + +async def sync_kanban( + notion_token: NotionToken, + jira_client, + user: User, + db: AsyncSession, +) -> dict: + """ + Sync the user's active Jira sprint board to a Notion Kanban database. + Idempotent: creates the DB on first call, upserts on subsequent calls. + """ + active_project_result = await db.execute( + select(ActiveJiraProject) + .where(ActiveJiraProject.user_id == user.id) + .order_by(ActiveJiraProject.created_at) + ) + active_project = active_project_result.scalars().first() + if not active_project: + raise HTTPException( + status_code=400, + detail="No active Jira project found. Please activate a project in the dashboard first.", + ) + + project_key = active_project.project_key + project_title = active_project.project_name or project_key + + try: + jira_resp = await jira_client.search_issues( + f'project = "{project_key}" ORDER BY updated DESC', + max_results=100, + fields=["summary", "status", "priority", "assignee"], + ) + except Exception as e: + logger.error(f"[Notion sync-kanban] Jira fetch failed: {e}") + raise HTTPException(status_code=502, detail=f"Failed to fetch Jira issues: {e}") from e + + issues = [ + { + "key": issue["key"], + "summary": issue.get("fields", {}).get("summary", ""), + "status": issue.get("fields", {}).get("status", {}).get("name", "To Do"), + "assignee": (issue.get("fields", {}).get("assignee") or {}).get( + "displayName", "Unassigned" + ), + "priority": (issue.get("fields", {}).get("priority") or {}).get("name", "Medium"), + } + for issue in jira_resp.get("issues", []) + ] + + notion_client = NotionClient(access_token=notion_token.access_token) + try: + if notion_token.kanban_database_id: + exists = await notion_client.check_database_exists(notion_token.kanban_database_id) + if not exists: + logger.warning( + f"[Notion] Kanban DB {notion_token.kanban_database_id} not accessible, recreating" + ) + notion_token.kanban_database_id = None + + board_view_hint: str | None = None + if not notion_token.kanban_database_id: + parent_page_id = notion_token.page_id + if not parent_page_id: + workspaces = await notion_client.get_workspaces() + pages = [w for w in workspaces if w["type"] == "page"] + if not pages: + raise HTTPException( + status_code=400, + detail=( + "No accessible Notion pages found. Please share at least one page " + "with the Kwillo integration in Notion, then try again." + ), + ) + parent_page_id = pages[0]["id"] + notion_token.page_id = parent_page_id + + kanban_result = await notion_client.create_kanban_database( + parent_page_id=parent_page_id, + title=f"{project_title} — Kanban", + ) + notion_token.kanban_database_id = kanban_result["database_id"] + notion_token.kanban_data_source_id = kanban_result["data_source_id"] + board_view_hint = kanban_result["board_view_hint"] + await db.commit() + logger.info( + f"[Notion] Kanban DB created: {notion_token.kanban_database_id}, " + f"data_source_id: {notion_token.kanban_data_source_id}" + ) + elif not notion_token.kanban_data_source_id: + ds_id = await notion_client.get_data_source_id(notion_token.kanban_database_id) + if ds_id: + notion_token.kanban_data_source_id = ds_id + await db.commit() + logger.info(f"[Notion] Backfilled kanban_data_source_id: {ds_id}") + + result = await notion_client.sync_issues_to_kanban( + database_id=notion_token.kanban_database_id, + issues=issues, + data_source_id=notion_token.kanban_data_source_id, + ) + + db_id_clean = notion_token.kanban_database_id.replace("-", "") + response: dict = { + "status": "success", + "created": result["created"], + "updated": result["updated"], + "total": len(issues), + "database_url": f"https://notion.so/{db_id_clean}", + } + if board_view_hint: + response["board_view_hint"] = board_view_hint + return response + finally: + await notion_client.aclose() + + +async def sync_all_meetings( + notion_token: NotionToken, + user: User, + db: AsyncSession, +) -> dict: + """ + Sync all ready meetings for the user to Notion. + Idempotent: meetings with notion_page_id already set are skipped. + """ + sync_mode = notion_token.sync_mode or "database" + database_id = notion_token.database_id + page_id = notion_token.page_id + include_transcript = bool(notion_token.include_transcript) + data_source_id = notion_token.data_source_id + frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173") + + if sync_mode == "database" and not database_id: + raise HTTPException( + status_code=400, detail="No Notion database configured. Run init-database first." + ) + if sync_mode == "page" and not page_id: + raise HTTPException(status_code=400, detail="No Notion page configured.") + + meetings_result = await db.execute( + select(Meeting) + .where(Meeting.user_id == user.id, Meeting.status == "ready") + .order_by(Meeting.created_at.desc()) + ) + meetings = meetings_result.scalars().all() + + synced_count = 0 + skipped_count = 0 + errors = [] + + client = NotionClient(access_token=notion_token.access_token) + try: + for meeting in meetings: + if meeting.notion_page_id: + skipped_count += 1 + continue + + try: + from apps.agents.orchestrator.meeting_routes.shared import format_duration + + meeting_date = meeting.created_at.strftime("%Y-%m-%d") if meeting.created_at else "" + meeting_link = f"{frontend_url.rstrip('/')}/meetings/{meeting.id}" + duration_str = format_duration(meeting.duration_seconds or 0) + action_items_detailed = [ + {"title": item, "description": ""} if isinstance(item, str) else item + for item in (meeting.action_items or []) + ] + + synced_page_id = None + if sync_mode == "database" and database_id: + synced_page_id = await client.add_meeting_to_database( + database_id=database_id, + data_source_id=data_source_id, + title=meeting.title or "Untitled Meeting", + date=meeting_date, + summary=meeting.summary or "", + action_items=action_items_detailed, + transcript=meeting.raw_transcript, + include_transcript=include_transcript, + duration=duration_str, + meeting_url=meeting_link, + ) + elif sync_mode == "page" and page_id: + synced_page_id = await client.add_meeting_as_page( + parent_page_id=page_id, + title=meeting.title or "Untitled Meeting", + date=meeting_date, + summary=meeting.summary or "", + action_items=action_items_detailed, + transcript=meeting.raw_transcript, + include_transcript=include_transcript, + duration=duration_str, + meeting_url=meeting_link, + ) + + if synced_page_id: + meeting.notion_page_id = synced_page_id + await db.commit() + synced_count += 1 + + except Exception as e: + logger.error(f"[Notion] Failed to sync meeting {meeting.id}: {e}") + errors.append({"meeting_id": meeting.id, "title": meeting.title, "error": str(e)}) + finally: + await client.aclose() + + return { + "status": "success", + "synced": synced_count, + "skipped": skipped_count, + "total": len(meetings), + "errors": errors, + } diff --git a/build_log.txt b/build_log.txt new file mode 100644 index 0000000..e49f989 Binary files /dev/null and b/build_log.txt differ diff --git a/build_log_no_cache.txt b/build_log_no_cache.txt new file mode 100644 index 0000000..6b143ac Binary files /dev/null and b/build_log_no_cache.txt differ diff --git a/libs/common/jira/__init__.py b/libs/common/jira/__init__.py new file mode 100644 index 0000000..de6fcec --- /dev/null +++ b/libs/common/jira/__init__.py @@ -0,0 +1,32 @@ +""" +libs/common/jira — JiraClient composed from domain mixins. + +Usage (unchanged from before): + from libs.common.jira_client import JiraClient, JiraSessionExpired + # or: + from libs.common.jira import JiraClient, JiraSessionExpired +""" + +from libs.common.jira.base import JiraClientBase, JiraSessionExpired, run_with_interactive_retry +from libs.common.jira.boards import BoardsMixin +from libs.common.jira.comments import CommentsMixin +from libs.common.jira.issues import IssuesMixin +from libs.common.jira.projects import ProjectsMixin +from libs.common.jira.users import UsersMixin + + +class JiraClient( + IssuesMixin, + BoardsMixin, + UsersMixin, + CommentsMixin, + ProjectsMixin, + JiraClientBase, +): + """ + Full Jira API client. + Composed from domain-specific mixins — see the individual modules for details. + """ + + +__all__ = ["JiraClient", "JiraSessionExpired", "run_with_interactive_retry"] diff --git a/libs/common/jira/base.py b/libs/common/jira/base.py new file mode 100644 index 0000000..a0c9a3f --- /dev/null +++ b/libs/common/jira/base.py @@ -0,0 +1,260 @@ +""" +JiraClient base infrastructure: auth, token refresh, HTTP request wrappers. +""" + +import asyncio +import logging +from typing import Any + +import httpx + +from apps.agents.agent_server.src.common.auth import load_token, save_token +from apps.agents.agent_server.src.common.cache_manager import CacheManager +from apps.agents.agent_server.src.common.config_manager import ConfigManager +from apps.agents.agent_server.src.common.phrases import get_random_phrase +from libs.common.auth import refresh_access_token + +logger = logging.getLogger(__name__) +config_manager = ConfigManager() + + +class JiraSessionExpired(Exception): + """Raised when the Jira OAuth refresh token is permanently invalid.""" + + +async def run_with_interactive_retry(func, *args, **kwargs): + """Run a coroutine with exponential backoff on transient errors (429, 5xx).""" + total_retries = 3 + backoff_factor = 1 + + for attempt in range(total_retries + 1): + try: + response = await func(*args, **kwargs) + if response.status_code in [429, 500, 502, 503, 504]: + if attempt == total_retries: + return response + sleep_time = backoff_factor * (2**attempt) + if response.status_code == 429: + retry_after = response.headers.get("Retry-After") + if retry_after: + try: + sleep_time = float(retry_after) + except ValueError: + pass + if sleep_time > 0.5: + phrase = get_random_phrase(round(sleep_time, 1)) + logger.info(f"[Jira] {phrase}") + await asyncio.sleep(sleep_time) + continue + return response + except (httpx.RequestError, httpx.TimeoutException): + if attempt == total_retries: + raise + await asyncio.sleep(backoff_factor * (2**attempt)) + return None + + +class JiraClientBase: + def __init__( + self, + client_id: str, + client_secret: str, + site_id: str | None = None, + headless: bool = False, + access_token: str | None = None, + refresh_token: str | None = None, + token_saver: Any = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.token_saver = token_saver + self.site_id = site_id + + if access_token: + self.access_token = access_token + self.refresh_token = refresh_token + else: + raise ValueError( + "Jira authentication required. No access token provided for the current user session." + ) + + self.cloud_id: str | None = None + self.base_url: str | None = None + self._cache: CacheManager | None = None + self._client = httpx.AsyncClient(timeout=30.0) + + async def __aenter__(self): + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def initialize(self): + """Fetch cloud_id and set base_url if not already set.""" + if not self.cloud_id: + self.cloud_id = await self._get_cloud_id(self.site_id) + self.base_url = f"https://api.atlassian.com/ex/jira/{self.cloud_id}/rest/api/3" + self._cache = CacheManager(cloud_id=self.cloud_id) + + async def close(self): + if self._client: + await self._client.aclose() + + async def _get_cloud_id(self, preferred_site_id=None): + """Fetch the authorized Jira cloud resource ID.""" + headers = {"Authorization": f"Bearer {self.access_token}", "Accept": "application/json"} + client = self._client + + async def _call(): + return await client.get( + "https://api.atlassian.com/oauth/token/accessible-resources", + headers=headers, + ) + + response = await run_with_interactive_retry(_call) + if response.status_code == 401: + await self._handle_refresh() + headers["Authorization"] = f"Bearer {self.access_token}" + response = await run_with_interactive_retry(_call) + + if response.status_code != 200: + raise Exception(f"Failed to get accessible resources: {response.text}") + + resources = response.json() + if not resources: + raise Exception("No Jira resources found. Please authorize the app for a site.") + + available_sites = {r["id"]: r for r in resources} + + if preferred_site_id and preferred_site_id in available_sites: + return preferred_site_id + if preferred_site_id: + logger.warning(f"Preferred site {preferred_site_id} not found in accessible resources.") + + default_site = config_manager.get_default_site_id() + if default_site and default_site in available_sites: + return default_site + + if len(resources) == 1: + site = resources[0] + config_manager.set_default_site_id(site["id"]) + config_manager.add_site(site["id"], site["name"], site["url"]) + return site["id"] + + tips = "\n".join([f"- {r['name']} (ID: {r['id']})" for r in resources]) + raise Exception( + f"Multiple Jira sites found. Please configure a default site.\nAvailable:\n{tips}" + ) + + async def _handle_refresh(self): + """Refresh the OAuth token and update the instance.""" + logger.info("Token expired, refreshing...") + + if self.refresh_token: + try: + new_token_data = await refresh_access_token( + self.client_id, self.client_secret, self.refresh_token + ) + self.access_token = new_token_data["access_token"] + self.refresh_token = new_token_data["refresh_token"] + if self.token_saver: + try: + token_data = { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + } + if asyncio.iscoroutinefunction(self.token_saver): + await self.token_saver(token_data) + else: + self.token_saver(token_data) + logger.info("New tokens persisted to backend.") + except Exception as e: + logger.error(f"Failed to persist refreshed tokens: {e}") + except Exception as e: + error_msg = str(e).lower() + logger.error(f"Failed to refresh user token: {e}") + if "invalid_grant" in error_msg or "unauthorized_client" in error_msg: + logger.warning("Permanent refresh failure. Signaling token invalidation.") + if self.token_saver: + try: + if asyncio.iscoroutinefunction(self.token_saver): + await self.token_saver(None) + else: + self.token_saver(None) + except Exception as saver_error: + logger.error(f"Failed to signal token invalidation: {saver_error}") + raise JiraSessionExpired("Session expired. Please re-authenticate.") from e + return + + # Legacy / CLI Mode + token_data = load_token() + if not token_data or "refresh_token" not in token_data: + raise Exception("No refresh token available. Please re-authenticate.") + new_token_data = await refresh_access_token( + self.client_id, self.client_secret, token_data["refresh_token"] + ) + save_token(new_token_data) + self.access_token = new_token_data["access_token"] + + async def request(self, method: str, endpoint: str, **kwargs) -> httpx.Response: + """Generic Jira API request with auto-refresh and rate-limit logging.""" + if not self.cloud_id: + await self.initialize() + + url = f"{self.base_url}/{endpoint.lstrip('/')}" + headers = kwargs.pop("headers", {}) + headers["Authorization"] = f"Bearer {self.access_token}" + headers["Accept"] = "application/json" + if "files" not in kwargs: + headers["Content-Type"] = "application/json" + + async def _make_request(): + return await self._client.request(method, url, headers=headers, **kwargs) + + try: + logger.info(f"JIRA REQUEST: {method} {url}") + response = await run_with_interactive_retry(_make_request) + + limit = response.headers.get("X-RateLimit-Limit") + remaining = response.headers.get("X-RateLimit-Remaining") + if limit and remaining: + try: + if int(remaining) < int(limit) * 0.1: + logger.warning(f"Jira API Rate Limit Low! {remaining}/{limit} remaining.") + except ValueError: + pass + + if response.status_code == 401: + await self._handle_refresh() + headers["Authorization"] = f"Bearer {self.access_token}" + response = await run_with_interactive_retry(_make_request) + + if not response.is_success: + logger.error(f"Jira API Error {response.status_code} at {url}: {response.text}") + + return response + except Exception as e: + logger.error(f"Request failed: {e}") + raise + + async def agile_request(self, method: str, endpoint: str, **kwargs) -> httpx.Response: + """Request to the Jira Software Agile API (rest/agile/1.0/).""" + if not self.cloud_id: + await self.initialize() + agile_base = f"https://api.atlassian.com/ex/jira/{self.cloud_id}/rest/agile/1.0" + url = f"{agile_base}/{endpoint.lstrip('/')}" + headers = kwargs.pop("headers", {}) + headers["Authorization"] = f"Bearer {self.access_token}" + headers["Accept"] = "application/json" + + async def _make_request(): + return await self._client.request(method, url, headers=headers, **kwargs) + + logger.info(f"JIRA AGILE REQUEST: {method} {url}") + res = await run_with_interactive_retry(_make_request) + if res.status_code == 401: + await self._handle_refresh() + headers["Authorization"] = f"Bearer {self.access_token}" + res = await run_with_interactive_retry(_make_request) + return res diff --git a/libs/common/jira/boards.py b/libs/common/jira/boards.py new file mode 100644 index 0000000..34184b1 --- /dev/null +++ b/libs/common/jira/boards.py @@ -0,0 +1,46 @@ +"""BoardsMixin — Jira Software Agile boards and sprints.""" + + +class BoardsMixin: + async def get_boards_for_project(self, project_key: str) -> list[dict]: + """Return all boards for a project. Cached for 15 minutes.""" + cache_key = f"boards:{project_key}" + if self._cache: + cached = await self._cache.get(cache_key) + if cached: + return cached + res = await self.agile_request("GET", f"board?projectKeyOrId={project_key}") + res.raise_for_status() + boards = [{"id": b["id"], "name": b["name"]} for b in res.json().get("values", [])] + if self._cache: + await self._cache.set(cache_key, boards, ttl_seconds=900) + return boards + + async def move_issue_to_sprint(self, sprint_id: int, issue_key: str) -> None: + """Move an issue to a sprint using the Agile API.""" + res = await self.agile_request( + "POST", f"sprint/{sprint_id}/issue", json={"issues": [issue_key]} + ) + res.raise_for_status() + + async def get_sprints_for_board(self, board_id: int) -> list[dict]: + """Return active and future sprints for a board. Cached for 5 minutes.""" + cache_key = f"sprints:{board_id}" + if self._cache: + cached = await self._cache.get(cache_key) + if cached: + return cached + res = await self.agile_request("GET", f"board/{board_id}/sprint?state=active,future") + res.raise_for_status() + sprints = [ + { + "id": s["id"], + "name": s["name"], + "state": s["state"], + "boardId": s.get("originBoardId", board_id), + } + for s in res.json().get("values", []) + ] + if self._cache: + await self._cache.set(cache_key, sprints, ttl_seconds=300) + return sprints diff --git a/libs/common/jira/comments.py b/libs/common/jira/comments.py new file mode 100644 index 0000000..e326024 --- /dev/null +++ b/libs/common/jira/comments.py @@ -0,0 +1,76 @@ +"""CommentsMixin — Jira issue comments and worklogs.""" + +from apps.agents.agent_server.src.common.adf_utils import adf_to_markdown, text_to_adf + + +class CommentsMixin: + def _format_comment(self, c: dict) -> dict: + body = c.get("body") + content = adf_to_markdown(body) if isinstance(body, dict) else str(body) + return { + "id": c.get("id"), + "author": c.get("author", {}).get("displayName", "Unknown"), + "author_id": c.get("author", {}).get("accountId"), + "created": c.get("created"), + "body": content, + } + + async def get_comments(self, issue_id_or_key: str) -> list[dict]: + res = await self.request("GET", f"issue/{issue_id_or_key}/comment") + res.raise_for_status() + return [self._format_comment(c) for c in res.json().get("comments", [])] + + async def read_comments(self, issue_id_or_key: str, comment_ids: list[str]) -> str: + """Fetch and format specific comments by ID.""" + output = [] + for cid in comment_ids: + try: + res = await self.request("GET", f"issue/{issue_id_or_key}/comment/{cid}") + if res.is_success: + data = res.json() + body = adf_to_markdown(data.get("body")) + author = data.get("author", {}).get("displayName", "Unknown") + created = data.get("created") + output.append(f"### Comment {cid} by {author} ({created})\n{body}\n") + else: + output.append(f"Error fetching comment {cid}: {res.status_code}") + except Exception as e: + output.append(f"Error fetching comment {cid}: {e}") + return "\n---\n".join(output) + + async def add_comment(self, issue_id_or_key: str, body: str) -> dict: + payload = { + "body": { + "type": "doc", + "version": 1, + "content": [{"type": "paragraph", "content": [{"type": "text", "text": body}]}], + } + } + res = await self.request("POST", f"issue/{issue_id_or_key}/comment", json=payload) + res.raise_for_status() + return res.json() + + async def update_comment(self, issue_key: str, comment_id: str, body_text: str) -> dict: + payload = {"body": text_to_adf(body_text)} + res = await self.request("PUT", f"issue/{issue_key}/comment/{comment_id}", json=payload) + res.raise_for_status() + return res.json() + + async def delete_comment(self, issue_key: str, comment_id: str) -> dict: + res = await self.request("DELETE", f"issue/{issue_key}/comment/{comment_id}") + res.raise_for_status() + return {"status": "success", "message": f"Comment {comment_id} deleted"} + + async def add_worklog( + self, issue_id_or_key: str, time_spent_string: str, comment: str | None = None + ) -> dict: + payload = {"timeSpent": time_spent_string} + if comment: + payload["comment"] = { + "type": "doc", + "version": 1, + "content": [{"type": "paragraph", "content": [{"type": "text", "text": comment}]}], + } + res = await self.request("POST", f"issue/{issue_id_or_key}/worklog", json=payload) + res.raise_for_status() + return res.json() diff --git a/libs/common/jira/issues.py b/libs/common/jira/issues.py new file mode 100644 index 0000000..4399e2f --- /dev/null +++ b/libs/common/jira/issues.py @@ -0,0 +1,385 @@ +"""IssuesMixin — Jira issue CRUD, transitions, search, attachments.""" + +import logging + +from apps.agents.agent_server.src.common.adf_utils import adf_to_markdown + +logger = logging.getLogger(__name__) + + +class IssuesMixin: + async def search_issues( + self, + jql: str, + fields: list[str] | None = None, + start_at: int = 0, + max_results: int = 50, + ) -> dict: + all_issues = [] + PAGE_SIZE = 100 + target_count = max_results if max_results else float("inf") + next_token = None + data = {} + + while len(all_issues) < target_count: + remaining = target_count - len(all_issues) + payload = { + "jql": jql, + "maxResults": min(remaining, PAGE_SIZE), + "fields": fields + or [ + "summary", + "status", + "issuetype", + "created", + "priority", + "assignee", + "reporter", + ], + } + if next_token: + payload["nextPageToken"] = next_token + + res = await self.request("POST", "search/jql", json=payload) + res.raise_for_status() + data = res.json() + issues = data.get("issues", []) + all_issues.extend(issues) + next_token = data.get("nextPageToken") + if not issues or not next_token: + break + + return { + "startAt": start_at, + "maxResults": len(all_issues), + "total": data.get("total", len(all_issues)), + "issues": all_issues, + } + + async def validate_jql(self, jql: str) -> dict: + try: + res = await self.request( + "POST", "jql/parse", json={"queries": [jql], "validation": "strict"} + ) + if not res.is_success: + return {"valid": False, "errors": [f"API Error: {res.status_code} {res.text}"]} + data = res.json() + queries = data.get("queries", []) + if not queries: + return {"valid": False, "errors": ["No validation result returned from API."]} + errors = queries[0].get("errors", []) + return {"valid": not errors, "errors": errors} + except Exception as e: + logger.error(f"JQL Validation failed: {e}") + return {"valid": False, "errors": [str(e)]} + + async def get_issue(self, issue_id_or_key: str) -> dict: + params = { + "expand": "renderedFields,names,schema,operations,editmeta,changelog,versionedRepresentations" + } + res = await self.request("GET", f"issue/{issue_id_or_key}", params=params) + res.raise_for_status() + data = res.json() + fields = data.get("fields", {}) + + issue_view = { + "key": data.get("key"), + "id": data.get("id"), + "summary": fields.get("summary"), + "status": fields.get("status", {}).get("name"), + "priority": fields.get("priority", {}).get("name"), + "assignee": fields.get("assignee", {}).get("displayName", "Unassigned"), + "reporter": fields.get("reporter", {}).get("displayName", "Unknown"), + "created": fields.get("created"), + "updated": fields.get("updated"), + "labels": fields.get("labels", []), + "issuetype": fields.get("issuetype", {}).get("name"), + "url": f"{self.base_url.replace('/rest/api/3', '')}/browse/{data.get('key')}", + } + + desc_node = fields.get("description") + issue_view["description"] = adf_to_markdown(desc_node) if desc_node else "" + + raw_comments = fields.get("comment", {}).get("comments", []) + total_comments = len(raw_comments) + processed_comments = [] + + def _format_toc(c): + author = c.get("author", {}).get("displayName", "Unknown") + created = c.get("created")[:10] + body = c.get("body") + preview = "..." + if isinstance(body, dict): + try: + preview = body["content"][0]["content"][0]["text"][:50] + "..." + except (KeyError, IndexError, TypeError): + preview = "[Complex User Content]" + return f'[ID: {c.get("id")}] {created} | @{author} | "{preview}"' + + if total_comments <= 4: + processed_comments = [self._format_comment(c) for c in raw_comments] + else: + processed_comments.append(self._format_comment(raw_comments[0])) + middle_comments = raw_comments[1:-3] + processed_comments.append( + f"--- HIDDEN {len(middle_comments)} COMMENTS (Use get_comments to view all) ---" + ) + for c in middle_comments: + processed_comments.append(_format_toc(c)) + processed_comments.append("--- END HIDDEN ---") + for c in raw_comments[-3:]: + processed_comments.append(self._format_comment(c)) + + issue_view["comments"] = processed_comments + issue_view["subtasks"] = [ + { + "key": s["key"], + "summary": s["fields"]["summary"], + "status": s["fields"]["status"]["name"], + "issuetype": s["fields"]["issuetype"]["name"], + } + for s in fields.get("subtasks", []) + ] + + issue_view["links"] = [] + for link in fields.get("issuelinks", []): + if "outwardIssue" in link: + issue_view["links"].append( + { + "type": link["type"]["outward"], + "key": link["outwardIssue"]["key"], + "status": link["outwardIssue"]["fields"]["status"]["name"], + "direction": "outward", + } + ) + elif "inwardIssue" in link: + issue_view["links"].append( + { + "type": link["type"]["inward"], + "key": link["inwardIssue"]["key"], + "status": link["inwardIssue"]["fields"]["status"]["name"], + "direction": "inward", + } + ) + + return issue_view + + async def search_issue_history(self, issue_id_or_key: str, query: str) -> str: + res = await self.request( + "GET", + f"issue/{issue_id_or_key}", + params={"expand": "changelog", "fields": "comment,summary"}, + ) + res.raise_for_status() + data = res.json() + query_lower = query.lower() + results = [] + + for c in data.get("fields", {}).get("comment", {}).get("comments", []): + body_md = adf_to_markdown(c.get("body", {})) + if query_lower in body_md.lower(): + cid = c.get("id") + author = c.get("author", {}).get("displayName") + results.append(f"[Comment {cid}] {author}: found match in body.") + idx = body_md.lower().find(query_lower) + snippet = body_md[ + max(0, idx - 30) : min(len(body_md), idx + 30 + len(query)) + ].replace("\n", " ") + results.append(f' Context: "...{snippet}..."') + + for history in data.get("changelog", {}).get("histories", []): + author = history.get("author", {}).get("displayName") + created = history.get("created") + for item in history.get("items", []): + f = item.get("field", "") + t_from, t_to = str(item.get("fromString", "")), str(item.get("toString", "")) + if any(query_lower in x.lower() for x in [f, t_from, t_to]): + results.append( + f"[Changelog] {created} {author} changed '{f}': '{t_from}' -> '{t_to}'" + ) + + return ( + "\n".join(results) if results else f"No matches found for '{query}' in issue history." + ) + + async def create_issue( + self, + project_key: str, + summary: str, + description: str, + issue_type: str = "Task", + assignee_id: str | None = None, + parent_key: str | None = None, + priority: str | None = None, + labels: list[str] | None = None, + due_date: str | None = None, + sprint_id: int | None = None, + ) -> dict: + fields = { + "project": {"key": project_key}, + "summary": summary, + "description": { + "type": "doc", + "version": 1, + "content": [ + {"type": "paragraph", "content": [{"type": "text", "text": description}]} + ], + }, + "issuetype": {"name": issue_type}, + } + if assignee_id: + fields["assignee"] = {"accountId": assignee_id} + if parent_key: + fields["parent"] = {"key": parent_key} + if priority: + fields["priority"] = {"name": priority} + if labels: + fields["labels"] = labels + if due_date: + fields["duedate"] = due_date + if sprint_id: + fields["customfield_10020"] = sprint_id + + res = await self.request("POST", "issue", json={"fields": fields}) + res.raise_for_status() + return res.json() + + async def update_issue(self, issue_id_or_key: str, fields_dict: dict) -> dict: + processed_fields = fields_dict.copy() + assignee_msg = None + + if "assignee" in processed_fields: + val = processed_fields.pop("assignee") + account_id = val["accountId"] if isinstance(val, dict) and "accountId" in val else val + if isinstance(account_id, str): + try: + res_assign = await self.request( + "PUT", f"issue/{issue_id_or_key}/assignee", json={"accountId": account_id} + ) + res_assign.raise_for_status() + except Exception as e: + assignee_msg = f"Failed to assign issue: {e}" + + if "customfield_10020" in processed_fields: + val = processed_fields["customfield_10020"] + if isinstance(val, int) and val: + processed_fields["customfield_10020"] = {"id": val} + elif not val: + del processed_fields["customfield_10020"] + + for key, wrapper in [("priority", "name"), ("issuetype", "name")]: + if key in processed_fields and isinstance(processed_fields[key], str): + processed_fields[key] = {wrapper: processed_fields[key]} + + if "description" in processed_fields and isinstance(processed_fields["description"], str): + processed_fields["description"] = { + "type": "doc", + "version": 1, + "content": [ + { + "type": "paragraph", + "content": [{"type": "text", "text": processed_fields["description"]}], + } + ], + } + + if processed_fields: + res = await self.request( + "PUT", f"issue/{issue_id_or_key}", json={"fields": processed_fields} + ) + res.raise_for_status() + if res.status_code == 204: + return {"status": "success", "assignee_error": assignee_msg} + return res.json() + return {"status": "success", "assignee_error": assignee_msg} + + async def bulk_update_issues(self, issue_keys: list[str], fields_dict: dict) -> dict: + results: dict = {"success": [], "failed": []} + for key in issue_keys: + try: + await self.update_issue(key, fields_dict) + results["success"].append(key) + except Exception as e: + results["failed"].append({"key": key, "error": str(e)}) + return results + + async def delete_issue(self, issue_id_or_key: str) -> None: + res = await self.request("DELETE", f"issue/{issue_id_or_key}") + res.raise_for_status() + + async def get_transitions(self, issue_id_or_key: str) -> dict: + res = await self.request("GET", f"issue/{issue_id_or_key}/transitions") + res.raise_for_status() + return res.json() + + async def transition_issue(self, issue_id_or_key: str, transition_id: str) -> None: + res = await self.request( + "POST", + f"issue/{issue_id_or_key}/transitions", + json={"transition": {"id": str(transition_id)}}, + ) + if res.status_code == 404: + raise Exception( + f"Issue {issue_id_or_key} not found (404). Verify the Issue Key and permissions." + ) + res.raise_for_status() + + async def bulk_transition_issues(self, issue_keys: list[str], transition_id: str) -> dict: + results: dict = {"success": [], "failed": []} + for key in issue_keys: + try: + await self.transition_issue(key, transition_id) + results["success"].append(key) + except Exception as e: + results["failed"].append({"key": key, "error": str(e)}) + return results + + async def update_status(self, issue_key: str, status_name: str) -> dict: + """Transition an issue to a status by name (fuzzy match).""" + data = await self.get_transitions(issue_key) + transitions = data.get("transitions", []) + status_lower = status_name.lower() + target_id = next((t["id"] for t in transitions if t["name"].lower() == status_lower), None) + if not target_id: + target_id = next( + (t["id"] for t in transitions if status_lower in t["name"].lower()), None + ) + if not target_id: + names = [t["name"] for t in transitions] + msg = f"Transition '{status_name}' not found for {issue_key}." + if names: + msg += f" Available: {', '.join(names)}" + return {"error": msg} + await self.transition_issue(issue_key, target_id) + return {"status": "success", "message": f"Moved {issue_key} to {status_name}"} + + async def get_issue_details(self, issue_key: str) -> dict: + """Alias for get_issue to match old client interface.""" + return await self.get_issue(issue_key) + + async def upload_attachment(self, issue_key: str, file_content: bytes, filename: str) -> dict: + from libs.common.jira.base import run_with_interactive_retry + + url = f"issue/{issue_key}/attachments" + headers = { + "Authorization": f"Bearer {self.access_token}", + "Accept": "application/json", + "X-Atlassian-Token": "no-check", + } + files = {"file": (filename, file_content)} + + async def _call(): + return await self._client.post(f"{self.base_url}/{url}", headers=headers, files=files) + + res = await run_with_interactive_retry(_call) + if res.status_code == 401: + await self._handle_refresh() + headers["Authorization"] = f"Bearer {self.access_token}" + res = await run_with_interactive_retry(_call) + res.raise_for_status() + return res.json() + + async def delete_attachment(self, attachment_id: str) -> dict: + res = await self.request("DELETE", f"attachment/{attachment_id}") + res.raise_for_status() + return {"status": "success", "message": f"Attachment {attachment_id} deleted"} diff --git a/libs/common/jira/projects.py b/libs/common/jira/projects.py new file mode 100644 index 0000000..5e48f95 --- /dev/null +++ b/libs/common/jira/projects.py @@ -0,0 +1,144 @@ +"""ProjectsMixin — Jira project metadata, fields, and context generation.""" + +import logging + +logger = logging.getLogger(__name__) + + +class ProjectsMixin: + async def get_projects(self) -> list[dict]: + res = await self.request("GET", "project") + res.raise_for_status() + return res.json() + + async def get_priorities(self) -> list[dict]: + res = await self.request("GET", "priority") + res.raise_for_status() + return res.json() + + async def get_fields(self) -> list[dict]: + cache_key = "jira_fields" + if self._cache: + cached = await self._cache.get(cache_key) + if cached: + return cached + res = await self.request("GET", "field") + res.raise_for_status() + data = res.json() + if self._cache: + await self._cache.set(cache_key, data, ttl_seconds=86400) + return data + + async def get_project_issue_types(self, project_key: str) -> list[dict]: + """Fetch issue types for a project (id, name, iconUrl, subtask).""" + res = await self.request("GET", f"project/{project_key}") + res.raise_for_status() + return [ + { + "id": it["id"], + "name": it["name"], + "iconUrl": it.get("iconUrl"), + "description": it.get("description"), + "subtask": it.get("subtask"), + } + for it in res.json().get("issueTypes", []) + ] + + async def get_project_context(self, project_key: str) -> dict: + """Return project metadata (name, issue types, statuses). Cached 1 hour.""" + cache_key = f"project_meta_{project_key}" + if self._cache: + cached = await self._cache.get(cache_key) + if cached: + return cached + + res_p = await self.request("GET", f"project/{project_key}") + res_p.raise_for_status() + project_data = res_p.json() + + res_s = await self.request("GET", f"project/{project_key}/statuses") + res_s.raise_for_status() + statuses_data = res_s.json() + + context = { + "key": project_key, + "name": project_data["name"], + "issue_types": [ + {"name": it["name"], "statuses": [s["name"] for s in it["statuses"]]} + for it in statuses_data + ], + } + if self._cache: + await self._cache.set(cache_key, context, ttl_seconds=3600) + return context + + async def get_project_metadata(self, project_key: str) -> dict: + try: + ctx = await self.get_project_context(project_key) + return { + "key": project_key, + "issue_types": [it["name"] for it in ctx.get("issue_types", [])], + } + except Exception as e: + logger.error(f"Error in get_project_metadata: {e}") + return {"error": str(e)} + + async def get_create_metadata(self, project_key: str, issue_type_id: str) -> dict: + res = await self.request( + "GET", f"issue/createmeta/{project_key}/issuetypes/{issue_type_id}" + ) + res.raise_for_status() + data = res.json() + fields = data.get("values") or data.get("fields") or [] + + items = ( + fields.items() + if isinstance(fields, dict) + else [(f.get("fieldId", f.get("id", "unknown")), f) for f in fields] + if isinstance(fields, list) + else [] + ) + required_fields = [ + { + "id": field_key, + "name": field_val.get("name"), + "required": True, + "schema": field_val.get("schema"), + "allowedValues": field_val.get("allowedValues", []), + } + for field_key, field_val in items + if field_val.get("required") + ] + return { + "project_key": project_key, + "issue_type_id": issue_type_id, + "required_fields": required_fields, + "all_fields_count": len(fields) if isinstance(fields, (list, dict)) else 0, + } + + async def get_project_summary_text(self) -> str: + """Generate a human-readable context summary of projects and active tasks.""" + try: + projects = await self.get_projects() + summary_lines = [] + for project in projects: + key, name = project["key"], project["name"] + summary_lines.append(f"Project: {name} (Key: {key})") + jql = f"project = {key} AND statusCategory in ('To Do', 'In Progress') ORDER BY updated DESC" + search_results = await self.search_issues(jql, max_results=10) + issues = search_results.get("issues", []) + if issues: + summary_lines.append(" Active Tasks:") + for issue in issues: + f = issue["fields"] + assignee = ( + f["assignee"]["displayName"] if f.get("assignee") else "Unassigned" + ) + summary_lines.append( + f" - {issue['key']}: {f['summary']} (Assignee: {assignee}, Status: {f['status']['name']})" + ) + summary_lines.append("") + return "\n".join(summary_lines) + except Exception as e: + logger.error(f"Error fetching context: {e}") + return f"Error fetching context: {str(e)}" diff --git a/libs/common/jira/users.py b/libs/common/jira/users.py new file mode 100644 index 0000000..6aefedd --- /dev/null +++ b/libs/common/jira/users.py @@ -0,0 +1,121 @@ +"""UsersMixin — Jira user search and identity resolution.""" + +import logging + +logger = logging.getLogger(__name__) + + +class UsersMixin: + async def get_myself(self) -> dict: + res = await self.request("GET", "myself") + res.raise_for_status() + return res.json() + + async def get_myself_context(self) -> dict: + """Return cached user identity context for AI prompts.""" + cache_key = "user_myself" + if self._cache: + cached = await self._cache.get(cache_key) + if cached: + return cached + me = await self.get_myself() + context = { + "name": me["displayName"], + "accountId": me["accountId"], + "email": me.get("emailAddress", "hidden"), + } + if self._cache: + await self._cache.set(cache_key, context, ttl_seconds=3600) + return context + + async def find_user_account_id(self, name: str) -> str | None: + users = await self.search_users(name) + if isinstance(users, list): + for user in users: + if user.get("accountType") == "atlassian": + return user.get("accountId") + return None + + async def search_users(self, query: str, max_results: int = 10) -> list[dict]: + res = await self.request( + "GET", "user/search", params={"query": query, "maxResults": max_results} + ) + res.raise_for_status() + return res.json() + + async def get_assignable_users( + self, project_key: str, query: str | None = None, max_results: int = 50 + ) -> list[dict]: + params = {"project": project_key, "maxResults": max_results} + if query: + params["query"] = query + res = await self.request("GET", "user/assignable/search", params=params) + res.raise_for_status() + return res.json() + + async def search_users_async( + self, query: str, max_results: int = 10, project_key: str | None = None + ) -> list[dict]: + if project_key: + return await self.get_assignable_users(project_key, query, max_results) + return await self.search_users(query, max_results) + + async def get_user_active_tasks(self, email: str, project_key: str | None = None) -> list[dict]: + """Get active tasks (In Progress, To Do) for a specific user by email.""" + if not email: + return [] + + account_id = None + + try: + myself = await self.get_myself() + if myself and myself.get("accountId"): + account_id = myself.get("accountId") + except Exception as e: + logger.info(f"'myself' check failed: {e}") + + if not account_id: + try: + user_search = await self.search_users(email) + if user_search and isinstance(user_search, list): + account_id = user_search[0].get("accountId") + except Exception as e: + logger.info(f"Email search failed: {e}") + + if not account_id and project_key: + try: + assignable = await self.get_assignable_users(project_key, max_results=100) + if assignable: + for u in assignable: + if u.get("emailAddress", "").lower() == email.lower(): + account_id = u.get("accountId") + break + if email.split("@")[0].lower() in u.get("displayName", "").lower(): + account_id = u.get("accountId") + break + except Exception as e: + logger.info(f"Assignable search failed: {e}") + + if not account_id: + logger.warning(f"Could not find user accountId for {email}") + return [] + + try: + results = await self.search_issues( + f"assignee = '{account_id}' AND statusCategory != Done ORDER BY updated DESC", + max_results=10, + ) + return [ + { + "key": i["key"], + "summary": i["fields"].get("summary"), + "description": i["fields"].get("description"), + "status": i["fields"].get("status", {}).get("name"), + "priority": i["fields"].get("priority", {}).get("name"), + "updated": i["fields"].get("updated"), + } + for i in results.get("issues", []) + ] + except Exception as e: + logger.error(f"Error fetching active tasks for {email}: {e}") + return [] diff --git a/libs/common/jira_client.py b/libs/common/jira_client.py index 8c71022..1e626f2 100644 --- a/libs/common/jira_client.py +++ b/libs/common/jira_client.py @@ -1,1076 +1,10 @@ """ -libs/common/jira_client.py — Canonical JiraClient (single source of truth). - -Moved from apps/agents/agent_server/src/common/jira_client.py. -Old location is a thin re-export shim. - -Note: imports adf_utils, cache_manager, config_manager, phrases from -apps.agents.agent_server.src.common — acceptable since those utils have no -reverse dependency on libs.common.jira_client. +Shim: JiraClient moved to libs/common/jira/. +All existing imports continue to work unchanged. """ -import asyncio -import logging -import sys -from typing import Any - -import httpx - -from apps.agents.agent_server.src.common.adf_utils import adf_to_markdown, text_to_adf -from apps.agents.agent_server.src.common.auth import load_token, save_token -from apps.agents.agent_server.src.common.cache_manager import CacheManager -from apps.agents.agent_server.src.common.config_manager import ConfigManager -from apps.agents.agent_server.src.common.phrases import get_random_phrase -from libs.common.auth import refresh_access_token - -config_manager = ConfigManager() - - -logger = logging.getLogger(__name__) - - -class JiraSessionExpired(Exception): - """Raised when the Jira OAuth refresh token is permanently invalid.""" - - -async def run_with_interactive_retry(func, *args, **kwargs): - """ - Helper to run a function with retries and interactive feedback. - Mimics InteractiveRetry logic from requests. - """ - total_retries = 3 - backoff_factor = 1 - - for attempt in range(total_retries + 1): - try: - response = await func(*args, **kwargs) - if response.status_code in [429, 500, 502, 503, 504]: - if attempt == total_retries: - return response - - # Calculate sleep time - sleep_time = backoff_factor * (2**attempt) - if response.status_code == 429: - retry_after = response.headers.get("Retry-After") - if retry_after: - try: - sleep_time = float(retry_after) - except ValueError: - pass - - if sleep_time > 0.5: - phrase = get_random_phrase(round(sleep_time, 1)) - logger.info(f"[Jira] {phrase}") - - await asyncio.sleep(sleep_time) - continue - return response - except (httpx.RequestError, httpx.TimeoutException) as e: - if attempt == total_retries: - raise e - sleep_time = backoff_factor * (2**attempt) - await asyncio.sleep(sleep_time) - return None # Should not reach here - - -class JiraClient: - def __init__( - self, - client_id: str, - client_secret: str, - site_id: str | None = None, - headless: bool = False, - access_token: str | None = None, - refresh_token: str | None = None, - token_saver: Any = None, # Callback to save new tokens - ): - self.client_id = client_id - self.client_secret = client_secret - self.token_saver = token_saver - self.site_id = site_id - - # If tokens provided (SaaS mode), use them; otherwise load from file (legacy) - if access_token: - self.access_token = access_token - self.refresh_token = refresh_token - else: - raise ValueError( - "Jira authentication required. No access token provided for the current user session." - ) - - self.cloud_id: str | None = None - self.base_url: str | None = None - self._cache: CacheManager | None = ( - None # initialised in initialize() once cloud_id is known - ) - self._client = httpx.AsyncClient(timeout=30.0) - - async def __aenter__(self): - await self.initialize() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - async def initialize(self): - """Initializes the client, fetches cloud_id if not set.""" - if not self.cloud_id: - self.cloud_id = await self._get_cloud_id(self.site_id) - self.base_url = f"https://api.atlassian.com/ex/jira/{self.cloud_id}/rest/api/3" - # Instantiate per-tenant cache once cloud_id is known - self._cache = CacheManager(cloud_id=self.cloud_id) - - async def close(self): - if self._client: - await self._client.aclose() - - async def _get_cloud_id(self, preferred_site_id=None): - """Fetches the authorized cloud resource ID (Site ID).""" - headers = { - "Authorization": f"Bearer {self.access_token}", - "Accept": "application/json", - } - - # We use a temporary client if initialize isn't fully called yet or just use self._client - client = self._client - - async def _call(): - return await client.get( - "https://api.atlassian.com/oauth/token/accessible-resources", - headers=headers, - ) - - response = await run_with_interactive_retry(_call) - - if response.status_code == 401: - await self._handle_refresh() - headers["Authorization"] = f"Bearer {self.access_token}" - response = await run_with_interactive_retry(_call) - - if response.status_code != 200: - raise Exception(f"Failed to get accessible resources: {response.text}") - - resources = response.json() - if not resources: - raise Exception("No Jira resources found. Please authorize the app for a site.") - - # Site Resolution Logic - available_sites = {r["id"]: r for r in resources} - - # 1. Try passed site_id - if preferred_site_id: - if preferred_site_id in available_sites: - return preferred_site_id - else: - logger.warning( - f"Preferred site {preferred_site_id} not found in accessible resources." - ) - - # 2. Try configured default - default_site = config_manager.get_default_site_id() - if default_site and default_site in available_sites: - return default_site - - # 3. If only one exists, use it - if len(resources) == 1: - site = resources[0] - # Auto-save validation - config_manager.set_default_site_id(site["id"]) - config_manager.add_site(site["id"], site["name"], site["url"]) - return site["id"] - - # 4. Ambiguous - tips = "\n".join([f"- {r['name']} (ID: {r['id']})" for r in resources]) - raise Exception( - f"Multiple Jira sites found. Please configure a default site.\nAvailable:\n{tips}" - ) - - async def _handle_refresh(self): - """Refreshes the token and updates the instance.""" - logger.info("Token expired, refreshing...") - - if self.refresh_token: - try: - new_token_data = await refresh_access_token( - self.client_id, self.client_secret, self.refresh_token - ) - self.access_token = new_token_data["access_token"] - self.refresh_token = new_token_data["refresh_token"] - - if self.token_saver: - try: - token_data = { - "access_token": self.access_token, - "refresh_token": self.refresh_token, - } - if asyncio.iscoroutinefunction(self.token_saver): - await self.token_saver(token_data) - else: - self.token_saver(token_data) - logger.info("New tokens persisted to backend.") - except Exception as e: - logger.error(f"Failed to persist refreshed tokens: {e}") - - except Exception as e: - error_msg = str(e).lower() - logger.error(f"Failed to refresh user token: {e}") - if "invalid_grant" in error_msg or "unauthorized_client" in error_msg: - logger.warning( - "Permanent refresh failure detected. Signaling token invalidation." - ) - if self.token_saver: - try: - if asyncio.iscoroutinefunction(self.token_saver): - await self.token_saver(None) - else: - self.token_saver(None) - except Exception as saver_error: - logger.error(f"Failed to signal token invalidation: {saver_error}") - - raise JiraSessionExpired("Session expired. Please re-authenticate.") from e - return - - # Legacy / CLI Mode - token_data = load_token() - if not token_data or "refresh_token" not in token_data: - raise Exception("No refresh token available. Please re-authenticate.") - - new_token_data = await refresh_access_token( - self.client_id, self.client_secret, token_data["refresh_token"] - ) - save_token(new_token_data) - self.access_token = new_token_data["access_token"] - - async def request(self, method: str, endpoint: str, **kwargs) -> httpx.Response: - """Generic request wrapper with auto-refresh and rate limit handling.""" - if not self.cloud_id: - await self.initialize() - - url = f"{self.base_url}/{endpoint.lstrip('/')}" - headers = kwargs.pop("headers", {}) - headers["Authorization"] = f"Bearer {self.access_token}" - headers["Accept"] = "application/json" - if "files" not in kwargs: - headers["Content-Type"] = "application/json" - - async def _make_request(): - return await self._client.request(method, url, headers=headers, **kwargs) - - try: - logger.info(f"JIRA REQUEST: {method} {url}") - response = await run_with_interactive_retry(_make_request) - - # --- Rate Limit Handling --- - limit = response.headers.get("X-RateLimit-Limit") - remaining = response.headers.get("X-RateLimit-Remaining") - - if limit and remaining: - try: - limit_int = int(limit) - remaining_int = int(remaining) - logger.debug(f"Rate Limit: {remaining_int}/{limit_int}") - if remaining_int < (limit_int * 0.1): - logger.warning( - f"Jira API Rate Limit Low! {remaining_int}/{limit_int} remaining." - ) - except ValueError: - pass - - if response.status_code == 401: - await self._handle_refresh() - headers["Authorization"] = f"Bearer {self.access_token}" - response = await run_with_interactive_retry(_make_request) - - if response.status_code == 429: - retry_after = response.headers.get("Retry-After") - logger.error(f"Rate Limit Exceeded (429). Retry-After: {retry_after}") - - if not response.is_success: - logger.error(f"Jira API Error {response.status_code} at {url}: {response.text}") - - return response - except Exception as e: - logger.error(f"Request failed: {e}") - raise - - async def agile_request(self, method: str, endpoint: str, **kwargs): - """Make a request to the Jira Agile API (rest/agile/1.0/).""" - if not self.cloud_id: - await self.initialize() - agile_base = f"https://api.atlassian.com/ex/jira/{self.cloud_id}/rest/agile/1.0" - url = f"{agile_base}/{endpoint.lstrip('/')}" - headers = kwargs.pop("headers", {}) - headers["Authorization"] = f"Bearer {self.access_token}" - headers["Accept"] = "application/json" - - async def _make_request(): - return await self._client.request(method, url, headers=headers, **kwargs) - - logger.info(f"JIRA AGILE REQUEST: {method} {url}") - res = await run_with_interactive_retry(_make_request) - - if res.status_code == 401: - logger.error(f"JIRA AGILE 401 body: {res.text}") - await self._handle_refresh() - headers["Authorization"] = f"Bearer {self.access_token}" - res = await run_with_interactive_retry(_make_request) - if res.status_code == 401: - logger.error(f"JIRA AGILE 401 after refresh body: {res.text}") - - return res - - async def get_boards_for_project(self, project_key: str) -> list[dict]: - """Return all boards for a project. Cached for 15 minutes.""" - cache_key = f"boards:{project_key}" - if self._cache: - cached = await self._cache.get(cache_key) - if cached: - return cached - res = await self.agile_request("GET", f"board?projectKeyOrId={project_key}") - res.raise_for_status() - boards = [{"id": b["id"], "name": b["name"]} for b in res.json().get("values", [])] - if self._cache: - await self._cache.set(cache_key, boards, ttl_seconds=900) # 15 min - return boards - - async def move_issue_to_sprint(self, sprint_id: int, issue_key: str) -> None: - """Move an issue to a sprint using the Agile API.""" - res = await self.agile_request( - "POST", f"sprint/{sprint_id}/issue", json={"issues": [issue_key]} - ) - res.raise_for_status() - - async def get_sprints_for_board(self, board_id: int) -> list[dict]: - """Return active and future sprints for a board. Cached for 5 minutes.""" - cache_key = f"sprints:{board_id}" - if self._cache: - cached = await self._cache.get(cache_key) - if cached: - return cached - res = await self.agile_request("GET", f"board/{board_id}/sprint?state=active,future") - res.raise_for_status() - sprints = [ - { - "id": s["id"], - "name": s["name"], - "state": s["state"], - "boardId": s.get("originBoardId", board_id), - } - for s in res.json().get("values", []) - ] - if self._cache: - await self._cache.set(cache_key, sprints, ttl_seconds=300) # 5 min - return sprints - - async def get_projects(self): - res = await self.request("GET", "project") - res.raise_for_status() - return res.json() - - async def find_user_account_id(self, name: str): - users = await self.search_users(name) - if isinstance(users, list): - for user in users: - if user.get("accountType") == "atlassian": - return user.get("accountId") - return None - - async def search_users(self, query: str, max_results: int = 10): - params = {"query": query, "maxResults": max_results} - res = await self.request("GET", "user/search", params=params) - res.raise_for_status() - return res.json() - - async def get_assignable_users( - self, project_key: str, query: str | None = None, max_results: int = 50 - ): - params = {"project": project_key, "maxResults": max_results} - if query: - params["query"] = query - res = await self.request("GET", "user/assignable/search", params=params) - res.raise_for_status() - return res.json() - - async def search_users_async( - self, query: str, max_results: int = 10, project_key: str | None = None - ): - if project_key: - return await self.get_assignable_users(project_key, query, max_results) - return await self.search_users(query, max_results) - - async def get_priorities(self): - res = await self.request("GET", "priority") - res.raise_for_status() - return res.json() - - async def get_project_metadata(self, project_key: str): - try: - ctx = await self.get_project_context(project_key) - issue_types = [it["name"] for it in ctx.get("issue_types", [])] - return {"key": project_key, "issue_types": issue_types} - except Exception as e: - logger.error(f"Error in get_project_metadata: {e}") - return {"error": str(e)} - - async def search_issues( - self, jql: str, fields: list[str] | None = None, start_at: int = 0, max_results: int = 50 - ): - all_issues = [] - PAGE_SIZE = 100 - target_count = max_results if max_results else float("inf") - next_token = None - - while len(all_issues) < target_count: - remaining = target_count - len(all_issues) - limit_for_request = min(remaining, PAGE_SIZE) - payload = {"jql": jql, "maxResults": limit_for_request} - if fields: - payload["fields"] = fields - else: - payload["fields"] = [ - "summary", - "status", - "issuetype", - "created", - "priority", - "assignee", - "reporter", - ] - - if next_token: - payload["nextPageToken"] = next_token - - res = await self.request("POST", "search/jql", json=payload) - res.raise_for_status() - data = res.json() - issues = data.get("issues", []) - all_issues.extend(issues) - next_token = data.get("nextPageToken") - if not issues or not next_token: - break - - return { - "startAt": start_at, - "maxResults": len(all_issues), - "total": data.get("total", len(all_issues)), - "issues": all_issues, - } - - async def validate_jql(self, jql: str): - payload = {"queries": [jql], "validation": "strict"} - try: - res = await self.request("POST", "jql/parse", json=payload) - if not res.is_success: - return {"valid": False, "errors": [f"API Error: {res.status_code} {res.text}"]} - data = res.json() - queries = data.get("queries", []) - if not queries: - return {"valid": False, "errors": ["No validation result returned from API."]} - result = queries[0] - errors = result.get("errors", []) - if errors: - return {"valid": False, "errors": errors} - return {"valid": True, "errors": []} - except Exception as e: - logger.error(f"JQL Validation failed: {e}") - return {"valid": False, "errors": [str(e)]} - - async def get_myself(self): - res = await self.request("GET", "myself") - res.raise_for_status() - return res.json() - - async def add_worklog( - self, issue_id_or_key: str, time_spent_string: str, comment: str | None = None - ): - payload = {"timeSpent": time_spent_string} - if comment: - payload["comment"] = { - "type": "doc", - "version": 1, - "content": [{"type": "paragraph", "content": [{"type": "text", "text": comment}]}], - } - res = await self.request("POST", f"issue/{issue_id_or_key}/worklog", json=payload) - res.raise_for_status() - return res.json() - - async def get_issue(self, issue_id_or_key: str): - params = { - "expand": "renderedFields,names,schema,operations,editmeta,changelog,versionedRepresentations" - } - res = await self.request("GET", f"issue/{issue_id_or_key}", params=params) - res.raise_for_status() - data = res.json() - fields = data.get("fields", {}) - - issue_view = { - "key": data.get("key"), - "id": data.get("id"), - "summary": fields.get("summary"), - "status": fields.get("status", {}).get("name"), - "priority": fields.get("priority", {}).get("name"), - "assignee": fields.get("assignee", {}).get("displayName", "Unassigned"), - "reporter": fields.get("reporter", {}).get("displayName", "Unknown"), - "created": fields.get("created"), - "updated": fields.get("updated"), - "labels": fields.get("labels", []), - "issuetype": fields.get("issuetype", {}).get("name"), - "url": f"{self.base_url.replace('/rest/api/3', '')}/browse/{data.get('key')}", - } - desc_node = fields.get("description") - issue_view["description"] = adf_to_markdown(desc_node) if desc_node else "" - raw_comments = fields.get("comment", {}).get("comments", []) - total_comments = len(raw_comments) - processed_comments = [] - - def _format_toc(c): - author = c.get("author", {}).get("displayName", "Unknown") - created = c.get("created")[:10] - body = c.get("body") - preview = "..." - if isinstance(body, dict): - try: - preview = body["content"][0]["content"][0]["text"][:50] + "..." - except (KeyError, IndexError, TypeError): - preview = "[Complex User Content]" - return f'[ID: {c.get("id")}] {created} | @{author} | "{preview}"' - - if total_comments <= 4: - processed_comments = [self._format_comment(c) for c in raw_comments] - else: - processed_comments.append(self._format_comment(raw_comments[0])) - middle_comments = raw_comments[1:-3] - processed_comments.append( - f"--- HIDDEN {len(middle_comments)} COMMENTS (Use get_comments to view all) ---" - ) - for c in middle_comments: - processed_comments.append(_format_toc(c)) - processed_comments.append("--- END HIDDEN ---") - for c in raw_comments[-3:]: - processed_comments.append(self._format_comment(c)) - - issue_view["comments"] = processed_comments - subtasks = fields.get("subtasks", []) - issue_view["subtasks"] = [ - { - "key": s["key"], - "summary": s["fields"]["summary"], - "status": s["fields"]["status"]["name"], - "issuetype": s["fields"]["issuetype"]["name"], - } - for s in subtasks - ] - links = fields.get("issuelinks", []) - issue_view["links"] = [] - for link in links: - if "outwardIssue" in link: - issue_view["links"].append( - { - "type": link["type"]["outward"], - "key": link["outwardIssue"]["key"], - "status": link["outwardIssue"]["fields"]["status"]["name"], - "direction": "outward", - } - ) - elif "inwardIssue" in link: - issue_view["links"].append( - { - "type": link["type"]["inward"], - "key": link["inwardIssue"]["key"], - "status": link["inwardIssue"]["fields"]["status"]["name"], - "direction": "inward", - } - ) - return issue_view - - async def read_comments(self, issue_id_or_key: str, comment_ids: list[str]): - output = [] - for cid in comment_ids: - try: - res = await self.request("GET", f"issue/{issue_id_or_key}/comment/{cid}") - if res.is_success: - data = res.json() - body = adf_to_markdown(data.get("body")) - author = data.get("author", {}).get("displayName", "Unknown") - created = data.get("created") - output.append(f"### Comment {cid} by {author} ({created})\n{body}\n") - else: - output.append(f"Error fetching comment {cid}: {res.status_code}") - except Exception as e: - output.append(f"Error fetching comment {cid}: {e}") - return "\n---\n".join(output) - - async def search_issue_history(self, issue_id_or_key: str, query: str): - params = {"expand": "changelog", "fields": "comment,summary"} - res = await self.request("GET", f"issue/{issue_id_or_key}", params=params) - res.raise_for_status() - data = res.json() - query_lower = query.lower() - results = [] - comments = data.get("fields", {}).get("comment", {}).get("comments", []) - for c in comments: - body_md = adf_to_markdown(c.get("body", {})) - if query_lower in body_md.lower(): - cid = c.get("id") - author = c.get("author", {}).get("displayName") - results.append(f"[Comment {cid}] {author}: found match in body.") - idx = body_md.lower().find(query_lower) - start = max(0, idx - 30) - end = min(len(body_md), idx + 30 + len(query)) - snippet = body_md[start:end].replace("\n", " ") - results.append(f' Context: "...{snippet}..."') - changelog = data.get("changelog", {}).get("histories", []) - for history in changelog: - author = history.get("author", {}).get("displayName") - created = history.get("created") - for item in history.get("items", []): - f = item.get("field", "") - t_from = str(item.get("fromString", "")) - t_to = str(item.get("toString", "")) - if ( - query_lower in f.lower() - or query_lower in t_from.lower() - or query_lower in t_to.lower() - ): - results.append( - f"[Changelog] {created} {author} changed '{f}': '{t_from}' -> '{t_to}'" - ) - if not results: - return f"No matches found for '{query}' in issue history." - return "\n".join(results) - - async def create_issue( - self, - project_key: str, - summary: str, - description: str, - issue_type: str = "Task", - assignee_id: str | None = None, - parent_key: str | None = None, - priority: str | None = None, - labels: list[str] | None = None, - due_date: str | None = None, - sprint_id: int | None = None, - ): - fields = { - "project": {"key": project_key}, - "summary": summary, - "description": { - "type": "doc", - "version": 1, - "content": [ - {"type": "paragraph", "content": [{"type": "text", "text": description}]} - ], - }, - "issuetype": {"name": issue_type}, - } - if assignee_id: - fields["assignee"] = {"accountId": assignee_id} - if parent_key: - fields["parent"] = {"key": parent_key} - if priority: - fields["priority"] = {"name": priority} - if labels: - fields["labels"] = labels - if due_date: - fields["duedate"] = due_date - if sprint_id: - fields["customfield_10020"] = sprint_id - - res = await self.request("POST", "issue", json={"fields": fields}) - res.raise_for_status() - return res.json() - - async def update_issue(self, issue_id_or_key: str, fields_dict: dict): - processed_fields = fields_dict.copy() - assignee_msg = None - if "assignee" in processed_fields: - val = processed_fields.pop("assignee") - account_id = val["accountId"] if isinstance(val, dict) and "accountId" in val else val - if isinstance(account_id, str): - try: - res_assign = await self.request( - "PUT", f"issue/{issue_id_or_key}/assignee", json={"accountId": account_id} - ) - res_assign.raise_for_status() - except Exception as e: - assignee_msg = f"Failed to assign issue: {e}" - if "customfield_10020" in processed_fields: - val = processed_fields["customfield_10020"] - if isinstance(val, int) and val: - processed_fields["customfield_10020"] = {"id": val} - elif not val: - del processed_fields["customfield_10020"] # omit to avoid Jira error on null sprint - if "priority" in processed_fields and isinstance(processed_fields["priority"], str): - processed_fields["priority"] = {"name": processed_fields["priority"]} - if "issuetype" in processed_fields and isinstance(processed_fields["issuetype"], str): - processed_fields["issuetype"] = {"name": processed_fields["issuetype"]} - if "description" in processed_fields and isinstance(processed_fields["description"], str): - processed_fields["description"] = { - "type": "doc", - "version": 1, - "content": [ - { - "type": "paragraph", - "content": [{"type": "text", "text": processed_fields["description"]}], - } - ], - } - if processed_fields: - res = await self.request( - "PUT", f"issue/{issue_id_or_key}", json={"fields": processed_fields} - ) - res.raise_for_status() - if res.status_code == 204: - return {"status": "success", "assignee_error": assignee_msg} - return res.json() - return {"status": "success", "assignee_error": assignee_msg} - - async def bulk_update_issues(self, issue_keys: list[str], fields_dict: dict): - results = {"success": [], "failed": []} - for key in issue_keys: - try: - await self.update_issue(key, fields_dict) - results["success"].append(key) - except Exception as e: - results["failed"].append({"key": key, "error": str(e)}) - return results - - async def bulk_transition_issues(self, issue_keys: list[str], transition_id: str): - results = {"success": [], "failed": []} - for key in issue_keys: - try: - await self.transition_issue(key, transition_id) - results["success"].append(key) - except Exception as e: - results["failed"].append({"key": key, "error": str(e)}) - return results - - async def get_fields(self): - cache_key = "jira_fields" - if self._cache: - cached = await self._cache.get(cache_key) - if cached: - return cached - res = await self.request("GET", "field") - res.raise_for_status() - data = res.json() - if self._cache: - await self._cache.set(cache_key, data, ttl_seconds=86400) - return data - - async def delete_issue(self, issue_id_or_key: str): - res = await self.request("DELETE", f"issue/{issue_id_or_key}") - res.raise_for_status() - return None - - async def add_comment(self, issue_id_or_key: str, body: str): - payload = { - "body": { - "type": "doc", - "version": 1, - "content": [{"type": "paragraph", "content": [{"type": "text", "text": body}]}], - } - } - res = await self.request("POST", f"issue/{issue_id_or_key}/comment", json=payload) - res.raise_for_status() - return res.json() - - async def get_transitions(self, issue_id_or_key: str): - res = await self.request("GET", f"issue/{issue_id_or_key}/transitions") - res.raise_for_status() - return res.json() - - async def transition_issue(self, issue_id_or_key: str, transition_id: str): - payload = {"transition": {"id": str(transition_id)}} - res = await self.request("POST", f"issue/{issue_id_or_key}/transitions", json=payload) - if res.status_code == 404: - raise Exception( - f"Issue {issue_id_or_key} check failed (404). It likely does not exist, or you lack permissions. Please verify the Issue Key." - ) - res.raise_for_status() - return None - - async def get_project_issue_types(self, project_key: str): - """Fetch detailed issue types for a project (id, name, iconUrl, subtask).""" - res = await self.request("GET", f"project/{project_key}") - res.raise_for_status() - data = res.json() - issue_types = data.get("issueTypes", []) - return [ - { - "id": it["id"], - "name": it["name"], - "iconUrl": it.get("iconUrl"), - "description": it.get("description"), - "subtask": it.get("subtask"), - } - for it in issue_types - ] - - async def get_project_context(self, project_key: str): - cache_key = f"project_meta_{project_key}" - if self._cache: - cached = await self._cache.get(cache_key) - if cached: - return cached - res_p = await self.request("GET", f"project/{project_key}") - res_p.raise_for_status() - project_data = res_p.json() - res_s = await self.request("GET", f"project/{project_key}/statuses") - res_s.raise_for_status() - statuses_data = res_s.json() - context = {"key": project_key, "name": project_data["name"], "issue_types": []} - for it in statuses_data: - itype = {"name": it["name"], "statuses": [s["name"] for s in it["statuses"]]} - context["issue_types"].append(itype) - if self._cache: - await self._cache.set(cache_key, context, ttl_seconds=3600) - return context - - async def upload_attachment(self, issue_key: str, file_content: bytes, filename: str): - url = f"issue/{issue_key}/attachments" - headers = { - "Authorization": f"Bearer {self.access_token}", - "Accept": "application/json", - "X-Atlassian-Token": "no-check", - } - files = {"file": (filename, file_content)} - - async def _call(): - return await self._client.post(f"{self.base_url}/{url}", headers=headers, files=files) - - res = await run_with_interactive_retry(_call) - if res.status_code == 401: - await self._handle_refresh() - headers["Authorization"] = f"Bearer {self.access_token}" - res = await run_with_interactive_retry(_call) - res.raise_for_status() - return res.json() - - async def delete_attachment(self, attachment_id: str): - res = await self.request("DELETE", f"attachment/{attachment_id}") - res.raise_for_status() - return {"status": "success", "message": f"Attachment {attachment_id} deleted"} - - async def update_comment(self, issue_key: str, comment_id: str, body_text: str): - payload = {"body": text_to_adf(body_text)} - res = await self.request("PUT", f"issue/{issue_key}/comment/{comment_id}", json=payload) - res.raise_for_status() - return res.json() - - async def delete_comment(self, issue_key: str, comment_id: str): - res = await self.request("DELETE", f"issue/{issue_key}/comment/{comment_id}") - res.raise_for_status() - return {"status": "success", "message": f"Comment {comment_id} deleted"} - - async def get_create_metadata(self, project_key: str, issue_type_id: str): - res = await self.request( - "GET", f"issue/createmeta/{project_key}/issuetypes/{issue_type_id}" - ) - res.raise_for_status() - data = res.json() - fields = data.get("values") or data.get("fields") or [] - required_fields = [] - if isinstance(fields, dict): - items = fields.items() - elif isinstance(fields, list): - items = [(f.get("fieldId", f.get("id", "unknown")), f) for f in fields] - else: - items = [] - for field_key, field_val in items: - if field_val.get("required"): - required_fields.append( - { - "id": field_key, - "name": field_val.get("name"), - "required": True, - "schema": field_val.get("schema"), - "allowedValues": field_val.get("allowedValues", []), - } - ) - return { - "project_key": project_key, - "issue_type_id": issue_type_id, - "required_fields": required_fields, - "all_fields_count": len(fields) if isinstance(fields, (list, dict)) else 0, - } - - def _format_comment(self, c: dict): - body = c.get("body") - content = adf_to_markdown(body) if isinstance(body, dict) else str(body) - return { - "id": c.get("id"), - "author": c.get("author", {}).get("displayName", "Unknown"), - "author_id": c.get("author", {}).get("accountId"), - "created": c.get("created"), - "body": content, - } - - async def get_comments(self, issue_id_or_key: str): - res = await self.request("GET", f"issue/{issue_id_or_key}/comment") - res.raise_for_status() - data = res.json() - raw_comments = data.get("comments", []) - return [self._format_comment(c) for c in raw_comments] - - async def get_myself_context(self): - cache_key = "user_myself" - if self._cache: - cached = await self._cache.get(cache_key) - if cached: - return cached - me = await self.get_myself() - context = { - "name": me["displayName"], - "accountId": me["accountId"], - "email": me.get("emailAddress", "hidden"), - } - if self._cache: - await self._cache.set(cache_key, context, ttl_seconds=3600) - return context - - async def get_project_summary_text(self): - """ - Fetches projects, sprints, and active tasks to generate a context summary. - Merged from apps/integrations/jira/jira_client.py. - """ - try: - projects = await self.get_projects() - summary_lines = [] - - for project in projects: - key = project["key"] - name = project["name"] - summary_lines.append(f"Project: {name} (Key: {key})") - - jql = f"project = {key} AND statusCategory in ('To Do', 'In Progress') ORDER BY updated DESC" - search_results = await self.search_issues(jql, max_results=10) - - issues = search_results.get("issues", []) - if issues: - summary_lines.append(" Active Tasks:") - for issue in issues: - fields = issue["fields"] - summary = fields["summary"] - status = fields["status"]["name"] - assignee = ( - fields["assignee"]["displayName"] - if fields["assignee"] - else "Unassigned" - ) - issue_key = issue["key"] - summary_lines.append( - f" - {issue_key}: {summary} (Assignee: {assignee}, Status: {status})" - ) - - summary_lines.append("") - - return "\n".join(summary_lines) - except Exception as e: - logger.error(f"Error fetching context: {e}") - return f"Error fetching context: {str(e)}" - - async def get_user_active_tasks(self, email: str, project_key: str = None) -> list[dict]: - """ - Get active tasks (In Progress, To Do) for a specific user email. - Merged from apps/integrations/jira/jira_client.py with all strategies. - """ - if not email: - return [] - - account_id = None - - # Strategy 0: 'myself' - try: - myself = await self.get_myself() - if myself and myself.get("accountId"): - account_id = myself.get("accountId") - logger.info(f"Found user via 'myself' endpoint: {account_id}") - except Exception as e: - logger.info(f"'myself' check failed: {e}") - - # Strategy 1: Search by Email - if not account_id: - try: - user_search = await self.search_users(email) - if user_search and isinstance(user_search, list) and len(user_search) > 0: - account_id = user_search[0].get("accountId") - logger.info(f"Found user via email search: {account_id}") - except Exception as e: - logger.info(f"Email search failed: {e}") - - # Strategy 2: Project Assignable Users - if not account_id and project_key: - try: - assignable = await self.get_assignable_users(project_key, max_results=100) - if assignable and isinstance(assignable, list): - for u in assignable: - if u.get("emailAddress", "").lower() == email.lower(): - account_id = u.get("accountId") - break - if email.split("@")[0].lower() in u.get("displayName", "").lower(): - account_id = u.get("accountId") - break - except Exception as e: - logger.info(f"Assignable search failed: {e}") - - if not account_id: - logger.warning(f"Could not find user accountId for {email}") - return [] - - # 2. Search for issues - try: - jql = f"assignee = '{account_id}' AND statusCategory != Done ORDER BY updated DESC" - results = await self.search_issues(jql, max_results=10) - issues = results.get("issues", []) - return [ - { - "key": i["key"], - "summary": i["fields"].get("summary"), - "description": i["fields"].get("description"), - "status": i["fields"].get("status", {}).get("name"), - "priority": i["fields"].get("priority", {}).get("name"), - "updated": i["fields"].get("updated"), - } - for i in issues - ] - except Exception as e: - logger.error(f"Error fetching active tasks for {email}: {e}") - return [] - - async def update_status(self, issue_key: str, status_name: str): - """ - Transition an issue to a status by name (fuzzy match). - Merged from apps/integrations/jira/jira_client.py. - """ - data = await self.get_transitions(issue_key) - transitions = data.get("transitions", []) - - target_id = None - status_lower = status_name.lower() - - for t in transitions: - if t["name"].lower() == status_lower: - target_id = t["id"] - break - - if not target_id: - for t in transitions: - if status_lower in t["name"].lower(): - target_id = t["id"] - break - - if not target_id: - msg = f"Transition '{status_name}' not found for {issue_key}." - if transitions: - names = [t["name"] for t in transitions] - msg += f" Available: {', '.join(names)}" - return {"error": msg} - - await self.transition_issue(issue_key, target_id) - return {"status": "success", "message": f"Moved {issue_key} to {status_name}"} - - async def get_issue_details(self, issue_key: str): - """Alias for get_issue to match old client interface.""" - return await self.get_issue(issue_key) +from libs.common.jira import ( # noqa: F401 + JiraClient, + JiraSessionExpired, + run_with_interactive_retry, +) diff --git a/libs/common/models.py b/libs/common/models.py index 3a403ac..da4163a 100644 --- a/libs/common/models.py +++ b/libs/common/models.py @@ -25,6 +25,8 @@ class UserToken(Base): access_token = Column(Text) refresh_token = Column(Text) cloud_id = Column(String) # The Jira Cloud ID (site ID) + trello_workspace_id = Column(String, nullable=True) # Atlassian resource ID for Trello + trello_board_id = Column(String, nullable=True) # Selected Trello board ID # Relationship user = relationship("User", back_populates="jira_token") diff --git a/tests/test_chat_jira_tasks_response.py b/tests/test_chat_jira_tasks_response.py index 9125872..7e2e3cf 100644 --- a/tests/test_chat_jira_tasks_response.py +++ b/tests/test_chat_jira_tasks_response.py @@ -1,6 +1,6 @@ def test_proposed_tasks_cleared_between_calls(): """Tasks from a previous call must not leak after clear.""" - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, get_proposed_tasks, propose_jira_tasks, @@ -18,7 +18,7 @@ def test_response_includes_jira_tasks_when_proposed(): import re from apps.agents.agent_server.api import build_chat_response - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, propose_jira_tasks, ) @@ -46,7 +46,7 @@ def test_response_includes_jira_tasks_when_proposed(): def test_response_excludes_jira_tasks_when_none_proposed(): from apps.agents.agent_server.api import build_chat_response - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, ) diff --git a/tests/test_jira_task_proposal_tools.py b/tests/test_jira_task_proposal_tools.py index 3fc9df2..1ac9fcf 100644 --- a/tests/test_jira_task_proposal_tools.py +++ b/tests/test_jira_task_proposal_tools.py @@ -1,5 +1,5 @@ def test_propose_jira_tasks_returns_confirmation(): - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, get_proposed_tasks, propose_jira_tasks, @@ -14,7 +14,7 @@ def test_propose_jira_tasks_returns_confirmation(): def test_propose_jira_tasks_stores_all_fields(): - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, get_proposed_tasks, propose_jira_tasks, @@ -45,7 +45,7 @@ def test_propose_jira_tasks_stores_all_fields(): def test_clear_proposed_tasks(): - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, get_proposed_tasks, propose_jira_tasks, @@ -57,7 +57,7 @@ def test_clear_proposed_tasks(): def test_propose_jira_tasks_replaces_on_second_call(): - from apps.agents.agent_server.src.server.tools.jira_task_proposal_tools import ( + from apps.agents.agent_server.tools.jira_task_proposal_tools import ( clear_proposed_tasks, get_proposed_tasks, propose_jira_tasks, diff --git a/tmp_test_health.py b/tmp_test_health.py deleted file mode 100644 index 34dc762..0000000 --- a/tmp_test_health.py +++ /dev/null @@ -1,19 +0,0 @@ -import asyncio - -import httpx - -from apps.api_server.main import app - - -async def test_health(): - transport = httpx.ASGITransport(app=app) - async with httpx.AsyncClient(transport=transport, base_url="http://test") as ac: - response = await ac.get("/health") - print("Status Code:", response.status_code) - import json - - print("Response:", json.dumps(response.json(), indent=2)) - - -if __name__ == "__main__": - asyncio.run(test_health())