diff --git a/.gitignore b/.gitignore index 3ce1cc5e..94908b9a 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,4 @@ bots/conf/ # IDE files .vscode/ .idea/ +improvements \ No newline at end of file diff --git a/config.py b/config.py index 2f1b08dc..a77a886d 100644 --- a/config.py +++ b/config.py @@ -1,3 +1,4 @@ +import logging from typing import List from pydantic import Field @@ -58,19 +59,60 @@ class MarketDataSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="MARKET_DATA_", extra="ignore") -class SecuritySettings(BaseSettings): - """Security and authentication configuration.""" +# Insecure default credential values (SEC-018), mapped to the environment variables that override them. +# They are kept only for local development convenience and MUST be overridden in production deployments. +_INSECURE_SECURITY_DEFAULTS = { + "USERNAME": "admin", + "PASSWORD": "admin", + "CONFIG_PASSWORD": "a", +} + - username: str = Field(default="admin", description="API basic auth username") - password: str = Field(default="admin", description="API basic auth password") - debug_mode: bool = Field(default=False, description="Enable debug mode (disables auth)") - config_password: str = Field(default="a", description="Bot configuration encryption password") +class SecuritySettings(BaseSettings): + """Security and authentication configuration. + + All fields are read from environment variables without a prefix (or from .env): + - USERNAME: API basic auth username (default "admin" — local development only, never use in production) + - PASSWORD: API basic auth password (default "admin" — local development only, never use in production) + - CONFIG_PASSWORD: password used to encrypt ALL connector credentials (default "a" — local development only, + never use in production) + """ + + username: str = Field(default="admin", description="API basic auth username (override via USERNAME in production)") + password: str = Field(default="admin", description="API basic auth password (override via PASSWORD in production)") + config_password: str = Field( + default="a", + description="Bot configuration encryption password (override via CONFIG_PASSWORD in production)" + ) model_config = SettingsConfigDict( env_prefix="", extra="ignore" # Ignore extra environment variables ) + def insecure_defaults_in_use(self) -> List[str]: + """Return the env var names of security settings still set to their insecure default values.""" + current_values = {"USERNAME": self.username, "PASSWORD": self.password, "CONFIG_PASSWORD": self.config_password} + return [name for name, default in _INSECURE_SECURITY_DEFAULTS.items() if current_values[name] == default] + + +def warn_if_insecure_security_defaults(security: SecuritySettings) -> List[str]: + """Emit a high-severity log if any security setting still uses its insecure default value (SEC-018). + + Returns the list of env var names that are still at their defaults (empty list when fully configured). + """ + insecure = security.insecure_defaults_in_use() + if insecure: + logging.critical( + "SECURITY WARNING: insecure default credentials in use for: %s. " + "Anyone who can reach this API can authenticate with the default basic auth credentials, and all " + "connector credentials are encrypted with a trivially guessable password. " + "Set the USERNAME, PASSWORD and CONFIG_PASSWORD environment variables (e.g. in .env) before deploying " + "to production. Do NOT run a production deployment with these defaults.", + ", ".join(insecure), + ) + return insecure + class AWSSettings(BaseSettings): """AWS configuration for S3 archiving.""" @@ -93,6 +135,33 @@ class GatewaySettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="GATEWAY_", extra="ignore") +class CORSSettings(BaseSettings): + """CORS configuration for the API (SEC-019). + + A wildcard origin ("*") must never be combined with allow_credentials=True: browsers reject that + combination per the CORS spec, and Starlette works around it by reflecting any Origin, which lets + arbitrary third-party pages call the API from an authenticated operator's browser. Origins are + therefore restricted by default and configurable via environment variables: + - CORS_ALLOW_ORIGINS: JSON list of explicit trusted origins, e.g. '["https://dashboard.example.com"]' + - CORS_ALLOW_ORIGIN_REGEX: regex for trusted origins (defaults to localhost-only for local development; + set to an empty string to disable regex matching entirely) + """ + + allow_origins: List[str] = Field( + default=[], + description='Explicit list of trusted CORS origins, e.g. CORS_ALLOW_ORIGINS=\'["https://dashboard.example.com"]\'' + ) + allow_origin_regex: str = Field( + default=r"https?://(localhost|127\.0\.0\.1)(:\d+)?", + description="Regex matching trusted CORS origins; defaults to localhost-only. Empty string disables regex matching." + ) + allow_credentials: bool = Field(default=True, description="Allow credentialed (cookies/auth) cross-origin requests") + allow_methods: List[str] = Field(default=["*"], description="HTTP methods allowed for cross-origin requests") + allow_headers: List[str] = Field(default=["*"], description="HTTP headers allowed for cross-origin requests") + + model_config = SettingsConfigDict(env_prefix="CORS_", extra="ignore") + + class AppSettings(BaseSettings): """Main application settings.""" @@ -130,6 +199,7 @@ class Settings(BaseSettings): security: SecuritySettings = Field(default_factory=SecuritySettings) aws: AWSSettings = Field(default_factory=AWSSettings) gateway: GatewaySettings = Field(default_factory=GatewaySettings) + cors: CORSSettings = Field(default_factory=CORSSettings) app: AppSettings = Field(default_factory=AppSettings) # Direct banned_tokens field to handle env parsing @@ -145,6 +215,4 @@ class Settings(BaseSettings): extra="ignore" ) - -# Create global settings instance settings = Settings() diff --git a/database/__init__.py b/database/__init__.py index 0f2a6ece..7f759fb4 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -17,6 +17,7 @@ AccountRepository, BotRunRepository, ControllerPerformanceRepository, + ExecutorRepository, FundingRepository, GatewayCLMMRepository, GatewaySwapRepository, @@ -30,6 +31,7 @@ "ControllerPerformanceSnapshot", "Base", "AsyncDatabaseManager", "AccountRepository", "BotRunRepository", "ControllerPerformanceRepository", + "ExecutorRepository", "OrderRepository", "TradeRepository", "FundingRepository", "GatewaySwapRepository", "GatewayCLMMRepository" ] diff --git a/database/repositories/account_repository.py b/database/repositories/account_repository.py index a799c130..baa88519 100644 --- a/database/repositories/account_repository.py +++ b/database/repositories/account_repository.py @@ -1,10 +1,8 @@ -from datetime import datetime, timedelta +from datetime import datetime from decimal import Decimal from typing import Dict, List, Optional, Tuple -import base64 -import json -from sqlalchemy import desc, select, func +from sqlalchemy import desc, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -15,6 +13,17 @@ class AccountRepository: def __init__(self, session: AsyncSession): self.session = session + @staticmethod + def _token_state_to_dict(token_state: TokenState) -> Dict: + """Serialize a TokenState into the standard token info dict with float casts.""" + return { + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + } + @staticmethod def _interval_to_minutes(interval: str) -> int: """Convert interval string to minutes.""" @@ -63,11 +72,16 @@ def _sample_history_by_interval(history: List[Dict], interval_minutes: int) -> L return sampled - async def save_account_state(self, account_name: str, connector_name: str, tokens_info: List[Dict], + async def save_account_state(self, account_name: str, connector_name: str, tokens_info: List[Dict], snapshot_timestamp: Optional[datetime] = None) -> AccountState: """ Save account state with token information to the database. If snapshot_timestamp is provided, use it instead of server default. + + Note: this method does NOT commit; it only flushes to obtain the AccountState id. + The caller's session context owns the transaction and commits once + (e.g. get_session_context commits on successful exit), so a snapshot spanning + multiple accounts/connectors persists atomically in a single transaction. """ account_state_data = { "account_name": account_name, @@ -93,8 +107,7 @@ async def save_account_state(self, account_name: str, connector_name: str, token available_units=Decimal(str(token_info["available_units"])) ) self.session.add(token_state) - - await self.session.commit() + return account_state async def get_latest_account_states(self) -> Dict[str, Dict[str, List[Dict]]]: @@ -133,16 +146,8 @@ async def get_latest_account_states(self) -> Dict[str, Dict[str, List[Dict]]]: if account_state.account_name not in accounts_state: accounts_state[account_state.account_name] = {} - token_info = [] - for token_state in account_state.token_states: - token_info.append({ - "token": token_state.token, - "units": float(token_state.units), - "price": float(token_state.price), - "value": float(token_state.value), - "available_units": float(token_state.available_units) - }) - + token_info = [self._token_state_to_dict(token_state) for token_state in account_state.token_states] + accounts_state[account_state.account_name][account_state.connector_name] = token_info return accounts_state @@ -150,6 +155,7 @@ async def get_latest_account_states(self) -> Dict[str, Dict[str, List[Dict]]]: async def get_account_state_history(self, limit: Optional[int] = None, account_name: Optional[str] = None, + account_names: Optional[List[str]] = None, connector_name: Optional[str] = None, cursor: Optional[str] = None, start_time: Optional[datetime] = None, @@ -160,7 +166,8 @@ async def get_account_state_history(self, Args: limit: Maximum number of records to return - account_name: Filter by account name + account_name: Filter by a single account name + account_names: Filter by multiple account names (IN filter) connector_name: Filter by connector name cursor: Cursor for pagination start_time: Start time filter @@ -171,52 +178,67 @@ async def get_account_state_history(self, Tuple of (data, next_cursor, has_more) """ interval_minutes = self._interval_to_minutes(interval) - query = ( - select(AccountState) - .options(joinedload(AccountState.token_states)) - .order_by(desc(AccountState.timestamp)) - ) - - # Apply filters - if account_name: - query = query.filter(AccountState.account_name == account_name) - if connector_name: - query = query.filter(AccountState.connector_name == connector_name) - if start_time: - query = query.filter(AccountState.timestamp >= start_time) - if end_time: - query = query.filter(AccountState.timestamp <= end_time) - - # Handle cursor-based pagination - if cursor: - try: - cursor_time = datetime.fromisoformat(cursor.replace('Z', '+00:00')) - query = query.filter(AccountState.timestamp < cursor_time) - except (ValueError, TypeError): - # Invalid cursor, ignore it - pass - - # Fetch more records than requested to ensure we have enough after sampling - # For intervals > 5m, we need to fetch more data to get enough sampled points + + # Minute bucket expression: a single logical snapshot fans out into one row per + # (account_name, connector_name) but all share the same minute. Paginate by these + # distinct minute buckets so the limit/cursor are independent of the account/connector + # fan-out (a row-based limit would collapse N*M rows into far fewer buckets than `limit`). + minute_bucket = func.date_trunc("minute", AccountState.timestamp) + + def _apply_filters(stmt): + if account_name: + stmt = stmt.filter(AccountState.account_name == account_name) + if account_names: + stmt = stmt.filter(AccountState.account_name.in_(account_names)) + if connector_name: + stmt = stmt.filter(AccountState.connector_name == connector_name) + if start_time: + stmt = stmt.filter(AccountState.timestamp >= start_time) + if end_time: + stmt = stmt.filter(AccountState.timestamp <= end_time) + # Handle cursor-based pagination: the cursor is a minute-bucket timestamp, so + # everything strictly before it excludes all already-returned buckets. + if cursor: + try: + cursor_time = datetime.fromisoformat(cursor.replace('Z', '+00:00')) + stmt = stmt.filter(AccountState.timestamp < cursor_time) + except (ValueError, TypeError): + # Invalid cursor, ignore it + pass + return stmt + + # Step 1: select the distinct minute buckets that match the filters, most recent first. + # For intervals > 5m we widen the window so sampling still has enough buckets to pick from. sampling_multiplier = max(1, interval_minutes // 5) # How many 5m intervals per sample fetch_limit = (limit * sampling_multiplier + 1) if limit else (100 * sampling_multiplier + 1) - query = query.limit(fetch_limit) - - result = await self.session.execute(query) - account_states = result.unique().scalars().all() + timestamps_query = ( + select(minute_bucket.label("minute")) + .distinct() + .order_by(desc(minute_bucket)) + .limit(fetch_limit) + ) + timestamps_query = _apply_filters(timestamps_query) + timestamps_result = await self.session.execute(timestamps_query) + selected_minutes = [row.minute for row in timestamps_result.all()] + + # Step 2: fetch the AccountState (+token) rows only for the selected minute buckets. + if selected_minutes: + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .filter(minute_bucket.in_(selected_minutes)) + .order_by(desc(AccountState.timestamp)) + ) + query = _apply_filters(query) + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + else: + account_states = [] # Format response - Group by minute to aggregate account/connector states minute_groups = {} for account_state in account_states: - token_info = [] - for token_state in account_state.token_states: - token_info.append({ - "token": token_state.token, - "units": float(token_state.units), - "price": float(token_state.price), - "value": float(token_state.value), - "available_units": float(token_state.available_units) - }) + token_info = [self._token_state_to_dict(token_state) for token_state in account_state.token_states] # Round timestamp to the nearest minute for grouping minute_timestamp = account_state.timestamp.replace(second=0, microsecond=0) @@ -235,9 +257,9 @@ async def get_account_state_history(self, minute_groups[minute_key]["state"][account_state.account_name][account_state.connector_name] = token_info - # Convert to list and maintain chronological order (most recent first) + # Already ordered most-recent-first: Step 2 fetched rows ordered by descending + # timestamp and minute truncation is monotonic, so dict insertion order is descending. history = list(minute_groups.values()) - history.sort(key=lambda x: x["timestamp"], reverse=True) # Apply interval sampling sampled_history = self._sample_history_by_interval(history, interval_minutes) @@ -284,15 +306,7 @@ async def get_account_current_state(self, account_name: str) -> Dict[str, List[D state = {} for account_state in account_states: - token_info = [] - for token_state in account_state.token_states: - token_info.append({ - "token": token_state.token, - "units": float(token_state.units), - "price": float(token_state.price), - "value": float(token_state.value), - "available_units": float(token_state.available_units) - }) + token_info = [self._token_state_to_dict(token_state) for token_state in account_state.token_states] state[account_state.connector_name] = token_info return state @@ -318,16 +332,8 @@ async def get_connector_current_state(self, account_name: str, connector_name: s if not account_state: return [] - token_info = [] - for token_state in account_state.token_states: - token_info.append({ - "token": token_state.token, - "units": float(token_state.units), - "price": float(token_state.price), - "value": float(token_state.value), - "available_units": float(token_state.available_units) - }) - + token_info = [self._token_state_to_dict(token_state) for token_state in account_state.token_states] + return token_info async def get_all_unique_tokens(self) -> List[str]: diff --git a/database/repositories/bot_run_repository.py b/database/repositories/bot_run_repository.py index 57f4bad5..f6566cc4 100644 --- a/database/repositories/bot_run_repository.py +++ b/database/repositories/bot_run_repository.py @@ -1,8 +1,8 @@ import json from datetime import datetime, timezone -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional -from sqlalchemy import delete, desc, select, and_, or_, func +from sqlalchemy import and_, desc, func, or_, select from sqlalchemy.ext.asyncio import AsyncSession from database.models import BotRun diff --git a/database/repositories/controller_performance_repository.py b/database/repositories/controller_performance_repository.py index 83b22f84..e7494c52 100644 --- a/database/repositories/controller_performance_repository.py +++ b/database/repositories/controller_performance_repository.py @@ -66,6 +66,32 @@ async def save_controller_performance( await self.session.flush() return snapshot + async def save_controller_performances(self, snapshots: List[Dict]) -> List[ControllerPerformanceSnapshot]: + """Save multiple controller performance snapshots with a single add_all/flush. + + Each item in `snapshots` is a dict with keys: bot_name, controller_id, status, + performance, custom_info and optionally snapshot_timestamp. + """ + if not snapshots: + return [] + + rows = [] + for item in snapshots: + data = { + "bot_name": item["bot_name"], + "controller_id": item["controller_id"], + "status": item["status"], + "performance": json.dumps(item["performance"]) if item.get("performance") else None, + "custom_info": json.dumps(item["custom_info"]) if item.get("custom_info") else None, + } + if item.get("snapshot_timestamp"): + data["timestamp"] = item["snapshot_timestamp"] + rows.append(ControllerPerformanceSnapshot(**data)) + + self.session.add_all(rows) + await self.session.flush() + return rows + async def get_latest_performance( self, bot_name: Optional[str] = None diff --git a/database/repositories/funding_repository.py b/database/repositories/funding_repository.py index e9b8dd42..ab9e685d 100644 --- a/database/repositories/funding_repository.py +++ b/database/repositories/funding_repository.py @@ -1,8 +1,7 @@ -from datetime import datetime -from typing import Dict, List, Optional from decimal import Decimal +from typing import Dict, List -from sqlalchemy import desc, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.models import FundingPayment diff --git a/database/repositories/gateway_clmm_repository.py b/database/repositories/gateway_clmm_repository.py index af11b0df..f292c69f 100644 --- a/database/repositories/gateway_clmm_repository.py +++ b/database/repositories/gateway_clmm_repository.py @@ -1,11 +1,11 @@ from datetime import datetime, timezone -from typing import Dict, List, Optional, Set, Tuple from decimal import Decimal +from typing import Dict, List, Optional, Set -from sqlalchemy import desc, select, distinct +from sqlalchemy import distinct, select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import GatewayCLMMPosition, GatewayCLMMEvent +from database.models import GatewayCLMMEvent, GatewayCLMMPosition class GatewayCLMMRepository: @@ -30,6 +30,13 @@ async def get_position_by_address(self, position_address: str) -> Optional[Gatew ) return result.scalar_one_or_none() + async def get_position_by_id(self, position_id: int) -> Optional[GatewayCLMMPosition]: + """Get a position by its primary key id.""" + result = await self.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.id == position_id) + ) + return result.scalar_one_or_none() + async def update_position_liquidity( self, position_address: str, diff --git a/database/repositories/gateway_swap_repository.py b/database/repositories/gateway_swap_repository.py index 57871fb8..c5aea52f 100644 --- a/database/repositories/gateway_swap_repository.py +++ b/database/repositories/gateway_swap_repository.py @@ -1,8 +1,8 @@ from datetime import datetime -from typing import Dict, List, Optional from decimal import Decimal +from typing import Dict, List, Optional -from sqlalchemy import desc, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.models import GatewaySwap diff --git a/database/repositories/order_repository.py b/database/repositories/order_repository.py index 27acdc7f..5036bb10 100644 --- a/database/repositories/order_repository.py +++ b/database/repositories/order_repository.py @@ -1,8 +1,8 @@ from datetime import datetime -from typing import Dict, List, Optional from decimal import Decimal +from typing import Dict, List, Optional -from sqlalchemy import desc, select +from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.models import Order @@ -26,13 +26,10 @@ async def get_order_by_client_id(self, client_order_id: str) -> Optional[Order]: ) return result.scalar_one_or_none() - async def update_order_status(self, client_order_id: str, status: str, - error_message: Optional[str] = None) -> Optional[Order]: + async def update_order_status(self, client_order_id: str, status: str, + error_message: Optional[str] = None) -> Optional[Order]: """Update order status and optional error message.""" - result = await self.session.execute( - select(Order).where(Order.client_order_id == client_order_id) - ) - order = result.scalar_one_or_none() + order = await self.get_order_by_client_id(client_order_id) if order: order.status = status if error_message: @@ -132,20 +129,30 @@ async def get_active_orders(self, account_name: Optional[str] = None, async def get_orders_summary(self, account_name: Optional[str] = None, start_time: Optional[int] = None, end_time: Optional[int] = None) -> Dict: - """Get order summary statistics.""" - orders = await self.get_orders( - account_name=account_name, - start_time=start_time, - end_time=end_time, - limit=10000 # Get all for summary + """Get order summary statistics using a single DB-level aggregate query.""" + query = select(Order.status, func.count()).group_by(Order.status) + + # Apply the same filters as get_orders + if account_name: + query = query.where(Order.account_name == account_name) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Order.created_at >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Order.created_at <= end_dt) + + result = await self.session.execute(query) + counts = {status: count for status, count in result.all()} + + total_orders = sum(counts.values()) + filled_orders = counts.get("FILLED", 0) + cancelled_orders = counts.get("CANCELLED", 0) + failed_orders = counts.get("FAILED", 0) + active_orders = ( + counts.get("SUBMITTED", 0) + counts.get("OPEN", 0) + counts.get("PARTIALLY_FILLED", 0) ) - - total_orders = len(orders) - filled_orders = sum(1 for o in orders if o.status == "FILLED") - cancelled_orders = sum(1 for o in orders if o.status == "CANCELLED") - failed_orders = sum(1 for o in orders if o.status == "FAILED") - active_orders = sum(1 for o in orders if o.status in ["SUBMITTED", "OPEN", "PARTIALLY_FILLED"]) - + return { "total_orders": total_orders, "filled_orders": filled_orders, diff --git a/database/repositories/trade_repository.py b/database/repositories/trade_repository.py index f718a643..612859d8 100644 --- a/database/repositories/trade_repository.py +++ b/database/repositories/trade_repository.py @@ -1,11 +1,11 @@ from datetime import datetime from typing import Dict, List, Optional -from sqlalchemy import desc, select +from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from database.models import Trade, Order +from database.models import Order, Trade class TradeRepository: @@ -41,40 +41,6 @@ async def get_trade_by_id(self, trade_id: str) -> Optional[Trade]: result = await self.session.execute(query) return result.scalar_one_or_none() - async def get_trades(self, account_name: Optional[str] = None, - connector_name: Optional[str] = None, - trading_pair: Optional[str] = None, - trade_type: Optional[str] = None, - start_time: Optional[int] = None, - end_time: Optional[int] = None, - limit: int = 100, offset: int = 0) -> List[Trade]: - """Get trades with filtering and pagination.""" - # Join trades with orders to get account information - query = select(Trade).join(Order, Trade.order_id == Order.id) - - # Apply filters - if account_name: - query = query.where(Order.account_name == account_name) - if connector_name: - query = query.where(Order.connector_name == connector_name) - if trading_pair: - query = query.where(Trade.trading_pair == trading_pair) - if trade_type: - query = query.where(Trade.trade_type == trade_type) - if start_time: - start_dt = datetime.fromtimestamp(start_time / 1000) - query = query.where(Trade.timestamp >= start_dt) - if end_time: - end_dt = datetime.fromtimestamp(end_time / 1000) - query = query.where(Trade.timestamp <= end_dt) - - # Apply ordering and pagination - query = query.order_by(Trade.timestamp.desc()) - query = query.limit(limit).offset(offset) - - result = await self.session.execute(query) - return result.scalars().all() - async def get_trades_with_orders(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, trading_pair: Optional[str] = None, diff --git a/deps.py b/deps.py index 36db7f1a..7b4ff4ae 100644 --- a/deps.py +++ b/deps.py @@ -2,15 +2,16 @@ from database import AsyncDatabaseManager from services.accounts_service import AccountsService +from services.backtesting_service import BacktestingService from services.bots_orchestrator import BotsOrchestrator from services.docker_service import DockerService from services.executor_service import ExecutorService from services.executor_ws_manager import ExecutorWebSocketManager from services.gateway_service import GatewayService from services.market_data_service import MarketDataService +from services.trading_history_service import TradingHistoryService from services.trading_service import TradingService from services.unified_connector_service import UnifiedConnectorService -from services.backtesting_service import BacktestingService from services.websocket_manager import WebSocketManager from utils.bot_archiver import BotArchiver @@ -50,6 +51,11 @@ def get_trading_service(request: Request) -> TradingService: return request.app.state.trading_service +def get_trading_history_service(request: Request) -> TradingHistoryService: + """Get TradingHistoryService from app state.""" + return request.app.state.trading_history_service + + def get_executor_service(request: Request) -> ExecutorService: """Get ExecutorService from app state.""" return request.app.state.executor_service diff --git a/main.py b/main.py index 0f6e3d34..28328e0a 100644 --- a/main.py +++ b/main.py @@ -38,7 +38,7 @@ def patched_save_to_yml(yml_path, cm): from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient # noqa: E402 from hummingbot.core.rate_oracle.rate_oracle import RATE_ORACLE_SOURCES, RateOracle # noqa: E402 -from config import settings # noqa: E402 +from config import settings, warn_if_insecure_security_defaults # noqa: E402 from database import AsyncDatabaseManager # noqa: E402 from routers import ( # noqa: E402 accounts, @@ -68,6 +68,7 @@ def patched_save_to_yml(yml_path, cm): from services.executor_ws_manager import ExecutorWebSocketManager # noqa: E402 from services.gateway_service import GatewayService # noqa: E402 from services.market_data_service import MarketDataService # noqa: E402 +from services.trading_history_service import TradingHistoryService # noqa: E402 from services.trading_service import TradingService # noqa: E402 from services.unified_connector_service import UnifiedConnectorService # noqa: E402 from services.websocket_manager import WebSocketManager # noqa: E402 @@ -86,7 +87,6 @@ def patched_save_to_yml(yml_path, cm): # Get settings from Pydantic Settings username = settings.security.username password = settings.security.password -debug_mode = settings.security.debug_mode # Security setup security = HTTPBasic() @@ -98,6 +98,9 @@ async def lifespan(app: FastAPI): Lifespan context manager for the FastAPI application. Handles startup and shutdown events. """ + # SEC-018: warn loudly if USERNAME/PASSWORD/CONFIG_PASSWORD are still the insecure defaults + warn_if_insecure_security_defaults(settings.security) + # Ensure password verification file exists if BackendAPISecurity.new_password_required(): # Create secrets manager with CONFIG_PASSWORD @@ -195,15 +198,19 @@ async def lifespan(app: FastAPI): # AccountsService - account management, balances, portfolio (simplified) accounts_service = AccountsService( + db_manager=db_manager, + connector_service=connector_service, + market_data_service=market_data_service, + trading_service=trading_service, account_update_interval=settings.app.account_update_interval, gateway_url=settings.gateway.url ) - # Inject services into AccountsService - accounts_service._connector_service = connector_service - accounts_service._market_data_service = market_data_service - accounts_service._trading_service = trading_service logging.info("AccountsService initialized") + # TradingHistoryService - read-only persistence queries for orders/trades/funding + trading_history_service = TradingHistoryService(db_manager=db_manager) + logging.info("TradingHistoryService initialized") + # ========================================================================= # 4. ExecutorService - depends on TradingService (NO circular dependency) # ========================================================================= @@ -226,6 +233,7 @@ async def lifespan(app: FastAPI): broker_port=settings.broker.port, broker_username=settings.broker.username, broker_password=settings.broker.password, + db_manager=db_manager, performance_dump_interval=settings.broker.performance_dump_interval ) @@ -270,6 +278,7 @@ async def lifespan(app: FastAPI): app.state.market_data_service = market_data_service app.state.trading_service = trading_service app.state.accounts_service = accounts_service + app.state.trading_history_service = trading_history_service app.state.executor_service = executor_service websocket_manager = WebSocketManager(market_data_service) app.state.websocket_manager = websocket_manager @@ -296,7 +305,7 @@ async def lifespan(app: FastAPI): websocket_manager.shutdown() await executor_ws_manager.shutdown() - bots_orchestrator.stop() + await bots_orchestrator.stop() await accounts_service.stop() await executor_service.stop() market_data_service.stop() @@ -315,13 +324,16 @@ async def lifespan(app: FastAPI): redirect_slashes=False, ) -# Add CORS middleware +# Add CORS middleware (SEC-019). Origins are restricted by default: a wildcard origin must not be +# combined with allow_credentials=True. Trusted origins are configured via CORS_ALLOW_ORIGINS / +# CORS_ALLOW_ORIGIN_REGEX (see config.CORSSettings); the default only allows localhost origins. app.add_middleware( CORSMiddleware, - allow_origins=["*"], # Modify in production to specific origins - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_origins=settings.cors.allow_origins, + allow_origin_regex=settings.cors.allow_origin_regex or None, + allow_credentials=settings.cors.allow_credentials, + allow_methods=settings.cors.allow_methods, + allow_headers=settings.cors.allow_headers, ) @@ -367,7 +379,7 @@ def auth_user( is_correct_password = secrets.compare_digest( current_password_bytes, correct_password_bytes ) - if not (is_correct_username and is_correct_password) and not debug_mode: + if not (is_correct_username and is_correct_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", diff --git a/models/bot_orchestration.py b/models/bot_orchestration.py index a23dd242..5378ef0b 100644 --- a/models/bot_orchestration.py +++ b/models/bot_orchestration.py @@ -1,6 +1,28 @@ +import re from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator + +# Safe single path component names: prevents path traversal via '/', '\' or '..'. +# Mirrors services.accounts_service.SAFE_NAME_PATTERN (replicated locally to avoid a +# heavy/circular import of accounts_service into the model layer). +SAFE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9_-]+$") + + +def _validate_safe_name(name: str, label: str) -> str: + """Validate that a name is safe to use as a single path component (no separators or traversal sequences).""" + if not name or not SAFE_NAME_PATTERN.fullmatch(name): + raise ValueError( + f"Invalid {label}: '{name}'. Only letters, numbers, underscores and hyphens are allowed." + ) + return name + + +def _validate_safe_config_name(name: str, label: str) -> str: + """Validate a config file name, ignoring an optional .yml extension before checking the base name.""" + base_name = name[:-4] if name.endswith(".yml") else name + _validate_safe_name(base_name, label) + return name class BotAction(BaseModel): @@ -103,6 +125,23 @@ class V2ScriptDeployment(BaseModel): script_config: Optional[str] = Field(default=None, description="Script configuration file name (without .yml extension)") headless: bool = Field(default=False, description="Run in headless mode (no UI)") + @field_validator("instance_name") + @classmethod + def _validate_instance_name(cls, v: str) -> str: + return _validate_safe_name(v, "instance_name") + + @field_validator("credentials_profile") + @classmethod + def _validate_credentials_profile(cls, v: str) -> str: + return _validate_safe_name(v, "credentials_profile") + + @field_validator("script_config") + @classmethod + def _validate_script_config(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + return _validate_safe_config_name(v, "script_config") + class V2ControllerDeployment(BaseModel): """Configuration for deploying a bot with controllers""" @@ -120,3 +159,25 @@ class V2ControllerDeployment(BaseModel): image: str = Field(default="hummingbot/hummingbot:latest", description="Docker image for the Hummingbot instance") script_config: Optional[str] = Field(default=None, description="Generated script configuration file name") headless: bool = Field(default=False, description="Run in headless mode (no UI)") + + @field_validator("instance_name") + @classmethod + def _validate_instance_name(cls, v: str) -> str: + return _validate_safe_name(v, "instance_name") + + @field_validator("credentials_profile") + @classmethod + def _validate_credentials_profile(cls, v: str) -> str: + return _validate_safe_name(v, "credentials_profile") + + @field_validator("controllers_config") + @classmethod + def _validate_controllers_config(cls, v: List[str]) -> List[str]: + return [_validate_safe_config_name(controller, "controllers_config") for controller in v] + + @field_validator("script_config") + @classmethod + def _validate_script_config(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + return _validate_safe_config_name(v, "script_config") diff --git a/models/pagination.py b/models/pagination.py index 32309218..453ac4ec 100644 --- a/models/pagination.py +++ b/models/pagination.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Callable from pydantic import BaseModel, Field, ConfigDict @@ -34,4 +34,62 @@ class PaginatedResponse(BaseModel): ) data: List[Dict[str, Any]] - pagination: Dict[str, Any] \ No newline at end of file + pagination: Dict[str, Any] + + +def paginate_by_cursor( + items: List[Dict[str, Any]], + cursor: Optional[str], + limit: int, + sort_key: Optional[Callable[[Dict[str, Any]], Any]] = None, + reverse: bool = False, +) -> PaginatedResponse: + """ + Apply in-memory cursor-based pagination over items carrying a "_cursor_id" key. + + Each item must already have a "_cursor_id" assigned by the caller. The items are sorted + (by "_cursor_id" unless a custom sort_key is provided), the page after the cursor is sliced, + has_more/next_cursor are computed, and "_cursor_id" is stripped from the returned page. + + Args: + items: Items to paginate, each with a "_cursor_id" key + cursor: Cursor value ("_cursor_id" of the last item of the previous page), if any + limit: Number of items per page + sort_key: Optional sort key; defaults to sorting by "_cursor_id" + reverse: Whether to sort in descending order + + Returns: + PaginatedResponse with the page data and pagination metadata + """ + # Sort for consistent pagination + items.sort(key=sort_key if sort_key is not None else (lambda x: x.get("_cursor_id", "")), reverse=reverse) + + # Find the item after the cursor + start_index = 0 + if cursor: + for i, item in enumerate(items): + if item.get("_cursor_id") == cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + limit + page_items = items[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(items) + next_cursor = page_items[-1].get("_cursor_id") if page_items and has_more else None + + # Clean up cursor_id from response data + for item in page_items: + item.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_items, + pagination={ + "limit": limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(items), + }, + ) \ No newline at end of file diff --git a/routers/accounts.py b/routers/accounts.py index a5e4106a..3a80ad88 100644 --- a/routers/accounts.py +++ b/routers/accounts.py @@ -5,7 +5,7 @@ from deps import get_accounts_service from models import GatewayWalletCredential, SetDefaultWalletRequest -from services.accounts_service import AccountsService +from services.accounts_service import AccountsService, validate_safe_name router = APIRouter(tags=["Accounts"], prefix="/accounts") @@ -58,8 +58,9 @@ async def add_account(account_name: str, accounts_service: AccountsService = Dep Success message when account is created Raises: - HTTPException: 400 if account already exists + HTTPException: 400 if account already exists or the account name is invalid """ + validate_safe_name(account_name, "account name") try: accounts_service.add_account(account_name) return {"message": "Account added successfully."} @@ -79,8 +80,9 @@ async def delete_account(account_name: str, accounts_service: AccountsService = Success message when account is deleted Raises: - HTTPException: 400 if trying to delete master account, 404 if account not found + HTTPException: 400 if trying to delete master account or the account name is invalid, 404 if account not found """ + validate_safe_name(account_name, "account name") try: if account_name == "master_account": raise HTTPException(status_code=400, detail="Cannot delete master account.") @@ -132,7 +134,8 @@ async def add_credential(account_name: str, connector_name: str, credentials: Di await accounts_service.add_credentials(account_name, connector_name, credentials) return {"message": "Connector credentials added successfully."} except Exception as e: - await accounts_service.delete_credentials(account_name, connector_name) + # Rollback is handled inside add_credentials, which only deletes the file for a + # brand-new creation and preserves pre-existing credentials on a failed update. raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/archived_bots.py b/routers/archived_bots.py index 729c471f..0b2bb93b 100644 --- a/routers/archived_bots.py +++ b/routers/archived_bots.py @@ -13,6 +13,20 @@ router = APIRouter(tags=["Archived Bots"], prefix="/archived-bots") +def _validate_db_path(db_path: str) -> str: + """ + Resolve db_path and ensure it points to a database file inside the archived bots directory. + + Raises HTTPException 400 for paths escaping the archived directory and 404 for missing files. + """ + try: + return fs_util.get_archived_db_path(db_path) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + @router.get("/", response_model=List[str]) async def list_databases(): """ @@ -79,8 +93,9 @@ async def get_database_status(db_path: str): Returns: Database status including table health """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) return { "db_path": db_path, "status": db.status, @@ -101,8 +116,9 @@ async def get_database_summary(db_path: str): Returns: Summary statistics of the database contents """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) # Get basic counts orders = db.get_orders() @@ -136,8 +152,9 @@ async def get_database_performance(db_path: str): Returns: Trade-based performance metrics with rolling calculations """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) # Use new trade-based performance calculation performance_data = db.calculate_trade_based_performance() @@ -194,8 +211,9 @@ async def get_database_trades( Returns: List of trades with pagination info """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) trades = db.get_trade_fills() # Apply pagination @@ -235,8 +253,9 @@ async def get_database_orders( Returns: List of orders with pagination info """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) orders = db.get_orders() # Apply status filter if provided @@ -272,8 +291,9 @@ async def get_database_executors(db_path: str): Returns: List of executors with their configurations and results """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) executors = db.get_executors_data() return { @@ -302,8 +322,9 @@ async def get_database_positions( Returns: List of positions with pagination info """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) positions = db.get_positions() # Apply pagination @@ -335,8 +356,9 @@ async def get_database_controllers(db_path: str): Returns: List of controllers that were running with their configurations """ + resolved_db_path = _validate_db_path(db_path) try: - db = HummingbotDatabase(db_path) + db = HummingbotDatabase(resolved_db_path) controllers = db.get_controllers_data() return { diff --git a/routers/bot_orchestration.py b/routers/bot_orchestration.py index df789dfa..ec48d66f 100644 --- a/routers/bot_orchestration.py +++ b/routers/bot_orchestration.py @@ -1,13 +1,10 @@ -import asyncio import logging import os -import shutil -from datetime import datetime, timezone +from datetime import datetime from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query -from database import AsyncDatabaseManager, BotRunRepository -from deps import get_bot_archiver, get_bots_orchestrator, get_database_manager, get_docker_service +from deps import get_bot_archiver, get_bots_orchestrator, get_docker_service from models import StartBotAction, StopBotAction, V2ControllerDeployment, V2ScriptDeployment from services.bots_orchestrator import BotsOrchestrator from services.docker_service import DockerService @@ -188,8 +185,7 @@ async def get_bot_history( @router.post("/start-bot") async def start_bot( action: StartBotAction, - bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Start a bot with the specified configuration. @@ -197,7 +193,6 @@ async def start_bot( Args: action: StartBotAction containing bot configuration parameters bots_manager: Bot orchestrator service dependency - db_manager: Database manager dependency Returns: Dictionary with status and response from bot start operation @@ -215,8 +210,7 @@ async def start_bot( @router.post("/stop-bot") async def stop_bot( action: StopBotAction, - bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Stop a bot with the specified configuration. @@ -224,7 +218,6 @@ async def stop_bot( Args: action: StopBotAction containing bot stop parameters bots_manager: Bot orchestrator service dependency - db_manager: Database manager dependency Returns: Dictionary with status and response from bot stop operation @@ -245,13 +238,7 @@ async def stop_bot( # Update bot run status to STOPPED if stop was successful if response.get("success"): try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.update_bot_run_stopped( - action.bot_name, - final_status=final_status - ) - logger.info(f"Updated bot run status to STOPPED for {action.bot_name}") + await bots_manager.mark_bot_run_stopped(action.bot_name, final_status=final_status) except Exception as e: logger.error(f"Failed to update bot run status: {e}") # Don't fail the stop operation if bot run update fails @@ -269,7 +256,7 @@ async def get_bot_runs( deployment_status: str = None, limit: int = 100, offset: int = 0, - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Get bot runs with optional filtering. @@ -283,54 +270,30 @@ async def get_bot_runs( deployment_status: Filter by deployment status (DEPLOYED, FAILED, ARCHIVED) limit: Maximum number of results to return offset: Number of results to skip - db_manager: Database manager dependency + bots_manager: Bot orchestrator service dependency Returns: List of bot runs with their details """ try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - bot_runs = await bot_run_repo.get_bot_runs( - bot_name=bot_name, - account_name=account_name, - strategy_type=strategy_type, - strategy_name=strategy_name, - run_status=run_status, - deployment_status=deployment_status, - limit=limit, - offset=offset - ) - - # Convert bot runs to dictionaries for JSON serialization - runs_data = [] - for run in bot_runs: - run_dict = { - "id": run.id, - "bot_name": run.bot_name, - "instance_name": run.instance_name, - "deployed_at": run.deployed_at.isoformat() if run.deployed_at else None, - "stopped_at": run.stopped_at.isoformat() if run.stopped_at else None, - "strategy_type": run.strategy_type, - "strategy_name": run.strategy_name, - "config_name": run.config_name, - "account_name": run.account_name, - "image_version": run.image_version, - "deployment_status": run.deployment_status, - "run_status": run.run_status, - "deployment_config": run.deployment_config, - "final_status": run.final_status, - "error_message": run.error_message - } - runs_data.append(run_dict) + runs_data = await bots_manager.get_bot_runs( + bot_name=bot_name, + account_name=account_name, + strategy_type=strategy_type, + strategy_name=strategy_name, + run_status=run_status, + deployment_status=deployment_status, + limit=limit, + offset=offset + ) - return { - "status": "success", - "data": runs_data, - "total": len(runs_data), - "limit": limit, - "offset": offset - } + return { + "status": "success", + "data": runs_data, + "total": len(runs_data), + "limit": limit, + "offset": offset + } except Exception as e: logger.error(f"Failed to get bot runs: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -338,23 +301,20 @@ async def get_bot_runs( @router.get("/bot-runs/stats") async def get_bot_run_stats( - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Get statistics about bot runs. Args: - db_manager: Database manager dependency + bots_manager: Bot orchestrator service dependency Returns: Bot run statistics """ try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - stats = await bot_run_repo.get_bot_run_stats() - - return {"status": "success", "data": stats} + stats = await bots_manager.get_bot_run_stats() + return {"status": "success", "data": stats} except Exception as e: logger.error(f"Failed to get bot run stats: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -363,14 +323,14 @@ async def get_bot_run_stats( @router.get("/bot-runs/{bot_run_id}") async def get_bot_run_by_id( bot_run_id: int, - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Get a specific bot run by ID. Args: bot_run_id: ID of the bot run - db_manager: Database manager dependency + bots_manager: Bot orchestrator service dependency Returns: Bot run details @@ -379,32 +339,12 @@ async def get_bot_run_by_id( HTTPException: 404 if bot run not found """ try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - bot_run = await bot_run_repo.get_bot_run_by_id(bot_run_id) - - if not bot_run: - raise HTTPException(status_code=404, detail=f"Bot run {bot_run_id} not found") - - run_dict = { - "id": bot_run.id, - "bot_name": bot_run.bot_name, - "instance_name": bot_run.instance_name, - "deployed_at": bot_run.deployed_at.isoformat() if bot_run.deployed_at else None, - "stopped_at": bot_run.stopped_at.isoformat() if bot_run.stopped_at else None, - "strategy_type": bot_run.strategy_type, - "strategy_name": bot_run.strategy_name, - "config_name": bot_run.config_name, - "account_name": bot_run.account_name, - "image_version": bot_run.image_version, - "deployment_status": bot_run.deployment_status, - "run_status": bot_run.run_status, - "deployment_config": bot_run.deployment_config, - "final_status": bot_run.final_status, - "error_message": bot_run.error_message - } + run_dict = await bots_manager.get_bot_run_by_id(bot_run_id) + + if not run_dict: + raise HTTPException(status_code=404, detail=f"Bot run {bot_run_id} not found") - return {"status": "success", "data": run_dict} + return {"status": "success", "data": run_dict} except HTTPException: raise except Exception as e: @@ -415,14 +355,14 @@ async def get_bot_run_by_id( @router.delete("/bot-runs/{bot_run_id}") async def delete_bot_run( bot_run_id: int, - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Delete a bot run record by ID. Args: bot_run_id: ID of the bot run to delete - db_manager: Database manager dependency + bots_manager: Bot orchestrator service dependency Returns: Confirmation of deletion @@ -431,34 +371,17 @@ async def delete_bot_run( HTTPException: 404 if bot run not found """ try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - bot_run = await bot_run_repo.delete_bot_run(bot_run_id) - - if not bot_run: - raise HTTPException(status_code=404, detail=f"Bot run {bot_run_id} not found") - - # Also delete the archived bot folder if it exists - archived_dir = os.path.join('bots', 'archived', bot_run.instance_name) - archived_deleted = False - if os.path.isdir(archived_dir): - try: - import subprocess, platform - if platform.system() == 'Darwin': - # Strip macOS ACLs (Docker adds "deny delete" ACLs) - subprocess.run(['chmod', '-R', '-N', archived_dir], check=False) - shutil.rmtree(archived_dir) - archived_deleted = True - logger.info(f"Deleted archived folder: {archived_dir}") - except Exception as e: - logger.warning(f"Failed to delete archived folder {archived_dir}: {e}") + result = await bots_manager.delete_bot_run(bot_run_id) - return { - "status": "success", - "message": f"Bot run {bot_run_id} deleted successfully", - "bot_name": bot_run.bot_name, - "archived_folder_deleted": archived_deleted - } + if not result: + raise HTTPException(status_code=404, detail=f"Bot run {bot_run_id} not found") + + return { + "status": "success", + "message": f"Bot run {bot_run_id} deleted successfully", + "bot_name": result["bot_name"], + "archived_folder_deleted": result["archived_folder_deleted"] + } except HTTPException: raise except Exception as e: @@ -466,159 +389,6 @@ async def delete_bot_run( raise HTTPException(status_code=500, detail=str(e)) -async def _background_stop_and_archive( - bot_name: str, - container_name: str, - bot_name_for_orchestrator: str, - skip_order_cancellation: bool, - archive_locally: bool, - s3_bucket: str, - bots_manager: BotsOrchestrator, - docker_manager: DockerService, - bot_archiver: BotArchiver, - db_manager: AsyncDatabaseManager -): - """Background task to handle the stop and archive process""" - try: - logger.info(f"Starting background stop-and-archive for {bot_name}") - - # Step 1: Capture bot final status before stopping (while bot is still running) - logger.info(f"Capturing final status for {bot_name_for_orchestrator}") - final_status = None - try: - final_status = bots_manager.get_bot_status(bot_name_for_orchestrator) - logger.info(f"Captured final status for {bot_name_for_orchestrator}: {final_status}") - except Exception as e: - logger.warning(f"Failed to capture final status for {bot_name_for_orchestrator}: {e}") - - # Step 2: Update bot run with stopped_at timestamp and final status before stopping - try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.update_bot_run_stopped( - bot_name, - final_status=final_status - ) - logger.info(f"Updated bot run with stopped_at timestamp and final status for {bot_name}") - except Exception as e: - logger.error(f"Failed to update bot run with stopped status: {e}") - # Continue with stop process even if database update fails - - # Step 3: Mark the bot as stopping, and stop the bot trading process - bots_manager.set_bot_stopping(bot_name_for_orchestrator) - logger.info(f"Stopping bot trading process for {bot_name_for_orchestrator}") - stop_response = await bots_manager.stop_bot( - bot_name_for_orchestrator, - skip_order_cancellation=skip_order_cancellation, - async_backend=True # Always use async for background tasks - ) - - if not stop_response or not stop_response.get("success", False): - error_msg = stop_response.get('error', 'Unknown error') if stop_response else 'No response from bot orchestrator' - logger.error(f"Failed to stop bot process: {error_msg}") - return - - # Step 4: Wait for graceful shutdown (15 seconds as requested) - logger.info(f"Waiting 15 seconds for bot {bot_name} to gracefully shutdown") - await asyncio.sleep(15) - - # Step 5: Stop the container with monitoring - max_retries = 10 - retry_interval = 2 - container_stopped = False - - for i in range(max_retries): - logger.info(f"Attempting to stop container {container_name} (attempt {i+1}/{max_retries})") - docker_manager.stop_container(container_name) - - # Check if container is already stopped - container_status = docker_manager.get_container_status(container_name) - if container_status.get("state", {}).get("status") == "exited": - container_stopped = True - logger.info(f"Container {container_name} is already stopped") - break - - await asyncio.sleep(retry_interval) - - if not container_stopped: - logger.error(f"Failed to stop container {container_name} after {max_retries} attempts") - return - - # Step 6: Archive the bot data - instance_dir = os.path.join('bots', 'instances', container_name) - logger.info(f"Archiving bot data from {instance_dir}") - - try: - if archive_locally: - bot_archiver.archive_locally(container_name, instance_dir) - else: - bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) - logger.info(f"Successfully archived bot data for {container_name}") - except Exception as e: - logger.error(f"Archive failed: {str(e)}") - # Continue with removal even if archive fails - - # Step 7: Remove the container - logging.info(f"Removing container {container_name}") - remove_response = docker_manager.remove_container(container_name, force=False) - - if not remove_response.get("success"): - # If graceful remove fails, try force remove - logging.warning("Graceful container removal failed, attempting force removal") - remove_response = docker_manager.remove_container(container_name, force=True) - - if remove_response.get("success"): - logging.info(f"Successfully completed stop-and-archive for bot {bot_name}") - - # Step 8: Update bot run deployment status to ARCHIVED - try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.update_bot_run_archived(bot_name) - logger.info(f"Updated bot run deployment status to ARCHIVED for {bot_name}") - except Exception as e: - logger.error(f"Failed to update bot run to archived: {e}") - else: - logging.error(f"Failed to remove container {container_name}") - - # Update bot run with error status (but keep stopped_at timestamp from earlier) - try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.update_bot_run_stopped( - bot_name, - error_message="Failed to remove container during archive process" - ) - logger.info(f"Updated bot run with error status for {bot_name}") - except Exception as e: - logger.error(f"Failed to update bot run with error: {e}") - - except Exception as e: - logging.error(f"Error in background stop-and-archive for {bot_name}: {str(e)}") - - # Update bot run with error status - try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.update_bot_run_stopped( - bot_name, - error_message=str(e) - ) - logger.info(f"Updated bot run with error status for {bot_name}") - except Exception as db_error: - logger.error(f"Failed to update bot run with error: {db_error}") - finally: - # Always clear the stopping status when the background task completes - bots_manager.clear_bot_stopping(bot_name_for_orchestrator) - logger.info(f"Cleared stopping status for bot {bot_name}") - - # Remove bot from active_bots and clear all MQTT data - if bot_name_for_orchestrator in bots_manager.active_bots: - bots_manager.mqtt_manager.clear_bot_data(bot_name_for_orchestrator) - del bots_manager.active_bots[bot_name_for_orchestrator] - logger.info(f"Removed bot {bot_name_for_orchestrator} from active_bots and cleared MQTT data") - - @router.post("/stop-and-archive-bot/{bot_name}") async def stop_and_archive_bot( bot_name: str, @@ -628,8 +398,7 @@ async def stop_and_archive_bot( s3_bucket: str = None, bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), docker_manager: DockerService = Depends(get_docker_service), - bot_archiver: BotArchiver = Depends(get_bot_archiver), - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bot_archiver: BotArchiver = Depends(get_bot_archiver) ): """ Gracefully stop a bot and archive its data in the background. @@ -677,17 +446,15 @@ async def stop_and_archive_bot( # Add the background task background_tasks.add_task( - _background_stop_and_archive, + bots_manager.stop_and_archive_bot, bot_name=actual_bot_name, container_name=container_name, bot_name_for_orchestrator=bot_name_for_orchestrator, skip_order_cancellation=skip_order_cancellation, archive_locally=archive_locally, s3_bucket=s3_bucket, - bots_manager=bots_manager, docker_manager=docker_manager, - bot_archiver=bot_archiver, - db_manager=db_manager + bot_archiver=bot_archiver ) return { @@ -713,7 +480,7 @@ async def stop_and_archive_bot( async def deploy_v2_controllers( deployment: V2ControllerDeployment, docker_manager: DockerService = Depends(get_docker_service), - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Deploy a V2 strategy with controllers by generating the script config and creating the instance. @@ -777,23 +544,16 @@ async def deploy_v2_controllers( response["unique_instance_name"] = unique_instance_name # Track bot run if deployment was successful - try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.create_bot_run( - bot_name=unique_instance_name, - instance_name=unique_instance_name, - strategy_type="controller", - strategy_name="v2_with_controllers", - account_name=deployment.credentials_profile, - config_name=script_config_filename, - image_version=deployment.image, - deployment_config=deployment.dict() - ) - logger.info(f"Created bot run record for controller deployment {unique_instance_name}") - except Exception as e: - logger.error(f"Failed to create bot run record: {e}") - # Don't fail the deployment if bot run creation fails + await bots_manager.create_bot_run( + bot_name=unique_instance_name, + instance_name=unique_instance_name, + strategy_type="controller", + strategy_name="v2_with_controllers", + account_name=deployment.credentials_profile, + config_name=script_config_filename, + image_version=deployment.image, + deployment_config=deployment.dict() + ) return response @@ -806,7 +566,7 @@ async def deploy_v2_controllers( async def deploy_v2_script( deployment: V2ScriptDeployment, docker_manager: DockerService = Depends(get_docker_service), - db_manager: AsyncDatabaseManager = Depends(get_database_manager) + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) ): """ Deploy a V2 script bot with optional script configuration. @@ -839,23 +599,16 @@ async def deploy_v2_script( response["unique_instance_name"] = unique_instance_name # Track bot run if deployment was successful - try: - async with db_manager.get_session_context() as session: - bot_run_repo = BotRunRepository(session) - await bot_run_repo.create_bot_run( - bot_name=unique_instance_name, - instance_name=unique_instance_name, - strategy_type="script", - strategy_name=deployment.script or "default", - account_name=deployment.credentials_profile, - config_name=deployment.script_config, - image_version=deployment.image, - deployment_config=deployment.dict() - ) - logger.info(f"Created bot run record for script deployment {unique_instance_name}") - except Exception as e: - logger.error(f"Failed to create bot run record: {e}") - # Don't fail the deployment if bot run creation fails + await bots_manager.create_bot_run( + bot_name=unique_instance_name, + instance_name=unique_instance_name, + strategy_type="script", + strategy_name=deployment.script or "default", + account_name=deployment.credentials_profile, + config_name=deployment.script_config, + image_version=deployment.image, + deployment_config=deployment.dict() + ) return response diff --git a/routers/market_data.py b/routers/market_data.py index 1d358a28..c29508bd 100644 --- a/routers/market_data.py +++ b/routers/market_data.py @@ -68,18 +68,16 @@ async def get_candles(request: Request, candles_config: CandlesConfigRequest): try: market_data_service: MarketDataService = request.app.state.market_data_service - # Validate trading pair exists on the exchange before starting a feed - try: - await market_data_service.validate_trading_pair( - candles_config.connector_name, candles_config.trading_pair, candles_config.interval - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - candles_cfg = CandlesConfig( connector=candles_config.connector_name, trading_pair=candles_config.trading_pair, interval=candles_config.interval, max_records=candles_config.max_records) - candles_feed = market_data_service.get_candles_feed(candles_cfg) + + # Creating the feed validates the trading pair on first use (cache hit afterwards); + # an invalid pair raises ValueError. + try: + candles_feed = await market_data_service.get_candles_feed(candles_cfg) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) # Wait for the candles feed to be ready with a timeout timeout = settings.market_data.candles_ready_timeout @@ -143,21 +141,18 @@ async def get_historical_candles(request: Request, config: HistoricalCandlesConf try: market_data_service: MarketDataService = request.app.state.market_data_service - # Validate trading pair exists on the exchange before fetching - try: - await market_data_service.validate_trading_pair( - config.connector_name, config.trading_pair, config.interval - ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - candles_config = CandlesConfig( connector=config.connector_name, trading_pair=config.trading_pair, interval=config.interval ) - candles = market_data_service.get_candles_feed(candles_config) + # Creating the feed validates the trading pair on first use (cache hit afterwards); + # an invalid pair raises ValueError. + try: + candles = await market_data_service.get_candles_feed(candles_config) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) timeout = settings.market_data.candles_ready_timeout historical_data = await asyncio.wait_for( diff --git a/routers/portfolio.py b/routers/portfolio.py index 0981701b..fc951f8b 100644 --- a/routers/portfolio.py +++ b/routers/portfolio.py @@ -92,47 +92,31 @@ async def get_portfolio_history( start_time_dt = datetime.fromtimestamp(filter_request.start_time / 1000) if filter_request.start_time else None end_time_dt = datetime.fromtimestamp(filter_request.end_time / 1000) if filter_request.end_time else None - if not filter_request.account_names: - # Get history for all accounts - data, next_cursor, has_more = await accounts_service.load_account_state_history( - limit=filter_request.limit, - cursor=filter_request.cursor, - start_time=start_time_dt, - end_time=end_time_dt, - interval=filter_request.interval - ) - else: - # Get history for specific accounts - need to aggregate - all_data = [] - for account_name in filter_request.account_names: - acc_data, _, _ = await accounts_service.get_account_state_history( - account_name=account_name, - limit=filter_request.limit, - cursor=filter_request.cursor, - start_time=start_time_dt, - end_time=end_time_dt, - interval=filter_request.interval - ) - all_data.extend(acc_data) - - # Sort by timestamp and apply pagination - all_data.sort(key=lambda x: x.get("timestamp", ""), reverse=True) - - # Apply limit - data = all_data[:filter_request.limit] - has_more = len(all_data) > filter_request.limit - next_cursor = data[-1]["timestamp"] if data and has_more else None - - # Apply connector filter to the data if specified + # Single query handles both all-accounts and filtered-accounts cases (IN filter), + # returning data ordered by timestamp desc with a consistent pagination cursor. + data, next_cursor, has_more = await accounts_service.load_account_state_history( + limit=filter_request.limit, + cursor=filter_request.cursor, + start_time=start_time_dt, + end_time=end_time_dt, + interval=filter_request.interval, + account_names=filter_request.account_names + ) + + # Apply connector filter to the data if specified. Each history item is + # {"timestamp": ..., "state": {account_name: {connector_name: [tokens]}}}, + # so connectors live directly under each account inside "state". if filter_request.connector_names: for item in data: - for account_name, account_data in item.items(): - if isinstance(account_data, dict) and "connectors" in account_data: - filtered_connectors = {} - for connector_name in filter_request.connector_names: - if connector_name in account_data["connectors"]: - filtered_connectors[connector_name] = account_data["connectors"][connector_name] - account_data["connectors"] = filtered_connectors + state = item.get("state", {}) + for account_name, account_data in state.items(): + if isinstance(account_data, dict): + filtered_connectors = { + connector_name: account_data[connector_name] + for connector_name in filter_request.connector_names + if connector_name in account_data + } + state[account_name] = filtered_connectors return PaginatedResponse( data=data, diff --git a/routers/trading.py b/routers/trading.py index a7f25d34..4b6df289 100644 --- a/routers/trading.py +++ b/routers/trading.py @@ -1,17 +1,13 @@ import logging import math - from typing import Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException - -# Create module-specific logger -logger = logging.getLogger(__name__) from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType from pydantic import BaseModel from starlette import status -from deps import get_accounts_service, get_connector_service +from deps import get_accounts_service, get_connector_service, get_trading_history_service from models import ( ActiveOrderFilterRequest, FundingPaymentFilterRequest, @@ -23,7 +19,12 @@ TradeResponse, ) from models.accounts import LeverageRequest, PositionModeRequest +from models.pagination import paginate_by_cursor from services.accounts_service import AccountsService +from services.trading_history_service import TradingHistoryService + +# Create module-specific logger +logger = logging.getLogger(__name__) router = APIRouter(tags=["Trading"], prefix="/trading") @@ -167,39 +168,8 @@ async def get_positions( logger.warning(f"Failed to get positions for {account_name}/{connector_name}: {e}") - # Sort by cursor_id for consistent pagination - all_positions.sort(key=lambda x: x.get("_cursor_id", "")) - - # Apply cursor-based pagination - start_index = 0 - if filter_request.cursor: - # Find the position after the cursor - for i, position in enumerate(all_positions): - if position.get("_cursor_id") == filter_request.cursor: - start_index = i + 1 - break - - # Get page of results - end_index = start_index + filter_request.limit - page_positions = all_positions[start_index:end_index] - - # Determine next cursor and has_more - has_more = end_index < len(all_positions) - next_cursor = page_positions[-1].get("_cursor_id") if page_positions and has_more else None - - # Clean up cursor_id from response data - for position in page_positions: - position.pop("_cursor_id", None) - - return PaginatedResponse( - data=page_positions, - pagination={ - "limit": filter_request.limit, - "has_more": has_more, - "next_cursor": next_cursor, - "total_count": len(all_positions), - }, - ) + # Sort by cursor_id and apply cursor-based pagination + return paginate_by_cursor(all_positions, filter_request.cursor, filter_request.limit) except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching positions: {str(e)}") @@ -265,39 +235,8 @@ async def get_active_orders( logger.warning(f"Failed to get active orders for {account_name}/{connector_name}: {e}") - # Sort by cursor_id for consistent pagination - all_active_orders.sort(key=lambda x: x.get("_cursor_id", "")) - - # Apply cursor-based pagination - start_index = 0 - if filter_request.cursor: - # Find the order after the cursor - for i, order in enumerate(all_active_orders): - if order.get("_cursor_id") == filter_request.cursor: - start_index = i + 1 - break - - # Get page of results - end_index = start_index + filter_request.limit - page_orders = all_active_orders[start_index:end_index] - - # Determine next cursor and has_more - has_more = end_index < len(all_active_orders) - next_cursor = page_orders[-1].get("_cursor_id") if page_orders and has_more else None - - # Clean up cursor_id from response data - for order in page_orders: - order.pop("_cursor_id", None) - - return PaginatedResponse( - data=page_orders, - pagination={ - "limit": filter_request.limit, - "has_more": has_more, - "next_cursor": next_cursor, - "total_count": len(all_active_orders), - }, - ) + # Sort by cursor_id and apply cursor-based pagination + return paginate_by_cursor(all_active_orders, filter_request.cursor, filter_request.limit) except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching active orders: {str(e)}") @@ -307,7 +246,7 @@ async def get_active_orders( @router.post("/orders/search", response_model=PaginatedResponse) async def get_orders( filter_request: OrderFilterRequest, - accounts_service: AccountsService = Depends(get_accounts_service), + trading_history_service: TradingHistoryService = Depends(get_trading_history_service), connector_service = Depends(get_connector_service) ): """ @@ -333,7 +272,7 @@ async def get_orders( # Collect orders from all specified accounts for account_name in accounts_to_check: try: - orders = await accounts_service.get_orders( + orders = await trading_history_service.get_orders( account_name=account_name, connector_name=( filter_request.connector_names[0] @@ -367,38 +306,13 @@ async def get_orders( if filter_request.trading_pairs and len(filter_request.trading_pairs) > 1: all_orders = [order for order in all_orders if order.get("trading_pair") in filter_request.trading_pairs] - # Sort by timestamp (most recent first) and then by cursor_id for consistency - all_orders.sort(key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), reverse=True) - - # Apply cursor-based pagination - start_index = 0 - if filter_request.cursor: - # Find the order after the cursor - for i, order in enumerate(all_orders): - if order.get("_cursor_id") == filter_request.cursor: - start_index = i + 1 - break - - # Get page of results - end_index = start_index + filter_request.limit - page_orders = all_orders[start_index:end_index] - - # Determine next cursor and has_more - has_more = end_index < len(all_orders) - next_cursor = page_orders[-1].get("_cursor_id") if page_orders and has_more else None - - # Clean up cursor_id from response data - for order in page_orders: - order.pop("_cursor_id", None) - - return PaginatedResponse( - data=page_orders, - pagination={ - "limit": filter_request.limit, - "has_more": has_more, - "next_cursor": next_cursor, - "total_count": len(all_orders), - }, + # Sort by timestamp (most recent first) then cursor_id, and apply cursor-based pagination + return paginate_by_cursor( + all_orders, + filter_request.cursor, + filter_request.limit, + sort_key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), + reverse=True, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching orders: {str(e)}") @@ -408,7 +322,7 @@ async def get_orders( @router.post("/trades", response_model=PaginatedResponse) async def get_trades( filter_request: TradeFilterRequest, - accounts_service: AccountsService = Depends(get_accounts_service), + trading_history_service: TradingHistoryService = Depends(get_trading_history_service), connector_service = Depends(get_connector_service) ): """ @@ -434,7 +348,7 @@ async def get_trades( # Collect trades from all specified accounts for account_name in accounts_to_check: try: - trades = await accounts_service.get_trades( + trades = await trading_history_service.get_trades( account_name=account_name, connector_name=( filter_request.connector_names[0] @@ -474,38 +388,13 @@ async def get_trades( if filter_request.trade_types and len(filter_request.trade_types) > 1: all_trades = [trade for trade in all_trades if trade.get("trade_type") in filter_request.trade_types] - # Sort by timestamp (most recent first) and then by cursor_id for consistency - all_trades.sort(key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), reverse=True) - - # Apply cursor-based pagination - start_index = 0 - if filter_request.cursor: - # Find the trade after the cursor - for i, trade in enumerate(all_trades): - if trade.get("_cursor_id") == filter_request.cursor: - start_index = i + 1 - break - - # Get page of results - end_index = start_index + filter_request.limit - page_trades = all_trades[start_index:end_index] - - # Determine next cursor and has_more - has_more = end_index < len(all_trades) - next_cursor = page_trades[-1].get("_cursor_id") if page_trades and has_more else None - - # Clean up cursor_id from response data - for trade in page_trades: - trade.pop("_cursor_id", None) - - return PaginatedResponse( - data=page_trades, - pagination={ - "limit": filter_request.limit, - "has_more": has_more, - "next_cursor": next_cursor, - "total_count": len(all_trades), - }, + # Sort by timestamp (most recent first) then cursor_id, and apply cursor-based pagination + return paginate_by_cursor( + all_trades, + filter_request.cursor, + filter_request.limit, + sort_key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), + reverse=True, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching trades: {str(e)}") @@ -609,7 +498,7 @@ async def set_leverage( @router.post("/funding-payments", response_model=PaginatedResponse) async def get_funding_payments( filter_request: FundingPaymentFilterRequest, - accounts_service: AccountsService = Depends(get_accounts_service), + trading_history_service: TradingHistoryService = Depends(get_trading_history_service), connector_service = Depends(get_connector_service) ): """ @@ -647,7 +536,7 @@ async def get_funding_payments( # Only fetch funding payments from perpetual connectors if connector_name in all_connectors[account_name] and "_perpetual" in connector_name: try: - payments = await accounts_service.get_funding_payments( + payments = await trading_history_service.get_funding_payments( account_name=account_name, connector_name=connector_name, trading_pair=filter_request.trading_pair, @@ -665,38 +554,13 @@ async def get_funding_payments( logger.warning(f"Failed to get funding payments for {account_name}/{connector_name}: {e}") - # Sort by timestamp (most recent first) and then by cursor_id for consistency - all_funding_payments.sort(key=lambda x: (x.get("timestamp", ""), x.get("_cursor_id", "")), reverse=True) - - # Apply cursor-based pagination - start_index = 0 - if filter_request.cursor: - # Find the payment after the cursor - for i, payment in enumerate(all_funding_payments): - if payment.get("_cursor_id") == filter_request.cursor: - start_index = i + 1 - break - - # Get page of results - end_index = start_index + filter_request.limit - page_payments = all_funding_payments[start_index:end_index] - - # Determine next cursor and has_more - has_more = end_index < len(all_funding_payments) - next_cursor = page_payments[-1].get("_cursor_id") if page_payments and has_more else None - - # Clean up cursor_id from response data - for payment in page_payments: - payment.pop("_cursor_id", None) - - return PaginatedResponse( - data=page_payments, - pagination={ - "limit": filter_request.limit, - "has_more": has_more, - "next_cursor": next_cursor, - "total_count": len(all_funding_payments), - }, + # Sort by timestamp (most recent first) then cursor_id, and apply cursor-based pagination + return paginate_by_cursor( + all_funding_payments, + filter_request.cursor, + filter_request.limit, + sort_key=lambda x: (x.get("timestamp", ""), x.get("_cursor_id", "")), + reverse=True, ) except Exception as e: diff --git a/routers/websocket.py b/routers/websocket.py index 31241e5e..96d8d83d 100644 --- a/routers/websocket.py +++ b/routers/websocket.py @@ -24,11 +24,8 @@ def _authenticate_websocket(websocket: WebSocket) -> bool: """ Authenticate a WebSocket connection using Basic Auth from headers or query params. - Returns True if authenticated (or debug mode), False otherwise. + Returns True if authenticated, False otherwise. """ - if settings.security.debug_mode: - return True - # Try Authorization header first auth_header = websocket.headers.get("authorization", "") if auth_header.startswith("Basic "): diff --git a/services/accounts_service.py b/services/accounts_service.py index 9d48830e..61f92bae 100644 --- a/services/accounts_service.py +++ b/services/accounts_service.py @@ -1,434 +1,42 @@ import asyncio import logging -import time +import re from datetime import datetime, timezone from decimal import Decimal -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional from fastapi import HTTPException from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger -from hummingbot.connector.connector_base import ConnectorBase from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType from config import settings -from database import AccountRepository, AsyncDatabaseManager, FundingRepository, OrderRepository, TradeRepository +from database import AccountRepository, AsyncDatabaseManager from services.gateway_client import GatewayClient from services.gateway_transaction_poller import GatewayTransactionPoller +from services.gateway_wallet_service import GatewayWalletService, balance_entry +from services.perpetual_trading_service import PerpetualTradingService +from services.portfolio_analytics_service import PortfolioAnalyticsService from utils.file_system import fs_util # Create module-specific logger logger = logging.getLogger(__name__) +# Safe single path component names: prevents path traversal via '/', '\' or '..' +SAFE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9_-]+$") -class AccountTradingInterface: - """ - ScriptStrategyBase-compatible interface for executor trading. - - This class provides the exact interface that Hummingbot executors expect - from a strategy object, backed by AccountsService resources. - - IMPORTANT: This class does NOT maintain its own connector cache. Instead, it - uses the shared ConnectorManager via AccountsService which is the single source - of truth for all connector instances. - - Executors use the following interface from strategy: - - current_timestamp: float property - - buy(connector_name, trading_pair, amount, order_type, price, position_action) -> str - - sell(connector_name, trading_pair, amount, order_type, price, position_action) -> str - - cancel(connector_name, trading_pair, order_id) -> str - - get_active_orders(connector_name) -> List - ExecutorBase also accesses: - - connectors: Dict[str, ConnectorBase] (accessed directly in ExecutorBase.__init__) +def validate_safe_name(name: str, label: str = "name") -> str: """ - - def __init__( - self, - accounts_service: 'AccountsService', - account_name: str - ): - """ - Initialize AccountTradingInterface. - - Args: - accounts_service: AccountsService instance for connector access - account_name: Account to use for connectors - """ - self._accounts_service = accounts_service - self._account_name = account_name - - # Track active markets (connector_name -> set of trading_pairs) - self._markets: Dict[str, Set[str]] = {} - - # Timestamp tracking - self._current_timestamp: float = time.time() - - # Lock for async operations - self._lock = asyncio.Lock() - - @property - def account_name(self) -> str: - """Return the account name for this trading interface.""" - return self._account_name - - @property - def connectors(self) -> Dict[str, ConnectorBase]: - """ - Return connectors for this account from the connector service. - - This returns the actual connectors that are already initialized and running, - avoiding any duplicate caching or connector management. - """ - if not self._accounts_service._connector_service: - return {} - all_connectors = self._accounts_service._connector_service.get_all_trading_connectors() - return all_connectors.get(self._account_name, {}) - - @property - def markets(self) -> Dict[str, Set[str]]: - """Return active markets configuration.""" - return self._markets - - @property - def current_timestamp(self) -> float: - """Return current timestamp (updated by control loop).""" - return self._current_timestamp - - def update_timestamp(self): - """Update the current timestamp. Called by ExecutorService control loop.""" - self._current_timestamp = time.time() - - async def ensure_connector(self, connector_name: str) -> ConnectorBase: - """ - Ensure connector is loaded and available. - - This method uses the connector service which already caches connectors. - It also ensures the MarketDataProvider has access to the connector for - order book initialization. - - Args: - connector_name: Name of the connector - - Returns: - The connector instance - """ - # Get connector from connector service (already cached there) - connector = await self._accounts_service._connector_service.get_trading_connector( - self._account_name, - connector_name - ) - return connector - - async def add_market( - self, - connector_name: str, - trading_pair: str, - order_book_timeout: float = 10.0 - ): - """ - Add a trading pair to active markets with full order book support. - - This method ensures: - 1. Connector is loaded - 2. Order book is initialized and has valid data - 3. Rate sources are initialized for price feeds - - Args: - connector_name: Name of the connector - trading_pair: Trading pair to add - order_book_timeout: Timeout in seconds to wait for order book data - """ - await self.ensure_connector(connector_name) - - if connector_name not in self._markets: - self._markets[connector_name] = set() - - # Check if already tracking this pair - if trading_pair in self._markets[connector_name]: - logger.debug(f"Market {connector_name}/{trading_pair} already active") - return - - self._markets[connector_name].add(trading_pair) - - # Get connector and its order book tracker - connector = self.connectors.get(connector_name) - if not connector: - raise ValueError(f"Connector {connector_name} not available. Check credentials.") - tracker = connector.order_book_tracker - - # Check if order book already exists, if not initialize it dynamically - if trading_pair in tracker.order_books: - logger.debug(f"Order book already exists for {connector_name}/{trading_pair}") - else: - logger.debug(f"Order book not found for {connector_name}/{trading_pair}, initializing dynamically") - market_data_service = self._accounts_service._market_data_service - if market_data_service: - try: - success = await market_data_service.initialize_order_book( - connector_name, trading_pair, - account_name=self._account_name, - timeout=order_book_timeout - ) - if not success: - logger.warning(f"Order book for {connector_name}/{trading_pair} not ready after timeout") - except Exception as e: - logger.warning(f"Exception initializing order book: {e}") - - # Register the trading pair with the connector - self._register_trading_pair_with_connector(connector, trading_pair) - - async def _wait_for_order_book_ready( - self, - tracker, - trading_pair: str, - timeout: float = 30.0 - ) -> bool: - """ - Wait for an order book to have valid data. - - Args: - tracker: Order book tracker instance - trading_pair: Trading pair to wait for - timeout: Maximum time to wait in seconds - - Returns: - True if order book is ready, False if timeout - """ - import asyncio - waited = 0 - interval = 0.5 - while waited < timeout: - if trading_pair in tracker.order_books: - ob = tracker.order_books[trading_pair] - try: - bids, asks = ob.snapshot - if len(bids) > 0 and len(asks) > 0: - logger.info(f"Order book for {trading_pair} is ready with {len(bids)} bids and {len(asks)} asks") - return True - except Exception: - pass - await asyncio.sleep(interval) - waited += interval - logger.warning(f"Timeout waiting for {trading_pair} order book to be ready") - return False - - def _register_trading_pair_with_connector( - self, - connector: ConnectorBase, - trading_pair: str - ): - """ - Register a trading pair with the connector's internal structures. - - This is needed for methods like get_order_book() to work properly. - Different connector types may store trading pairs differently. - - Args: - connector: The connector instance - trading_pair: Trading pair to register - """ - if trading_pair not in connector._trading_pairs: - connector._trading_pairs.append(trading_pair) - logger.debug(f"Registered {trading_pair} with connector {type(connector).__name__}") - - async def remove_market( - self, - connector_name: str, - trading_pair: str, - remove_order_book: bool = True - ): - """ - Remove a trading pair from active markets and optionally cleanup order book. - - Args: - connector_name: Name of the connector - trading_pair: Trading pair to remove - remove_order_book: Whether to remove the order book (default True) - """ - if connector_name not in self._markets: - return - - self._markets[connector_name].discard(trading_pair) - if not self._markets[connector_name]: - del self._markets[connector_name] - - # Remove order book if requested - if remove_order_book: - market_data_service = self._accounts_service._market_data_service - if market_data_service: - try: - success = await market_data_service.remove_trading_pair( - connector_name, - trading_pair, - account_name=self._account_name - ) - if success: - logger.info(f"Removed order book for {connector_name}/{trading_pair}") - else: - logger.debug(f"Order book for {trading_pair} was not being tracked") - except Exception as e: - logger.warning(f"Failed to remove order book for {trading_pair}: {e}") - - # ======================================== - # ScriptStrategyBase-compatible methods - # These are called by executors via self._strategy.method() - # ======================================== - - def buy( - self, - connector_name: str, - trading_pair: str, - amount: Decimal, - order_type: OrderType, - price: Decimal = Decimal("NaN"), - position_action: PositionAction = PositionAction.NIL - ) -> str: - """ - Place a buy order. - - Args: - connector_name: Name of the connector - trading_pair: Trading pair - amount: Order amount in base currency - order_type: Type of order (LIMIT, MARKET, etc.) - price: Order price (for limit orders) - position_action: Position action for perpetuals - - Returns: - Client order ID - """ - connector = self.connectors.get(connector_name) - if not connector: - raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") - - return connector.buy( - trading_pair=trading_pair, - amount=amount, - order_type=order_type, - price=price, - position_action=position_action - ) - - def sell( - self, - connector_name: str, - trading_pair: str, - amount: Decimal, - order_type: OrderType, - price: Decimal = Decimal("NaN"), - position_action: PositionAction = PositionAction.NIL - ) -> str: - """ - Place a sell order. - - Args: - connector_name: Name of the connector - trading_pair: Trading pair - amount: Order amount in base currency - order_type: Type of order (LIMIT, MARKET, etc.) - price: Order price (for limit orders) - position_action: Position action for perpetuals - - Returns: - Client order ID - """ - connector = self.connectors.get(connector_name) - if not connector: - raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") - - return connector.sell( - trading_pair=trading_pair, - amount=amount, - order_type=order_type, - price=price, - position_action=position_action - ) - - def cancel( - self, - connector_name: str, - trading_pair: str, - order_id: str - ) -> str: - """ - Cancel an order. - - Args: - connector_name: Name of the connector - trading_pair: Trading pair - order_id: Client order ID to cancel - - Returns: - Client order ID that was cancelled - """ - connector = self.connectors.get(connector_name) - if not connector: - raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") - - return connector.cancel(trading_pair=trading_pair, client_order_id=order_id) - - def get_active_orders(self, connector_name: str) -> List: - """ - Get active orders for a connector. - - Args: - connector_name: Name of the connector - - Returns: - List of active in-flight orders - """ - connector = self.connectors.get(connector_name) - if not connector: - return [] - return list(connector.in_flight_orders.values()) - - # ======================================== - # Additional helper methods - # ======================================== - - def get_connector(self, connector_name: str) -> Optional[ConnectorBase]: - """ - Get a connector by name from the shared ConnectorManager. - - Args: - connector_name: Name of the connector - - Returns: - The connector instance or None if not loaded - """ - return self.connectors.get(connector_name) - - def is_connector_loaded(self, connector_name: str) -> bool: - """ - Check if a connector is loaded in the shared ConnectorManager. - - Args: - connector_name: Name of the connector - - Returns: - True if connector is loaded - """ - return connector_name in self.connectors - - def get_all_trading_pairs(self) -> Dict[str, Set[str]]: - """ - Get all active trading pairs by connector. - - Returns: - Dictionary mapping connector names to sets of trading pairs - """ - return {k: v.copy() for k, v in self._markets.items()} - - async def cleanup(self): - """ - Cleanup resources. Called when shutting down. - - Note: This does NOT clean up connectors since they are managed by the - shared ConnectorManager, not by AccountTradingInterface. - """ - # Clear only local state (markets tracking) - self._markets.clear() - logger.info(f"AccountTradingInterface cleanup completed for account {self._account_name}") + Validate that a name is safe to use as a single path component (no separators or traversal sequences). + :param name: The name to validate. + :param label: Human readable label used in the error message. + :return: The validated name. + :raises HTTPException: 400 if the name is invalid. + """ + if not name or not SAFE_NAME_PATTERN.fullmatch(name): + raise HTTPException(status_code=400, + detail=f"Invalid {label}: '{name}'. Only letters, numbers, underscores and hyphens are allowed.") + return name class AccountsService: @@ -440,15 +48,18 @@ class AccountsService: default_quotes = { "hyperliquid": "USDC", "hyperliquid_perpetual": "USD", + "lighter": "USDC", + "lighter_perpetual": "USDC", "xrpl": "RLUSD", "kraken": "USD", } potential_wrapped_tokens = ["ETH", "SOL", "BNB", "POL", "AVAX"] - - # Cache for storing last successful prices by trading pair - _last_known_prices = {} def __init__(self, + db_manager: AsyncDatabaseManager, + connector_service, + market_data_service, + trading_service, account_update_interval: int = 5, default_quote: str = "USDT", gateway_url: str = "http://localhost:15888"): @@ -456,6 +67,10 @@ def __init__(self, Initialize the AccountsService. Args: + db_manager: AsyncDatabaseManager for persistence (shared, created once at startup) + connector_service: UnifiedConnectorService (required, injected from main.py) + market_data_service: MarketDataService (required, injected from main.py) + trading_service: TradingService (required, injected from main.py) account_update_interval: How often to update account states in minutes (default: 5) default_quote: Default quote currency for trading pairs (default: "USDT") gateway_url: URL for Gateway service (default: "http://localhost:15888") @@ -468,19 +83,29 @@ def __init__(self, self._update_account_state_task: Optional[asyncio.Task] = None self._order_status_polling_task: Optional[asyncio.Task] = None - # Database setup for account states and orders - self.db_manager = AsyncDatabaseManager(settings.database.url) - self._db_initialized = False + # Cache for storing last successful prices by trading pair (per-instance) + self._last_known_prices = {} - # Services injected from main.py - self._connector_service = None # UnifiedConnectorService - self._market_data_service = None # MarketDataService - self._trading_service = None # TradingService + # Database setup for account states and orders (shared manager injected from main.py; + # tables are created once at startup so no per-service bootstrap is needed) + self.db_manager = db_manager + + # Services injected from main.py (required). Set BEFORE any composed service below + # uses them: perpetual_trading_service binds self.get_connector_instance, which relies + # on _connector_service being available. + self._connector_service = connector_service # UnifiedConnectorService + self._market_data_service = market_data_service # MarketDataService + self._trading_service = trading_service # TradingService # Initialize Gateway client self.gateway_base_url = gateway_url self.gateway_client = GatewayClient(gateway_url) + # Composed services: gateway wallet CRUD/balances, perpetual trading and pure portfolio analytics + self.gateway_wallet_service = GatewayWalletService(self.gateway_client) + self.perpetual_trading_service = PerpetualTradingService(self.get_connector_instance) + self.portfolio_analytics_service = PortfolioAnalyticsService() + # Initialize Gateway transaction poller self.gateway_tx_poller = GatewayTransactionPoller( db_manager=self.db_manager, @@ -491,35 +116,6 @@ def __init__(self, ) self._gateway_poller_started = False - # Trading interfaces per account (for executor use) - self._trading_interfaces: Dict[str, AccountTradingInterface] = {} - - def get_trading_interface(self, account_name: str) -> AccountTradingInterface: - """ - Get or create a trading interface for the specified account. - - This interface provides ScriptStrategyBase-compatible methods - that executors can use for trading operations. - - Args: - account_name: Account to get trading interface for - - Returns: - AccountTradingInterface instance for the account - """ - if account_name not in self._trading_interfaces: - self._trading_interfaces[account_name] = AccountTradingInterface( - accounts_service=self, - account_name=account_name - ) - return self._trading_interfaces[account_name] - - async def ensure_db_initialized(self): - """Ensure database is initialized before using it.""" - if not self._db_initialized: - await self.db_manager.create_tables() - self._db_initialized = True - def get_accounts_state(self): return self.accounts_state @@ -585,15 +181,8 @@ async def stop(self): except Exception as e: logger.error(f"Error stopping Gateway transaction poller: {e}", exc_info=True) - # Cleanup trading interfaces - for interface in self._trading_interfaces.values(): - await interface.cleanup() - self._trading_interfaces.clear() - logger.info("Cleaned up trading interfaces") - # Stop all connectors through the connector service - if self._connector_service: - await self._connector_service.stop_all() + await self._connector_service.stop_all() logger.info("AccountsService stopped successfully") @@ -603,12 +192,11 @@ async def _refresh_and_get_tokens_info(self, connector, connector_name: str, acc Combines the connector state refresh and token info retrieval into a single awaitable so both can run in parallel across all connectors. """ - if self._connector_service: - try: - await self._connector_service._update_connector_state(connector, connector_name, account_name) - except Exception as e: - logger.error(f"Error refreshing {connector_name}, using stale data: {e}") - # skip_balance_refresh=True since _update_connector_state already called _update_balances + try: + await self._connector_service.refresh_connector_state(connector, connector_name, account_name) + except Exception as e: + logger.error(f"Error refreshing {connector_name}, using stale data: {e}") + # skip_balance_refresh=True since refresh_connector_state already called _update_balances return await self._get_connector_tokens_info(connector, connector_name, skip_balance_refresh=True) async def update_account_state_loop(self): @@ -621,7 +209,7 @@ async def update_account_state_loop(self): await self.check_all_connectors() # Single parallel pass: refresh connector state + get token info + gateway - all_connectors = self._connector_service.get_all_trading_connectors() if self._connector_service else {} + all_connectors = self._connector_service.get_all_trading_connectors() tasks = [] task_meta = [] # (account_name, connector_name) @@ -664,8 +252,7 @@ async def order_status_polling_loop(self): """ while True: try: - if self._connector_service: - await self._connector_service.sync_all_orders_to_database() + await self._connector_service.sync_all_orders_to_database() except Exception as e: logger.error(f"Error syncing order state to database: {e}") finally: @@ -675,23 +262,29 @@ async def dump_account_state(self): """ Save the current account state to the database. All account/connector combinations from the same snapshot will use the same timestamp. + The whole snapshot is persisted atomically in a single transaction: save_account_state + only flushes, and get_session_context commits once on successful exit. :return: """ - await self.ensure_db_initialized() - + # Snapshot the live dict synchronously (no awaits) so concurrent mutations of + # accounts_state cannot raise "dictionary changed size during iteration" + accounts_state_snapshot = {account: dict(connectors) for account, connectors in self.accounts_state.items()} + try: # Generate a single timestamp for this entire snapshot snapshot_timestamp = datetime.now(timezone.utc) - + async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) - - # Save each account-connector combination with the same timestamp - for account_name, connectors in self.accounts_state.items(): + + # Save each account-connector combination with the same timestamp. + # No commit happens inside the loop; the session context commits once + # after all rows are added (one transaction per snapshot). + for account_name, connectors in accounts_state_snapshot.items(): for connector_name, tokens_info in connectors.items(): if tokens_info: # Only save if there's token data await repository.save_account_state(account_name, connector_name, tokens_info, snapshot_timestamp) - + except Exception as e: logger.error(f"Error saving account state to database: {e}") # Re-raise the exception since we no longer have a fallback @@ -702,7 +295,8 @@ async def load_account_state_history(self, cursor: Optional[str] = None, start_time: Optional[datetime] = None, end_time: Optional[datetime] = None, - interval: str = "5m"): + interval: str = "5m", + account_names: Optional[List[str]] = None): """ Load the account state history from the database with pagination and interval sampling. @@ -712,16 +306,16 @@ async def load_account_state_history(self, start_time: Start time filter end_time: End time filter interval: Sampling interval (5m, 15m, 30m, 1h, 4h, 12h, 1d) + account_names: Optional list of account names to filter by (single IN query) :return: Tuple of (data, next_cursor, has_more). """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) return await repository.get_account_state_history( limit=limit, + account_names=account_names, cursor=cursor, start_time=start_time, end_time=end_time, @@ -747,9 +341,6 @@ async def _ensure_account_connectors_initialized(self, account_name: str): :param account_name: The name of the account to initialize connectors for. """ - if not self._connector_service: - return - # Initialize missing connectors for connector_name in self._connector_service.list_available_credentials(account_name): try: @@ -774,7 +365,7 @@ async def update_account_state( connector_names: If provided, only update these connectors. If None, update all connectors. For Gateway, this filters by chain-network (e.g., 'solana-mainnet-beta'). """ - all_connectors = self._connector_service.get_all_trading_connectors() if self._connector_service else {} + all_connectors = self._connector_service.get_all_trading_connectors() # Prepare parallel tasks tasks = [] @@ -847,9 +438,7 @@ async def _get_connector_tokens_info(self, connector, connector_name: str, skip_ price = Decimal("1") else: # Try RateOracle first (instant, cached) - rate = None - if self._market_data_service: - rate = self._market_data_service.get_rate(token, "USDT") + rate = self._market_data_service.get_rate(token, "USDT") if rate and rate > 0: price = rate else: @@ -859,13 +448,12 @@ async def _get_connector_tokens_info(self, connector, connector_name: str, skip_ missing_indices.append(len(tokens_info)) price = None # resolved below - tokens_info.append({ - "token": token, - "units": float(balance["units"]), - "price": float(price) if price is not None else 0.0, - "value": float(price * balance["units"]) if price is not None else 0.0, - "available_units": float(connector.get_available_balance(token)) - }) + tokens_info.append(balance_entry( + token, + balance["units"], + price, + available_units=connector.get_available_balance(token), + )) # Batch-fetch only the missing prices from the exchange if missing_pairs: @@ -905,14 +493,8 @@ async def _fetch_single(pair): last_traded[pair] = price # Fill in fallbacks for any pairs that failed - for pair in trading_pairs: - if pair not in last_traded: - if pair in self._last_known_prices: - last_traded[pair] = self._last_known_prices[pair] - logger.info(f"Using cached price {self._last_known_prices[pair]} for {pair}") - else: - last_traded[pair] = Decimal("0") - logger.warning(f"No cached price available for {pair}, using 0") + missing_pairs = [pair for pair in trading_pairs if pair not in last_traded] + last_traded.update(self._get_fallback_prices(missing_pairs)) return last_traded @@ -946,8 +528,15 @@ async def add_credentials(self, account_name: str, connector_name: str, credenti :param credentials: Dictionary containing the connector credentials. :raises Exception: If credentials are invalid or connector cannot be initialized. """ - if not self._connector_service: - raise HTTPException(status_code=500, detail="Connector service not initialized") + validate_safe_name(account_name, "account name") + validate_safe_name(connector_name, "connector name") + + # Capture the original credential file BEFORE the in-place overwrite performed by + # update_connector_keys. This determines whether a failure is a brand-new CREATE + # (rollback the partial file) or an UPDATE (restore the previous file byte-for-byte). + credentials_path = f"credentials/{account_name}/connectors/{connector_name}.yml" + credentials_existed = fs_util.path_exists(credentials_path) + original_content = fs_util.read_file(credentials_path) if credentials_existed else None try: # Update the connector keys (this saves the credentials to file and validates them) @@ -956,7 +545,13 @@ async def add_credentials(self, account_name: str, connector_name: str, credenti await self.update_account_state() except Exception as e: logger.error(f"Error adding connector credentials for account {account_name}: {e}") - await self.delete_credentials(account_name, connector_name) + # Roll back the file write. For a brand-new creation, delete the partial file. For an + # update, update_connector_keys overwrote the previous (valid) credentials in-place, so + # we restore the captured original content to keep the file byte-for-byte intact. + if not credentials_existed: + await self.delete_credentials(account_name, connector_name) + elif original_content is not None: + fs_util.ensure_file_and_dump_text(credentials_path, original_content) raise e @staticmethod @@ -974,6 +569,7 @@ def list_credentials(account_name: str): :param account_name: The name of the account. :return: List of credentials. """ + validate_safe_name(account_name, "account name") try: return [file for file in fs_util.list_files(f'credentials/{account_name}/connectors') if file.endswith('.yml')] @@ -987,16 +583,17 @@ async def delete_credentials(self, account_name: str, connector_name: str): :param connector_name: :return: """ + validate_safe_name(account_name, "account name") + validate_safe_name(connector_name, "connector name") # Delete credentials file if it exists if fs_util.path_exists(f"credentials/{account_name}/connectors/{connector_name}.yml"): fs_util.delete_file(directory=f"credentials/{account_name}/connectors", file_name=f"{connector_name}.yml") # Always perform cleanup regardless of file existence - if self._connector_service: - # Stop the connector if it's running - await self._connector_service.stop_trading_connector(account_name, connector_name) - # Clear the connector from cache - self._connector_service.clear_trading_connector(account_name, connector_name) + # Stop the connector if it's running + await self._connector_service.stop_trading_connector(account_name, connector_name) + # Clear the connector from cache + self._connector_service.clear_trading_connector(account_name, connector_name) # Remove from account state if account_name in self.accounts_state and connector_name in self.accounts_state[account_name]: @@ -1008,6 +605,7 @@ def add_account(self, account_name: str): :param account_name: :return: """ + validate_safe_name(account_name, "account name") # Check if account already exists by looking at folders if account_name in self.list_accounts(): raise HTTPException(status_code=400, detail="Account already exists.") @@ -1027,12 +625,12 @@ async def delete_account(self, account_name: str): :param account_name: :return: """ + validate_safe_name(account_name, "account name") # Stop all connectors for this account - if self._connector_service: - for connector_name in self._connector_service.list_account_connectors(account_name): - await self._connector_service.stop_trading_connector(account_name, connector_name) - # Clear all connectors for this account from cache - self._connector_service.clear_trading_connector(account_name) + for connector_name in self._connector_service.list_account_connectors(account_name): + await self._connector_service.stop_trading_connector(account_name, connector_name) + # Clear all connectors for this account from cache + self._connector_service.clear_trading_connector(account_name) # Delete account folder fs_util.delete_folder('credentials', account_name) @@ -1045,8 +643,6 @@ async def get_account_current_state(self, account_name: str) -> Dict[str, List[D """ Get current state for a specific account from database. """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1074,8 +670,6 @@ async def get_account_state_history(self, end_time: End time filter interval: Sampling interval (5m, 15m, 30m, 1h, 4h, 12h, 1d) """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1095,8 +689,6 @@ async def get_connector_current_state(self, account_name: str, connector_name: s """ Get current state for a specific connector. """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1116,8 +708,6 @@ async def get_connector_state_history(self, """ Get historical state for a specific connector with pagination. """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1137,8 +727,6 @@ async def get_all_unique_tokens(self) -> List[str]: """ Get all unique tokens across all accounts and connectors. """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1157,8 +745,6 @@ async def get_token_current_state(self, token: str) -> List[Dict]: """ Get current state of a specific token across all accounts. """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1167,12 +753,10 @@ async def get_token_current_state(self, token: str) -> List[Dict]: logger.error(f"Error getting token current state: {e}") return [] - async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict[str, any]: + async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict[str, Any]: """ Get total portfolio value, optionally filtered by account. """ - await self.ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repository = AccountRepository(session) @@ -1195,175 +779,20 @@ async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict[ return portfolio - def get_portfolio_distribution(self, account_name: Optional[str] = None) -> Dict[str, any]: + def get_portfolio_distribution(self, account_name: Optional[str] = None) -> Dict[str, Any]: """ Get portfolio distribution by tokens with percentages. + Delegates the pure math to PortfolioAnalyticsService (snapshots the live state internally). """ - try: - # Get accounts to process - accounts_to_process = [account_name] if account_name else list(self.accounts_state.keys()) - - # Aggregate all tokens across accounts and connectors - token_values = {} - total_value = 0 - - for acc_name in accounts_to_process: - if acc_name in self.accounts_state: - for connector_name, connector_data in self.accounts_state[acc_name].items(): - for token_info in connector_data: - token = token_info.get("token", "") - value = token_info.get("value", 0) - - if token not in token_values: - token_values[token] = { - "token": token, - "total_value": 0, - "total_units": 0, - "accounts": {} - } - - token_values[token]["total_value"] += value - token_values[token]["total_units"] += token_info.get("units", 0) - total_value += value - - # Track by account - if acc_name not in token_values[token]["accounts"]: - token_values[token]["accounts"][acc_name] = { - "value": 0, - "units": 0, - "connectors": {} - } - - token_values[token]["accounts"][acc_name]["value"] += value - token_values[token]["accounts"][acc_name]["units"] += token_info.get("units", 0) - - # Track by connector within account - if connector_name not in token_values[token]["accounts"][acc_name]["connectors"]: - token_values[token]["accounts"][acc_name]["connectors"][connector_name] = { - "value": 0, - "units": 0 - } - - token_values[token]["accounts"][acc_name]["connectors"][connector_name]["value"] += value - token_values[token]["accounts"][acc_name]["connectors"][connector_name]["units"] += token_info.get("units", 0) - - # Calculate percentages - distribution = [] - for token_data in token_values.values(): - percentage = (token_data["total_value"] / total_value * 100) if total_value > 0 else 0 - - token_dist = { - "token": token_data["token"], - "total_value": round(token_data["total_value"], 6), - "total_units": token_data["total_units"], - "percentage": round(percentage, 4), - "accounts": {} - } - - # Add account-level percentages - for acc_name, acc_data in token_data["accounts"].items(): - acc_percentage = (acc_data["value"] / total_value * 100) if total_value > 0 else 0 - token_dist["accounts"][acc_name] = { - "value": round(acc_data["value"], 6), - "units": acc_data["units"], - "percentage": round(acc_percentage, 4), - "connectors": {} - } - - # Add connector-level data - for conn_name, conn_data in acc_data["connectors"].items(): - token_dist["accounts"][acc_name]["connectors"][conn_name] = { - "value": round(conn_data["value"], 6), - "units": conn_data["units"] - } - - distribution.append(token_dist) - - # Sort by value (descending) - distribution.sort(key=lambda x: x["total_value"], reverse=True) - - return { - "total_portfolio_value": round(total_value, 6), - "token_count": len(distribution), - "distribution": distribution, - "account_filter": account_name if account_name else "all_accounts" - } - - except Exception as e: - logger.error(f"Error calculating portfolio distribution: {e}") - return { - "total_portfolio_value": 0, - "token_count": 0, - "distribution": [], - "account_filter": account_name if account_name else "all_accounts", - "error": str(e) - } - - def get_account_distribution(self) -> Dict[str, any]: + return self.portfolio_analytics_service.get_portfolio_distribution(self.accounts_state, account_name) + + def get_account_distribution(self) -> Dict[str, Any]: """ Get portfolio distribution by accounts with percentages. + Delegates the pure math to PortfolioAnalyticsService (snapshots the live state internally). """ - try: - account_values = {} - total_value = 0 - - for acc_name, account_data in self.accounts_state.items(): - account_value = 0 - connector_values = {} - - for connector_name, connector_data in account_data.items(): - connector_value = 0 - for token_info in connector_data: - value = token_info.get("value", 0) - connector_value += value - account_value += value - - connector_values[connector_name] = round(connector_value, 6) - - account_values[acc_name] = { - "total_value": round(account_value, 6), - "connectors": connector_values - } - total_value += account_value - - # Calculate percentages - distribution = [] - for acc_name, acc_data in account_values.items(): - percentage = (acc_data["total_value"] / total_value * 100) if total_value > 0 else 0 - - connector_dist = {} - for conn_name, conn_value in acc_data["connectors"].items(): - conn_percentage = (conn_value / total_value * 100) if total_value > 0 else 0 - connector_dist[conn_name] = { - "value": conn_value, - "percentage": round(conn_percentage, 4) - } - - distribution.append({ - "account": acc_name, - "total_value": acc_data["total_value"], - "percentage": round(percentage, 4), - "connectors": connector_dist - }) - - # Sort by value (descending) - distribution.sort(key=lambda x: x["total_value"], reverse=True) - - return { - "total_portfolio_value": round(total_value, 6), - "account_count": len(distribution), - "distribution": distribution - } - - except Exception as e: - logger.error(f"Error calculating account distribution: {e}") - return { - "total_portfolio_value": 0, - "account_count": 0, - "distribution": [], - "error": str(e) - } - + return self.portfolio_analytics_service.get_account_distribution(self.accounts_state) + async def place_trade(self, account_name: str, connector_name: str, trading_pair: str, trade_type: TradeType, amount: Decimal, order_type: OrderType = OrderType.LIMIT, price: Optional[Decimal] = None, position_action: PositionAction = PositionAction.OPEN) -> str: @@ -1390,11 +819,8 @@ async def place_trade(self, account_name: str, connector_name: str, trading_pair if account_name not in self.list_accounts(): raise HTTPException(status_code=404, detail=f"Account '{account_name}' not found") - if not self._connector_service: - raise HTTPException(status_code=500, detail="Connector service not initialized") - connector = await self._connector_service.get_trading_connector(account_name, connector_name) - + # Validate price for limit orders if order_type in [OrderType.LIMIT, OrderType.LIMIT_MAKER] and price is None: raise HTTPException(status_code=400, detail="Price is required for LIMIT and LIMIT_MAKER orders") @@ -1439,13 +865,12 @@ async def place_trade(self, account_name: str, connector_name: str, trading_pair notional_size = quantized_price * quantized_amount else: # For market orders without price, get current market price for validation - if self._market_data_service: - try: - prices = await self._market_data_service.get_prices(connector_name, [trading_pair]) - if trading_pair in prices and "error" not in prices: - price = Decimal(str(prices[trading_pair])) - except Exception as e: - logger.error(f"Error getting market price for {trading_pair}: {e}") + try: + prices = await self._market_data_service.get_prices(connector_name, [trading_pair]) + if trading_pair in prices and "error" not in prices: + price = Decimal(str(prices[trading_pair])) + except Exception as e: + logger.error(f"Error getting market price for {trading_pair}: {e}") notional_size = price * quantized_amount if price else Decimal("0") if notional_size < trading_rule.min_notional_size: @@ -1504,30 +929,9 @@ async def get_connector_instance(self, account_name: str, connector_name: str): if account_name not in self.list_accounts(): raise HTTPException(status_code=404, detail=f"Account '{account_name}' not found") - if not self._connector_service: - raise HTTPException(status_code=500, detail="Connector service not initialized") - return await self._connector_service.get_trading_connector(account_name, connector_name) - async def _get_perpetual_connector(self, account_name: str, connector_name: str): - """ - Get a perpetual connector instance with validation. - - Args: - account_name: Name of the account - connector_name: Name of the connector (must be perpetual) - - Returns: - Perpetual connector instance - - Raises: - HTTPException: If connector is not perpetual or not found - """ - if "_perpetual" not in connector_name: - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") - return await self.get_connector_instance(account_name, connector_name) - - async def get_active_orders(self, account_name: str, connector_name: str) -> Dict[str, any]: + async def get_active_orders(self, account_name: str, connector_name: str) -> Dict[str, Any]: """ Get active orders for a specific connector. @@ -1574,311 +978,31 @@ async def set_leverage(self, account_name: str, connector_name: str, trading_pair: str, leverage: int) -> Dict[str, str]: """ Set leverage for a specific trading pair on a perpetual connector. - - Args: - account_name: Name of the account - connector_name: Name of the connector (must be perpetual) - trading_pair: Trading pair to set leverage for - leverage: Leverage value (typically 1-125) - - Returns: - Dictionary with success status and message - - Raises: - HTTPException: If account/connector not found, not perpetual, or operation fails + Delegates to PerpetualTradingService. """ - connector = await self._get_perpetual_connector(account_name, connector_name) - - if not hasattr(connector, '_execute_set_leverage'): - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support leverage setting") - - try: - await connector._execute_set_leverage(trading_pair, leverage) - message = f"Leverage for {trading_pair} set to {leverage} on {connector_name}" - logger.info(f"Set leverage for {trading_pair} to {leverage} on {connector_name} (Account: {account_name})") - return {"status": "success", "message": message} - - except Exception as e: - logger.error(f"Failed to set leverage for {trading_pair} to {leverage}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to set leverage: {str(e)}") + return await self.perpetual_trading_service.set_leverage(account_name, connector_name, trading_pair, leverage) async def set_position_mode(self, account_name: str, connector_name: str, position_mode: PositionMode) -> Dict[str, str]: """ Set position mode for a perpetual connector. - - Args: - account_name: Name of the account - connector_name: Name of the connector (must be perpetual) - position_mode: PositionMode.HEDGE or PositionMode.ONEWAY - - Returns: - Dictionary with success status and message - - Raises: - HTTPException: If account/connector not found, not perpetual, or operation fails + Delegates to PerpetualTradingService. """ - connector = await self._get_perpetual_connector(account_name, connector_name) - - # Check if the requested position mode is supported - supported_modes = connector.supported_position_modes() - if position_mode not in supported_modes: - supported_values = [mode.value for mode in supported_modes] - raise HTTPException( - status_code=400, - detail=f"Position mode '{position_mode.value}' not supported. Supported modes: {supported_values}" - ) - - try: - # Try to call the method - it might be sync or async - result = connector.set_position_mode(position_mode) - # If it's a coroutine, await it - if asyncio.iscoroutine(result): - await result - - message = f"Position mode set to {position_mode.value} on {connector_name}" - logger.info(f"Set position mode to {position_mode.value} on {connector_name} (Account: {account_name})") - return {"status": "success", "message": message} - - except Exception as e: - logger.error(f"Failed to set position mode to {position_mode.value}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to set position mode: {str(e)}") + return await self.perpetual_trading_service.set_position_mode(account_name, connector_name, position_mode) async def get_position_mode(self, account_name: str, connector_name: str) -> Dict[str, str]: """ Get current position mode for a perpetual connector. - - Args: - account_name: Name of the account - connector_name: Name of the connector (must be perpetual) - - Returns: - Dictionary with current position mode - - Raises: - HTTPException: If account/connector not found, not perpetual, or operation fails + Delegates to PerpetualTradingService. """ - connector = await self._get_perpetual_connector(account_name, connector_name) - - if not hasattr(connector, 'position_mode'): - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position mode") - - try: - current_mode = connector.position_mode - return { - "position_mode": current_mode.value if current_mode else "UNKNOWN", - "connector": connector_name, - "account": account_name - } - - except Exception as e: - logger.error(f"Failed to get position mode: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get position mode: {str(e)}") - - async def get_orders(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, - trading_pair: Optional[str] = None, status: Optional[str] = None, - start_time: Optional[int] = None, end_time: Optional[int] = None, - limit: int = 100, offset: int = 0) -> List[Dict]: - """Get order history using OrderRepository.""" - await self.ensure_db_initialized() - - try: - async with self.db_manager.get_session_context() as session: - order_repo = OrderRepository(session) - orders = await order_repo.get_orders( - account_name=account_name, - connector_name=connector_name, - trading_pair=trading_pair, - status=status, - start_time=start_time, - end_time=end_time, - limit=limit, - offset=offset - ) - return [order_repo.to_dict(order) for order in orders] - except Exception as e: - logger.error(f"Error getting orders: {e}") - return [] - - async def get_active_orders_history(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, - trading_pair: Optional[str] = None) -> List[Dict]: - """Get active orders from database using OrderRepository.""" - await self.ensure_db_initialized() - - try: - async with self.db_manager.get_session_context() as session: - order_repo = OrderRepository(session) - orders = await order_repo.get_active_orders( - account_name=account_name, - connector_name=connector_name, - trading_pair=trading_pair - ) - return [order_repo.to_dict(order) for order in orders] - except Exception as e: - logger.error(f"Error getting active orders: {e}") - return [] - - async def get_orders_summary(self, account_name: Optional[str] = None, start_time: Optional[int] = None, - end_time: Optional[int] = None) -> Dict: - """Get order summary statistics using OrderRepository.""" - await self.ensure_db_initialized() - - try: - async with self.db_manager.get_session_context() as session: - order_repo = OrderRepository(session) - return await order_repo.get_orders_summary( - account_name=account_name, - start_time=start_time, - end_time=end_time - ) - except Exception as e: - logger.error(f"Error getting orders summary: {e}") - return { - "total_orders": 0, - "filled_orders": 0, - "cancelled_orders": 0, - "failed_orders": 0, - "active_orders": 0, - "fill_rate": 0, - } - - async def get_trades(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, - trading_pair: Optional[str] = None, trade_type: Optional[str] = None, - start_time: Optional[int] = None, end_time: Optional[int] = None, - limit: int = 100, offset: int = 0) -> List[Dict]: - """Get trade history using TradeRepository.""" - await self.ensure_db_initialized() - - try: - async with self.db_manager.get_session_context() as session: - trade_repo = TradeRepository(session) - trade_order_pairs = await trade_repo.get_trades_with_orders( - account_name=account_name, - connector_name=connector_name, - trading_pair=trading_pair, - trade_type=trade_type, - start_time=start_time, - end_time=end_time, - limit=limit, - offset=offset - ) - return [trade_repo.to_dict(trade, order) for trade, order in trade_order_pairs] - except Exception as e: - logger.error(f"Error getting trades: {e}") - return [] + return await self.perpetual_trading_service.get_position_mode(account_name, connector_name) async def get_account_positions(self, account_name: str, connector_name: str) -> List[Dict]: """ Get current positions for a specific perpetual connector. - - Args: - account_name: Name of the account - connector_name: Name of the connector (must be perpetual) - - Returns: - List of position dictionaries - - Raises: - HTTPException: If account/connector not found or not perpetual - """ - connector = await self._get_perpetual_connector(account_name, connector_name) - - if not hasattr(connector, 'account_positions'): - raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position tracking") - - try: - # Force position update to ensure current market prices are used - await connector._update_positions() - - positions = [] - raw_positions = connector.account_positions - - for trading_pair, position_info in raw_positions.items(): - # Convert position data to dict format - position_dict = { - "account_name": account_name, - "connector_name": connector_name, - "trading_pair": position_info.trading_pair, - "side": position_info.position_side.name if hasattr(position_info, 'position_side') else "UNKNOWN", - "amount": float(position_info.amount) if hasattr(position_info, 'amount') else 0.0, - "entry_price": float(position_info.entry_price) if hasattr(position_info, 'entry_price') else None, - "unrealized_pnl": float(position_info.unrealized_pnl) if hasattr(position_info, 'unrealized_pnl') else None, - "leverage": float(position_info.leverage) if hasattr(position_info, 'leverage') else None, - } - - # Only include positions with non-zero amounts - if position_dict["amount"] != 0: - positions.append(position_dict) - - return positions - - except Exception as e: - logger.error(f"Failed to get positions for {connector_name}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get positions: {str(e)}") - - async def get_funding_payments(self, account_name: str, connector_name: str = None, - trading_pair: str = None, limit: int = 100) -> List[Dict]: - """ - Get funding payment history for an account. - - Args: - account_name: Name of the account - connector_name: Optional connector name filter - trading_pair: Optional trading pair filter - limit: Maximum number of records to return - - Returns: - List of funding payment dictionaries + Delegates to PerpetualTradingService. """ - await self.ensure_db_initialized() - - try: - async with self.db_manager.get_session_context() as session: - funding_repo = FundingRepository(session) - funding_payments = await funding_repo.get_funding_payments( - account_name=account_name, - connector_name=connector_name, - trading_pair=trading_pair, - limit=limit - ) - return [funding_repo.to_dict(payment) for payment in funding_payments] - - except Exception as e: - logger.error(f"Error getting funding payments: {e}") - return [] - - async def get_total_funding_fees(self, account_name: str, connector_name: str, - trading_pair: str) -> Dict: - """ - Get total funding fees for a specific trading pair. - - Args: - account_name: Name of the account - connector_name: Name of the connector - trading_pair: Trading pair to get fees for - - Returns: - Dictionary with total funding fees information - """ - await self.ensure_db_initialized() - - try: - async with self.db_manager.get_session_context() as session: - funding_repo = FundingRepository(session) - return await funding_repo.get_total_funding_fees( - account_name=account_name, - connector_name=connector_name, - trading_pair=trading_pair - ) - - except Exception as e: - logger.error(f"Error getting total funding fees: {e}") - return { - "total_funding_fees": 0, - "payment_count": 0, - "fee_currency": None, - "error": str(e) - } + return await self.perpetual_trading_service.get_account_positions(account_name, connector_name) # ============================================ # Gateway Wallet Management Methods @@ -1917,22 +1041,33 @@ async def _update_gateway_balances(self, chain_networks: Optional[List[str]] = N balance_tasks = [] task_metadata = [] # Store (chain, network, address) for each task - # For each chain, get its config with defaultWallet and defaultNetworks + # Fetch every chain's config concurrently first, instead of one HTTP round-trip + # per chain in serial. Each config is the merged chain-network namespace + # (e.g., solana-mainnet-beta), returning both chain-level fields + # (defaultWallet, defaultNetworks) and network fields. + chains_with_networks = [ + chain_info for chain_info in chains_result["chains"] if chain_info.get("networks") + ] for chain_info in chains_result["chains"]: - chain = chain_info["chain"] - networks = chain_info.get("networks", []) + if not chain_info.get("networks"): + logger.debug(f"Chain '{chain_info['chain']}' has no networks configured, skipping") + + config_results = await asyncio.gather( + *[ + self.gateway_client.get_config(f"{chain_info['chain']}-{chain_info['networks'][0]}") + for chain_info in chains_with_networks + ], + return_exceptions=True, + ) - if not networks: - logger.debug(f"Chain '{chain}' has no networks configured, skipping") - continue + # For each chain, build balance tasks from its resolved config + for chain_info, config in zip(chains_with_networks, config_results): + chain = chain_info["chain"] + first_network = chain_info["networks"][0] - # Get merged config using chain-network namespace (e.g., solana-mainnet-beta) - # This returns both chain-level fields (defaultWallet, defaultNetworks) and network fields - first_network = networks[0] - try: - config = await self.gateway_client.get_config(f"{chain}-{first_network}") - except Exception as e: - logger.warning(f"Could not get config for '{chain}-{first_network}': {e}") + # A chain whose get_config raised is skipped/logged, same as before + if isinstance(config, Exception): + logger.warning(f"Could not get config for '{chain}-{first_network}': {config}") continue default_wallet = config.get("defaultWallet") @@ -2013,263 +1148,31 @@ async def _update_gateway_balances(self, chain_networks: Optional[List[str]] = N async def get_gateway_wallets(self) -> List[Dict]: """ Get all wallets from Gateway. Gateway manages its own encrypted wallets. - - Returns: - List of wallet information from Gateway, with default_address included for each chain + Delegates to GatewayWalletService. """ - if not await self.gateway_client.ping(): - raise HTTPException(status_code=503, detail="Gateway service is not available") - - try: - wallets = await self.gateway_client.get_wallets() - - # Enrich with default wallet info for each chain - for wallet_group in wallets: - chain = wallet_group.get("chain") - if chain: - default_wallet = await self.gateway_client.get_default_wallet_address(chain) - wallet_group["default_address"] = default_wallet or "" - - return wallets - except Exception as e: - logger.error(f"Error getting Gateway wallets: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get wallets: {str(e)}") + return await self.gateway_wallet_service.get_gateway_wallets() async def add_gateway_wallet(self, chain: str, private_key: str, set_default: bool = True) -> Dict: """ Add a wallet to Gateway. Gateway handles encryption internally. - - Args: - chain: Blockchain chain (e.g., 'solana', 'ethereum') - private_key: Wallet private key - set_default: Set as default wallet for this chain (default: True) - - Returns: - Dictionary with wallet information from Gateway + Delegates to GatewayWalletService. """ - if not await self.gateway_client.ping(): - raise HTTPException(status_code=503, detail="Gateway service is not available") - - try: - result = await self.gateway_client.add_wallet(chain, private_key, set_default=set_default) - - if "error" in result: - raise HTTPException(status_code=400, detail=f"Gateway error: {result['error']}") - - logger.info(f"Added {chain} wallet {result.get('address')} to Gateway") - return result - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error adding Gateway wallet: {e}") - raise HTTPException(status_code=500, detail=f"Failed to add wallet: {str(e)}") + return await self.gateway_wallet_service.add_gateway_wallet(chain, private_key, set_default=set_default) async def remove_gateway_wallet(self, chain: str, address: str) -> Dict: """ Remove a wallet from Gateway. - - Args: - chain: Blockchain chain - address: Wallet address to remove - - Returns: - Success message + Delegates to GatewayWalletService. """ - if not await self.gateway_client.ping(): - raise HTTPException(status_code=503, detail="Gateway service is not available") - - try: - result = await self.gateway_client.remove_wallet(chain, address) - - if "error" in result: - raise HTTPException(status_code=400, detail=f"Gateway error: {result['error']}") - - logger.info(f"Removed {chain} wallet {address} from Gateway") - return {"success": True, "message": f"Successfully removed {chain} wallet"} - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error removing Gateway wallet: {e}") - raise HTTPException(status_code=500, detail=f"Failed to remove wallet: {str(e)}") + return await self.gateway_wallet_service.remove_gateway_wallet(chain, address) - async def get_gateway_balances(self, chain: str, address: str, network: Optional[str] = None, tokens: Optional[List[str]] = None) -> List[Dict]: + async def get_gateway_balances(self, chain: str, address: str, network: Optional[str] = None, + tokens: Optional[List[str]] = None) -> List[Dict]: """ Get Gateway wallet balances with pricing from rate sources. - - Args: - chain: Blockchain chain - address: Wallet address - network: Optional network name (if not provided, uses default network for chain) - tokens: Optional list of token symbols to query - - Returns: - List of token balance dictionaries with prices from rate sources - """ - if not await self.gateway_client.ping(): - raise HTTPException(status_code=503, detail="Gateway service is not available") - - try: - # Get default network for chain if not provided - if not network: - network = await self.gateway_client.get_default_network(chain) - if not network: - raise HTTPException(status_code=400, detail=f"Could not determine network for chain '{chain}'") - - # Get balances from Gateway - balances_response = await self.gateway_client.get_balances(chain, network, address, tokens=tokens) - - if "error" in balances_response: - raise HTTPException(status_code=400, detail=f"Gateway error: {balances_response['error']}") - - # Format balances list - balances = balances_response.get("balances", {}) - balances_list = [] - - for token, balance in balances.items(): - if balance and float(balance) > 0: - balances_list.append({ - "token": token, - "units": Decimal(str(balance)) - }) - - # Get prices for tokens - unique_tokens = [b["token"] for b in balances_list] - all_prices = {} - - # Fetch prices for Gateway tokens - if unique_tokens: - try: - fetched_prices = await self._fetch_gateway_prices_immediate( - chain, network, unique_tokens - ) - for token, price in fetched_prices.items(): - if price > 0: - all_prices[token] = price - except Exception as e: - logger.warning(f"Error fetching gateway prices: {e}") - - # Format final result with prices - formatted_balances = [] - for balance in balances_list: - token = balance["token"] - if "USD" in token: - price = Decimal("1") - else: - # all_prices is now keyed by token name directly - price = Decimal(str(all_prices.get(token, 0))) - - formatted_balances.append({ - "token": token, - "units": float(balance["units"]), - "price": float(price), - "value": float(price * balance["units"]), - "available_units": float(balance["units"]) - }) - - return formatted_balances - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error getting Gateway balances: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get balances: {str(e)}") - - async def _fetch_gateway_prices_immediate(self, chain: str, network: str, - tokens: List[str]) -> Dict[str, Decimal]: + Delegates to GatewayWalletService. """ - Fetch prices immediately from Gateway for the given tokens. - This is used to get prices right away instead of waiting for the background update task. - - Args: - chain: Blockchain chain (e.g., 'solana', 'ethereum') - network: Network name (e.g., 'mainnet-beta', 'mainnet') - tokens: List of token symbols to get prices for - - Returns: - Dictionary mapping token symbol to price in USDC - """ - from hummingbot.core.data_type.common import TradeType - from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient - from hummingbot.core.rate_oracle.rate_oracle import RateOracle - - gateway_client = GatewayHttpClient.get_instance() - rate_oracle = RateOracle.get_instance() - prices = {} - - # Construct full network name (e.g., "solana-mainnet-beta") - full_network = f"{chain}-{network}" - - # Create tasks for all tokens in parallel - tasks = [] - task_tokens = [] - quote_asset = "USDC" - - # On ethereum networks, use WETH price for ETH to avoid duplicate calls - eth_needs_weth_price = False - if chain == "ethereum": - has_eth = any(t.upper() == "ETH" for t in tokens) - has_weth = any(t.upper() == "WETH" for t in tokens) - if has_eth and not has_weth: - # Replace ETH with WETH for fetching - tokens = [t if t.upper() != "ETH" else "WETH" for t in tokens] - eth_needs_weth_price = True - logger.debug("Replacing ETH with WETH for price fetch on ethereum") - elif has_eth and has_weth: - # Remove ETH, will copy WETH price later - tokens = [t for t in tokens if t.upper() != "ETH"] - eth_needs_weth_price = True - logger.debug("Removing duplicate ETH, will use WETH price on ethereum") - - for token in tokens: - token_upper = token.upper() - - # Skip same-token quotes (e.g., USDC/USDC) - price is always 1 - if token_upper == quote_asset.upper(): - prices[token] = Decimal("1") - rate_oracle.set_price(f"{token}-{quote_asset}", Decimal("1")) - logger.debug(f"Skipping same-token quote for {token}, price=1") - continue - - try: - # get_price will auto-fetch dex/trading_type from network's swap provider - task = gateway_client.get_price( - network=full_network, - base_asset=token, - quote_asset=quote_asset, - amount=Decimal("1"), - side=TradeType.SELL - ) - tasks.append(task) - task_tokens.append(token) - except Exception as e: - logger.warning(f"Error preparing price request for {token}: {e}") - continue - - if tasks: - try: - results = await asyncio.gather(*tasks, return_exceptions=True) - for token, result in zip(task_tokens, results): - if isinstance(result, Exception): - logger.warning(f"Error fetching price for {token}: {result}") - elif result and "price" in result: - price = Decimal(str(result["price"])) - prices[token] = price - # Also update the rate oracle so future lookups can find it - trading_pair = f"{token}-USDC" - rate_oracle.set_price(trading_pair, price) - logger.debug(f"Fetched immediate price for {token}: {price} USDC") - except Exception as e: - logger.error(f"Error fetching gateway prices: {e}", exc_info=True) - - # Copy WETH price to ETH on ethereum networks - if eth_needs_weth_price and "WETH" in prices: - prices["ETH"] = prices["WETH"] - rate_oracle.set_price("ETH-USDC", prices["WETH"]) - logger.debug(f"Copied WETH price to ETH: {prices['WETH']} USDC") - - return prices + return await self.gateway_wallet_service.get_gateway_balances(chain, address, network=network, tokens=tokens) def get_unwrapped_token(self, token: str) -> str: """Get the unwrapped version of a wrapped token symbol (e.g., WSOL -> SOL).""" diff --git a/services/bots_orchestrator.py b/services/bots_orchestrator.py index 88a36bff..b274a5c2 100644 --- a/services/bots_orchestrator.py +++ b/services/bots_orchestrator.py @@ -1,27 +1,26 @@ import asyncio import logging +import os import re +import shutil from datetime import datetime, timezone -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import docker -from config import settings -from database import AsyncDatabaseManager, ControllerPerformanceRepository +from database import AsyncDatabaseManager, BotRunRepository, ControllerPerformanceRepository +from services.docker_service import DockerService +from utils.bot_archiver import BotArchiver from utils.mqtt_manager import MQTTManager logger = logging.getLogger(__name__) -# HummingbotPerformanceListener class is no longer needed -# All functionality is now handled by MQTTManager - - class BotsOrchestrator: """Orchestrates Hummingbot instances using Docker and MQTT communication.""" def __init__(self, broker_host, broker_port, broker_username, broker_password, - performance_dump_interval: int = 5): + db_manager: AsyncDatabaseManager, performance_dump_interval: int = 5): self.broker_host = broker_host self.broker_port = broker_port self.broker_username = broker_username @@ -43,8 +42,9 @@ def __init__(self, broker_host, broker_port, broker_username, broker_password, # Controller performance dump (similar to AccountsService.dump_account_state) self.performance_dump_interval = performance_dump_interval * 60 # Convert minutes to seconds self._performance_dump_task: Optional[asyncio.Task] = None - self.db_manager = AsyncDatabaseManager(settings.database.url) - self._db_initialized = False + # Shared manager injected from main.py; tables are created once at startup, + # so no per-service bootstrap is needed here. + self.db_manager = db_manager # MQTT manager will be started asynchronously later @@ -87,18 +87,26 @@ async def _start_async(self): # Then start the update loop await self.update_active_bots() - def stop(self): + async def stop(self): """Stop the active bots monitoring loop.""" if self._update_bots_task: self._update_bots_task.cancel() + try: + await self._update_bots_task + except asyncio.CancelledError: + pass self._update_bots_task = None if self._performance_dump_task: self._performance_dump_task.cancel() + try: + await self._performance_dump_task + except asyncio.CancelledError: + pass self._performance_dump_task = None - # Stop MQTT manager asynchronously - asyncio.create_task(self.mqtt_manager.stop()) + # Stop MQTT manager + await self.mqtt_manager.stop() async def update_active_bots(self, sleep_time=1.0): """Monitor and update active bots list using both Docker and MQTT discovery.""" @@ -294,7 +302,6 @@ def determine_controller_performance(controller_reports): return cleaned_data def get_all_bots_status(self): - # TODO: improve logic of bots state management """Get status information for all active bots.""" all_bots_status = {} for bot in [bot for bot in self.active_bots if not self.is_bot_stopping(bot)]: @@ -368,12 +375,6 @@ def is_bot_stopping(self, bot_name: str) -> bool: # Controller Performance Snapshots # ============================================ - async def _ensure_db_initialized(self): - """Ensure database is initialized before using it.""" - if not self._db_initialized: - await self.db_manager.create_tables() - self._db_initialized = True - async def _performance_dump_loop(self): """Periodically dump controller performance to the database (default every 5 minutes).""" while True: @@ -386,8 +387,6 @@ async def _performance_dump_loop(self): async def dump_controller_performance(self): """Save current controller performance for all active bots to the database.""" - await self._ensure_db_initialized() - snapshot_timestamp = datetime.now(timezone.utc) saved_count = 0 @@ -395,6 +394,7 @@ async def dump_controller_performance(self): async with self.db_manager.get_session_context() as session: repo = ControllerPerformanceRepository(session) + snapshots = [] for bot_name in list(self.active_bots): if self.is_bot_stopping(bot_name): continue @@ -403,15 +403,17 @@ async def dump_controller_performance(self): performance_data = self.determine_controller_performance(controller_reports) for controller_id, data in performance_data.items(): - await repo.save_controller_performance( - bot_name=bot_name, - controller_id=controller_id, - status=data.get("status", "unknown"), - performance=data.get("performance", {}), - custom_info=data.get("custom_info", {}), - snapshot_timestamp=snapshot_timestamp, - ) - saved_count += 1 + snapshots.append({ + "bot_name": bot_name, + "controller_id": controller_id, + "status": data.get("status", "unknown"), + "performance": data.get("performance", {}), + "custom_info": data.get("custom_info", {}), + "snapshot_timestamp": snapshot_timestamp, + }) + + saved_rows = await repo.save_controller_performances(snapshots) + saved_count = len(saved_rows) if saved_count > 0: logger.info(f"Dumped {saved_count} controller performance snapshots") @@ -430,8 +432,6 @@ async def get_controller_performance_history( interval: str = "5m" ): """Get historical controller performance with pagination and interval sampling.""" - await self._ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repo = ControllerPerformanceRepository(session) @@ -453,8 +453,6 @@ async def get_latest_controller_performance( bot_name: Optional[str] = None ) -> List[Dict]: """Get the most recent performance snapshot for each bot/controller.""" - await self._ensure_db_initialized() - try: async with self.db_manager.get_session_context() as session: repo = ControllerPerformanceRepository(session) @@ -462,3 +460,276 @@ async def get_latest_controller_performance( except Exception as e: logger.error(f"Error getting latest controller performance: {e}") return [] + + # ============================================ + # Bot Run persistence + # ============================================ + + async def mark_bot_run_stopped(self, bot_name: str, final_status: Optional[Dict] = None): + """Update a bot run status to STOPPED, capturing the final status snapshot.""" + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped(bot_name, final_status=final_status) + logger.info(f"Updated bot run status to STOPPED for {bot_name}") + + async def get_bot_runs( + self, + bot_name: Optional[str] = None, + account_name: Optional[str] = None, + strategy_type: Optional[str] = None, + strategy_name: Optional[str] = None, + run_status: Optional[str] = None, + deployment_status: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> List[Dict]: + """Get bot runs with optional filtering, serialized as dictionaries.""" + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_runs = await bot_run_repo.get_bot_runs( + bot_name=bot_name, + account_name=account_name, + strategy_type=strategy_type, + strategy_name=strategy_name, + run_status=run_status, + deployment_status=deployment_status, + limit=limit, + offset=offset, + ) + return [self._serialize_bot_run(run) for run in bot_runs] + + async def get_bot_run_stats(self) -> Dict[str, Any]: + """Get statistics about bot runs.""" + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + return await bot_run_repo.get_bot_run_stats() + + async def get_bot_run_by_id(self, bot_run_id: int) -> Optional[Dict]: + """Get a specific bot run by ID, serialized as a dictionary (None if not found).""" + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_run = await bot_run_repo.get_bot_run_by_id(bot_run_id) + if not bot_run: + return None + return self._serialize_bot_run(bot_run) + + async def delete_bot_run(self, bot_run_id: int) -> Optional[Dict]: + """Delete a bot run record and its archived folder. + + Returns a dict with ``bot_name`` and ``archived_folder_deleted`` keys, + or None if the bot run does not exist. + """ + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_run = await bot_run_repo.delete_bot_run(bot_run_id) + + if not bot_run: + return None + + # Also delete the archived bot folder if it exists + archived_dir = os.path.join('bots', 'archived', bot_run.instance_name) + archived_deleted = False + if os.path.isdir(archived_dir): + try: + import platform + import subprocess + if platform.system() == 'Darwin': + # Strip macOS ACLs (Docker adds "deny delete" ACLs) + subprocess.run(['chmod', '-R', '-N', archived_dir], check=False) + shutil.rmtree(archived_dir) + archived_deleted = True + logger.info(f"Deleted archived folder: {archived_dir}") + except Exception as e: + logger.warning(f"Failed to delete archived folder {archived_dir}: {e}") + + return { + "bot_name": bot_run.bot_name, + "archived_folder_deleted": archived_deleted, + } + + async def create_bot_run(self, **kwargs): + """Create a bot run record. Errors are logged and swallowed so that a + failed tracking write never fails the caller's deployment.""" + try: + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.create_bot_run(**kwargs) + logger.info(f"Created bot run record for deployment {kwargs.get('instance_name')}") + except Exception as e: + logger.error(f"Failed to create bot run record: {e}") + # Don't fail the deployment if bot run creation fails + + @staticmethod + def _serialize_bot_run(run) -> Dict: + """Serialize a BotRun ORM object into a JSON-friendly dictionary.""" + return { + "id": run.id, + "bot_name": run.bot_name, + "instance_name": run.instance_name, + "deployed_at": run.deployed_at.isoformat() if run.deployed_at else None, + "stopped_at": run.stopped_at.isoformat() if run.stopped_at else None, + "strategy_type": run.strategy_type, + "strategy_name": run.strategy_name, + "config_name": run.config_name, + "account_name": run.account_name, + "image_version": run.image_version, + "deployment_status": run.deployment_status, + "run_status": run.run_status, + "deployment_config": run.deployment_config, + "final_status": run.final_status, + "error_message": run.error_message, + } + + # ============================================ + # Stop & Archive orchestration + # ============================================ + + async def stop_and_archive_bot( + self, + bot_name: str, + container_name: str, + bot_name_for_orchestrator: str, + skip_order_cancellation: bool, + archive_locally: bool, + s3_bucket: Optional[str], + docker_manager: DockerService, + bot_archiver: BotArchiver, + ): + """Stop a bot and archive its data (8-step workflow). + + This is the background-task body for ``stop-and-archive-bot``. It is + FastAPI-agnostic and can be invoked/tested directly. + """ + try: + logger.info(f"Starting background stop-and-archive for {bot_name}") + + # Step 1: Capture bot final status before stopping (while bot is still running) + logger.info(f"Capturing final status for {bot_name_for_orchestrator}") + final_status = None + try: + final_status = self.get_bot_status(bot_name_for_orchestrator) + logger.info(f"Captured final status for {bot_name_for_orchestrator}: {final_status}") + except Exception as e: + logger.warning(f"Failed to capture final status for {bot_name_for_orchestrator}: {e}") + + # Step 2: Update bot run with stopped_at timestamp and final status before stopping + try: + await self.mark_bot_run_stopped(bot_name, final_status=final_status) + logger.info(f"Updated bot run with stopped_at timestamp and final status for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run with stopped status: {e}") + # Continue with stop process even if database update fails + + # Step 3: Mark the bot as stopping, and stop the bot trading process + self.set_bot_stopping(bot_name_for_orchestrator) + logger.info(f"Stopping bot trading process for {bot_name_for_orchestrator}") + stop_response = await self.stop_bot( + bot_name_for_orchestrator, + skip_order_cancellation=skip_order_cancellation, + async_backend=True # Always use async for background tasks + ) + + if not stop_response or not stop_response.get("success", False): + error_msg = stop_response.get('error', 'Unknown error') if stop_response else 'No response from bot orchestrator' + logger.error(f"Failed to stop bot process: {error_msg}") + return + + # Step 4: Wait for graceful shutdown (15 seconds as requested) + logger.info(f"Waiting 15 seconds for bot {bot_name} to gracefully shutdown") + await asyncio.sleep(15) + + # Step 5: Stop the container with monitoring + max_retries = 10 + retry_interval = 2 + container_stopped = False + + for i in range(max_retries): + logger.info(f"Attempting to stop container {container_name} (attempt {i+1}/{max_retries})") + docker_manager.stop_container(container_name) + + # Check if container is already stopped + container_status = docker_manager.get_container_status(container_name) + if container_status.get("state", {}).get("status") == "exited": + container_stopped = True + logger.info(f"Container {container_name} is already stopped") + break + + await asyncio.sleep(retry_interval) + + if not container_stopped: + logger.error(f"Failed to stop container {container_name} after {max_retries} attempts") + return + + # Step 6: Archive the bot data + instance_dir = os.path.join('bots', 'instances', container_name) + logger.info(f"Archiving bot data from {instance_dir}") + + try: + if archive_locally: + bot_archiver.archive_locally(container_name, instance_dir) + else: + bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) + logger.info(f"Successfully archived bot data for {container_name}") + except Exception as e: + logger.error(f"Archive failed: {str(e)}") + # Continue with removal even if archive fails + + # Step 7: Remove the container + logging.info(f"Removing container {container_name}") + remove_response = docker_manager.remove_container(container_name, force=False) + + if not remove_response.get("success"): + # If graceful remove fails, try force remove + logging.warning("Graceful container removal failed, attempting force removal") + remove_response = docker_manager.remove_container(container_name, force=True) + + if remove_response.get("success"): + logging.info(f"Successfully completed stop-and-archive for bot {bot_name}") + + # Step 8: Update bot run deployment status to ARCHIVED + try: + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_archived(bot_name) + logger.info(f"Updated bot run deployment status to ARCHIVED for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run to archived: {e}") + else: + logging.error(f"Failed to remove container {container_name}") + + # Update bot run with error status (but keep stopped_at timestamp from earlier) + try: + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + error_message="Failed to remove container during archive process" + ) + logger.info(f"Updated bot run with error status for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run with error: {e}") + + except Exception as e: + logging.error(f"Error in background stop-and-archive for {bot_name}: {str(e)}") + + # Update bot run with error status + try: + async with self.db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + error_message=str(e) + ) + logger.info(f"Updated bot run with error status for {bot_name}") + except Exception as db_error: + logger.error(f"Failed to update bot run with error: {db_error}") + finally: + # Always clear the stopping status when the background task completes + self.clear_bot_stopping(bot_name_for_orchestrator) + logger.info(f"Cleared stopping status for bot {bot_name}") + + # Remove bot from active_bots and clear all MQTT data + if bot_name_for_orchestrator in self.active_bots: + self.mqtt_manager.clear_bot_data(bot_name_for_orchestrator) + del self.active_bots[bot_name_for_orchestrator] + logger.info(f"Removed bot {bot_name_for_orchestrator} from active_bots and cleared MQTT data") diff --git a/services/docker_service.py b/services/docker_service.py index a474499f..6aa6da40 100644 --- a/services/docker_service.py +++ b/services/docker_service.py @@ -161,17 +161,33 @@ def remove_container(self, container_name, force=True): except DockerException as e: return {"success": False, "message": str(e)} + @staticmethod + def _ensure_contained(path: str, base_dir: str, label: str): + """ + Defense in depth: verify that `path` stays inside `base_dir` after resolving symlinks and + traversal sequences. Raises ValueError if it escapes the allowed base directory. + """ + resolved_base = os.path.realpath(base_dir) + resolved_path = os.path.realpath(path) + if os.path.commonpath([resolved_base, resolved_path]) != resolved_base: + raise ValueError(f"Invalid {label}: '{path}' resolves outside of '{base_dir}'.") + return resolved_path + def create_hummingbot_instance(self, config: V2ControllerDeployment): bots_path = os.environ.get('BOTS_PATH', self.SOURCE_PATH) # Default to 'SOURCE_PATH' if BOTS_PATH is not set instance_name = config.instance_name instance_dir = os.path.join("bots", 'instances', instance_name) + # Defense in depth: ensure the resolved paths stay within their allowed base directories + # before any filesystem mutation (makedirs/copytree) takes place. + self._ensure_contained(instance_dir, os.path.join("bots", "instances"), "instance_name") + source_credentials_dir = os.path.join("bots", 'credentials', config.credentials_profile) + self._ensure_contained(source_credentials_dir, os.path.join("bots", "credentials"), "credentials_profile") if not os.path.exists(instance_dir): os.makedirs(instance_dir) os.makedirs(os.path.join(instance_dir, 'data')) os.makedirs(os.path.join(instance_dir, 'logs')) # Copy credentials to instance directory - source_credentials_dir = os.path.join("bots", 'credentials', config.credentials_profile) destination_credentials_dir = os.path.join(instance_dir, 'conf') # Remove the destination directory if it already exists diff --git a/services/executor_service.py b/services/executor_service.py index 08459dc2..21c235c1 100644 --- a/services/executor_service.py +++ b/services/executor_service.py @@ -32,7 +32,7 @@ from hummingbot.strategy_v2.executors.xemm_executor.xemm_executor import XEMMExecutor from hummingbot.strategy_v2.models.executors import CloseType, TrackedOrder -from database import AsyncDatabaseManager +from database import AsyncDatabaseManager, ExecutorRepository from models.executors import PositionHold from services.trading_service import AccountTradingInterface, TradingService from utils.executor_log_capture import ExecutorLogCapture, current_executor_id @@ -61,6 +61,43 @@ def _json_default(obj): raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") +def _coerce_json_compatible(obj): + """Recursively coerce a value into JSON-compatible primitives. + + Mirrors the result of ``json.loads(json.dumps(obj, default=_json_default))`` + without the string round-trip: containers are walked recursively and any + object handled by ``_json_default`` is coerced to the same output type. + """ + # JSON-native primitives are returned as-is. + if obj is None or isinstance(obj, (str, bool, int, float)): + return obj + if isinstance(obj, dict): + # json.dumps coerces non-string scalar keys (int/float/bool/None) to + # strings; replicate that so the output shape is identical. + coerced = {} + for key, value in obj.items(): + if isinstance(key, str): + str_key = key + elif isinstance(key, bool): + str_key = "true" if key else "false" + elif key is None: + str_key = "null" + elif isinstance(key, (int, float)): + str_key = json.dumps(key) + else: + raise TypeError( + f"keys must be str, int, float, bool or None, not {type(key).__name__}" + ) + coerced[str_key] = _coerce_json_compatible(value) + return coerced + if isinstance(obj, (list, tuple)): + # json.dumps serializes tuples as JSON arrays (-> lists on decode). + return [_coerce_json_compatible(item) for item in obj] + # Non-native types: route through the same coercion as the JSON encoder, + # then recurse into the (possibly nested) replacement value. + return _coerce_json_compatible(_json_default(obj)) + + class ExecutorService: """ Service for managing trading executors without Docker containers. @@ -145,7 +182,6 @@ async def recover_positions_from_db(self): try: async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) records = await repo.get_active_position_holds() @@ -206,7 +242,6 @@ async def cleanup_orphaned_executors(self): active_executor_ids = list(self._active_executors.keys()) async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) # Clean up orphaned executors @@ -283,25 +318,27 @@ def _get_trading_interface(self, account_name: str) -> AccountTradingInterface: self._trading_interfaces[account_name] = self._trading_service.get_trading_interface(account_name) return self._trading_interfaces[account_name] - async def create_executor( + def _validate_executor_config( self, executor_config: Dict[str, Any], - account_name: Optional[str] = None, - controller_id: Optional[str] = None - ) -> Dict[str, Any]: + default_timestamp: Optional[float] = None + ) -> tuple[Type[ExecutorBase], Type[ExecutorConfigBase], ExecutorConfigBase]: """ - Create and start a new executor. + Validate the executor type and build the typed executor config. + + Pure validation step: no IO, no executor started, no DB access. Args: executor_config: Executor configuration dictionary (must include 'type') - account_name: Account to use (defaults to master_account) + default_timestamp: Timestamp to set on the config if not provided + (required for time-based features like time_limit) Returns: - Dictionary with executor_id and initial status - """ - account = account_name or self.default_account + Tuple of (executor_class, config_class, typed_config) - # Get executor type from config + Raises: + HTTPException: 400 if the type is missing/invalid or the config is invalid + """ executor_type = executor_config.get("type") if not executor_type: raise HTTPException( @@ -309,32 +346,15 @@ async def create_executor( detail="executor_config must include 'type' field" ) - # Validate executor type if executor_type not in self.EXECUTOR_REGISTRY: raise HTTPException( status_code=400, detail=f"Invalid executor type '{executor_type}'. Valid types: {list(self.EXECUTOR_REGISTRY.keys())}" ) - # Get trading interface for this account - trading_interface = self._get_trading_interface(account) - - # Extract connector and trading pair from config - connector_name = executor_config.get("connector_name") - trading_pair = executor_config.get("trading_pair") - - # Ensure connector and market are ready - if connector_name: - if trading_pair: - await trading_interface.add_market(connector_name, trading_pair) - else: - await trading_interface.ensure_connector(connector_name) - - # Set timestamp if not provided (required for time-based features like time_limit) if "timestamp" not in executor_config or executor_config["timestamp"] is None: - executor_config["timestamp"] = trading_interface.current_timestamp + executor_config["timestamp"] = default_timestamp - # Create typed executor config executor_class, config_class = self.EXECUTOR_REGISTRY[executor_type] try: typed_config = config_class(**executor_config) @@ -344,7 +364,39 @@ async def create_executor( detail=f"Invalid executor config: {str(e)}" ) - # Create the executor instance + return executor_class, config_class, typed_config + + async def _prepare_market(self, account: str, connector_name: Optional[str], trading_pair: Optional[str]): + """Ensure the connector and market for the executor are ready on the account's trading interface.""" + trading_interface = self._get_trading_interface(account) + if connector_name: + if trading_pair: + await trading_interface.add_market(connector_name, trading_pair) + else: + await trading_interface.ensure_connector(connector_name) + + def _instantiate_and_register( + self, + executor_class: Type[ExecutorBase], + typed_config: ExecutorConfigBase, + trading_interface: AccountTradingInterface, + metadata: Dict[str, Any] + ) -> tuple[str, ExecutorBase]: + """ + Instantiate the executor, register it in memory and start it. + + Args: + executor_class: Executor class to instantiate + typed_config: Validated typed executor config + trading_interface: Trading interface acting as the executor's strategy + metadata: Metadata dict to register for the executor + + Returns: + Tuple of (executor_id, executor) + + Raises: + HTTPException: 400 if the executor fails to instantiate + """ try: executor = executor_class( strategy=trading_interface, @@ -358,11 +410,50 @@ async def create_executor( detail=f"Failed to create executor: {str(e)}" ) - # Store executor and metadata executor_id = typed_config.id - controller_id = controller_id or getattr(typed_config, "controller_id", "main") or "main" self._active_executors[executor_id] = executor - self._executor_metadata[executor_id] = { + self._executor_metadata[executor_id] = metadata + + # Set ContextVar so the asyncio Task created by start() inherits it + token = current_executor_id.set(executor_id) + executor.start() + current_executor_id.reset(token) + + return executor_id, executor + + async def create_executor( + self, + executor_config: Dict[str, Any], + account_name: Optional[str] = None, + controller_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + Create and start a new executor. + + Args: + executor_config: Executor configuration dictionary (must include 'type') + account_name: Account to use (defaults to master_account) + + Returns: + Dictionary with executor_id and initial status + """ + account = account_name or self.default_account + trading_interface = self._get_trading_interface(account) + + # Validate executor type and build the typed config + executor_class, _config_class, typed_config = self._validate_executor_config( + executor_config, default_timestamp=trading_interface.current_timestamp + ) + executor_type = executor_config["type"] + + # Ensure connector and market are ready + connector_name = executor_config.get("connector_name") + trading_pair = executor_config.get("trading_pair") + await self._prepare_market(account, connector_name, trading_pair) + + # Instantiate the executor, register it in memory and start it + controller_id = controller_id or getattr(typed_config, "controller_id", "main") or "main" + metadata = { "account_name": account, "connector_name": connector_name, "trading_pair": trading_pair, @@ -371,17 +462,13 @@ async def create_executor( "created_at": datetime.now(timezone.utc), "config": executor_config } - - # Set ContextVar so the asyncio Task created by start() inherits it - token = current_executor_id.set(executor_id) - executor.start() - current_executor_id.reset(token) + executor_id, executor = self._instantiate_and_register(executor_class, typed_config, trading_interface, metadata) # Persist to database await self._persist_executor_created(executor_id, executor) # Capture created_at before potential cleanup - created_at = self._executor_metadata[executor_id]["created_at"].isoformat() + created_at = metadata["created_at"].isoformat() # Check if executor terminated immediately (e.g., insufficient balance) # If so, handle completion now rather than waiting for control loop @@ -452,7 +539,6 @@ async def get_executors( if self.db_manager: try: async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) db_executors = await repo.get_executors( @@ -495,7 +581,6 @@ async def get_executor(self, executor_id: str) -> Optional[Dict[str, Any]]: if self.db_manager: try: async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) record = await repo.get_executor_by_id(executor_id) @@ -566,8 +651,11 @@ async def stop_executor( async def _handle_executor_completion(self, executor_id: str): """Handle cleanup when an executor completes.""" - executor = self._active_executors.get(executor_id) - if not executor: + # Atomically claim the executor so a concurrent completion (e.g. the + # control loop racing with the synchronous call in create_executor) + # returns early instead of double-persisting / double-aggregating. + executor = self._active_executors.pop(executor_id, None) + if executor is None: return metadata = self._executor_metadata.get(executor_id, {}) @@ -579,8 +667,9 @@ async def _handle_executor_completion(self, executor_id: str): # Persist final state to database await self._persist_executor_completed(executor_id, executor) - # Remove from active executors - del self._active_executors[executor_id] + # Active executor already claimed via pop above; drop its metadata last + # (metadata is read above and re-fetched inside the persist/aggregate + # helpers, so it must stay until after those awaits complete). if executor_id in self._executor_metadata: del self._executor_metadata[executor_id] @@ -599,9 +688,14 @@ def _format_executor_info( metadata = self._executor_metadata.get(executor_id, {}) executor_type = metadata.get("executor_type") - # Get executor_info and serialize + # Get executor_info as a dict and strip heavy custom_info fields BEFORE + # serialization so they never get coerced (fill_events, grid + # levels_by_state, etc.); then coerce in-place to JSON-compatible + # primitives instead of doing a json.dumps/json.loads string round-trip. executor_info = executor.executor_info - result = json.loads(json.dumps(executor_info.model_dump(), default=_json_default)) + dumped = executor_info.model_dump() + dumped["custom_info"] = self._strip_heavy_fields(dumped.get("custom_info"), executor_type) + result = _coerce_json_compatible(dumped) # Add metadata result["executor_id"] = executor_id @@ -626,9 +720,6 @@ def _format_executor_info( # Convert TradeType enum or int to string result["side"] = side.name if hasattr(side, 'name') else str(side) - # Filter out heavy fields from custom_info - result["custom_info"] = self._strip_heavy_fields(result.get("custom_info"), executor_type) - # Add log capture info result["error_count"] = self._log_capture.get_error_count(executor_id) result["last_error"] = self._log_capture.get_last_error(executor_id) @@ -766,7 +857,6 @@ async def get_performance_report( if self.db_manager: try: async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) db_data = await repo.get_performance_report(controller_id=controller_id) @@ -861,7 +951,6 @@ async def _persist_executor_created(self, executor_id: str, executor: ExecutorBa metadata = self._executor_metadata.get(executor_id, {}) async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) await repo.create_executor( @@ -950,7 +1039,6 @@ async def _persist_executor_completed(self, executor_id: str, executor: Executor logger.debug(f"Failed to serialize error logs for {executor_id}: {e}") async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) await repo.update_executor( @@ -1119,7 +1207,6 @@ async def _persist_position_hold(self, position: PositionHold): return try: async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) await repo.upsert_position_hold( account_name=position.account_name, @@ -1223,7 +1310,6 @@ async def clear_position_held( if self.db_manager: try: async with self.db_manager.get_session_context() as session: - from database.repositories.executor_repository import ExecutorRepository repo = ExecutorRepository(session) cleared = await repo.clear_position_hold( account_name=account_name, diff --git a/services/funding_recorder.py b/services/funding_recorder.py index 9560939c..a4a57350 100644 --- a/services/funding_recorder.py +++ b/services/funding_recorder.py @@ -1,12 +1,11 @@ import asyncio import logging -from datetime import datetime from decimal import Decimal, InvalidOperation from typing import Dict, Optional from hummingbot.connector.connector_base import ConnectorBase from hummingbot.core.event.event_forwarder import SourceInfoEventForwarder -from hummingbot.core.event.events import MarketEvent, FundingPaymentCompletedEvent +from hummingbot.core.event.events import FundingPaymentCompletedEvent, MarketEvent from database import AsyncDatabaseManager, FundingRepository @@ -23,6 +22,9 @@ def __init__(self, db_manager: AsyncDatabaseManager, account_name: str, connecto self.connector_name = connector_name self._connector: Optional[ConnectorBase] = None self.logger = logging.getLogger(__name__) + + # Strong references to in-flight event handler tasks so they are not garbage-collected before completing + self._pending_tasks: set[asyncio.Task] = set() # Create event forwarder for funding payments self._funding_payment_forwarder = SourceInfoEventForwarder(self._did_funding_payment) @@ -53,11 +55,26 @@ async def stop(self): for event, forwarder in self._event_pairs: self._connector.remove_listener(event, forwarder) self.logger.info(f"FundingRecorder stopped for {self.account_name}/{self.connector_name}") - + + # Wait for in-flight write tasks so no funding payment records are lost + if self._pending_tasks: + await asyncio.gather(*self._pending_tasks, return_exceptions=True) + + def _create_tracked_task(self, coro) -> asyncio.Task: + """Create a task and keep a strong reference to it until it completes. + + The event loop only keeps weak references to tasks, so without this a pending + task could be garbage-collected before finishing, dropping the DB write. + """ + task = asyncio.create_task(coro) + self._pending_tasks.add(task) + task.add_done_callback(self._pending_tasks.discard) + return task + def _did_funding_payment(self, event_tag: int, market: ConnectorBase, event: FundingPaymentCompletedEvent): """Handle funding payment events - called by SourceInfoEventForwarder""" try: - asyncio.create_task(self._handle_funding_payment(event)) + self._create_tracked_task(self._handle_funding_payment(event)) except Exception as e: self.logger.error(f"Error in _did_funding_payment: {e}") @@ -120,17 +137,16 @@ async def record_funding_payment(self, event: FundingPaymentCompletedEvent, }) # Save to database - async with self.db_manager.get_session() as session: + async with self.db_manager.get_session_context() as session: funding_repo = FundingRepository(session) - + # Check if funding payment already exists if await funding_repo.funding_payment_exists(funding_data["funding_payment_id"]): self.logger.info(f"Funding payment {funding_data['funding_payment_id']} already exists, skipping") return - + funding_payment = await funding_repo.create_funding_payment(funding_data) - await session.commit() - + self.logger.info( f"Recorded funding payment for {account_name}/{connector_name}: " f"{event.trading_pair} - Rate: {funding_rate}, Payment: {funding_payment} " diff --git a/services/gateway_client.py b/services/gateway_client.py index 5d1f1f39..31a46d24 100644 --- a/services/gateway_client.py +++ b/services/gateway_client.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import aiohttp @@ -266,7 +266,7 @@ async def get_config(self, namespace: str) -> Dict: """Get configuration for a specific namespace (connector or chain-network)""" return await self._request("GET", "config", params={"namespace": namespace}) - async def update_config(self, namespace: str, path: str, value: any) -> Dict: + async def update_config(self, namespace: str, path: str, value: Any) -> Dict: """Update a configuration value for a namespace""" return await self._request("POST", "config/update", json={ "namespace": namespace, diff --git a/services/gateway_service.py b/services/gateway_service.py index d7378838..2448013d 100644 --- a/services/gateway_service.py +++ b/services/gateway_service.py @@ -2,7 +2,7 @@ import os import platform import shutil -from typing import Optional, Dict +from typing import Any, Dict, Optional import docker from docker.errors import DockerException @@ -95,7 +95,7 @@ def get_status(self) -> GatewayStatus: port=port ) - def start(self, config: GatewayConfig) -> Dict[str, any]: + def start(self, config: GatewayConfig) -> Dict[str, Any]: """ Start the Gateway container. If a container already exists, it will be stopped and removed before creating a new one. @@ -201,7 +201,7 @@ def start(self, config: GatewayConfig) -> Dict[str, any]: "message": f"Failed to start Gateway: {str(e)}" } - def stop(self) -> Dict[str, any]: + def stop(self) -> Dict[str, Any]: """Stop the Gateway container""" container = self._get_gateway_container() @@ -226,7 +226,7 @@ def stop(self) -> Dict[str, any]: "message": f"Failed to stop Gateway: {str(e)}" } - def restart(self, config: Optional[GatewayConfig] = None) -> Dict[str, any]: + def restart(self, config: Optional[GatewayConfig] = None) -> Dict[str, Any]: """ Restart the Gateway container. If config is provided, the container will be recreated with the new configuration. @@ -271,7 +271,7 @@ def restart(self, config: Optional[GatewayConfig] = None) -> Dict[str, any]: "message": f"Failed to restart Gateway: {str(e)}" } - def remove(self, remove_data: bool = False) -> Dict[str, any]: + def remove(self, remove_data: bool = False) -> Dict[str, Any]: """ Remove the Gateway container and optionally its data. @@ -337,7 +337,7 @@ def remove(self, remove_data: bool = False) -> Dict[str, any]: "message": f"Gateway container removed but failed to remove data: {str(e)}" } - def get_logs(self, tail: int = 100) -> Dict[str, any]: + def get_logs(self, tail: int = 100) -> Dict[str, Any]: """Get logs from the Gateway container""" container = self._get_gateway_container() diff --git a/services/gateway_transaction_poller.py b/services/gateway_transaction_poller.py index 33d810f4..8c30bc5a 100644 --- a/services/gateway_transaction_poller.py +++ b/services/gateway_transaction_poller.py @@ -8,15 +8,12 @@ """ import asyncio import logging -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from decimal import Decimal -from typing import Dict, List, Optional - -from sqlalchemy import select -from sqlalchemy.orm import selectinload +from typing import Dict, Optional from database import AsyncDatabaseManager -from database.models import GatewayCLMMEvent, GatewayCLMMPosition +from database.models import GatewayCLMMPosition from database.repositories import GatewayCLMMRepository, GatewaySwapRepository from services.gateway_client import GatewayClient @@ -194,10 +191,7 @@ async def _poll_clmm_event_transaction(self, event, clmm_repo: GatewayCLMMReposi """Poll a specific CLMM event transaction status.""" try: # Get the position by ID from the event's position_id foreign key - result = await clmm_repo.session.execute( - select(GatewayCLMMPosition).where(GatewayCLMMPosition.id == event.position_id) - ) - position = result.scalar_one_or_none() + position = await clmm_repo.get_position_by_id(event.position_id) if not position: logger.error(f"Position not found for CLMM event {event.transaction_hash}") @@ -245,11 +239,8 @@ async def _poll_clmm_event_transaction(self, event, clmm_repo: GatewayCLMMReposi async def _update_position_from_event(self, event, clmm_repo: GatewayCLMMRepository): """Update CLMM position state based on confirmed event.""" try: - # Get position by ID using the existing clmm_repo session - result = await clmm_repo.session.execute( - select(GatewayCLMMPosition).where(GatewayCLMMPosition.id == event.position_id) - ) - position = result.scalar_one_or_none() + # Get position by ID using the repository + position = await clmm_repo.get_position_by_id(event.position_id) if not position: logger.error(f"Position not found for event {event.id}") diff --git a/services/gateway_wallet_service.py b/services/gateway_wallet_service.py new file mode 100644 index 00000000..687b4ae8 --- /dev/null +++ b/services/gateway_wallet_service.py @@ -0,0 +1,305 @@ +import asyncio +import logging +from decimal import Decimal +from typing import Dict, List, Optional + +from fastapi import HTTPException + +from services.gateway_client import GatewayClient + +# Create module-specific logger +logger = logging.getLogger(__name__) + + +def balance_entry(token: str, units: Decimal, price: Optional[Decimal], + available_units: Optional[Decimal] = None) -> Dict: + """Build the standard token balance entry dict shared across balance endpoints. + + Args: + token: Token symbol + units: Token balance + price: Token price (None means unknown -> price/value reported as 0.0) + available_units: Available balance (defaults to units when not provided) + """ + if available_units is None: + available_units = units + return { + "token": token, + "units": float(units), + "price": float(price) if price is not None else 0.0, + "value": float(price * units) if price is not None else 0.0, + "available_units": float(available_units), + } + + +class GatewayWalletService: + """ + Gateway wallet management: wallet CRUD plus balance and price retrieval through the Gateway service. + Gateway manages its own encrypted wallets; this service only talks to it over HTTP via GatewayClient. + """ + + def __init__(self, gateway_client: GatewayClient): + """ + Initialize the GatewayWalletService. + + Args: + gateway_client: Client used for all Gateway HTTP interactions. + """ + self.gateway_client = gateway_client + + async def _require_gateway(self) -> None: + """Raise a 503 HTTPException if the Gateway service is not reachable.""" + if not await self.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + async def get_gateway_wallets(self) -> List[Dict]: + """ + Get all wallets from Gateway. Gateway manages its own encrypted wallets. + + Returns: + List of wallet information from Gateway, with default_address included for each chain + """ + await self._require_gateway() + + try: + wallets = await self.gateway_client.get_wallets() + + # Enrich with default wallet info for each chain + for wallet_group in wallets: + chain = wallet_group.get("chain") + if chain: + default_wallet = await self.gateway_client.get_default_wallet_address(chain) + wallet_group["default_address"] = default_wallet or "" + + return wallets + except Exception as e: + logger.error(f"Error getting Gateway wallets: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get wallets: {str(e)}") + + async def add_gateway_wallet(self, chain: str, private_key: str, set_default: bool = True) -> Dict: + """ + Add a wallet to Gateway. Gateway handles encryption internally. + + Args: + chain: Blockchain chain (e.g., 'solana', 'ethereum') + private_key: Wallet private key + set_default: Set as default wallet for this chain (default: True) + + Returns: + Dictionary with wallet information from Gateway + """ + await self._require_gateway() + + try: + result = await self.gateway_client.add_wallet(chain, private_key, set_default=set_default) + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Gateway error: {result['error']}") + + logger.info(f"Added {chain} wallet {result.get('address')} to Gateway") + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error adding Gateway wallet: {e}") + raise HTTPException(status_code=500, detail=f"Failed to add wallet: {str(e)}") + + async def remove_gateway_wallet(self, chain: str, address: str) -> Dict: + """ + Remove a wallet from Gateway. + + Args: + chain: Blockchain chain + address: Wallet address to remove + + Returns: + Success message + """ + await self._require_gateway() + + try: + result = await self.gateway_client.remove_wallet(chain, address) + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Gateway error: {result['error']}") + + logger.info(f"Removed {chain} wallet {address} from Gateway") + return {"success": True, "message": f"Successfully removed {chain} wallet"} + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error removing Gateway wallet: {e}") + raise HTTPException(status_code=500, detail=f"Failed to remove wallet: {str(e)}") + + async def get_gateway_balances(self, chain: str, address: str, network: Optional[str] = None, + tokens: Optional[List[str]] = None) -> List[Dict]: + """ + Get Gateway wallet balances with pricing from rate sources. + + Args: + chain: Blockchain chain + address: Wallet address + network: Optional network name (if not provided, uses default network for chain) + tokens: Optional list of token symbols to query + + Returns: + List of token balance dictionaries with prices from rate sources + """ + await self._require_gateway() + + try: + # Get default network for chain if not provided + if not network: + network = await self.gateway_client.get_default_network(chain) + if not network: + raise HTTPException(status_code=400, detail=f"Could not determine network for chain '{chain}'") + + # Get balances from Gateway + balances_response = await self.gateway_client.get_balances(chain, network, address, tokens=tokens) + + if "error" in balances_response: + raise HTTPException(status_code=400, detail=f"Gateway error: {balances_response['error']}") + + # Format balances list + balances = balances_response.get("balances", {}) + balances_list = [] + + for token, balance in balances.items(): + if balance and float(balance) > 0: + balances_list.append({ + "token": token, + "units": Decimal(str(balance)) + }) + + # Get prices for tokens + unique_tokens = [b["token"] for b in balances_list] + all_prices = {} + + # Fetch prices for Gateway tokens + if unique_tokens: + try: + fetched_prices = await self._fetch_gateway_prices_immediate( + chain, network, unique_tokens + ) + for token, price in fetched_prices.items(): + if price > 0: + all_prices[token] = price + except Exception as e: + logger.warning(f"Error fetching gateway prices: {e}") + + # Format final result with prices + formatted_balances = [] + for balance in balances_list: + token = balance["token"] + if "USD" in token: + price = Decimal("1") + else: + # all_prices is now keyed by token name directly + price = Decimal(str(all_prices.get(token, 0))) + + formatted_balances.append(balance_entry(token, balance["units"], price)) + + return formatted_balances + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting Gateway balances: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get balances: {str(e)}") + + async def _fetch_gateway_prices_immediate(self, chain: str, network: str, + tokens: List[str]) -> Dict[str, Decimal]: + """ + Fetch prices immediately from Gateway for the given tokens. + This is used to get prices right away instead of waiting for the background update task. + + Args: + chain: Blockchain chain (e.g., 'solana', 'ethereum') + network: Network name (e.g., 'mainnet-beta', 'mainnet') + tokens: List of token symbols to get prices for + + Returns: + Dictionary mapping token symbol to price in USDC + """ + from hummingbot.core.data_type.common import TradeType + from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient + from hummingbot.core.rate_oracle.rate_oracle import RateOracle + + gateway_client = GatewayHttpClient.get_instance() + rate_oracle = RateOracle.get_instance() + prices = {} + + # Construct full network name (e.g., "solana-mainnet-beta") + full_network = f"{chain}-{network}" + + # Create tasks for all tokens in parallel + tasks = [] + task_tokens = [] + quote_asset = "USDC" + + # On ethereum networks, use WETH price for ETH to avoid duplicate calls + eth_needs_weth_price = False + if chain == "ethereum": + has_eth = any(t.upper() == "ETH" for t in tokens) + has_weth = any(t.upper() == "WETH" for t in tokens) + if has_eth and not has_weth: + # Replace ETH with WETH for fetching + tokens = [t if t.upper() != "ETH" else "WETH" for t in tokens] + eth_needs_weth_price = True + logger.debug("Replacing ETH with WETH for price fetch on ethereum") + elif has_eth and has_weth: + # Remove ETH, will copy WETH price later + tokens = [t for t in tokens if t.upper() != "ETH"] + eth_needs_weth_price = True + logger.debug("Removing duplicate ETH, will use WETH price on ethereum") + + for token in tokens: + token_upper = token.upper() + + # Skip same-token quotes (e.g., USDC/USDC) - price is always 1 + if token_upper == quote_asset.upper(): + prices[token] = Decimal("1") + rate_oracle.set_price(f"{token}-{quote_asset}", Decimal("1")) + logger.debug(f"Skipping same-token quote for {token}, price=1") + continue + + try: + # get_price will auto-fetch dex/trading_type from network's swap provider + task = gateway_client.get_price( + network=full_network, + base_asset=token, + quote_asset=quote_asset, + amount=Decimal("1"), + side=TradeType.SELL + ) + tasks.append(task) + task_tokens.append(token) + except Exception as e: + logger.warning(f"Error preparing price request for {token}: {e}") + continue + + if tasks: + try: + results = await asyncio.gather(*tasks, return_exceptions=True) + for token, result in zip(task_tokens, results): + if isinstance(result, Exception): + logger.warning(f"Error fetching price for {token}: {result}") + elif result and "price" in result: + price = Decimal(str(result["price"])) + prices[token] = price + # Also update the rate oracle so future lookups can find it + trading_pair = f"{token}-USDC" + rate_oracle.set_price(trading_pair, price) + logger.debug(f"Fetched immediate price for {token}: {price} USDC") + except Exception as e: + logger.error(f"Error fetching gateway prices: {e}", exc_info=True) + + # Copy WETH price to ETH on ethereum networks + if eth_needs_weth_price and "WETH" in prices: + prices["ETH"] = prices["WETH"] + rate_oracle.set_price("ETH-USDC", prices["WETH"]) + logger.debug(f"Copied WETH price to ETH: {prices['WETH']} USDC") + + return prices diff --git a/services/market_data_service.py b/services/market_data_service.py index ab002857..4b5d46c5 100644 --- a/services/market_data_service.py +++ b/services/market_data_service.py @@ -381,24 +381,28 @@ def validate_connector(connector_name: str) -> None: raise UnsupportedConnectorException(connector_name) @staticmethod - async def validate_trading_pair(connector_name: str, trading_pair: str, interval: str = "1m") -> None: + async def _validate_pair(feed, connector_name: str, trading_pair: str) -> None: """ - Validate that a trading pair exists on the exchange by attempting a small REST candle fetch. + Validate that a trading pair exists on the exchange by loading the feed's exchange + data and probing a single REST candle. + + Called once per feed, at creation time, so the cost is not paid on every request. Raises: ValueError: If the trading pair does not exist or the exchange returns an error. """ - import time as _time - feed = CandlesFactory.get_candle(CandlesConfig( - connector=connector_name, - trading_pair=trading_pair, - interval=interval, - max_records=10, - )) try: - end_time = int(_time.time()) - candles = await feed.fetch_candles(end_time=end_time, limit=1) - if candles is None or len(candles) == 0: + # Some feeds (e.g. hyperliquid spot) need exchange data (symbol maps, + # quanto multipliers, etc.) loaded before a REST candle fetch can build + # its payload. start_network() does this internally, but fetch_candles() + # does not, so initialize explicitly here. No-op on feeds that don't need it. + await feed.initialize_exchange_data() + # Probe a generous window: a 1-candle probe spans only the current (often + # incomplete) interval, which is empty for illiquid pairs. fetch_candles + # returns a 0-d numpy array (np.array(None)) when no candles come back, so + # check ndim before len() to stay numpy-safe. + candles = await feed.fetch_candles(end_time=int(time.time()), limit=50) + if candles is None or getattr(candles, "ndim", 0) < 2 or len(candles) == 0: raise ValueError( f"Trading pair '{trading_pair}' not found on '{connector_name}'. " f"No candle data returned." @@ -410,33 +414,40 @@ async def validate_trading_pair(connector_name: str, trading_pair: str, interval f"Trading pair '{trading_pair}' appears to be invalid on '{connector_name}': {e}" ) - def get_candles_feed(self, config: CandlesConfig): + async def get_candles_feed(self, config: CandlesConfig): """ Get or create a candles feed. + On first creation the trading pair is validated (exchange data load + a one-candle + REST probe). Cached feeds are returned directly, so repeated requests for the same + feed pay no extra REST cost and never re-initialize exchange data. + Args: config: CandlesConfig for the desired feed Returns: Candle feed instance + + Raises: + ValueError: If the trading pair does not exist on the exchange. """ feed_key = self._generate_feed_key( FeedType.CANDLES, config.connector, config.trading_pair, config.interval ) - self._last_access_times[feed_key] = time.time() - self._feed_configs[feed_key] = (FeedType.CANDLES, config) - if feed_key not in self._candle_feeds: self.validate_connector(config.connector) feed = CandlesFactory.get_candle(config) + await self._validate_pair(feed, config.connector, config.trading_pair) feed.start() self._candle_feeds[feed_key] = feed + self._feed_configs[feed_key] = (FeedType.CANDLES, config) logger.info(f"Created candle feed: {feed_key}") + self._last_access_times[feed_key] = time.time() return self._candle_feeds[feed_key] - def get_candles_df( + async def get_candles_df( self, connector_name: str, trading_pair: str, @@ -462,7 +473,7 @@ def get_candles_df( max_records=max_records ) - feed = self.get_candles_feed(config) + feed = await self.get_candles_feed(config) return feed.candles_df def stop_candle_feed(self, config: CandlesConfig): diff --git a/services/orders_recorder.py b/services/orders_recorder.py index ad563617..6aa6828f 100644 --- a/services/orders_recorder.py +++ b/services/orders_recorder.py @@ -28,6 +28,9 @@ def __init__(self, db_manager: AsyncDatabaseManager, account_name: str, connecto self.connector_name = connector_name self._connector: Optional[ConnectorBase] = None + # Strong references to in-flight event handler tasks so they are not garbage-collected before completing + self._pending_tasks: set[asyncio.Task] = set() + # Create event forwarders similar to MarketsRecorder self._create_order_forwarder = SourceInfoEventForwarder(self._did_create_order) self._fill_order_forwarder = SourceInfoEventForwarder(self._did_fill_order) @@ -59,27 +62,28 @@ def start(self, connector: ConnectorBase): # Subscribe to order events using the same pattern as MarketsRecorder for event, forwarder in self._event_pairs: connector.add_listener(event, forwarder) - logger.info(f"OrdersRecorder: Added listener for {event} with forwarder {forwarder}") + logger.debug(f"OrdersRecorder: Added listener for {event} with forwarder {forwarder}") # Debug: Check if listeners were actually added - if hasattr(connector, '_event_listeners'): + if logger.isEnabledFor(logging.DEBUG) and hasattr(connector, '_event_listeners'): listeners = connector._event_listeners.get(event, []) - logger.info(f"OrdersRecorder: Event {event} now has {len(listeners)} listeners") + logger.debug(f"OrdersRecorder: Event {event} now has {len(listeners)} listeners") for i, listener in enumerate(listeners): - logger.info(f"OrdersRecorder: Listener {i}: {listener}") + logger.debug(f"OrdersRecorder: Listener {i}: {listener}") logger.info( f"OrdersRecorder started for {self.account_name}/{self.connector_name} with {len(self._event_pairs)} event listeners") # Debug: Print connector info - logger.info(f"OrdersRecorder: Connector type: {type(connector)}") - logger.info(f"OrdersRecorder: Connector name: {getattr(connector, 'name', 'unknown')}") - logger.info(f"OrdersRecorder: Connector ready: {getattr(connector, 'ready', 'unknown')}") + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"OrdersRecorder: Connector type: {type(connector)}") + logger.debug(f"OrdersRecorder: Connector name: {getattr(connector, 'name', 'unknown')}") + logger.debug(f"OrdersRecorder: Connector ready: {getattr(connector, 'ready', 'unknown')}") # Test if forwarders are callable for event, forwarder in self._event_pairs: if callable(forwarder): - logger.info(f"OrdersRecorder: Forwarder for {event} is callable") + logger.debug(f"OrdersRecorder: Forwarder for {event} is callable") else: logger.error(f"OrdersRecorder: Forwarder for {event} is NOT callable: {type(forwarder)}") @@ -90,8 +94,23 @@ async def stop(self): for event, forwarder in self._event_pairs: self._connector.remove_listener(event, forwarder) + # Wait for in-flight write tasks so no order/trade records are lost + if self._pending_tasks: + await asyncio.gather(*self._pending_tasks, return_exceptions=True) + logger.info(f"OrdersRecorder stopped for {self.account_name}/{self.connector_name}") + def _create_tracked_task(self, coro) -> asyncio.Task: + """Create a task and keep a strong reference to it until it completes. + + The event loop only keeps weak references to tasks, so without this a pending + task could be garbage-collected before finishing, dropping the DB write. + """ + task = asyncio.create_task(coro) + self._pending_tasks.add(task) + task.add_done_callback(self._pending_tasks.discard) + return task + def _extract_error_message(self, event) -> str: """Extract error message from various possible event attributes.""" # Try different possible attribute names for error messages @@ -107,47 +126,47 @@ def _extract_error_message(self, event) -> str: def _did_create_order(self, event_tag: int, market: ConnectorBase, event: Union[BuyOrderCreatedEvent, SellOrderCreatedEvent]): """Handle order creation events - called by SourceInfoEventForwarder""" - logger.info(f"OrdersRecorder: _did_create_order called for order {getattr(event, 'order_id', 'unknown')}") + logger.debug(f"OrdersRecorder: _did_create_order called for order {getattr(event, 'order_id', 'unknown')}") try: # Determine trade type from event trade_type = TradeType.BUY if isinstance(event, BuyOrderCreatedEvent) else TradeType.SELL - logger.info(f"OrdersRecorder: Creating task to handle order created - {trade_type} order") - asyncio.create_task(self._handle_order_created(event, trade_type)) + logger.debug(f"OrdersRecorder: Creating task to handle order created - {trade_type} order") + self._create_tracked_task(self._handle_order_created(event, trade_type)) except Exception as e: logger.error(f"Error in _did_create_order: {e}") def _did_fill_order(self, event_tag: int, market: ConnectorBase, event: OrderFilledEvent): """Handle order fill events - called by SourceInfoEventForwarder""" try: - asyncio.create_task(self._handle_order_filled(event)) + self._create_tracked_task(self._handle_order_filled(event)) except Exception as e: logger.error(f"Error in _did_fill_order: {e}") def _did_cancel_order(self, event_tag: int, market: ConnectorBase, event: Any): """Handle order cancel events - called by SourceInfoEventForwarder""" try: - asyncio.create_task(self._handle_order_cancelled(event)) + self._create_tracked_task(self._handle_order_cancelled(event)) except Exception as e: logger.error(f"Error in _did_cancel_order: {e}") def _did_fail_order(self, event_tag: int, market: ConnectorBase, event: Any): """Handle order failure events - called by SourceInfoEventForwarder""" try: - asyncio.create_task(self._handle_order_failed(event)) + self._create_tracked_task(self._handle_order_failed(event)) except Exception as e: logger.error(f"Error in _did_fail_order: {e}") def _did_complete_order(self, event_tag: int, market: ConnectorBase, event: Any): """Handle order completion events - called by SourceInfoEventForwarder""" try: - asyncio.create_task(self._handle_order_completed(event)) + self._create_tracked_task(self._handle_order_completed(event)) except Exception as e: logger.error(f"Error in _did_complete_order: {e}") async def _handle_order_created(self, event: Union[BuyOrderCreatedEvent, SellOrderCreatedEvent], trade_type: TradeType): """Handle order creation events""" - logger.info(f"OrdersRecorder: _handle_order_created started for order {event.order_id}") + logger.debug(f"OrdersRecorder: _handle_order_created started for order {event.order_id}") try: async with self.db_manager.get_session_context() as session: order_repo = OrderRepository(session) @@ -155,20 +174,20 @@ async def _handle_order_created(self, event: Union[BuyOrderCreatedEvent, SellOrd # Check if order already exists first existing_order = await order_repo.get_order_by_client_id(event.order_id) if existing_order: - logger.info( + logger.debug( f"OrdersRecorder: Order {event.order_id} already exists with status {existing_order.status}") # Update exchange_order_id if we have it now and it was missing exchange_order_id = getattr(event, 'exchange_order_id', None) if exchange_order_id and not existing_order.exchange_order_id: existing_order.exchange_order_id = exchange_order_id - logger.info( + logger.debug( f"OrdersRecorder: Updated exchange_order_id to {exchange_order_id} for order {event.order_id}") # Update status if it's still in PENDING_CREATE or similar early state if existing_order.status in ["PENDING_CREATE", "PENDING", "SUBMITTED"]: existing_order.status = "OPEN" - logger.info(f"OrdersRecorder: Updated status to OPEN for order {event.order_id}") + logger.debug(f"OrdersRecorder: Updated status to OPEN for order {event.order_id}") await session.flush() return @@ -187,7 +206,7 @@ async def _handle_order_created(self, event: Union[BuyOrderCreatedEvent, SellOrd } await order_repo.create_order(order_data) - logger.info(f"OrdersRecorder: Successfully recorded order created: {event.order_id}") + logger.debug(f"OrdersRecorder: Successfully recorded order created: {event.order_id}") except Exception as e: logger.error(f"OrdersRecorder: Error recording order created: {e}") @@ -466,7 +485,9 @@ async def _handle_order_completed(self, event: Any): order = await order_repo.get_order_by_client_id(event.order_id) if order: order.status = "FILLED" - order.exchange_order_id = getattr(event, 'exchange_order_id', None) + eoid = getattr(event, 'exchange_order_id', None) + if eoid: + order.exchange_order_id = eoid logger.debug(f"Recorded order completed: {event.order_id}") except Exception as e: diff --git a/services/perpetual_trading_service.py b/services/perpetual_trading_service.py new file mode 100644 index 00000000..b0b24485 --- /dev/null +++ b/services/perpetual_trading_service.py @@ -0,0 +1,199 @@ +import asyncio +import logging +from typing import Any, Awaitable, Callable, Dict, List + +from fastapi import HTTPException +from hummingbot.core.data_type.common import PositionMode + +# Create module-specific logger +logger = logging.getLogger(__name__) + + +class PerpetualTradingService: + """ + Perpetual-specific trading operations: leverage, position mode and position queries. + Connector instances are resolved through an injected provider so this service stays + decoupled from account/credential management. + """ + + def __init__(self, connector_provider: Callable[[str, str], Awaitable[Any]]): + """ + Initialize the PerpetualTradingService. + + Args: + connector_provider: Async callable (account_name, connector_name) -> connector instance. + Expected to raise HTTPException if the account or connector is not found. + """ + self._connector_provider = connector_provider + + async def _get_perpetual_connector(self, account_name: str, connector_name: str): + """ + Get a perpetual connector instance with validation. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + Perpetual connector instance + + Raises: + HTTPException: If connector is not perpetual or not found + """ + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + return await self._connector_provider(account_name, connector_name) + + async def set_leverage(self, account_name: str, connector_name: str, + trading_pair: str, leverage: int) -> Dict[str, str]: + """ + Set leverage for a specific trading pair on a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + trading_pair: Trading pair to set leverage for + leverage: Leverage value (typically 1-125) + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + if not hasattr(connector, '_execute_set_leverage'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support leverage setting") + + try: + await connector._execute_set_leverage(trading_pair, leverage) + message = f"Leverage for {trading_pair} set to {leverage} on {connector_name}" + logger.info(f"Set leverage for {trading_pair} to {leverage} on {connector_name} (Account: {account_name})") + return {"status": "success", "message": message} + + except Exception as e: + logger.error(f"Failed to set leverage for {trading_pair} to {leverage}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to set leverage: {str(e)}") + + async def set_position_mode(self, account_name: str, connector_name: str, + position_mode: PositionMode) -> Dict[str, str]: + """ + Set position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + position_mode: PositionMode.HEDGE or PositionMode.ONEWAY + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + # Check if the requested position mode is supported + supported_modes = connector.supported_position_modes() + if position_mode not in supported_modes: + supported_values = [mode.value for mode in supported_modes] + raise HTTPException( + status_code=400, + detail=f"Position mode '{position_mode.value}' not supported. Supported modes: {supported_values}" + ) + + try: + # Try to call the method - it might be sync or async + result = connector.set_position_mode(position_mode) + # If it's a coroutine, await it + if asyncio.iscoroutine(result): + await result + + message = f"Position mode set to {position_mode.value} on {connector_name}" + logger.info(f"Set position mode to {position_mode.value} on {connector_name} (Account: {account_name})") + return {"status": "success", "message": message} + + except Exception as e: + logger.error(f"Failed to set position mode to {position_mode.value}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to set position mode: {str(e)}") + + async def get_position_mode(self, account_name: str, connector_name: str) -> Dict[str, str]: + """ + Get current position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + Dictionary with current position mode + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + if not hasattr(connector, 'position_mode'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position mode") + + try: + current_mode = connector.position_mode + return { + "position_mode": current_mode.value if current_mode else "UNKNOWN", + "connector": connector_name, + "account": account_name + } + + except Exception as e: + logger.error(f"Failed to get position mode: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get position mode: {str(e)}") + + async def get_account_positions(self, account_name: str, connector_name: str) -> List[Dict]: + """ + Get current positions for a specific perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + List of position dictionaries + + Raises: + HTTPException: If account/connector not found or not perpetual + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + if not hasattr(connector, 'account_positions'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position tracking") + + try: + # Force position update to ensure current market prices are used + await connector._update_positions() + + positions = [] + raw_positions = connector.account_positions + + for trading_pair, position_info in raw_positions.items(): + # Convert position data to dict format + position_dict = { + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": position_info.trading_pair, + "side": position_info.position_side.name if hasattr(position_info, 'position_side') else "UNKNOWN", + "amount": float(position_info.amount) if hasattr(position_info, 'amount') else 0.0, + "entry_price": float(position_info.entry_price) if hasattr(position_info, 'entry_price') else None, + "unrealized_pnl": float(position_info.unrealized_pnl) if hasattr(position_info, 'unrealized_pnl') else None, + "leverage": float(position_info.leverage) if hasattr(position_info, 'leverage') else None, + } + + # Only include positions with non-zero amounts + if position_dict["amount"] != 0: + positions.append(position_dict) + + return positions + + except Exception as e: + logger.error(f"Failed to get positions for {connector_name}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get positions: {str(e)}") diff --git a/services/portfolio_analytics_service.py b/services/portfolio_analytics_service.py new file mode 100644 index 00000000..d41145b9 --- /dev/null +++ b/services/portfolio_analytics_service.py @@ -0,0 +1,201 @@ +import logging +from typing import Any, Dict, List, Optional + +# Create module-specific logger +logger = logging.getLogger(__name__) + + +class PortfolioAnalyticsService: + """ + Pure portfolio-distribution math over account state data. + + This service performs no IO: it has no database, gateway or connector dependencies. It operates on plain + account-state dictionaries shaped as {account_name: {connector_name: [token_info, ...]}} where each + token_info dict contains at least "token", "units" and "value" keys. Callers may pass a live dict; the + methods snapshot it before iterating so concurrent mutations cannot affect the calculation. + """ + + def get_portfolio_distribution(self, + accounts_state: Dict[str, Dict[str, List[Dict[str, Any]]]], + account_name: Optional[str] = None) -> Dict[str, Any]: + """ + Get portfolio distribution by tokens with percentages. + + Args: + accounts_state: Account state data shaped as {account_name: {connector_name: [token_info, ...]}} + account_name: Optional account name to filter by (None aggregates all accounts) + """ + try: + # Snapshot the live dict so concurrent mutations cannot affect the iteration + accounts_state_snapshot = {account: dict(connectors) for account, connectors in accounts_state.items()} + + # Get accounts to process + accounts_to_process = [account_name] if account_name else list(accounts_state_snapshot.keys()) + + # Aggregate all tokens across accounts and connectors + token_values = {} + total_value = 0 + + for acc_name in accounts_to_process: + if acc_name in accounts_state_snapshot: + for connector_name, connector_data in accounts_state_snapshot[acc_name].items(): + for token_info in connector_data: + token = token_info.get("token", "") + value = token_info.get("value", 0) + + if token not in token_values: + token_values[token] = { + "token": token, + "total_value": 0, + "total_units": 0, + "accounts": {} + } + + token_values[token]["total_value"] += value + token_values[token]["total_units"] += token_info.get("units", 0) + total_value += value + + # Track by account + if acc_name not in token_values[token]["accounts"]: + token_values[token]["accounts"][acc_name] = { + "value": 0, + "units": 0, + "connectors": {} + } + + token_values[token]["accounts"][acc_name]["value"] += value + token_values[token]["accounts"][acc_name]["units"] += token_info.get("units", 0) + + # Track by connector within account + if connector_name not in token_values[token]["accounts"][acc_name]["connectors"]: + token_values[token]["accounts"][acc_name]["connectors"][connector_name] = { + "value": 0, + "units": 0 + } + + connector_totals = token_values[token]["accounts"][acc_name]["connectors"][connector_name] + connector_totals["value"] += value + connector_totals["units"] += token_info.get("units", 0) + + # Calculate percentages + distribution = [] + for token_data in token_values.values(): + percentage = (token_data["total_value"] / total_value * 100) if total_value > 0 else 0 + + token_dist = { + "token": token_data["token"], + "total_value": round(token_data["total_value"], 6), + "total_units": token_data["total_units"], + "percentage": round(percentage, 4), + "accounts": {} + } + + # Add account-level percentages + for acc_name, acc_data in token_data["accounts"].items(): + acc_percentage = (acc_data["value"] / total_value * 100) if total_value > 0 else 0 + token_dist["accounts"][acc_name] = { + "value": round(acc_data["value"], 6), + "units": acc_data["units"], + "percentage": round(acc_percentage, 4), + "connectors": {} + } + + # Add connector-level data + for conn_name, conn_data in acc_data["connectors"].items(): + token_dist["accounts"][acc_name]["connectors"][conn_name] = { + "value": round(conn_data["value"], 6), + "units": conn_data["units"] + } + + distribution.append(token_dist) + + # Sort by value (descending) + distribution.sort(key=lambda x: x["total_value"], reverse=True) + + return { + "total_portfolio_value": round(total_value, 6), + "token_count": len(distribution), + "distribution": distribution, + "account_filter": account_name if account_name else "all_accounts" + } + + except Exception as e: + logger.error(f"Error calculating portfolio distribution: {e}") + return { + "total_portfolio_value": 0, + "token_count": 0, + "distribution": [], + "account_filter": account_name if account_name else "all_accounts", + "error": str(e) + } + + def get_account_distribution(self, accounts_state: Dict[str, Dict[str, List[Dict[str, Any]]]]) -> Dict[str, Any]: + """ + Get portfolio distribution by accounts with percentages. + + Args: + accounts_state: Account state data shaped as {account_name: {connector_name: [token_info, ...]}} + """ + try: + # Snapshot the live dict so concurrent mutations cannot affect the iteration + accounts_state_snapshot = {account: dict(connectors) for account, connectors in accounts_state.items()} + + account_values = {} + total_value = 0 + + for acc_name, account_data in accounts_state_snapshot.items(): + account_value = 0 + connector_values = {} + + for connector_name, connector_data in account_data.items(): + connector_value = 0 + for token_info in connector_data: + value = token_info.get("value", 0) + connector_value += value + account_value += value + + connector_values[connector_name] = round(connector_value, 6) + + account_values[acc_name] = { + "total_value": round(account_value, 6), + "connectors": connector_values + } + total_value += account_value + + # Calculate percentages + distribution = [] + for acc_name, acc_data in account_values.items(): + percentage = (acc_data["total_value"] / total_value * 100) if total_value > 0 else 0 + + connector_dist = {} + for conn_name, conn_value in acc_data["connectors"].items(): + conn_percentage = (conn_value / total_value * 100) if total_value > 0 else 0 + connector_dist[conn_name] = { + "value": conn_value, + "percentage": round(conn_percentage, 4) + } + + distribution.append({ + "account": acc_name, + "total_value": acc_data["total_value"], + "percentage": round(percentage, 4), + "connectors": connector_dist + }) + + # Sort by value (descending) + distribution.sort(key=lambda x: x["total_value"], reverse=True) + + return { + "total_portfolio_value": round(total_value, 6), + "account_count": len(distribution), + "distribution": distribution + } + + except Exception as e: + logger.error(f"Error calculating account distribution: {e}") + return { + "total_portfolio_value": 0, + "account_count": 0, + "distribution": [], + "error": str(e) + } diff --git a/services/trading_history_service.py b/services/trading_history_service.py new file mode 100644 index 00000000..96f07d09 --- /dev/null +++ b/services/trading_history_service.py @@ -0,0 +1,188 @@ +""" +TradingHistoryService provides read-only access to persisted trading history +(orders, trades and funding payments). + +This concern was extracted out of the AccountsService god-class: AccountsService +stays focused on account/credential/balance state, while the database read +wrappers for orders/trades/funding live here behind a single session+error +helper (``_run_in_repo``). +""" +import logging +from typing import Dict, List, Optional + +from database import AsyncDatabaseManager, FundingRepository, OrderRepository, TradeRepository + +logger = logging.getLogger(__name__) + + +class TradingHistoryService: + """Read-only queries over persisted orders, trades and funding payments.""" + + def __init__(self, db_manager: AsyncDatabaseManager): + """ + Initialize the TradingHistoryService. + + Args: + db_manager: AsyncDatabaseManager for persistence (shared, created once at startup) + """ + self.db_manager = db_manager + + async def _run_in_repo(self, repo_cls, fn, default, error_message): + """Run ``fn`` against a freshly constructed repository inside a session. + + Collapses the repeated ``get_session_context + try/except`` scaffold: a + new session is opened, ``repo_cls(session)`` is built and passed to + ``fn`` (which performs the read and any to_dict conversion). On any + exception the error is logged and ``default`` is returned. + + Args: + repo_cls: Repository class to instantiate with the session. + fn: Async callable receiving the repository instance. + default: Value returned (defaults-on-error) if ``fn`` raises. May be + a callable that receives the raised exception and returns the + default value (used when the default embeds the error). + error_message: Prefix used when logging the exception. + + Returns: + The result of ``fn`` or ``default`` on error. + """ + try: + async with self.db_manager.get_session_context() as session: + return await fn(repo_cls(session)) + except Exception as e: + logger.error(f"{error_message}: {e}") + return default(e) if callable(default) else default + + async def get_orders(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, status: Optional[str] = None, + start_time: Optional[int] = None, end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Dict]: + """Get order history using OrderRepository.""" + async def _fn(order_repo): + orders = await order_repo.get_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + status=status, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + return [order_repo.to_dict(order) for order in orders] + + return await self._run_in_repo(OrderRepository, _fn, [], "Error getting orders") + + async def get_active_orders_history(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None) -> List[Dict]: + """Get active orders from database using OrderRepository.""" + async def _fn(order_repo): + orders = await order_repo.get_active_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair + ) + return [order_repo.to_dict(order) for order in orders] + + return await self._run_in_repo(OrderRepository, _fn, [], "Error getting active orders") + + async def get_orders_summary(self, account_name: Optional[str] = None, start_time: Optional[int] = None, + end_time: Optional[int] = None) -> Dict: + """Get order summary statistics using OrderRepository.""" + async def _fn(order_repo): + return await order_repo.get_orders_summary( + account_name=account_name, + start_time=start_time, + end_time=end_time + ) + + return await self._run_in_repo( + OrderRepository, + _fn, + { + "total_orders": 0, + "filled_orders": 0, + "cancelled_orders": 0, + "failed_orders": 0, + "active_orders": 0, + "fill_rate": 0, + }, + "Error getting orders summary", + ) + + async def get_trades(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, trade_type: Optional[str] = None, + start_time: Optional[int] = None, end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Dict]: + """Get trade history using TradeRepository.""" + async def _fn(trade_repo): + trade_order_pairs = await trade_repo.get_trades_with_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + trade_type=trade_type, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + return [trade_repo.to_dict(trade, order) for trade, order in trade_order_pairs] + + return await self._run_in_repo(TradeRepository, _fn, [], "Error getting trades") + + async def get_funding_payments(self, account_name: str, connector_name: str = None, + trading_pair: str = None, limit: int = 100) -> List[Dict]: + """ + Get funding payment history for an account. + + Args: + account_name: Name of the account + connector_name: Optional connector name filter + trading_pair: Optional trading pair filter + limit: Maximum number of records to return + + Returns: + List of funding payment dictionaries + """ + async def _fn(funding_repo): + funding_payments = await funding_repo.get_funding_payments( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + limit=limit + ) + return [funding_repo.to_dict(payment) for payment in funding_payments] + + return await self._run_in_repo(FundingRepository, _fn, [], "Error getting funding payments") + + async def get_total_funding_fees(self, account_name: str, connector_name: str, + trading_pair: str) -> Dict: + """ + Get total funding fees for a specific trading pair. + + Args: + account_name: Name of the account + connector_name: Name of the connector + trading_pair: Trading pair to get fees for + + Returns: + Dictionary with total funding fees information + """ + async def _fn(funding_repo): + return await funding_repo.get_total_funding_fees( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair + ) + + return await self._run_in_repo( + FundingRepository, + _fn, + lambda e: { + "total_funding_fees": 0, + "payment_count": 0, + "fee_currency": None, + "error": str(e), + }, + "Error getting total funding fees", + ) diff --git a/services/trading_service.py b/services/trading_service.py index 4e648725..11041d61 100644 --- a/services/trading_service.py +++ b/services/trading_service.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set from hummingbot.connector.connector_base import ConnectorBase -from hummingbot.core.data_type.common import OrderType, PositionAction, TradeType +from hummingbot.core.data_type.common import OrderType, PositionAction if TYPE_CHECKING: from services.market_data_service import MarketDataService @@ -395,10 +395,7 @@ class TradingService: """ Centralized trading service using UnifiedConnectorService. - This service manages: - - Trading interfaces for each account (executor-compatible) - - Order placement and cancellation - - Position management for perpetuals + This service manages trading interfaces for each account (executor-compatible). """ def __init__( @@ -448,165 +445,6 @@ def get_all_trading_interfaces(self) -> Dict[str, AccountTradingInterface]: """Get all active trading interfaces.""" return self._trading_interfaces.copy() - # ==================== Direct Trading Operations ==================== - - async def place_order( - self, - account_name: str, - connector_name: str, - trading_pair: str, - trade_type: TradeType, - amount: Decimal, - order_type: OrderType, - price: Optional[Decimal] = None, - position_action: PositionAction = PositionAction.NIL - ) -> str: - """ - Place an order on an exchange. - - Args: - account_name: Account to use - connector_name: Exchange connector name - trading_pair: Trading pair - trade_type: BUY or SELL - amount: Order amount - order_type: LIMIT, MARKET, etc. - price: Order price (required for LIMIT orders) - position_action: Position action for perpetuals - - Returns: - Client order ID - """ - interface = self.get_trading_interface(account_name) - await interface.ensure_connector(connector_name) - - if trade_type == TradeType.BUY: - return interface.buy( - connector_name=connector_name, - trading_pair=trading_pair, - amount=amount, - order_type=order_type, - price=price if price else Decimal("NaN"), - position_action=position_action - ) - else: - return interface.sell( - connector_name=connector_name, - trading_pair=trading_pair, - amount=amount, - order_type=order_type, - price=price if price else Decimal("NaN"), - position_action=position_action - ) - - async def cancel_order( - self, - account_name: str, - connector_name: str, - trading_pair: str, - order_id: str - ) -> str: - """ - Cancel an order. - - Args: - account_name: Account name - connector_name: Exchange connector name - trading_pair: Trading pair - order_id: Client order ID to cancel - - Returns: - Client order ID that was cancelled - """ - interface = self.get_trading_interface(account_name) - return interface.cancel(connector_name, trading_pair, order_id) - - def get_active_orders( - self, - account_name: str, - connector_name: str - ) -> List: - """ - Get active orders for an account/connector. - - Args: - account_name: Account name - connector_name: Exchange connector name - - Returns: - List of active orders - """ - interface = self.get_trading_interface(account_name) - return interface.get_active_orders(connector_name) - - # ==================== Position Management ==================== - - async def get_positions( - self, - account_name: str, - connector_name: str - ) -> Dict: - """ - Get positions for a perpetual connector. - - Args: - account_name: Account name - connector_name: Exchange connector name - - Returns: - Dictionary of positions - """ - connector = await self._connector_service.get_trading_connector( - account_name, connector_name - ) - - if hasattr(connector, 'account_positions'): - return { - str(pos.trading_pair): { - "trading_pair": pos.trading_pair, - "position_side": pos.position_side.name, - "unrealized_pnl": float(pos.unrealized_pnl), - "entry_price": float(pos.entry_price), - "amount": float(pos.amount), - "leverage": pos.leverage - } - for pos in connector.account_positions.values() - } - return {} - - async def set_leverage( - self, - account_name: str, - connector_name: str, - trading_pair: str, - leverage: int - ) -> bool: - """ - Set leverage for a trading pair on a perpetual connector. - - Args: - account_name: Account name - connector_name: Exchange connector name - trading_pair: Trading pair - leverage: Leverage value - - Returns: - True if successful - """ - connector = await self._connector_service.get_trading_connector( - account_name, connector_name - ) - - if hasattr(connector, 'set_leverage'): - try: - await connector.set_leverage(trading_pair, leverage) - logger.info(f"Set leverage to {leverage}x for {trading_pair} on {connector_name}") - return True - except Exception as e: - logger.error(f"Error setting leverage: {e}") - return False - return False - # ==================== Lifecycle ==================== async def stop(self): diff --git a/services/unified_connector_service.py b/services/unified_connector_service.py index b72979b2..2a54a7c8 100644 --- a/services/unified_connector_service.py +++ b/services/unified_connector_service.py @@ -14,7 +14,7 @@ import logging import time from decimal import Decimal -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger from hummingbot.client.config.config_helpers import ClientConfigAdapter, api_keys_from_connector_config_map, get_connector_class @@ -64,8 +64,8 @@ def __init__(self, secrets_manager: ETHKeyFileSecretManger, db_manager=None): self._data_connectors_started: Dict[str, bool] = {} # Order and funding recorders (for trading connectors) - self._orders_recorders: Dict[str, any] = {} - self._funding_recorders: Dict[str, any] = {} + self._orders_recorders: Dict[str, Any] = {} + self._funding_recorders: Dict[str, Any] = {} self._metrics_collectors: Dict[str, TradeVolumeMetricCollector] = {} # Locks to prevent race conditions in connector creation @@ -768,6 +768,19 @@ async def _stop_connector_network(self, connector: ConnectorBase): if hasattr(connector, 'stop_network'): await connector.stop_network() + async def refresh_connector_state( + self, + connector: ConnectorBase, + connector_name: str, + account_name: str = None + ): + """Public API to refresh a single connector's state (balances, positions, orders). + + Delegates to the internal _update_connector_state implementation so callers + in sibling services don't depend on the underscore-prefixed helper. + """ + await self._update_connector_state(connector, connector_name, account_name) + async def _update_connector_state( self, connector: ConnectorBase, @@ -886,30 +899,37 @@ async def _sync_orders_to_database( if not self.db_manager: return + from database import OrderRepository + terminal_states = [ OrderState.FILLED, OrderState.CANCELED, OrderState.FAILED, OrderState.COMPLETED ] orders_to_remove = [] - for client_order_id, order in list(connector.in_flight_orders.items()): - try: - from database import OrderRepository + try: + # Single session/transaction per connector: one SELECT per order and one commit on context exit. + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) - async with self.db_manager.get_session_context() as session: - order_repo = OrderRepository(session) - db_order = await order_repo.get_order_by_client_id(client_order_id) + for client_order_id, order in list(connector.in_flight_orders.items()): + try: + db_order = await order_repo.get_order_by_client_id(client_order_id) - if db_order: - new_status = self._map_order_state_to_status(order.current_state) - if db_order.status != new_status: - await order_repo.update_order_status(client_order_id, new_status) + if db_order: + new_status = self._map_order_state_to_status(order.current_state) + if db_order.status != new_status: + db_order.status = new_status + await session.flush() - if order.current_state in terminal_states: - orders_to_remove.append(client_order_id) + if order.current_state in terminal_states: + orders_to_remove.append(client_order_id) - except Exception as e: - logger.error(f"Error syncing order {client_order_id}: {e}") + except Exception as e: + logger.error(f"Error syncing order {client_order_id}: {e}") + + except Exception as e: + logger.error(f"Error syncing orders for {account_name}/{connector_name}: {e}") for order_id in orders_to_remove: connector.in_flight_orders.pop(order_id, None) @@ -962,46 +982,53 @@ async def reconcile_active_orders(self) -> Dict[str, int]: # Snapshot tracked orders (the set was loaded from the DB at init). tracked_orders = list(connector.in_flight_orders.values()) - for order in tracked_orders: - client_order_id = order.client_order_id - note = None - try: - order_update = await connector._request_order_status(order) - new_state = order_update.new_state - except Exception as exc: - if connector._is_order_not_found_during_status_update_error(exc): - # The exchange does not know this order -> it is gone. - new_state = OrderState.CANCELED - note = "Reconciled on startup: order not found on exchange" - else: - # Transient/unknown error - do not touch the order. - logger.warning( - f"Could not verify order {client_order_id} on " - f"{account_name}/{connector_name}: {exc}" - ) + # Single session/transaction per connector: every reconciled status update is + # flushed into one shared session and committed once on context exit. Each + # order's write runs inside its own savepoint so a SQLAlchemy error on one + # order is rolled back in isolation and does not poison the rest. + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + for order in tracked_orders: + client_order_id = order.client_order_id + note = None + try: + order_update = await connector._request_order_status(order) + new_state = order_update.new_state + except Exception as exc: + if connector._is_order_not_found_during_status_update_error(exc): + # The exchange does not know this order -> it is gone. + new_state = OrderState.CANCELED + note = "Reconciled on startup: order not found on exchange" + else: + # Transient/unknown error - do not touch the order. + logger.warning( + f"Could not verify order {client_order_id} on " + f"{account_name}/{connector_name}: {exc}" + ) + summary["unverified"] += 1 + continue + + db_status = self._map_order_state_to_status(new_state) + try: + async with session.begin_nested(): + await order_repo.update_order_status( + client_order_id=client_order_id, + status=db_status, + error_message=note, + ) + except Exception as exc: + # Savepoint rolled back: this order failed to persist but the + # session stays usable for the remaining orders. + logger.error(f"Failed to persist reconciled order {client_order_id}: {exc}") summary["unverified"] += 1 continue - db_status = self._map_order_state_to_status(new_state) - try: - async with self.db_manager.get_session_context() as session: - order_repo = OrderRepository(session) - await order_repo.update_order_status( - client_order_id=client_order_id, - status=db_status, - error_message=note, - ) - except Exception as exc: - logger.error(f"Failed to persist reconciled order {client_order_id}: {exc}") - summary["unverified"] += 1 - continue - - if new_state in terminal_states: - connector.in_flight_orders.pop(client_order_id, None) - summary["reconciled_terminal"] += 1 - else: - # Keep tracking so it stays cancelable via the trading endpoints. - summary["still_open"] += 1 + if new_state in terminal_states: + connector.in_flight_orders.pop(client_order_id, None) + summary["reconciled_terminal"] += 1 + else: + # Keep tracking so it stays cancelable via the trading endpoints. + summary["still_open"] += 1 logger.info( "Order reconciliation complete: " @@ -1018,15 +1045,21 @@ async def sync_all_orders_to_database(self): The connector's built-in polling already updates in_flight_orders from the exchange. This method syncs that state to our database and cleans up closed orders. """ + tasks = [] + task_keys = [] for account_name, connectors in self._trading_connectors.items(): for connector_name, connector in connectors.items(): - try: - if not connector.in_flight_orders: - continue - await self._sync_orders_to_database(connector, account_name, connector_name) - logger.debug(f"Synced order state to DB for {account_name}/{connector_name}") - except Exception as e: - logger.error(f"Error syncing order state for {account_name}/{connector_name}: {e}") + if not connector.in_flight_orders: + continue + tasks.append(self._sync_orders_to_database(connector, account_name, connector_name)) + task_keys.append(f"{account_name}/{connector_name}") + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + for key, result in zip(task_keys, results): + if isinstance(result, Exception): + logger.error(f"Error syncing order state for {key}: {result}") + else: + logger.debug(f"Synced order state to DB for {key}") def _convert_db_order_to_in_flight(self, order_record) -> InFlightOrder: """Convert database order to InFlightOrder.""" diff --git a/services/websocket_manager.py b/services/websocket_manager.py index f7ef806b..eedb36e2 100644 --- a/services/websocket_manager.py +++ b/services/websocket_manager.py @@ -5,11 +5,11 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional +from fastapi import WebSocket +from fastapi.websockets import WebSocketDisconnect from hummingbot.core.event.event_forwarder import SourceInfoEventForwarder from hummingbot.core.event.events import OrderBookEvent, OrderBookTradeEvent from hummingbot.data_feed.candles_feed.data_types import CandlesConfig -from fastapi import WebSocket -from fastapi.websockets import WebSocketDisconnect from config import settings from services.market_data_service import MarketDataService @@ -101,17 +101,8 @@ async def handle_subscribe(self, conn_id: str, websocket: WebSocket, msg: dict): if sub_id in subs: self._cleanup_subscription(subs.pop(sub_id)) - # Validate trading pair exists before starting feed - try: - if sub_type == "candles": - await self._market_data_service.validate_trading_pair( - connector, trading_pair, sub.interval or "1m" - ) - except ValueError as e: - await self._send_error(websocket, str(e)) - return - - # Start the feed / ensure it exists + # Start the feed / ensure it exists. For candles, creating the feed also validates + # the trading pair on first use (cache hit afterwards); an invalid pair raises ValueError. try: if sub_type == "candles": config = CandlesConfig( @@ -120,10 +111,13 @@ async def handle_subscribe(self, conn_id: str, websocket: WebSocket, msg: dict): interval=sub.interval, max_records=sub.max_records, ) - self._market_data_service.get_candles_feed(config) + await self._market_data_service.get_candles_feed(config) else: # Both order_book and trades need the order book initialized await self._market_data_service.initialize_order_book(connector, trading_pair) + except ValueError as e: + await self._send_error(websocket, str(e)) + return except Exception as e: await self._send_error(websocket, f"Failed to start feed: {e}") return @@ -189,7 +183,7 @@ async def _candles_push_loop(self, websocket: WebSocket, sub: Subscription): while True: await asyncio.sleep(sub.update_interval) try: - feed = self._market_data_service.get_candles_feed(config) + feed = await self._market_data_service.get_candles_feed(config) if not feed.ready: continue df = feed.candles_df diff --git a/test/test_cors_settings.py b/test/test_cors_settings.py new file mode 100644 index 00000000..6c62ebff --- /dev/null +++ b/test/test_cors_settings.py @@ -0,0 +1,86 @@ +""" +Tests for the CORS configuration (SEC-019). + +Run with: pytest test/test_cors_settings.py -v +""" +import pytest + +from config import CORSSettings + + +def _build_client(cors: CORSSettings): + """Build a minimal app with CORSMiddleware wired exactly like main.py does.""" + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from fastapi.testclient import TestClient + + app = FastAPI() + app.add_middleware( + CORSMiddleware, + allow_origins=cors.allow_origins, + allow_origin_regex=cors.allow_origin_regex or None, + allow_credentials=cors.allow_credentials, + allow_methods=cors.allow_methods, + allow_headers=cors.allow_headers, + ) + + @app.get("/") + async def root(): + return {"status": "running"} + + return TestClient(app) + + +class TestCORSSettings: + """Tests for CORSSettings defaults and env-driven configuration.""" + + def test_default_origins_are_not_wildcard_with_credentials(self): + cors = CORSSettings() + assert cors.allow_credentials is True + assert "*" not in cors.allow_origins + assert cors.allow_origin_regex != ".*" + + def test_origins_configurable_via_environment(self, monkeypatch): + monkeypatch.setenv("CORS_ALLOW_ORIGINS", '["https://dashboard.example.com"]') + monkeypatch.setenv("CORS_ALLOW_ORIGIN_REGEX", "") + cors = CORSSettings() + assert cors.allow_origins == ["https://dashboard.example.com"] + assert cors.allow_origin_regex == "" + + +class TestCORSMiddlewareBehavior: + """Tests that the middleware (configured as in main.py) rejects untrusted origins.""" + + def test_default_allows_localhost_origins(self): + client = _build_client(CORSSettings()) + for origin in ("http://localhost:3000", "http://127.0.0.1:8501"): + response = client.get("/", headers={"Origin": origin}) + assert response.headers.get("access-control-allow-origin") == origin + + def test_default_rejects_untrusted_origin(self): + client = _build_client(CORSSettings()) + response = client.get("/", headers={"Origin": "https://evil.example.com"}) + assert "access-control-allow-origin" not in response.headers + + preflight = client.options( + "/", + headers={"Origin": "https://evil.example.com", "Access-Control-Request-Method": "GET"}, + ) + assert preflight.status_code == 400 + assert "access-control-allow-origin" not in preflight.headers + + def test_explicit_origin_list_from_env(self, monkeypatch): + monkeypatch.setenv("CORS_ALLOW_ORIGINS", '["https://dashboard.example.com"]') + monkeypatch.setenv("CORS_ALLOW_ORIGIN_REGEX", "") + client = _build_client(CORSSettings()) + + allowed = client.get("/", headers={"Origin": "https://dashboard.example.com"}) + assert allowed.headers.get("access-control-allow-origin") == "https://dashboard.example.com" + assert allowed.headers.get("access-control-allow-credentials") == "true" + + rejected = client.get("/", headers={"Origin": "http://localhost:3000"}) + assert "access-control-allow-origin" not in rejected.headers + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_portfolio_analytics.py b/test/test_portfolio_analytics.py new file mode 100644 index 00000000..1ed13297 --- /dev/null +++ b/test/test_portfolio_analytics.py @@ -0,0 +1,174 @@ +""" +Tests for PortfolioAnalyticsService pure portfolio-distribution math. + +Run with: pytest test/test_portfolio_analytics.py -v +""" +import pytest + +from services.portfolio_analytics_service import PortfolioAnalyticsService + + +@pytest.fixture +def analytics(): + return PortfolioAnalyticsService() + + +@pytest.fixture +def accounts_state(): + """Plain dict fixture shaped like AccountsService.accounts_state.""" + return { + "master_account": { + "binance": [ + {"token": "BTC", "units": 0.5, "price": 50000.0, "value": 25000.0, "available_units": 0.5}, + {"token": "USDT", "units": 5000.0, "price": 1.0, "value": 5000.0, "available_units": 5000.0}, + ], + "kraken": [ + {"token": "BTC", "units": 0.1, "price": 50000.0, "value": 5000.0, "available_units": 0.1}, + ], + }, + "sub_account": { + "binance": [ + {"token": "ETH", "units": 5.0, "price": 3000.0, "value": 15000.0, "available_units": 5.0}, + ], + }, + } + + +class TestPortfolioDistribution: + def test_total_value_and_token_count(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state) + + assert result["total_portfolio_value"] == 50000.0 + assert result["token_count"] == 3 + assert result["account_filter"] == "all_accounts" + assert "error" not in result + + def test_response_shape(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state) + + assert set(result.keys()) == {"total_portfolio_value", "token_count", "distribution", "account_filter"} + token_dist = result["distribution"][0] + assert set(token_dist.keys()) == {"token", "total_value", "total_units", "percentage", "accounts"} + account_entry = next(iter(token_dist["accounts"].values())) + assert set(account_entry.keys()) == {"value", "units", "percentage", "connectors"} + connector_entry = next(iter(account_entry["connectors"].values())) + assert set(connector_entry.keys()) == {"value", "units"} + + def test_token_percentages(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state) + by_token = {d["token"]: d for d in result["distribution"]} + + # BTC: 25000 (binance) + 5000 (kraken) = 30000 -> 60% + assert by_token["BTC"]["total_value"] == 30000.0 + assert by_token["BTC"]["total_units"] == 0.6 + assert by_token["BTC"]["percentage"] == 60.0 + # ETH: 15000 -> 30% + assert by_token["ETH"]["percentage"] == 30.0 + # USDT: 5000 -> 10% + assert by_token["USDT"]["percentage"] == 10.0 + + def test_account_and_connector_breakdown(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state) + btc = next(d for d in result["distribution"] if d["token"] == "BTC") + + master = btc["accounts"]["master_account"] + assert master["value"] == 30000.0 + assert master["units"] == 0.6 + assert master["percentage"] == 60.0 + assert master["connectors"]["binance"] == {"value": 25000.0, "units": 0.5} + assert master["connectors"]["kraken"] == {"value": 5000.0, "units": 0.1} + + def test_sorted_by_value_descending(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state) + values = [d["total_value"] for d in result["distribution"]] + + assert values == sorted(values, reverse=True) + assert [d["token"] for d in result["distribution"]] == ["BTC", "ETH", "USDT"] + + def test_account_filter(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state, "sub_account") + + assert result["account_filter"] == "sub_account" + assert result["total_portfolio_value"] == 15000.0 + assert result["token_count"] == 1 + assert result["distribution"][0]["token"] == "ETH" + assert result["distribution"][0]["percentage"] == 100.0 + + def test_unknown_account_filter_returns_empty(self, analytics, accounts_state): + result = analytics.get_portfolio_distribution(accounts_state, "missing_account") + + assert result["total_portfolio_value"] == 0 + assert result["token_count"] == 0 + assert result["distribution"] == [] + assert result["account_filter"] == "missing_account" + + def test_empty_state(self, analytics): + result = analytics.get_portfolio_distribution({}) + + assert result["total_portfolio_value"] == 0 + assert result["token_count"] == 0 + assert result["distribution"] == [] + assert "error" not in result + + def test_zero_total_value_has_zero_percentages(self, analytics): + state = {"acc": {"conn": [{"token": "XYZ", "units": 1.0, "price": 0.0, "value": 0.0}]}} + result = analytics.get_portfolio_distribution(state) + + assert result["total_portfolio_value"] == 0 + assert result["distribution"][0]["percentage"] == 0 + + def test_error_path_returns_error_shape(self, analytics): + result = analytics.get_portfolio_distribution(None) + + assert result["total_portfolio_value"] == 0 + assert result["token_count"] == 0 + assert result["distribution"] == [] + assert result["account_filter"] == "all_accounts" + assert "error" in result + + +class TestAccountDistribution: + def test_totals_and_percentages(self, analytics, accounts_state): + result = analytics.get_account_distribution(accounts_state) + + assert result["total_portfolio_value"] == 50000.0 + assert result["account_count"] == 2 + by_account = {d["account"]: d for d in result["distribution"]} + assert by_account["master_account"]["total_value"] == 35000.0 + assert by_account["master_account"]["percentage"] == 70.0 + assert by_account["sub_account"]["total_value"] == 15000.0 + assert by_account["sub_account"]["percentage"] == 30.0 + + def test_connector_percentages_relative_to_total(self, analytics, accounts_state): + result = analytics.get_account_distribution(accounts_state) + master = next(d for d in result["distribution"] if d["account"] == "master_account") + + assert master["connectors"]["binance"] == {"value": 30000.0, "percentage": 60.0} + assert master["connectors"]["kraken"] == {"value": 5000.0, "percentage": 10.0} + + def test_response_shape(self, analytics, accounts_state): + result = analytics.get_account_distribution(accounts_state) + + assert set(result.keys()) == {"total_portfolio_value", "account_count", "distribution"} + entry = result["distribution"][0] + assert set(entry.keys()) == {"account", "total_value", "percentage", "connectors"} + connector_entry = next(iter(entry["connectors"].values())) + assert set(connector_entry.keys()) == {"value", "percentage"} + + def test_sorted_by_value_descending(self, analytics, accounts_state): + result = analytics.get_account_distribution(accounts_state) + + assert [d["account"] for d in result["distribution"]] == ["master_account", "sub_account"] + + def test_empty_state(self, analytics): + result = analytics.get_account_distribution({}) + + assert result == {"total_portfolio_value": 0, "account_count": 0, "distribution": []} + + def test_error_path_returns_error_shape(self, analytics): + result = analytics.get_account_distribution(None) + + assert result["total_portfolio_value"] == 0 + assert result["account_count"] == 0 + assert result["distribution"] == [] + assert "error" in result diff --git a/utils/bot_archiver.py b/utils/bot_archiver.py index 9d7c7183..b32c3b79 100644 --- a/utils/bot_archiver.py +++ b/utils/bot_archiver.py @@ -1,9 +1,12 @@ +import logging import os import shutil import boto3 from botocore.exceptions import NoCredentialsError +logger = logging.getLogger(__name__) + class BotArchiver: def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, default_bucket_name=None): @@ -28,16 +31,16 @@ def archive_and_upload(self, instance_name, instance_dir, bucket_name=None): try: self.s3.upload_file(archive_path, bucket_name, archive_name) - print(f"Archive {archive_name} uploaded successfully to S3.") + logger.info(f"Archive {archive_name} uploaded successfully to S3.") os.remove(archive_path) # Remove the local archive file shutil.rmtree(instance_dir) # Remove the instance directory except NoCredentialsError: - print("Credentials not available for AWS S3.") + logger.error("Credentials not available for AWS S3.") @staticmethod def compress_directory(source_dir, output_path): shutil.make_archive(output_path.replace('.tar.gz', ''), 'gztar', source_dir) - print(f"Compressed {source_dir} into {output_path}") + logger.info(f"Compressed {source_dir} into {output_path}") def archive_locally(self, instance_name, instance_dir, compress=False): if compress: diff --git a/utils/file_system.py b/utils/file_system.py index 2b555202..b62737e1 100644 --- a/utils/file_system.py +++ b/utils/file_system.py @@ -56,9 +56,12 @@ def list_files(self, directory: str) -> List[str]: Lists all files in a given directory. :param directory: The directory to list files from. :return: List of file names in the directory. + :raises ValueError: If the directory contains '..' path components. :raises FileNotFoundError: If the directory does not exist. :raises PermissionError: If access is denied to the directory. """ + if any(part == ".." for part in directory.replace("\\", "/").split("/")): + raise ValueError(f"Invalid directory: '{directory}'") excluded_files = ["__init__.py", "__pycache__", ".DS_Store", ".dockerignore", ".gitignore"] dir_path = self._get_full_path(directory) if not os.path.exists(dir_path): @@ -140,9 +143,12 @@ def delete_folder(self, directory: str, folder_name: str) -> None: Deletes a folder in a specified directory. :param directory: The directory to delete the folder from. :param folder_name: The name of the folder to be deleted. + :raises ValueError: If folder_name is empty, contains path separators or is a '.'/'..' component. :raises FileNotFoundError: If folder doesn't exist. :raises PermissionError: If permission is denied. """ + if not folder_name or folder_name in (".", "..") or '/' in folder_name or '\\' in folder_name: + raise ValueError(f"Invalid folder name: '{folder_name}'") folder_path = self._get_full_path(os.path.join(directory, folder_name)) if not os.path.exists(folder_path): raise FileNotFoundError(f"Folder '{folder_name}' not found in '{directory}'") @@ -415,21 +421,36 @@ def list_directories(self, path): except Exception: return [] - def delete_archived_bot(self, db_path: str) -> str: + def get_archived_db_path(self, db_path: str) -> str: """ - Deletes an archived bot directory given a database file path. + Resolves a database path and validates that it is contained within the archived bots directory. :param db_path: Path to a database file (as returned by list_databases, e.g. bots/archived/{instance}/data/file.sqlite) - :return: The name of the deleted archived bot directory. - :raises FileNotFoundError: If the path or archived directory doesn't exist. - :raises ValueError: If the path doesn't point to a valid archived bot. + :return: The resolved absolute path to the database file. + :raises ValueError: If the resolved path escapes the archived bots directory. + :raises FileNotFoundError: If the database file does not exist. """ # list_databases returns paths that already include base_path prefix (e.g. bots/archived/...) # Strip it to avoid double-prefixing when _get_full_path adds it again prefix = self.base_path + os.sep normalized = db_path[len(prefix):] if db_path.startswith(prefix) else db_path full_path = normalized if os.path.isabs(normalized) else self._get_full_path(normalized) - if not os.path.exists(full_path): + archived_root = os.path.realpath(self._get_full_path("archived")) + resolved_path = os.path.realpath(full_path) + if os.path.commonpath([archived_root, resolved_path]) != archived_root: + raise ValueError(f"Path '{db_path}' is outside the archived bots directory") + if not os.path.isfile(resolved_path): raise FileNotFoundError(f"Database path '{db_path}' not found") + return resolved_path + + def delete_archived_bot(self, db_path: str) -> str: + """ + Deletes an archived bot directory given a database file path. + :param db_path: Path to a database file (as returned by list_databases, e.g. bots/archived/{instance}/data/file.sqlite) + :return: The name of the deleted archived bot directory. + :raises FileNotFoundError: If the path or archived directory doesn't exist. + :raises ValueError: If the path doesn't point to a valid archived bot. + """ + full_path = self.get_archived_db_path(db_path) # Navigate up from .../archived/{instance}/data/file.sqlite to .../archived/{instance} archived_bot_dir = os.path.dirname(os.path.dirname(full_path)) diff --git a/utils/hummingbot_database_reader.py b/utils/hummingbot_database_reader.py index 6a57d260..110d057d 100644 --- a/utils/hummingbot_database_reader.py +++ b/utils/hummingbot_database_reader.py @@ -13,6 +13,8 @@ class HummingbotDatabase: def __init__(self, db_path: str): + if not os.path.isfile(db_path): + raise FileNotFoundError(f"Database file '{db_path}' not found") self.db_name = os.path.basename(db_path) self.db_path = db_path self.db_path = f'sqlite:///{os.path.join(db_path)}'