diff --git a/.dockerignore b/.dockerignore index 2eea525d..8e571199 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,62 @@ -.env \ No newline at end of file +# Python cache +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# Virtual environments +venv/ +ENV/ +env/ +.venv + +# IDEs +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS files +.DS_Store +Thumbs.db + +# Git +.git/ +.gitignore + +# Documentation +*.md +docs/ + +# Tests +test/ +tests/ +pytest_cache/ +.coverage +.pytest_cache/ + +# Development files +.env +.env.local +*.log + +# Build artifacts +build/ +dist/ +*.egg-info/ + +# Docker files (don't copy themselves) +Dockerfile* +docker-compose*.yml +.dockerignore + +# Bot data that should be mounted as volumes +bots/instances/* +bots/data/* +bots/credentials/* +!bots/credentials/master_account/ + +# Archives +bots/archived/ \ No newline at end of file diff --git a/.github/workflows/docker_buildx_workflow.yml b/.github/workflows/docker_buildx_workflow.yml index 156388d9..682f832e 100644 --- a/.github/workflows/docker_buildx_workflow.yml +++ b/.github/workflows/docker_buildx_workflow.yml @@ -1,59 +1,24 @@ -name: Backend-API Docker Buildx Workflow +name: Hummingbot-API Docker Buildx Workflow on: pull_request: types: [closed] branches: - main - - development - release: - types: [published, edited] jobs: - build_pr: - if: github.event_name == 'pull_request' && github.event.pull_request.merged == true + build_and_push: + if: github.event.pull_request.merged == true runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4.1.1 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3.1.0 - - - name: Login to DockerHub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Build and push Development Image - if: github.base_ref == 'development' - uses: docker/build-push-action@v5 - with: - context: . - platforms: linux/amd64,linux/arm64 - push: true - tags: hummingbot/backend-api:development - - - name: Build and push Latest Image - if: github.base_ref == 'main' - uses: docker/build-push-action@v5 - with: - context: . - file: ./Dockerfile - platforms: linux/amd64,linux/arm64 - push: true - tags: hummingbot/backend-api:latest - - build_release: - if: github.event_name == 'release' - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4.1.1 + - name: Extract version from main.py + id: get_version + run: | + VERSION=$(grep -E '^VERSION *= *' main.py | head -1 | sed -E 's/^VERSION *= *["\x27]?([^"\x27]*)["\x27]?/\1/') + echo "VERSION=$VERSION" >> $GITHUB_OUTPUT - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -67,14 +32,12 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Extract tag name - id: get_tag - run: echo ::set-output name=VERSION::${GITHUB_REF#refs/tags/} - - - name: Build and push + - name: Build and push Docker images (latest and versioned) uses: docker/build-push-action@v5 with: context: . platforms: linux/amd64,linux/arm64 push: true - tags: hummingbot/backend-api:${{ steps.get_tag.outputs.VERSION }} + tags: | + hummingbot/hummingbot-api:latest + hummingbot/hummingbot-api:${{ steps.get_version.outputs.VERSION }} diff --git a/.gitignore b/.gitignore index 68bc17f9..3ce1cc5e 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,9 @@ instance/ # Scrapy stuff: .scrapy +# Setup sentinel +.setup-complete + # Sphinx documentation docs/_build/ @@ -158,3 +161,21 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Hummingbot Gateway files +gateway-files/ + +# Hummingbot wheel files (for local Docker builds) +*.whl + +# Hummingbot credentials and local data +bots/credentials/ +bots/instances/ +bots/conf/ + +# Local MCP configuration (project-specific overrides) +.mcp.json + +# IDE files +.vscode/ +.idea/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dc82c98b..ccdb09ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,8 +26,7 @@ repos: args: [--settings-path=pyproject.toml] - repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + rev: 7.1.1 hooks: - id: flake8 - additional_dependencies: ['flake8'] args: [--max-line-length=130] diff --git a/Dockerfile b/Dockerfile index 8b310580..27e4b497 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,22 +1,56 @@ -# Start from a base image with Miniconda installed -FROM continuumio/miniconda3 +# Stage 1: Builder stage +FROM continuumio/miniconda3 AS builder -# Install system dependencies +# Install build dependencies RUN apt-get update && \ - apt-get install -y sudo libusb-1.0 python3-dev gcc && \ + apt-get install -y python3-dev gcc g++ build-essential && \ rm -rf /var/lib/apt/lists/* -# Set the working directory in the container -WORKDIR /backend-api +# Set working directory +WORKDIR /build + +# Copy only the environment file first (for better layer caching) +COPY environment.yml . + +# Create the conda environment +RUN conda env create -f environment.yml && \ + conda clean -afy && \ + rm -rf /root/.cache/pip/* + +# Stage 2: Runtime stage +FROM continuumio/miniconda3 + +# Install only runtime dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + libusb-1.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Copy the conda environment from builder +COPY --from=builder /opt/conda/envs/hummingbot-api /opt/conda/envs/hummingbot-api + +# Set the working directory +WORKDIR /hummingbot-api + +# Copy only necessary application files +COPY main.py config.py deps.py ./ +COPY models ./models +COPY routers ./routers +COPY services ./services +COPY utils ./utils +COPY database ./database +COPY bots/controllers ./bots/controllers +COPY bots/scripts ./bots/scripts -# Copy the current directory contents and the Conda environment file into the container -COPY . . +# Create necessary directories +RUN mkdir -p bots/instances bots/conf bots/credentials bots/data bots/archived -# Create the environment from the environment.yml file -RUN conda env create -f environment.yml +# Expose port +EXPOSE 8000 -# Make RUN commands use the new environment -SHELL ["conda", "run", "-n", "backend-api", "/bin/bash", "-c"] +# Set environment variables to ensure conda env is used +ENV PATH="/opt/conda/envs/hummingbot-api/bin:$PATH" +ENV CONDA_DEFAULT_ENV=hummingbot-api -# The code to run when container is started -ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "backend-api", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] +# Run the application +ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/Makefile b/Makefile index b240cd28..8c4a43aa 100644 --- a/Makefile +++ b/Makefile @@ -1,49 +1,48 @@ -.ONESHELL: -.SHELLFLAGS := -c - -.PHONY: run -.PHONY: uninstall -.PHONY: install -.PHONY: install-pre-commit -.PHONY: docker_build -.PHONY: docker_run - - -detect_conda_bin := $(shell bash -c 'if [ "${CONDA_EXE} " == " " ]; then \ - CONDA_EXE=$$((find /opt/conda/bin/conda || find ~/anaconda3/bin/conda || \ - find /usr/local/anaconda3/bin/conda || find ~/miniconda3/bin/conda || \ - find /root/miniconda/bin/conda || find ~/Anaconda3/Scripts/conda || \ - find $$CONDA/bin/conda) 2>/dev/null); fi; \ - if [ "${CONDA_EXE}_" == "_" ]; then \ - echo "Please install Anaconda w/ Python 3.10+ first"; \ - echo "See: https://www.anaconda.com/distribution/"; \ - exit 1; fi; \ - echo $$(dirname $${CONDA_EXE})') - -CONDA_BIN := $(detect_conda_bin) +.PHONY: setup run deploy stop install uninstall build install-pre-commit +SETUP_SENTINEL := .setup-complete + +setup: $(SETUP_SENTINEL) + +$(SETUP_SENTINEL): + chmod +x setup.sh + ./setup.sh + +# Run locally (dev mode) run: - uvicorn main:app --reload + docker compose up emqx postgres -d + conda run --no-capture-output -n hummingbot-api uvicorn main:app --reload -uninstall: - conda env remove -n backend-api +# Deploy with Docker +deploy: $(SETUP_SENTINEL) + docker compose up -d + +# Stop all services +stop: + docker compose down +# Install conda environment install: - if conda env list | grep -q '^backend-api '; then \ - echo "Environment already exists."; \ + @if ! command -v conda >/dev/null 2>&1; then \ + echo "Error: Conda is not found in PATH. Please install Conda or add it to your PATH."; \ + exit 1; \ + fi + @if conda env list | grep -q '^hummingbot-api '; then \ + echo "Environment already exists."; \ else \ - conda env create -f environment.yml; \ + conda env create -f environment.yml; \ fi $(MAKE) install-pre-commit + $(MAKE) setup -install-pre-commit: - /bin/bash -c 'source "${CONDA_BIN}/activate" backend-api && \ - if ! conda list pre-commit | grep pre-commit &> /dev/null; then \ - pip install pre-commit; \ - fi && pre-commit install' +uninstall: + conda env remove -n hummingbot-api -y + rm -f $(SETUP_SENTINEL) -docker_build: - docker build -t hummingbot/backend-api:latest . +install-pre-commit: + conda run -n hummingbot-api pip install pre-commit + conda run -n hummingbot-api pre-commit install -docker_run: - docker compose up -d +# Build Docker image +build: + docker build -t hummingbot/hummingbot-api:latest . \ No newline at end of file diff --git a/README.md b/README.md index 7d23d29b..9486a201 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,138 @@ -# Backend API - -## Overview -Backend-api is a dedicated solution for managing Hummingbot instances. It offers a robust backend API to streamline the deployment, management, and interaction with Hummingbot containers. This tool is essential for administrators and developers looking to efficiently handle various aspects of Hummingbot operations. - -## Features -- **Deployment File Management**: Manage files necessary for deploying new Hummingbot instances. -- **Container Control**: Effortlessly start and stop Hummingbot containers. -- **Archiving Options**: Securely archive containers either locally or on Amazon S3 post-removal. -- **Direct Messaging**: Communicate with Hummingbots through the broker for effective control and coordination. - -## Getting Started - -### Conda Installation -1. Install the environment using Conda: - ```bash - conda env create -f environment.yml - ``` -2. Activate the Conda environment: - ```bash - conda activate backend-api - ``` - -### Running the API with Conda -Run the API using uvicorn with the following command: - ```bash - uvicorn main:app --reload - ``` - -### Docker Installation and Running the API -For running the project using Docker, follow these steps: - -1. **Set up Environment Variables**: - - Execute the `set_environment.sh` script to configure the necessary environment variables in the `.env` file: - ```bash - ./set_environment.sh - ``` - -2. **Build and Run with Docker Compose**: - - After setting up the environment variables, use Docker Compose to build and run the project: - ```bash - docker compose up --build - ``` - - - This command will build the Docker image and start the containers as defined in your `docker-compose.yml` file. - -### Usage -This API is designed for: -- **Deploying Hummingbot instances** -- **Starting/Stopping Containers** -- **Archiving Hummingbots** -- **Messaging with Hummingbot instances** - -To test these endpoints, you can use the [Swagger UI](http://localhost:8000/docs) or [Redoc](http://localhost:8000/redoc). - -## Contributing -Contributions are welcome! For support or queries, please contact us on Discord. +# Hummingbot API + +A REST API for managing Hummingbot trading bots across multiple exchanges, with AI assistant integration via MCP. + +## Quick Start + +```bash +git clone https://github.com/hummingbot/hummingbot-api.git +cd hummingbot-api +make setup # Creates .env (prompts for passwords) +make deploy # Starts all services +``` + +That's it! The API is now running at http://localhost:8000 + +## Available Commands + +| Command | Description | +|---------|-------------| +| `make setup` | Create `.env` file with configuration | +| `make deploy` | Start all services (API, PostgreSQL, EMQX) | +| `make stop` | Stop all services | +| `make run` | Run API locally in dev mode | +| `make install` | Install conda environment for development | +| `make build` | Build Docker image | + +## Services + +After `make deploy`, these services are available: + +| Service | URL | Description | +|---------|-----|-------------| +| **API** | http://localhost:8000 | REST API | +| **Swagger UI** | http://localhost:8000/docs | Interactive API documentation | +| **PostgreSQL** | localhost:5432 | Database | +| **EMQX** | localhost:1883 | MQTT broker | +| **EMQX Dashboard** | http://localhost:18083 | Broker admin (admin/public) | + +## Connect AI Assistant (MCP) + +### Claude Code (CLI) + +```bash +claude mcp add --transport stdio hummingbot -- \ + docker run --rm -i \ + -e HUMMINGBOT_API_URL=http://host.docker.internal:8000 \ + -v hummingbot_mcp:/root/.hummingbot_mcp \ + hummingbot/hummingbot-mcp:latest +``` + +Then use natural language: +- "Show my portfolio balances" +- "Set up my Binance account" +- "Create a market making strategy for ETH-USDT" + +### Claude Desktop + +Add to your config file: +- **macOS**: `~/Library/Application Support/Claude/claude_desktop_config.json` +- **Windows**: `%APPDATA%\Claude\claude_desktop_config.json` + +```json +{ + "mcpServers": { + "hummingbot": { + "command": "docker", + "args": ["run", "--rm", "-i", "-e", "HUMMINGBOT_API_URL=http://host.docker.internal:8000", "-v", "hummingbot_mcp:/root/.hummingbot_mcp", "hummingbot/hummingbot-mcp:latest"] + } + } +} +``` + +Restart Claude Desktop after adding. + +## Gateway (DEX Trading) + +Gateway enables decentralized exchange trading. Start it via MCP: + +> "Start Gateway in development mode with passphrase 'admin'" + +Or via API at http://localhost:8000/docs using the Gateway endpoints. + +Once running, Gateway is available at http://localhost:15888 + +## Configuration + +The `.env` file contains all configuration. Key settings: + +```bash +USERNAME=admin # API username +PASSWORD=admin # API password +CONFIG_PASSWORD=admin # Encrypts bot credentials +DATABASE_URL=... # PostgreSQL connection +GATEWAY_URL=... # Gateway URL (for DEX) +``` + +Edit `.env` and restart with `make deploy` to apply changes. + +## API Features + +- **Portfolio**: Balances, positions, P&L across all exchanges +- **Trading**: Place orders, manage positions, track history +- **Bots**: Deploy, monitor, and control trading bots +- **Market Data**: Prices, orderbooks, candles, funding rates +- **Strategies**: Create and manage trading strategies + +Full API documentation at http://localhost:8000/docs + +## Development + +```bash +make install # Create conda environment +conda activate hummingbot-api +make run # Run with hot-reload +``` + +## Troubleshooting + +**API won't start?** +```bash +docker compose logs hummingbot-api +``` + +**Database issues?** +```bash +docker compose down -v # Reset all data +make deploy # Fresh start +``` + +**Check service status:** +```bash +docker ps | grep hummingbot +``` + +## Support + +- **API Docs**: http://localhost:8000/docs +- **Issues**: https://github.com/hummingbot/hummingbot-api/issues diff --git a/bots/data/.gitignore b/bots/archived/.gitignore similarity index 100% rename from bots/data/.gitignore rename to bots/archived/.gitignore diff --git a/bots/controllers/directional_trading/ai_livestream.py b/bots/controllers/directional_trading/ai_livestream.py new file mode 100644 index 00000000..28a1157a --- /dev/null +++ b/bots/controllers/directional_trading/ai_livestream.py @@ -0,0 +1,84 @@ +from decimal import Decimal +from typing import List + +import pandas_ta as ta # noqa: F401 +from pydantic import Field + +from hummingbot.core.data_type.common import TradeType +from hummingbot.remote_iface.mqtt import ExternalTopicFactory +from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( + DirectionalTradingControllerBase, + DirectionalTradingControllerConfigBase, +) +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig + + +class AILivestreamControllerConfig(DirectionalTradingControllerConfigBase): + controller_name: str = "ai_livestream" + long_threshold: float = Field(default=0.5, json_schema_extra={"is_updatable": True}) + short_threshold: float = Field(default=0.5, json_schema_extra={"is_updatable": True}) + topic: str = "hbot/predictions" + + +class AILivestreamController(DirectionalTradingControllerBase): + def __init__(self, config: AILivestreamControllerConfig, *args, **kwargs): + self.config = config + super().__init__(config, *args, **kwargs) + # Start ML signal listener + self._init_ml_signal_listener() + + def _init_ml_signal_listener(self): + """Initialize a listener for ML signals from the MQTT broker""" + try: + normalized_pair = self.config.trading_pair.replace("-", "_").lower() + topic = f"{self.config.topic}/{normalized_pair}/ML_SIGNALS" + self._ml_signal_listener = ExternalTopicFactory.create_async( + topic=topic, + callback=self._handle_ml_signal, + use_bot_prefix=False, + ) + self.logger().info("ML signal listener initialized successfully") + except Exception as e: + self.logger().error(f"Failed to initialize ML signal listener: {str(e)}") + self._ml_signal_listener = None + + def _handle_ml_signal(self, signal: dict, topic: str): + """Handle incoming ML signal""" + # self.logger().info(f"Received ML signal: {signal}") + short, neutral, long = signal["probabilities"] + if short > self.config.short_threshold: + self.processed_data["signal"] = -1 + elif long > self.config.long_threshold: + self.processed_data["signal"] = 1 + else: + self.processed_data["signal"] = 0 + self.processed_data["features"] = signal + + async def update_processed_data(self): + pass + + def get_executor_config(self, trade_type: TradeType, price: Decimal, amount: Decimal): + """ + Get the executor config based on the trade_type, price and amount. This method can be overridden by the + subclasses if required. + """ + return PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=trade_type, + entry_price=price, + amount=amount, + triple_barrier_config=self.config.triple_barrier_config.new_instance_with_adjusted_volatility( + volatility_factor=self.processed_data["features"].get("target_pct", 0.01)), + leverage=self.config.leverage, + ) + + def to_format_status(self) -> List[str]: + lines = [] + features = self.processed_data.get("features", {}) + lines.append(f"Signal: {self.processed_data.get('signal', 'N/A')}") + lines.append(f"Timestamp: {features.get('timestamp', 'N/A')}") + lines.append(f"Probabilities: {features.get('probabilities', 'N/A')}") + lines.append(f"Target Pct: {features.get('target_pct', 'N/A')}") + return lines diff --git a/bots/controllers/directional_trading/bollinger_v1.py b/bots/controllers/directional_trading/bollinger_v1.py index 8f1e92e2..afa0d772 100644 --- a/bots/controllers/directional_trading/bollinger_v1.py +++ b/bots/controllers/directional_trading/bollinger_v1.py @@ -1,56 +1,52 @@ from typing import List import pandas_ta as ta # noqa: F401 -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( DirectionalTradingControllerBase, DirectionalTradingControllerConfigBase, ) -from pydantic import Field, validator class BollingerV1ControllerConfig(DirectionalTradingControllerConfigBase): - controller_name = "bollinger_v1" - candles_config: List[CandlesConfig] = [] - candles_connector: str = Field(default=None) - candles_trading_pair: str = Field(default=None) + controller_name: str = "bollinger_v1" + candles_connector: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) + candles_trading_pair: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) interval: str = Field( default="3m", - client_data=ClientFieldData( - prompt=lambda mi: "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", - prompt_on_new=False)) + json_schema_extra={ + "prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", + "prompt_on_new": True}) bb_length: int = Field( default=100, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands length: ", - prompt_on_new=True)) - bb_std: float = Field( - default=2.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands standard deviation: ", - prompt_on_new=False)) - bb_long_threshold: float = Field( - default=0.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands long threshold: ", - prompt_on_new=True)) - bb_short_threshold: float = Field( - default=1.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands short threshold: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the Bollinger Bands length: ", "prompt_on_new": True}) + bb_std: float = Field(default=2.0) + bb_long_threshold: float = Field(default=0.0) + bb_short_threshold: float = Field(default=1.0) - @validator("candles_connector", pre=True, always=True) - def set_candles_connector(cls, v, values): + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("connector_name") + return validation_info.data.get("connector_name") return v - @validator("candles_trading_pair", pre=True, always=True) - def set_candles_trading_pair(cls, v, values): + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("trading_pair") + return validation_info.data.get("trading_pair") return v @@ -58,23 +54,24 @@ class BollingerV1Controller(DirectionalTradingControllerBase): def __init__(self, config: BollingerV1ControllerConfig, *args, **kwargs): self.config = config self.max_records = self.config.bb_length - if len(self.config.candles_config) == 0: - self.config.candles_config = [CandlesConfig( - connector=config.candles_connector, - trading_pair=config.candles_trading_pair, - interval=config.interval, - max_records=self.max_records - )] super().__init__(config, *args, **kwargs) + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] + async def update_processed_data(self): df = self.market_data_provider.get_candles_df(connector_name=self.config.candles_connector, trading_pair=self.config.candles_trading_pair, interval=self.config.interval, max_records=self.max_records) # Add indicators - df.ta.bbands(length=self.config.bb_length, std=self.config.bb_std, append=True) - bbp = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}"] + df.ta.bbands(length=self.config.bb_length, lower_std=self.config.bb_std, upper_std=self.config.bb_std, append=True) + bbp = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}_{self.config.bb_std}"] # Generate signal long_condition = bbp < self.config.bb_long_threshold diff --git a/bots/controllers/directional_trading/bollinger_v2.py b/bots/controllers/directional_trading/bollinger_v2.py new file mode 100644 index 00000000..83718137 --- /dev/null +++ b/bots/controllers/directional_trading/bollinger_v2.py @@ -0,0 +1,118 @@ +from sys import float_info as sflt +from typing import List + +import pandas as pd +import pandas_ta as ta # noqa: F401 +import talib +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo +from talib import MA_Type + +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( + DirectionalTradingControllerBase, + DirectionalTradingControllerConfigBase, +) + + +class BollingerV2ControllerConfig(DirectionalTradingControllerConfigBase): + controller_name: str = "bollinger_v2" + candles_connector: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) + candles_trading_pair: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) + interval: str = Field( + default="3m", + json_schema_extra={ + "prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", + "prompt_on_new": True}) + bb_length: int = Field( + default=100, + json_schema_extra={"prompt": "Enter the Bollinger Bands length: ", "prompt_on_new": True}) + bb_std: float = Field(default=2.0) + bb_long_threshold: float = Field(default=0.0) + bb_short_threshold: float = Field(default=1.0) + + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): + if v is None or v == "": + return validation_info.data.get("connector_name") + return v + + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): + if v is None or v == "": + return validation_info.data.get("trading_pair") + return v + + +class BollingerV2Controller(DirectionalTradingControllerBase): + def __init__(self, config: BollingerV2ControllerConfig, *args, **kwargs): + self.config = config + self.max_records = self.config.bb_length * 5 + super().__init__(config, *args, **kwargs) + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] + + def non_zero_range(self, x: pd.Series, y: pd.Series) -> pd.Series: + """Non-Zero Range + + Calculates the difference of two Series plus epsilon to any zero values. + Technically: ```x - y + epsilon``` + + Parameters: + x (Series): Series of 'x's + y (Series): Series of 'y's + + Returns: + (Series): 1 column + """ + diff = x - y + if diff.eq(0).any().any(): + diff += sflt.epsilon + return diff + + async def update_processed_data(self): + df = self.market_data_provider.get_candles_df(connector_name=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records) + # Add indicators + df.ta.bbands(length=self.config.bb_length, lower_std=self.config.bb_std, upper_std=self.config.bb_std, append=True) + df["upperband"], df["middleband"], df["lowerband"] = talib.BBANDS(real=df["close"], timeperiod=self.config.bb_length, nbdevup=self.config.bb_std, nbdevdn=self.config.bb_std, matype=MA_Type.SMA) + + ulr = self.non_zero_range(df["upperband"], df["lowerband"]) + bbp = self.non_zero_range(df["close"], df["lowerband"]) / ulr + df["percent"] = bbp + + # Generate signal + long_condition = bbp < self.config.bb_long_threshold + short_condition = bbp > self.config.bb_short_threshold + + # Generate signal + df["signal"] = 0 + df.loc[long_condition, "signal"] = 1 + df.loc[short_condition, "signal"] = -1 + + # Debug + # We skip the last row which is live candle + with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None): + self.logger().info(df.head(-1).tail(15)) + + # Update processed data + self.processed_data["signal"] = df["signal"].iloc[-1] + self.processed_data["features"] = df diff --git a/bots/controllers/directional_trading/bollingrid.py b/bots/controllers/directional_trading/bollingrid.py new file mode 100644 index 00000000..0122b772 --- /dev/null +++ b/bots/controllers/directional_trading/bollingrid.py @@ -0,0 +1,160 @@ +from decimal import Decimal +from typing import List + +import pandas_ta as ta # noqa: F401 +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + +from hummingbot.core.data_type.common import TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( + DirectionalTradingControllerBase, + DirectionalTradingControllerConfigBase, +) +from hummingbot.strategy_v2.executors.grid_executor.data_types import GridExecutorConfig + + +class BollinGridControllerConfig(DirectionalTradingControllerConfigBase): + controller_name: str = "bollingrid" + candles_connector: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) + candles_trading_pair: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) + interval: str = Field( + default="3m", + json_schema_extra={ + "prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", + "prompt_on_new": True}) + bb_length: int = Field( + default=100, + json_schema_extra={"prompt": "Enter the Bollinger Bands length: ", "prompt_on_new": True}) + bb_std: float = Field(default=2.0) + bb_long_threshold: float = Field(default=0.0) + bb_short_threshold: float = Field(default=1.0) + + # Grid-specific parameters + grid_start_price_coefficient: float = Field( + default=0.25, + json_schema_extra={"prompt": "Grid start price coefficient (multiplier of BB width): ", "prompt_on_new": True}) + grid_end_price_coefficient: float = Field( + default=0.75, + json_schema_extra={"prompt": "Grid end price coefficient (multiplier of BB width): ", "prompt_on_new": True}) + grid_limit_price_coefficient: float = Field( + default=0.35, + json_schema_extra={"prompt": "Grid limit price coefficient (multiplier of BB width): ", "prompt_on_new": True}) + min_spread_between_orders: Decimal = Field( + default=Decimal("0.005"), + json_schema_extra={"prompt": "Minimum spread between grid orders (e.g., 0.005 for 0.5%): ", "prompt_on_new": True}) + order_frequency: int = Field( + default=2, + json_schema_extra={"prompt": "Order frequency (seconds between grid orders): ", "prompt_on_new": True}) + max_orders_per_batch: int = Field( + default=1, + json_schema_extra={"prompt": "Maximum orders per batch: ", "prompt_on_new": True}) + min_order_amount_quote: Decimal = Field( + default=Decimal("6"), + json_schema_extra={"prompt": "Minimum order amount in quote currency: ", "prompt_on_new": True}) + max_open_orders: int = Field( + default=5, + json_schema_extra={"prompt": "Maximum number of open orders: ", "prompt_on_new": True}) + + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): + if v is None or v == "": + return validation_info.data.get("connector_name") + return v + + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): + if v is None or v == "": + return validation_info.data.get("trading_pair") + return v + + +class BollinGridController(DirectionalTradingControllerBase): + def __init__(self, config: BollinGridControllerConfig, *args, **kwargs): + self.config = config + self.max_records = self.config.bb_length + super().__init__(config, *args, **kwargs) + + async def update_processed_data(self): + df = self.market_data_provider.get_candles_df(connector_name=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records) + # Add indicators + df.ta.bbands(length=self.config.bb_length, std=self.config.bb_std, append=True) + bbp = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}"] + bb_width = df[f"BBB_{self.config.bb_length}_{self.config.bb_std}"] + + # Generate signal + long_condition = bbp < self.config.bb_long_threshold + short_condition = bbp > self.config.bb_short_threshold + + # Generate signal + df["signal"] = 0 + df.loc[long_condition, "signal"] = 1 + df.loc[short_condition, "signal"] = -1 + signal = df["signal"].iloc[-1] + close = df["close"].iloc[-1] + current_bb_width = bb_width.iloc[-1] / 100 + if signal == -1: + end_price = close * (1 + current_bb_width * self.config.grid_start_price_coefficient) + start_price = close * (1 - current_bb_width * self.config.grid_end_price_coefficient) + limit_price = close * (1 + current_bb_width * self.config.grid_limit_price_coefficient) + elif signal == 1: + start_price = close * (1 - current_bb_width * self.config.grid_start_price_coefficient) + end_price = close * (1 + current_bb_width * self.config.grid_end_price_coefficient) + limit_price = close * (1 - current_bb_width * self.config.grid_limit_price_coefficient) + else: + start_price = None + end_price = None + limit_price = None + + # Update processed data + self.processed_data["signal"] = df["signal"].iloc[-1] + self.processed_data["features"] = df + self.processed_data["grid_params"] = { + "start_price": start_price, + "end_price": end_price, + "limit_price": limit_price + } + + def get_executor_config(self, trade_type: TradeType, price: Decimal, amount: Decimal): + """ + Get the grid executor config based on the trade_type, price and amount. + Uses configurable grid parameters from the controller config. + """ + return GridExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + start_price=self.processed_data["grid_params"]["start_price"], + end_price=self.processed_data["grid_params"]["end_price"], + limit_price=self.processed_data["grid_params"]["limit_price"], + side=trade_type, + triple_barrier_config=self.config.triple_barrier_config, + leverage=self.config.leverage, + min_spread_between_orders=self.config.min_spread_between_orders, + total_amount_quote=amount * price, + order_frequency=self.config.order_frequency, + max_orders_per_batch=self.config.max_orders_per_batch, + min_order_amount_quote=self.config.min_order_amount_quote, + max_open_orders=self.config.max_open_orders, + ) + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] diff --git a/bots/controllers/directional_trading/dman_v3.py b/bots/controllers/directional_trading/dman_v3.py index cdf3c13a..7562af50 100644 --- a/bots/controllers/directional_trading/dman_v3.py +++ b/bots/controllers/directional_trading/dman_v3.py @@ -3,7 +3,9 @@ from typing import List, Optional, Tuple import pandas_ta as ta # noqa: F401 -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + from hummingbot.core.data_type.common import TradeType from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( @@ -12,74 +14,69 @@ ) from hummingbot.strategy_v2.executors.dca_executor.data_types import DCAExecutorConfig, DCAMode from hummingbot.strategy_v2.executors.position_executor.data_types import TrailingStop -from pydantic import Field, validator class DManV3ControllerConfig(DirectionalTradingControllerConfigBase): controller_name: str = "dman_v3" - candles_config: List[CandlesConfig] = [] - candles_connector: str = Field(default=None) - candles_trading_pair: str = Field(default=None) + candles_connector: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) + candles_trading_pair: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) interval: str = Field( - default="30m", - client_data=ClientFieldData( - prompt=lambda mi: "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", - prompt_on_new=True)) + default="3m", + json_schema_extra={ + "prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", + "prompt_on_new": True}) bb_length: int = Field( default=100, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands length: ", - prompt_on_new=True)) - bb_std: float = Field( - default=2.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands standard deviation: ", - prompt_on_new=False)) - bb_long_threshold: float = Field( - default=0.0, - client_data=ClientFieldData( - is_updatable=True, - prompt=lambda mi: "Enter the Bollinger Bands long threshold: ", - prompt_on_new=True)) - bb_short_threshold: float = Field( - default=1.0, - client_data=ClientFieldData( - is_updatable=True, - prompt=lambda mi: "Enter the Bollinger Bands short threshold: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the Bollinger Bands length: ", "prompt_on_new": True}) + bb_std: float = Field(default=2.0) + bb_long_threshold: float = Field(default=0.0) + bb_short_threshold: float = Field(default=1.0) + trailing_stop: Optional[TrailingStop] = Field( + default="0.015,0.005", + json_schema_extra={ + "prompt": "Enter the trailing stop parameters (activation_price, trailing_delta) as a comma-separated list: ", + "prompt_on_new": True, + } + ) dca_spreads: List[Decimal] = Field( default="0.001,0.018,0.15,0.25", - client_data=ClientFieldData( - prompt=lambda - mi: "Enter the spreads for each DCA level (comma-separated) if dynamic_spread=True this value " - "will multiply the Bollinger Bands width, e.g. if the Bollinger Bands width is 0.1 (10%)" - "and the spread is 0.2, the distance of the order to the current price will be 0.02 (2%) ", - prompt_on_new=True)) + json_schema_extra={ + "prompt": "Enter the spreads for each DCA level (comma-separated) if dynamic_spread=True this value " + "will multiply the Bollinger Bands width, e.g. if the Bollinger Bands width is 0.1 (10%)" + "and the spread is 0.2, the distance of the order to the current price will be 0.02 (2%) ", + "prompt_on_new": True}, + ) dca_amounts_pct: List[Decimal] = Field( default=None, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the amounts for each DCA level (as a percentage of the total balance, " - "comma-separated). Don't worry about the final sum, it will be normalized. ", - prompt_on_new=True)) + json_schema_extra={ + "prompt": "Enter the amounts for each DCA level (as a percentage of the total balance, " + "comma-separated). Don't worry about the final sum, it will be normalized. ", + "prompt_on_new": True}, + ) dynamic_order_spread: bool = Field( default=None, - client_data=ClientFieldData( - prompt=lambda mi: "Do you want to make the spread dynamic? (Yes/No) ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Do you want to make the spread dynamic? (Yes/No) ", "prompt_on_new": True}) dynamic_target: bool = Field( default=None, - client_data=ClientFieldData( - prompt=lambda mi: "Do you want to make the target dynamic? (Yes/No) ", - prompt_on_new=True)) - + json_schema_extra={"prompt": "Do you want to make the target dynamic? (Yes/No) ", "prompt_on_new": True}) activation_bounds: Optional[List[Decimal]] = Field( default=None, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the activation bounds for the orders " - "(e.g., 0.01 activates the next order when the price is closer than 1%): ", - prompt_on_new=True)) - - @validator("activation_bounds", pre=True, always=True) + json_schema_extra={ + "prompt": "Enter the activation bounds for the orders (e.g., 0.01 activates the next order when the price is closer than 1%): ", + "prompt_on_new": True, + } + ) + + @field_validator("activation_bounds", mode="before") + @classmethod def parse_activation_bounds(cls, v): if isinstance(v, str): if v == "": @@ -89,15 +86,17 @@ def parse_activation_bounds(cls, v): return [Decimal(val) for val in v] return v - @validator('dca_spreads', pre=True, always=True) + @field_validator('dca_spreads', mode="before") + @classmethod def validate_spreads(cls, v): if isinstance(v, str): return [Decimal(val) for val in v.split(",")] return v - @validator('dca_amounts_pct', pre=True, always=True) - def validate_amounts(cls, v, values): - spreads = values.get("dca_spreads") + @field_validator('dca_amounts_pct', mode="before") + @classmethod + def validate_amounts(cls, v, validation_info: ValidationInfo): + spreads = validation_info.data.get("dca_spreads") if isinstance(v, str): if v == "": return [Decimal('1.0') / len(spreads) for _ in spreads] @@ -109,9 +108,21 @@ def validate_amounts(cls, v, values): return [Decimal('1.0') / len(spreads) for _ in spreads] return v - def get_spreads_and_amounts_in_quote(self, - trade_type: TradeType, - total_amount_quote: Decimal) -> Tuple[List[Decimal], List[Decimal]]: + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): + if v is None or v == "": + return validation_info.data.get("connector_name") + return v + + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): + if v is None or v == "": + return validation_info.data.get("trading_pair") + return v + + def get_spreads_and_amounts_in_quote(self, trade_type: TradeType, total_amount_quote: Decimal) -> Tuple[List[Decimal], List[Decimal]]: amounts_pct = self.dca_amounts_pct if amounts_pct is None: # Equally distribute if amounts_pct is not set @@ -125,18 +136,6 @@ def get_spreads_and_amounts_in_quote(self, return self.dca_spreads, [amt_pct * total_amount_quote for amt_pct in normalized_amounts_pct] - @validator("candles_connector", pre=True, always=True) - def set_candles_connector(cls, v, values): - if v is None or v == "": - return values.get("connector_name") - return v - - @validator("candles_trading_pair", pre=True, always=True) - def set_candles_trading_pair(cls, v, values): - if v is None or v == "": - return values.get("trading_pair") - return v - class DManV3Controller(DirectionalTradingControllerBase): """ @@ -147,13 +146,6 @@ class DManV3Controller(DirectionalTradingControllerBase): def __init__(self, config: DManV3ControllerConfig, *args, **kwargs): self.config = config self.max_records = config.bb_length - if len(self.config.candles_config) == 0: - self.config.candles_config = [CandlesConfig( - connector=config.candles_connector, - trading_pair=config.candles_trading_pair, - interval=config.interval, - max_records=self.max_records - )] super().__init__(config, *args, **kwargs) async def update_processed_data(self): @@ -162,11 +154,11 @@ async def update_processed_data(self): interval=self.config.interval, max_records=self.max_records) # Add indicators - df.ta.bbands(length=self.config.bb_length, std=self.config.bb_std, append=True) + df.ta.bbands(length=self.config.bb_length, lower_std=self.config.bb_std, upper_std=self.config.bb_std, append=True) # Generate signal - long_condition = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}"] < self.config.bb_long_threshold - short_condition = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}"] > self.config.bb_short_threshold + long_condition = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}_{self.config.bb_std}"] < self.config.bb_long_threshold + short_condition = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}_{self.config.bb_std}"] > self.config.bb_short_threshold # Generate signal df["signal"] = 0 @@ -180,7 +172,7 @@ async def update_processed_data(self): def get_spread_multiplier(self) -> Decimal: if self.config.dynamic_order_spread: df = self.processed_data["features"] - bb_width = df[f"BBB_{self.config.bb_length}_{self.config.bb_std}"].iloc[-1] + bb_width = df[f"BBB_{self.config.bb_length}_{self.config.bb_std}_{self.config.bb_std}"].iloc[-1] return Decimal(bb_width / 200) else: return Decimal("1.0") @@ -194,9 +186,12 @@ def get_executor_config(self, trade_type: TradeType, price: Decimal, amount: Dec prices = [price * (1 + spread * spread_multiplier) for spread in spread] if self.config.dynamic_target: stop_loss = self.config.stop_loss * spread_multiplier - trailing_stop = TrailingStop( - activation_price=self.config.trailing_stop.activation_price * spread_multiplier, - trailing_delta=self.config.trailing_stop.trailing_delta * spread_multiplier) + if self.config.trailing_stop: + trailing_stop = TrailingStop( + activation_price=self.config.trailing_stop.activation_price * spread_multiplier, + trailing_delta=self.config.trailing_stop.trailing_delta * spread_multiplier) + else: + trailing_stop = None else: stop_loss = self.config.stop_loss trailing_stop = self.config.trailing_stop @@ -214,3 +209,11 @@ def get_executor_config(self, trade_type: TradeType, price: Decimal, amount: Dec leverage=self.config.leverage, activation_bounds=self.config.activation_bounds, ) + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] diff --git a/bots/controllers/directional_trading/macd_bb_v1.py b/bots/controllers/directional_trading/macd_bb_v1.py index ab215cbd..151a2990 100644 --- a/bots/controllers/directional_trading/macd_bb_v1.py +++ b/bots/controllers/directional_trading/macd_bb_v1.py @@ -1,71 +1,61 @@ from typing import List import pandas_ta as ta # noqa: F401 -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( DirectionalTradingControllerBase, DirectionalTradingControllerConfigBase, ) -from pydantic import Field, validator class MACDBBV1ControllerConfig(DirectionalTradingControllerConfigBase): - controller_name = "macd_bb_v1" - candles_config: List[CandlesConfig] = [] - candles_connector: str = Field(default=None) - candles_trading_pair: str = Field(default=None) + controller_name: str = "macd_bb_v1" + candles_connector: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) + candles_trading_pair: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) interval: str = Field( default="3m", - client_data=ClientFieldData( - prompt=lambda mi: "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", - prompt_on_new=False)) + json_schema_extra={ + "prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", + "prompt_on_new": True}) bb_length: int = Field( default=100, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands length: ", - prompt_on_new=True)) - bb_std: float = Field( - default=2.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands standard deviation: ", - prompt_on_new=False)) - bb_long_threshold: float = Field( - default=0.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands long threshold: ", - prompt_on_new=True)) - bb_short_threshold: float = Field( - default=1.0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the Bollinger Bands short threshold: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the Bollinger Bands length: ", "prompt_on_new": True}) + bb_std: float = Field(default=2.0) + bb_long_threshold: float = Field(default=0.0) + bb_short_threshold: float = Field(default=1.0) macd_fast: int = Field( default=21, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the MACD fast period: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the MACD fast period: ", "prompt_on_new": True}) macd_slow: int = Field( default=42, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the MACD slow period: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the MACD slow period: ", "prompt_on_new": True}) macd_signal: int = Field( default=9, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the MACD signal period: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the MACD signal period: ", "prompt_on_new": True}) - @validator("candles_connector", pre=True, always=True) - def set_candles_connector(cls, v, values): + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("connector_name") + return validation_info.data.get("connector_name") return v - @validator("candles_trading_pair", pre=True, always=True) - def set_candles_trading_pair(cls, v, values): + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("trading_pair") + return validation_info.data.get("trading_pair") return v @@ -73,14 +63,7 @@ class MACDBBV1Controller(DirectionalTradingControllerBase): def __init__(self, config: MACDBBV1ControllerConfig, *args, **kwargs): self.config = config - self.max_records = max(config.macd_slow, config.macd_fast, config.macd_signal, config.bb_length) - if len(self.config.candles_config) == 0: - self.config.candles_config = [CandlesConfig( - connector=config.candles_connector, - trading_pair=config.candles_trading_pair, - interval=config.interval, - max_records=self.max_records - )] + self.max_records = max(config.macd_slow, config.macd_fast, config.macd_signal, config.bb_length) + 20 super().__init__(config, *args, **kwargs) async def update_processed_data(self): @@ -89,10 +72,10 @@ async def update_processed_data(self): interval=self.config.interval, max_records=self.max_records) # Add indicators - df.ta.bbands(length=self.config.bb_length, std=self.config.bb_std, append=True) + df.ta.bbands(length=self.config.bb_length, lower_std=self.config.bb_std, upper_std=self.config.bb_std, append=True) df.ta.macd(fast=self.config.macd_fast, slow=self.config.macd_slow, signal=self.config.macd_signal, append=True) - bbp = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}"] + bbp = df[f"BBP_{self.config.bb_length}_{self.config.bb_std}_{self.config.bb_std}"] macdh = df[f"MACDh_{self.config.macd_fast}_{self.config.macd_slow}_{self.config.macd_signal}"] macd = df[f"MACD_{self.config.macd_fast}_{self.config.macd_slow}_{self.config.macd_signal}"] @@ -107,3 +90,11 @@ async def update_processed_data(self): # Update processed data self.processed_data["signal"] = df["signal"].iloc[-1] self.processed_data["features"] = df + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] diff --git a/bots/controllers/directional_trading/supertrend_v1.py b/bots/controllers/directional_trading/supertrend_v1.py index e96f4465..37f85bcc 100644 --- a/bots/controllers/directional_trading/supertrend_v1.py +++ b/bots/controllers/directional_trading/supertrend_v1.py @@ -1,39 +1,53 @@ -from typing import List, Optional +from typing import List import pandas_ta as ta # noqa: F401 -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.directional_trading_controller_base import ( DirectionalTradingControllerBase, DirectionalTradingControllerConfigBase, ) -from pydantic import Field, validator class SuperTrendConfig(DirectionalTradingControllerConfigBase): controller_name: str = "supertrend_v1" - candles_config: List[CandlesConfig] = [] - candles_connector: Optional[str] = Field(default=None) - candles_trading_pair: Optional[str] = Field(default=None) - interval: str = Field(default="3m") - length: int = Field(default=20, client_data=ClientFieldData(prompt=lambda mi: "Enter the supertrend length: ", - prompt_on_new=True)) - multiplier: float = Field(default=4.0, - client_data=ClientFieldData(prompt=lambda mi: "Enter the supertrend multiplier: ", - prompt_on_new=True)) - percentage_threshold: float = Field(default=0.01, client_data=ClientFieldData( - prompt=lambda mi: "Enter the percentage threshold: ", prompt_on_new=True)) + candles_connector: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) + candles_trading_pair: str = Field( + default=None, + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) + interval: str = Field( + default="3m", + json_schema_extra={"prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", "prompt_on_new": True}) + length: int = Field( + default=20, + json_schema_extra={"prompt": "Enter the supertrend length: ", "prompt_on_new": True}) + multiplier: float = Field( + default=4.0, + json_schema_extra={"prompt": "Enter the supertrend multiplier: ", "prompt_on_new": True}) + percentage_threshold: float = Field( + default=0.01, + json_schema_extra={"prompt": "Enter the percentage threshold: ", "prompt_on_new": True}) - @validator("candles_connector", pre=True, always=True) - def set_candles_connector(cls, v, values): + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("connector_name") + return validation_info.data.get("connector_name") return v - @validator("candles_trading_pair", pre=True, always=True) - def set_candles_trading_pair(cls, v, values): + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("trading_pair") + return validation_info.data.get("trading_pair") return v @@ -41,13 +55,6 @@ class SuperTrend(DirectionalTradingControllerBase): def __init__(self, config: SuperTrendConfig, *args, **kwargs): self.config = config self.max_records = config.length + 10 - if len(self.config.candles_config) == 0: - self.config.candles_config = [CandlesConfig( - connector=config.candles_connector, - trading_pair=config.candles_trading_pair, - interval=config.interval, - max_records=self.max_records - )] super().__init__(config, *args, **kwargs) async def update_processed_data(self): @@ -57,14 +64,11 @@ async def update_processed_data(self): max_records=self.max_records) # Add indicators df.ta.supertrend(length=self.config.length, multiplier=self.config.multiplier, append=True) - df["percentage_distance"] = abs(df["close"] - df[f"SUPERT_{self.config.length}_{self.config.multiplier}"]) / df[ - "close"] + df["percentage_distance"] = abs(df["close"] - df[f"SUPERT_{self.config.length}_{self.config.multiplier}"]) / df["close"] # Generate long and short conditions - long_condition = (df[f"SUPERTd_{self.config.length}_{self.config.multiplier}"] == 1) & ( - df["percentage_distance"] < self.config.percentage_threshold) - short_condition = (df[f"SUPERTd_{self.config.length}_{self.config.multiplier}"] == -1) & ( - df["percentage_distance"] < self.config.percentage_threshold) + long_condition = (df[f"SUPERTd_{self.config.length}_{self.config.multiplier}"] == 1) & (df["percentage_distance"] < self.config.percentage_threshold) + short_condition = (df[f"SUPERTd_{self.config.length}_{self.config.multiplier}"] == -1) & (df["percentage_distance"] < self.config.percentage_threshold) # Choose side df['signal'] = 0 @@ -74,3 +78,11 @@ async def update_processed_data(self): # Update processed data self.processed_data["signal"] = df["signal"].iloc[-1] self.processed_data["features"] = df + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] diff --git a/bots/controllers/generic/arbitrage_controller.py b/bots/controllers/generic/arbitrage_controller.py new file mode 100644 index 00000000..5cb0cac5 --- /dev/null +++ b/bots/controllers/generic/arbitrage_controller.py @@ -0,0 +1,191 @@ +from decimal import Decimal +from typing import List, Optional + +import pandas as pd + +from hummingbot.client.ui.interface_utils import format_df_for_printout +from hummingbot.core.data_type.common import MarketDict +from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient +from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.arbitrage_executor.data_types import ArbitrageExecutorConfig +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.models.base import RunnableStatus +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction + + +class ArbitrageControllerConfig(ControllerConfigBase): + controller_name: str = "arbitrage_controller" + exchange_pair_1: ConnectorPair = ConnectorPair(connector_name="binance", trading_pair="SOL-USDT") + exchange_pair_2: ConnectorPair = ConnectorPair(connector_name="jupiter/router", trading_pair="SOL-USDC") + min_profitability: Decimal = Decimal("0.01") + delay_between_executors: int = 10 # in seconds + max_executors_imbalance: int = 1 + rate_connector: str = "binance" + quote_conversion_asset: str = "USDT" + + def update_markets(self, markets: MarketDict) -> MarketDict: + return [markets.add_or_update(cp.connector_name, cp.trading_pair) for cp in [self.exchange_pair_1, self.exchange_pair_2]][-1] + + +class ArbitrageController(ControllerBase): + def __init__(self, config: ArbitrageControllerConfig, *args, **kwargs): + self.config = config + super().__init__(config, *args, **kwargs) + self._imbalance = 0 + self._last_buy_closed_timestamp = 0 + self._last_sell_closed_timestamp = 0 + self._len_active_buy_arbitrages = 0 + self._len_active_sell_arbitrages = 0 + self.base_asset = self.config.exchange_pair_1.trading_pair.split("-")[0] + self._gas_token_cache = {} # Cache for gas tokens by connector + self._initialize_gas_tokens() # Fetch gas tokens during init + self.initialize_rate_sources() + + def initialize_rate_sources(self): + rates_required = [] + for connector_pair in [self.config.exchange_pair_1, self.config.exchange_pair_2]: + base, quote = connector_pair.trading_pair.split("-") + + # Add rate source for gas token if it's an AMM connector + if connector_pair.is_amm_connector(): + gas_token = self.get_gas_token(connector_pair.connector_name) + if gas_token and gas_token != quote: + rates_required.append(ConnectorPair(connector_name=self.config.rate_connector, + trading_pair=f"{gas_token}-{quote}")) + + # Add rate source for quote conversion asset + if quote != self.config.quote_conversion_asset: + rates_required.append(ConnectorPair(connector_name=self.config.rate_connector, + trading_pair=f"{quote}-{self.config.quote_conversion_asset}")) + + # Add rate source for trading pairs + rates_required.append(ConnectorPair(connector_name=connector_pair.connector_name, + trading_pair=connector_pair.trading_pair)) + if len(rates_required) > 0: + self.market_data_provider.initialize_rate_sources(rates_required) + + def _initialize_gas_tokens(self): + """Initialize gas tokens for AMM connectors during controller initialization.""" + import asyncio + + async def fetch_gas_tokens(): + for connector_pair in [self.config.exchange_pair_1, self.config.exchange_pair_2]: + if connector_pair.is_amm_connector(): + connector_name = connector_pair.connector_name + if connector_name not in self._gas_token_cache: + try: + gateway_client = GatewayHttpClient.get_instance() + + # Get chain and network for the connector + chain, network, error = await gateway_client.get_connector_chain_network( + connector_name + ) + + if error: + self.logger().warning(f"Failed to get chain info for {connector_name}: {error}") + continue + + # Get native currency symbol + native_currency = await gateway_client.get_native_currency_symbol(chain, network) + + if native_currency: + self._gas_token_cache[connector_name] = native_currency + self.logger().info(f"Gas token for {connector_name}: {native_currency}") + else: + self.logger().warning(f"Failed to get native currency for {connector_name}") + except Exception as e: + self.logger().error(f"Error getting gas token for {connector_name}: {e}") + + # Run the async function to fetch gas tokens + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(fetch_gas_tokens()) + else: + loop.run_until_complete(fetch_gas_tokens()) + + def get_gas_token(self, connector_name: str) -> Optional[str]: + """Get the cached gas token for a connector.""" + return self._gas_token_cache.get(connector_name) + + async def update_processed_data(self): + pass + + def determine_executor_actions(self) -> List[ExecutorAction]: + self.update_arbitrage_stats() + executor_actions = [] + current_time = self.market_data_provider.time() + if (abs(self._imbalance) >= self.config.max_executors_imbalance or + self._last_buy_closed_timestamp + self.config.delay_between_executors > current_time or + self._last_sell_closed_timestamp + self.config.delay_between_executors > current_time): + return executor_actions + if self._len_active_buy_arbitrages == 0: + executor_actions.append(self.create_arbitrage_executor_action(self.config.exchange_pair_1, + self.config.exchange_pair_2)) + if self._len_active_sell_arbitrages == 0: + executor_actions.append(self.create_arbitrage_executor_action(self.config.exchange_pair_2, + self.config.exchange_pair_1)) + return [action for action in executor_actions if action is not None] + + def create_arbitrage_executor_action(self, buying_exchange_pair: ConnectorPair, + selling_exchange_pair: ConnectorPair): + try: + if buying_exchange_pair.is_amm_connector(): + gas_token = self.get_gas_token(buying_exchange_pair.connector_name) + if gas_token: + pair = buying_exchange_pair.trading_pair.split("-")[0] + "-" + gas_token + gas_conversion_price = self.market_data_provider.get_rate(pair) + else: + gas_conversion_price = None + elif selling_exchange_pair.is_amm_connector(): + gas_token = self.get_gas_token(selling_exchange_pair.connector_name) + if gas_token: + pair = selling_exchange_pair.trading_pair.split("-")[0] + "-" + gas_token + gas_conversion_price = self.market_data_provider.get_rate(pair) + else: + gas_conversion_price = None + else: + gas_conversion_price = None + rate = self.market_data_provider.get_rate(self.base_asset + "-" + self.config.quote_conversion_asset) + if not rate: + self.logger().warning( + f"Cannot get conversion rate for {self.base_asset}-{self.config.quote_conversion_asset}. " + f"Skipping executor creation.") + return None + amount_quantized = self.market_data_provider.quantize_order_amount( + buying_exchange_pair.connector_name, buying_exchange_pair.trading_pair, + self.config.total_amount_quote / rate) + arbitrage_config = ArbitrageExecutorConfig( + timestamp=self.market_data_provider.time(), + buying_market=buying_exchange_pair, + selling_market=selling_exchange_pair, + order_amount=amount_quantized, + min_profitability=self.config.min_profitability, + gas_conversion_price=gas_conversion_price, + ) + return CreateExecutorAction( + executor_config=arbitrage_config, + controller_id=self.config.id) + except Exception as e: + self.logger().error( + f"Error creating executor to buy on {buying_exchange_pair.connector_name} and sell on {selling_exchange_pair.connector_name}, {e}") + + def update_arbitrage_stats(self): + closed_executors = [e for e in self.executors_info if e.status == RunnableStatus.TERMINATED] + active_executors = [e for e in self.executors_info if e.status != RunnableStatus.TERMINATED] + buy_arbitrages = [arbitrage for arbitrage in closed_executors if + arbitrage.config.buying_market == self.config.exchange_pair_1] + sell_arbitrages = [arbitrage for arbitrage in closed_executors if + arbitrage.config.buying_market == self.config.exchange_pair_2] + self._imbalance = len(buy_arbitrages) - len(sell_arbitrages) + self._last_buy_closed_timestamp = max([arbitrage.close_timestamp for arbitrage in buy_arbitrages]) if len( + buy_arbitrages) > 0 else 0 + self._last_sell_closed_timestamp = max([arbitrage.close_timestamp for arbitrage in sell_arbitrages]) if len( + sell_arbitrages) > 0 else 0 + self._len_active_buy_arbitrages = len([arbitrage for arbitrage in active_executors if + arbitrage.config.buying_market == self.config.exchange_pair_1]) + self._len_active_sell_arbitrages = len([arbitrage for arbitrage in active_executors if + arbitrage.config.buying_market == self.config.exchange_pair_2]) + + def to_format_status(self) -> List[str]: + all_executors_custom_info = pd.DataFrame(e.custom_info for e in self.executors_info) + return [format_df_for_printout(all_executors_custom_info, table_format="psql", )] diff --git a/bots/controllers/generic/examples/__init__.py b/bots/controllers/generic/examples/__init__.py new file mode 100644 index 00000000..1887b4a1 --- /dev/null +++ b/bots/controllers/generic/examples/__init__.py @@ -0,0 +1,2 @@ +# Examples package for Hummingbot V2 Controllers +# This package contains example controllers migrated from the original scripts/ diff --git a/bots/controllers/generic/examples/basic_order_example.py b/bots/controllers/generic/examples/basic_order_example.py new file mode 100644 index 00000000..0ddde676 --- /dev/null +++ b/bots/controllers/generic/examples/basic_order_example.py @@ -0,0 +1,48 @@ +from decimal import Decimal + +from hummingbot.core.data_type.common import MarketDict, PositionMode, PriceType, TradeType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction + + +class BasicOrderExampleConfig(ControllerConfigBase): + controller_name: str = "examples.basic_order_example" + connector_name: str = "binance_perpetual" + trading_pair: str = "WLD-USDT" + side: TradeType = TradeType.BUY + position_mode: PositionMode = PositionMode.HEDGE + leverage: int = 20 + amount_quote: Decimal = Decimal("10") + order_frequency: int = 10 + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class BasicOrderExample(ControllerBase): + def __init__(self, config: BasicOrderExampleConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.last_timestamp = 0 + + async def update_processed_data(self): + mid_price = self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + n_active_executors = len([executor for executor in self.executors_info if executor.is_active]) + self.processed_data = {"mid_price": mid_price, "n_active_executors": n_active_executors} + + def determine_executor_actions(self) -> list[ExecutorAction]: + if (self.processed_data["n_active_executors"] == 0 and + self.market_data_provider.time() - self.last_timestamp > self.config.order_frequency): + self.last_timestamp = self.market_data_provider.time() + config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=self.config.side, + amount=self.config.amount_quote / self.processed_data["mid_price"], + execution_strategy=ExecutionStrategy.MARKET, + price=self.processed_data["mid_price"], + ) + return [CreateExecutorAction(controller_id=self.config.id, executor_config=config)] + return [] diff --git a/bots/controllers/generic/examples/basic_order_open_close_example.py b/bots/controllers/generic/examples/basic_order_open_close_example.py new file mode 100644 index 00000000..f959fd77 --- /dev/null +++ b/bots/controllers/generic/examples/basic_order_open_close_example.py @@ -0,0 +1,83 @@ +from decimal import Decimal + +from hummingbot.core.data_type.common import MarketDict, PositionAction, PositionMode, PriceType, TradeType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction + + +class BasicOrderOpenCloseExampleConfig(ControllerConfigBase): + controller_name: str = "examples.basic_order_open_close_example" + controller_type: str = "generic" + connector_name: str = "binance_perpetual" + trading_pair: str = "WLD-USDT" + side: TradeType = TradeType.BUY + position_mode: PositionMode = PositionMode.HEDGE + leverage: int = 50 + close_order_delay: int = 10 + open_short_to_close_long: bool = False + close_partial_position: bool = False + amount_quote: Decimal = Decimal("20") + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class BasicOrderOpenClose(ControllerBase): + def __init__(self, config: BasicOrderOpenCloseExampleConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.open_order_placed = False + self.closed_order_placed = False + self.last_timestamp = 0 + self.open_side = self.config.side + self.close_side = TradeType.SELL if self.config.side == TradeType.BUY else TradeType.BUY + + def get_position(self, connector_name, trading_pair): + for position in self.positions_held: + if position.connector_name == connector_name and position.trading_pair == trading_pair: + return position + + def determine_executor_actions(self) -> list[ExecutorAction]: + mid_price = self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + if not self.open_order_placed: + config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=self.config.side, + amount=self.config.amount_quote / mid_price, + execution_strategy=ExecutionStrategy.MARKET, + position_action=PositionAction.OPEN, + price=mid_price, + ) + self.open_order_placed = True + self.last_timestamp = self.market_data_provider.time() + return [CreateExecutorAction( + controller_id=self.config.id, + executor_config=config)] + else: + if self.market_data_provider.time() - self.last_timestamp > self.config.close_order_delay and not self.closed_order_placed: + current_position = self.get_position(self.config.connector_name, self.config.trading_pair) + if current_position is None: + self.logger().info("The original position is not found, can close the position") + else: + amount = current_position.amount / 2 if self.config.close_partial_position else current_position.amount + config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=self.close_side, + amount=amount, + execution_strategy=ExecutionStrategy.MARKET, + position_action=PositionAction.OPEN if self.config.open_short_to_close_long else PositionAction.CLOSE, + price=mid_price, + ) + self.closed_order_placed = True + return [CreateExecutorAction( + controller_id=self.config.id, + executor_config=config)] + return [] + + async def update_processed_data(self): + pass diff --git a/bots/controllers/generic/examples/buy_three_times_example.py b/bots/controllers/generic/examples/buy_three_times_example.py new file mode 100644 index 00000000..19f6fb2d --- /dev/null +++ b/bots/controllers/generic/examples/buy_three_times_example.py @@ -0,0 +1,69 @@ +from decimal import Decimal +from typing import List + +from hummingbot.core.data_type.common import MarketDict, PositionMode, PriceType, TradeType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction + + +class BuyThreeTimesExampleConfig(ControllerConfigBase): + controller_name: str = "examples.buy_three_times_example" + connector_name: str = "binance_perpetual" + trading_pair: str = "WLD-USDT" + position_mode: PositionMode = PositionMode.HEDGE + leverage: int = 20 + amount_quote: Decimal = Decimal("10") + order_frequency: int = 10 + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class BuyThreeTimesExample(ControllerBase): + def __init__(self, config: BuyThreeTimesExampleConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.last_timestamp = 0 + self.buy_count = 0 + self.max_buys = 3 + + async def update_processed_data(self): + mid_price = self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + n_active_executors = len([executor for executor in self.executors_info if executor.is_active]) + self.processed_data = { + "mid_price": mid_price, + "n_active_executors": n_active_executors, + "buy_count": self.buy_count, + "max_buys_reached": self.buy_count >= self.max_buys + } + + def determine_executor_actions(self) -> list[ExecutorAction]: + if (self.buy_count < self.max_buys and + self.processed_data["n_active_executors"] == 0 and + self.market_data_provider.time() - self.last_timestamp > self.config.order_frequency): + + self.last_timestamp = self.market_data_provider.time() + self.buy_count += 1 + + config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=TradeType.BUY, + amount=self.config.amount_quote / self.processed_data["mid_price"], + execution_strategy=ExecutionStrategy.MARKET, + price=self.processed_data["mid_price"], + ) + return [CreateExecutorAction(controller_id=self.config.id, executor_config=config)] + return [] + + def to_format_status(self) -> List[str]: + lines = [] + lines.append("Buy Three Times Example Status:") + lines.append(f" Buys completed: {self.buy_count}/{self.max_buys}") + lines.append(f" Max buys reached: {self.buy_count >= self.max_buys}") + if hasattr(self, 'processed_data') and self.processed_data: + lines.append(f" Mid price: {self.processed_data.get('mid_price', 'N/A')}") + lines.append(f" Active executors: {self.processed_data.get('n_active_executors', 'N/A')}") + return lines diff --git a/bots/controllers/generic/examples/candles_data_controller.py b/bots/controllers/generic/examples/candles_data_controller.py new file mode 100644 index 00000000..38a1de5c --- /dev/null +++ b/bots/controllers/generic/examples/candles_data_controller.py @@ -0,0 +1,202 @@ +from typing import List + +import pandas as pd +import pandas_ta as ta # noqa: F401 +from pydantic import Field, field_validator + +from hummingbot.core.data_type.common import MarketDict +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.models.executor_actions import ExecutorAction + + +class CandlesDataControllerConfig(ControllerConfigBase): + controller_name: str = "examples.candles_data_controller" + + # Candles configuration - user can modify these + candles_config: List[CandlesConfig] = Field( + default_factory=lambda: [ + CandlesConfig(connector="binance", trading_pair="ETH-USDT", interval="1m", max_records=1000), + CandlesConfig(connector="binance", trading_pair="ETH-USDT", interval="1h", max_records=1000), + CandlesConfig(connector="binance", trading_pair="ETH-USDT", interval="1w", max_records=200), + ], + json_schema_extra={ + "prompt": "Enter candles configurations (format: connector.pair.interval.max_records, separated by colons): ", + "prompt_on_new": True, + } + ) + + @field_validator('candles_config', mode="before") + @classmethod + def parse_candles_config(cls, v) -> List[CandlesConfig]: + # Handle string input (user provided) + if isinstance(v, str): + return cls.parse_candles_config_str(v) + # Handle list input (could be already CandlesConfig objects or dicts) + elif isinstance(v, list): + # If empty list, return as is + if not v: + return v + # If already CandlesConfig objects, return as is + if isinstance(v[0], CandlesConfig): + return v + # Otherwise, let Pydantic handle the conversion + return v + # Return as-is and let Pydantic validate + return v + + @staticmethod + def parse_candles_config_str(v: str) -> List[CandlesConfig]: + configs = [] + if v.strip(): + entries = v.split(':') + for entry in entries: + parts = entry.split('.') + if len(parts) != 4: + raise ValueError(f"Invalid candles config format in segment '{entry}'. " + "Expected format: 'exchange.tradingpair.interval.maxrecords'") + connector, trading_pair, interval, max_records_str = parts + try: + max_records = int(max_records_str) + except ValueError: + raise ValueError(f"Invalid max_records value '{max_records_str}' in segment '{entry}'. " + "max_records should be an integer.") + config = CandlesConfig( + connector=connector, + trading_pair=trading_pair, + interval=interval, + max_records=max_records + ) + configs.append(config) + return configs + + def update_markets(self, markets: MarketDict) -> MarketDict: + # This controller doesn't require any trading markets since it's only consuming data + return markets + + +class CandlesDataController(ControllerBase): + def __init__(self, config: CandlesDataControllerConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + + # Initialize candles based on config + for candles_config in self.config.candles_config: + self.market_data_provider.initialize_candles_feed(candles_config) + self.logger().info(f"Initialized {len(self.config.candles_config)} candle feeds successfully") + + @property + def all_candles_ready(self): + """ + Checks if all configured candles are ready. + """ + for candle in self.config.candles_config: + candles_feed = self.market_data_provider.get_candles_feed(candle) + # Check if the feed is ready and has data + if not candles_feed.ready or candles_feed.candles_df.empty: + return False + return True + + async def update_processed_data(self): + candles_data = {} + if self.all_candles_ready: + for i, candle_config in enumerate(self.config.candles_config): + candles_df = self.market_data_provider.get_candles_df( + connector_name=candle_config.connector, + trading_pair=candle_config.trading_pair, + interval=candle_config.interval, + max_records=50 + ) + if candles_df is not None and not candles_df.empty: + candles_df = candles_df.copy() + + # Calculate indicators if enough data + if len(candles_df) >= 20: + candles_df.ta.rsi(length=14, append=True) + candles_df.ta.bbands(length=20, std=2, append=True) + candles_df.ta.ema(length=14, append=True) + + candles_data[f"{candle_config.connector}_{candle_config.trading_pair}_{candle_config.interval}"] = candles_df + + self.processed_data = {"candles_data": candles_data, "all_candles_ready": self.all_candles_ready} + + def determine_executor_actions(self) -> list[ExecutorAction]: + # This controller is for data monitoring only, no trading actions + return [] + + def to_format_status(self) -> List[str]: + lines = [] + lines.extend(["\n" + "=" * 100]) + lines.extend([" CANDLES DATA CONTROLLER"]) + lines.extend(["=" * 100]) + + if self.all_candles_ready: + for i, candle_config in enumerate(self.config.candles_config): + candles_df = self.market_data_provider.get_candles_df( + connector_name=candle_config.connector, + trading_pair=candle_config.trading_pair, + interval=candle_config.interval, + max_records=50 + ) + + if candles_df is not None and not candles_df.empty: + candles_df = candles_df.copy() + + # Calculate indicators if we have enough data + if len(candles_df) >= 20: + candles_df.ta.rsi(length=14, append=True) + candles_df.ta.bbands(length=20, std=2, append=True) + candles_df.ta.ema(length=14, append=True) + + candles_df["timestamp"] = pd.to_datetime(candles_df["timestamp"], unit="s") + + # Display candles info + lines.extend([f"\n[{i + 1}] {candle_config.connector.upper()} | {candle_config.trading_pair} | {candle_config.interval}"]) + lines.extend(["-" * 80]) + + # Show last 5 rows with basic columns (OHLC + volume) + basic_columns = ["timestamp", "open", "high", "low", "close", "volume"] + indicator_columns = [] + + # Include indicators if they exist and have data + if "RSI_14" in candles_df.columns and candles_df["RSI_14"].notna().any(): + indicator_columns.append("RSI_14") + if "BBP_20_2.0_2.0" in candles_df.columns and candles_df["BBP_20_2.0_2.0"].notna().any(): + indicator_columns.append("BBP_20_2.0_2.0") + if "EMA_14" in candles_df.columns and candles_df["EMA_14"].notna().any(): + indicator_columns.append("EMA_14") + + display_columns = basic_columns + indicator_columns + display_df = candles_df.tail(5)[display_columns].copy() + + # Round numeric columns only, handle datetime columns separately + numeric_columns = display_df.select_dtypes(include=['number']).columns + display_df[numeric_columns] = display_df[numeric_columns].round(4) + lines.extend([" " + line for line in display_df.to_string(index=False).split("\n")]) + + # Current values + current = candles_df.iloc[-1] + lines.extend([""]) + current_price = f"Current Price: ${current['close']:.4f}" + + # Add indicator values if available + if "RSI_14" in candles_df.columns and pd.notna(current.get('RSI_14')): + current_price += f" | RSI: {current['RSI_14']:.2f}" + + if "BBP_20_2.0_2.0" in candles_df.columns and pd.notna(current.get('BBP_20_2.0_2.0')): + current_price += f" | BB%: {current['BBP_20_2.0_2.0']:.3f}" + + lines.extend([f" {current_price}"]) + else: + lines.extend([f"\n[{i + 1}] {candle_config.connector.upper()} | {candle_config.trading_pair} | {candle_config.interval}"]) + lines.extend([" No data available yet..."]) + else: + lines.extend(["\n⏳ Waiting for candles data to be ready..."]) + for candle_config in self.config.candles_config: + candles_feed = self.market_data_provider.get_candles_feed(candle_config) + ready = candles_feed.ready and not candles_feed.candles_df.empty + status = "✅" if ready else "❌" + lines.extend([f" {status} {candle_config.connector}.{candle_config.trading_pair}.{candle_config.interval}"]) + + lines.extend(["\n" + "=" * 100 + "\n"]) + return lines diff --git a/bots/controllers/generic/examples/full_trading_example.py b/bots/controllers/generic/examples/full_trading_example.py new file mode 100644 index 00000000..e91b7d26 --- /dev/null +++ b/bots/controllers/generic/examples/full_trading_example.py @@ -0,0 +1,190 @@ +from decimal import Decimal + +from hummingbot.core.data_type.common import MarketDict, PriceType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, LimitChaserConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import ExecutorAction + + +class FullTradingExampleConfig(ControllerConfigBase): + controller_name: str = "examples.full_trading_example" + connector_name: str = "binance_perpetual" + trading_pair: str = "ETH-USDT" + amount: Decimal = Decimal("0.1") + spread: Decimal = Decimal("0.002") # 0.2% spread + max_open_orders: int = 3 + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class FullTradingExample(ControllerBase): + """ + Example controller demonstrating the full trading API built into ControllerBase. + + This controller shows how to use buy(), sell(), cancel(), open_orders(), + and open_positions() methods for intuitive trading operations. + """ + + def __init__(self, config: FullTradingExampleConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + + async def update_processed_data(self): + """Update market data for decision making.""" + mid_price = self.get_current_price( + self.config.connector_name, + self.config.trading_pair, + PriceType.MidPrice + ) + + open_orders = self.open_orders( + self.config.connector_name, + self.config.trading_pair + ) + + open_positions = self.open_positions( + self.config.connector_name, + self.config.trading_pair + ) + + self.processed_data = { + "mid_price": mid_price, + "open_orders": open_orders, + "open_positions": open_positions, + "n_open_orders": len(open_orders) + } + + def determine_executor_actions(self) -> list[ExecutorAction]: + """ + Demonstrate different trading scenarios using the beautiful API. + """ + actions = [] + mid_price = self.processed_data["mid_price"] + n_open_orders = self.processed_data["n_open_orders"] + + # Scenario 1: Market buy with risk management + if n_open_orders == 0: + # Create a market buy with triple barrier for risk management + triple_barrier = TripleBarrierConfig( + stop_loss=Decimal("0.02"), # 2% stop loss + take_profit=Decimal("0.03"), # 3% take profit + time_limit=300 # 5 minutes time limit + ) + + executor_id = self.buy( + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + amount=self.config.amount, + execution_strategy=ExecutionStrategy.MARKET, + triple_barrier_config=triple_barrier, + keep_position=True + ) + + self.logger().info(f"Created market buy order with triple barrier: {executor_id}") + + # Scenario 2: Limit orders with spread + elif n_open_orders < self.config.max_open_orders: + # Place limit buy below market + buy_price = mid_price * (Decimal("1") - self.config.spread) + buy_executor_id = self.buy( + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + amount=self.config.amount, + price=buy_price, + execution_strategy=ExecutionStrategy.LIMIT_MAKER, + keep_position=True + ) + + # Place limit sell above market + sell_price = mid_price * (Decimal("1") + self.config.spread) + sell_executor_id = self.sell( + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + amount=self.config.amount, + price=sell_price, + execution_strategy=ExecutionStrategy.LIMIT_MAKER, + keep_position=True + ) + + self.logger().info(f"Created limit orders - Buy: {buy_executor_id}, Sell: {sell_executor_id}") + + # Scenario 3: Limit chaser example + elif n_open_orders < self.config.max_open_orders + 1: + # Use limit chaser for better fill rates + chaser_config = LimitChaserConfig( + distance=Decimal("0.001"), # 0.1% from best price + refresh_threshold=Decimal("0.002") # Refresh if price moves 0.2% + ) + + chaser_executor_id = self.buy( + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + amount=self.config.amount, + execution_strategy=ExecutionStrategy.LIMIT_CHASER, + chaser_config=chaser_config, + keep_position=True + ) + + self.logger().info(f"Created limit chaser order: {chaser_executor_id}") + + return actions # Actions are handled automatically by the mixin + + def demonstrate_cancel_operations(self): + """ + Example of how to use cancel operations. + """ + # Cancel a specific order by executor ID + open_orders = self.open_orders() + if open_orders: + executor_id = open_orders[0]['executor_id'] + success = self.cancel(executor_id) + self.logger().info(f"Cancelled executor {executor_id}: {success}") + + # Cancel all orders for a specific trading pair + cancelled_ids = self.cancel_all( + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair + ) + self.logger().info(f"Cancelled {len(cancelled_ids)} orders: {cancelled_ids}") + + def to_format_status(self) -> list[str]: + """Display controller status with trading information.""" + lines = [] + + if self.processed_data: + mid_price = self.processed_data["mid_price"] + open_orders = self.processed_data["open_orders"] + open_positions = self.processed_data["open_positions"] + + lines.append("=== Beautiful Trading Example Controller ===") + lines.append(f"Trading Pair: {self.config.trading_pair}") + lines.append(f"Current Price: {mid_price:.6f}") + lines.append(f"Open Orders: {len(open_orders)}") + lines.append(f"Open Positions: {len(open_positions)}") + + if open_orders: + lines.append("--- Open Orders ---") + for order in open_orders: + lines.append(f" {order['side']} {order['amount']:.4f} @ {order.get('price', 'MARKET')} " + f"(Filled: {order['filled_amount']:.4f}) - {order['status']}") + + if open_positions: + lines.append("--- Held Positions ---") + for position in open_positions: + lines.append(f" {position['side']} {position['amount']:.4f} @ {position['entry_price']:.6f} " + f"(PnL: {position['pnl_percentage']:.2f}%)") + + return lines + + def get_custom_info(self) -> dict: + """Return custom information for MQTT reporting.""" + if self.processed_data: + return { + "mid_price": float(self.processed_data["mid_price"]), + "n_open_orders": len(self.processed_data["open_orders"]), + "n_open_positions": len(self.processed_data["open_positions"]), + "total_open_volume": sum(order["amount"] for order in self.processed_data["open_orders"]) + } + return {} diff --git a/bots/controllers/generic/examples/liquidations_monitor_controller.py b/bots/controllers/generic/examples/liquidations_monitor_controller.py new file mode 100644 index 00000000..c67c631a --- /dev/null +++ b/bots/controllers/generic/examples/liquidations_monitor_controller.py @@ -0,0 +1,94 @@ +from typing import List + +from pydantic import Field + +from hummingbot.client.ui.interface_utils import format_df_for_printout +from hummingbot.core.data_type.common import MarketDict +from hummingbot.data_feed.liquidations_feed.liquidations_factory import LiquidationsConfig, LiquidationsFactory +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.models.executor_actions import ExecutorAction + + +class LiquidationsMonitorControllerConfig(ControllerConfigBase): + controller_name: str = "examples.liquidations_monitor_controller" + exchange: str = Field(default="binance_paper_trade") + trading_pair: str = Field(default="BTC-USDT") + liquidations_trading_pairs: list = Field(default=["BTC-USDT", "1000PEPE-USDT", "1000BONK-USDT", "HBAR-USDT"]) + max_retention_seconds: int = Field(default=10) + + def update_markets(self, markets: MarketDict) -> MarketDict: + markets[self.exchange] = markets.get(self.exchange, set()) | {self.trading_pair} + return markets + + +class LiquidationsMonitorController(ControllerBase): + def __init__(self, config: LiquidationsMonitorControllerConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + + # Initialize liquidations feed + self.binance_liquidations_config = LiquidationsConfig( + connector="binance", # the source for liquidation data (currently only binance is supported) + max_retention_seconds=self.config.max_retention_seconds, # how many seconds the data should be stored + trading_pairs=self.config.liquidations_trading_pairs + ) + self.binance_liquidations_feed = LiquidationsFactory.get_liquidations_feed(self.binance_liquidations_config) + self.binance_liquidations_feed.start() + + async def update_processed_data(self): + liquidations_data = { + "feed_ready": self.binance_liquidations_feed.ready, + "trading_pairs": self.config.liquidations_trading_pairs + } + + if self.binance_liquidations_feed.ready: + try: + # Get combined liquidations dataframe + liquidations_data["combined_df"] = self.binance_liquidations_feed.liquidations_df() + + # Get individual trading pair dataframes + liquidations_data["individual_dfs"] = {} + for trading_pair in self.config.liquidations_trading_pairs: + liquidations_data["individual_dfs"][trading_pair] = self.binance_liquidations_feed.liquidations_df(trading_pair) + except Exception as e: + self.logger().error(f"Error getting liquidations data: {e}") + liquidations_data["error"] = str(e) + + self.processed_data = liquidations_data + + def determine_executor_actions(self) -> list[ExecutorAction]: + # This controller is for monitoring only, no trading actions + return [] + + def to_format_status(self) -> List[str]: + lines = [] + lines.extend(["", "LIQUIDATIONS MONITOR"]) + lines.extend(["=" * 50]) + + if not self.binance_liquidations_feed.ready: + lines.append("Feed not ready yet!") + else: + try: + # Combined liquidations + lines.append("Combined liquidations:") + combined_df = self.binance_liquidations_feed.liquidations_df().tail(10) + lines.extend([format_df_for_printout(df=combined_df, table_format="psql")]) + lines.append("") + lines.append("") + + # Individual trading pairs + for trading_pair in self.binance_liquidations_config.trading_pairs: + lines.append("Liquidations for trading pair: {}".format(trading_pair)) + pair_df = self.binance_liquidations_feed.liquidations_df(trading_pair).tail(5) + lines.extend([format_df_for_printout(df=pair_df, table_format="psql")]) + lines.append("") + except Exception as e: + lines.append(f"Error displaying liquidations data: {e}") + + return lines + + async def stop(self): + """Clean shutdown of the liquidations feed""" + if hasattr(self, 'binance_liquidations_feed'): + self.binance_liquidations_feed.stop() + await super().stop() diff --git a/bots/controllers/generic/examples/market_status_controller.py b/bots/controllers/generic/examples/market_status_controller.py new file mode 100644 index 00000000..3aa328e9 --- /dev/null +++ b/bots/controllers/generic/examples/market_status_controller.py @@ -0,0 +1,131 @@ +from typing import List + +import pandas as pd +from pydantic import Field + +from hummingbot.core.data_type.common import MarketDict, PriceType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.models.executor_actions import ExecutorAction + + +class MarketStatusControllerConfig(ControllerConfigBase): + controller_name: str = "examples.market_status_controller" + exchanges: list = Field(default=["binance_paper_trade", "kucoin_paper_trade", "gate_io_paper_trade"]) + trading_pairs: list = Field(default=["ETH-USDT", "BTC-USDT", "POL-USDT", "AVAX-USDT", "WLD-USDT", "DOGE-USDT", "SHIB-USDT", "XRP-USDT", "SOL-USDT"]) + + def update_markets(self, markets: MarketDict) -> MarketDict: + # Add all combinations of exchanges and trading pairs + for exchange in self.exchanges: + markets[exchange] = markets.get(exchange, set()) | set(self.trading_pairs) + return markets + + +class MarketStatusController(ControllerBase): + def __init__(self, config: MarketStatusControllerConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + + @property + def ready_to_trade(self) -> bool: + """ + Check if all configured exchanges and trading pairs are ready for trading. + """ + try: + for exchange in self.config.exchanges: + for trading_pair in self.config.trading_pairs: + # Try to get price data to verify connectivity + price = self.market_data_provider.get_price_by_type(exchange, trading_pair, PriceType.MidPrice) + if price is None: + return False + return True + except Exception: + return False + + async def update_processed_data(self): + market_status_data = {} + if self.ready_to_trade: + try: + market_status_df = self.get_market_status_df_with_depth() + market_status_data = { + "market_status_df": market_status_df, + "ready_to_trade": True + } + except Exception as e: + self.logger().error(f"Error getting market status: {e}") + market_status_data = { + "error": str(e), + "ready_to_trade": False + } + else: + market_status_data = {"ready_to_trade": False} + + self.processed_data = market_status_data + + def determine_executor_actions(self) -> list[ExecutorAction]: + # This controller is for monitoring only, no trading actions + return [] + + def to_format_status(self) -> List[str]: + if not self.ready_to_trade: + return ["Market connectors are not ready."] + + lines = [] + lines.extend(["", " Market Status Data Frame:"]) + + try: + market_status_df = self.get_market_status_df_with_depth() + lines.extend([" " + line for line in market_status_df.to_string(index=False).split("\n")]) + except Exception as e: + lines.extend([f" Error: {str(e)}"]) + + return lines + + def get_market_status_df_with_depth(self): + """ + Create a DataFrame with market status information including prices and volumes. + """ + data = [] + for exchange in self.config.exchanges: + for trading_pair in self.config.trading_pairs: + try: + best_ask = self.market_data_provider.get_price_by_type(exchange, trading_pair, PriceType.BestAsk) + best_bid = self.market_data_provider.get_price_by_type(exchange, trading_pair, PriceType.BestBid) + mid_price = self.market_data_provider.get_price_by_type(exchange, trading_pair, PriceType.MidPrice) + + # Calculate volumes at +/-1% from mid price + volume_plus_1 = None + volume_minus_1 = None + if mid_price: + try: + price_plus_1 = mid_price * 1.01 + price_minus_1 = mid_price * 0.99 + volume_plus_1 = self.market_data_provider.get_volume_for_price(exchange, trading_pair, float(price_plus_1), True) + volume_minus_1 = self.market_data_provider.get_volume_for_price(exchange, trading_pair, float(price_minus_1), False) + except Exception: + volume_plus_1 = "N/A" + volume_minus_1 = "N/A" + + data.append({ + "Exchange": exchange.replace("_paper_trade", "").title(), + "Market": trading_pair, + "Best Bid": best_bid, + "Best Ask": best_ask, + "Mid Price": mid_price, + "Volume (+1%)": volume_plus_1, + "Volume (-1%)": volume_minus_1 + }) + except Exception as e: + self.logger().error(f"Error getting market status: {e}") + data.append({ + "Exchange": exchange.replace("_paper_trade", "").title(), + "Market": trading_pair, + "Best Bid": "Error", + "Best Ask": "Error", + "Mid Price": "Error", + "Volume (+1%)": "Error", + "Volume (-1%)": "Error" + }) + + market_status_df = pd.DataFrame(data) + market_status_df.sort_values(by=["Market"], inplace=True) + return market_status_df diff --git a/bots/controllers/generic/examples/price_monitor_controller.py b/bots/controllers/generic/examples/price_monitor_controller.py new file mode 100644 index 00000000..a3468a31 --- /dev/null +++ b/bots/controllers/generic/examples/price_monitor_controller.py @@ -0,0 +1,119 @@ +from typing import List + +from pydantic import Field + +from hummingbot.core.data_type.common import MarketDict, PriceType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.models.executor_actions import ExecutorAction + + +class PriceMonitorControllerConfig(ControllerConfigBase): + controller_name: str = "examples.price_monitor_controller" + exchanges: list = Field(default=["binance_paper_trade", "kucoin_paper_trade", "gate_io_paper_trade"]) + trading_pair: str = Field(default="ETH-USDT") + log_interval: int = Field(default=60) # seconds between price logs + + def update_markets(self, markets: MarketDict) -> MarketDict: + # Add the trading pair to all exchanges + for exchange in self.exchanges: + markets[exchange] = markets.get(exchange, set()) | {self.trading_pair} + return markets + + +class PriceMonitorController(ControllerBase): + def __init__(self, config: PriceMonitorControllerConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.last_log_time = 0 + + async def update_processed_data(self): + price_data = {} + current_time = self.market_data_provider.time() + + # Log prices at specified intervals + if current_time - self.last_log_time >= self.config.log_interval: + self.last_log_time = current_time + + for connector_name in self.config.exchanges: + try: + best_ask = self.market_data_provider.get_price_by_type(connector_name, self.config.trading_pair, PriceType.BestAsk) + best_bid = self.market_data_provider.get_price_by_type(connector_name, self.config.trading_pair, PriceType.BestBid) + mid_price = self.market_data_provider.get_price_by_type(connector_name, self.config.trading_pair, PriceType.MidPrice) + + price_info = { + "best_ask": best_ask, + "best_bid": best_bid, + "mid_price": mid_price, + "spread": best_ask - best_bid if best_ask and best_bid else None, + "spread_pct": ((best_ask - best_bid) / mid_price * 100) if best_ask and best_bid and mid_price else None + } + + price_data[connector_name] = price_info + + # Log to console + self.logger().info(f"Connector: {connector_name}") + self.logger().info(f"Best ask: {best_ask}") + self.logger().info(f"Best bid: {best_bid}") + self.logger().info(f"Mid price: {mid_price}") + if price_info["spread"]: + self.logger().info(f"Spread: {price_info['spread']:.6f} ({price_info['spread_pct']:.3f}%)") + + except Exception as e: + self.logger().error(f"Error getting price data for {connector_name}: {e}") + price_data[connector_name] = {"error": str(e)} + + self.processed_data = { + "price_data": price_data, + "last_log_time": self.last_log_time, + "trading_pair": self.config.trading_pair + } + + def determine_executor_actions(self) -> list[ExecutorAction]: + # This controller is for monitoring only, no trading actions + return [] + + def to_format_status(self) -> List[str]: + lines = [] + lines.extend(["", f"PRICE MONITOR - {self.config.trading_pair}"]) + lines.extend(["=" * 60]) + + if hasattr(self, 'processed_data') and self.processed_data.get("price_data"): + for connector_name, price_info in self.processed_data["price_data"].items(): + lines.extend([f"\n{connector_name.upper()}:"]) + + if "error" in price_info: + lines.extend([f" Error: {price_info['error']}"]) + else: + lines.extend([f" Best Ask: {price_info.get('best_ask', 'N/A')}"]) + lines.extend([f" Best Bid: {price_info.get('best_bid', 'N/A')}"]) + lines.extend([f" Mid Price: {price_info.get('mid_price', 'N/A')}"]) + + if price_info.get('spread') is not None: + lines.extend([f" Spread: {price_info['spread']:.6f} ({price_info['spread_pct']:.3f}%)"]) + else: + # Get current prices for display + for connector_name in self.config.exchanges: + try: + best_ask = self.market_data_provider.get_price_by_type(connector_name, self.config.trading_pair, PriceType.BestAsk) + best_bid = self.market_data_provider.get_price_by_type(connector_name, self.config.trading_pair, PriceType.BestBid) + mid_price = self.market_data_provider.get_price_by_type(connector_name, self.config.trading_pair, PriceType.MidPrice) + + lines.extend([f"\n{connector_name.upper()}:"]) + lines.extend([f" Best Ask: {best_ask}"]) + lines.extend([f" Best Bid: {best_bid}"]) + lines.extend([f" Mid Price: {mid_price}"]) + + if best_ask and best_bid and mid_price: + spread = best_ask - best_bid + spread_pct = spread / mid_price * 100 + lines.extend([f" Spread: {spread:.6f} ({spread_pct:.3f}%)"]) + + except Exception as e: + lines.extend([f"\n{connector_name.upper()}:"]) + lines.extend([f" Error: {str(e)}"]) + + next_log_time = self.last_log_time + self.config.log_interval + time_until_next_log = max(0, next_log_time - self.market_data_provider.time()) + lines.extend([f"\nNext price log in: {time_until_next_log:.0f} seconds"]) + + return lines diff --git a/bots/controllers/generic/grid_strike.py b/bots/controllers/generic/grid_strike.py new file mode 100644 index 00000000..0af3f412 --- /dev/null +++ b/bots/controllers/generic/grid_strike.py @@ -0,0 +1,194 @@ +from decimal import Decimal +from typing import List, Optional + +from pydantic import Field + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.grid_executor.data_types import GridExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction +from hummingbot.strategy_v2.models.executors_info import ExecutorInfo + + +class GridStrikeConfig(ControllerConfigBase): + """ + Configuration required to run the GridStrike strategy for one connector and trading pair. + """ + controller_type: str = "generic" + controller_name: str = "grid_strike" + + # Account configuration + leverage: int = 20 + position_mode: PositionMode = PositionMode.HEDGE + + # Boundaries + connector_name: str = "binance_perpetual" + trading_pair: str = "WLD-USDT" + side: TradeType = TradeType.BUY + start_price: Decimal = Field(default=Decimal("0.58"), json_schema_extra={"is_updatable": True}) + end_price: Decimal = Field(default=Decimal("0.95"), json_schema_extra={"is_updatable": True}) + limit_price: Decimal = Field(default=Decimal("0.55"), json_schema_extra={"is_updatable": True}) + + # Profiling + total_amount_quote: Decimal = Field(default=Decimal("1000"), json_schema_extra={"is_updatable": True}) + min_spread_between_orders: Optional[Decimal] = Field(default=Decimal("0.001"), json_schema_extra={"is_updatable": True}) + min_order_amount_quote: Optional[Decimal] = Field(default=Decimal("5"), json_schema_extra={"is_updatable": True}) + + # Execution + max_open_orders: int = Field(default=2, json_schema_extra={"is_updatable": True}) + max_orders_per_batch: Optional[int] = Field(default=1, json_schema_extra={"is_updatable": True}) + order_frequency: int = Field(default=3, json_schema_extra={"is_updatable": True}) + activation_bounds: Optional[Decimal] = Field(default=None, json_schema_extra={"is_updatable": True}) + keep_position: bool = Field(default=False, json_schema_extra={"is_updatable": True}) + + # Risk Management + triple_barrier_config: TripleBarrierConfig = TripleBarrierConfig( + take_profit=Decimal("0.001"), + open_order_type=OrderType.LIMIT_MAKER, + take_profit_order_type=OrderType.LIMIT_MAKER, + ) + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class GridStrike(ControllerBase): + def __init__(self, config: GridStrikeConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self._last_grid_levels_update = 0 + self.trading_rules = None + self.grid_levels = [] + self.initialize_rate_sources() + + def initialize_rate_sources(self): + self.market_data_provider.initialize_rate_sources([ConnectorPair(connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair)]) + + def active_executors(self) -> List[ExecutorInfo]: + return [ + executor for executor in self.executors_info + if executor.is_active + ] + + def is_inside_bounds(self, price: Decimal) -> bool: + return self.config.start_price <= price <= self.config.end_price + + def determine_executor_actions(self) -> List[ExecutorAction]: + mid_price = self.market_data_provider.get_price_by_type( + self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + if len(self.active_executors()) == 0 and self.is_inside_bounds(mid_price): + return [CreateExecutorAction( + controller_id=self.config.id, + executor_config=GridExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + start_price=self.config.start_price, + end_price=self.config.end_price, + leverage=self.config.leverage, + limit_price=self.config.limit_price, + side=self.config.side, + total_amount_quote=self.config.total_amount_quote, + min_spread_between_orders=self.config.min_spread_between_orders, + min_order_amount_quote=self.config.min_order_amount_quote, + max_open_orders=self.config.max_open_orders, + max_orders_per_batch=self.config.max_orders_per_batch, + order_frequency=self.config.order_frequency, + activation_bounds=self.config.activation_bounds, + triple_barrier_config=self.config.triple_barrier_config, + level_id=None, + keep_position=self.config.keep_position, + ))] + return [] + + async def update_processed_data(self): + pass + + def to_format_status(self) -> List[str]: + status = [] + mid_price = self.market_data_provider.get_price_by_type( + self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + # Define standard box width for consistency + box_width = 114 + # Top Grid Configuration box with simple borders + status.append("┌" + "─" * box_width + "┐") + # First line: Grid Configuration and Mid Price + left_section = "Grid Configuration:" + padding = box_width - len(left_section) - 4 # -4 for the border characters and spacing + config_line1 = f"│ {left_section}{' ' * padding}" + padding2 = box_width - len(config_line1) + 1 # +1 for correct right border alignment + config_line1 += " " * padding2 + "│" + status.append(config_line1) + # Second line: Configuration parameters + config_line2 = f"│ Start: {self.config.start_price:.4f} │ End: {self.config.end_price:.4f} │ Side: {self.config.side} │ Limit: {self.config.limit_price:.4f} │ Mid Price: {mid_price:.4f} │" + padding = box_width - len(config_line2) + 1 # +1 for correct right border alignment + config_line2 += " " * padding + "│" + status.append(config_line2) + # Third line: Max orders and Inside bounds + config_line3 = f"│ Max Orders: {self.config.max_open_orders} │ Inside bounds: {1 if self.is_inside_bounds(mid_price) else 0}" + padding = box_width - len(config_line3) + 1 # +1 for correct right border alignment + config_line3 += " " * padding + "│" + status.append(config_line3) + status.append("└" + "─" * box_width + "┘") + for level in self.active_executors(): + # Define column widths for perfect alignment + col_width = box_width // 3 # Dividing the total width by 3 for equal columns + total_width = box_width + # Grid Status header - use long line and running status + status_header = f"Grid Status: {level.id} (RunnableStatus.RUNNING)" + status_line = f"┌ {status_header}" + "─" * (total_width - len(status_header) - 2) + "┐" + status.append(status_line) + # Calculate exact column widths for perfect alignment + col1_end = col_width + # Column headers + header_line = "│ Level Distribution" + " " * (col1_end - 20) + "│" + header_line += " Order Statistics" + " " * (col_width - 18) + "│" + header_line += " Performance Metrics" + " " * (col_width - 21) + "│" + status.append(header_line) + # Data for the three columns + level_dist_data = [ + f"NOT_ACTIVE: {len(level.custom_info['levels_by_state'].get('NOT_ACTIVE', []))}", + f"OPEN_ORDER_PLACED: {len(level.custom_info['levels_by_state'].get('OPEN_ORDER_PLACED', []))}", + f"OPEN_ORDER_FILLED: {len(level.custom_info['levels_by_state'].get('OPEN_ORDER_FILLED', []))}", + f"CLOSE_ORDER_PLACED: {len(level.custom_info['levels_by_state'].get('CLOSE_ORDER_PLACED', []))}", + f"COMPLETE: {len(level.custom_info['levels_by_state'].get('COMPLETE', []))}" + ] + order_stats_data = [ + f"Total: {sum(len(level.custom_info[k]) for k in ['filled_orders', 'failed_orders', 'canceled_orders'])}", + f"Filled: {len(level.custom_info['filled_orders'])}", + f"Failed: {len(level.custom_info['failed_orders'])}", + f"Canceled: {len(level.custom_info['canceled_orders'])}" + ] + perf_metrics_data = [ + f"Buy Vol: {level.custom_info['realized_buy_size_quote']:.4f}", + f"Sell Vol: {level.custom_info['realized_sell_size_quote']:.4f}", + f"R. PnL: {level.custom_info['realized_pnl_quote']:.4f}", + f"R. Fees: {level.custom_info['realized_fees_quote']:.4f}", + f"P. PnL: {level.custom_info['position_pnl_quote']:.4f}", + f"Position: {level.custom_info['position_size_quote']:.4f}" + ] + # Build rows with perfect alignment + max_rows = max(len(level_dist_data), len(order_stats_data), len(perf_metrics_data)) + for i in range(max_rows): + col1 = level_dist_data[i] if i < len(level_dist_data) else "" + col2 = order_stats_data[i] if i < len(order_stats_data) else "" + col3 = perf_metrics_data[i] if i < len(perf_metrics_data) else "" + row = "│ " + col1 + row += " " * (col1_end - len(col1) - 2) # -2 for the "│ " at the start + row += "│ " + col2 + row += " " * (col_width - len(col2) - 2) # -2 for the "│ " before col2 + row += "│ " + col3 + row += " " * (col_width - len(col3) - 2) # -2 for the "│ " before col3 + row += "│" + status.append(row) + # Liquidity line with perfect alignment + status.append("├" + "─" * total_width + "┤") + liquidity_line = f"│ Open Liquidity: {level.custom_info['open_liquidity_placed']:.4f} │ Close Liquidity: {level.custom_info['close_liquidity_placed']:.4f} │" + liquidity_line += " " * (total_width - len(liquidity_line) + 1) # +1 for correct right border alignment + liquidity_line += "│" + status.append(liquidity_line) + status.append("└" + "─" * total_width + "┘") + return status diff --git a/bots/controllers/generic/hedge_asset.py b/bots/controllers/generic/hedge_asset.py new file mode 100644 index 00000000..8b79facb --- /dev/null +++ b/bots/controllers/generic/hedge_asset.py @@ -0,0 +1,174 @@ +""" +Explanation: + +This strategy tracks the spot balance of a single asset on one exchange and maintains a hedge on a perpetual exchange +using a fixed, user-defined hedge ratio. It continuously compares the target hedge size (spot_balance × hedge_ratio) +with the actual short position and adjusts only when the difference exceeds a minimum notional threshold and enough +time has passed since the last order. This prevents overtrading while keeping the exposure appropriately hedged. The +user can manually update the hedge ratio in the config, and the controller will rebalance toward the new target size, +reducing or increasing the short position as needed. This allows safe, controlled management of spot inventory with +minimal noise and predictable hedge behavior. +""" +from decimal import Decimal +from typing import List + +from pydantic import Field + +from hummingbot.core.data_type.common import MarketDict, PositionAction, PositionMode, TradeType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction + + +class HedgeAssetConfig(ControllerConfigBase): + """ + Configuration required to run the GridStrike strategy for one connector and trading pair. + """ + controller_type: str = "generic" + controller_name: str = "hedge_asset" + total_amount_quote: Decimal = Decimal(0) + + # Spot connector + spot_connector_name: str = "binance" + asset_to_hedge: str = "SOL" + + # Perpetual connector + hedge_connector_name: str = "binance_perpetual" + hedge_trading_pair: str = "SOL-USDT" + leverage: int = 20 + position_mode: PositionMode = PositionMode.HEDGE + + # Hedge params + hedge_ratio: Decimal = Field(default=Decimal("0"), ge=0, le=1, json_schema_extra={"is_updatable": True}) + min_notional_size: float = Field(default=10, ge=0) + cooldown_time: float = Field(default=10.0, ge=0) + + def update_markets(self, markets: MarketDict) -> MarketDict: + markets.add_or_update(self.spot_connector_name, self.asset_to_hedge + "-USDC") + markets.add_or_update(self.hedge_connector_name, self.hedge_trading_pair) + return markets + + +class HedgeAssetController(ControllerBase): + def __init__(self, config: HedgeAssetConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.perp_collateral_asset = self.config.hedge_trading_pair.split("-")[1] + self.set_leverage_and_position_mode() + + def set_leverage_and_position_mode(self): + connector = self.market_data_provider.get_connector(self.config.hedge_connector_name) + connector.set_leverage(leverage=self.config.leverage, trading_pair=self.config.hedge_trading_pair) + connector.set_position_mode(self.config.position_mode) + + @property + def hedge_position_size(self) -> Decimal: + hedge_positions = [position for position in self.positions_held if + position.connector_name == self.config.hedge_connector_name and + position.trading_pair == self.config.hedge_trading_pair and + position.side == TradeType.SELL] + if len(hedge_positions) > 0: + hedge_position = hedge_positions[0] + hedge_position_size = hedge_position.amount + else: + hedge_position_size = Decimal("0") + return hedge_position_size + + @property + def last_hedge_timestamp(self) -> float: + if len(self.executors_info) > 0: + return self.executors_info[-1].timestamp + return 0 + + async def update_processed_data(self): + """ + Compute current spot balance, hedge position size, current hedge ratio, last hedge time, current hedge gap quote + """ + current_price = self.market_data_provider.get_price_by_type(self.config.hedge_connector_name, self.config.hedge_trading_pair) + spot_balance = self.market_data_provider.get_balance(self.config.spot_connector_name, self.config.asset_to_hedge) + perp_available_balance = self.market_data_provider.get_available_balance(self.config.hedge_connector_name, self.perp_collateral_asset) + hedge_position_size = self.hedge_position_size + hedge_position_gap = spot_balance * self.config.hedge_ratio - hedge_position_size + hedge_position_gap_quote = hedge_position_gap * current_price + last_hedge_timestamp = self.last_hedge_timestamp + + # if these conditions are true we are allowed to execute a trade + cool_down_time_condition = last_hedge_timestamp + self.config.cooldown_time < self.market_data_provider.time() + min_notional_size_condition = abs(hedge_position_gap_quote) >= self.config.min_notional_size + self.processed_data.update({ + "current_price": current_price, + "spot_balance": spot_balance, + "perp_available_balance": perp_available_balance, + "hedge_position_size": hedge_position_size, + "hedge_position_gap": hedge_position_gap, + "hedge_position_gap_quote": hedge_position_gap_quote, + "last_hedge_timestamp": last_hedge_timestamp, + "cool_down_time_condition": cool_down_time_condition, + "min_notional_size_condition": min_notional_size_condition, + }) + + def determine_executor_actions(self) -> List[ExecutorAction]: + if self.processed_data["cool_down_time_condition"] and self.processed_data["min_notional_size_condition"]: + side = TradeType.SELL if self.processed_data["hedge_position_gap"] >= 0 else TradeType.BUY + order_executor_config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.hedge_connector_name, + trading_pair=self.config.hedge_trading_pair, + side=side, + amount=abs(self.processed_data["hedge_position_gap"]), + price=self.processed_data["current_price"], + leverage=self.config.leverage, + position_action=PositionAction.CLOSE if side == TradeType.BUY else PositionAction.OPEN, + execution_strategy=ExecutionStrategy.MARKET + ) + return [CreateExecutorAction(controller_id=self.config.id, executor_config=order_executor_config)] + return [] + + def to_format_status(self) -> List[str]: + """ + These report will be showing the metrics that are important to determine the state of the hedge. + """ + lines = [] + + # Get data + spot_balance = self.processed_data.get("spot_balance", Decimal("0")) + hedge_position = self.processed_data.get("hedge_position_size", Decimal("0")) + perp_balance = self.processed_data.get("perp_available_balance", Decimal("0")) + current_price = self.processed_data.get("current_price", Decimal("0")) + gap = self.processed_data.get("hedge_position_gap", Decimal("0")) + gap_quote = self.processed_data.get("hedge_position_gap_quote", Decimal("0")) + cooldown_ok = self.processed_data.get("cool_down_time_condition", False) + notional_ok = self.processed_data.get("min_notional_size_condition", False) + + # Calculate theoretical hedge + theoretical_hedge = spot_balance * self.config.hedge_ratio + + # Status indicators + cooldown_status = "✓" if cooldown_ok else "✗" + notional_status = "✓" if notional_ok else "✗" + + # Header + lines.append(f"\n{'=' * 65}") + lines.append(f" HEDGE ASSET CONTROLLER: {self.config.asset_to_hedge} @ {current_price:.4f} {self.perp_collateral_asset}") + lines.append(f"{'=' * 65}") + + # Calculation flow + lines.append(f" Spot Balance: {spot_balance:>10.4f} {self.config.asset_to_hedge}") + lines.append(f" × Hedge Ratio: {self.config.hedge_ratio:>10.1%}") + lines.append(f" {'─' * 61}") + lines.append(f" = Target Hedge: {theoretical_hedge:>10.4f} {self.config.asset_to_hedge}") + lines.append(f" - Current Hedge: {hedge_position:>10.4f} {self.config.asset_to_hedge}") + lines.append(f" {'─' * 61}") + lines.append(f" = Gap: {gap:>10.4f} {self.config.asset_to_hedge} ({gap_quote:>8.2f} {self.perp_collateral_asset})") + lines.append("") + lines.append(f" Perp Balance: {perp_balance:>10.2f} {self.perp_collateral_asset}") + lines.append("") + + # Trading conditions + lines.append(" Trading Conditions:") + lines.append(f" Cooldown ({self.config.cooldown_time:.0f}s): {cooldown_status}") + lines.append(f" Min Notional (≥{self.config.min_notional_size:.0f} {self.perp_collateral_asset}): {notional_status}") + + lines.append(f"{'=' * 65}\n") + + return lines diff --git a/bots/controllers/generic/lp_rebalancer/__init__.py b/bots/controllers/generic/lp_rebalancer/__init__.py new file mode 100644 index 00000000..5553f773 --- /dev/null +++ b/bots/controllers/generic/lp_rebalancer/__init__.py @@ -0,0 +1,3 @@ +from .lp_rebalancer import LPRebalancer, LPRebalancerConfig + +__all__ = ["LPRebalancer", "LPRebalancerConfig"] diff --git a/bots/controllers/generic/lp_rebalancer/lp_rebalancer.py b/bots/controllers/generic/lp_rebalancer/lp_rebalancer.py new file mode 100644 index 00000000..6c400629 --- /dev/null +++ b/bots/controllers/generic/lp_rebalancer/lp_rebalancer.py @@ -0,0 +1,942 @@ +import logging +from decimal import Decimal +from typing import List, Optional + +from hummingbot.core.data_type.common import MarketDict +from hummingbot.core.utils.async_utils import safe_ensure_future +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.logger import HummingbotLogger +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.lp_executor.data_types import LPExecutorConfig, LPExecutorStates +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors_info import ExecutorInfo +from pydantic import Field, field_validator, model_validator + + +class LPRebalancerConfig(ControllerConfigBase): + """ + Configuration for LP Rebalancer Controller. + + Uses total_amount_quote and side for position sizing. + Implements KEEP vs REBALANCE logic based on price limits. + """ + controller_type: str = "generic" + controller_name: str = "lp_rebalancer" + candles_config: List[CandlesConfig] = [] + + # Pool configuration (required) + connector_name: str = "meteora/clmm" + network: str = "solana-mainnet-beta" + trading_pair: str = "" + pool_address: str = "" + + # Position parameters + total_amount_quote: Decimal = Field(default=Decimal("50"), json_schema_extra={"is_updatable": True}) + side: int = Field(default=1, json_schema_extra={"is_updatable": True}) # 0=BOTH, 1=BUY, 2=SELL + position_width_pct: Decimal = Field(default=Decimal("0.5"), json_schema_extra={"is_updatable": True}) + position_offset_pct: Decimal = Field( + default=Decimal("0.01"), + json_schema_extra={"is_updatable": True}, + description="Offset from current price to ensure single-sided positions start out-of-range (e.g., 0.1 = 0.1%)" + ) + + # Rebalancing + rebalance_seconds: int = Field(default=60, json_schema_extra={"is_updatable": True}) + rebalance_threshold_pct: Decimal = Field( + default=Decimal("0.1"), + json_schema_extra={"is_updatable": True}, + description="Price must be this % out of range before rebalance timer starts (e.g., 0.1 = 0.1%, 2 = 2%)" + ) + + # Price limits - overlapping grids for sell and buy ranges + # Sell range: [sell_price_min, sell_price_max] + # Buy range: [buy_price_min, buy_price_max] + sell_price_max: Optional[Decimal] = Field(default=None, json_schema_extra={"is_updatable": True}) + sell_price_min: Optional[Decimal] = Field(default=None, json_schema_extra={"is_updatable": True}) + buy_price_max: Optional[Decimal] = Field(default=None, json_schema_extra={"is_updatable": True}) + buy_price_min: Optional[Decimal] = Field(default=None, json_schema_extra={"is_updatable": True}) + + # Connector-specific params (optional) + strategy_type: Optional[int] = Field(default=None, json_schema_extra={"is_updatable": True}) + + @field_validator("sell_price_min", "sell_price_max", "buy_price_min", "buy_price_max", mode="before") + @classmethod + def validate_price_limits(cls, v): + """Allow null/None values for price limits.""" + if v is None: + return None + return Decimal(str(v)) + + @field_validator("side", mode="before") + @classmethod + def validate_side(cls, v): + """Validate side is 0, 1, or 2.""" + v = int(v) + if v not in (0, 1, 2): + raise ValueError("side must be 0 (BOTH), 1 (BUY), or 2 (SELL)") + return v + + @model_validator(mode="after") + def validate_price_limit_ranges(self): + """Validate that price limit ranges are valid.""" + if self.buy_price_max is not None and self.buy_price_min is not None: + if self.buy_price_max < self.buy_price_min: + raise ValueError("buy_price_max must be >= buy_price_min") + if self.sell_price_max is not None and self.sell_price_min is not None: + if self.sell_price_max < self.sell_price_min: + raise ValueError("sell_price_max must be >= sell_price_min") + return self + + def update_markets(self, markets: MarketDict) -> MarketDict: + """Register the LP connector with trading pair""" + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class LPRebalancer(ControllerBase): + """ + Controller for LP position management with smart rebalancing. + + Key features: + - Uses total_amount_quote for all positions (initial and rebalance) + - Derives rebalance side from price vs last executor's range + - KEEP position when already at limit, REBALANCE when not + - Validates bounds before creating positions + """ + + _logger: Optional[HummingbotLogger] = None + + @classmethod + def logger(cls) -> HummingbotLogger: + if cls._logger is None: + cls._logger = logging.getLogger(__name__) + return cls._logger + + def __init__(self, config: LPRebalancerConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config: LPRebalancerConfig = config + + # Parse token symbols from trading pair + parts = config.trading_pair.split("-") + self._base_token: str = parts[0] if len(parts) >= 2 else "" + self._quote_token: str = parts[1] if len(parts) >= 2 else "" + + # Rebalance tracking + self._pending_rebalance: bool = False + self._pending_rebalance_side: Optional[int] = None # Side for pending rebalance + + # Track the executor we created + self._current_executor_id: Optional[str] = None + + # Track amounts from last closed position (for rebalance sizing) + self._last_closed_base_amount: Optional[Decimal] = None + self._last_closed_quote_amount: Optional[Decimal] = None + self._last_closed_base_fee: Optional[Decimal] = None + self._last_closed_quote_fee: Optional[Decimal] = None + + # Track initial balances for comparison + self._initial_base_balance: Optional[Decimal] = None + self._initial_quote_balance: Optional[Decimal] = None + + # Flag to trigger balance update after position creation + self._pending_balance_update: bool = False + + # Cached pool price (updated in update_processed_data) + self._pool_price: Optional[Decimal] = None + + # Initialize rate sources + self.market_data_provider.initialize_rate_sources([ + ConnectorPair( + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair + ) + ]) + + def active_executor(self) -> Optional[ExecutorInfo]: + """Get current active executor (should be 0 or 1)""" + active = [e for e in self.executors_info if e.is_active] + return active[0] if active else None + + def get_tracked_executor(self) -> Optional[ExecutorInfo]: + """Get the executor we're currently tracking (by ID)""" + if not self._current_executor_id: + return None + for e in self.executors_info: + if e.id == self._current_executor_id: + return e + return None + + def is_tracked_executor_terminated(self) -> bool: + """Check if the executor we created has terminated""" + from hummingbot.strategy_v2.models.base import RunnableStatus + if not self._current_executor_id: + return True + executor = self.get_tracked_executor() + if executor is None: + return True + return executor.status == RunnableStatus.TERMINATED + + def _trigger_balance_update(self): + """Trigger a balance update on the connector after position changes.""" + try: + connector = self.market_data_provider.get_connector(self.config.connector_name) + if hasattr(connector, 'update_balances'): + safe_ensure_future(connector.update_balances()) + self.logger().info("Triggered balance update after position creation") + except Exception as e: + self.logger().debug(f"Could not trigger balance update: {e}") + + def determine_executor_actions(self) -> List[ExecutorAction]: + """Decide whether to create/stop executors""" + # Capture initial balances on first run + if self._initial_base_balance is None: + try: + self._initial_base_balance = self.market_data_provider.get_balance( + self.config.connector_name, self._base_token + ) + self._initial_quote_balance = self.market_data_provider.get_balance( + self.config.connector_name, self._quote_token + ) + except Exception as e: + self.logger().debug(f"Could not capture initial balances: {e}") + + actions = [] + executor = self.active_executor() + + # Track the active executor's ID if we don't have one yet + if executor and not self._current_executor_id: + self._current_executor_id = executor.id + self.logger().info(f"Tracking executor: {executor.id}") + + # No active executor - check if we should create one + if executor is None: + if not self.is_tracked_executor_terminated(): + tracked = self.get_tracked_executor() + self.logger().debug( + f"Waiting for executor {self._current_executor_id} to terminate " + f"(status: {tracked.status if tracked else 'not found'})" + ) + return actions + + # Previous executor terminated - capture final amounts for rebalance sizing + terminated_executor = self.get_tracked_executor() + if terminated_executor and self._pending_rebalance: + self._last_closed_base_amount = Decimal(str(terminated_executor.custom_info.get("base_amount", 0))) + self._last_closed_quote_amount = Decimal(str(terminated_executor.custom_info.get("quote_amount", 0))) + self._last_closed_base_fee = Decimal(str(terminated_executor.custom_info.get("base_fee", 0))) + self._last_closed_quote_fee = Decimal(str(terminated_executor.custom_info.get("quote_fee", 0))) + self.logger().info( + f"Captured closed position amounts: base={self._last_closed_base_amount}, " + f"quote={self._last_closed_quote_amount}, base_fee={self._last_closed_base_fee}, " + f"quote_fee={self._last_closed_quote_fee}" + ) + + # Clear tracking + self._current_executor_id = None + + # Determine side for new position + if self._pending_rebalance and self._pending_rebalance_side is not None: + side = self._pending_rebalance_side + self._pending_rebalance = False + self._pending_rebalance_side = None + else: + side = self.config.side + + # Create executor config with calculated bounds + executor_config = self._create_executor_config(side) + if executor_config is None: + self.logger().warning("Skipping position creation - invalid bounds") + return actions + + actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=executor_config + )) + self._pending_balance_update = True + return actions + + # Trigger balance update after position is created + if self._pending_balance_update: + state = executor.custom_info.get("state") + if state in ("IN_RANGE", "OUT_OF_RANGE"): + self._pending_balance_update = False + self._trigger_balance_update() + + # Check executor state + state = executor.custom_info.get("state") + + # Don't take action while executor is in transition states + if state in [LPExecutorStates.OPENING.value, LPExecutorStates.CLOSING.value]: + return actions + + # Check for rebalancing when out of range + if state == LPExecutorStates.OUT_OF_RANGE.value: + # Check if price is beyond threshold before considering timer + if self._is_beyond_rebalance_threshold(executor): + out_of_range_seconds = executor.custom_info.get("out_of_range_seconds") + if out_of_range_seconds is not None and out_of_range_seconds >= self.config.rebalance_seconds: + rebalance_action = self._handle_rebalance(executor) + if rebalance_action: + actions.append(rebalance_action) + + return actions + + def _handle_rebalance(self, executor: ExecutorInfo) -> Optional[StopExecutorAction]: + """ + Handle rebalancing logic. + + Returns StopExecutorAction if rebalance needed, None if KEEP. + """ + current_price = executor.custom_info.get("current_price") + lower_price = executor.custom_info.get("lower_price") + upper_price = executor.custom_info.get("upper_price") + + if current_price is None or lower_price is None or upper_price is None: + return None + + current_price = Decimal(str(current_price)) + lower_price = Decimal(str(lower_price)) + upper_price = Decimal(str(upper_price)) + + # Step 1: Determine side from price direction (using [lower, upper) convention) + if current_price >= upper_price: + new_side = 1 # BUY - price at or above range + elif current_price < lower_price: + new_side = 2 # SELL - price below range + else: + # Price is in range, shouldn't happen in OUT_OF_RANGE state + self.logger().warning(f"Price {current_price} appears in range [{lower_price}, {upper_price})") + return None + + # Step 2: Check if new position would be valid (price within limits) + if not self._is_price_within_limits(current_price, new_side): + # Don't log repeatedly - this is checked every tick + return None + + # Step 4: Initiate rebalance + self._pending_rebalance = True + self._pending_rebalance_side = new_side + self.logger().info( + f"REBALANCE initiated (side={new_side}, price={current_price}, " + f"old_bounds=[{lower_price}, {upper_price}])" + ) + + return StopExecutorAction( + controller_id=self.config.id, + executor_id=executor.id, + keep_position=False, + ) + + def _is_beyond_rebalance_threshold(self, executor: ExecutorInfo) -> bool: + """ + Check if price is beyond the rebalance threshold. + + Price must be this % out of range before rebalance timer is considered. + """ + current_price = executor.custom_info.get("current_price") + lower_price = executor.custom_info.get("lower_price") + upper_price = executor.custom_info.get("upper_price") + + if current_price is None or lower_price is None or upper_price is None: + return False + + threshold = self.config.rebalance_threshold_pct / Decimal("100") + + # Check if price is beyond threshold above upper or below lower + if current_price > upper_price: + deviation_pct = (current_price - upper_price) / upper_price + return deviation_pct >= threshold + elif current_price < lower_price: + deviation_pct = (lower_price - current_price) / lower_price + return deviation_pct >= threshold + + return False # Price is in range + + def _create_executor_config(self, side: int) -> Optional[LPExecutorConfig]: + """ + Create executor config for the given side. + + Returns None if bounds are invalid. + """ + # Use pool price (fetched in update_processed_data every tick) + current_price = self._pool_price + if current_price is None or current_price == 0: + self.logger().warning("No pool price available - waiting for update_processed_data") + return None + + # Calculate amounts based on side + base_amt, quote_amt = self._calculate_amounts(side, current_price) + + # Calculate bounds + lower_price, upper_price = self._calculate_price_bounds(side, current_price) + + # Validate bounds + if lower_price >= upper_price: + self.logger().warning(f"Invalid bounds [{lower_price}, {upper_price}] - skipping position") + return None + + # Build extra params (connector-specific) + extra_params = {} + if self.config.strategy_type is not None: + extra_params["strategyType"] = self.config.strategy_type + + # Check if bounds were clamped by price limits + clamped = [] + if side == 1: # BUY + if self.config.buy_price_max and upper_price == self.config.buy_price_max: + clamped.append(f"upper=buy_price_max({self.config.buy_price_max})") + if self.config.buy_price_min and lower_price == self.config.buy_price_min: + clamped.append(f"lower=buy_price_min({self.config.buy_price_min})") + elif side == 2: # SELL + if self.config.sell_price_min and lower_price == self.config.sell_price_min: + clamped.append(f"lower=sell_price_min({self.config.sell_price_min})") + if self.config.sell_price_max and upper_price == self.config.sell_price_max: + clamped.append(f"upper=sell_price_max({self.config.sell_price_max})") + + clamped_info = f", clamped: {', '.join(clamped)}" if clamped else "" + offset_pct = self.config.position_offset_pct + self.logger().info( + f"Creating position: side={side}, pool_price={current_price:.2f}, " + f"bounds=[{lower_price:.4f}, {upper_price:.4f}], offset_pct={offset_pct}, " + f"base={base_amt:.4f}, quote={quote_amt:.4f}{clamped_info}" + ) + + return LPExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + pool_address=self.config.pool_address, + lower_price=lower_price, + upper_price=upper_price, + base_amount=base_amt, + quote_amount=quote_amt, + side=side, + position_offset_pct=self.config.position_offset_pct, + extra_params=extra_params if extra_params else None, + keep_position=False, + ) + + def _calculate_amounts(self, side: int, current_price: Decimal) -> tuple: + """ + Calculate base and quote amounts based on side and total_amount_quote. + + For rebalances, clamps to the actual amounts returned from the closed position + to avoid order failures when balance is less than configured total (due to + impermanent loss, fees, or price movement). + + Side 0 (BOTH): split 50/50 + Side 1 (BUY): all quote - clamp to closed position's quote + quote_fee + Side 2 (SELL): all base - clamp to closed position's base + base_fee + """ + total = self.config.total_amount_quote + + # For rebalances, clamp to actual amounts from closed position + # Check if we have captured amounts (indicates this is a rebalance) + has_closed_amounts = (self._last_closed_base_amount is not None or + self._last_closed_quote_amount is not None) + if has_closed_amounts: + if side == 1: # BUY - needs quote token + if self._last_closed_quote_amount is not None: + # Total available = position amount + fees earned + available_quote = self._last_closed_quote_amount + if self._last_closed_quote_fee: + available_quote += self._last_closed_quote_fee + if available_quote < total: + self.logger().info( + f"Clamping quote amount from {total} to {available_quote} {self._quote_token} " + f"(closed position returned {self._last_closed_quote_amount} + {self._last_closed_quote_fee} fees)" + ) + total = available_quote + elif side == 2: # SELL - needs base token + if self._last_closed_base_amount is not None: + # Total available = position amount + fees earned + available_base = self._last_closed_base_amount + if self._last_closed_base_fee: + available_base += self._last_closed_base_fee + available_as_quote = available_base * current_price + if available_as_quote < total: + self.logger().info( + f"Clamping total from {total} to {available_as_quote:.4f} " + f"{self._quote_token} (closed: {self._last_closed_base_amount} + " + f"{self._last_closed_base_fee} fees {self._base_token})" + ) + total = available_as_quote + + # Clear the cached amounts after use + self._last_closed_base_amount = None + self._last_closed_quote_amount = None + self._last_closed_base_fee = None + self._last_closed_quote_fee = None + + if side == 0: # BOTH + quote_amt = total / Decimal("2") + base_amt = quote_amt / current_price + elif side == 1: # BUY + base_amt = Decimal("0") + quote_amt = total + else: # SELL + base_amt = total / current_price + quote_amt = Decimal("0") + + return base_amt, quote_amt + + def _calculate_price_bounds(self, side: int, current_price: Decimal) -> tuple: + """ + Calculate position bounds based on side and price limits. + + Side 0 (BOTH): centered on current price, clamped to [buy_min, sell_max] + Side 1 (BUY): upper = min(current, buy_price_max) * (1 - offset), lower extends width below + Side 2 (SELL): lower = max(current, sell_price_min) * (1 + offset), upper extends width above + + The offset ensures single-sided positions start out-of-range so they only + require one token (SOL for SELL, USDC for BUY). + """ + width = self.config.position_width_pct / Decimal("100") + offset = self.config.position_offset_pct / Decimal("100") + + if side == 0: # BOTH + half_width = width / Decimal("2") + lower_price = current_price * (Decimal("1") - half_width) + upper_price = current_price * (Decimal("1") + half_width) + # Clamp to limits + if self.config.buy_price_min: + lower_price = max(lower_price, self.config.buy_price_min) + if self.config.sell_price_max: + upper_price = min(upper_price, self.config.sell_price_max) + + elif side == 1: # BUY + # Position BELOW current price so we only need quote token (USDC) + if self.config.buy_price_max: + upper_price = min(current_price, self.config.buy_price_max) + else: + upper_price = current_price + # Apply offset to decrease upper bound (ensures out-of-range) + upper_price = upper_price * (Decimal("1") - offset) + lower_price = upper_price * (Decimal("1") - width) + # Clamp lower to floor + if self.config.buy_price_min: + lower_price = max(lower_price, self.config.buy_price_min) + + else: # SELL + # Position ABOVE current price so we only need base token (SOL) + if self.config.sell_price_min: + lower_price = max(current_price, self.config.sell_price_min) + else: + lower_price = current_price + # Apply offset to increase lower bound (ensures out-of-range) + lower_price = lower_price * (Decimal("1") + offset) + upper_price = lower_price * (Decimal("1") + width) + # Clamp upper to ceiling + if self.config.sell_price_max: + upper_price = min(upper_price, self.config.sell_price_max) + + return lower_price, upper_price + + def _is_price_within_limits(self, price: Decimal, side: int) -> bool: + """ + Check if price is within configured limits for the position type. + + Price must be within the range to create a position that's IN_RANGE: + - BUY: price must be within [buy_price_min, buy_price_max] + - SELL: price must be within [sell_price_min, sell_price_max] + - BOTH: price must be within the intersection of both ranges + + If price is outside the range, the position would be immediately OUT_OF_RANGE. + """ + if side == 2: # SELL + if self.config.sell_price_min and price < self.config.sell_price_min: + return False + if self.config.sell_price_max and price > self.config.sell_price_max: + return False + elif side == 1: # BUY + if self.config.buy_price_min and price < self.config.buy_price_min: + return False + if self.config.buy_price_max and price > self.config.buy_price_max: + return False + else: # BOTH - must be within intersection of ranges + # Check buy range + if self.config.buy_price_min and price < self.config.buy_price_min: + return False + if self.config.buy_price_max and price > self.config.buy_price_max: + return False + # Check sell range + if self.config.sell_price_min and price < self.config.sell_price_min: + return False + if self.config.sell_price_max and price > self.config.sell_price_max: + return False + return True + + async def update_processed_data(self): + """Called every tick - always fetch fresh pool price for accurate position creation.""" + try: + connector = self.market_data_provider.get_connector(self.config.connector_name) + if hasattr(connector, 'get_pool_info_by_address'): + pool_info = await connector.get_pool_info_by_address(self.config.pool_address) + if pool_info and pool_info.price: + self._pool_price = Decimal(str(pool_info.price)) + except Exception as e: + self.logger().debug(f"Could not fetch pool price: {e}") + + def to_format_status(self) -> List[str]: + """Format status for display.""" + status = [] + box_width = 100 + price_decimals = 8 # For small-value tokens like memecoins + + # Header + status.append("+" + "-" * box_width + "+") + header = f"| LP Rebalancer: {self.config.trading_pair} on {self.config.connector_name}" + status.append(header + " " * (box_width - len(header) + 1) + "|") + status.append("+" + "-" * box_width + "+") + + # Network, connector, pool + line = f"| Network: {self.config.network}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + line = f"| Pool: {self.config.pool_address}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Position info from current executor (active or transitioning) + executor = self.active_executor() or self.get_tracked_executor() + if executor and not executor.is_done: + position_address = executor.custom_info.get("position_address", "N/A") + line = f"| Position: {position_address}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Config summary + side_names = {0: "BOTH", 1: "BUY", 2: "SELL"} + side_str = side_names.get(self.config.side, '?') + amt = self.config.total_amount_quote + width = self.config.position_width_pct + rebal = self.config.rebalance_seconds + line = f"| Config: side={side_str}, amount={amt} {self._quote_token}, width={width}%, rebal={rebal}s" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Position fees and assets + if executor and not executor.is_done: + custom = executor.custom_info + + # Fees row: base_fee + quote_fee = total + base_fee = Decimal(str(custom.get("base_fee", 0))) + quote_fee = Decimal(str(custom.get("quote_fee", 0))) + fees_earned_quote = Decimal(str(custom.get("fees_earned_quote", 0))) + line = ( + f"| Fees: {float(base_fee):.6f} {self._base_token} + " + f"{float(quote_fee):.6f} {self._quote_token} = {float(fees_earned_quote):.6f}" + ) + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Value row: base_amount + quote_amount = total value + base_amount = Decimal(str(custom.get("base_amount", 0))) + quote_amount = Decimal(str(custom.get("quote_amount", 0))) + total_value_quote = Decimal(str(custom.get("total_value_quote", 0))) + line = ( + f"| Value: {float(base_amount):.6f} {self._base_token} + " + f"{float(quote_amount):.6f} {self._quote_token} = {float(total_value_quote):.4f}" + ) + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Position range visualization + lower_price = executor.custom_info.get("lower_price") + upper_price = executor.custom_info.get("upper_price") + + if lower_price and upper_price and self._pool_price: + # Show rebalance thresholds (convert % to decimal) + # Takes into account price limits - rebalance only happens within limits + threshold = self.config.rebalance_threshold_pct / Decimal("100") + lower_threshold = Decimal(str(lower_price)) * (Decimal("1") - threshold) + upper_threshold = Decimal(str(upper_price)) * (Decimal("1") + threshold) + + # Lower threshold triggers SELL - check sell_price_min + if self.config.sell_price_min and lower_threshold < self.config.sell_price_min: + lower_str = "N/A" # Below sell limit, no rebalance possible + else: + lower_str = f"{float(lower_threshold):.{price_decimals}f}" + + # Upper threshold triggers BUY - check buy_price_max + if self.config.buy_price_max and upper_threshold > self.config.buy_price_max: + upper_str = "N/A" # Above buy limit, no rebalance possible + else: + upper_str = f"{float(upper_threshold):.{price_decimals}f}" + + line = f"| Price: {float(self._pool_price):.{price_decimals}f} | Rebalance if: <{lower_str} or >{upper_str}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + state = executor.custom_info.get("state", "UNKNOWN") + state_icons = { + "IN_RANGE": "●", + "OUT_OF_RANGE": "○", + "OPENING": "◐", + "CLOSING": "◑", + "COMPLETE": "◌", + "NOT_ACTIVE": "○", + } + state_icon = state_icons.get(state, "?") + + status.append("|" + " " * box_width + "|") + line = f"| Position Status: [{state_icon} {state}]" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + range_viz = self._create_price_range_visualization( + Decimal(str(lower_price)), + self._pool_price, + Decimal(str(upper_price)) + ) + for viz_line in range_viz.split('\n'): + line = f"| {viz_line}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Show rebalance timer if out of range + out_of_range_seconds = executor.custom_info.get("out_of_range_seconds") + if out_of_range_seconds is not None: + # Check if beyond threshold + beyond_threshold = self._is_beyond_rebalance_threshold(executor) + if beyond_threshold: + line = f"| Rebalance: {out_of_range_seconds}s / {self.config.rebalance_seconds}s" + else: + line = f"| Rebalance: waiting (below {float(self.config.rebalance_threshold_pct):.2f}% threshold)" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Price limits visualization + has_limits = any([ + self.config.sell_price_min, self.config.sell_price_max, + self.config.buy_price_min, self.config.buy_price_max + ]) + if has_limits and self._pool_price: + # Get position bounds if available + pos_lower = None + pos_upper = None + if executor: + pos_lower = executor.custom_info.get("lower_price") + pos_upper = executor.custom_info.get("upper_price") + if pos_lower: + pos_lower = Decimal(str(pos_lower)) + if pos_upper: + pos_upper = Decimal(str(pos_upper)) + + status.append("|" + " " * box_width + "|") + limits_viz = self._create_price_limits_visualization( + self._pool_price, pos_lower, pos_upper, price_decimals + ) + if limits_viz: + for viz_line in limits_viz.split('\n'): + line = f"| {viz_line}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Balance comparison table (formatted like main balance table) + status.append("|" + " " * box_width + "|") + try: + current_base = self.market_data_provider.get_balance( + self.config.connector_name, self._base_token + ) + current_quote = self.market_data_provider.get_balance( + self.config.connector_name, self._quote_token + ) + + line = "| Balances:" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Table header + header = f"| {'Asset':<12} {'Initial':>14} {'Current':>14} {'Change':>16}" + status.append(header + " " * (box_width - len(header) + 1) + "|") + + # Base token row + if self._initial_base_balance is not None: + base_change = current_base - self._initial_base_balance + init_b = float(self._initial_base_balance) + curr_b = float(current_base) + chg_b = float(base_change) + line = f"| {self._base_token:<12} {init_b:>14.6f} {curr_b:>14.6f} {chg_b:>+16.6f}" + else: + curr_b = float(current_base) + line = f"| {self._base_token:<12} {'N/A':>14} {curr_b:>14.6f} {'N/A':>16}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Quote token row + if self._initial_quote_balance is not None: + quote_change = current_quote - self._initial_quote_balance + init_q = float(self._initial_quote_balance) + curr_q = float(current_quote) + chg_q = float(quote_change) + line = f"| {self._quote_token:<12} {init_q:>14.6f} {curr_q:>14.6f} {chg_q:>+16.6f}" + else: + curr_q = float(current_quote) + line = f"| {self._quote_token:<12} {'N/A':>14} {curr_q:>14.6f} {'N/A':>16}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + except Exception as e: + line = f"| Balances: Error fetching ({e})" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + # Closed positions summary + status.append("|" + " " * box_width + "|") + + closed = [e for e in self.executors_info if e.is_done] + + # Count closed by side (config.side: 0=both, 1=buy, 2=sell) + both_count = len([e for e in closed if getattr(e.config, "side", None) == 0]) + buy_count = len([e for e in closed if getattr(e.config, "side", None) == 1]) + sell_count = len([e for e in closed if getattr(e.config, "side", None) == 2]) + + # Calculate fees from closed positions + total_fees_base = Decimal("0") + total_fees_quote = Decimal("0") + + for e in closed: + total_fees_base += Decimal(str(e.custom_info.get("base_fee", 0))) + total_fees_quote += Decimal(str(e.custom_info.get("quote_fee", 0))) + + pool_price = self._pool_price or Decimal("0") + total_fees_value = total_fees_base * pool_price + total_fees_quote + + line = f"| Closed: {len(closed)} (both:{both_count} buy:{buy_count} sell:{sell_count})" + status.append(line + " " * (box_width - len(line) + 1) + "|") + fb = float(total_fees_base) + fq = float(total_fees_quote) + fv = float(total_fees_value) + line = f"| Fees Collected: {fb:.6f} {self._base_token} + {fq:.6f} {self._quote_token} = {fv:.6f}" + status.append(line + " " * (box_width - len(line) + 1) + "|") + + status.append("+" + "-" * box_width + "+") + return status + + def _create_price_range_visualization(self, lower_price: Decimal, current_price: Decimal, + upper_price: Decimal) -> str: + """Create visual representation of price range with current price marker""" + price_range = upper_price - lower_price + if price_range == 0: + return f"[{float(lower_price):.6f}] (zero width)" + + current_position = (current_price - lower_price) / price_range + bar_width = 50 + current_pos = int(current_position * bar_width) + + range_bar = ['─'] * bar_width + range_bar[0] = '├' + range_bar[-1] = '┤' + + if current_pos < 0: + marker_line = '● ' + ''.join(range_bar) + elif current_pos >= bar_width: + marker_line = ''.join(range_bar) + ' ●' + else: + range_bar[current_pos] = '●' + marker_line = ''.join(range_bar) + + viz_lines = [] + viz_lines.append(marker_line) + lower_str = f'{float(lower_price):.6f}' + upper_str = f'{float(upper_price):.6f}' + viz_lines.append(lower_str + ' ' * (bar_width - len(lower_str) - len(upper_str)) + upper_str) + + return '\n'.join(viz_lines) + + def _create_price_limits_visualization( + self, + current_price: Decimal, + pos_lower: Optional[Decimal] = None, + pos_upper: Optional[Decimal] = None, + price_decimals: int = 8 + ) -> Optional[str]: + """Create visualization of sell/buy price limits on unified scale.""" + viz_lines = [] + + bar_width = 50 + + # Collect all price points to determine unified scale + prices = [current_price] + if self.config.sell_price_min: + prices.append(self.config.sell_price_min) + if self.config.sell_price_max: + prices.append(self.config.sell_price_max) + if self.config.buy_price_min: + prices.append(self.config.buy_price_min) + if self.config.buy_price_max: + prices.append(self.config.buy_price_max) + if pos_lower: + prices.append(pos_lower) + if pos_upper: + prices.append(pos_upper) + + scale_min = min(prices) + scale_max = max(prices) + scale_range = scale_max - scale_min + + if scale_range <= 0: + return None + + def pos_to_idx(price: Decimal) -> int: + return int((price - scale_min) / scale_range * (bar_width - 1)) + + # Get position marker index + price_idx = pos_to_idx(current_price) + + # Helper to create a range bar on unified scale with position marker + def make_range_bar(range_min: Optional[Decimal], range_max: Optional[Decimal], + label: str, fill_char: str = '═', show_position: bool = False) -> str: + if range_min is None or range_max is None: + return "" + + bar = [' '] * bar_width + start_idx = max(0, pos_to_idx(range_min)) + end_idx = min(bar_width - 1, pos_to_idx(range_max)) + + # Fill the range + for i in range(start_idx, end_idx + 1): + bar[i] = fill_char + # Mark boundaries + if 0 <= start_idx < bar_width: + bar[start_idx] = '[' + if 0 <= end_idx < bar_width: + bar[end_idx] = ']' + + # Add position marker if requested + if show_position and 0 <= price_idx < bar_width: + bar[price_idx] = '●' + + return f" {label}: {''.join(bar)}" + + # Build visualization with aligned bars + viz_lines.append("Price Limits:") + + # Create labels with price ranges + if self.config.sell_price_min and self.config.sell_price_max: + s_min = float(self.config.sell_price_min) + s_max = float(self.config.sell_price_max) + sell_label = f"Sell [{s_min:.{price_decimals}f}-{s_max:.{price_decimals}f}]" + else: + sell_label = "Sell" + if self.config.buy_price_min and self.config.buy_price_max: + b_min = float(self.config.buy_price_min) + b_max = float(self.config.buy_price_max) + buy_label = f"Buy [{b_min:.{price_decimals}f}-{b_max:.{price_decimals}f}]" + else: + buy_label = "Buy " + + # Find max label length for alignment + max_label_len = max(len(sell_label), len(buy_label)) + + # Sell range (with position marker) + if self.config.sell_price_min and self.config.sell_price_max: + viz_lines.append(make_range_bar( + self.config.sell_price_min, self.config.sell_price_max, + sell_label.ljust(max_label_len), '═', show_position=True + )) + else: + viz_lines.append(" Sell: No limits set") + + # Buy range (with position marker) + if self.config.buy_price_min and self.config.buy_price_max: + viz_lines.append(make_range_bar( + self.config.buy_price_min, self.config.buy_price_max, + buy_label.ljust(max_label_len), '─', show_position=True + )) + else: + viz_lines.append(" Buy : No limits set") + + # Scale line (aligned with bar start) + min_str = f'{float(scale_min):.{price_decimals}f}' + max_str = f'{float(scale_max):.{price_decimals}f}' + label_padding = max_label_len + 4 # " " prefix + ": " suffix + viz_lines.append(f"{' ' * label_padding}{min_str}{' ' * (bar_width - len(min_str) - len(max_str))}{max_str}") + + return '\n'.join(viz_lines) diff --git a/bots/controllers/generic/multi_grid_strike.py b/bots/controllers/generic/multi_grid_strike.py new file mode 100644 index 00000000..c1b7e683 --- /dev/null +++ b/bots/controllers/generic/multi_grid_strike.py @@ -0,0 +1,291 @@ +from decimal import Decimal +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.grid_executor.data_types import GridExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors_info import ExecutorInfo + + +class GridConfig(BaseModel): + """Configuration for an individual grid""" + grid_id: str + start_price: Decimal = Field(json_schema_extra={"is_updatable": True}) + end_price: Decimal = Field(json_schema_extra={"is_updatable": True}) + limit_price: Decimal = Field(json_schema_extra={"is_updatable": True}) + side: TradeType = Field(json_schema_extra={"is_updatable": True}) + amount_quote_pct: Decimal = Field(json_schema_extra={"is_updatable": True}) # Percentage of total amount (0.0 to 1.0) + enabled: bool = Field(default=True, json_schema_extra={"is_updatable": True}) + + +class MultiGridStrikeConfig(ControllerConfigBase): + """ + Configuration for MultiGridStrike strategy supporting multiple grids + """ + controller_type: str = "generic" + controller_name: str = "multi_grid_strike" + + # Account configuration + leverage: int = 20 + position_mode: PositionMode = PositionMode.HEDGE + + # Common configuration + connector_name: str = "binance_perpetual" + trading_pair: str = "WLD-USDT" + + # Total capital allocation + total_amount_quote: Decimal = Field(default=Decimal("1000"), json_schema_extra={"is_updatable": True}) + + # Grid configurations + grids: List[GridConfig] = Field(default_factory=list, json_schema_extra={"is_updatable": True}) + + # Common grid parameters + min_spread_between_orders: Optional[Decimal] = Field(default=Decimal("0.001"), json_schema_extra={"is_updatable": True}) + min_order_amount_quote: Optional[Decimal] = Field(default=Decimal("5"), json_schema_extra={"is_updatable": True}) + + # Execution + max_open_orders: int = Field(default=2, json_schema_extra={"is_updatable": True}) + max_orders_per_batch: Optional[int] = Field(default=1, json_schema_extra={"is_updatable": True}) + order_frequency: int = Field(default=3, json_schema_extra={"is_updatable": True}) + activation_bounds: Optional[Decimal] = Field(default=None, json_schema_extra={"is_updatable": True}) + keep_position: bool = Field(default=False, json_schema_extra={"is_updatable": True}) + + # Risk Management + triple_barrier_config: TripleBarrierConfig = TripleBarrierConfig( + take_profit=Decimal("0.001"), + open_order_type=OrderType.LIMIT_MAKER, + take_profit_order_type=OrderType.LIMIT_MAKER, + ) + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class MultiGridStrike(ControllerBase): + def __init__(self, config: MultiGridStrikeConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self._last_config_hash = self._get_config_hash() + self._grid_executor_mapping: Dict[str, str] = {} # grid_id -> executor_id + self.trading_rules = None + self.initialize_rate_sources() + + def initialize_rate_sources(self): + self.market_data_provider.initialize_rate_sources([ConnectorPair(connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair)]) + + def _get_config_hash(self) -> str: + """Generate a hash of the current grid configurations""" + return str(hash(tuple( + (g.grid_id, g.start_price, g.end_price, g.limit_price, g.side, g.amount_quote_pct, g.enabled) + for g in self.config.grids + ))) + + def _has_config_changed(self) -> bool: + """Check if configuration has changed""" + current_hash = self._get_config_hash() + changed = current_hash != self._last_config_hash + if changed: + self._last_config_hash = current_hash + return changed + + def active_executors(self) -> List[ExecutorInfo]: + return [ + executor for executor in self.executors_info + if executor.is_active + ] + + def get_executor_by_grid_id(self, grid_id: str) -> Optional[ExecutorInfo]: + """Get executor associated with a specific grid""" + executor_id = self._grid_executor_mapping.get(grid_id) + if executor_id: + for executor in self.executors_info: + if executor.id == executor_id: + return executor + return None + + def calculate_grid_amount(self, grid: GridConfig) -> Decimal: + """Calculate the actual amount for a grid based on its percentage allocation""" + return self.config.total_amount_quote * grid.amount_quote_pct + + def is_inside_bounds(self, price: Decimal, grid: GridConfig) -> bool: + """Check if price is within grid bounds""" + return grid.start_price <= price <= grid.end_price + + def determine_executor_actions(self) -> List[ExecutorAction]: + actions = [] + mid_price = self.market_data_provider.get_price_by_type( + self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + + # Check for config changes + if self._has_config_changed(): + # Handle removed or disabled grids + current_grid_ids = {g.grid_id for g in self.config.grids if g.enabled} + for grid_id, executor_id in list(self._grid_executor_mapping.items()): + if grid_id not in current_grid_ids: + # Stop executor for removed/disabled grid + actions.append(StopExecutorAction( + controller_id=self.config.id, + executor_id=executor_id + )) + del self._grid_executor_mapping[grid_id] + + # Process each enabled grid + for grid in self.config.grids: + if not grid.enabled: + continue + + executor = self.get_executor_by_grid_id(grid.grid_id) + + # Create new executor if none exists and price is in bounds + if executor is None and self.is_inside_bounds(mid_price, grid): + executor_action = CreateExecutorAction( + controller_id=self.config.id, + executor_config=GridExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + start_price=grid.start_price, + end_price=grid.end_price, + leverage=self.config.leverage, + limit_price=grid.limit_price, + side=grid.side, + total_amount_quote=self.calculate_grid_amount(grid), + min_spread_between_orders=self.config.min_spread_between_orders, + min_order_amount_quote=self.config.min_order_amount_quote, + max_open_orders=self.config.max_open_orders, + max_orders_per_batch=self.config.max_orders_per_batch, + order_frequency=self.config.order_frequency, + activation_bounds=self.config.activation_bounds, + triple_barrier_config=self.config.triple_barrier_config, + level_id=grid.grid_id, # Use grid_id as level_id for identification + keep_position=self.config.keep_position, + )) + actions.append(executor_action) + # Note: We'll update the mapping after executor is created + + # Update executor mapping if needed + if executor is None and len(actions) > 0: + # This will be handled in the next cycle after executor is created + pass + + return actions + + async def update_processed_data(self): + # Update executor mapping for newly created executors + for executor in self.active_executors(): + if hasattr(executor.config, 'level_id') and executor.config.level_id: + self._grid_executor_mapping[executor.config.level_id] = executor.id + + def to_format_status(self) -> List[str]: + status = [] + mid_price = self.market_data_provider.get_price_by_type( + self.config.connector_name, self.config.trading_pair, PriceType.MidPrice) + + # Define standard box width for consistency + box_width = 114 + + # Top Multi-Grid Configuration box + status.append("┌" + "─" * box_width + "┐") + + # Header + header = f"│ Multi-Grid Configuration - {self.config.connector_name} {self.config.trading_pair}" + header += " " * (box_width - len(header) + 1) + "│" + status.append(header) + + # Mid price, grid count, and total amount + active_grids = len([g for g in self.config.grids if g.enabled]) + total_grids = len(self.config.grids) + total_amount = self.config.total_amount_quote + info_line = f"│ Mid Price: {mid_price:.4f} │ Active Grids: {active_grids}/{total_grids} │ Total Amount: {total_amount:.2f} │" + info_line += " " * (box_width - len(info_line) + 1) + "│" + status.append(info_line) + + status.append("└" + "─" * box_width + "┘") + + # Display each grid configuration + for grid in self.config.grids: + if not grid.enabled: + continue + + executor = self.get_executor_by_grid_id(grid.grid_id) + in_bounds = self.is_inside_bounds(mid_price, grid) + + # Grid header + grid_status = "ACTIVE" if executor else ("READY" if in_bounds else "OUT_OF_BOUNDS") + status_header = f"Grid {grid.grid_id}: {grid_status}" + status_line = f"┌ {status_header}" + "─" * (box_width - len(status_header) - 2) + "┐" + status.append(status_line) + + # Grid configuration + grid_amount = self.calculate_grid_amount(grid) + pct_display = f"{grid.amount_quote_pct * 100:.1f}%" + config_line = f"│ Start: {grid.start_price:.4f} │ End: {grid.end_price:.4f} │ Side: {grid.side} │ Limit: {grid.limit_price:.4f} │ Amount: {grid_amount:.2f} ({pct_display}) │" + config_line += " " * (box_width - len(config_line) + 1) + "│" + status.append(config_line) + + if executor: + # Display executor statistics + col_width = box_width // 3 + + # Column headers + header_line = "│ Level Distribution" + " " * (col_width - 20) + "│" + header_line += " Order Statistics" + " " * (col_width - 18) + "│" + header_line += " Performance Metrics" + " " * (col_width - 21) + "│" + status.append(header_line) + + # Data columns + level_dist_data = [ + f"NOT_ACTIVE: {len(executor.custom_info.get('levels_by_state', {}).get('NOT_ACTIVE', []))}", + f"OPEN_ORDER_PLACED: {len(executor.custom_info.get('levels_by_state', {}).get('OPEN_ORDER_PLACED', []))}", + f"OPEN_ORDER_FILLED: {len(executor.custom_info.get('levels_by_state', {}).get('OPEN_ORDER_FILLED', []))}", + f"CLOSE_ORDER_PLACED: {len(executor.custom_info.get('levels_by_state', {}).get('CLOSE_ORDER_PLACED', []))}", + f"COMPLETE: {len(executor.custom_info.get('levels_by_state', {}).get('COMPLETE', []))}" + ] + + order_stats_data = [ + f"Total: {sum(len(executor.custom_info.get(k, [])) for k in ['filled_orders', 'failed_orders', 'canceled_orders'])}", + f"Filled: {len(executor.custom_info.get('filled_orders', []))}", + f"Failed: {len(executor.custom_info.get('failed_orders', []))}", + f"Canceled: {len(executor.custom_info.get('canceled_orders', []))}" + ] + + perf_metrics_data = [ + f"Buy Vol: {executor.custom_info.get('realized_buy_size_quote', 0):.4f}", + f"Sell Vol: {executor.custom_info.get('realized_sell_size_quote', 0):.4f}", + f"R. PnL: {executor.custom_info.get('realized_pnl_quote', 0):.4f}", + f"R. Fees: {executor.custom_info.get('realized_fees_quote', 0):.4f}", + f"P. PnL: {executor.custom_info.get('position_pnl_quote', 0):.4f}", + f"Position: {executor.custom_info.get('position_size_quote', 0):.4f}" + ] + + # Build rows + max_rows = max(len(level_dist_data), len(order_stats_data), len(perf_metrics_data)) + for i in range(max_rows): + col1 = level_dist_data[i] if i < len(level_dist_data) else "" + col2 = order_stats_data[i] if i < len(order_stats_data) else "" + col3 = perf_metrics_data[i] if i < len(perf_metrics_data) else "" + + row = "│ " + col1 + row += " " * (col_width - len(col1) - 2) + row += "│ " + col2 + row += " " * (col_width - len(col2) - 2) + row += "│ " + col3 + row += " " * (col_width - len(col3) - 2) + row += "│" + status.append(row) + + # Liquidity line + status.append("├" + "─" * box_width + "┤") + liquidity_line = f"│ Open Liquidity: {executor.custom_info.get('open_liquidity_placed', 0):.4f} │ Close Liquidity: {executor.custom_info.get('close_liquidity_placed', 0):.4f} │" + liquidity_line += " " * (box_width - len(liquidity_line) + 1) + "│" + status.append(liquidity_line) + + status.append("└" + "─" * box_width + "┘") + + return status diff --git a/bots/controllers/generic/pmm.py b/bots/controllers/generic/pmm.py new file mode 100644 index 00000000..97e55135 --- /dev/null +++ b/bots/controllers/generic/pmm.py @@ -0,0 +1,647 @@ +from decimal import Decimal +from typing import List, Optional, Tuple, Union + +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors import CloseType + + +class PMMConfig(ControllerConfigBase): + """ + This class represents the base configuration for a market making controller. + """ + controller_type: str = "generic" + controller_name: str = "pmm" + candles_config: List[CandlesConfig] = [] + connector_name: str = Field( + default="binance", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the name of the connector to use (e.g., binance):", + } + ) + trading_pair: str = Field( + default="BTC-FDUSD", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the trading pair to trade on (e.g., BTC-FDUSD):", + } + ) + portfolio_allocation: Decimal = Field( + default=Decimal("0.05"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the maximum quote exposure percentage around mid price (e.g., 0.05 for 5% of total quote allocation):", + } + ) + target_base_pct: Decimal = Field( + default=Decimal("0.2"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the target base percentage (e.g., 0.2 for 20%):", + } + ) + min_base_pct: Decimal = Field( + default=Decimal("0.1"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the minimum base percentage (e.g., 0.1 for 10%):", + } + ) + max_base_pct: Decimal = Field( + default=Decimal("0.4"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the maximum base percentage (e.g., 0.4 for 40%):", + } + ) + buy_spreads: List[float] = Field( + default="0.01,0.02", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of buy spreads (e.g., '0.01, 0.02'):", + } + ) + sell_spreads: List[float] = Field( + default="0.01,0.02", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of sell spreads (e.g., '0.01, 0.02'):", + } + ) + buy_amounts_pct: Union[List[Decimal], None] = Field( + default=None, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of buy amounts as percentages (e.g., '50, 50'), or leave blank to distribute equally:", + } + ) + sell_amounts_pct: Union[List[Decimal], None] = Field( + default=None, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of sell amounts as percentages (e.g., '50, 50'), or leave blank to distribute equally:", + } + ) + executor_refresh_time: int = Field( + default=60 * 5, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the refresh time in seconds for executors (e.g., 300 for 5 minutes):", + } + ) + cooldown_time: int = Field( + default=15, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the cooldown time in seconds between after replacing an executor that traded (e.g., 15):", + } + ) + leverage: int = Field( + default=20, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the leverage to use for trading (e.g., 20 for 20x leverage). Set it to 1 for spot trading:", + } + ) + position_mode: PositionMode = Field(default="HEDGE") + take_profit: Optional[Decimal] = Field( + default=Decimal("0.02"), gt=0, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the take profit as a decimal (e.g., 0.02 for 2%):", + } + ) + take_profit_order_type: Optional[OrderType] = Field( + default="LIMIT_MAKER", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the order type for take profit (e.g., LIMIT_MAKER):", + } + ) + max_skew: Decimal = Field( + default=Decimal("1.0"), + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the maximum skew factor (e.g., 1.0):", + } + ) + global_take_profit: Decimal = Decimal("0.02") + global_stop_loss: Decimal = Decimal("0.05") + + @field_validator("take_profit", mode="before") + @classmethod + def validate_target(cls, v): + if isinstance(v, str): + if v == "": + return None + return Decimal(v) + return v + + @field_validator('take_profit_order_type', mode="before") + @classmethod + def validate_order_type(cls, v) -> OrderType: + if isinstance(v, OrderType): + return v + elif v is None: + return OrderType.MARKET + elif isinstance(v, str): + if v.upper() in OrderType.__members__: + return OrderType[v.upper()] + elif isinstance(v, int): + try: + return OrderType(v) + except ValueError: + pass + raise ValueError(f"Invalid order type: {v}. Valid options are: {', '.join(OrderType.__members__)}") + + @field_validator('buy_spreads', 'sell_spreads', mode="before") + @classmethod + def parse_spreads(cls, v): + if v is None: + return [] + if isinstance(v, str): + if v == "": + return [] + return [float(x.strip()) for x in v.split(',')] + return v + + @field_validator('buy_amounts_pct', 'sell_amounts_pct', mode="before") + @classmethod + def parse_and_validate_amounts(cls, v, validation_info: ValidationInfo): + field_name = validation_info.field_name + if v is None or v == "": + spread_field = field_name.replace('amounts_pct', 'spreads') + return [1 for _ in validation_info.data[spread_field]] + if isinstance(v, str): + return [float(x.strip()) for x in v.split(',')] + elif isinstance(v, list) and len(v) != len(validation_info.data[field_name.replace('amounts_pct', 'spreads')]): + raise ValueError( + f"The number of {field_name} must match the number of {field_name.replace('amounts_pct', 'spreads')}.") + return v + + @field_validator('position_mode', mode="before") + @classmethod + def validate_position_mode(cls, v) -> PositionMode: + if isinstance(v, str): + if v.upper() in PositionMode.__members__: + return PositionMode[v.upper()] + raise ValueError(f"Invalid position mode: {v}. Valid options are: {', '.join(PositionMode.__members__)}") + return v + + @property + def triple_barrier_config(self) -> TripleBarrierConfig: + return TripleBarrierConfig( + take_profit=self.take_profit, + trailing_stop=None, + open_order_type=OrderType.LIMIT_MAKER, # Defaulting to LIMIT as is a Maker Controller + take_profit_order_type=self.take_profit_order_type, + stop_loss_order_type=OrderType.MARKET, # Defaulting to MARKET as per requirement + time_limit_order_type=OrderType.MARKET # Defaulting to MARKET as per requirement + ) + + def update_parameters(self, trade_type: TradeType, new_spreads: Union[List[float], str], new_amounts_pct: Optional[Union[List[int], str]] = None): + spreads_field = 'buy_spreads' if trade_type == TradeType.BUY else 'sell_spreads' + amounts_pct_field = 'buy_amounts_pct' if trade_type == TradeType.BUY else 'sell_amounts_pct' + + setattr(self, spreads_field, self.parse_spreads(new_spreads)) + if new_amounts_pct is not None: + setattr(self, amounts_pct_field, self.parse_and_validate_amounts(new_amounts_pct, self.__dict__, self.__fields__[amounts_pct_field])) + else: + setattr(self, amounts_pct_field, [1 for _ in getattr(self, spreads_field)]) + + def get_spreads_and_amounts_in_quote(self, trade_type: TradeType) -> Tuple[List[float], List[float]]: + buy_amounts_pct = getattr(self, 'buy_amounts_pct') + sell_amounts_pct = getattr(self, 'sell_amounts_pct') + + # Calculate total percentages across buys and sells + total_pct = sum(buy_amounts_pct) + sum(sell_amounts_pct) + + # Normalize amounts_pct based on total percentages + if trade_type == TradeType.BUY: + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in buy_amounts_pct] + else: # TradeType.SELL + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in sell_amounts_pct] + + spreads = getattr(self, f'{trade_type.name.lower()}_spreads') + return spreads, [amt_pct * self.total_amount_quote * self.portfolio_allocation for amt_pct in normalized_amounts_pct] + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class PMM(ControllerBase): + """ + This class represents the base class for a market making controller. + """ + + def __init__(self, config: PMMConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.market_data_provider.initialize_rate_sources([ConnectorPair( + connector_name=config.connector_name, trading_pair=config.trading_pair)]) + + def determine_executor_actions(self) -> List[ExecutorAction]: + """ + Determine actions based on the provided executor handler report. + """ + actions = [] + actions.extend(self.create_actions_proposal()) + actions.extend(self.stop_actions_proposal()) + return actions + + def create_actions_proposal(self) -> List[ExecutorAction]: + """ + Create actions proposal based on the current state of the controller. + """ + create_actions = [] + + # Check if a position reduction executor for TP/SL is already sent + reduction_executor_exists = any( + executor.is_active and + executor.custom_info.get("level_id") == "global_tp_sl" + for executor in self.executors_info + ) + + if (not reduction_executor_exists and + self.processed_data["current_base_pct"] > self.config.target_base_pct and + (self.processed_data["unrealized_pnl_pct"] > self.config.global_take_profit or + self.processed_data["unrealized_pnl_pct"] < -self.config.global_stop_loss)): + + # Create a global take profit or stop loss executor + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=TradeType.SELL, + amount=self.processed_data["position_amount"], + execution_strategy=ExecutionStrategy.MARKET, + price=self.processed_data["reference_price"], + level_id="global_tp_sl" # Use a specific level_id to identify this as a TP/SL executor + ) + )) + return create_actions + levels_to_execute = self.get_levels_to_execute() + # Pre-calculate all spreads and amounts for buy and sell sides + buy_spreads, buy_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.BUY) + sell_spreads, sell_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.SELL) + reference_price = Decimal(self.processed_data["reference_price"]) + # Get current position info for skew calculation + current_pct = self.processed_data["current_base_pct"] + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + # Calculate skew factors (0 to 1) - how much to scale orders + if max_pct > min_pct: # Prevent division by zero + # For buys: full size at min_pct, decreasing as we approach max_pct + buy_skew = (max_pct - current_pct) / (max_pct - min_pct) + # For sells: full size at max_pct, decreasing as we approach min_pct + sell_skew = (current_pct - min_pct) / (max_pct - min_pct) + # Ensure values stay between 0.2 and 1.0 (never go below 20% of original size) + buy_skew = max(min(buy_skew, Decimal("1.0")), self.config.max_skew) + sell_skew = max(min(sell_skew, Decimal("1.0")), self.config.max_skew) + else: + buy_skew = sell_skew = Decimal("1.0") + # Create executors for each level + for level_id in levels_to_execute: + trade_type = self.get_trade_type_from_level_id(level_id) + level = self.get_level_from_level_id(level_id) + if trade_type == TradeType.BUY: + spread_in_pct = Decimal(buy_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(buy_amounts_quote[level]) + skew = buy_skew + else: # TradeType.SELL + spread_in_pct = Decimal(sell_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(sell_amounts_quote[level]) + skew = sell_skew + # Calculate price + side_multiplier = Decimal("-1") if trade_type == TradeType.BUY else Decimal("1") + price = reference_price * (Decimal("1") + side_multiplier * spread_in_pct) + # Calculate amount with skew applied + amount = self.market_data_provider.quantize_order_amount(self.config.connector_name, + self.config.trading_pair, + (amount_quote / price) * skew) + if amount == Decimal("0"): + self.logger().warning(f"The amount of the level {level_id} is 0. Skipping.") + executor_config = self.get_executor_config(level_id, price, amount) + if executor_config is not None: + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=executor_config + )) + return create_actions + + def get_levels_to_execute(self) -> List[str]: + working_levels = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active or (x.close_type == CloseType.STOP_LOSS and self.market_data_provider.time() - x.close_timestamp < self.config.cooldown_time) + ) + working_levels_ids = [executor.custom_info["level_id"] for executor in working_levels] + return self.get_not_active_levels_ids(working_levels_ids) + + def stop_actions_proposal(self) -> List[ExecutorAction]: + """ + Create a list of actions to stop the executors based on order refresh and early stop conditions. + """ + stop_actions = [] + stop_actions.extend(self.executors_to_refresh()) + stop_actions.extend(self.executors_to_early_stop()) + return stop_actions + + def executors_to_refresh(self) -> List[ExecutorAction]: + executors_to_refresh = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: not x.is_trading and x.is_active and self.market_data_provider.time() - x.timestamp > self.config.executor_refresh_time) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id) for executor in executors_to_refresh] + + def executors_to_early_stop(self) -> List[ExecutorAction]: + """ + Get the executors to early stop based on the current state of market data. This method can be overridden to + implement custom behavior. + """ + executors_to_early_stop = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active and x.is_trading and self.market_data_provider.time() - x.custom_info["open_order_last_update"] > self.config.cooldown_time) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id) for executor in executors_to_early_stop] + + async def update_processed_data(self): + """ + Update the processed data for the controller. This method should be reimplemented to modify the reference price + and spread multiplier based on the market data. By default, it will update the reference price as mid price and + the spread multiplier as 1. + """ + reference_price = self.market_data_provider.get_price_by_type(self.config.connector_name, + self.config.trading_pair, PriceType.MidPrice) + position_held = next((position for position in self.positions_held if + (position.trading_pair == self.config.trading_pair) & + (position.connector_name == self.config.connector_name)), None) + target_position = self.config.total_amount_quote * self.config.target_base_pct + if position_held is not None: + position_amount = position_held.amount + current_base_pct = position_held.amount_quote / self.config.total_amount_quote + deviation = (target_position - position_held.amount_quote) / target_position + unrealized_pnl_pct = position_held.unrealized_pnl_quote / position_held.amount_quote if position_held.amount_quote != 0 else Decimal("0") + else: + position_amount = 0 + current_base_pct = 0 + deviation = 1 + unrealized_pnl_pct = 0 + + self.processed_data = {"reference_price": Decimal(reference_price), "spread_multiplier": Decimal("1"), + "deviation": deviation, "current_base_pct": current_base_pct, + "unrealized_pnl_pct": unrealized_pnl_pct, "position_amount": position_amount} + + def get_executor_config(self, level_id: str, price: Decimal, amount: Decimal): + """ + Get the executor config for a given level id. + """ + trade_type = self.get_trade_type_from_level_id(level_id) + level_multiplier = self.get_level_from_level_id(level_id) + 1 + return PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + level_id=level_id, + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + entry_price=price, + amount=amount, + triple_barrier_config=self.config.triple_barrier_config.new_instance_with_adjusted_volatility(level_multiplier), + leverage=self.config.leverage, + side=trade_type, + ) + + def get_level_id_from_side(self, trade_type: TradeType, level: int) -> str: + """ + Get the level id based on the trade type and the level. + """ + return f"{trade_type.name.lower()}_{level}" + + def get_trade_type_from_level_id(self, level_id: str) -> TradeType: + return TradeType.BUY if level_id.startswith("buy") else TradeType.SELL + + def get_level_from_level_id(self, level_id: str) -> int: + return int(level_id.split('_')[1]) + + def get_not_active_levels_ids(self, active_levels_ids: List[str]) -> List[str]: + """ + Get the levels to execute based on the current state of the controller. + """ + buy_ids_missing = [self.get_level_id_from_side(TradeType.BUY, level) for level in range(len(self.config.buy_spreads)) + if self.get_level_id_from_side(TradeType.BUY, level) not in active_levels_ids] + sell_ids_missing = [self.get_level_id_from_side(TradeType.SELL, level) for level in range(len(self.config.sell_spreads)) + if self.get_level_id_from_side(TradeType.SELL, level) not in active_levels_ids] + if self.processed_data["current_base_pct"] < self.config.min_base_pct: + return buy_ids_missing + elif self.processed_data["current_base_pct"] > self.config.max_base_pct: + return sell_ids_missing + return buy_ids_missing + sell_ids_missing + + def to_format_status(self) -> List[str]: + """ + Get the status of the controller in a formatted way with ASCII visualizations. + """ + from decimal import Decimal + from itertools import zip_longest + + status = [] + + # Get all required data + base_pct = self.processed_data['current_base_pct'] + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + target_pct = self.config.target_base_pct + skew = base_pct - target_pct + skew_pct = skew / target_pct if target_pct != 0 else Decimal('0') + max_skew = getattr(self.config, 'max_skew', Decimal('0.0')) + + # Fixed widths - adjusted based on screenshot analysis + outer_width = 92 # Total width including outer borders + inner_width = outer_width - 4 # Inner content width + half_width = (inner_width) // 2 - 1 # Width of each column in split sections + bar_width = inner_width - 15 # Width of visualization bars (accounting for label) + + # Header - omit ID since it's shown above in controller header + status.append("╒" + "═" * (inner_width) + "╕") + + header_line = ( + f"{self.config.connector_name}:{self.config.trading_pair} " + f"Price: {self.processed_data['reference_price']} " + f"Alloc: {self.config.portfolio_allocation:.1%} " + f"Spread Mult: {self.processed_data['spread_multiplier']} |" + ) + + status.append(f"│ {header_line:<{inner_width}} │") + + # Position and PnL sections with precise widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'POSITION STATUS':<{half_width - 2}} │ {'PROFIT & LOSS':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Position data for left column + position_info = [ + f"Current: {base_pct:.2%}", + f"Target: {target_pct:.2%}", + f"Min/Max: {min_pct:.2%}/{max_pct:.2%}", + f"Skew: {skew_pct:+.2%} (max {max_skew:.2%})" + ] + + # PnL data for right column + pnl_info = [] + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + pnl_sign = "+" if pnl >= 0 else "" + pnl_info = [ + f"Unrealized: {pnl_sign}{pnl:.2%}", + f"Take Profit: {self.config.global_take_profit:.2%}", + f"Stop Loss: {-self.config.global_stop_loss:.2%}", + f"Leverage: {self.config.leverage}x" + ] + + # Display position and PnL info side by side with exact spacing + for pos_line, pnl_line in zip_longest(position_info, pnl_info, fillvalue=""): + status.append(f"│ {pos_line:<{half_width - 2}} │ {pnl_line:<{half_width - 2}} │") + + # Adjust visualization section - ensure consistent spacing + status.append(f"├{'─' * (inner_width)}┤") + status.append(f"│ {'VISUALIZATIONS':<{inner_width}} │") + status.append(f"├{'─' * (inner_width)}┤") + + # Position bar with exact spacing and characters + filled_width = int(base_pct * bar_width) + min_pos = int(min_pct * bar_width) + max_pos = int(max_pct * bar_width) + target_pos = int(target_pct * bar_width) + + # Build position bar character by character + position_bar = "" + for i in range(bar_width): + if i == filled_width: + position_bar += "◆" # Current position + elif i == min_pos: + position_bar += "┃" # Min threshold + elif i == max_pos: + position_bar += "┃" # Max threshold + elif i == target_pos: + position_bar += "┇" # Target threshold + elif i < filled_width: + position_bar += "█" # Filled area + else: + position_bar += "░" # Empty area + + # Ensure consistent label spacing as seen in screenshot + status.append(f"│ Position: [{position_bar}] │") + + # Skew visualization with exact spacing + skew_bar_width = bar_width + center = skew_bar_width // 2 + skew_pos = center + int(skew_pct * center * 2) + skew_pos = max(0, min(skew_bar_width - 1, skew_pos)) + + # Build skew bar character by character + skew_bar = "" + for i in range(skew_bar_width): + if i == center: + skew_bar += "┃" # Center line + elif i == skew_pos: + skew_bar += "⬤" # Current skew + else: + skew_bar += "─" # Empty line + + # Match spacing from screenshot with exact character counts + status.append(f"│ Skew: [{skew_bar}] │") + + # PnL visualization if available + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + take_profit = self.config.global_take_profit + stop_loss = -self.config.global_stop_loss + + pnl_bar_width = bar_width + center = pnl_bar_width // 2 + + # Calculate positions with exact scaling + max_range = max(abs(take_profit), abs(stop_loss), abs(pnl)) * Decimal("1.2") + scale = (pnl_bar_width // 2) / max_range + + pnl_pos = center + int(pnl * scale) + take_profit_pos = center + int(take_profit * scale) + stop_loss_pos = center + int(stop_loss * scale) + + # Ensure positions are within bounds + pnl_pos = max(0, min(pnl_bar_width - 1, pnl_pos)) + take_profit_pos = max(0, min(pnl_bar_width - 1, take_profit_pos)) + stop_loss_pos = max(0, min(pnl_bar_width - 1, stop_loss_pos)) + + # Build PnL bar character by character + pnl_bar = "" + for i in range(pnl_bar_width): + if i == center: + pnl_bar += "│" # Center line + elif i == pnl_pos: + pnl_bar += "⬤" # Current PnL + elif i == take_profit_pos: + pnl_bar += "T" # Take profit line + elif i == stop_loss_pos: + pnl_bar += "S" # Stop loss line + elif (pnl >= 0 and center <= i < pnl_pos) or (pnl < 0 and pnl_pos < i <= center): + pnl_bar += "█" if pnl >= 0 else "▓" + else: + pnl_bar += "─" + + # Match spacing from screenshot + status.append(f"│ PnL: [{pnl_bar}] │") + + # Executors section with precise column widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'EXECUTORS STATUS':<{half_width - 2}} │ {'EXECUTOR VISUALIZATION':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Count active executors by type + active_buy = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.BUY) + active_sell = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.SELL) + total_active = sum(1 for info in self.executors_info if info.is_active) + + # Executor information with fixed formatting + executor_info = [ + f"Total Active: {total_active}", + f"Total Created: {len(self.executors_info)}", + f"Buy Executors: {active_buy}", + f"Sell Executors: {active_sell}" + ] + + if 'deviation' in self.processed_data: + executor_info.append(f"Target Deviation: {self.processed_data['deviation']:.4f}") + + # Visualization with consistent block characters for buy/sell representation + buy_bars = "▮" * active_buy if active_buy > 0 else "─" + sell_bars = "▮" * active_sell if active_sell > 0 else "─" + + executor_viz = [ + f"Buy: {buy_bars}", + f"Sell: {sell_bars}" + ] + + # Display with fixed width columns + for exec_line, viz_line in zip_longest(executor_info, executor_viz, fillvalue=""): + status.append(f"│ {exec_line:<{half_width - 2}} │ {viz_line:<{half_width - 2}} │") + + # Bottom border with exact width + status.append(f"╘{'═' * (inner_width)}╛") + + return status diff --git a/bots/controllers/generic/pmm_adjusted.py b/bots/controllers/generic/pmm_adjusted.py new file mode 100644 index 00000000..e9bc2667 --- /dev/null +++ b/bots/controllers/generic/pmm_adjusted.py @@ -0,0 +1,669 @@ +from decimal import Decimal +from typing import List, Optional, Tuple, Union + +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors import CloseType + + +class PMMAdjustedConfig(ControllerConfigBase): + """ + This class represents the base configuration for a market making controller. + """ + controller_type: str = "generic" + controller_name: str = "pmm_adjusted" + candles_config: List[CandlesConfig] = [] + connector_name: str = Field( + default="binance", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the name of the connector to use (e.g., binance):", + } + ) + trading_pair: str = Field( + default="BTC-FDUSD", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the trading pair to trade on (e.g., BTC-FDUSD):", + } + ) + candles_connector_name: str = Field(default="binance") + candles_trading_pair: str = Field(default="BTC-USDT") + candles_interval: str = Field(default="1s") + + portfolio_allocation: Decimal = Field( + default=Decimal("0.05"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the maximum quote exposure percentage around mid price (e.g., 0.05 for 5% of total quote allocation):", + } + ) + target_base_pct: Decimal = Field( + default=Decimal("0.2"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the target base percentage (e.g., 0.2 for 20%):", + } + ) + min_base_pct: Decimal = Field( + default=Decimal("0.1"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the minimum base percentage (e.g., 0.1 for 10%):", + } + ) + max_base_pct: Decimal = Field( + default=Decimal("0.4"), + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the maximum base percentage (e.g., 0.4 for 40%):", + } + ) + buy_spreads: List[float] = Field( + default="0.01,0.02", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of buy spreads (e.g., '0.01, 0.02'):", + } + ) + sell_spreads: List[float] = Field( + default="0.01,0.02", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of sell spreads (e.g., '0.01, 0.02'):", + } + ) + buy_amounts_pct: Union[List[Decimal], None] = Field( + default=None, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of buy amounts as percentages (e.g., '50, 50'), or leave blank to distribute equally:", + } + ) + sell_amounts_pct: Union[List[Decimal], None] = Field( + default=None, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter a comma-separated list of sell amounts as percentages (e.g., '50, 50'), or leave blank to distribute equally:", + } + ) + executor_refresh_time: int = Field( + default=60 * 5, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the refresh time in seconds for executors (e.g., 300 for 5 minutes):", + } + ) + cooldown_time: int = Field( + default=15, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the cooldown time in seconds between after replacing an executor that traded (e.g., 15):", + } + ) + leverage: int = Field( + default=20, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the leverage to use for trading (e.g., 20 for 20x leverage). Set it to 1 for spot trading:", + } + ) + position_mode: PositionMode = Field(default="HEDGE") + take_profit: Optional[Decimal] = Field( + default=Decimal("0.02"), gt=0, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the take profit as a decimal (e.g., 0.02 for 2%):", + } + ) + take_profit_order_type: Optional[OrderType] = Field( + default="LIMIT_MAKER", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the order type for take profit (e.g., LIMIT_MAKER):", + } + ) + max_skew: Decimal = Field( + default=Decimal("1.0"), + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the maximum skew factor (e.g., 1.0):", + } + ) + global_take_profit: Decimal = Decimal("0.02") + global_stop_loss: Decimal = Decimal("0.05") + + @field_validator("take_profit", mode="before") + @classmethod + def validate_target(cls, v): + if isinstance(v, str): + if v == "": + return None + return Decimal(v) + return v + + @field_validator('take_profit_order_type', mode="before") + @classmethod + def validate_order_type(cls, v) -> OrderType: + if isinstance(v, OrderType): + return v + elif v is None: + return OrderType.MARKET + elif isinstance(v, str): + if v.upper() in OrderType.__members__: + return OrderType[v.upper()] + elif isinstance(v, int): + try: + return OrderType(v) + except ValueError: + pass + raise ValueError(f"Invalid order type: {v}. Valid options are: {', '.join(OrderType.__members__)}") + + @field_validator('buy_spreads', 'sell_spreads', mode="before") + @classmethod + def parse_spreads(cls, v): + if v is None: + return [] + if isinstance(v, str): + if v == "": + return [] + return [float(x.strip()) for x in v.split(',')] + return v + + @field_validator('buy_amounts_pct', 'sell_amounts_pct', mode="before") + @classmethod + def parse_and_validate_amounts(cls, v, validation_info: ValidationInfo): + field_name = validation_info.field_name + if v is None or v == "": + spread_field = field_name.replace('amounts_pct', 'spreads') + return [1 for _ in validation_info.data[spread_field]] + if isinstance(v, str): + return [float(x.strip()) for x in v.split(',')] + elif isinstance(v, list) and len(v) != len(validation_info.data[field_name.replace('amounts_pct', 'spreads')]): + raise ValueError( + f"The number of {field_name} must match the number of {field_name.replace('amounts_pct', 'spreads')}.") + return v + + @field_validator('position_mode', mode="before") + @classmethod + def validate_position_mode(cls, v) -> PositionMode: + if isinstance(v, str): + if v.upper() in PositionMode.__members__: + return PositionMode[v.upper()] + raise ValueError(f"Invalid position mode: {v}. Valid options are: {', '.join(PositionMode.__members__)}") + return v + + @property + def triple_barrier_config(self) -> TripleBarrierConfig: + return TripleBarrierConfig( + take_profit=self.take_profit, + trailing_stop=None, + open_order_type=OrderType.LIMIT_MAKER, # Defaulting to LIMIT as is a Maker Controller + take_profit_order_type=self.take_profit_order_type, + stop_loss_order_type=OrderType.MARKET, # Defaulting to MARKET as per requirement + time_limit_order_type=OrderType.MARKET # Defaulting to MARKET as per requirement + ) + + def update_parameters(self, trade_type: TradeType, new_spreads: Union[List[float], str], new_amounts_pct: Optional[Union[List[int], str]] = None): + spreads_field = 'buy_spreads' if trade_type == TradeType.BUY else 'sell_spreads' + amounts_pct_field = 'buy_amounts_pct' if trade_type == TradeType.BUY else 'sell_amounts_pct' + + setattr(self, spreads_field, self.parse_spreads(new_spreads)) + if new_amounts_pct is not None: + setattr(self, amounts_pct_field, self.parse_and_validate_amounts(new_amounts_pct, self.__dict__, self.__fields__[amounts_pct_field])) + else: + setattr(self, amounts_pct_field, [1 for _ in getattr(self, spreads_field)]) + + def get_spreads_and_amounts_in_quote(self, trade_type: TradeType) -> Tuple[List[float], List[float]]: + buy_amounts_pct = getattr(self, 'buy_amounts_pct') + sell_amounts_pct = getattr(self, 'sell_amounts_pct') + + # Calculate total percentages across buys and sells + total_pct = sum(buy_amounts_pct) + sum(sell_amounts_pct) + + # Normalize amounts_pct based on total percentages + if trade_type == TradeType.BUY: + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in buy_amounts_pct] + else: # TradeType.SELL + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in sell_amounts_pct] + + spreads = getattr(self, f'{trade_type.name.lower()}_spreads') + return spreads, [amt_pct * self.total_amount_quote * self.portfolio_allocation for amt_pct in normalized_amounts_pct] + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class PMMAdjusted(ControllerBase): + """ + This class represents the base class for a market making controller. + """ + + def __init__(self, config: PMMAdjustedConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.market_data_provider.initialize_rate_sources([ConnectorPair( + connector_name=config.connector_name, trading_pair=config.trading_pair)]) + self.config.candles_config = [ + CandlesConfig(connector=self.config.candles_connector_name, + trading_pair=self.config.candles_trading_pair, + interval=self.config.candles_interval) + ] + + def determine_executor_actions(self) -> List[ExecutorAction]: + """ + Determine actions based on the provided executor handler report. + """ + actions = [] + actions.extend(self.create_actions_proposal()) + actions.extend(self.stop_actions_proposal()) + return actions + + def create_actions_proposal(self) -> List[ExecutorAction]: + """ + Create actions proposal based on the current state of the controller. + """ + create_actions = [] + + # Check if a position reduction executor for TP/SL is already sent + reduction_executor_exists = any( + executor.is_active and + executor.custom_info.get("level_id") == "global_tp_sl" + for executor in self.executors_info + ) + + if (not reduction_executor_exists and + self.processed_data["current_base_pct"] > self.config.target_base_pct and + (self.processed_data["unrealized_pnl_pct"] > self.config.global_take_profit or + self.processed_data["unrealized_pnl_pct"] < -self.config.global_stop_loss)): + + # Create a global take profit or stop loss executor + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=TradeType.SELL, + amount=self.processed_data["position_amount"], + execution_strategy=ExecutionStrategy.MARKET, + price=self.processed_data["reference_price"], + level_id="global_tp_sl" # Use a specific level_id to identify this as a TP/SL executor + ) + )) + return create_actions + levels_to_execute = self.get_levels_to_execute() + # Pre-calculate all spreads and amounts for buy and sell sides + buy_spreads, buy_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.BUY) + sell_spreads, sell_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.SELL) + reference_price = Decimal(self.processed_data["reference_price"]) + # Get current position info for skew calculation + current_pct = self.processed_data["current_base_pct"] + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + # Calculate skew factors (0 to 1) - how much to scale orders + if max_pct > min_pct: # Prevent division by zero + # For buys: full size at min_pct, decreasing as we approach max_pct + buy_skew = (max_pct - current_pct) / (max_pct - min_pct) + # For sells: full size at max_pct, decreasing as we approach min_pct + sell_skew = (current_pct - min_pct) / (max_pct - min_pct) + # Ensure values stay between 0.2 and 1.0 (never go below 20% of original size) + buy_skew = max(min(buy_skew, Decimal("1.0")), self.config.max_skew) + sell_skew = max(min(sell_skew, Decimal("1.0")), self.config.max_skew) + else: + buy_skew = sell_skew = Decimal("1.0") + # Create executors for each level + for level_id in levels_to_execute: + trade_type = self.get_trade_type_from_level_id(level_id) + level = self.get_level_from_level_id(level_id) + if trade_type == TradeType.BUY: + spread_in_pct = Decimal(buy_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(buy_amounts_quote[level]) + skew = buy_skew + else: # TradeType.SELL + spread_in_pct = Decimal(sell_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(sell_amounts_quote[level]) + skew = sell_skew + # Calculate price + side_multiplier = Decimal("-1") if trade_type == TradeType.BUY else Decimal("1") + price = reference_price * (Decimal("1") + side_multiplier * spread_in_pct) + # Calculate amount with skew applied + amount = self.market_data_provider.quantize_order_amount(self.config.connector_name, + self.config.trading_pair, + (amount_quote / price) * skew) + if amount == Decimal("0"): + self.logger().warning(f"The amount of the level {level_id} is 0. Skipping.") + executor_config = self.get_executor_config(level_id, price, amount) + if executor_config is not None: + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=executor_config + )) + return create_actions + + def get_levels_to_execute(self) -> List[str]: + working_levels = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active or (x.close_type == CloseType.STOP_LOSS and self.market_data_provider.time() - x.close_timestamp < self.config.cooldown_time) + ) + working_levels_ids = [executor.custom_info["level_id"] for executor in working_levels] + return self.get_not_active_levels_ids(working_levels_ids) + + def stop_actions_proposal(self) -> List[ExecutorAction]: + """ + Create a list of actions to stop the executors based on order refresh and early stop conditions. + """ + stop_actions = [] + stop_actions.extend(self.executors_to_refresh()) + stop_actions.extend(self.executors_to_early_stop()) + return stop_actions + + def executors_to_refresh(self) -> List[ExecutorAction]: + executors_to_refresh = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: not x.is_trading and x.is_active and self.market_data_provider.time() - x.timestamp > self.config.executor_refresh_time) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id) for executor in executors_to_refresh] + + def executors_to_early_stop(self) -> List[ExecutorAction]: + """ + Get the executors to early stop based on the current state of market data. This method can be overridden to + implement custom behavior. + """ + executors_to_early_stop = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active and x.is_trading and self.market_data_provider.time() - x.custom_info["open_order_last_update"] > self.config.cooldown_time) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id) for executor in executors_to_early_stop] + + async def update_processed_data(self): + """ + Update the processed data for the controller. This method should be reimplemented to modify the reference price + and spread multiplier based on the market data. By default, it will update the reference price as mid price and + the spread multiplier as 1. + """ + reference_price = self.get_current_candles_price() + position_held = next((position for position in self.positions_held if + (position.trading_pair == self.config.trading_pair) & + (position.connector_name == self.config.connector_name)), None) + target_position = self.config.total_amount_quote * self.config.target_base_pct + if position_held is not None: + position_amount = position_held.amount + current_base_pct = position_held.amount_quote / self.config.total_amount_quote + deviation = (target_position - position_held.amount_quote) / target_position + unrealized_pnl_pct = position_held.unrealized_pnl_quote / position_held.amount_quote if position_held.amount_quote != 0 else Decimal("0") + else: + position_amount = 0 + current_base_pct = 0 + deviation = 1 + unrealized_pnl_pct = 0 + + self.processed_data = {"reference_price": Decimal(reference_price), "spread_multiplier": Decimal("1"), + "deviation": deviation, "current_base_pct": current_base_pct, + "unrealized_pnl_pct": unrealized_pnl_pct, "position_amount": position_amount} + + def get_current_candles_price(self) -> Decimal: + """ + Get the current price from the candles data provider. + """ + candles = self.market_data_provider.get_candles_df(self.config.candles_connector_name, + self.config.candles_trading_pair, + self.config.candles_interval) + if candles is not None and not candles.empty: + last_candle = candles.iloc[-1] + return Decimal(last_candle['close']) + else: + self.logger().warning(f"No candles data available for {self.config.candles_connector_name} - {self.config.candles_trading_pair} at {self.config.candles_interval}. Using last known price.") + return Decimal(self.market_data_provider.get_price_by_type(self.config.connector_name, self.config.trading_pair, PriceType.MidPrice)) + + def get_executor_config(self, level_id: str, price: Decimal, amount: Decimal): + """ + Get the executor config for a given level id. + """ + trade_type = self.get_trade_type_from_level_id(level_id) + level_multiplier = self.get_level_from_level_id(level_id) + 1 + return PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + level_id=level_id, + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + entry_price=price, + amount=amount, + triple_barrier_config=self.config.triple_barrier_config.new_instance_with_adjusted_volatility(level_multiplier), + leverage=self.config.leverage, + side=trade_type, + ) + + def get_level_id_from_side(self, trade_type: TradeType, level: int) -> str: + """ + Get the level id based on the trade type and the level. + """ + return f"{trade_type.name.lower()}_{level}" + + def get_trade_type_from_level_id(self, level_id: str) -> TradeType: + return TradeType.BUY if level_id.startswith("buy") else TradeType.SELL + + def get_level_from_level_id(self, level_id: str) -> int: + return int(level_id.split('_')[1]) + + def get_not_active_levels_ids(self, active_levels_ids: List[str]) -> List[str]: + """ + Get the levels to execute based on the current state of the controller. + """ + buy_ids_missing = [self.get_level_id_from_side(TradeType.BUY, level) for level in range(len(self.config.buy_spreads)) + if self.get_level_id_from_side(TradeType.BUY, level) not in active_levels_ids] + sell_ids_missing = [self.get_level_id_from_side(TradeType.SELL, level) for level in range(len(self.config.sell_spreads)) + if self.get_level_id_from_side(TradeType.SELL, level) not in active_levels_ids] + if self.processed_data["current_base_pct"] < self.config.min_base_pct: + return buy_ids_missing + elif self.processed_data["current_base_pct"] > self.config.max_base_pct: + return sell_ids_missing + return buy_ids_missing + sell_ids_missing + + def to_format_status(self) -> List[str]: + """ + Get the status of the controller in a formatted way with ASCII visualizations. + """ + from decimal import Decimal + from itertools import zip_longest + + status = [] + + # Get all required data + base_pct = self.processed_data['current_base_pct'] + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + target_pct = self.config.target_base_pct + skew = base_pct - target_pct + skew_pct = skew / target_pct if target_pct != 0 else Decimal('0') + max_skew = getattr(self.config, 'max_skew', Decimal('0.0')) + + # Fixed widths - adjusted based on screenshot analysis + outer_width = 92 # Total width including outer borders + inner_width = outer_width - 4 # Inner content width + half_width = (inner_width) // 2 - 1 # Width of each column in split sections + bar_width = inner_width - 15 # Width of visualization bars (accounting for label) + + # Header - omit ID since it's shown above in controller header + status.append("╒" + "═" * (inner_width) + "╕") + + header_line = ( + f"{self.config.connector_name}:{self.config.trading_pair} " + f"Price: {self.processed_data['reference_price']} " + f"Alloc: {self.config.portfolio_allocation:.1%} " + f"Spread Mult: {self.processed_data['spread_multiplier']} |" + ) + + status.append(f"│ {header_line:<{inner_width}} │") + + # Position and PnL sections with precise widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'POSITION STATUS':<{half_width - 2}} │ {'PROFIT & LOSS':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Position data for left column + position_info = [ + f"Current: {base_pct:.2%}", + f"Target: {target_pct:.2%}", + f"Min/Max: {min_pct:.2%}/{max_pct:.2%}", + f"Skew: {skew_pct:+.2%} (max {max_skew:.2%})" + ] + + # PnL data for right column + pnl_info = [] + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + pnl_sign = "+" if pnl >= 0 else "" + pnl_info = [ + f"Unrealized: {pnl_sign}{pnl:.2%}", + f"Take Profit: {self.config.global_take_profit:.2%}", + f"Stop Loss: {-self.config.global_stop_loss:.2%}", + f"Leverage: {self.config.leverage}x" + ] + + # Display position and PnL info side by side with exact spacing + for pos_line, pnl_line in zip_longest(position_info, pnl_info, fillvalue=""): + status.append(f"│ {pos_line:<{half_width - 2}} │ {pnl_line:<{half_width - 2}} │") + + # Adjust visualization section - ensure consistent spacing + status.append(f"├{'─' * (inner_width)}┤") + status.append(f"│ {'VISUALIZATIONS':<{inner_width}} │") + status.append(f"├{'─' * (inner_width)}┤") + + # Position bar with exact spacing and characters + filled_width = int(base_pct * bar_width) + min_pos = int(min_pct * bar_width) + max_pos = int(max_pct * bar_width) + target_pos = int(target_pct * bar_width) + + # Build position bar character by character + position_bar = "" + for i in range(bar_width): + if i == filled_width: + position_bar += "◆" # Current position + elif i == min_pos: + position_bar += "┃" # Min threshold + elif i == max_pos: + position_bar += "┃" # Max threshold + elif i == target_pos: + position_bar += "┇" # Target threshold + elif i < filled_width: + position_bar += "█" # Filled area + else: + position_bar += "░" # Empty area + + # Ensure consistent label spacing as seen in screenshot + status.append(f"│ Position: [{position_bar}] │") + + # Skew visualization with exact spacing + skew_bar_width = bar_width + center = skew_bar_width // 2 + skew_pos = center + int(skew_pct * center * 2) + skew_pos = max(0, min(skew_bar_width - 1, skew_pos)) + + # Build skew bar character by character + skew_bar = "" + for i in range(skew_bar_width): + if i == center: + skew_bar += "┃" # Center line + elif i == skew_pos: + skew_bar += "⬤" # Current skew + else: + skew_bar += "─" # Empty line + + # Match spacing from screenshot with exact character counts + status.append(f"│ Skew: [{skew_bar}] │") + + # PnL visualization if available + if 'unrealized_pnl_pct' in self.processed_data: + pnl = self.processed_data['unrealized_pnl_pct'] + take_profit = self.config.global_take_profit + stop_loss = -self.config.global_stop_loss + + pnl_bar_width = bar_width + center = pnl_bar_width // 2 + + # Calculate positions with exact scaling + max_range = max(abs(take_profit), abs(stop_loss), abs(pnl)) * Decimal("1.2") + scale = (pnl_bar_width // 2) / max_range + + pnl_pos = center + int(pnl * scale) + take_profit_pos = center + int(take_profit * scale) + stop_loss_pos = center + int(stop_loss * scale) + + # Ensure positions are within bounds + pnl_pos = max(0, min(pnl_bar_width - 1, pnl_pos)) + take_profit_pos = max(0, min(pnl_bar_width - 1, take_profit_pos)) + stop_loss_pos = max(0, min(pnl_bar_width - 1, stop_loss_pos)) + + # Build PnL bar character by character + pnl_bar = "" + for i in range(pnl_bar_width): + if i == center: + pnl_bar += "│" # Center line + elif i == pnl_pos: + pnl_bar += "⬤" # Current PnL + elif i == take_profit_pos: + pnl_bar += "T" # Take profit line + elif i == stop_loss_pos: + pnl_bar += "S" # Stop loss line + elif (pnl >= 0 and center <= i < pnl_pos) or (pnl < 0 and pnl_pos < i <= center): + pnl_bar += "█" if pnl >= 0 else "▓" + else: + pnl_bar += "─" + + # Match spacing from screenshot + status.append(f"│ PnL: [{pnl_bar}] │") + + # Executors section with precise column widths + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'EXECUTORS STATUS':<{half_width - 2}} │ {'EXECUTOR VISUALIZATION':<{half_width - 2}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Count active executors by type + active_buy = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.BUY) + active_sell = sum(1 for info in self.executors_info + if info.is_active and self.get_trade_type_from_level_id(info.custom_info["level_id"]) == TradeType.SELL) + total_active = sum(1 for info in self.executors_info if info.is_active) + + # Executor information with fixed formatting + executor_info = [ + f"Total Active: {total_active}", + f"Total Created: {len(self.executors_info)}", + f"Buy Executors: {active_buy}", + f"Sell Executors: {active_sell}" + ] + + if 'deviation' in self.processed_data: + executor_info.append(f"Target Deviation: {self.processed_data['deviation']:.4f}") + + # Visualization with consistent block characters for buy/sell representation + buy_bars = "▮" * active_buy if active_buy > 0 else "─" + sell_bars = "▮" * active_sell if active_sell > 0 else "─" + + executor_viz = [ + f"Buy: {buy_bars}", + f"Sell: {sell_bars}" + ] + + # Display with fixed width columns + for exec_line, viz_line in zip_longest(executor_info, executor_viz, fillvalue=""): + status.append(f"│ {exec_line:<{half_width - 2}} │ {viz_line:<{half_width - 2}} │") + + # Bottom border with exact width + status.append(f"╘{'═' * (inner_width)}╛") + + return status diff --git a/bots/controllers/generic/pmm_mister.py b/bots/controllers/generic/pmm_mister.py new file mode 100644 index 00000000..03aef7f5 --- /dev/null +++ b/bots/controllers/generic/pmm_mister.py @@ -0,0 +1,1586 @@ +from decimal import Decimal +from typing import Dict, List, Optional, Set, Tuple, Union + +from hummingbot.core.data_type.common import MarketDict, OrderType, PositionMode, PriceType, TradeType +from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.utils.common import parse_comma_separated_list, parse_enum_value +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + + +class PMMisterConfig(ControllerConfigBase): + """ + Advanced PMM (Pure Market Making) controller with sophisticated position management. + Features hanging executors, price distance requirements, and breakeven awareness. + """ + controller_type: str = "generic" + controller_name: str = "pmm_mister" + connector_name: str = Field(default="binance") + trading_pair: str = Field(default="BTC-FDUSD") + portfolio_allocation: Decimal = Field(default=Decimal("0.1"), json_schema_extra={"is_updatable": True}) + target_base_pct: Decimal = Field(default=Decimal("0.5"), json_schema_extra={"is_updatable": True}) + min_base_pct: Decimal = Field(default=Decimal("0.3"), json_schema_extra={"is_updatable": True}) + max_base_pct: Decimal = Field(default=Decimal("0.7"), json_schema_extra={"is_updatable": True}) + buy_spreads: List[float] = Field(default="0.0005", json_schema_extra={"is_updatable": True}) + sell_spreads: List[float] = Field(default="0.0005", json_schema_extra={"is_updatable": True}) + buy_amounts_pct: Union[List[Decimal], None] = Field(default="1", json_schema_extra={"is_updatable": True}) + sell_amounts_pct: Union[List[Decimal], None] = Field(default="1", json_schema_extra={"is_updatable": True}) + executor_refresh_time: int = Field(default=30, json_schema_extra={"is_updatable": True}) + + # Enhanced timing parameters + buy_cooldown_time: int = Field(default=60, json_schema_extra={"is_updatable": True}) + sell_cooldown_time: int = Field(default=60, json_schema_extra={"is_updatable": True}) + buy_position_effectivization_time: int = Field(default=120, json_schema_extra={"is_updatable": True}) + sell_position_effectivization_time: int = Field(default=120, json_schema_extra={"is_updatable": True}) + + # Price distance tolerance - prevents placing new orders when existing ones are too close to current price + price_distance_tolerance: Decimal = Field(default=Decimal("0.0005"), json_schema_extra={"is_updatable": True}) + # Refresh tolerance - triggers replacing open orders when price deviates from theoretical level + refresh_tolerance: Decimal = Field(default=Decimal("0.0005"), json_schema_extra={"is_updatable": True}) + tolerance_scaling: Decimal = Field(default=Decimal("1.2"), json_schema_extra={"is_updatable": True}) + + leverage: int = Field(default=20, json_schema_extra={"is_updatable": True}) + position_mode: PositionMode = Field(default="ONEWAY") + take_profit: Optional[Decimal] = Field(default=Decimal("0.0001"), gt=0, json_schema_extra={"is_updatable": True}) + take_profit_order_type: Optional[OrderType] = Field(default="LIMIT_MAKER", json_schema_extra={"is_updatable": True}) + open_order_type: Optional[OrderType] = Field(default="LIMIT_MAKER", json_schema_extra={"is_updatable": True}) + max_active_executors_by_level: Optional[int] = Field(default=4, json_schema_extra={"is_updatable": True}) + tick_mode: bool = Field(default=False, json_schema_extra={"is_updatable": True}) + position_profit_protection: bool = Field(default=False, json_schema_extra={"is_updatable": True}) + min_skew: Decimal = Field(default=Decimal("1.0"), json_schema_extra={"is_updatable": True}) + global_take_profit: Decimal = Field(default=Decimal("0.03"), json_schema_extra={"is_updatable": True}) + global_stop_loss: Decimal = Field(default=Decimal("0.05"), json_schema_extra={"is_updatable": True}) + + @field_validator("take_profit", mode="before") + @classmethod + def validate_target(cls, v): + if isinstance(v, str): + if v == "": + return None + return Decimal(v) + return v + + @field_validator('take_profit_order_type', mode="before") + @classmethod + def validate_order_type(cls, v) -> OrderType: + if v is None: + return OrderType.MARKET + return parse_enum_value(OrderType, v, "take_profit_order_type") + + @field_validator('open_order_type', mode="before") + @classmethod + def validate_open_order_type(cls, v) -> OrderType: + if v is None: + return OrderType.MARKET + return parse_enum_value(OrderType, v, "open_order_type") + + @field_validator('buy_spreads', 'sell_spreads', mode="before") + @classmethod + def parse_spreads(cls, v): + return parse_comma_separated_list(v) + + @field_validator('buy_amounts_pct', 'sell_amounts_pct', mode="before") + @classmethod + def parse_and_validate_amounts(cls, v, validation_info: ValidationInfo): + field_name = validation_info.field_name + if v is None or v == "": + spread_field = field_name.replace('amounts_pct', 'spreads') + return [1 for _ in validation_info.data[spread_field]] + parsed = parse_comma_separated_list(v) + if isinstance(parsed, list) and len(parsed) != len( + validation_info.data[field_name.replace('amounts_pct', 'spreads')]): + raise ValueError( + f"The number of {field_name} must match the number of {field_name.replace('amounts_pct', 'spreads')}.") + return parsed + + @field_validator('position_mode', mode="before") + @classmethod + def validate_position_mode(cls, v) -> PositionMode: + return parse_enum_value(PositionMode, v, "position_mode") + + @field_validator('price_distance_tolerance', 'refresh_tolerance', 'tolerance_scaling', mode="before") + @classmethod + def validate_tolerance_fields(cls, v, validation_info: ValidationInfo): + field_name = validation_info.field_name + if isinstance(v, str): + return Decimal(v) + if field_name == 'tolerance_scaling' and Decimal(str(v)) <= 0: + raise ValueError(f"{field_name} must be greater than 0") + return v + + @property + def triple_barrier_config(self) -> TripleBarrierConfig: + # Ensure we're passing OrderType enum values, not strings + open_order_type = self.open_order_type if isinstance(self.open_order_type, OrderType) else OrderType.LIMIT_MAKER + take_profit_order_type = self.take_profit_order_type if isinstance(self.take_profit_order_type, + OrderType) else OrderType.LIMIT_MAKER + + return TripleBarrierConfig( + take_profit=self.take_profit, + trailing_stop=None, + open_order_type=open_order_type, + take_profit_order_type=take_profit_order_type, + stop_loss_order_type=OrderType.MARKET, + time_limit_order_type=OrderType.MARKET + ) + + def get_cooldown_time(self, trade_type: TradeType) -> int: + """Get cooldown time for specific trade type""" + return self.buy_cooldown_time if trade_type == TradeType.BUY else self.sell_cooldown_time + + def get_position_effectivization_time(self, trade_type: TradeType) -> int: + """Get position effectivization time for specific trade type""" + return self.buy_position_effectivization_time if trade_type == TradeType.BUY else self.sell_position_effectivization_time + + def get_price_distance_level_tolerance(self, level: int) -> Decimal: + """Get level-specific price distance tolerance (for new order placement). + Prevents placing new orders when existing ones are too close to current price. + """ + return self.price_distance_tolerance * (self.tolerance_scaling ** level) + + def get_refresh_level_tolerance(self, level: int) -> Decimal: + """Get level-specific refresh tolerance (for order replacement). + Triggers replacing open orders when price deviates from theoretical level. + """ + return self.refresh_tolerance * (self.tolerance_scaling ** level) + + def update_parameters(self, trade_type: TradeType, new_spreads: Union[List[float], str], + new_amounts_pct: Optional[Union[List[int], str]] = None): + spreads_field = 'buy_spreads' if trade_type == TradeType.BUY else 'sell_spreads' + amounts_pct_field = 'buy_amounts_pct' if trade_type == TradeType.BUY else 'sell_amounts_pct' + + setattr(self, spreads_field, self.parse_spreads(new_spreads)) + if new_amounts_pct is not None: + setattr(self, amounts_pct_field, + self.parse_and_validate_amounts(new_amounts_pct, self.__dict__, self.__fields__[amounts_pct_field])) + else: + setattr(self, amounts_pct_field, [1 for _ in getattr(self, spreads_field)]) + + def get_spreads_and_amounts_in_quote(self, trade_type: TradeType) -> Tuple[List[float], List[float]]: + buy_amounts_pct = getattr(self, 'buy_amounts_pct') + sell_amounts_pct = getattr(self, 'sell_amounts_pct') + + total_pct = sum(buy_amounts_pct) + sum(sell_amounts_pct) + + if trade_type == TradeType.BUY: + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in buy_amounts_pct] + else: + normalized_amounts_pct = [amt_pct / total_pct for amt_pct in sell_amounts_pct] + + spreads = getattr(self, f'{trade_type.name.lower()}_spreads') + return spreads, [amt_pct * self.total_amount_quote * self.portfolio_allocation for amt_pct in + normalized_amounts_pct] + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class PMMister(ControllerBase): + """ + Advanced PMM (Pure Market Making) controller with sophisticated position management. + Features: + - Hanging executors system for better position control + - Price distance requirements to prevent over-accumulation + - Breakeven awareness for dynamic parameter adjustment + - Separate buy/sell cooldown and effectivization times + """ + + def __init__(self, config: PMMisterConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.market_data_provider.initialize_rate_sources( + [ConnectorPair(connector_name=config.connector_name, trading_pair=config.trading_pair)] + ) + # Price history for visualization (last 60 price points) + self.price_history = [] + self.max_price_history = 60 + # Order history for visualization + self.order_history = [] + self.max_order_history = 20 + # Initialize processed_data to prevent access errors + self.processed_data = {} + + def determine_executor_actions(self) -> List[ExecutorAction]: + """ + Determine actions based on the current state with advanced position management. + """ + actions = [] + + # Create new executors + actions.extend(self.create_actions_proposal()) + + # Stop executors (refresh and early stop) + actions.extend(self.stop_actions_proposal()) + + return actions + + def should_effectivize_executor(self, executor_info, current_time: int) -> bool: + """Check if a hanging executor should be effectivized""" + level_id = executor_info.custom_info.get("level_id", "") + fill_time = executor_info.custom_info["open_order_last_update"] + if not level_id or not fill_time: + return False + + trade_type = self.get_trade_type_from_level_id(level_id) + effectivization_time = self.config.get_position_effectivization_time(trade_type) + + return current_time - fill_time >= effectivization_time + + def calculate_theoretical_price(self, level_id: str, reference_price: Decimal) -> Decimal: + """Calculate the theoretical price for a given level""" + trade_type = self.get_trade_type_from_level_id(level_id) + level = self.get_level_from_level_id(level_id) + + if trade_type == TradeType.BUY: + spreads = self.config.buy_spreads + else: + spreads = self.config.sell_spreads + + if level >= len(spreads): + return reference_price + + spread_in_pct = Decimal(spreads[level]) * Decimal(self.processed_data.get("spread_multiplier", 1)) + side_multiplier = Decimal("-1") if trade_type == TradeType.BUY else Decimal("1") + theoretical_price = reference_price * (Decimal("1") + side_multiplier * spread_in_pct) + + return theoretical_price + + def should_refresh_executor_by_distance(self, executor_info, reference_price: Decimal) -> bool: + """Check if executor should be refreshed due to price distance deviation""" + level_id = executor_info.custom_info.get("level_id", "") + if not level_id or not hasattr(executor_info.config, 'entry_price'): + return False + + current_order_price = executor_info.config.entry_price + theoretical_price = self.calculate_theoretical_price(level_id, reference_price) + + # Calculate distance deviation percentage + if theoretical_price == 0: + return False + + distance_deviation = abs(current_order_price - theoretical_price) / theoretical_price + + # Check if deviation exceeds level-specific refresh tolerance + level = self.get_level_from_level_id(level_id) + level_tolerance = self.config.get_refresh_level_tolerance(level) + return distance_deviation > level_tolerance + + def create_actions_proposal(self) -> List[ExecutorAction]: + """ + Create actions proposal with advanced position management logic. + """ + create_actions = [] + + # Get levels to execute with advanced logic + levels_to_execute = self.get_levels_to_execute() + + # Pre-calculate spreads and amounts + buy_spreads, buy_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.BUY) + sell_spreads, sell_amounts_quote = self.config.get_spreads_and_amounts_in_quote(TradeType.SELL) + reference_price = Decimal(self.processed_data["reference_price"]) + + # Use pre-calculated skew factors from processed_data + buy_skew = self.processed_data["buy_skew"] + sell_skew = self.processed_data["sell_skew"] + + # Create executors for each level + for level_id in levels_to_execute: + trade_type = self.get_trade_type_from_level_id(level_id) + level = self.get_level_from_level_id(level_id) + + if trade_type == TradeType.BUY: + spread_in_pct = Decimal(buy_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(buy_amounts_quote[level]) + else: + spread_in_pct = Decimal(sell_spreads[level]) * Decimal(self.processed_data["spread_multiplier"]) + amount_quote = Decimal(sell_amounts_quote[level]) + + # Apply skew to amount calculation + skew = buy_skew if trade_type == TradeType.BUY else sell_skew + + # Calculate price and amount + side_multiplier = Decimal("-1") if trade_type == TradeType.BUY else Decimal("1") + price = reference_price * (Decimal("1") + side_multiplier * spread_in_pct) + amount = self.market_data_provider.quantize_order_amount( + self.config.connector_name, + self.config.trading_pair, + (amount_quote / price) * skew + ) + + if amount == Decimal("0"): + self.logger().warning(f"The amount of the level {level_id} is 0. Skipping.") + continue + + # Position profit protection: don't place sell orders below breakeven + if self.config.position_profit_protection and trade_type == TradeType.SELL: + breakeven_price = self.processed_data.get("breakeven_price") + if breakeven_price is not None and breakeven_price > 0 and price < breakeven_price: + continue + + executor_config = self.get_executor_config(level_id, price, amount) + if executor_config is not None: + # Track order creation for visualization + self.order_history.append({ + 'timestamp': self.market_data_provider.time(), + 'price': price, + 'side': trade_type.name, + 'level_id': level_id, + 'action': 'CREATE' + }) + if len(self.order_history) > self.max_order_history: + self.order_history.pop(0) + + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=executor_config + )) + + return create_actions + + def get_levels_to_execute(self) -> List[str]: + """ + Get levels to execute with advanced hanging executor logic using the analyzer. + """ + current_time = self.market_data_provider.time() + + # Analyze all levels to understand executor states + all_levels_analysis = self.analyze_all_levels() + + # Get working levels (active or hanging with cooldown) + working_levels_ids = [] + + for analysis in all_levels_analysis: + level_id = analysis["level_id"] + trade_type = self.get_trade_type_from_level_id(level_id) + is_buy = level_id.startswith("buy") + current_price = Decimal(self.processed_data["reference_price"]) + + # Evaluate each condition separately for debugging + has_active_not_trading = len(analysis["active_executors_not_trading"]) > 0 + has_too_many_executors = analysis["total_active_executors"] >= self.config.max_active_executors_by_level + + # Check cooldown condition + has_active_cooldown = False + if analysis["open_order_last_update"]: + cooldown_time = self.config.get_cooldown_time(trade_type) + has_active_cooldown = current_time - analysis["open_order_last_update"] < cooldown_time + + # Enhanced price distance logic with level-specific tolerance + price_distance_violated = False + level = self.get_level_from_level_id(level_id) + + if is_buy and analysis["max_price"]: + # For buy orders, ensure they're not too close to current price + distance_from_current = (current_price - analysis["max_price"]) / current_price + level_tolerance = self.config.get_price_distance_level_tolerance(level) + if distance_from_current < level_tolerance: + price_distance_violated = True + elif not is_buy and analysis["min_price"]: + # For sell orders, ensure they're not too close to current price + distance_from_current = (analysis["min_price"] - current_price) / current_price + level_tolerance = self.config.get_price_distance_level_tolerance(level) + if distance_from_current < level_tolerance: + price_distance_violated = True + + # Level is working if any condition is true + if (has_active_not_trading or + has_too_many_executors or + has_active_cooldown or + price_distance_violated): + working_levels_ids.append(level_id) + continue + return self.get_not_active_levels_ids(working_levels_ids) + + def stop_actions_proposal(self) -> List[ExecutorAction]: + """ + Create stop actions with enhanced refresh logic. + """ + stop_actions = [] + stop_actions.extend(self.executors_to_refresh()) + stop_actions.extend(self.process_hanging_executors()) + return stop_actions + + def executors_to_refresh(self) -> List[ExecutorAction]: + """Refresh executors that have been active too long or deviated too far from theoretical price""" + current_time = self.market_data_provider.time() + reference_price = Decimal(self.processed_data.get("reference_price", Decimal("0"))) + + executors_to_refresh = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: ( + not x.is_trading and x.is_active and ( + # Time-based refresh condition + current_time - x.timestamp > self.config.executor_refresh_time or + # Distance-based refresh condition + (reference_price > 0 and self.should_refresh_executor_by_distance(x, reference_price)) + ) + ) + ) + return [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id + ) for executor in executors_to_refresh] + + def process_hanging_executors(self) -> List[ExecutorAction]: + """Process hanging executors and effectivize them when appropriate""" + current_time = self.market_data_provider.time() + # Find hanging executors that should be effectivized (only is_trading) + executors_to_effectivize = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_trading and self.should_effectivize_executor(x, current_time) + ) + + # Create actions for effectivization (keep position) + effectivize_actions = [StopExecutorAction( + controller_id=self.config.id, + keep_position=True, + executor_id=executor.id + ) for executor in executors_to_effectivize] + + return effectivize_actions + + async def update_processed_data(self): + """ + Update processed data with enhanced condition tracking and analysis. + """ + current_time = self.market_data_provider.time() + + # Safely get reference price with fallback + try: + reference_price = self.market_data_provider.get_price_by_type( + self.config.connector_name, self.config.trading_pair, PriceType.MidPrice + ) + if reference_price is None or reference_price <= 0: + self.logger().warning("Invalid reference price received, using previous price if available") + reference_price = self.processed_data.get("reference_price", Decimal("100")) # Default fallback + except Exception as e: + self.logger().warning(f"Error getting reference price: {e}, using previous price if available") + reference_price = self.processed_data.get("reference_price", Decimal("100")) # Default fallback + + # Update price history for visualization + self.price_history.append({ + 'timestamp': current_time, + 'price': Decimal(reference_price) + }) + if len(self.price_history) > self.max_price_history: + self.price_history.pop(0) + + position_held = next((position for position in self.positions_held if + (position.trading_pair == self.config.trading_pair) & + (position.connector_name == self.config.connector_name)), None) + + target_position = self.config.total_amount_quote * self.config.target_base_pct + + if position_held is not None: + position_amount = position_held.amount + current_base_pct = position_held.amount_quote / self.config.total_amount_quote + deviation = (target_position - position_held.amount_quote) / target_position + unrealized_pnl_pct = ( + position_held.unrealized_pnl_quote / position_held.amount_quote + if position_held.amount_quote != 0 else Decimal("0") + ) + breakeven_price = position_held.breakeven_price + else: + position_amount = 0 + current_base_pct = 0 + deviation = 1 + unrealized_pnl_pct = 0 + breakeven_price = None + + if self.config.tick_mode: + spread_multiplier = ( + self.market_data_provider.get_trading_rules( + self.config.connector_name, self.config.trading_pair + ).min_price_increment / reference_price + ) + else: + spread_multiplier = Decimal("1") + + # Calculate skew factors for position balancing + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + + if max_pct > min_pct: + # Calculate skew factors based on position deviation + buy_skew = (max_pct - current_base_pct) / (max_pct - min_pct) + sell_skew = (current_base_pct - min_pct) / (max_pct - min_pct) + # Apply minimum skew to prevent orders from becoming too small + buy_skew = max(min(buy_skew, Decimal("1.0")), self.config.min_skew) + sell_skew = max(min(sell_skew, Decimal("1.0")), self.config.min_skew) + else: + buy_skew = sell_skew = Decimal("1.0") + + # Enhanced condition tracking - only if we have valid data + cooldown_status = self._calculate_cooldown_status(current_time) + price_distance_analysis = self._calculate_price_distance_analysis(Decimal(reference_price)) + effectivization_tracking = self._calculate_effectivization_tracking(current_time) + level_conditions = self._analyze_level_conditions(current_time, Decimal(reference_price)) + executor_stats = self._calculate_executor_statistics(current_time) + refresh_tracking = self._calculate_refresh_tracking(current_time) + + self.processed_data = { + "reference_price": Decimal(reference_price), + "spread_multiplier": spread_multiplier, + "deviation": deviation, + "current_base_pct": current_base_pct, + "unrealized_pnl_pct": unrealized_pnl_pct, + "position_amount": position_amount, + "breakeven_price": breakeven_price, + "buy_skew": buy_skew, + "sell_skew": sell_skew, + # Enhanced tracking data + "cooldown_status": cooldown_status, + "price_distance_analysis": price_distance_analysis, + "effectivization_tracking": effectivization_tracking, + "level_conditions": level_conditions, + "executor_stats": executor_stats, + "refresh_tracking": refresh_tracking, + "current_time": current_time + } + + def get_executor_config(self, level_id: str, price: Decimal, amount: Decimal): + """Get executor config for a given level""" + trade_type = self.get_trade_type_from_level_id(level_id) + return PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + level_id=level_id, + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + entry_price=price, + amount=amount, + triple_barrier_config=self.config.triple_barrier_config, + leverage=self.config.leverage, + side=trade_type, + ) + + def get_level_id_from_side(self, trade_type: TradeType, level: int) -> str: + """Get level ID based on trade type and level""" + return f"{trade_type.name.lower()}_{level}" + + def get_trade_type_from_level_id(self, level_id: str) -> TradeType: + return TradeType.BUY if level_id.startswith("buy") else TradeType.SELL + + def get_level_from_level_id(self, level_id: str) -> int: + return int(level_id.split('_')[1]) + + def get_not_active_levels_ids(self, active_levels_ids: List[str]) -> List[str]: + """Get levels that should be executed based on position constraints""" + buy_ids_missing = [ + self.get_level_id_from_side(TradeType.BUY, level) + for level in range(len(self.config.buy_spreads)) + if self.get_level_id_from_side(TradeType.BUY, level) not in active_levels_ids + ] + sell_ids_missing = [ + self.get_level_id_from_side(TradeType.SELL, level) + for level in range(len(self.config.sell_spreads)) + if self.get_level_id_from_side(TradeType.SELL, level) not in active_levels_ids + ] + + current_pct = self.processed_data["current_base_pct"] + + if current_pct < self.config.min_base_pct: + return buy_ids_missing + elif current_pct > self.config.max_base_pct: + return sell_ids_missing + + # Position profit protection: filter based on breakeven + if self.config.position_profit_protection: + breakeven_price = self.processed_data.get("breakeven_price") + reference_price = self.processed_data["reference_price"] + target_pct = self.config.target_base_pct + + if breakeven_price is not None and breakeven_price > 0: + if current_pct < target_pct and reference_price < breakeven_price: + return buy_ids_missing # Don't sell at a loss when underweight + elif current_pct > target_pct and reference_price > breakeven_price: + return sell_ids_missing # Don't buy more when overweight and in profit + + return buy_ids_missing + sell_ids_missing + + def analyze_all_levels(self) -> List[Dict]: + """Analyze executors for all levels.""" + level_ids: Set[str] = {e.custom_info.get("level_id") for e in self.executors_info if + "level_id" in e.custom_info} + return [self._analyze_by_level_id(level_id) for level_id in level_ids] + + def _analyze_by_level_id(self, level_id: str) -> Dict: + """Analyze executors for a specific level ID.""" + # Get active executors for level calculations + filtered_executors = [e for e in self.executors_info if + e.custom_info.get("level_id") == level_id and e.is_active] + + active_not_trading = [e for e in filtered_executors if e.is_active and not e.is_trading] + active_trading = [e for e in filtered_executors if e.is_active and e.is_trading] + + # For cooldown calculation, include both active and recently completed executors + all_level_executors = [e for e in self.executors_info if e.custom_info.get("level_id") == level_id] + open_order_last_updates = [ + e.custom_info.get("open_order_last_update") for e in all_level_executors + if "open_order_last_update" in e.custom_info and e.custom_info["open_order_last_update"] is not None + ] + latest_open_order_update = max(open_order_last_updates) if open_order_last_updates else None + + prices = [e.config.entry_price for e in filtered_executors if hasattr(e.config, 'entry_price')] + + return { + "level_id": level_id, + "active_executors_not_trading": active_not_trading, + "active_executors_trading": active_trading, + "total_active_executors": len(active_not_trading) + len(active_trading), + "open_order_last_update": latest_open_order_update, + "min_price": min(prices) if prices else None, + "max_price": max(prices) if prices else None, + } + + def to_format_status(self) -> List[str]: + """ + Comprehensive real-time trading conditions dashboard. + """ + from decimal import Decimal + from itertools import zip_longest + + status = [] + + # Layout dimensions - set early for error cases + outer_width = 170 + inner_width = outer_width - 4 + + # Get all required data with safe fallbacks + if not hasattr(self, 'processed_data') or not self.processed_data: + # Return minimal status if processed_data is not available + status.append("╒" + "═" * inner_width + "╕") + status.append(f"│ {'Initializing controller... please wait':<{inner_width}} │") + status.append(f"╘{'═' * inner_width}╛") + return status + + base_pct = self.processed_data.get('current_base_pct', Decimal("0")) + min_pct = self.config.min_base_pct + max_pct = self.config.max_base_pct + target_pct = self.config.target_base_pct + pnl = self.processed_data.get('unrealized_pnl_pct', Decimal('0')) + breakeven = self.processed_data.get('breakeven_price') + current_price = self.processed_data.get('reference_price', Decimal("0")) + buy_skew = self.processed_data.get('buy_skew', Decimal("1.0")) + sell_skew = self.processed_data.get('sell_skew', Decimal("1.0")) + + # Enhanced condition data + cooldown_status = self.processed_data.get('cooldown_status', {}) + effectivization = self.processed_data.get('effectivization_tracking', {}) + level_conditions = self.processed_data.get('level_conditions', {}) + executor_stats = self.processed_data.get('executor_stats', {}) + refresh_tracking = self.processed_data.get('refresh_tracking', {}) + + # Layout dimensions already set above + + # Smart column distribution for 5 columns + col1_width = 28 # Cooldowns + col2_width = 35 # Price distances + col3_width = 28 # Effectivization + col4_width = 25 # Refresh tracking + col5_width = inner_width - col1_width - col2_width - col3_width - col4_width - 4 # Execution status + + half_width = inner_width // 2 - 1 + bar_width = inner_width - 25 + + # Header with enhanced info + status.append("╒" + "═" * inner_width + "╕") + + header_line = ( + f"{self.config.connector_name}:{self.config.trading_pair} @ {current_price:.2f} " + f"Alloc: {self.config.portfolio_allocation:.1%} " + f"Spread×{self.processed_data['spread_multiplier']:.3f} " + f"Dist: {self.config.price_distance_tolerance:.4%} " + f"Ref: {self.config.refresh_tolerance:.4%} (×{self.config.tolerance_scaling}) " + f"Pos Protect: {'ON' if self.config.position_profit_protection else 'OFF'}" + ) + status.append(f"│ {header_line:<{inner_width}} │") + + # REAL-TIME CONDITIONS DASHBOARD + status.append(f"├{'─' * inner_width}┤") + status.append(f"│ {'🔄 REAL-TIME CONDITIONS DASHBOARD':<{inner_width}} │") + status.append( + f"├{'─' * col1_width}┬{'─' * col2_width}┬{'─' * col3_width}┬{'─' * col4_width}┬{'─' * col5_width}┤") + status.append( + f"│ {'COOLDOWNS':<{col1_width}} │ {'PRICE DISTANCES':<{col2_width}} │ " + f"{'EFFECTIVIZATION':<{col3_width}} │ {'REFRESH TRACKING':<{col4_width}} │ " + f"{'EXECUTION':<{col5_width}} │") + status.append( + f"├{'─' * col1_width}┼{'─' * col2_width}┼{'─' * col3_width}┼{'─' * col4_width}┼{'─' * col5_width}┤") + + # Cooldown information + buy_cooldown = cooldown_status.get('buy', {}) + sell_cooldown = cooldown_status.get('sell', {}) + + cooldown_info = [ + f"BUY: {self._format_cooldown_status(buy_cooldown)}", + f"SELL: {self._format_cooldown_status(sell_cooldown)}", + f"Times: {self.config.buy_cooldown_time}/{self.config.sell_cooldown_time}s", + "" + ] + + # Calculate actual distances for current levels + current_buy_distance = "" + current_sell_distance = "" + + all_levels_analysis = self.analyze_all_levels() + for analysis in all_levels_analysis: + level_id = analysis["level_id"] + is_buy = level_id.startswith("buy") + + if is_buy and analysis["max_price"]: + distance = (current_price - analysis["max_price"]) / current_price + current_buy_distance = f"({distance:.3%})" + elif not is_buy and analysis["min_price"]: + distance = (analysis["min_price"] - current_price) / current_price + current_sell_distance = f"({distance:.3%})" + + # Enhanced price info with unified tolerance approach + violation_marker = " ⚠️" if (current_buy_distance and "(0.0" in current_buy_distance) or ( + current_sell_distance and "(0.0" in current_sell_distance) else "" + + # Show level-specific tolerances + dist_l0 = self.config.get_price_distance_level_tolerance(0) + dist_l1 = self.config.get_price_distance_level_tolerance(1) if len(self.config.buy_spreads) > 1 else None + + price_info = [ + f"L0 Dist: {dist_l0:.4%}{violation_marker}", + f"BUY Current: {current_buy_distance}", + f"L1 Dist: {dist_l1:.4%}" if dist_l1 else "L1: N/A", + f"SELL Current: {current_sell_distance}" + ] + + # Effectivization information + total_hanging = effectivization.get('total_hanging', 0) + ready_count = effectivization.get('ready_for_effectivization', 0) + + effect_info = [ + f"Hanging: {total_hanging}", + f"Ready: {ready_count}", + f"Times: {self.config.buy_position_effectivization_time}s/{self.config.sell_position_effectivization_time}s", + "" + ] + + # Refresh tracking information + near_refresh = refresh_tracking.get('near_refresh', 0) + refresh_ready = refresh_tracking.get('refresh_ready', 0) + distance_violations = refresh_tracking.get('distance_violations', 0) + + refresh_info = [ + f"Near Refresh: {near_refresh}", + f"Ready: {refresh_ready}", + f"Distance Violations: {distance_violations}", + f"Threshold: {self.config.executor_refresh_time}s" + ] + + # Execution status + can_execute_buy = len([level for level in level_conditions.values() if + level.get('trade_type') == 'BUY' and level.get('can_execute')]) + can_execute_sell = len([level for level in level_conditions.values() if + level.get('trade_type') == 'SELL' and level.get('can_execute')]) + total_buy_levels = len(self.config.buy_spreads) + total_sell_levels = len(self.config.sell_spreads) + + execution_info = [ + f"BUY: {can_execute_buy}/{total_buy_levels}", + f"SELL: {can_execute_sell}/{total_sell_levels}", + f"Active: {executor_stats.get('total_active', 0)}", + "" + ] + + # Display conditions in 5 columns + for cool_line, price_line, effect_line, refresh_line, exec_line in zip_longest(cooldown_info, price_info, + effect_info, refresh_info, + execution_info, fillvalue=""): + status.append( + f"│ {cool_line:<{col1_width}} │ {price_line:<{col2_width}} │ " + f"{effect_line:<{col3_width}} │ {refresh_line:<{col4_width}} │ " + f"{exec_line:<{col5_width}} │") + + # LEVEL-BY-LEVEL ANALYSIS + status.append(f"├{'─' * inner_width}┤") + status.append(f"│ {'📊 LEVEL-BY-LEVEL ANALYSIS':<{inner_width}} │") + status.append(f"├{'─' * inner_width}┤") + + # Show level conditions + status.extend(self._format_level_conditions(level_conditions, inner_width)) + + # VISUAL PROGRESS INDICATORS + status.append(f"├{'─' * inner_width}┤") + status.append(f"│ {'🔄 VISUAL PROGRESS INDICATORS':<{inner_width}} │") + status.append(f"├{'─' * inner_width}┤") + + # Cooldown progress bars + if buy_cooldown.get('active') or sell_cooldown.get('active'): + status.extend(self._format_cooldown_bars(buy_cooldown, sell_cooldown, bar_width, inner_width)) + + # Effectivization progress + if total_hanging > 0: + status.extend(self._format_effectivization_bars(effectivization, bar_width, inner_width)) + + # Refresh progress bars + if refresh_tracking.get('refresh_candidates', []): + status.extend(self._format_refresh_bars(refresh_tracking, bar_width, inner_width)) + + # POSITION & PNL DASHBOARD + status.append(f"├{'─' * half_width}┬{'─' * half_width}┤") + status.append(f"│ {'📍 POSITION STATUS':<{half_width}} │ {'💰 PROFIT & LOSS':<{half_width}} │") + status.append(f"├{'─' * half_width}┼{'─' * half_width}┤") + + # Position data with enhanced skew info + skew = base_pct - target_pct + skew_pct = skew / target_pct if target_pct != 0 else Decimal('0') + position_info = [ + f"Current: {base_pct:.2%} (Target: {target_pct:.2%})", + f"Range: {min_pct:.2%} - {max_pct:.2%}", + f"Skew: {skew_pct:+.2%} (min {self.config.min_skew:.2%})", + f"Buy Skew: {buy_skew:.2f} | Sell Skew: {sell_skew:.2f}" + ] + + # Enhanced PnL data + breakeven_str = f"{breakeven:.2f}" if breakeven is not None else "N/A" + pnl_sign = "+" if pnl >= 0 else "" + distance_to_tp = self.config.global_take_profit - pnl if pnl < self.config.global_take_profit else Decimal('0') + distance_to_sl = pnl + self.config.global_stop_loss if pnl > -self.config.global_stop_loss else Decimal('0') + + pnl_info = [ + f"Unrealized: {pnl_sign}{pnl:.2%}", + f"Take Profit: {self.config.global_take_profit:.2%} (Δ{distance_to_tp:.2%})", + f"Stop Loss: {-self.config.global_stop_loss:.2%} (Δ{distance_to_sl:.2%})", + f"Breakeven: {breakeven_str}" + ] + + # Display position and PnL info + for pos_line, pnl_line in zip_longest(position_info, pnl_info, fillvalue=""): + status.append(f"│ {pos_line:<{half_width}} │ {pnl_line:<{half_width}} │") + + # Position visualization with enhanced details + status.append(f"├{'─' * inner_width}┤") + status.extend( + self._format_position_visualization(base_pct, target_pct, min_pct, max_pct, skew_pct, pnl, bar_width, + inner_width)) + + # Bottom border + status.append(f"╘{'═' * inner_width}╛") + + return status + + def _is_executor_too_far_from_price(self, executor_info, current_price: Decimal) -> bool: + """Check if hanging executor is too far from current price and should be stopped""" + if not hasattr(executor_info.config, 'entry_price'): + return False + + entry_price = executor_info.config.entry_price + level_id = executor_info.custom_info.get("level_id", "") + + if not level_id: + return False + + is_buy = level_id.startswith("buy") + + # Calculate price distance + if is_buy: + # For buy orders, stop if they're above current price (inverted) + if entry_price >= current_price: + return True + distance = (current_price - entry_price) / current_price + max_distance = Decimal("0.05") # 5% maximum distance + else: + # For sell orders, stop if they're below current price + if entry_price <= current_price: + return True + distance = (entry_price - current_price) / current_price + max_distance = Decimal("0.05") # 5% maximum distance + + return distance > max_distance + + def _format_cooldown_status(self, cooldown_data: Dict) -> str: + """Format cooldown status for display""" + if not cooldown_data.get('active'): + return "READY ✓" + + remaining = cooldown_data.get('remaining_time', 0) + progress = cooldown_data.get('progress_pct', Decimal('0')) + return f"{remaining:.1f}s ({progress:.0%})" + + def _format_level_conditions(self, level_conditions: Dict, inner_width: int) -> List[str]: + """Format level-by-level conditions analysis""" + lines = [] + + # Group by trade type + buy_levels = {k: v for k, v in level_conditions.items() if v.get('trade_type') == 'BUY'} + sell_levels = {k: v for k, v in level_conditions.items() if v.get('trade_type') == 'SELL'} + + if not buy_levels and not sell_levels: + lines.append(f"│ {'No levels configured':<{inner_width}} │") + return lines + + # BUY levels analysis + if buy_levels: + lines.append(f"│ {'BUY LEVELS:':<{inner_width}} │") + for level_id, conditions in sorted(buy_levels.items()): + status_icon = "✓" if conditions.get('can_execute') else "✗" + blocking = ", ".join(conditions.get('blocking_conditions', [])) + active = conditions.get('active_executors', 0) + hanging = conditions.get('hanging_executors', 0) + + level_line = f" {level_id}: {status_icon} Active:{active} Hanging:{hanging}" + if blocking: + level_line += f" | Blocked: {blocking}" + + lines.append(f"│ {level_line:<{inner_width}} │") + + # SELL levels analysis + if sell_levels: + lines.append(f"│ {'SELL LEVELS:':<{inner_width}} │") + for level_id, conditions in sorted(sell_levels.items()): + status_icon = "✓" if conditions.get('can_execute') else "✗" + blocking = ", ".join(conditions.get('blocking_conditions', [])) + active = conditions.get('active_executors', 0) + hanging = conditions.get('hanging_executors', 0) + + level_line = f" {level_id}: {status_icon} Active:{active} Hanging:{hanging}" + if blocking: + level_line += f" | Blocked: {blocking}" + + lines.append(f"│ {level_line:<{inner_width}} │") + + return lines + + def _format_cooldown_bars( + self, buy_cooldown: Dict, sell_cooldown: Dict, bar_width: int, inner_width: int + ) -> List[str]: + """Format cooldown progress bars""" + lines = [] + + if buy_cooldown.get('active'): + progress = float(buy_cooldown.get('progress_pct', 0)) + remaining = buy_cooldown.get('remaining_time', 0) + bar = self._create_progress_bar(progress, bar_width // 2) # Same size as other bars + lines.append(f"│ BUY Cooldown: [{bar}] {remaining:.1f}s remaining │") + + if sell_cooldown.get('active'): + progress = float(sell_cooldown.get('progress_pct', 0)) + remaining = sell_cooldown.get('remaining_time', 0) + bar = self._create_progress_bar(progress, bar_width // 2) # Same size as other bars + lines.append(f"│ SELL Cooldown: [{bar}] {remaining:.1f}s remaining │") + + return lines + + def _format_effectivization_bars(self, effectivization: Dict, bar_width: int, inner_width: int) -> List[str]: + """Format effectivization progress bars""" + lines = [] + + hanging_executors = effectivization.get('hanging_executors', []) + if not hanging_executors: + return lines + + lines.append(f"│ {'EFFECTIVIZATION PROGRESS:':<{inner_width}} │") + + # Show up to 5 hanging executors with progress + for executor in hanging_executors[:5]: + level_id = executor.get('level_id', 'unknown') + trade_type = executor.get('trade_type', 'UNKNOWN') + progress = float(executor.get('progress_pct', 0)) + remaining = executor.get('remaining_time', 0) + ready = executor.get('ready', False) + + bar = self._create_progress_bar(progress, bar_width // 2) + status = "READY!" if ready else f"{remaining}s" + icon = "🔄" if not ready else "✓" + + lines.append(f"│ {icon} {level_id} ({trade_type}): [{bar}] {status:<10} │") + + if len(hanging_executors) > 5: + lines.append(f"│ {'... and ' + str(len(hanging_executors) - 5) + ' more':<{inner_width}} │") + + return lines + + def _format_position_visualization(self, base_pct: Decimal, target_pct: Decimal, min_pct: Decimal, + max_pct: Decimal, skew_pct: Decimal, pnl: Decimal, + bar_width: int, inner_width: int) -> List[str]: + """Format enhanced position visualization""" + lines = [] + + # Position bar + filled_width = int(float(base_pct) * bar_width) + min_pos = int(float(min_pct) * bar_width) + max_pos = int(float(max_pct) * bar_width) + target_pos = int(float(target_pct) * bar_width) + + position_bar = "" + for i in range(bar_width): + if i == filled_width: + position_bar += "◆" # Current position marker + elif i == target_pos: + position_bar += "┇" # Target line + elif i == min_pos: + position_bar += "┃" # Min threshold + elif i == max_pos: + position_bar += "┃" # Max threshold + elif i < filled_width: + position_bar += "█" # Filled area + else: + position_bar += "░" # Empty area + + lines.append(f"│ Position: [{position_bar}] {base_pct:.2%} │") + + # Skew visualization + center = bar_width // 2 + skew_pos = center + int(float(skew_pct) * center) + skew_pos = max(0, min(bar_width - 1, skew_pos)) + + skew_bar = "" + for i in range(bar_width): + if i == center: + skew_bar += "┃" # Center line (neutral) + elif i == skew_pos: + skew_bar += "⬤" # Current skew position + else: + skew_bar += "─" + + skew_direction = "BULLISH" if skew_pct > 0 else "BEARISH" if skew_pct < 0 else "NEUTRAL" + lines.append(f"│ Skew: [{skew_bar}] {skew_direction} │") + + # PnL visualization with dynamic scaling + max_range = max(abs(self.config.global_take_profit), abs(self.config.global_stop_loss), abs(pnl)) * Decimal( + "1.2") + if max_range > 0: + scale = (bar_width // 2) / float(max_range) + pnl_pos = center + int(float(pnl) * scale) + take_profit_pos = center + int(float(self.config.global_take_profit) * scale) + stop_loss_pos = center + int(float(-self.config.global_stop_loss) * scale) + + pnl_pos = max(0, min(bar_width - 1, pnl_pos)) + take_profit_pos = max(0, min(bar_width - 1, take_profit_pos)) + stop_loss_pos = max(0, min(bar_width - 1, stop_loss_pos)) + + pnl_bar = "" + for i in range(bar_width): + if i == center: + pnl_bar += "│" # Zero line + elif i == pnl_pos: + pnl_bar += "⬤" # Current PnL + elif i == take_profit_pos: + pnl_bar += "T" # Take profit target + elif i == stop_loss_pos: + pnl_bar += "S" # Stop loss target + elif ((pnl >= 0 and center <= i < pnl_pos) or + (pnl < 0 and pnl_pos < i <= center)): + pnl_bar += "█" if pnl >= 0 else "▓" # Fill to current PnL + else: + pnl_bar += "─" + else: + pnl_bar = "─" * bar_width + + pnl_status = "PROFIT" if pnl > 0 else "LOSS" if pnl < 0 else "BREAK-EVEN" + lines.append(f"│ PnL: [{pnl_bar}] {pnl_status} │") + + return lines + + def _create_progress_bar(self, progress: float, width: int) -> str: + """Create a progress bar string""" + progress = max(0, min(1, progress)) # Clamp between 0 and 1 + filled = int(progress * width) + + bar = "" + for i in range(width): + if i < filled: + bar += "█" # Filled + elif i == filled and filled < width: + bar += "▌" # Partial fill + else: + bar += "░" # Empty + + return bar + + def _calculate_cooldown_status(self, current_time: int) -> Dict: + """Calculate cooldown status for buy and sell sides""" + cooldown_status = { + "buy": {"active": False, "remaining_time": 0, "progress_pct": Decimal("0")}, + "sell": {"active": False, "remaining_time": 0, "progress_pct": Decimal("0")} + } + + # Get latest order timestamps for each trade type + buy_executors = [e for e in self.executors_info if e.custom_info.get("level_id", "").startswith("buy")] + sell_executors = [e for e in self.executors_info if e.custom_info.get("level_id", "").startswith("sell")] + + for trade_type, executors in [("buy", buy_executors), ("sell", sell_executors)]: + if not executors: + continue + + # Find most recent open order update + latest_updates = [ + e.custom_info.get("open_order_last_update") for e in executors + if "open_order_last_update" in e.custom_info and e.custom_info["open_order_last_update"] is not None + ] + + if not latest_updates: + continue + + latest_update = max(latest_updates) + cooldown_time = (self.config.buy_cooldown_time if trade_type == "buy" + else self.config.sell_cooldown_time) + + time_since_update = current_time - latest_update + remaining_time = max(0, cooldown_time - time_since_update) + + if remaining_time > 0: + cooldown_status[trade_type]["active"] = True + cooldown_status[trade_type]["remaining_time"] = remaining_time + cooldown_status[trade_type]["progress_pct"] = Decimal(str(time_since_update)) / Decimal( + str(cooldown_time)) + else: + cooldown_status[trade_type]["progress_pct"] = Decimal("1") + + return cooldown_status + + def _calculate_price_distance_analysis(self, reference_price: Decimal) -> Dict: + """Analyze price distance conditions for all levels with unified tolerance approach""" + price_analysis = { + "buy": {"violations": [], "distances": [], "base_tolerance": self.config.price_distance_tolerance}, + "sell": {"violations": [], "distances": [], "base_tolerance": self.config.price_distance_tolerance} + } + + # Analyze all levels for price distance violations + all_levels_analysis = self.analyze_all_levels() + + for analysis in all_levels_analysis: + level_id = analysis["level_id"] + is_buy = level_id.startswith("buy") + level = self.get_level_from_level_id(level_id) + + if is_buy and analysis["max_price"]: + current_distance = (reference_price - analysis["max_price"]) / reference_price + level_tolerance = self.config.get_price_distance_level_tolerance(level) + + price_analysis["buy"]["distances"].append({ + "level_id": level_id, + "level": level, + "current_distance": current_distance, + "distance_pct": current_distance, + "tolerance": level_tolerance, + "violates": current_distance < level_tolerance + }) + + if current_distance < level_tolerance: + price_analysis["buy"]["violations"].append(level_id) + + elif not is_buy and analysis["min_price"]: + current_distance = (analysis["min_price"] - reference_price) / reference_price + level_tolerance = self.config.get_price_distance_level_tolerance(level) + + price_analysis["sell"]["distances"].append({ + "level_id": level_id, + "level": level, + "current_distance": current_distance, + "distance_pct": current_distance, + "tolerance": level_tolerance, + "violates": current_distance < level_tolerance + }) + + if current_distance < level_tolerance: + price_analysis["sell"]["violations"].append(level_id) + + return price_analysis + + def _calculate_effectivization_tracking(self, current_time: int) -> Dict: + """Track hanging executor effectivization progress""" + effectivization_data = { + "hanging_executors": [], + "total_hanging": 0, + "ready_for_effectivization": 0 + } + + hanging_executors = [e for e in self.executors_info if e.is_active and e.is_trading] + effectivization_data["total_hanging"] = len(hanging_executors) + + for executor in hanging_executors: + level_id = executor.custom_info.get("level_id", "") + if not level_id: + continue + + trade_type = self.get_trade_type_from_level_id(level_id) + effectivization_time = self.config.get_position_effectivization_time(trade_type) + fill_time = executor.custom_info.get("open_order_last_update", current_time) + + time_elapsed = current_time - fill_time + remaining_time = max(0, effectivization_time - time_elapsed) + progress_pct = min(Decimal("1"), Decimal(str(time_elapsed)) / Decimal(str(effectivization_time))) + + ready = remaining_time == 0 + if ready: + effectivization_data["ready_for_effectivization"] += 1 + + effectivization_data["hanging_executors"].append({ + "level_id": level_id, + "trade_type": trade_type.name, + "time_elapsed": time_elapsed, + "remaining_time": remaining_time, + "progress_pct": progress_pct, + "ready": ready, + "executor_id": executor.id + }) + + return effectivization_data + + def _analyze_level_conditions(self, current_time: int, reference_price: Decimal) -> Dict: + """Analyze conditions preventing each level from executing""" + level_conditions = {} + + # Get all possible levels + all_buy_levels = [self.get_level_id_from_side(TradeType.BUY, i) for i in range(len(self.config.buy_spreads))] + all_sell_levels = [self.get_level_id_from_side(TradeType.SELL, i) for i in range(len(self.config.sell_spreads))] + all_levels = all_buy_levels + all_sell_levels + + # Cache level analysis to avoid redundant calculations + level_analysis_cache = {} + for level_id in all_levels: + level_analysis_cache[level_id] = self._analyze_by_level_id(level_id) + + # Pre-calculate position constraints with safe defaults + if hasattr(self, 'processed_data') and self.processed_data: + current_pct = self.processed_data.get("current_base_pct", Decimal("0")) + breakeven_price = self.processed_data.get("breakeven_price") + else: + current_pct = Decimal("0") + breakeven_price = None + + below_min_position = current_pct < self.config.min_base_pct + above_max_position = current_pct > self.config.max_base_pct + + # Analyze each level + for level_id in all_levels: + trade_type = self.get_trade_type_from_level_id(level_id) + is_buy = level_id.startswith("buy") + + conditions = { + "level_id": level_id, + "trade_type": trade_type.name, + "can_execute": True, + "blocking_conditions": [], + "active_executors": 0, + "hanging_executors": 0 + } + + # Get cached level analysis + level_analysis = level_analysis_cache[level_id] + + # Check various blocking conditions + # 1. Active executor limit + if level_analysis["total_active_executors"] >= self.config.max_active_executors_by_level: + conditions["blocking_conditions"].append("max_active_executors_reached") + conditions["can_execute"] = False + + # 2. Cooldown check + cooldown_time = self.config.get_cooldown_time(trade_type) + if level_analysis["open_order_last_update"]: + time_since_update = current_time - level_analysis["open_order_last_update"] + if time_since_update < cooldown_time: + conditions["blocking_conditions"].append("cooldown_active") + conditions["can_execute"] = False + + # 3. Price distance check with level-specific tolerance + level = self.get_level_from_level_id(level_id) + if is_buy and level_analysis["max_price"]: + distance = (reference_price - level_analysis["max_price"]) / reference_price + level_tolerance = self.config.get_price_distance_level_tolerance(level) + if distance < level_tolerance: + conditions["blocking_conditions"].append("price_distance_violation") + conditions["can_execute"] = False + elif not is_buy and level_analysis["min_price"]: + distance = (level_analysis["min_price"] - reference_price) / reference_price + level_tolerance = self.config.get_price_distance_level_tolerance(level) + if distance < level_tolerance: + conditions["blocking_conditions"].append("price_distance_violation") + conditions["can_execute"] = False + + # 4. Position constraints + if below_min_position and not is_buy: + conditions["blocking_conditions"].append("below_min_position") + conditions["can_execute"] = False + elif above_max_position and is_buy: + conditions["blocking_conditions"].append("above_max_position") + conditions["can_execute"] = False + + # 5. Position profit protection + if (self.config.position_profit_protection and not is_buy and + breakeven_price and breakeven_price > 0 and reference_price < breakeven_price): + conditions["blocking_conditions"].append("position_profit_protection") + conditions["can_execute"] = False + + conditions["active_executors"] = len(level_analysis["active_executors_not_trading"]) + conditions["hanging_executors"] = len(level_analysis["active_executors_trading"]) + + level_conditions[level_id] = conditions + + return level_conditions + + def _calculate_executor_statistics(self, current_time: int) -> Dict: + """Calculate performance statistics for executors""" + stats = { + "total_active": len([e for e in self.executors_info if e.is_active]), + "total_trading": len([e for e in self.executors_info if e.is_active and e.is_trading]), + "total_not_trading": len([e for e in self.executors_info if e.is_active and not e.is_trading]), + "avg_executor_age": Decimal("0"), + "oldest_executor_age": 0, + "refresh_candidates": 0 + } + + active_executors = [e for e in self.executors_info if e.is_active] + + if active_executors: + ages = [current_time - e.timestamp for e in active_executors] + stats["avg_executor_age"] = Decimal(str(sum(ages))) / Decimal(str(len(ages))) + stats["oldest_executor_age"] = max(ages) + + # Count refresh candidates + stats["refresh_candidates"] = len([ + e for e in active_executors + if not e.is_trading and current_time - e.timestamp > self.config.executor_refresh_time + ]) + + return stats + + def _calculate_refresh_tracking(self, current_time: int) -> Dict: + """Track executor refresh progress including distance-based refresh conditions""" + refresh_data = { + "refresh_candidates": [], + "near_refresh": 0, + "refresh_ready": 0, + "distance_violations": 0 + } + + # Get active non-trading executors + active_not_trading = [e for e in self.executors_info if e.is_active and not e.is_trading] + reference_price = Decimal(self.processed_data.get("reference_price", Decimal("0"))) + + for executor in active_not_trading: + age = current_time - executor.timestamp + time_to_refresh = max(0, self.config.executor_refresh_time - age) + progress_pct = min(Decimal("1"), Decimal(str(age)) / Decimal(str(self.config.executor_refresh_time))) + + # Check distance-based refresh condition + distance_violation = (reference_price > 0 and + self.should_refresh_executor_by_distance(executor, reference_price)) + # Calculate distance deviation for display + distance_deviation_pct = Decimal("0") + if reference_price > 0: + level_id = executor.custom_info.get("level_id", "") + if level_id and hasattr(executor.config, 'entry_price'): + theoretical_price = self.calculate_theoretical_price(level_id, reference_price) + if theoretical_price > 0: + distance_deviation_pct = abs( + executor.config.entry_price - theoretical_price) / theoretical_price + + ready_by_time = time_to_refresh == 0 + ready_by_distance = distance_violation + ready = ready_by_time or ready_by_distance + near_refresh = time_to_refresh <= (self.config.executor_refresh_time * 0.2) # Within 20% of refresh time + + if ready: + refresh_data["refresh_ready"] += 1 + elif near_refresh: + refresh_data["near_refresh"] += 1 + + if distance_violation: + refresh_data["distance_violations"] += 1 + + level_id = executor.custom_info.get("level_id", "unknown") + level = self.get_level_from_level_id(level_id) if level_id != "unknown" else 0 + + # Get level-specific refresh tolerance for display + level_tolerance = self.config.get_refresh_level_tolerance( + level) if level_id != "unknown" else self.config.refresh_tolerance + + refresh_data["refresh_candidates"].append({ + "executor_id": executor.id, + "level_id": level_id, + "level": level, + "age": age, + "time_to_refresh": time_to_refresh, + "progress_pct": progress_pct, + "ready": ready, + "ready_by_time": ready_by_time, + "ready_by_distance": ready_by_distance, + "distance_deviation_pct": distance_deviation_pct, + "distance_violation": distance_violation, + "level_tolerance": level_tolerance, + "near_refresh": near_refresh + }) + + return refresh_data + + def _format_refresh_bars(self, refresh_tracking: Dict, bar_width: int, inner_width: int) -> List[str]: + """Format refresh progress bars""" + lines = [] + + refresh_candidates = refresh_tracking.get('refresh_candidates', []) + if not refresh_candidates: + return lines + + lines.append(f"│ {'REFRESH PROGRESS:':<{inner_width}} │") + + # Show up to 5 executors approaching refresh + for candidate in refresh_candidates[:5]: + level_id = candidate.get('level_id', 'unknown') + time_to_refresh = candidate.get('time_to_refresh', 0) + progress = float(candidate.get('progress_pct', 0)) + ready = candidate.get('ready', False) + ready_by_distance = candidate.get('ready_by_distance', False) + distance_deviation_pct = candidate.get('distance_deviation_pct', Decimal('0')) + near_refresh = candidate.get('near_refresh', False) + + bar = self._create_progress_bar(progress, bar_width // 2) + + if ready: + if ready_by_distance: + status = f"DISTANCE! ({distance_deviation_pct:.1%})" + icon = "⚠️" + else: + status = "TIME REFRESH!" + icon = "🔄" + elif near_refresh: + status = f"{time_to_refresh}s (Soon)" + icon = "⏰" + else: + if distance_deviation_pct > 0: + status = f"{time_to_refresh}s ({distance_deviation_pct:.1%})" + else: + status = f"{time_to_refresh}s" + icon = "⏳" + + lines.append(f"│ {icon} {level_id}: [{bar}] {status:<15} │") + + if len(refresh_candidates) > 5: + lines.append(f"│ {'... and ' + str(len(refresh_candidates) - 5) + ' more':<{inner_width}} │") + + return lines + + def _format_price_graph( + self, current_price: Decimal, breakeven_price: Optional[Decimal], inner_width: int + ) -> List[str]: + """Format price graph with order zones and history""" + lines = [] + + if len(self.price_history) < 10: + lines.append(f"│ {'Collecting price data...':<{inner_width}} │") + return lines + + # Get last 30 price points for the graph + recent_prices = [p['price'] for p in self.price_history[-30:]] + min_price = min(recent_prices) + max_price = max(recent_prices) + + # Calculate price range with some padding + price_range = max_price - min_price + if price_range == 0: + price_range = current_price * Decimal('0.01') # 1% range if no movement + + padding = price_range * Decimal('0.1') # 10% padding + graph_min = min_price - padding + graph_max = max_price + padding + graph_range = graph_max - graph_min + + # Calculate order zones using level 0 price distance tolerance + level_0_tolerance = self.config.get_price_distance_level_tolerance(0) + buy_distance = current_price * level_0_tolerance + sell_distance = current_price * level_0_tolerance + buy_zone_price = current_price - buy_distance + sell_zone_price = current_price + sell_distance + + # Graph dimensions + graph_width = inner_width - 20 # Leave space for price labels and borders + graph_height = 8 + + # Create the graph + graph_lines = [] + for row in range(graph_height): + # Calculate price level for this row (top to bottom) + price_level = graph_max - (Decimal(row) / Decimal(graph_height - 1)) * graph_range + line = "" + + # Price label (left side) + price_label = f"{float(price_level):6.2f}" + line += price_label + " ┼" + + # Graph data + for col in range(graph_width): + # Calculate which price point this column represents + col_index = int((col / graph_width) * len(recent_prices)) + if col_index >= len(recent_prices): + col_index = len(recent_prices) - 1 + + price_at_col = recent_prices[col_index] + + # Determine what to show at this position + char = "─" # Default horizontal line + + # Check if current price line crosses this position + if abs(float(price_at_col - price_level)) < float(graph_range) / (graph_height * 2): + if price_at_col == current_price: + char = "●" # Current price marker + else: + char = "·" # Price history point + + # Mark breakeven line + if breakeven_price and abs(float(breakeven_price - price_level)) < float(graph_range) / ( + graph_height * 2): + char = "=" # Breakeven line + + # Mark order zones + if abs(float(buy_zone_price - price_level)) < float(graph_range) / (graph_height * 4): + char = "B" # Buy zone boundary + elif abs(float(sell_zone_price - price_level)) < float(graph_range) / (graph_height * 4): + char = "S" # Sell zone boundary + + # Mark recent orders + for order in self.order_history[-10:]: # Last 10 orders + order_price = order['price'] + if abs(float(order_price - price_level)) < float(graph_range) / (graph_height * 3): + if order['side'] == 'BUY': + char = "b" # Buy order + else: + char = "s" # Sell order + break + + line += char + + # Add right border and annotations + annotation = "" + if abs(float(current_price - price_level)) < float(graph_range) / (graph_height * 2): + annotation = " ← Current" + elif breakeven_price and abs(float(breakeven_price - price_level)) < float(graph_range) / ( + graph_height * 2): + annotation = " ← Breakeven" + elif abs(float(sell_zone_price - price_level)) < float(graph_range) / (graph_height * 4): + annotation = " ← Sell zone" + elif abs(float(buy_zone_price - price_level)) < float(graph_range) / (graph_height * 4): + annotation = " ← Buy zone" + + line += annotation + graph_lines.append(line) + + # Format graph lines with proper padding + for graph_line in graph_lines: + lines.append(f"│ {graph_line:<{inner_width}} │") + + # Add legend + lines.append( + f"│ {'Legend: ● Current price = Breakeven B/S Zone boundaries b/s Recent orders':<{inner_width}} │") + + # Add current metrics + dist_l0 = self.config.get_price_distance_level_tolerance(0) + ref_l0 = self.config.get_refresh_level_tolerance(0) + metrics_line = f"Dist: L0 {dist_l0:.4%} | Refresh: L0 {ref_l0:.4%} | Scaling: ×{self.config.tolerance_scaling}" + if breakeven_price: + distance_to_breakeven = ( + (current_price - breakeven_price) / current_price) if breakeven_price > 0 else Decimal(0) + metrics_line += f" | Breakeven gap: {distance_to_breakeven:+.2%}" + + lines.append(f"│ {metrics_line:<{inner_width}} │") + + return lines diff --git a/bots/controllers/generic/pmm_v1.py b/bots/controllers/generic/pmm_v1.py new file mode 100644 index 00000000..d7314ab6 --- /dev/null +++ b/bots/controllers/generic/pmm_v1.py @@ -0,0 +1,763 @@ +""" +PMM V1 Controller - Pure Market Making Controller + +This controller replicates the legacy pure_market_making strategy with: +- Multi-level spread/amount configuration (list-based) +- Inventory skew calculation matching legacy algorithm +- Order refresh with timing controls and tolerance +- Static and moving price bands +- Minimum spread enforcement +""" + +from decimal import Decimal +from typing import Dict, List, Optional, Tuple + +import numpy as np +from pydantic import Field, field_validator + +from hummingbot.core.data_type.common import MarketDict, PriceType, TradeType +from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.models.base import RunnableStatus +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors import CloseType + + +class PMMV1Config(ControllerConfigBase): + """ + Configuration for the PMM V1 controller - a pure market making controller. + + Implements the core features from legacy pure_market_making strategy. + """ + controller_type: str = "generic" + controller_name: str = "pmm_v1" + + # === Core Market Settings === + connector_name: str = Field( + default="binance", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the connector name (e.g., binance):", + } + ) + trading_pair: str = Field( + default="BTC-USDT", + json_schema_extra={ + "prompt_on_new": True, + "prompt": "Enter the trading pair (e.g., BTC-USDT):", + } + ) + + # === Spread & Amount Configuration === + # Override inherited total_amount_quote — PMM V1 uses order_amount in base asset + total_amount_quote: Decimal = Field(default=Decimal("0"), json_schema_extra={"prompt_on_new": False}) + + order_amount: Decimal = Field( + default=Decimal("1"), + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter the order amount in base asset (e.g., 0.01 for BTC):", + } + ) + buy_spreads: List[float] = Field( + default="0.01", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter comma-separated buy spreads as decimals (e.g., '0.01,0.02' for 1%, 2%):", + } + ) + sell_spreads: List[float] = Field( + default="0.01", + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter comma-separated sell spreads as decimals (e.g., '0.01,0.02' for 1%, 2%):", + } + ) + + # === Timing Configuration === + order_refresh_time: int = Field( + default=30, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter order refresh time in seconds (how often to refresh orders):", + } + ) + order_refresh_tolerance_pct: Decimal = Field( + default=Decimal("-1"), + json_schema_extra={ + "prompt_on_new": False, "is_updatable": True, + "prompt": "Enter order refresh tolerance as decimal (e.g., 0.01 = 1%). -1 to disable:", + } + ) + filled_order_delay: int = Field( + default=60, + json_schema_extra={ + "prompt_on_new": False, "is_updatable": True, + "prompt": "Enter delay in seconds after a fill before placing new orders:", + } + ) + + # === Inventory Skew Configuration === + inventory_skew_enabled: bool = Field( + default=False, + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enable inventory skew? (adjusts order sizes based on inventory):", + } + ) + target_base_pct: Decimal = Field( + default=Decimal("0.5"), + json_schema_extra={ + "prompt_on_new": True, "is_updatable": True, + "prompt": "Enter target base percentage (e.g., 0.5 for 50% base, 50% quote):", + } + ) + inventory_range_multiplier: Decimal = Field( + default=Decimal("1.0"), + json_schema_extra={ + "prompt_on_new": False, "is_updatable": True, + "prompt": "Enter inventory range multiplier for skew calculation:", + } + ) + + # === Static Price Band Configuration === + price_ceiling: Decimal = Field( + default=Decimal("-1"), + json_schema_extra={ + "prompt_on_new": False, "is_updatable": True, + "prompt": "Enter static price ceiling (-1 to disable). Only sell orders above this price:", + } + ) + price_floor: Decimal = Field( + default=Decimal("-1"), + json_schema_extra={ + "prompt_on_new": False, "is_updatable": True, + "prompt": "Enter static price floor (-1 to disable). Only buy orders below this price:", + } + ) + + # === Validators === + @field_validator('buy_spreads', 'sell_spreads', mode="before") + @classmethod + def parse_spreads(cls, v): + if v is None or v == "": + return [] + if isinstance(v, str): + return [float(x.strip()) for x in v.split(',')] + return [float(x) for x in v] + + def get_spreads(self, trade_type: TradeType) -> List[float]: + """Get spreads for a trade type. Each spread defines one order level.""" + if trade_type == TradeType.BUY: + return self.buy_spreads + return self.sell_spreads + + def update_markets(self, markets: MarketDict) -> MarketDict: + return markets.add_or_update(self.connector_name, self.trading_pair) + + +class PMMV1(ControllerBase): + """ + PMM V1 Controller - Pure Market Making Controller. + + Replicates legacy pure_market_making strategy with simple limit orders. + """ + + def __init__(self, config: PMMV1Config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.market_data_provider.initialize_rate_sources([ConnectorPair( + connector_name=config.connector_name, trading_pair=config.trading_pair)]) + + # Track when each level can next create orders (for filled_order_delay) + self._level_next_create_timestamps: Dict[str, float] = {} + # Track last seen executor states to detect fills + self._last_seen_executors: Dict[str, bool] = {} + + def _detect_filled_executors(self): + """Detect executors that were filled (not cancelled).""" + # Get current active executor IDs by level + current_active_by_level = {} + filled_levels = set() + + for executor in self.executors_info: + level_id = executor.custom_info.get("level_id", "") + + if executor.is_active: + current_active_by_level[level_id] = True + elif executor.close_type == CloseType.POSITION_HOLD: + # POSITION_HOLD means the order was filled + filled_levels.add(level_id) + + # Check for levels that were active before but aren't now and were filled + for level_id, was_active in self._last_seen_executors.items(): + if (was_active and + level_id not in current_active_by_level and + level_id in filled_levels): + # This level was active before, not now, and was filled + self._handle_filled_executor(level_id) + + # Update last seen state + self._last_seen_executors = current_active_by_level.copy() + + def _handle_filled_executor(self, level_id: str): + """Set the next create timestamp for a level when its executor is filled.""" + current_time = self.market_data_provider.time() + self._level_next_create_timestamps[level_id] = current_time + self.config.filled_order_delay + + # Log the filled order delay + self.logger().debug(f"Order on level {level_id} filled. Next order for this level can be created after {self.config.filled_order_delay}s delay.") + + def _get_reference_price(self) -> Decimal: + """Get reference price (mid price).""" + try: + price = self.market_data_provider.get_price_by_type( + self.config.connector_name, + self.config.trading_pair, + PriceType.MidPrice + ) + if price is None or (isinstance(price, float) and np.isnan(price)): + return Decimal("0") + return Decimal(str(price)) + except Exception: + return Decimal("0") + + async def update_processed_data(self): + """ + Update processed data with reference price, inventory info, and derived metrics. + """ + # Detect filled executors (executors that disappeared since last check) + self._detect_filled_executors() + + reference_price = self._get_reference_price() + + # Calculate inventory metrics for skew + base_balance, quote_balance = self._get_balances() + total_value_in_quote = base_balance * reference_price + quote_balance if reference_price > 0 else Decimal("0") + + if total_value_in_quote > 0: + current_base_pct = (base_balance * reference_price) / total_value_in_quote + else: + current_base_pct = Decimal("0") + + # Calculate inventory skew multipliers using legacy algorithm + buy_skew, sell_skew = self._calculate_inventory_skew_legacy( + current_base_pct, base_balance, quote_balance, reference_price + ) + + # Determine effective price ceiling and floor + effective_ceiling = self.config.price_ceiling if self.config.price_ceiling > 0 else None + effective_floor = self.config.price_floor if self.config.price_floor > 0 else None + + # Calculate proposal prices for tolerance comparison + buy_proposal_prices, sell_proposal_prices = self._calculate_proposal_prices(reference_price) + + self.processed_data = { + "reference_price": reference_price, + "current_base_pct": current_base_pct, + "base_balance": base_balance, + "quote_balance": quote_balance, + "buy_skew": buy_skew, + "sell_skew": sell_skew, + "price_ceiling": effective_ceiling, + "price_floor": effective_floor, + "buy_proposal_prices": buy_proposal_prices, + "sell_proposal_prices": sell_proposal_prices, + } + + def _get_balances(self) -> Tuple[Decimal, Decimal]: + """Get base and quote balances from the connector.""" + try: + base, quote = self.config.trading_pair.split("-") + base_balance = self.market_data_provider.get_balance( + self.config.connector_name, base + ) + quote_balance = self.market_data_provider.get_balance( + self.config.connector_name, quote + ) + return Decimal(str(base_balance)), Decimal(str(quote_balance)) + except Exception: + return Decimal("0"), Decimal("0") + + def _calculate_inventory_skew_legacy( + self, + current_base_pct: Decimal, + base_balance: Decimal, + quote_balance: Decimal, + reference_price: Decimal + ) -> Tuple[Decimal, Decimal]: + """ + Calculate inventory skew multipliers matching the legacy inventory_skew_calculator.pyx algorithm. + + The legacy algorithm: + 1. Uses total_order_size * inventory_range_multiplier for the range (in base asset) + 2. Calculates water marks around target + 3. Uses np.interp for smooth interpolation + 4. Returns bid/ask ratios from 0.0 to 2.0 + """ + if not self.config.inventory_skew_enabled: + return Decimal("1"), Decimal("1") + + if reference_price <= 0: + return Decimal("1"), Decimal("1") + + # Get total order size in base asset for range calculation + num_buy_levels = len(self.config.get_spreads(TradeType.BUY)) + num_sell_levels = len(self.config.get_spreads(TradeType.SELL)) + total_order_size_base = float(self.config.order_amount) * (num_buy_levels + num_sell_levels) + + if total_order_size_base <= 0: + return Decimal("1"), Decimal("1") + + # Calculate range in base asset (matching legacy) + base_asset_range = total_order_size_base * float(self.config.inventory_range_multiplier) + + # Call the legacy calculation + return self._c_calculate_bid_ask_ratios( + float(base_balance), + float(quote_balance), + float(reference_price), + float(self.config.target_base_pct), + base_asset_range + ) + + def _c_calculate_bid_ask_ratios( + self, + base_asset_amount: float, + quote_asset_amount: float, + price: float, + target_base_asset_ratio: float, + base_asset_range: float + ) -> Tuple[Decimal, Decimal]: + """ + Exact port of legacy c_calculate_bid_ask_ratios_from_base_asset_ratio. + """ + total_portfolio_value = base_asset_amount * price + quote_asset_amount + + if total_portfolio_value <= 0.0 or base_asset_range <= 0.0: + return Decimal("1"), Decimal("1") + + base_asset_value = base_asset_amount * price + base_asset_range_value = min(base_asset_range * price, total_portfolio_value * 0.5) + target_base_asset_value = total_portfolio_value * target_base_asset_ratio + left_base_asset_value_limit = max(target_base_asset_value - base_asset_range_value, 0.0) + right_base_asset_value_limit = target_base_asset_value + base_asset_range_value + + # Use np.interp for smooth interpolation (matching legacy) + left_inventory_ratio = float(np.interp( + base_asset_value, + [left_base_asset_value_limit, target_base_asset_value], + [0.0, 0.5] + )) + right_inventory_ratio = float(np.interp( + base_asset_value, + [target_base_asset_value, right_base_asset_value_limit], + [0.5, 1.0] + )) + + if base_asset_value < target_base_asset_value: + bid_adjustment = float(np.interp(left_inventory_ratio, [0, 0.5], [2.0, 1.0])) + else: + bid_adjustment = float(np.interp(right_inventory_ratio, [0.5, 1], [1.0, 0.0])) + + ask_adjustment = 2.0 - bid_adjustment + + return Decimal(str(bid_adjustment)), Decimal(str(ask_adjustment)) + + def _calculate_proposal_prices( + self, reference_price: Decimal + ) -> Tuple[List[Decimal], List[Decimal]]: + """Calculate what the proposal prices would be for tolerance comparison.""" + buy_spreads = self.config.get_spreads(TradeType.BUY) + sell_spreads = self.config.get_spreads(TradeType.SELL) + + buy_prices = [] + for spread in buy_spreads: + price = reference_price * (Decimal("1") - Decimal(str(spread))) + buy_prices.append(price) + + sell_prices = [] + for spread in sell_spreads: + price = reference_price * (Decimal("1") + Decimal(str(spread))) + sell_prices.append(price) + + return buy_prices, sell_prices + + def determine_executor_actions(self) -> List[ExecutorAction]: + """Determine actions based on current state.""" + # Don't create new actions if the controller is being stopped + if self.status == RunnableStatus.TERMINATED: + return [] + + actions = [] + actions.extend(self.create_actions_proposal()) + actions.extend(self.stop_actions_proposal()) + return actions + + def create_actions_proposal(self) -> List[ExecutorAction]: + """Create actions proposal for new executors.""" + create_actions = [] + + # Get levels to execute + levels_to_execute = self.get_levels_to_execute() + + buy_spreads = self.config.get_spreads(TradeType.BUY) + sell_spreads = self.config.get_spreads(TradeType.SELL) + + reference_price = Decimal(self.processed_data["reference_price"]) + if reference_price <= 0: + return [] + + buy_skew = self.processed_data["buy_skew"] + sell_skew = self.processed_data["sell_skew"] + + for level_id in levels_to_execute: + trade_type = self.get_trade_type_from_level_id(level_id) + level = self.get_level_from_level_id(level_id) + + # Get spread for this level + if trade_type == TradeType.BUY: + if level >= len(buy_spreads): + continue + spread_in_pct = Decimal(str(buy_spreads[level])) + skew = buy_skew + else: + if level >= len(sell_spreads): + continue + spread_in_pct = Decimal(str(sell_spreads[level])) + skew = sell_skew + + # Calculate order price + side_multiplier = Decimal("-1") if trade_type == TradeType.BUY else Decimal("1") + price = reference_price * (Decimal("1") + side_multiplier * spread_in_pct) + + # Apply inventory skew to order amount (already in base asset) + amount = self.config.order_amount * skew + amount = self.market_data_provider.quantize_order_amount( + self.config.connector_name, self.config.trading_pair, amount + ) + + if amount == Decimal("0"): + continue + + # Quantize price + price = self.market_data_provider.quantize_order_price( + self.config.connector_name, self.config.trading_pair, price + ) + + # Create executor config + executor_config = self._get_executor_config(level_id, price, amount, trade_type) + if executor_config is not None: + create_actions.append(CreateExecutorAction( + controller_id=self.config.id, + executor_config=executor_config + )) + + return create_actions + + def get_levels_to_execute(self) -> List[str]: + """Get levels that need new executors. + + A level is considered "working" (and won't get a new executor) if: + - It has an active executor, OR + - Its filled_order_delay period hasn't expired yet + """ + current_time = self.market_data_provider.time() + + # Get levels with active executors + active_levels = self.filter_executors( + executors=self.executors_info, + filter_func=lambda x: x.is_active + ) + active_level_ids = [executor.custom_info.get("level_id", "") for executor in active_levels] + + # Get missing levels + missing_levels = self._get_not_active_levels_ids(active_level_ids) + + # Filter out levels still in filled_order_delay period + missing_levels = [ + level_id for level_id in missing_levels + if current_time >= self._level_next_create_timestamps.get(level_id, 0) + ] + + # Apply price band filter + missing_levels = self._apply_price_band_filter(missing_levels) + + return missing_levels + + def _get_not_active_levels_ids(self, active_level_ids: List[str]) -> List[str]: + """Get level IDs that are not currently active.""" + buy_spreads = self.config.get_spreads(TradeType.BUY) + sell_spreads = self.config.get_spreads(TradeType.SELL) + + num_buy_levels = len(buy_spreads) + num_sell_levels = len(sell_spreads) + + buy_ids_missing = [ + self.get_level_id_from_side(TradeType.BUY, level) + for level in range(num_buy_levels) + if self.get_level_id_from_side(TradeType.BUY, level) not in active_level_ids + ] + sell_ids_missing = [ + self.get_level_id_from_side(TradeType.SELL, level) + for level in range(num_sell_levels) + if self.get_level_id_from_side(TradeType.SELL, level) not in active_level_ids + ] + return buy_ids_missing + sell_ids_missing + + def _apply_price_band_filter(self, level_ids: List[str]) -> List[str]: + """Filter out levels that violate price band constraints. + + Price band logic (matching legacy pure_market_making): + - If price >= ceiling: only sell orders (don't buy at high prices) + - If price <= floor: only buy orders (don't sell at low prices) + """ + reference_price = self.processed_data["reference_price"] + ceiling = self.processed_data.get("price_ceiling") + floor = self.processed_data.get("price_floor") + + filtered = [] + for level_id in level_ids: + trade_type = self.get_trade_type_from_level_id(level_id) + if trade_type == TradeType.BUY and ceiling is not None and reference_price >= ceiling: + # Price at or above ceiling: only sell orders + continue + if trade_type == TradeType.SELL and floor is not None and reference_price <= floor: + # Price at or below floor: only buy orders + continue + filtered.append(level_id) + return filtered + + def stop_actions_proposal(self) -> List[ExecutorAction]: + """Create actions to stop executors.""" + stop_actions = [] + stop_actions.extend(self._executors_to_refresh()) + return stop_actions + + def _executors_to_refresh(self) -> List[StopExecutorAction]: + """Get executors that should be refreshed. + + Matching legacy behavior: + - Compares current order prices to proposal prices (not just reference price) + - If ALL orders on a side are within tolerance, don't refresh that side + """ + current_time = self.market_data_provider.time() + + # Only consider refresh after refresh time + executors_past_refresh = [ + e for e in self.executors_info + if e.is_active and not e.is_trading + and current_time - e.timestamp > self.config.order_refresh_time + ] + + if not executors_past_refresh: + return [] + + # If tolerance is disabled, refresh all + if self.config.order_refresh_tolerance_pct < 0: + return [ + StopExecutorAction( + controller_id=self.config.id, + executor_id=executor.id, + keep_position=True + ) + for executor in executors_past_refresh + ] + + # Get current order prices and proposal prices + buy_proposal_prices = self.processed_data.get("buy_proposal_prices", []) + sell_proposal_prices = self.processed_data.get("sell_proposal_prices", []) + + # Get current buy/sell order prices + current_buy_prices = [] + current_sell_prices = [] + for executor in executors_past_refresh: + level_id = executor.custom_info.get("level_id", "") + order_price = getattr(executor.config, 'price', None) + if order_price is None: + continue + if level_id.startswith("buy"): + current_buy_prices.append(order_price) + elif level_id.startswith("sell"): + current_sell_prices.append(order_price) + + # Check if within tolerance (matching legacy c_is_within_tolerance) + buys_within_tolerance = self._is_within_tolerance( + current_buy_prices, buy_proposal_prices + ) + sells_within_tolerance = self._is_within_tolerance( + current_sell_prices, sell_proposal_prices + ) + + # Log tolerance decisions + if buys_within_tolerance and sells_within_tolerance: + if executors_past_refresh: + executor_level_ids = [e.custom_info.get("level_id", "unknown") for e in executors_past_refresh] + self.logger().debug(f"Orders {executor_level_ids} will not be canceled because they are within the order tolerance ({self.config.order_refresh_tolerance_pct:.2%}).") + return [] + + # Log which orders are being refreshed due to tolerance + if executors_past_refresh: + executor_level_ids = [e.custom_info.get("level_id", "unknown") for e in executors_past_refresh] + tolerance_reason = [] + if not buys_within_tolerance: + tolerance_reason.append("buy orders outside tolerance") + if not sells_within_tolerance: + tolerance_reason.append("sell orders outside tolerance") + reason = " and ".join(tolerance_reason) + self.logger().debug(f"Refreshing orders {executor_level_ids} due to {reason} (tolerance: {self.config.order_refresh_tolerance_pct:.2%}).") + + # Otherwise, refresh all executors + return [ + StopExecutorAction( + controller_id=self.config.id, + executor_id=executor.id, + keep_position=True + ) + for executor in executors_past_refresh + ] + + def _is_within_tolerance( + self, current_prices: List[Decimal], proposal_prices: List[Decimal] + ) -> bool: + """Check if current prices are within tolerance of proposal prices. + + Matching legacy c_is_within_tolerance behavior. + """ + if len(current_prices) != len(proposal_prices): + return False + + if not current_prices: + return True + + current_sorted = sorted(current_prices) + proposal_sorted = sorted(proposal_prices) + + for current, proposal in zip(current_sorted, proposal_sorted): + if current == 0: + return False + diff_pct = abs(proposal - current) / current + if diff_pct > self.config.order_refresh_tolerance_pct: + return False + + return True + + def _get_executor_config( + self, level_id: str, price: Decimal, amount: Decimal, trade_type: TradeType + ) -> Optional[OrderExecutorConfig]: + """Create executor config for a level (simple limit order like legacy PMM).""" + return OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=self.config.trading_pair, + side=trade_type, + amount=amount, + execution_strategy=ExecutionStrategy.LIMIT, + price=price, + level_id=level_id, + ) + + def get_level_id_from_side(self, trade_type: TradeType, level: int) -> str: + """Get level ID from trade type and level number.""" + return f"{trade_type.name.lower()}_{level}" + + def get_trade_type_from_level_id(self, level_id: str) -> TradeType: + """Get trade type from level ID.""" + return TradeType.BUY if level_id.startswith("buy") else TradeType.SELL + + def get_level_from_level_id(self, level_id: str) -> int: + """Get level number from level ID.""" + if "_" not in level_id: + return 0 + return int(level_id.split('_')[1]) + + def to_format_status(self) -> List[str]: + """Get formatted status display.""" + from itertools import zip_longest + + status = [] + + # Get data + base_pct = self.processed_data.get('current_base_pct', Decimal('0')) + target_pct = self.config.target_base_pct + buy_skew = self.processed_data.get('buy_skew', Decimal('1')) + sell_skew = self.processed_data.get('sell_skew', Decimal('1')) + ref_price = self.processed_data.get('reference_price', Decimal('0')) + ceiling = self.processed_data.get('price_ceiling') + floor = self.processed_data.get('price_floor') + + active_buy = sum(1 for e in self.executors_info + if e.is_active and e.custom_info.get("level_id", "").startswith("buy")) + active_sell = sum(1 for e in self.executors_info + if e.is_active and e.custom_info.get("level_id", "").startswith("sell")) + + # Layout + w = 89 # total width including outer pipes + hw = (w - 3) // 2 # half width for two-column rows (minus 3 for "| " + "|" + " |") + + def sep(char="-"): + return char * w + + def row2(left, right): + return f"| {left:<{hw}}| {right:<{hw}}|" + + def row1(content): + return f"| {content:<{w - 4}} |" + + # Header + status.append(sep("=")) + header = f"PMM V1 | {self.config.connector_name}:{self.config.trading_pair}" + status.append(f"|{header:^{w - 2}}|") + status.append(sep("=")) + + # Inventory & Settings + status.append(row2("INVENTORY", "SETTINGS")) + status.append(sep()) + inv = [ + f"Base %: {base_pct:.2%} (target {target_pct:.2%})", + f"Buy Skew: {buy_skew:.2f}x | Sell Skew: {sell_skew:.2f}x", + ] + settings = [ + f"Order Amount: {self.config.order_amount} base", + f"Spreads B: {self.config.buy_spreads} S: {self.config.sell_spreads}", + ] + for left, right in zip_longest(inv, settings, fillvalue=""): + status.append(row2(left, right)) + + # Market & Price Bands + status.append(sep()) + status.append(row2("MARKET", "PRICE BANDS")) + status.append(sep()) + ceiling_str = f"{ceiling:.8g}" if ceiling else "None" + floor_str = f"{floor:.8g}" if floor else "None" + market = [ + f"Ref Price: {ref_price:.8g}", + f"Active: Buy={active_buy} Sell={active_sell}", + ] + bands = [ + f"Ceiling: {ceiling_str}", + f"Floor: {floor_str}", + ] + for left, right in zip_longest(market, bands, fillvalue=""): + status.append(row2(left, right)) + + # Inventory bar + status.append(sep()) + bar_width = w - 17 # account for "| Inventory: [" + "] |" + filled = int(float(base_pct) * bar_width) + target_pos = int(float(target_pct) * bar_width) + bar = "" + for i in range(bar_width): + if i == filled: + bar += "X" + elif i == target_pos: + bar += ":" + elif i < filled: + bar += "#" + else: + bar += "." + status.append(f"| Inventory: [{bar}] |") + status.append(sep("=")) + + return status diff --git a/bots/controllers/generic/quantum_grid_allocator.py b/bots/controllers/generic/quantum_grid_allocator.py new file mode 100644 index 00000000..efde0cb1 --- /dev/null +++ b/bots/controllers/generic/quantum_grid_allocator.py @@ -0,0 +1,493 @@ +from decimal import Decimal +from typing import Dict, List, Set, Union + +import pandas_ta as ta # noqa: F401 +from pydantic import Field, field_validator + +from hummingbot.core.data_type.common import OrderType, PositionMode, PriceType, TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair +from hummingbot.strategy_v2.executors.grid_executor.data_types import GridExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, StopExecutorAction +from hummingbot.strategy_v2.models.executors_info import ExecutorInfo + + +class QGAConfig(ControllerConfigBase): + controller_name: str = "quantum_grid_allocator" + + # Portfolio allocation zones + long_only_threshold: Decimal = Field(default=Decimal("0.2"), json_schema_extra={"is_updatable": True}) + short_only_threshold: Decimal = Field(default=Decimal("0.2"), json_schema_extra={"is_updatable": True}) + hedge_ratio: Decimal = Field(default=Decimal("2"), json_schema_extra={"is_updatable": True}) + + # Grid allocation multipliers + base_grid_value_pct: Decimal = Field(default=Decimal("0.08"), json_schema_extra={"is_updatable": True}) + max_grid_value_pct: Decimal = Field(default=Decimal("0.15"), json_schema_extra={"is_updatable": True}) + + # Order frequency settings + safe_extra_spread: Decimal = Field(default=Decimal("0.0001"), json_schema_extra={"is_updatable": True}) + favorable_order_frequency: int = Field(default=2, json_schema_extra={"is_updatable": True}) + unfavorable_order_frequency: int = Field(default=5, json_schema_extra={"is_updatable": True}) + max_orders_per_batch: int = Field(default=1, json_schema_extra={"is_updatable": True}) + + # Portfolio allocation + portfolio_allocation: Dict[str, Decimal] = Field( + default={ + "SOL": Decimal("0.50"), # 50% + }, + json_schema_extra={"is_updatable": True}) + # Grid parameters + grid_range: Decimal = Field(default=Decimal("0.002"), json_schema_extra={"is_updatable": True}) + tp_sl_ratio: Decimal = Field(default=Decimal("0.8"), json_schema_extra={"is_updatable": True}) + min_order_amount: Decimal = Field(default=Decimal("5"), json_schema_extra={"is_updatable": True}) + # Risk parameters + max_deviation: Decimal = Field(default=Decimal("0.05"), json_schema_extra={"is_updatable": True}) + max_open_orders: int = Field(default=2, json_schema_extra={"is_updatable": True}) + # Exchange settings + connector_name: str = "binance" + leverage: int = 1 + position_mode: PositionMode = PositionMode.HEDGE + quote_asset: str = "FDUSD" + fee_asset: str = "BNB" + # Grid price multipliers + min_spread_between_orders: Decimal = Field( + default=Decimal("0.0001"), # 0.01% between orders + json_schema_extra={"is_updatable": True}) + grid_tp_multiplier: Decimal = Field( + default=Decimal("0.0001"), # 0.2% take profit + json_schema_extra={"is_updatable": True}) + # Grid safety parameters + limit_price_spread: Decimal = Field( + default=Decimal("0.001"), # 0.1% spread for limit price + json_schema_extra={"is_updatable": True}) + activation_bounds: Decimal = Field( + default=Decimal("0.0002"), # Activation bounds for orders + json_schema_extra={"is_updatable": True}) + bb_length: int = 100 + bb_std_dev: float = 2.0 + interval: str = "1s" + dynamic_grid_range: bool = Field(default=False, json_schema_extra={"is_updatable": True}) + show_terminated_details: bool = False + + @property + def quote_asset_allocation(self) -> Decimal: + """Calculate the implicit quote asset (FDUSD) allocation""" + return Decimal("1") - sum(self.portfolio_allocation.values()) + + @field_validator("portfolio_allocation") + @classmethod + def validate_allocation(cls, v): + total = sum(v.values()) + if total >= Decimal("1"): + raise ValueError(f"Total allocation {total} exceeds or equals 100%. Must leave room for FDUSD allocation.") + if "FDUSD" in v: + raise ValueError("FDUSD should not be explicitly allocated as it is the quote asset") + return v + + def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: + if self.connector_name not in markets: + markets[self.connector_name] = set() + for asset in self.portfolio_allocation: + markets[self.connector_name].add(f"{asset}-{self.quote_asset}") + return markets + + +class QuantumGridAllocator(ControllerBase): + def __init__(self, config: QGAConfig, *args, **kwargs): + self.config = config + self.metrics = {} + # Track unfavorable grid IDs + self.unfavorable_grid_ids = set() + # Track held positions from unfavorable grids + self.unfavorable_positions = { + f"{asset}-{config.quote_asset}": { + 'long': {'size': Decimal('0'), 'value': Decimal('0'), 'weighted_price': Decimal('0')}, + 'short': {'size': Decimal('0'), 'value': Decimal('0'), 'weighted_price': Decimal('0')} + } + for asset in config.portfolio_allocation + } + super().__init__(config, *args, **kwargs) + self.initialize_rate_sources() + + def initialize_rate_sources(self): + fee_pair = ConnectorPair(connector_name=self.config.connector_name, trading_pair=f"{self.config.fee_asset}-{self.config.quote_asset}") + self.market_data_provider.initialize_rate_sources([fee_pair]) + + async def update_processed_data(self): + # Get the bb width to use it as the range for the grid + for asset in self.config.portfolio_allocation: + trading_pair = f"{asset}-{self.config.quote_asset}" + candles = self.market_data_provider.get_candles_df( + connector_name=self.config.connector_name, + trading_pair=trading_pair, + interval=self.config.interval, + max_records=self.config.bb_length + 100 + ) + if len(candles) == 0: + bb_width = self.config.grid_range + else: + bb = ta.bbands(candles["close"], length=self.config.bb_length, std=self.config.bb_std_dev) + bb_width = bb[f"BBB_{self.config.bb_length}_{self.config.bb_std_dev}"].iloc[-1] / 100 + self.processed_data[trading_pair] = { + "bb_width": bb_width + } + + def update_portfolio_metrics(self): + """ + Calculate theoretical vs actual portfolio allocations + """ + metrics = { + "theoretical": {}, + "actual": {}, + "difference": {}, + } + + # Get real balances and calculate total portfolio value + quote_balance = self.market_data_provider.get_balance(self.config.connector_name, self.config.quote_asset) + total_value_quote = quote_balance + + # Calculate actual allocations including positions + for asset in self.config.portfolio_allocation: + trading_pair = f"{asset}-{self.config.quote_asset}" + price = self.get_mid_price(trading_pair) + # Get balance and add any position from active grid + balance = self.market_data_provider.get_balance(self.config.connector_name, asset) + value = balance * price + total_value_quote += value + metrics["actual"][asset] = value + # Calculate theoretical allocations and differences + for asset in self.config.portfolio_allocation: + theoretical_value = total_value_quote * self.config.portfolio_allocation[asset] + metrics["theoretical"][asset] = theoretical_value + metrics["difference"][asset] = metrics["actual"][asset] - theoretical_value + # Add quote asset metrics + metrics["actual"][self.config.quote_asset] = quote_balance + metrics["theoretical"][self.config.quote_asset] = total_value_quote * self.config.quote_asset_allocation + metrics["difference"][self.config.quote_asset] = quote_balance - metrics["theoretical"][self.config.quote_asset] + metrics["total_portfolio_value"] = total_value_quote + self.metrics = metrics + + def get_active_grids_by_asset(self) -> Dict[str, List[ExecutorInfo]]: + """Group active grids by asset using filter_executors""" + active_grids = {} + for asset in self.config.portfolio_allocation: + if asset == self.config.quote_asset: + continue + trading_pair = f"{asset}-{self.config.quote_asset}" + active_executors = self.filter_executors( + executors=self.executors_info, + filter_func=lambda e: ( + e.is_active and + e.config.trading_pair == trading_pair + ) + ) + if active_executors: + active_grids[asset] = active_executors + return active_grids + + def to_format_status(self) -> List[str]: + """Generate a detailed status report with portfolio, grid, and position information""" + status_lines = [] + total_value = self.metrics.get("total_portfolio_value", Decimal("0")) + # Portfolio Status + status_lines.append(f"Total Portfolio Value: ${total_value:,.2f}") + status_lines.append("") + status_lines.append("Portfolio Status:") + status_lines.append("-" * 80) + status_lines.append( + f"{'Asset':<8} | " + f"{'Actual':>10} | " + f"{'Target':>10} | " + f"{'Diff':>10} | " + f"{'Dev %':>8}" + ) + status_lines.append("-" * 80) + # Show metrics for each asset + for asset in self.config.portfolio_allocation: + actual = self.metrics["actual"].get(asset, Decimal("0")) + theoretical = self.metrics["theoretical"].get(asset, Decimal("0")) + difference = self.metrics["difference"].get(asset, Decimal("0")) + deviation_pct = (difference / theoretical * 100) if theoretical != Decimal("0") else Decimal("0") + status_lines.append( + f"{asset:<8} | " + f"${actual:>9.2f} | " + f"${theoretical:>9.2f} | " + f"${difference:>+9.2f} | " + f"{deviation_pct:>+7.1f}%" + ) + # Add quote asset metrics + quote_asset = self.config.quote_asset + actual = self.metrics["actual"].get(quote_asset, Decimal("0")) + theoretical = self.metrics["theoretical"].get(quote_asset, Decimal("0")) + difference = self.metrics["difference"].get(quote_asset, Decimal("0")) + deviation_pct = (difference / theoretical * 100) if theoretical != Decimal("0") else Decimal("0") + status_lines.append("-" * 80) + status_lines.append( + f"{quote_asset:<8} | " + f"${actual:>9.2f} | " + f"${theoretical:>9.2f} | " + f"${difference:>+9.2f} | " + f"{deviation_pct:>+7.1f}%" + ) + # Active Grids Summary + active_grids = self.get_active_grids_by_asset() + if active_grids: + status_lines.append("") + status_lines.append("Active Grids:") + status_lines.append("-" * 140) + status_lines.append( + f"{'Asset':<8} {'Side':<6} | " + f"{'Total ($)':<10} {'Position':<10} {'Volume':<10} | " + f"{'PnL':<10} {'RPnL':<10} {'Fees':<10} | " + f"{'Start':<10} {'Current':<10} {'End':<10} {'Limit':<10}" + ) + status_lines.append("-" * 140) + for asset, executors in active_grids.items(): + for executor in executors: + config = executor.config + custom_info = executor.custom_info + trading_pair = config.trading_pair + current_price = self.get_mid_price(trading_pair) + # Get grid metrics + total_amount = Decimal(str(config.total_amount_quote)) + position_size = Decimal(str(custom_info.get('position_size_quote', '0'))) + volume = executor.filled_amount_quote + pnl = executor.net_pnl_quote + realized_pnl_quote = custom_info.get('realized_pnl_quote', Decimal('0')) + fees = executor.cum_fees_quote + status_lines.append( + f"{asset:<8} {config.side.name:<6} | " + f"${total_amount:<9.2f} ${position_size:<9.2f} ${volume:<9.2f} | " + f"${pnl:>+9.2f} ${realized_pnl_quote:>+9.2f} ${fees:>9.2f} | " + f"{config.start_price:<10.4f} {current_price:<10.4f} {config.end_price:<10.4f} {config.limit_price:<10.4f}" + ) + + status_lines.append("-" * 100 + "\n") + return status_lines + + def tp_multiplier(self): + return self.config.tp_sl_ratio + + def sl_multiplier(self): + return 1 - self.config.tp_sl_ratio + + def determine_executor_actions(self) -> List[Union[CreateExecutorAction, StopExecutorAction]]: + actions = [] + self.update_portfolio_metrics() + active_grids_by_asset = self.get_active_grids_by_asset() + for asset in self.config.portfolio_allocation: + if asset == self.config.quote_asset: + continue + trading_pair = f"{asset}-{self.config.quote_asset}" + # Check if there are any active grids for this asset + if asset in active_grids_by_asset: + self.logger().debug(f"Skipping {trading_pair} - Active grid exists") + continue + theoretical = self.metrics["theoretical"][asset] + difference = self.metrics["difference"][asset] + deviation = difference / theoretical if theoretical != Decimal("0") else Decimal("0") + mid_price = self.get_mid_price(trading_pair) + + # Calculate dynamic grid value percentage based on deviation + abs_deviation = abs(deviation) + grid_value_pct = self.config.max_grid_value_pct if abs_deviation > self.config.max_deviation else self.config.base_grid_value_pct + + self.logger().info( + f"{trading_pair} Grid Sizing - " + f"Deviation: {deviation:+.1%}, " + f"Grid Value %: {grid_value_pct:.1%}" + ) + if self.config.dynamic_grid_range: + grid_range = Decimal(self.processed_data[trading_pair]["bb_width"]) + else: + grid_range = self.config.grid_range + + # Determine which zone we're in by normalizing the deviation over the theoretical allocation + if deviation < -self.config.long_only_threshold: + # Long-only zone - only create buy grids + if difference < Decimal("0"): # Only if we need to buy + grid_value = min(abs(difference), theoretical * grid_value_pct) + start_price = mid_price * (1 - grid_range * self.sl_multiplier()) + end_price = mid_price * (1 + grid_range * self.tp_multiplier()) + grid_action = self.create_grid_executor( + trading_pair=trading_pair, + side=TradeType.BUY, + start_price=start_price, + end_price=end_price, + grid_value=grid_value, + is_unfavorable=False + ) + if grid_action is not None: + actions.append(grid_action) + elif deviation > self.config.short_only_threshold: + # Short-only zone - only create sell grids + if difference > Decimal("0"): # Only if we need to sell + grid_value = min(abs(difference), theoretical * grid_value_pct) + start_price = mid_price * (1 - grid_range * self.tp_multiplier()) + end_price = mid_price * (1 + grid_range * self.sl_multiplier()) + grid_action = self.create_grid_executor( + trading_pair=trading_pair, + side=TradeType.SELL, + start_price=start_price, + end_price=end_price, + grid_value=grid_value, + is_unfavorable=False + ) + if grid_action is not None: + actions.append(grid_action) + else: + # we create a buy and a sell grid with higher range pct and the base grid value pct + # to hedge the position + grid_value = theoretical * grid_value_pct + if difference < Decimal("0"): # create a bigger buy grid and sell grid + # Create buy grid + start_price = mid_price * (1 - 2 * grid_range * self.sl_multiplier()) + end_price = mid_price * (1 + grid_range * self.tp_multiplier()) + buy_grid_action = self.create_grid_executor( + trading_pair=trading_pair, + side=TradeType.BUY, + start_price=start_price, + end_price=end_price, + grid_value=grid_value, + is_unfavorable=False + ) + if buy_grid_action is not None: + actions.append(buy_grid_action) + # Create sell grid + start_price = mid_price * (1 - grid_range * self.tp_multiplier()) + end_price = mid_price * (1 + 2 * grid_range * self.sl_multiplier()) + sell_grid_action = self.create_grid_executor( + trading_pair=trading_pair, + side=TradeType.SELL, + start_price=start_price, + end_price=end_price, + grid_value=grid_value, + is_unfavorable=False + ) + if sell_grid_action is not None: + actions.append(sell_grid_action) + if difference > Decimal("0"): + # Create sell grid + start_price = mid_price * (1 - 2 * grid_range * self.tp_multiplier()) + end_price = mid_price * (1 + grid_range * self.sl_multiplier()) + sell_grid_action = self.create_grid_executor( + trading_pair=trading_pair, + side=TradeType.SELL, + start_price=start_price, + end_price=end_price, + grid_value=grid_value, + is_unfavorable=False + ) + if sell_grid_action is not None: + actions.append(sell_grid_action) + # Create buy grid + start_price = mid_price * (1 - grid_range * self.sl_multiplier()) + end_price = mid_price * (1 + 2 * grid_range * self.tp_multiplier()) + buy_grid_action = self.create_grid_executor( + trading_pair=trading_pair, + side=TradeType.BUY, + start_price=start_price, + end_price=end_price, + grid_value=grid_value, + is_unfavorable=False + ) + if buy_grid_action is not None: + actions.append(buy_grid_action) + return actions + + def create_grid_executor( + self, + trading_pair: str, + side: TradeType, + start_price: Decimal, + end_price: Decimal, + grid_value: Decimal, + is_unfavorable: bool = False + ) -> CreateExecutorAction: + """Creates a grid executor with dynamic sizing and range adjustments""" + # Get trading rules and minimum notional + trading_rules = self.market_data_provider.get_trading_rules(self.config.connector_name, trading_pair) + min_notional = max( + self.config.min_order_amount, + trading_rules.min_notional_size if trading_rules else Decimal("5.0") + ) + # Add safety margin and check if grid value is sufficient + min_grid_value = min_notional * Decimal("5") # Ensure room for at least 5 levels + if grid_value < min_grid_value: + self.logger().info( + f"Grid value {grid_value} is too small for {trading_pair}. " + f"Minimum required for viable grid: {min_grid_value}" + ) + return None # Skip grid creation if value is too small + + # Select order frequency based on grid favorability + order_frequency = ( + self.config.unfavorable_order_frequency if is_unfavorable + else self.config.favorable_order_frequency + ) + # Calculate limit price to be more aggressive than grid boundaries + if side == TradeType.BUY: + # For buys, limit price should be lower than start price + limit_price = start_price * (1 - self.config.limit_price_spread) + else: + # For sells, limit price should be higher than end price + limit_price = end_price * (1 + self.config.limit_price_spread) + # Create the executor action + action = CreateExecutorAction( + controller_id=self.config.id, + executor_config=GridExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_name, + trading_pair=trading_pair, + side=side, + start_price=start_price, + end_price=end_price, + limit_price=limit_price, + leverage=self.config.leverage, + total_amount_quote=grid_value, + safe_extra_spread=self.config.safe_extra_spread, + min_spread_between_orders=self.config.min_spread_between_orders, + min_order_amount_quote=self.config.min_order_amount, + max_open_orders=self.config.max_open_orders, + order_frequency=order_frequency, # Use dynamic order frequency + max_orders_per_batch=self.config.max_orders_per_batch, + activation_bounds=self.config.activation_bounds, + keep_position=True, # Always keep position for potential reversal + coerce_tp_to_step=True, + triple_barrier_config=TripleBarrierConfig( + take_profit=self.config.grid_tp_multiplier, + open_order_type=OrderType.LIMIT_MAKER, + take_profit_order_type=OrderType.LIMIT_MAKER, + stop_loss=None, + time_limit=None, + trailing_stop=None, + ))) + # Track unfavorable grid configs + if is_unfavorable: + self.unfavorable_grid_ids.add(action.executor_config.id) + self.logger().info( + f"Created unfavorable grid for {trading_pair} - " + f"Side: {side.name}, Value: ${grid_value:,.2f}, " + f"Order Frequency: {order_frequency}s" + ) + else: + self.logger().info( + f"Created favorable grid for {trading_pair} - " + f"Side: {side.name}, Value: ${grid_value:,.2f}, " + f"Order Frequency: {order_frequency}s" + ) + + return action + + def get_mid_price(self, trading_pair: str) -> Decimal: + return self.market_data_provider.get_price_by_type(self.config.connector_name, trading_pair, PriceType.MidPrice) + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.connector_name, + trading_pair=trading_pair + "-" + self.config.quote_asset, + interval=self.config.interval, + max_records=self.config.bb_length + 100 + ) for trading_pair in self.config.portfolio_allocation.keys()] diff --git a/bots/controllers/generic/spot_perp_arbitrage.py b/bots/controllers/generic/spot_perp_arbitrage.py deleted file mode 100644 index b477e0f2..00000000 --- a/bots/controllers/generic/spot_perp_arbitrage.py +++ /dev/null @@ -1,192 +0,0 @@ -from decimal import Decimal -from typing import Dict, List, Set - -from hummingbot.client.config.config_data_types import ClientFieldData -from hummingbot.core.data_type.common import OrderType, PositionAction, PriceType, TradeType -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig -from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase -from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig -from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction -from pydantic import Field - - -class SpotPerpArbitrageConfig(ControllerConfigBase): - controller_name: str = "spot_perp_arbitrage" - candles_config: List[CandlesConfig] = [] - spot_connector: str = Field( - default="binance", - client_data=ClientFieldData( - prompt=lambda e: "Enter the spot connector: ", - prompt_on_new=True - )) - spot_trading_pair: str = Field( - default="DOGE-USDT", - client_data=ClientFieldData( - prompt=lambda e: "Enter the spot trading pair: ", - prompt_on_new=True - )) - perp_connector: str = Field( - default="binance_perpetual", - client_data=ClientFieldData( - prompt=lambda e: "Enter the perp connector: ", - prompt_on_new=True - )) - perp_trading_pair: str = Field( - default="DOGE-USDT", - client_data=ClientFieldData( - prompt=lambda e: "Enter the perp trading pair: ", - prompt_on_new=True - )) - profitability: Decimal = Field( - default=0.002, - client_data=ClientFieldData( - prompt=lambda e: "Enter the minimum profitability: ", - prompt_on_new=True - )) - position_size_quote: float = Field( - default=50, - client_data=ClientFieldData( - prompt=lambda e: "Enter the position size in quote currency: ", - prompt_on_new=True - )) - - def update_markets(self, markets: Dict[str, Set[str]]) -> Dict[str, Set[str]]: - if self.spot_connector not in markets: - markets[self.spot_connector] = set() - markets[self.spot_connector].add(self.spot_trading_pair) - if self.perp_connector not in markets: - markets[self.perp_connector] = set() - markets[self.perp_connector].add(self.perp_trading_pair) - return markets - - -class SpotPerpArbitrage(ControllerBase): - - def __init__(self, config: SpotPerpArbitrageConfig, *args, **kwargs): - self.config = config - super().__init__(config, *args, **kwargs) - - @property - def spot_connector(self): - return self.market_data_provider.connectors[self.config.spot_connector] - - @property - def perp_connector(self): - return self.market_data_provider.connectors[self.config.perp_connector] - - def get_current_profitability_after_fees(self): - """ - This methods compares the profitability of buying at market in the two exchanges. If the side is TradeType.BUY - means that the operation is long on connector 1 and short on connector 2. - """ - spot_trading_pair = self.config.spot_trading_pair - perp_trading_pair = self.config.perp_trading_pair - - connector_spot_price = Decimal(self.market_data_provider.get_price_for_quote_volume( - connector_name=self.config.spot_connector, - trading_pair=spot_trading_pair, - quote_volume=self.config.position_size_quote, - is_buy=True, - ).result_price) - connector_perp_price = Decimal(self.market_data_provider.get_price_for_quote_volume( - connector_name=self.config.spot_connector, - trading_pair=perp_trading_pair, - quote_volume=self.config.position_size_quote, - is_buy=False, - ).result_price) - estimated_fees_spot_connector = self.spot_connector.get_fee( - base_currency=spot_trading_pair.split("-")[0], - quote_currency=spot_trading_pair.split("-")[1], - order_type=OrderType.MARKET, - order_side=TradeType.BUY, - amount=self.config.position_size_quote / float(connector_spot_price), - price=connector_spot_price, - is_maker=False, - ).percent - estimated_fees_perp_connector = self.perp_connector.get_fee( - base_currency=perp_trading_pair.split("-")[0], - quote_currency=perp_trading_pair.split("-")[1], - order_type=OrderType.MARKET, - order_side=TradeType.BUY, - amount=self.config.position_size_quote / float(connector_perp_price), - price=connector_perp_price, - is_maker=False, - position_action=PositionAction.OPEN - ).percent - - estimated_trade_pnl_pct = (connector_perp_price - connector_spot_price) / connector_spot_price - return estimated_trade_pnl_pct - estimated_fees_spot_connector - estimated_fees_perp_connector - - def is_active_arbitrage(self): - executors = self.filter_executors( - executors=self.executors_info, - filter_func=lambda e: e.is_active - ) - return len(executors) > 0 - - def current_pnl_pct(self): - executors = self.filter_executors( - executors=self.executors_info, - filter_func=lambda e: e.is_active - ) - filled_amount = sum(e.filled_amount_quote for e in executors) - return sum(e.net_pnl_quote for e in executors) / filled_amount if filled_amount > 0 else 0 - - async def update_processed_data(self): - self.processed_data = { - "profitability": self.get_current_profitability_after_fees(), - "active_arbitrage": self.is_active_arbitrage(), - "current_pnl": self.current_pnl_pct() - } - - def determine_executor_actions(self) -> List[ExecutorAction]: - executor_actions = [] - executor_actions.extend(self.create_new_arbitrage_actions()) - executor_actions.extend(self.stop_arbitrage_actions()) - return executor_actions - - def create_new_arbitrage_actions(self): - create_actions = [] - if not self.processed_data["active_arbitrage"] and \ - self.processed_data["profitability"] > self.config.profitability: - mid_price = self.market_data_provider.get_price_by_type(self.config.spot_connector, - self.config.spot_trading_pair, PriceType.MidPrice) - create_actions.append(CreateExecutorAction( - controller_id=self.config.id, - executor_config=PositionExecutorConfig( - timestamp=self.market_data_provider.time(), - connector_name=self.config.spot_connector, - trading_pair=self.config.spot_trading_pair, - side=TradeType.BUY, - amount=Decimal(self.config.position_size_quote) / mid_price, - triple_barrier_config=TripleBarrierConfig(open_order_type=OrderType.MARKET), - ) - )) - create_actions.append(CreateExecutorAction( - controller_id=self.config.id, - executor_config=PositionExecutorConfig( - timestamp=self.market_data_provider.time(), - connector_name=self.config.perp_connector, - trading_pair=self.config.perp_trading_pair, - side=TradeType.SELL, - amount=Decimal(self.config.position_size_quote) / mid_price, - triple_barrier_config=TripleBarrierConfig(open_order_type=OrderType.MARKET), - )) - ) - return create_actions - - def stop_arbitrage_actions(self): - stop_actions = [] - if self.processed_data["current_pnl"] > 0.003: - executors = self.filter_executors( - executors=self.executors_info, - filter_func=lambda e: e.is_active - ) - for executor in executors: - stop_actions.append(StopExecutorAction(controller_id=self.config.id, executor_id=executor.id)) - - def to_format_status(self) -> List[str]: - return [ - f"Current profitability: {self.processed_data['profitability']} | Min profitability: {self.config.profitability}", - f"Active arbitrage: {self.processed_data['active_arbitrage']}", - f"Current PnL: {self.processed_data['current_pnl']}"] diff --git a/bots/controllers/generic/stat_arb.py b/bots/controllers/generic/stat_arb.py new file mode 100644 index 00000000..fa21010c --- /dev/null +++ b/bots/controllers/generic/stat_arb.py @@ -0,0 +1,476 @@ +from decimal import Decimal +from typing import List + +import numpy as np +from sklearn.linear_model import LinearRegression + +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, PriceType, TradeType +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.strategy_v2.controllers import ControllerBase, ControllerConfigBase +from hummingbot.strategy_v2.executors.data_types import ConnectorPair, PositionSummary +from hummingbot.strategy_v2.executors.order_executor.data_types import ExecutionStrategy, OrderExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig, TripleBarrierConfig +from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction, StopExecutorAction + + +class StatArbConfig(ControllerConfigBase): + """ + Configuration for a statistical arbitrage controller that trades two cointegrated assets. + """ + controller_type: str = "generic" + controller_name: str = "stat_arb" + connector_pair_dominant: ConnectorPair = ConnectorPair(connector_name="binance_perpetual", trading_pair="SOL-USDT") + connector_pair_hedge: ConnectorPair = ConnectorPair(connector_name="binance_perpetual", trading_pair="POPCAT-USDT") + interval: str = "1m" + lookback_period: int = 300 + entry_threshold: Decimal = Decimal("2.0") + take_profit: Decimal = Decimal("0.0008") + tp_global: Decimal = Decimal("0.01") + sl_global: Decimal = Decimal("0.05") + min_amount_quote: Decimal = Decimal("10") + quoter_spread: Decimal = Decimal("0.0001") + quoter_cooldown: int = 30 + quoter_refresh: int = 10 + max_orders_placed_per_side: int = 2 + max_orders_filled_per_side: int = 2 + max_position_deviation: Decimal = Decimal("0.1") + pos_hedge_ratio: Decimal = Decimal("1.0") + leverage: int = 20 + position_mode: PositionMode = PositionMode.HEDGE + + @property + def triple_barrier_config(self) -> TripleBarrierConfig: + return TripleBarrierConfig( + take_profit=self.take_profit, + open_order_type=OrderType.LIMIT_MAKER, + take_profit_order_type=OrderType.LIMIT_MAKER, + ) + + def update_markets(self, markets: dict) -> dict: + """Update markets dictionary with both trading pairs""" + # Add dominant pair + if self.connector_pair_dominant.connector_name not in markets: + markets[self.connector_pair_dominant.connector_name] = set() + markets[self.connector_pair_dominant.connector_name].add(self.connector_pair_dominant.trading_pair) + + # Add hedge pair + if self.connector_pair_hedge.connector_name not in markets: + markets[self.connector_pair_hedge.connector_name] = set() + markets[self.connector_pair_hedge.connector_name].add(self.connector_pair_hedge.trading_pair) + + return markets + + +class StatArb(ControllerBase): + """ + Statistical arbitrage controller that trades two cointegrated assets. + """ + + def __init__(self, config: StatArbConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + self.theoretical_dominant_quote = self.config.total_amount_quote * (1 / (1 + self.config.pos_hedge_ratio)) + self.theoretical_hedge_quote = self.config.total_amount_quote * (self.config.pos_hedge_ratio / (1 + self.config.pos_hedge_ratio)) + + # Initialize processed data dictionary + self.processed_data = { + "dominant_price": None, + "hedge_price": None, + "spread": None, + "z_score": None, + "hedge_ratio": None, + "position_dominant": Decimal("0"), + "position_hedge": Decimal("0"), + "active_orders_dominant": [], + "active_orders_hedge": [], + "pair_pnl": Decimal("0"), + "signal": 0 # 0: no signal, 1: long dominant/short hedge, -1: short dominant/long hedge + } + + # Setup max records for safety + max_records = self.config.lookback_period + 20 + self.max_records = max_records + if "_perpetual" in self.config.connector_pair_dominant.connector_name: + connector = self.market_data_provider.get_connector(self.config.connector_pair_dominant.connector_name) + connector.set_position_mode(self.config.position_mode) + connector.set_leverage(self.config.connector_pair_dominant.trading_pair, self.config.leverage) + if "_perpetual" in self.config.connector_pair_hedge.connector_name: + connector = self.market_data_provider.get_connector(self.config.connector_pair_hedge.connector_name) + connector.set_position_mode(self.config.position_mode) + connector.set_leverage(self.config.connector_pair_hedge.trading_pair, self.config.leverage) + + def determine_executor_actions(self) -> List[ExecutorAction]: + """ + The execution logic for the statistical arbitrage strategy. + Market Data Conditions: Signal is generated based on the z-score of the spread between the two assets. + If signal == 1 --> long dominant/short hedge + If signal == -1 --> short dominant/long hedge + Execution Conditions: If the signal is generated add position executors to quote from the dominant and hedge markets. + We compare the current position with the theoretical position for the dominant and hedge assets. + If the current position + the active placed amount is greater than the theoretical position, can't place more orders. + If the imbalance scaled pct is greater than the threshold, we avoid placing orders in the market passed on filtered_connector_pair. + If the pnl of total position is greater than the take profit or lower than the stop loss, we close the position. + """ + actions: List[ExecutorAction] = [] + # Check global take profit and stop loss + if self.processed_data["pair_pnl_pct"] > self.config.tp_global or self.processed_data["pair_pnl_pct"] < -self.config.sl_global: + # Close all positions + for position in self.positions_held: + actions.extend(self.get_executors_to_reduce_position(position)) + return actions + # Check the signal + elif self.processed_data["signal"] != 0: + actions.extend(self.get_executors_to_quote()) + actions.extend(self.get_executors_to_reduce_position_on_opposite_signal()) + + # Get the executors to keep position after a cooldown is reached + actions.extend(self.get_executors_to_keep_position()) + actions.extend(self.get_executors_to_refresh()) + + return actions + + def get_executors_to_reduce_position_on_opposite_signal(self) -> List[ExecutorAction]: + if self.processed_data["signal"] == 1: + dominant_side, hedge_side = TradeType.SELL, TradeType.BUY + elif self.processed_data["signal"] == -1: + dominant_side, hedge_side = TradeType.BUY, TradeType.SELL + else: + return [] + # Get executors to stop + dominant_active_executors_to_stop = self.filter_executors(self.executors_info, filter_func=lambda e: e.connector_name == self.config.connector_pair_dominant.connector_name and e.trading_pair == self.config.connector_pair_dominant.trading_pair and e.side == dominant_side) + hedge_active_executors_to_stop = self.filter_executors(self.executors_info, filter_func=lambda e: e.connector_name == self.config.connector_pair_hedge.connector_name and e.trading_pair == self.config.connector_pair_hedge.trading_pair and e.side == hedge_side) + stop_actions = [StopExecutorAction(controller_id=self.config.id, executor_id=executor.id, keep_position=False) for executor in dominant_active_executors_to_stop + hedge_active_executors_to_stop] + + # Get order executors to reduce positions + reduce_actions: List[ExecutorAction] = [] + for position in self.positions_held: + if position.connector_name == self.config.connector_pair_dominant.connector_name and position.trading_pair == self.config.connector_pair_dominant.trading_pair and position.side == dominant_side: + reduce_actions.extend(self.get_executors_to_reduce_position(position)) + elif position.connector_name == self.config.connector_pair_hedge.connector_name and position.trading_pair == self.config.connector_pair_hedge.trading_pair and position.side == hedge_side: + reduce_actions.extend(self.get_executors_to_reduce_position(position)) + return stop_actions + reduce_actions + + def get_executors_to_keep_position(self) -> List[ExecutorAction]: + stop_actions: List[ExecutorAction] = [] + for executor in self.processed_data["executors_dominant_filled"] + self.processed_data["executors_hedge_filled"]: + if self.market_data_provider.time() - executor.timestamp >= self.config.quoter_cooldown: + # Create a new executor to keep the position + stop_actions.append(StopExecutorAction(controller_id=self.config.id, executor_id=executor.id, keep_position=True)) + return stop_actions + + def get_executors_to_refresh(self) -> List[ExecutorAction]: + refresh_actions: List[ExecutorAction] = [] + for executor in self.processed_data["executors_dominant_placed"] + self.processed_data["executors_hedge_placed"]: + if self.market_data_provider.time() - executor.timestamp >= self.config.quoter_refresh: + # Create a new executor to refresh the position + refresh_actions.append(StopExecutorAction(controller_id=self.config.id, executor_id=executor.id, keep_position=False)) + return refresh_actions + + def get_executors_to_quote(self) -> List[ExecutorAction]: + """ + Get Order Executor to quote from the dominant and hedge markets. + """ + actions: List[ExecutorAction] = [] + trade_type_dominant = TradeType.BUY if self.processed_data["signal"] == 1 else TradeType.SELL + trade_type_hedge = TradeType.SELL if self.processed_data["signal"] == 1 else TradeType.BUY + + # Analyze dominant active orders, max deviation and imbalance to create a new executor + if self.processed_data["dominant_gap"] > Decimal("0") and \ + self.processed_data["filter_connector_pair"] != self.config.connector_pair_dominant and \ + len(self.processed_data["executors_dominant_placed"]) < self.config.max_orders_placed_per_side and \ + len(self.processed_data["executors_dominant_filled"]) < self.config.max_orders_filled_per_side: + # Create Position Executor for dominant asset + if trade_type_dominant == TradeType.BUY: + price = self.processed_data["min_price_dominant"] * (1 - self.config.quoter_spread) + else: + price = self.processed_data["max_price_dominant"] * (1 + self.config.quoter_spread) + dominant_executor_config = PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, + side=trade_type_dominant, + entry_price=price, + amount=self.config.min_amount_quote / self.processed_data["dominant_price"], + triple_barrier_config=self.config.triple_barrier_config, + leverage=self.config.leverage, + ) + actions.append(CreateExecutorAction(controller_id=self.config.id, executor_config=dominant_executor_config)) + + # Analyze hedge active orders, max deviation and imbalance to create a new executor + if self.processed_data["hedge_gap"] > Decimal("0") and \ + self.processed_data["filter_connector_pair"] != self.config.connector_pair_hedge and \ + len(self.processed_data["executors_hedge_placed"]) < self.config.max_orders_placed_per_side and \ + len(self.processed_data["executors_hedge_filled"]) < self.config.max_orders_filled_per_side: + # Create Position Executor for hedge asset + if trade_type_hedge == TradeType.BUY: + price = self.processed_data["min_price_hedge"] * (1 - self.config.quoter_spread) + else: + price = self.processed_data["max_price_hedge"] * (1 + self.config.quoter_spread) + hedge_executor_config = PositionExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, + side=trade_type_hedge, + entry_price=price, + amount=self.config.min_amount_quote / self.processed_data["hedge_price"], + triple_barrier_config=self.config.triple_barrier_config, + leverage=self.config.leverage, + ) + actions.append(CreateExecutorAction(controller_id=self.config.id, executor_config=hedge_executor_config)) + return actions + + def get_executors_to_reduce_position(self, position: PositionSummary) -> List[ExecutorAction]: + """ + Get Order Executor to reduce position. + """ + if position.amount > Decimal("0"): + # Close position + config = OrderExecutorConfig( + timestamp=self.market_data_provider.time(), + connector_name=position.connector_name, + trading_pair=position.trading_pair, + side=TradeType.BUY if position.side == TradeType.SELL else TradeType.SELL, + amount=position.amount, + position_action=PositionAction.CLOSE, + execution_strategy=ExecutionStrategy.MARKET, + leverage=self.config.leverage, + ) + return [CreateExecutorAction(controller_id=self.config.id, executor_config=config)] + return [] + + async def update_processed_data(self): + """ + Update processed data with the latest market information and statistical calculations + needed for the statistical arbitrage strategy. + """ + # Stat arb analysis + spread, z_score = self.get_spread_and_z_score() + + # Generate trading signal based on z-score + entry_threshold = float(self.config.entry_threshold) + if z_score > entry_threshold: + # Spread is too high, expect it to revert: long dominant, short hedge + signal = 1 + dominant_side, hedge_side = TradeType.BUY, TradeType.SELL + elif z_score < -entry_threshold: + # Spread is too low, expect it to revert: short dominant, long hedge + signal = -1 + dominant_side, hedge_side = TradeType.SELL, TradeType.BUY + else: + # No signal + signal = 0 + dominant_side, hedge_side = None, None + + # Current prices + dominant_price, hedge_price = self.get_pairs_prices() + + # Get current positions stats by signal + positions_dominant = next((position for position in self.positions_held if position.connector_name == self.config.connector_pair_dominant.connector_name and position.trading_pair == self.config.connector_pair_dominant.trading_pair and (position.side == dominant_side or dominant_side is None)), None) + positions_hedge = next((position for position in self.positions_held if position.connector_name == self.config.connector_pair_hedge.connector_name and position.trading_pair == self.config.connector_pair_hedge.trading_pair and (position.side == hedge_side or hedge_side is None)), None) + # Get position stats + position_dominant_quote = positions_dominant.amount_quote if positions_dominant else Decimal("0") + position_hedge_quote = positions_hedge.amount_quote if positions_hedge else Decimal("0") + position_dominant_pnl_quote = positions_dominant.global_pnl_quote if positions_dominant else Decimal("0") + position_hedge_pnl_quote = positions_hedge.global_pnl_quote if positions_hedge else Decimal("0") + pair_pnl_pct = (position_dominant_pnl_quote + position_hedge_pnl_quote) / (position_dominant_quote + position_hedge_quote) if (position_dominant_quote + position_hedge_quote) != 0 else Decimal("0") + # Get active executors + executors_dominant_placed, executors_dominant_filled = self.get_executors_dominant() + executors_hedge_placed, executors_hedge_filled = self.get_executors_hedge() + min_price_dominant = Decimal(str(min([executor.config.entry_price for executor in executors_dominant_placed]))) if executors_dominant_placed else None + max_price_dominant = Decimal(str(max([executor.config.entry_price for executor in executors_dominant_placed]))) if executors_dominant_placed else None + min_price_hedge = Decimal(str(min([executor.config.entry_price for executor in executors_hedge_placed]))) if executors_hedge_placed else None + max_price_hedge = Decimal(str(max([executor.config.entry_price for executor in executors_hedge_placed]))) if executors_hedge_placed else None + + active_amount_dominant = Decimal(str(sum([executor.filled_amount_quote for executor in executors_dominant_filled]))) + active_amount_hedge = Decimal(str(sum([executor.filled_amount_quote for executor in executors_hedge_filled]))) + + # Compute imbalance based on the hedge ratio + dominant_gap = self.theoretical_dominant_quote - position_dominant_quote - active_amount_dominant + hedge_gap = self.theoretical_hedge_quote - position_hedge_quote - active_amount_hedge + imbalance = position_dominant_quote - position_hedge_quote + imbalance_scaled = position_dominant_quote - position_hedge_quote * self.config.pos_hedge_ratio + imbalance_scaled_pct = imbalance_scaled / position_dominant_quote if position_dominant_quote != Decimal("0") else Decimal("0") + filter_connector_pair = None + if imbalance_scaled_pct > self.config.max_position_deviation: + # Avoid placing orders in the dominant market + filter_connector_pair = self.config.connector_pair_dominant + elif imbalance_scaled_pct < -self.config.max_position_deviation: + # Avoid placing orders in the hedge market + filter_connector_pair = self.config.connector_pair_hedge + + # Update processed data + self.processed_data.update({ + "dominant_price": Decimal(str(dominant_price)), + "hedge_price": Decimal(str(hedge_price)), + "spread": Decimal(str(spread)), + "z_score": Decimal(str(z_score)), + "dominant_gap": Decimal(str(dominant_gap)), + "hedge_gap": Decimal(str(hedge_gap)), + "position_dominant_quote": position_dominant_quote, + "position_hedge_quote": position_hedge_quote, + "active_amount_dominant": active_amount_dominant, + "active_amount_hedge": active_amount_hedge, + "signal": signal, + # Store full dataframes for reference + "imbalance": Decimal(str(imbalance)), + "imbalance_scaled_pct": Decimal(str(imbalance_scaled_pct)), + "filter_connector_pair": filter_connector_pair, + "min_price_dominant": min_price_dominant if min_price_dominant is not None else Decimal(str(dominant_price)), + "max_price_dominant": max_price_dominant if max_price_dominant is not None else Decimal(str(dominant_price)), + "min_price_hedge": min_price_hedge if min_price_hedge is not None else Decimal(str(hedge_price)), + "max_price_hedge": max_price_hedge if max_price_hedge is not None else Decimal(str(hedge_price)), + "executors_dominant_filled": executors_dominant_filled, + "executors_hedge_filled": executors_hedge_filled, + "executors_dominant_placed": executors_dominant_placed, + "executors_hedge_placed": executors_hedge_placed, + "pair_pnl_pct": pair_pnl_pct, + }) + + def get_spread_and_z_score(self): + # Fetch candle data for both assets + dominant_df = self.market_data_provider.get_candles_df( + connector_name=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, + interval=self.config.interval, + max_records=self.max_records + ) + + hedge_df = self.market_data_provider.get_candles_df( + connector_name=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, + interval=self.config.interval, + max_records=self.max_records + ) + + if dominant_df.empty or hedge_df.empty: + self.logger().warning("Not enough candle data available for statistical analysis") + return + + # Extract close prices + dominant_prices = dominant_df['close'].values + hedge_prices = hedge_df['close'].values + + # Ensure we have enough data and both series have the same length + min_length = min(len(dominant_prices), len(hedge_prices)) + if min_length < self.config.lookback_period: + self.logger().warning( + f"Not enough data points for analysis. Required: {self.config.lookback_period}, Available: {min_length}") + return + + # Use the most recent data points + dominant_prices = dominant_prices[-self.config.lookback_period:] + hedge_prices = hedge_prices[-self.config.lookback_period:] + + # Convert to numpy arrays + dominant_prices_np = np.array(dominant_prices, dtype=float) + hedge_prices_np = np.array(hedge_prices, dtype=float) + + # Calculate percentage returns + dominant_pct_change = np.diff(dominant_prices_np) / dominant_prices_np[:-1] + hedge_pct_change = np.diff(hedge_prices_np) / hedge_prices_np[:-1] + + # Convert to cumulative returns + dominant_cum_returns = np.cumprod(dominant_pct_change + 1) + hedge_cum_returns = np.cumprod(hedge_pct_change + 1) + + # Normalize to start at 1 + dominant_cum_returns = dominant_cum_returns / dominant_cum_returns[0] if len(dominant_cum_returns) > 0 else np.array([1.0]) + hedge_cum_returns = hedge_cum_returns / hedge_cum_returns[0] if len(hedge_cum_returns) > 0 else np.array([1.0]) + + # Perform linear regression + dominant_cum_returns_reshaped = dominant_cum_returns.reshape(-1, 1) + reg = LinearRegression().fit(dominant_cum_returns_reshaped, hedge_cum_returns) + alpha = reg.intercept_ + beta = reg.coef_[0] + self.processed_data.update({ + "alpha": alpha, + "beta": beta, + }) + + # Calculate spread as percentage difference from predicted value + y_pred = alpha + beta * dominant_cum_returns + spread_pct = (hedge_cum_returns - y_pred) / y_pred * 100 + + # Calculate z-score + mean_spread = np.mean(spread_pct) + std_spread = np.std(spread_pct) + if std_spread == 0: + self.logger().warning("Standard deviation of spread is zero, cannot calculate z-score") + return + + current_spread = spread_pct[-1] + current_z_score = (current_spread - mean_spread) / std_spread + + return current_spread, current_z_score + + def get_pairs_prices(self): + current_dominant_price = self.market_data_provider.get_price_by_type( + connector_name=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, price_type=PriceType.MidPrice) + + current_hedge_price = self.market_data_provider.get_price_by_type( + connector_name=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, price_type=PriceType.MidPrice) + return current_dominant_price, current_hedge_price + + def get_executors_dominant(self): + active_executors_dominant_placed = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_dominant.connector_name and e.trading_pair == self.config.connector_pair_dominant.trading_pair and e.is_active and not e.is_trading and e.type == "position_executor" + ) + active_executors_dominant_filled = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_dominant.connector_name and e.trading_pair == self.config.connector_pair_dominant.trading_pair and e.is_active and e.is_trading and e.type == "position_executor" + ) + return active_executors_dominant_placed, active_executors_dominant_filled + + def get_executors_hedge(self): + active_executors_hedge_placed = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_hedge.connector_name and e.trading_pair == self.config.connector_pair_hedge.trading_pair and e.is_active and not e.is_trading and e.type == "position_executor" + ) + active_executors_hedge_filled = self.filter_executors( + self.executors_info, + filter_func=lambda e: e.connector_name == self.config.connector_pair_hedge.connector_name and e.trading_pair == self.config.connector_pair_hedge.trading_pair and e.is_active and e.is_trading and e.type == "position_executor" + ) + return active_executors_hedge_placed, active_executors_hedge_filled + + def to_format_status(self) -> List[str]: + """ + Format the status of the controller for display. + """ + status_lines = [] + status_lines.append(f""" +Dominant Pair: {self.config.connector_pair_dominant} | Hedge Pair: {self.config.connector_pair_hedge} | +Timeframe: {self.config.interval} | Lookback Period: {self.config.lookback_period} | Entry Threshold: {self.config.entry_threshold} + +Positions targets: +Theoretical Dominant : {self.theoretical_dominant_quote} | Theoretical Hedge: {self.theoretical_hedge_quote} | Position Hedge Ratio: {self.config.pos_hedge_ratio} +Position Dominant : {self.processed_data['position_dominant_quote']:.2f} | Position Hedge: {self.processed_data['position_hedge_quote']:.2f} | Imbalance: {self.processed_data['imbalance']:.2f} | Imbalance Scaled: {self.processed_data['imbalance_scaled_pct']:.2f} % + +Current Executors: +Active Orders Dominant : {len(self.processed_data['executors_dominant_placed'])} | Active Orders Hedge : {len(self.processed_data['executors_hedge_placed'])} | +Active Orders Dominant Filled: {len(self.processed_data['executors_dominant_filled'])} | Active Orders Hedge Filled: {len(self.processed_data['executors_hedge_filled'])} + +Signal: {self.processed_data['signal']:.2f} | Z-Score: {self.processed_data['z_score']:.2f} | Spread: {self.processed_data['spread']:.2f} +Alpha : {self.processed_data['alpha']:.2f} | Beta: {self.processed_data['beta']:.2f} +Pair PnL PCT: {self.processed_data['pair_pnl_pct'] * 100:.2f} % +""") + return status_lines + + def get_candles_config(self) -> List[CandlesConfig]: + max_records = self.config.lookback_period + 20 + return [ + CandlesConfig( + connector=self.config.connector_pair_dominant.connector_name, + trading_pair=self.config.connector_pair_dominant.trading_pair, + interval=self.config.interval, + max_records=max_records + ), + CandlesConfig( + connector=self.config.connector_pair_hedge.connector_name, + trading_pair=self.config.connector_pair_hedge.trading_pair, + interval=self.config.interval, + max_records=max_records + ) + ] diff --git a/bots/controllers/generic/xemm_multiple_levels.py b/bots/controllers/generic/xemm_multiple_levels.py index 2eb70d67..4780ef8a 100644 --- a/bots/controllers/generic/xemm_multiple_levels.py +++ b/bots/controllers/generic/xemm_multiple_levels.py @@ -1,81 +1,56 @@ import time from decimal import Decimal -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set import pandas as pd -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator + from hummingbot.client.ui.interface_utils import format_df_for_printout from hummingbot.core.data_type.common import PriceType, TradeType -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient from hummingbot.strategy_v2.controllers.controller_base import ControllerBase, ControllerConfigBase from hummingbot.strategy_v2.executors.data_types import ConnectorPair from hummingbot.strategy_v2.executors.xemm_executor.data_types import XEMMExecutorConfig from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, ExecutorAction -from pydantic import Field, validator class XEMMMultipleLevelsConfig(ControllerConfigBase): controller_name: str = "xemm_multiple_levels" - candles_config: List[CandlesConfig] = [] maker_connector: str = Field( - default="kucoin", - client_data=ClientFieldData( - prompt=lambda e: "Enter the maker connector: ", - prompt_on_new=True - )) + default="mexc", + json_schema_extra={"prompt": "Enter the maker connector: ", "prompt_on_new": True}) maker_trading_pair: str = Field( - default="LBR-USDT", - client_data=ClientFieldData( - prompt=lambda e: "Enter the maker trading pair: ", - prompt_on_new=True - )) + default="PEPE-USDT", + json_schema_extra={"prompt": "Enter the maker trading pair: ", "prompt_on_new": True}) taker_connector: str = Field( - default="okx", - client_data=ClientFieldData( - prompt=lambda e: "Enter the taker connector: ", - prompt_on_new=True - )) + default="binance", + json_schema_extra={"prompt": "Enter the taker connector: ", "prompt_on_new": True}) taker_trading_pair: str = Field( - default="LBR-USDT", - client_data=ClientFieldData( - prompt=lambda e: "Enter the taker trading pair: ", - prompt_on_new=True - )) + default="PEPE-USDT", + json_schema_extra={"prompt": "Enter the taker trading pair: ", "prompt_on_new": True}) buy_levels_targets_amount: List[List[Decimal]] = Field( default="0.003,10-0.006,20-0.009,30", - client_data=ClientFieldData( - prompt=lambda e: "Enter the buy levels targets with the following structure: " - "(target_profitability1,amount1-target_profitability2,amount2): ", - prompt_on_new=True - )) + json_schema_extra={ + "prompt": "Enter the buy levels targets with the following structure: (target_profitability1,amount1-target_profitability2,amount2): ", + "prompt_on_new": True}) sell_levels_targets_amount: List[List[Decimal]] = Field( default="0.003,10-0.006,20-0.009,30", - client_data=ClientFieldData( - prompt=lambda e: "Enter the sell levels targets with the following structure: " - "(target_profitability1,amount1-target_profitability2,amount2): ", - prompt_on_new=True - )) + json_schema_extra={ + "prompt": "Enter the sell levels targets with the following structure: (target_profitability1,amount1-target_profitability2,amount2): ", + "prompt_on_new": True}) min_profitability: Decimal = Field( - default=0.002, - client_data=ClientFieldData( - prompt=lambda e: "Enter the minimum profitability: ", - prompt_on_new=True - )) + default=0.003, + json_schema_extra={"prompt": "Enter the minimum profitability: ", "prompt_on_new": True}) max_profitability: Decimal = Field( default=0.01, - client_data=ClientFieldData( - prompt=lambda e: "Enter the maximum profitability: ", - prompt_on_new=True - )) + json_schema_extra={"prompt": "Enter the maximum profitability: ", "prompt_on_new": True}) max_executors_imbalance: int = Field( default=1, - client_data=ClientFieldData( - prompt=lambda e: "Enter the maximum executors imbalance: ", - prompt_on_new=True - )) + json_schema_extra={"prompt": "Enter the maximum executors imbalance: ", "prompt_on_new": True}) - @validator("buy_levels_targets_amount", "sell_levels_targets_amount", pre=True, always=True) - def validate_levels_targets_amount(cls, v, values): + @field_validator("buy_levels_targets_amount", "sell_levels_targets_amount", mode="before") + @classmethod + def validate_levels_targets_amount(cls, v): if isinstance(v, str): v = [list(map(Decimal, x.split(","))) for x in v.split("-")] return v @@ -97,14 +72,80 @@ def __init__(self, config: XEMMMultipleLevelsConfig, *args, **kwargs): self.buy_levels_targets_amount = config.buy_levels_targets_amount self.sell_levels_targets_amount = config.sell_levels_targets_amount super().__init__(config, *args, **kwargs) + self._gas_token_cache = {} + self._initialize_gas_tokens() + self.initialize_rate_sources() + + def initialize_rate_sources(self): + rates_required = [] + for connector_pair in [ + ConnectorPair(connector_name=self.config.maker_connector, trading_pair=self.config.maker_trading_pair), + ConnectorPair(connector_name=self.config.taker_connector, trading_pair=self.config.taker_trading_pair) + ]: + base, quote = connector_pair.trading_pair.split("-") + + # Add rate source for gas token if it's an AMM connector + if connector_pair.is_amm_connector(): + gas_token = self.get_gas_token(connector_pair.connector_name) + if gas_token and gas_token != base and gas_token != quote: + rates_required.append(ConnectorPair(connector_name=self.config.maker_connector, + trading_pair=f"{base}-{gas_token}")) + + # Add rate source for trading pairs + rates_required.append(connector_pair) + + if len(rates_required) > 0: + self.market_data_provider.initialize_rate_sources(rates_required) + + def _initialize_gas_tokens(self): + """Initialize gas tokens for AMM connectors during controller initialization.""" + import asyncio + + async def fetch_gas_tokens(): + for connector_name in [self.config.maker_connector, self.config.taker_connector]: + connector_pair = ConnectorPair(connector_name=connector_name, trading_pair="") + if connector_pair.is_amm_connector(): + if connector_name not in self._gas_token_cache: + try: + gateway_client = GatewayHttpClient.get_instance() + + # Get chain and network for the connector + chain, network, error = await gateway_client.get_connector_chain_network( + connector_name + ) + + if error: + self.logger().warning(f"Failed to get chain info for {connector_name}: {error}") + continue + + # Get native currency symbol + native_currency = await gateway_client.get_native_currency_symbol(chain, network) + + if native_currency: + self._gas_token_cache[connector_name] = native_currency + self.logger().info(f"Gas token for {connector_name}: {native_currency}") + else: + self.logger().warning(f"Failed to get native currency for {connector_name}") + except Exception as e: + self.logger().error(f"Error getting gas token for {connector_name}: {e}") + + # Run the async function to fetch gas tokens + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(fetch_gas_tokens()) + else: + loop.run_until_complete(fetch_gas_tokens()) + + def get_gas_token(self, connector_name: str) -> Optional[str]: + """Get the cached gas token for a connector.""" + return self._gas_token_cache.get(connector_name) async def update_processed_data(self): pass def determine_executor_actions(self) -> List[ExecutorAction]: executor_actions = [] - mid_price = self.market_data_provider.get_price_by_type(self.config.maker_connector, - self.config.maker_trading_pair, PriceType.MidPrice) + mid_price = self.market_data_provider.get_price_by_type(self.config.maker_connector, self.config.maker_trading_pair, PriceType.MidPrice) active_buy_executors = self.filter_executors( executors=self.executors_info, filter_func=lambda e: not e.is_done and e.config.maker_side == TradeType.BUY @@ -115,18 +156,30 @@ def determine_executor_actions(self) -> List[ExecutorAction]: ) stopped_buy_executors = self.filter_executors( executors=self.executors_info, - filter_func=lambda e: e.is_done and e.config.maker_side == TradeType.BUY and e.filled_amount != 0 + filter_func=lambda e: e.is_done and e.config.maker_side == TradeType.BUY and e.filled_amount_quote != 0 ) stopped_sell_executors = self.filter_executors( executors=self.executors_info, - filter_func=lambda e: e.is_done and e.config.maker_side == TradeType.SELL and e.filled_amount != 0 + filter_func=lambda e: e.is_done and e.config.maker_side == TradeType.SELL and e.filled_amount_quote != 0 ) imbalance = len(stopped_buy_executors) - len(stopped_sell_executors) + + # Calculate total amounts for proportional allocation + total_buy_amount = sum(amount for _, amount in self.buy_levels_targets_amount) + total_sell_amount = sum(amount for _, amount in self.sell_levels_targets_amount) + + # Allocate 50% of total_amount_quote to each side + buy_side_quote = self.config.total_amount_quote * Decimal("0.5") + sell_side_quote = self.config.total_amount_quote * Decimal("0.5") + for target_profitability, amount in self.buy_levels_targets_amount: - active_buy_executors_target = [e.config.target_profitability == target_profitability for e in - active_buy_executors] + active_buy_executors_target = [e.config.target_profitability == target_profitability for e in active_buy_executors] if len(active_buy_executors_target) == 0 and imbalance < self.config.max_executors_imbalance: + # Calculate proportional amount: (level_amount / total_side_amount) * (total_quote * 0.5) + proportional_amount_quote = (amount / total_buy_amount) * buy_side_quote + min_profitability = target_profitability - self.config.min_profitability + max_profitability = target_profitability + self.config.max_profitability config = XEMMExecutorConfig( controller_id=self.config.id, timestamp=self.market_data_provider.time(), @@ -135,16 +188,19 @@ def determine_executor_actions(self) -> List[ExecutorAction]: selling_market=ConnectorPair(connector_name=self.config.taker_connector, trading_pair=self.config.taker_trading_pair), maker_side=TradeType.BUY, - order_amount=amount / mid_price, - min_profitability=self.config.min_profitability, + order_amount=proportional_amount_quote / mid_price, + min_profitability=min_profitability, target_profitability=target_profitability, - max_profitability=self.config.max_profitability + max_profitability=max_profitability ) executor_actions.append(CreateExecutorAction(executor_config=config, controller_id=self.config.id)) for target_profitability, amount in self.sell_levels_targets_amount: - active_sell_executors_target = [e.config.target_profitability == target_profitability for e in - active_sell_executors] + active_sell_executors_target = [e.config.target_profitability == target_profitability for e in active_sell_executors] if len(active_sell_executors_target) == 0 and imbalance > -self.config.max_executors_imbalance: + # Calculate proportional amount: (level_amount / total_side_amount) * (total_quote * 0.5) + proportional_amount_quote = (amount / total_sell_amount) * sell_side_quote + min_profitability = target_profitability - self.config.min_profitability + max_profitability = target_profitability + self.config.max_profitability config = XEMMExecutorConfig( controller_id=self.config.id, timestamp=time.time(), @@ -153,10 +209,10 @@ def determine_executor_actions(self) -> List[ExecutorAction]: selling_market=ConnectorPair(connector_name=self.config.maker_connector, trading_pair=self.config.maker_trading_pair), maker_side=TradeType.SELL, - order_amount=amount / mid_price, - min_profitability=self.config.min_profitability, + order_amount=proportional_amount_quote / mid_price, + min_profitability=min_profitability, target_profitability=target_profitability, - max_profitability=self.config.max_profitability + max_profitability=max_profitability ) executor_actions.append(CreateExecutorAction(executor_config=config, controller_id=self.config.id)) return executor_actions diff --git a/bots/controllers/market_making/dman_maker_v2.py b/bots/controllers/market_making/dman_maker_v2.py index b822d83b..3ead968c 100644 --- a/bots/controllers/market_making/dman_maker_v2.py +++ b/bots/controllers/market_making/dman_maker_v2.py @@ -2,16 +2,15 @@ from typing import List, Optional import pandas_ta as ta # noqa: F401 -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator + from hummingbot.core.data_type.common import TradeType -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.market_making_controller_base import ( MarketMakingControllerBase, MarketMakingControllerConfigBase, ) from hummingbot.strategy_v2.executors.dca_executor.data_types import DCAExecutorConfig, DCAMode from hummingbot.strategy_v2.models.executor_actions import ExecutorAction, StopExecutorAction -from pydantic import Field, validator class DManMakerV2Config(MarketMakingControllerConfigBase): @@ -19,43 +18,19 @@ class DManMakerV2Config(MarketMakingControllerConfigBase): Configuration required to run the D-Man Maker V2 strategy. """ controller_name: str = "dman_maker_v2" - candles_config: List[CandlesConfig] = [] # DCA configuration dca_spreads: List[Decimal] = Field( default="0.01,0.02,0.04,0.08", - client_data=ClientFieldData( - prompt_on_new=True, - prompt=lambda mi: "Enter a comma-separated list of spreads for each DCA level: ")) + json_schema_extra={"prompt": "Enter a comma-separated list of spreads for each DCA level: ", "prompt_on_new": True}) dca_amounts: List[Decimal] = Field( default="0.1,0.2,0.4,0.8", - client_data=ClientFieldData( - prompt_on_new=True, - prompt=lambda mi: "Enter a comma-separated list of amounts for each DCA level: ")) - time_limit: int = Field( - default=60 * 60 * 24 * 7, gt=0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the time limit for each DCA level: ", - prompt_on_new=False)) - stop_loss: Decimal = Field( - default=Decimal("0.03"), gt=0, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the stop loss (as a decimal, e.g., 0.03 for 3%): ", - prompt_on_new=True)) - top_executor_refresh_time: Optional[float] = Field( - default=None, - client_data=ClientFieldData( - is_updatable=True, - prompt_on_new=False)) - executor_activation_bounds: Optional[List[Decimal]] = Field( - default=None, - client_data=ClientFieldData( - is_updatable=True, - prompt=lambda mi: "Enter the activation bounds for the orders " - "(e.g., 0.01 activates the next order when the price is closer than 1%): ", - prompt_on_new=False)) + json_schema_extra={"prompt": "Enter a comma-separated list of amounts for each DCA level: ", "prompt_on_new": True}) + top_executor_refresh_time: Optional[float] = Field(default=None, json_schema_extra={"is_updatable": True}) + executor_activation_bounds: Optional[List[Decimal]] = Field(default=None, json_schema_extra={"is_updatable": True}) - @validator("executor_activation_bounds", pre=True, always=True) + @field_validator("executor_activation_bounds", mode="before") + @classmethod def parse_activation_bounds(cls, v): if isinstance(v, list): return [Decimal(val) for val in v] @@ -65,8 +40,9 @@ def parse_activation_bounds(cls, v): return [Decimal(val) for val in v.split(",")] return v - @validator('dca_spreads', pre=True, always=True) - def parse_spreads(cls, v): + @field_validator('dca_spreads', mode="before") + @classmethod + def parse_dca_spreads(cls, v): if v is None: return [] if isinstance(v, str): @@ -75,15 +51,16 @@ def parse_spreads(cls, v): return [float(x.strip()) for x in v.split(',')] return v - @validator('dca_amounts', pre=True, always=True) - def parse_and_validate_amounts(cls, v, values, field): + @field_validator('dca_amounts', mode="before") + @classmethod + def parse_and_validate_dca_amounts(cls, v, validation_info): if v is None or v == "": - return [1 for _ in values[values['dca_spreads']]] + return [1 for _ in validation_info.data['dca_spreads']] if isinstance(v, str): return [float(x.strip()) for x in v.split(',')] - elif isinstance(v, list) and len(v) != len(values['dca_spreads']): + elif isinstance(v, list) and len(v) != len(validation_info.data['dca_spreads']): raise ValueError( - f"The number of {field.name} must match the number of {values['dca_spreads']}.") + f"The number of dca amounts must match the number of {validation_info.data['dca_spreads']}.") return v @@ -97,17 +74,16 @@ def __init__(self, config: DManMakerV2Config, *args, **kwargs): def first_level_refresh_condition(self, executor): if self.config.top_executor_refresh_time is not None: if self.get_level_from_level_id(executor.custom_info["level_id"]) == 0: - return self.market_data_provider.time() - executor.timestamp > self.config.top_executor_refresh_time * 1000 + return self.market_data_provider.time() - executor.timestamp > self.config.top_executor_refresh_time return False def order_level_refresh_condition(self, executor): - return self.market_data_provider.time() - executor.timestamp > self.config.executor_refresh_time * 1000 + return self.market_data_provider.time() - executor.timestamp > self.config.executor_refresh_time def executors_to_refresh(self) -> List[ExecutorAction]: executors_to_refresh = self.filter_executors( executors=self.executors_info, - filter_func=lambda x: not x.is_trading and x.is_active and ( - self.order_level_refresh_condition(x) or self.first_level_refresh_condition(x))) + filter_func=lambda x: not x.is_trading and x.is_active and (self.order_level_refresh_condition(x) or self.first_level_refresh_condition(x))) return [StopExecutorAction( controller_id=self.config.id, executor_id=executor.id) for executor in executors_to_refresh] diff --git a/bots/controllers/market_making/pmm_dynamic.py b/bots/controllers/market_making/pmm_dynamic.py index baad3fda..adb062d5 100644 --- a/bots/controllers/market_making/pmm_dynamic.py +++ b/bots/controllers/market_making/pmm_dynamic.py @@ -2,82 +2,71 @@ from typing import List import pandas_ta as ta # noqa: F401 -from hummingbot.client.config.config_data_types import ClientFieldData +from pydantic import Field, field_validator +from pydantic_core.core_schema import ValidationInfo + from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.market_making_controller_base import ( MarketMakingControllerBase, MarketMakingControllerConfigBase, ) from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig -from pydantic import Field, validator class PMMDynamicControllerConfig(MarketMakingControllerConfigBase): - controller_name = "pmm_dynamic" - candles_config: List[CandlesConfig] = [] + controller_name: str = "pmm_dynamic" buy_spreads: List[float] = Field( default="1,2,4", - client_data=ClientFieldData( - is_updatable=True, - prompt_on_new=True, - prompt=lambda mi: "Enter a comma-separated list of buy spreads (e.g., '0.01, 0.02'):")) + json_schema_extra={ + "prompt": "Enter a comma-separated list of buy spreads measured in units of volatility(e.g., '1, 2'): ", + "prompt_on_new": True, "is_updatable": True} + ) sell_spreads: List[float] = Field( default="1,2,4", - client_data=ClientFieldData( - is_updatable=True, - prompt_on_new=True, - prompt=lambda mi: "Enter a comma-separated list of sell spreads (e.g., '0.01, 0.02'):")) + json_schema_extra={ + "prompt": "Enter a comma-separated list of sell spreads measured in units of volatility(e.g., '1, 2'): ", + "prompt_on_new": True, "is_updatable": True} + ) candles_connector: str = Field( default=None, - client_data=ClientFieldData( - prompt_on_new=True, - prompt=lambda mi: "Enter the connector for the candles data, leave empty to use the same " - "exchange as the connector: ", ) - ) + json_schema_extra={ + "prompt": "Enter the connector for the candles data, leave empty to use the same exchange as the connector: ", + "prompt_on_new": True}) candles_trading_pair: str = Field( default=None, - client_data=ClientFieldData( - prompt_on_new=True, - prompt=lambda mi: "Enter the trading pair for the candles data, leave empty to use the same " - "trading pair as the connector: ", ) - ) + json_schema_extra={ + "prompt": "Enter the trading pair for the candles data, leave empty to use the same trading pair as the connector: ", + "prompt_on_new": True}) interval: str = Field( default="3m", - client_data=ClientFieldData( - prompt=lambda mi: "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", - prompt_on_new=False)) - + json_schema_extra={ + "prompt": "Enter the candle interval (e.g., 1m, 5m, 1h, 1d): ", + "prompt_on_new": True}) macd_fast: int = Field( - default=12, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the MACD fast length: ", - prompt_on_new=True)) + default=21, + json_schema_extra={"prompt": "Enter the MACD fast period: ", "prompt_on_new": True}) macd_slow: int = Field( - default=26, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the MACD slow length: ", - prompt_on_new=True)) + default=42, + json_schema_extra={"prompt": "Enter the MACD slow period: ", "prompt_on_new": True}) macd_signal: int = Field( default=9, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the MACD signal length: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the MACD signal period: ", "prompt_on_new": True}) natr_length: int = Field( default=14, - client_data=ClientFieldData( - prompt=lambda mi: "Enter the NATR length: ", - prompt_on_new=True)) + json_schema_extra={"prompt": "Enter the NATR length: ", "prompt_on_new": True}) - @validator("candles_connector", pre=True, always=True) - def set_candles_connector(cls, v, values): + @field_validator("candles_connector", mode="before") + @classmethod + def set_candles_connector(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("connector_name") + return validation_info.data.get("connector_name") return v - @validator("candles_trading_pair", pre=True, always=True) - def set_candles_trading_pair(cls, v, values): + @field_validator("candles_trading_pair", mode="before") + @classmethod + def set_candles_trading_pair(cls, v, validation_info: ValidationInfo): if v is None or v == "": - return values.get("trading_pair") + return validation_info.data.get("trading_pair") return v @@ -89,14 +78,7 @@ class PMMDynamicController(MarketMakingControllerBase): def __init__(self, config: PMMDynamicControllerConfig, *args, **kwargs): self.config = config - self.max_records = max(config.macd_slow, config.macd_fast, config.macd_signal, config.natr_length) + 10 - if len(self.config.candles_config) == 0: - self.config.candles_config = [CandlesConfig( - connector=config.candles_connector, - trading_pair=config.candles_trading_pair, - interval=config.interval, - max_records=self.max_records - )] + self.max_records = max(config.macd_slow, config.macd_fast, config.macd_signal, config.natr_length) + 100 super().__init__(config, *args, **kwargs) async def update_processed_data(self): @@ -134,3 +116,11 @@ def get_executor_config(self, level_id: str, price: Decimal, amount: Decimal): leverage=self.config.leverage, side=trade_type, ) + + def get_candles_config(self) -> List[CandlesConfig]: + return [CandlesConfig( + connector=self.config.candles_connector, + trading_pair=self.config.candles_trading_pair, + interval=self.config.interval, + max_records=self.max_records + )] diff --git a/bots/controllers/market_making/pmm_simple.py b/bots/controllers/market_making/pmm_simple.py index 4e47e404..821755ec 100644 --- a/bots/controllers/market_making/pmm_simple.py +++ b/bots/controllers/market_making/pmm_simple.py @@ -1,20 +1,14 @@ from decimal import Decimal -from typing import List -from hummingbot.client.config.config_data_types import ClientFieldData -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig from hummingbot.strategy_v2.controllers.market_making_controller_base import ( MarketMakingControllerBase, MarketMakingControllerConfigBase, ) from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig -from pydantic import Field class PMMSimpleConfig(MarketMakingControllerConfigBase): - controller_name = "pmm_simple" - # As this controller is a simple version of the PMM, we are not using the candles feed - candles_config: List[CandlesConfig] = Field(default=[], client_data=ClientFieldData(prompt_on_new=False)) + controller_name: str = "pmm_simple" class PMMSimpleController(MarketMakingControllerBase): diff --git a/bots/credentials/master_account/.password_verification b/bots/credentials/master_account/.password_verification deleted file mode 100644 index b8c76184..00000000 --- a/bots/credentials/master_account/.password_verification +++ /dev/null @@ -1 +0,0 @@ -7b2263727970746f223a207b22636970686572223a20226165732d3132382d637472222c2022636970686572706172616d73223a207b226976223a20223864336365306436393461623131396334363135663935366464653839363063227d2c202263697068657274657874223a20223836333266323430613563306131623665353664222c20226b6466223a202270626b646632222c20226b6466706172616d73223a207b2263223a20313030303030302c2022646b6c656e223a2033322c2022707266223a2022686d61632d736861323536222c202273616c74223a20226566373330376531636464373964376132303338323534656139343433663930227d2c20226d6163223a202266393439383534613530633138363633386363353962336133363665633962353333386633613964373266636635343066313034333361353431636232306438227d2c202276657273696f6e223a20337d \ No newline at end of file diff --git a/bots/credentials/master_account/conf_client.yml b/bots/credentials/master_account/conf_client.yml index ccd8729a..cbba78d4 100644 --- a/bots/credentials/master_account/conf_client.yml +++ b/bots/credentials/master_account/conf_client.yml @@ -24,8 +24,6 @@ kill_switch_mode: {} # What to auto-fill in the prompt after each import command (start/config) autofill_import: disabled -telegram_mode: {} - # MQTT Bridge configuration. mqtt_bridge: mqtt_host: localhost @@ -44,9 +42,6 @@ mqtt_bridge: # Error log sharing send_error_logs: true -# Can store the previous strategy ran for quick retrieval. -previous_strategy: some-strategy.yml - # Advanced database options, currently supports SQLAlchemy's included dialects # Reference: https://docs.sqlalchemy.org/en/13/dialects/ # To use an instance of SQLite DB the required configuration is @@ -60,8 +55,6 @@ previous_strategy: some-strategy.yml db_mode: db_engine: sqlite -pmm_script_mode: {} - # Balance Limit Configurations # e.g. Setting USDT and BTC limits on Binance. # balance_asset_limit: @@ -111,26 +104,12 @@ manual_gas_price: 50.0 gateway: gateway_api_host: localhost gateway_api_port: '15888' - -certs_path: /Users/dardonacci/Documents/work/hummingbot/certs + gateway_use_ssl: false # Whether to enable aggregated order and trade data collection anonymized_metrics_mode: anonymized_metrics_interval_min: 15.0 -# Command Shortcuts -# Define abbreviations for often used commands -# or batch grouped commands together -command_shortcuts: -- command: spreads - help: Set bid and ask spread - arguments: - - Bid Spread - - Ask Spread - output: - - config bid_spread $1 - - config ask_spread $2 - # A source for rate oracle, currently ascend_ex, binance, coin_gecko, coin_cap, kucoin, gate_io rate_oracle_source: name: binance @@ -192,6 +171,6 @@ color: tick_size: 1.0 market_data_collection: - market_data_collection_enabled: true + market_data_collection_enabled: false market_data_collection_interval: 60 market_data_collection_depth: 20 diff --git a/bots/scripts/v2_with_controllers.py b/bots/scripts/v2_with_controllers.py index a2c1c8ab..241041d4 100644 --- a/bots/scripts/v2_with_controllers.py +++ b/bots/scripts/v2_with_controllers.py @@ -1,26 +1,22 @@ import os -import time -from typing import Dict, List, Optional, Set +from decimal import Decimal +from typing import Dict, List, Optional from hummingbot.client.hummingbot_application import HummingbotApplication from hummingbot.connector.connector_base import ConnectorBase -from hummingbot.core.clock import Clock -from hummingbot.data_feed.candles_feed.data_types import CandlesConfig -from hummingbot.remote_iface.mqtt import ETopicPublisher +from hummingbot.core.event.events import MarketOrderFailureEvent from hummingbot.strategy.strategy_v2_base import StrategyV2Base, StrategyV2ConfigBase from hummingbot.strategy_v2.models.base import RunnableStatus from hummingbot.strategy_v2.models.executor_actions import CreateExecutorAction, StopExecutorAction -from pydantic import Field -class GenericV2StrategyWithCashOutConfig(StrategyV2ConfigBase): - script_file_name: str = Field(default_factory=lambda: os.path.basename(__file__)) - candles_config: List[CandlesConfig] = [] - markets: Dict[str, Set[str]] = {} - time_to_cash_out: Optional[int] = None +class V2WithControllersConfig(StrategyV2ConfigBase): + script_file_name: str = os.path.basename(__file__) + max_global_drawdown_quote: Optional[float] = None + max_controller_drawdown_quote: Optional[float] = None -class GenericV2StrategyWithCashOut(StrategyV2Base): +class V2WithControllers(StrategyV2Base): """ This script runs a generic strategy with cash out feature. Will also check if the controllers configs have been updated and apply the new settings. @@ -31,68 +27,90 @@ class GenericV2StrategyWithCashOut(StrategyV2Base): specific controller and wait until the active executors finalize their execution. The rest of the executors will wait until the main strategy stops them. """ + performance_report_interval: int = 1 - def __init__(self, connectors: Dict[str, ConnectorBase], config: GenericV2StrategyWithCashOutConfig): + def __init__(self, connectors: Dict[str, ConnectorBase], config: V2WithControllersConfig): super().__init__(connectors, config) self.config = config - self.cashing_out = False + self.max_pnl_by_controller = {} + self.max_global_pnl = Decimal("0") + self.drawdown_exited_controllers = [] self.closed_executors_buffer: int = 30 - self.performance_report_interval: int = 1 self._last_performance_report_timestamp = 0 - hb_app = HummingbotApplication.main_application() - self.mqtt_enabled = hb_app._mqtt is not None - self._pub: Optional[ETopicPublisher] = None - if self.config.time_to_cash_out: - self.cash_out_time = self.config.time_to_cash_out + time.time() - else: - self.cash_out_time = None - - def start(self, clock: Clock, timestamp: float) -> None: - """ - Start the strategy. - :param clock: Clock to use. - :param timestamp: Current time. - """ - self._last_timestamp = timestamp - self.apply_initial_setting() - if self.mqtt_enabled: - self._pub = ETopicPublisher("performance", use_bot_prefix=True) - - def on_stop(self): - if self.mqtt_enabled: - self._pub({controller_id: {} for controller_id in self.controllers.keys()}) - self._pub = None def on_tick(self): super().on_tick() - self.control_cash_out() - self.send_performance_report() + if not self._is_stop_triggered: + self.check_manual_kill_switch() + self.control_max_drawdown() + self.send_performance_report() + + def control_max_drawdown(self): + if self.config.max_controller_drawdown_quote: + self.check_max_controller_drawdown() + if self.config.max_global_drawdown_quote: + self.check_max_global_drawdown() + + def check_max_controller_drawdown(self): + for controller_id, controller in self.controllers.items(): + if controller.status != RunnableStatus.RUNNING: + continue + controller_pnl = self.get_performance_report(controller_id).global_pnl_quote + last_max_pnl = self.max_pnl_by_controller[controller_id] + if controller_pnl > last_max_pnl: + self.max_pnl_by_controller[controller_id] = controller_pnl + else: + current_drawdown = last_max_pnl - controller_pnl + if current_drawdown > self.config.max_controller_drawdown_quote: + self.logger().info(f"Controller {controller_id} reached max drawdown. Stopping the controller.") + controller.stop() + executors_order_placed = self.filter_executors( + executors=self.get_executors_by_controller(controller_id), + filter_func=lambda x: x.is_active and not x.is_trading, + ) + self.executor_orchestrator.execute_actions( + actions=[ + StopExecutorAction(controller_id=controller_id, executor_id=executor.id) + for executor in executors_order_placed + ] + ) + self.drawdown_exited_controllers.append(controller_id) + + def check_max_global_drawdown(self): + current_global_pnl = sum([ + self.get_performance_report(controller_id).global_pnl_quote + for controller_id in self.controllers.keys() + ]) + if current_global_pnl > self.max_global_pnl: + self.max_global_pnl = current_global_pnl + else: + current_global_drawdown = self.max_global_pnl - current_global_pnl + if current_global_drawdown > self.config.max_global_drawdown_quote: + self.drawdown_exited_controllers.extend(list(self.controllers.keys())) + self.logger().info("Global drawdown reached. Stopping the strategy.") + self._is_stop_triggered = True + HummingbotApplication.main_application().stop() + + def get_controller_report(self, controller_id: str) -> dict: + """ + Get the full report for a controller including performance and custom info. + """ + performance_report = self.controller_reports.get(controller_id, {}).get("performance") + return { + "performance": performance_report.dict() if performance_report else {}, + "custom_info": self.controllers[controller_id].get_custom_info() + } def send_performance_report(self): - if self.current_timestamp - self._last_performance_report_timestamp >= self.performance_report_interval \ - and self.mqtt_enabled: - performance_reports = {controller_id: self.executor_orchestrator.generate_performance_report( - controller_id=controller_id).dict() for controller_id in self.controllers.keys()} - self._pub(performance_reports) + if self.current_timestamp - self._last_performance_report_timestamp >= self.performance_report_interval and self._pub: + controller_reports = { + controller_id: self.get_controller_report(controller_id) + for controller_id in self.controllers.keys() + } + self._pub(controller_reports) self._last_performance_report_timestamp = self.current_timestamp - def control_cash_out(self): - self.evaluate_cash_out_time() - if self.cashing_out: - self.check_executors_status() - else: - self.check_manual_cash_out() - - def evaluate_cash_out_time(self): - if self.cash_out_time and self.current_timestamp >= self.cash_out_time and not self.cashing_out: - self.logger().info("Cash out time reached. Stopping the controllers.") - for controller_id, controller in self.controllers.items(): - if controller.status == RunnableStatus.RUNNING: - self.logger().info(f"Cash out for controller {controller_id}.") - controller.stop() - self.cashing_out = True - - def check_manual_cash_out(self): + def check_manual_kill_switch(self): for controller_id, controller in self.controllers.items(): if controller.config.manual_kill_switch and controller.status == RunnableStatus.RUNNING: self.logger().info(f"Manual cash out for controller {controller_id}.") @@ -102,6 +120,8 @@ def check_manual_cash_out(self): [StopExecutorAction(executor_id=executor.id, controller_id=executor.controller_id) for executor in executors_to_stop]) if not controller.config.manual_kill_switch and controller.status == RunnableStatus.TERMINATED: + if controller_id in self.drawdown_exited_controllers: + continue self.logger().info(f"Restarting controller {controller_id}.") controller.start() @@ -131,14 +151,30 @@ def stop_actions_proposal(self) -> List[StopExecutorAction]: def apply_initial_setting(self): connectors_position_mode = {} for controller_id, controller in self.controllers.items(): - config_dict = controller.config.dict() + self.max_pnl_by_controller[controller_id] = Decimal("0") + config_dict = controller.config.model_dump() if "connector_name" in config_dict: if self.is_perpetual(config_dict["connector_name"]): if "position_mode" in config_dict: connectors_position_mode[config_dict["connector_name"]] = config_dict["position_mode"] - if "leverage" in config_dict: - self.connectors[config_dict["connector_name"]].set_leverage(leverage=config_dict["leverage"], - trading_pair=config_dict[ - "trading_pair"]) + if "leverage" in config_dict and "trading_pair" in config_dict: + self.connectors[config_dict["connector_name"]].set_leverage( + leverage=config_dict["leverage"], + trading_pair=config_dict["trading_pair"]) for connector_name, position_mode in connectors_position_mode.items(): self.connectors[connector_name].set_position_mode(position_mode) + + def did_fail_order(self, order_failed_event: MarketOrderFailureEvent): + """ + Handle order failure events by logging the error and stopping the strategy if necessary. + """ + if order_failed_event.error_message and "position side" in order_failed_event.error_message.lower(): + connectors_position_mode = {} + for controller_id, controller in self.controllers.items(): + config_dict = controller.config.model_dump() + if "connector_name" in config_dict: + if self.is_perpetual(config_dict["connector_name"]): + if "position_mode" in config_dict: + connectors_position_mode[config_dict["connector_name"]] = config_dict["position_mode"] + for connector_name, position_mode in connectors_position_mode.items(): + self.connectors[connector_name].set_position_mode(position_mode) diff --git a/config.py b/config.py index b36db450..5c2bc63b 100644 --- a/config.py +++ b/config.py @@ -1,15 +1,136 @@ -import os +from typing import List +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict -from dotenv import load_dotenv -load_dotenv() +class BrokerSettings(BaseSettings): + """MQTT Broker configuration for bot communication.""" + + host: str = Field(default="localhost", description="MQTT broker host") + port: int = Field(default=1883, description="MQTT broker port") + username: str = Field(default="admin", description="MQTT broker username") + password: str = Field(default="password", description="MQTT broker password") -CONTROLLERS_PATH = "bots/conf/controllers" -CONTROLLERS_MODULE = "bots.controllers" -CONFIG_PASSWORD = os.getenv("CONFIG_PASSWORD", "a") -BROKER_HOST = os.getenv("BROKER_HOST", "localhost") -BROKER_PORT = int(os.getenv("BROKER_PORT", 1883)) -BROKER_USERNAME = os.getenv("BROKER_USERNAME", "admin") -BROKER_PASSWORD = os.getenv("BROKER_PASSWORD", "password") -PASSWORD_VERIFICATION_PATH = "bots/credentials/master_account/.password_verification" -BANNED_TOKENS = os.getenv("BANNED_TOKENS", "NAV,ARS,ETHW").split(",") \ No newline at end of file + model_config = SettingsConfigDict(env_prefix="BROKER_", extra="ignore") + + +class DatabaseSettings(BaseSettings): + """Database configuration.""" + + url: str = Field( + default="postgresql+asyncpg://hbot:hummingbot-api@localhost:5432/hummingbot_api", + description="Database connection URL" + ) + + model_config = SettingsConfigDict(env_prefix="DATABASE_", extra="ignore") + + +class MarketDataSettings(BaseSettings): + """Market data feed manager configuration.""" + + cleanup_interval: int = Field( + default=300, + description="How often to run feed cleanup in seconds" + ) + feed_timeout: int = Field( + default=600, + description="How long to keep unused feeds alive in seconds" + ) + candles_ready_timeout: int = Field( + default=30, + description="How long to wait for a candle feed to become ready in seconds" + ) + + model_config = SettingsConfigDict(env_prefix="MARKET_DATA_", extra="ignore") + + +class SecuritySettings(BaseSettings): + """Security and authentication configuration.""" + + username: str = Field(default="admin", description="API basic auth username") + password: str = Field(default="admin", description="API basic auth password") + debug_mode: bool = Field(default=False, description="Enable debug mode (disables auth)") + config_password: str = Field(default="a", description="Bot configuration encryption password") + + model_config = SettingsConfigDict( + env_prefix="", + extra="ignore" # Ignore extra environment variables + ) + + +class AWSSettings(BaseSettings): + """AWS configuration for S3 archiving.""" + + api_key: str = Field(default="", description="AWS API key") + secret_key: str = Field(default="", description="AWS secret key") + s3_default_bucket_name: str = Field(default="", description="Default S3 bucket for archiving") + + model_config = SettingsConfigDict(env_prefix="AWS_", extra="ignore") + + +class GatewaySettings(BaseSettings): + """Gateway service configuration.""" + + url: str = Field( + default="http://localhost:15888", + description="Gateway service URL (use 'http://gateway:15888' when running in Docker)" + ) + + model_config = SettingsConfigDict(env_prefix="GATEWAY_", extra="ignore") + + +class AppSettings(BaseSettings): + """Main application settings.""" + + # Static paths + controllers_path: str = "bots/conf/controllers" + controllers_module: str = "bots.controllers" + password_verification_path: str = "credentials/master_account/.password_verification" + + # Environment-configurable settings + logfire_environment: str = Field( + default="dev", + description="Logfire environment name" + ) + + # Account state update interval + account_update_interval: int = Field( + default=5, + description="How often to update account states in minutes" + ) + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + case_sensitive=False, + extra="ignore" + ) + + +class Settings(BaseSettings): + """Combined application settings.""" + + broker: BrokerSettings = Field(default_factory=BrokerSettings) + database: DatabaseSettings = Field(default_factory=DatabaseSettings) + market_data: MarketDataSettings = Field(default_factory=MarketDataSettings) + security: SecuritySettings = Field(default_factory=SecuritySettings) + aws: AWSSettings = Field(default_factory=AWSSettings) + gateway: GatewaySettings = Field(default_factory=GatewaySettings) + app: AppSettings = Field(default_factory=AppSettings) + + # Direct banned_tokens field to handle env parsing + banned_tokens: List[str] = Field( + default=["NAV", "ARS", "ETHW", "ETHF"], + description="List of banned trading tokens" + ) + + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + env_prefix="", + extra="ignore" + ) + + +# Create global settings instance +settings = Settings() diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 00000000..b0380ac7 --- /dev/null +++ b/database/__init__.py @@ -0,0 +1,19 @@ +from .models import ( + AccountState, TokenState, Order, Trade, PositionSnapshot, FundingPayment, BotRun, + GatewaySwap, GatewayCLMMPosition, GatewayCLMMEvent, + Base +) +from .connection import AsyncDatabaseManager +from .repositories import ( + AccountRepository, BotRunRepository, + OrderRepository, TradeRepository, FundingRepository, + GatewaySwapRepository, GatewayCLMMRepository +) + +__all__ = [ + "AccountState", "TokenState", "Order", "Trade", "PositionSnapshot", "FundingPayment", "BotRun", + "GatewaySwap", "GatewayCLMMPosition", "GatewayCLMMEvent", + "Base", "AsyncDatabaseManager", + "AccountRepository", "BotRunRepository", "OrderRepository", "TradeRepository", "FundingRepository", + "GatewaySwapRepository", "GatewayCLMMRepository" +] \ No newline at end of file diff --git a/database/connection.py b/database/connection.py new file mode 100644 index 00000000..adcaf82d --- /dev/null +++ b/database/connection.py @@ -0,0 +1,143 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from .models import Base + +logger = logging.getLogger(__name__) + + +class AsyncDatabaseManager: + def __init__(self, database_url: str): + # Convert postgresql:// to postgresql+asyncpg:// for async support + if database_url.startswith("postgresql://"): + database_url = database_url.replace("postgresql://", "postgresql+asyncpg://") + + self.engine = create_async_engine( + database_url, + # Connection pool settings for async + pool_size=5, + max_overflow=10, + pool_timeout=30, + pool_recycle=1800, # Recycle connections after 30 minutes + pool_pre_ping=True, # Test connections before using them + # Engine settings + echo=False, # Set to True for SQL query logging + echo_pool=False, # Set to True for connection pool logging + # Connection arguments for asyncpg + connect_args={ + "server_settings": {"application_name": "hummingbot-api"}, + "command_timeout": 60, + } + ) + self.async_session = async_sessionmaker( + self.engine, + class_=AsyncSession, + expire_on_commit=False + ) + + async def create_tables(self): + """Create all tables defined in the models.""" + try: + async with self.engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Run lightweight migrations for existing tables + await self._run_migrations(conn) + + # Drop Hummingbot's native tables since we use our custom orders/trades tables + await self._drop_hummingbot_tables(conn) + + logger.info("Database tables created successfully") + except Exception as e: + logger.error(f"Failed to create database tables: {e}") + raise + + async def _run_migrations(self, conn): + """Run lightweight schema migrations for existing tables.""" + migrations = [ + # Add controller_id to executors table (default "main" for existing rows) + ( + "executors", "controller_id", + "ALTER TABLE executors ADD COLUMN controller_id TEXT NOT NULL DEFAULT 'main'" + ), + ] + for table, column, sql in migrations: + try: + # Check if column already exists + result = await conn.execute( + text( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name = :table AND column_name = :column" + ), + {"table": table, "column": column} + ) + if result.fetchone() is None: + await conn.execute(text(sql)) + logger.info(f"Migration: added {column} to {table}") + except Exception as e: + # Column-already-exists is expected on repeat startups + err_msg = str(e).lower() + if "already exists" in err_msg or "duplicate column" in err_msg: + logger.debug(f"Migration check for {table}.{column}: {e}") + else: + logger.warning(f"Unexpected migration error for {table}.{column}: {e}") + + async def _drop_hummingbot_tables(self, conn): + """Drop Hummingbot's native database tables since we use custom ones.""" + hummingbot_tables = [ + "hummingbot_orders", + "hummingbot_trade_fills", + "hummingbot_order_status" + ] + + for table_name in hummingbot_tables: + try: + await conn.execute(text(f"DROP TABLE IF EXISTS {table_name}")) + logger.info(f"Dropped Hummingbot table: {table_name}") + except Exception as e: + logger.debug(f"Could not drop table {table_name}: {e}") # Use debug since table might not exist + + async def close(self): + """Close all database connections.""" + await self.engine.dispose() + logger.info("Database connections closed") + + def get_session(self) -> AsyncSession: + """Get a new database session.""" + return self.async_session() + + @asynccontextmanager + async def get_session_context(self) -> AsyncGenerator[AsyncSession, None]: + """ + Get a database session with automatic error handling and cleanup. + Usage: + async with db_manager.get_session_context() as session: + # Use session here + """ + async with self.async_session() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() + + async def health_check(self) -> bool: + """ + Check if the database connection is healthy. + Returns: + bool: True if connection is healthy, False otherwise. + """ + try: + async with self.engine.connect() as conn: + await conn.execute(text("SELECT 1")) + return True + except Exception as e: + logger.error(f"Database health check failed: {e}") + return False diff --git a/database/models.py b/database/models.py new file mode 100644 index 00000000..524a6a0b --- /dev/null +++ b/database/models.py @@ -0,0 +1,418 @@ +from sqlalchemy import TIMESTAMP, Column, ForeignKey, Integer, Numeric, String, Text, func +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +Base = declarative_base() + + +class AccountState(Base): + __tablename__ = "account_states" + + id = Column(Integer, primary_key=True, index=True) + timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + + token_states = relationship("TokenState", back_populates="account_state", cascade="all, delete-orphan") + + +class TokenState(Base): + __tablename__ = "token_states" + + id = Column(Integer, primary_key=True, index=True) + account_state_id = Column(Integer, ForeignKey("account_states.id"), nullable=False) + token = Column(String, nullable=False, index=True) + units = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=False) + value = Column(Numeric(precision=30, scale=18), nullable=False) + available_units = Column(Numeric(precision=30, scale=18), nullable=False) + + account_state = relationship("AccountState", back_populates="token_states") + + +class Order(Base): + __tablename__ = "orders" + + id = Column(Integer, primary_key=True, index=True) + # Order identification + client_order_id = Column(String, nullable=False, unique=True, index=True) + exchange_order_id = Column(String, nullable=True, index=True) + + # Timestamps + created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + updated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + + # Account and connector info + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + + # Order details + trading_pair = Column(String, nullable=False, index=True) + trade_type = Column(String, nullable=False) # BUY, SELL + order_type = Column(String, nullable=False) # LIMIT, MARKET, LIMIT_MAKER + amount = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=True) # Null for market orders + + # Order status and execution + status = Column(String, nullable=False, default="SUBMITTED", index=True) # SUBMITTED, OPEN, FILLED, CANCELLED, FAILED + filled_amount = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + average_fill_price = Column(Numeric(precision=30, scale=18), nullable=True) + + # Fee information + fee_paid = Column(Numeric(precision=30, scale=18), default=0, nullable=True) + fee_currency = Column(String, nullable=True) + + # Additional metadata + error_message = Column(Text, nullable=True) + + # Relationships for future enhancements + trades = relationship("Trade", back_populates="order", cascade="all, delete-orphan") + + +class Trade(Base): + __tablename__ = "trades" + + id = Column(Integer, primary_key=True, index=True) + order_id = Column(Integer, ForeignKey("orders.id"), nullable=False) + + # Trade identification + trade_id = Column(String, nullable=False, unique=True, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), nullable=False, index=True) + + # Trade details + trading_pair = Column(String, nullable=False, index=True) + trade_type = Column(String, nullable=False) # BUY, SELL + amount = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=False) + + # Fee information + fee_paid = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + fee_currency = Column(String, nullable=True) + + # Relationship + order = relationship("Order", back_populates="trades") + + +class PositionSnapshot(Base): + __tablename__ = "position_snapshots" + + id = Column(Integer, primary_key=True, index=True) + + # Position identification + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + trading_pair = Column(String, nullable=False, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + + # Real-time exchange data (from connector.account_positions) + side = Column(String, nullable=False) # LONG, SHORT + exchange_size = Column(Numeric(precision=30, scale=18), nullable=False) # Size from exchange + entry_price = Column(Numeric(precision=30, scale=18), nullable=True) # Average entry price + mark_price = Column(Numeric(precision=30, scale=18), nullable=True) # Current mark price + + # Real-time PnL data (can't be derived from trades alone) + unrealized_pnl = Column(Numeric(precision=30, scale=18), nullable=True) # From exchange + percentage_pnl = Column(Numeric(precision=10, scale=6), nullable=True) # PnL percentage + + # Leverage and margin info + leverage = Column(Numeric(precision=10, scale=2), nullable=True) # Position leverage + initial_margin = Column(Numeric(precision=30, scale=18), nullable=True) # Initial margin + maintenance_margin = Column(Numeric(precision=30, scale=18), nullable=True) # Maintenance margin + + # Fee tracking (exchange provides cumulative data) + cumulative_funding_fees = Column(Numeric(precision=30, scale=18), nullable=False, default=0) # Funding fees + fee_currency = Column(String, nullable=True) # Fee currency (usually USDT) + + # Reconciliation fields (calculated from our trade data) + calculated_size = Column(Numeric(precision=30, scale=18), nullable=True) # Size from our trades + calculated_entry_price = Column(Numeric(precision=30, scale=18), nullable=True) # Entry from our trades + size_difference = Column(Numeric(precision=30, scale=18), nullable=True) # Difference for reconciliation + + # Additional metadata + exchange_position_id = Column(String, nullable=True, index=True) # Exchange position ID + is_reconciled = Column(String, nullable=False, default="PENDING") # RECONCILED, MISMATCH, PENDING + + +class FundingPayment(Base): + __tablename__ = "funding_payments" + + id = Column(Integer, primary_key=True, index=True) + + # Payment identification + funding_payment_id = Column(String, nullable=False, unique=True, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), nullable=False, index=True) + + # Account and connector info + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + + # Funding details + trading_pair = Column(String, nullable=False, index=True) + funding_rate = Column(Numeric(precision=20, scale=18), nullable=False) # Funding rate + funding_payment = Column(Numeric(precision=30, scale=18), nullable=False) # Payment amount + fee_currency = Column(String, nullable=False) # Payment currency (usually USDT) + + # Position association + position_size = Column(Numeric(precision=30, scale=18), nullable=True) # Position size at time of payment + position_side = Column(String, nullable=True) # LONG, SHORT + + # Additional metadata + exchange_funding_id = Column(String, nullable=True, index=True) # Exchange funding ID + + +class BotRun(Base): + __tablename__ = "bot_runs" + + id = Column(Integer, primary_key=True, index=True) + + # Bot identification + bot_name = Column(String, nullable=False, index=True) + instance_name = Column(String, nullable=False, index=True) + + # Deployment info + deployed_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + strategy_type = Column(String, nullable=False, index=True) # 'script' or 'controller' + strategy_name = Column(String, nullable=False, index=True) + config_name = Column(String, nullable=True, index=True) + + # Runtime tracking + stopped_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) + + # Status tracking + deployment_status = Column(String, nullable=False, default="DEPLOYED", index=True) # DEPLOYED, FAILED, ARCHIVED + run_status = Column(String, nullable=False, default="CREATED", index=True) # CREATED, RUNNING, STOPPED, ERROR + + # Configuration and final state + deployment_config = Column(Text, nullable=True) # JSON of full deployment config + final_status = Column(Text, nullable=True) # JSON of final bot state, performance, etc. + + # Account info + account_name = Column(String, nullable=False, index=True) + + # Metadata + image_version = Column(String, nullable=True, index=True) + error_message = Column(Text, nullable=True) + + +class GatewaySwap(Base): + __tablename__ = "gateway_swaps" + + id = Column(Integer, primary_key=True, index=True) + + # Transaction identification + transaction_hash = Column(String, nullable=False, unique=True, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + + # Network and connector info (unified format) + network = Column(String, nullable=False, index=True) # chain-network format: solana-mainnet-beta, ethereum-mainnet + connector = Column(String, nullable=False, index=True) # jupiter, 0x, etc. + wallet_address = Column(String, nullable=False, index=True) + + # Swap details + trading_pair = Column(String, nullable=False, index=True) + base_token = Column(String, nullable=False, index=True) + quote_token = Column(String, nullable=False, index=True) + side = Column(String, nullable=False) # BUY, SELL + + # Amounts + input_amount = Column(Numeric(precision=30, scale=18), nullable=False) + output_amount = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=False) + + # Slippage and fees + slippage_pct = Column(Numeric(precision=10, scale=6), nullable=True) + gas_fee = Column(Numeric(precision=30, scale=18), nullable=True) + gas_token = Column(String, nullable=True) # SOL, ETH, etc. + + # Status + status = Column(String, nullable=False, default="SUBMITTED", index=True) # SUBMITTED, CONFIRMED, FAILED + + # Pool information (optional) + pool_address = Column(String, nullable=True, index=True) + + # Additional metadata + quote_id = Column(String, nullable=True) # If swap was from a quote + error_message = Column(Text, nullable=True) + + +class GatewayCLMMPosition(Base): + __tablename__ = "gateway_clmm_positions" + + id = Column(Integer, primary_key=True, index=True) + + # Position identification + position_address = Column(String, nullable=False, unique=True, index=True) # CLMM position NFT address + pool_address = Column(String, nullable=False, index=True) + + # Network and connector info (unified format) + network = Column(String, nullable=False, index=True) # chain-network format: solana-mainnet-beta, ethereum-mainnet + connector = Column(String, nullable=False, index=True) # meteora, raydium, uniswap + wallet_address = Column(String, nullable=False, index=True) + + # Position pair + trading_pair = Column(String, nullable=False, index=True) + base_token = Column(String, nullable=False, index=True) + quote_token = Column(String, nullable=False, index=True) + + # Timestamps + created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + closed_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) + + # Status + status = Column(String, nullable=False, default="OPEN", index=True) # OPEN, CLOSED + + # Price range (CLMM) + lower_price = Column(Numeric(precision=30, scale=18), nullable=False) + upper_price = Column(Numeric(precision=30, scale=18), nullable=False) + lower_bin_id = Column(Integer, nullable=True) # For bin-based CLMM (Meteora) + upper_bin_id = Column(Integer, nullable=True) + + # Price tracking for PnL calculation + entry_price = Column(Numeric(precision=30, scale=18), nullable=True) # Pool price when position opened + current_price = Column(Numeric(precision=30, scale=18), nullable=True) # Latest price (becomes close price when closed) + + # Initial deposit amounts (for PnL calculation) + initial_base_token_amount = Column(Numeric(precision=30, scale=18), nullable=True) + initial_quote_token_amount = Column(Numeric(precision=30, scale=18), nullable=True) + + # Position rent (SOL locked for position NFT, returned on close) + position_rent = Column(Numeric(precision=30, scale=18), nullable=True) + + # Current liquidity amounts + base_token_amount = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + quote_token_amount = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + + # In range status + in_range = Column(String, nullable=False, default="UNKNOWN") # IN_RANGE, OUT_OF_RANGE, UNKNOWN + + # Price range percentage: (upper_price - lower_price) / lower_price + percentage = Column(Numeric(precision=10, scale=6), nullable=True) + + # Accumulated fees (CLMM) + base_fee_collected = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + quote_fee_collected = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + base_fee_pending = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + quote_fee_pending = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + + # Last update timestamp + last_updated = Column(TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + + # Relationships + events = relationship("GatewayCLMMEvent", back_populates="position", cascade="all, delete-orphan") + + +class GatewayCLMMEvent(Base): + __tablename__ = "gateway_clmm_events" + + id = Column(Integer, primary_key=True, index=True) + position_id = Column(Integer, ForeignKey("gateway_clmm_positions.id"), nullable=False) + + # Event identification + transaction_hash = Column(String, nullable=False, index=True) + + # Timestamps + timestamp = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + + # Event type + event_type = Column(String, nullable=False, index=True) # OPEN, ADD_LIQUIDITY, REMOVE_LIQUIDITY, COLLECT_FEES, CLOSE + + # Event amounts + base_token_amount = Column(Numeric(precision=30, scale=18), nullable=True) + quote_token_amount = Column(Numeric(precision=30, scale=18), nullable=True) + + # For fee collection + base_fee_collected = Column(Numeric(precision=30, scale=18), nullable=True) + quote_fee_collected = Column(Numeric(precision=30, scale=18), nullable=True) + + # Gas fee + gas_fee = Column(Numeric(precision=30, scale=18), nullable=True) + gas_token = Column(String, nullable=True) + + # Status + status = Column(String, nullable=False, default="SUBMITTED", index=True) # SUBMITTED, CONFIRMED, FAILED + error_message = Column(Text, nullable=True) + + # Relationship + position = relationship("GatewayCLMMPosition", back_populates="events") + + +class ExecutorRecord(Base): + """Database model for executor state persistence.""" + __tablename__ = "executors" + + id = Column(Integer, primary_key=True, index=True) + + # Executor identification + executor_id = Column(String, nullable=False, unique=True, index=True) + executor_type = Column(String, nullable=False, index=True) + + # Account and connector info + account_name = Column(String, nullable=False, index=True) + connector_name = Column(String, nullable=False, index=True) + trading_pair = Column(String, nullable=False, index=True) + controller_id = Column(String, nullable=False, default="main", index=True) + + # Timestamps + created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True) + closed_at = Column(TIMESTAMP(timezone=True), nullable=True, index=True) + + # Status + status = Column(String, nullable=False, default="RUNNING", index=True) + close_type = Column(String, nullable=True) + + # Performance metrics + net_pnl_quote = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + net_pnl_pct = Column(Numeric(precision=10, scale=6), nullable=False, default=0) + cum_fees_quote = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + filled_amount_quote = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + + # Configuration (JSON) + config = Column(Text, nullable=True) + + # Final state (JSON) + final_state = Column(Text, nullable=True) + + # Relationships + orders = relationship("ExecutorOrder", back_populates="executor", cascade="all, delete-orphan") + + +class ExecutorOrder(Base): + """Database model for orders created by executors.""" + __tablename__ = "executor_orders" + + id = Column(Integer, primary_key=True, index=True) + + # Executor reference + executor_id = Column(String, ForeignKey("executors.executor_id"), nullable=False, index=True) + + # Order identification + client_order_id = Column(String, nullable=False, index=True) + exchange_order_id = Column(String, nullable=True) + + # Order details + order_type = Column(String, nullable=False) # open, close, take_profit, stop_loss + trade_type = Column(String, nullable=False) # BUY, SELL + amount = Column(Numeric(precision=30, scale=18), nullable=False) + price = Column(Numeric(precision=30, scale=18), nullable=True) + + # Execution + status = Column(String, nullable=False, default="SUBMITTED") + filled_amount = Column(Numeric(precision=30, scale=18), nullable=False, default=0) + average_fill_price = Column(Numeric(precision=30, scale=18), nullable=True) + + # Timestamps + created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False) + updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now(), nullable=True) + + # Relationship + executor = relationship("ExecutorRecord", back_populates="orders") + + diff --git a/database/repositories/__init__.py b/database/repositories/__init__.py new file mode 100644 index 00000000..bc8dd107 --- /dev/null +++ b/database/repositories/__init__.py @@ -0,0 +1,19 @@ +from .account_repository import AccountRepository +from .bot_run_repository import BotRunRepository +from .executor_repository import ExecutorRepository +from .funding_repository import FundingRepository +from .gateway_clmm_repository import GatewayCLMMRepository +from .gateway_swap_repository import GatewaySwapRepository +from .order_repository import OrderRepository +from .trade_repository import TradeRepository + +__all__ = [ + "AccountRepository", + "BotRunRepository", + "ExecutorRepository", + "FundingRepository", + "OrderRepository", + "TradeRepository", + "GatewaySwapRepository", + "GatewayCLMMRepository", +] \ No newline at end of file diff --git a/database/repositories/account_repository.py b/database/repositories/account_repository.py new file mode 100644 index 00000000..a799c130 --- /dev/null +++ b/database/repositories/account_repository.py @@ -0,0 +1,438 @@ +from datetime import datetime, timedelta +from decimal import Decimal +from typing import Dict, List, Optional, Tuple +import base64 +import json + +from sqlalchemy import desc, select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from database import AccountState, TokenState + + +class AccountRepository: + def __init__(self, session: AsyncSession): + self.session = session + + @staticmethod + def _interval_to_minutes(interval: str) -> int: + """Convert interval string to minutes.""" + interval_map = { + "5m": 5, + "15m": 15, + "30m": 30, + "1h": 60, + "4h": 240, + "12h": 720, + "1d": 1440 + } + return interval_map.get(interval, 5) # Default to 5 minutes + + @staticmethod + def _sample_history_by_interval(history: List[Dict], interval_minutes: int) -> List[Dict]: + """ + Sample historical data points based on the specified interval. + + Args: + history: List of historical data points sorted by timestamp (descending) + interval_minutes: Sampling interval in minutes + + Returns: + Sampled list of data points + """ + if not history or interval_minutes <= 5: + return history # Return all data for 5m or less + + sampled = [] + last_sampled_time = None + + for item in history: + item_time = datetime.fromisoformat(item["timestamp"].replace('Z', '+00:00')) + + if last_sampled_time is None: + # Always include the first (most recent) data point + sampled.append(item) + last_sampled_time = item_time + else: + # Check if enough time has passed since last sampled point + time_diff = (last_sampled_time - item_time).total_seconds() / 60 + if time_diff >= interval_minutes: + sampled.append(item) + last_sampled_time = item_time + + return sampled + + async def save_account_state(self, account_name: str, connector_name: str, tokens_info: List[Dict], + snapshot_timestamp: Optional[datetime] = None) -> AccountState: + """ + Save account state with token information to the database. + If snapshot_timestamp is provided, use it instead of server default. + """ + account_state_data = { + "account_name": account_name, + "connector_name": connector_name + } + + # If a specific timestamp is provided, use it instead of server default + if snapshot_timestamp: + account_state_data["timestamp"] = snapshot_timestamp + + account_state = AccountState(**account_state_data) + + self.session.add(account_state) + await self.session.flush() # Get the ID + + for token_info in tokens_info: + token_state = TokenState( + account_state_id=account_state.id, + token=token_info["token"], + units=Decimal(str(token_info["units"])), + price=Decimal(str(token_info["price"])), + value=Decimal(str(token_info["value"])), + available_units=Decimal(str(token_info["available_units"])) + ) + self.session.add(token_state) + + await self.session.commit() + return account_state + + async def get_latest_account_states(self) -> Dict[str, Dict[str, List[Dict]]]: + """ + Get the latest account states for all accounts and connectors. + """ + # Get the latest timestamp for each account-connector combination + subquery = ( + select( + AccountState.account_name, + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .group_by(AccountState.account_name, AccountState.connector_name) + .subquery() + ) + + # Get the full records for the latest timestamps + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .join( + subquery, + (AccountState.account_name == subquery.c.account_name) & + (AccountState.connector_name == subquery.c.connector_name) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + ) + + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + + # Convert to the expected format + accounts_state = {} + for account_state in account_states: + if account_state.account_name not in accounts_state: + accounts_state[account_state.account_name] = {} + + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + accounts_state[account_state.account_name][account_state.connector_name] = token_info + + return accounts_state + + async def get_account_state_history(self, + limit: Optional[int] = None, + account_name: Optional[str] = None, + connector_name: Optional[str] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + interval: str = "5m") -> Tuple[List[Dict], Optional[str], bool]: + """ + Get historical account states with cursor-based pagination and interval sampling. + + Args: + limit: Maximum number of records to return + account_name: Filter by account name + connector_name: Filter by connector name + cursor: Cursor for pagination + start_time: Start time filter + end_time: End time filter + interval: Sampling interval (5m, 15m, 30m, 1h, 4h, 12h, 1d) + + Returns: + Tuple of (data, next_cursor, has_more) + """ + interval_minutes = self._interval_to_minutes(interval) + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .order_by(desc(AccountState.timestamp)) + ) + + # Apply filters + if account_name: + query = query.filter(AccountState.account_name == account_name) + if connector_name: + query = query.filter(AccountState.connector_name == connector_name) + if start_time: + query = query.filter(AccountState.timestamp >= start_time) + if end_time: + query = query.filter(AccountState.timestamp <= end_time) + + # Handle cursor-based pagination + if cursor: + try: + cursor_time = datetime.fromisoformat(cursor.replace('Z', '+00:00')) + query = query.filter(AccountState.timestamp < cursor_time) + except (ValueError, TypeError): + # Invalid cursor, ignore it + pass + + # Fetch more records than requested to ensure we have enough after sampling + # For intervals > 5m, we need to fetch more data to get enough sampled points + sampling_multiplier = max(1, interval_minutes // 5) # How many 5m intervals per sample + fetch_limit = (limit * sampling_multiplier + 1) if limit else (100 * sampling_multiplier + 1) + query = query.limit(fetch_limit) + + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + + # Format response - Group by minute to aggregate account/connector states + minute_groups = {} + for account_state in account_states: + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + # Round timestamp to the nearest minute for grouping + minute_timestamp = account_state.timestamp.replace(second=0, microsecond=0) + minute_key = minute_timestamp.isoformat() + + # Initialize minute group if it doesn't exist + if minute_key not in minute_groups: + minute_groups[minute_key] = { + "timestamp": minute_key, + "state": {} + } + + # Add account/connector to the minute group + if account_state.account_name not in minute_groups[minute_key]["state"]: + minute_groups[minute_key]["state"][account_state.account_name] = {} + + minute_groups[minute_key]["state"][account_state.account_name][account_state.connector_name] = token_info + + # Convert to list and maintain chronological order (most recent first) + history = list(minute_groups.values()) + history.sort(key=lambda x: x["timestamp"], reverse=True) + + # Apply interval sampling + sampled_history = self._sample_history_by_interval(history, interval_minutes) + + # Apply limit and check if there are more records after sampling + has_more = len(sampled_history) > limit if limit else False + if has_more: + sampled_history = sampled_history[:limit] + + # Generate next cursor from the last sampled item + next_cursor = None + if has_more and sampled_history: + next_cursor = sampled_history[-1]["timestamp"] + + return sampled_history, next_cursor, has_more + + async def get_account_current_state(self, account_name: str) -> Dict[str, List[Dict]]: + """ + Get the current state for a specific account. + """ + subquery = ( + select( + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .filter(AccountState.account_name == account_name) + .group_by(AccountState.connector_name) + .subquery() + ) + + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .join( + subquery, + (AccountState.connector_name == subquery.c.connector_name) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + .filter(AccountState.account_name == account_name) + ) + + result = await self.session.execute(query) + account_states = result.unique().scalars().all() + + state = {} + for account_state in account_states: + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + state[account_state.connector_name] = token_info + + return state + + async def get_connector_current_state(self, account_name: str, connector_name: str) -> List[Dict]: + """ + Get the current state for a specific connector. + """ + query = ( + select(AccountState) + .options(joinedload(AccountState.token_states)) + .filter( + AccountState.account_name == account_name, + AccountState.connector_name == connector_name + ) + .order_by(desc(AccountState.timestamp)) + .limit(1) + ) + + result = await self.session.execute(query) + account_state = result.unique().scalar_one_or_none() + + if not account_state: + return [] + + token_info = [] + for token_state in account_state.token_states: + token_info.append({ + "token": token_state.token, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + return token_info + + async def get_all_unique_tokens(self) -> List[str]: + """ + Get all unique tokens across all accounts and connectors. + """ + query = ( + select(TokenState.token) + .distinct() + .order_by(TokenState.token) + ) + + result = await self.session.execute(query) + tokens = result.scalars().all() + + return list(tokens) + + async def get_token_current_state(self, token: str) -> List[Dict]: + """ + Get current state of a specific token across all accounts. + """ + # Get latest timestamps for each account-connector combination + subquery = ( + select( + AccountState.id, + AccountState.account_name, + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .group_by(AccountState.account_name, AccountState.connector_name, AccountState.id) + .subquery() + ) + + query = ( + select(TokenState, AccountState.account_name, AccountState.connector_name) + .join(AccountState) + .join( + subquery, + (AccountState.id == subquery.c.id) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + .filter(TokenState.token == token) + ) + + result = await self.session.execute(query) + token_states = result.all() + + states = [] + for token_state, account_name, connector_name in token_states: + states.append({ + "account_name": account_name, + "connector_name": connector_name, + "units": float(token_state.units), + "price": float(token_state.price), + "value": float(token_state.value), + "available_units": float(token_state.available_units) + }) + + return states + + async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict: + """ + Get total portfolio value, optionally filtered by account. + """ + # Get latest timestamps + subquery = ( + select( + AccountState.account_name, + AccountState.connector_name, + func.max(AccountState.timestamp).label("max_timestamp") + ) + .group_by(AccountState.account_name, AccountState.connector_name) + ) + + if account_name: + subquery = subquery.filter(AccountState.account_name == account_name) + + subquery = subquery.subquery() + + # Get token values + query = ( + select( + AccountState.account_name, + func.sum(TokenState.value).label("total_value") + ) + .join(TokenState) + .join( + subquery, + (AccountState.account_name == subquery.c.account_name) & + (AccountState.connector_name == subquery.c.connector_name) & + (AccountState.timestamp == subquery.c.max_timestamp) + ) + .group_by(AccountState.account_name) + ) + + result = await self.session.execute(query) + values = result.all() + + portfolio = { + "accounts": {}, + "total_value": 0 + } + + for account, value in values: + portfolio["accounts"][account] = float(value or 0) + portfolio["total_value"] += float(value or 0) + + return portfolio \ No newline at end of file diff --git a/database/repositories/bot_run_repository.py b/database/repositories/bot_run_repository.py new file mode 100644 index 00000000..3999389e --- /dev/null +++ b/database/repositories/bot_run_repository.py @@ -0,0 +1,191 @@ +import json +from datetime import datetime, timezone +from typing import Dict, List, Optional, Any + +from sqlalchemy import desc, select, and_, or_, func +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import BotRun + + +class BotRunRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_bot_run( + self, + bot_name: str, + instance_name: str, + strategy_type: str, # 'script' or 'controller' + strategy_name: str, + account_name: str, + config_name: Optional[str] = None, + image_version: Optional[str] = None, + deployment_config: Optional[Dict[str, Any]] = None + ) -> BotRun: + """Create a new bot run record.""" + bot_run = BotRun( + bot_name=bot_name, + instance_name=instance_name, + strategy_type=strategy_type, + strategy_name=strategy_name, + config_name=config_name, + account_name=account_name, + image_version=image_version, + deployment_config=json.dumps(deployment_config) if deployment_config else None, + deployment_status="DEPLOYED", + run_status="CREATED" + ) + + self.session.add(bot_run) + await self.session.flush() + await self.session.refresh(bot_run) + return bot_run + + + async def update_bot_run_stopped( + self, + bot_name: str, + final_status: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None + ) -> Optional[BotRun]: + """Mark a bot run as stopped and save final status.""" + stmt = select(BotRun).where( + and_( + BotRun.bot_name == bot_name, + or_(BotRun.run_status == "RUNNING", BotRun.run_status == "CREATED") + ) + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + bot_run = result.scalar_one_or_none() + + if bot_run: + bot_run.run_status = "STOPPED" if not error_message else "ERROR" + bot_run.stopped_at = datetime.utcnow() + bot_run.final_status = json.dumps(final_status) if final_status else None + bot_run.error_message = error_message + await self.session.flush() + await self.session.refresh(bot_run) + + return bot_run + + async def update_bot_run_archived(self, bot_name: str) -> Optional[BotRun]: + """Mark a bot run as archived.""" + stmt = select(BotRun).where( + BotRun.bot_name == bot_name + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + bot_run = result.scalar_one_or_none() + + if bot_run: + bot_run.deployment_status = "ARCHIVED" + bot_run.stopped_at = datetime.now(timezone.utc) + await self.session.flush() + await self.session.refresh(bot_run) + + return bot_run + + async def get_bot_runs( + self, + bot_name: Optional[str] = None, + account_name: Optional[str] = None, + strategy_type: Optional[str] = None, + strategy_name: Optional[str] = None, + run_status: Optional[str] = None, + deployment_status: Optional[str] = None, + limit: int = 100, + offset: int = 0 + ) -> List[BotRun]: + """Get bot runs with optional filters.""" + stmt = select(BotRun) + + conditions = [] + if bot_name: + conditions.append(BotRun.bot_name == bot_name) + if account_name: + conditions.append(BotRun.account_name == account_name) + if strategy_type: + conditions.append(BotRun.strategy_type == strategy_type) + if strategy_name: + conditions.append(BotRun.strategy_name == strategy_name) + if run_status: + conditions.append(BotRun.run_status == run_status) + if deployment_status: + conditions.append(BotRun.deployment_status == deployment_status) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(BotRun.deployed_at)).limit(limit).offset(offset) + + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_bot_run_by_id(self, bot_run_id: int) -> Optional[BotRun]: + """Get a specific bot run by ID.""" + stmt = select(BotRun).where(BotRun.id == bot_run_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_latest_bot_run(self, bot_name: str) -> Optional[BotRun]: + """Get the latest bot run for a specific bot.""" + stmt = select(BotRun).where( + BotRun.bot_name == bot_name + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_active_bot_runs(self) -> List[BotRun]: + """Get all currently active (running) bot runs.""" + stmt = select(BotRun).where( + and_( + BotRun.run_status == "RUNNING", + BotRun.deployment_status == "DEPLOYED" + ) + ).order_by(desc(BotRun.deployed_at)) + + result = await self.session.execute(stmt) + return result.scalars().all() + + async def get_bot_run_stats(self) -> Dict[str, Any]: + """Get statistics about bot runs.""" + # Total runs + total_stmt = select(func.count(BotRun.id)) + total_result = await self.session.execute(total_stmt) + total_runs = total_result.scalar() + + # Active runs + active_stmt = select(func.count(BotRun.id)).where( + and_( + BotRun.run_status == "RUNNING", + BotRun.deployment_status == "DEPLOYED" + ) + ) + active_result = await self.session.execute(active_stmt) + active_runs = active_result.scalar() + + # Runs by strategy type + strategy_stmt = select( + BotRun.strategy_type, + func.count(BotRun.id).label('count') + ).group_by(BotRun.strategy_type) + strategy_result = await self.session.execute(strategy_stmt) + strategy_counts = {row.strategy_type: row.count for row in strategy_result} + + # Runs by status + status_stmt = select( + BotRun.run_status, + func.count(BotRun.id).label('count') + ).group_by(BotRun.run_status) + status_result = await self.session.execute(status_stmt) + status_counts = {row.run_status: row.count for row in status_result} + + return { + "total_runs": total_runs, + "active_runs": active_runs, + "strategy_type_counts": strategy_counts, + "status_counts": status_counts + } \ No newline at end of file diff --git a/database/repositories/executor_repository.py b/database/repositories/executor_repository.py new file mode 100644 index 00000000..760ca39d --- /dev/null +++ b/database/repositories/executor_repository.py @@ -0,0 +1,457 @@ +""" +Repository for executor database operations. +""" +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Dict, List, Optional + +from sqlalchemy import and_, case, desc, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import ExecutorOrder, ExecutorRecord + + +class ExecutorRepository: + """Repository for ExecutorRecord and ExecutorOrder database operations.""" + + def __init__(self, session: AsyncSession): + self.session = session + + # ======================================== + # ExecutorRecord Operations + # ======================================== + + async def create_executor( + self, + executor_id: str, + executor_type: str, + account_name: str, + connector_name: str, + trading_pair: str, + config: Optional[str] = None, + status: str = "RUNNING", + controller_id: str = "main" + ) -> ExecutorRecord: + """Create a new executor record.""" + executor = ExecutorRecord( + executor_id=executor_id, + executor_type=executor_type, + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + controller_id=controller_id, + config=config, + status=status + ) + + self.session.add(executor) + await self.session.flush() + await self.session.refresh(executor) + return executor + + async def update_executor( + self, + executor_id: str, + status: Optional[str] = None, + close_type: Optional[str] = None, + net_pnl_quote: Optional[Decimal] = None, + net_pnl_pct: Optional[Decimal] = None, + cum_fees_quote: Optional[Decimal] = None, + filled_amount_quote: Optional[Decimal] = None, + final_state: Optional[str] = None + ) -> Optional[ExecutorRecord]: + """Update an executor record.""" + stmt = select(ExecutorRecord).where(ExecutorRecord.executor_id == executor_id) + result = await self.session.execute(stmt) + executor = result.scalar_one_or_none() + + if executor: + if status is not None: + executor.status = status + if close_type is not None: + executor.close_type = close_type + executor.closed_at = datetime.now(timezone.utc) + if net_pnl_quote is not None: + executor.net_pnl_quote = net_pnl_quote + if net_pnl_pct is not None: + executor.net_pnl_pct = net_pnl_pct + if cum_fees_quote is not None: + executor.cum_fees_quote = cum_fees_quote + if filled_amount_quote is not None: + executor.filled_amount_quote = filled_amount_quote + if final_state is not None: + executor.final_state = final_state + + await self.session.flush() + await self.session.refresh(executor) + + return executor + + async def get_executor_by_id(self, executor_id: str) -> Optional[ExecutorRecord]: + """Get an executor by ID.""" + stmt = select(ExecutorRecord).where(ExecutorRecord.executor_id == executor_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def get_executors( + self, + account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + executor_type: Optional[str] = None, + status: Optional[str] = None, + controller_id: Optional[str] = None, + limit: int = 100, + offset: int = 0 + ) -> List[ExecutorRecord]: + """Get executors with optional filters.""" + stmt = select(ExecutorRecord) + + conditions = [] + if account_name: + conditions.append(ExecutorRecord.account_name == account_name) + if connector_name: + conditions.append(ExecutorRecord.connector_name == connector_name) + if trading_pair: + conditions.append(ExecutorRecord.trading_pair == trading_pair) + if executor_type: + conditions.append(ExecutorRecord.executor_type == executor_type) + if status: + conditions.append(ExecutorRecord.status == status) + if controller_id: + conditions.append(ExecutorRecord.controller_id == controller_id) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(ExecutorRecord.created_at)).limit(limit).offset(offset) + + result = await self.session.execute(stmt) + return list(result.scalars().all()) + + async def get_active_executors( + self, + account_name: Optional[str] = None, + connector_name: Optional[str] = None + ) -> List[ExecutorRecord]: + """Get all active (running) executors.""" + stmt = select(ExecutorRecord).where(ExecutorRecord.status == "RUNNING") + + if account_name: + stmt = stmt.where(ExecutorRecord.account_name == account_name) + if connector_name: + stmt = stmt.where(ExecutorRecord.connector_name == connector_name) + + stmt = stmt.order_by(desc(ExecutorRecord.created_at)) + + result = await self.session.execute(stmt) + return list(result.scalars().all()) + + async def get_position_hold_executors( + self, + account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + controller_id: Optional[str] = None + ) -> List[ExecutorRecord]: + """Get executors that closed with POSITION_HOLD (keep_position=True).""" + stmt = select(ExecutorRecord).where(ExecutorRecord.close_type == "POSITION_HOLD") + + conditions = [] + if account_name: + conditions.append(ExecutorRecord.account_name == account_name) + if connector_name: + conditions.append(ExecutorRecord.connector_name == connector_name) + if trading_pair: + conditions.append(ExecutorRecord.trading_pair == trading_pair) + if controller_id: + conditions.append(ExecutorRecord.controller_id == controller_id) + + if conditions: + stmt = stmt.where(and_(*conditions)) + + stmt = stmt.order_by(desc(ExecutorRecord.created_at)) + + result = await self.session.execute(stmt) + return list(result.scalars().all()) + + async def get_executor_stats(self) -> Dict[str, Any]: + """Get statistics about executors.""" + # Total executors + total_stmt = select(func.count(ExecutorRecord.id)) + total_result = await self.session.execute(total_stmt) + total_executors = total_result.scalar() or 0 + + # Active executors + active_stmt = select(func.count(ExecutorRecord.id)).where( + ExecutorRecord.status == "RUNNING" + ) + active_result = await self.session.execute(active_stmt) + active_executors = active_result.scalar() or 0 + + # Total PnL + pnl_stmt = select(func.sum(ExecutorRecord.net_pnl_quote)) + pnl_result = await self.session.execute(pnl_stmt) + total_pnl = pnl_result.scalar() or Decimal("0") + + # Total volume + volume_stmt = select(func.sum(ExecutorRecord.filled_amount_quote)) + volume_result = await self.session.execute(volume_stmt) + total_volume = volume_result.scalar() or Decimal("0") + + # Executors by type + type_stmt = select( + ExecutorRecord.executor_type, + func.count(ExecutorRecord.id).label('count') + ).group_by(ExecutorRecord.executor_type) + type_result = await self.session.execute(type_stmt) + type_counts = {row.executor_type: row.count for row in type_result} + + # Executors by status + status_stmt = select( + ExecutorRecord.status, + func.count(ExecutorRecord.id).label('count') + ).group_by(ExecutorRecord.status) + status_result = await self.session.execute(status_stmt) + status_counts = {row.status: row.count for row in status_result} + + # Executors by connector + connector_stmt = select( + ExecutorRecord.connector_name, + func.count(ExecutorRecord.id).label('count') + ).group_by(ExecutorRecord.connector_name) + connector_result = await self.session.execute(connector_stmt) + connector_counts = {row.connector_name: row.count for row in connector_result} + + return { + "total_executors": total_executors, + "active_executors": active_executors, + "total_pnl_quote": float(total_pnl), + "total_volume_quote": float(total_volume), + "type_counts": type_counts, + "status_counts": status_counts, + "connector_counts": connector_counts + } + + async def get_performance_report( + self, + controller_id: Optional[str] = None + ) -> Dict[str, Any]: + """Get a performance report, optionally filtered by controller_id. + + Returns aggregate metrics: total executors, PnL, fees, volume, + win rate, per-executor PnL list (for Sharpe), and breakdown by type. + """ + base_filter = [] + if controller_id: + base_filter.append(ExecutorRecord.controller_id == controller_id) + + # --- Status counts --- + status_stmt = select( + ExecutorRecord.status, + func.count(ExecutorRecord.id).label("cnt"), + ).group_by(ExecutorRecord.status) + if base_filter: + status_stmt = status_stmt.where(and_(*base_filter)) + status_rows = await self.session.execute(status_stmt) + status_counts = {r.status: r.cnt for r in status_rows} + + total_executors = sum(status_counts.values()) + + # --- Aggregate PnL / fees / volume (completed only, excluding POSITION_HOLD to avoid double-counting) --- + completed_filter = base_filter + [ + ExecutorRecord.status != "RUNNING", + ExecutorRecord.close_type != "POSITION_HOLD", + ] + agg_stmt = select( + func.coalesce(func.sum(ExecutorRecord.net_pnl_quote), Decimal(0)).label("pnl"), + func.coalesce(func.sum(ExecutorRecord.cum_fees_quote), Decimal(0)).label("fees"), + func.coalesce(func.sum(ExecutorRecord.filled_amount_quote), Decimal(0)).label("vol"), + func.coalesce(func.avg(ExecutorRecord.net_pnl_pct), Decimal(0)).label("pnl_pct_avg"), + func.count(ExecutorRecord.id).label("completed_count"), + func.sum(case( + (ExecutorRecord.net_pnl_quote > 0, 1), + else_=0, + )).label("wins"), + ).where(and_(*completed_filter)) + agg_row = (await self.session.execute(agg_stmt)).one() + + completed_count = agg_row.completed_count or 0 + wins = agg_row.wins or 0 + win_rate = (wins / completed_count) if completed_count > 0 else 0.0 + + # --- Per-executor PnL list for Sharpe (excluding POSITION_HOLD) --- + pnl_list_stmt = select(ExecutorRecord.net_pnl_quote).where( + and_(*completed_filter) + ) + pnl_rows = await self.session.execute(pnl_list_stmt) + pnl_values = [float(r[0] or 0) for r in pnl_rows] + + # --- Breakdown by executor type (also excluding POSITION_HOLD to match aggregate totals) --- + type_stmt = select( + ExecutorRecord.executor_type, + func.count(ExecutorRecord.id).label("total"), + func.sum(case( + (ExecutorRecord.status != "RUNNING", 1), + else_=0, + )).label("completed"), + func.sum(case( + (ExecutorRecord.status == "RUNNING", 1), + else_=0, + )).label("running"), + func.coalesce(func.sum(ExecutorRecord.net_pnl_quote), Decimal(0)).label("pnl"), + func.coalesce(func.sum(ExecutorRecord.filled_amount_quote), Decimal(0)).label("vol"), + func.coalesce(func.sum(ExecutorRecord.cum_fees_quote), Decimal(0)).label("fees"), + ).where( + and_(*completed_filter) + ).group_by(ExecutorRecord.executor_type) + type_rows = await self.session.execute(type_stmt) + by_type = [ + { + "executor_type": r.executor_type, + "total": r.total, + "completed": r.completed or 0, + "running": r.running or 0, + "pnl_quote": float(r.pnl), + "volume_quote": float(r.vol), + "fees_quote": float(r.fees), + } + for r in type_rows + ] + + return { + "total_executors": total_executors, + "status_counts": status_counts, + "pnl_total_quote": float(agg_row.pnl), + "pnl_pct_avg": float(agg_row.pnl_pct_avg), + "fees_total_quote": float(agg_row.fees), + "volume_total_quote": float(agg_row.vol), + "win_rate": win_rate, + "pnl_values": pnl_values, + "by_type": by_type, + } + + # ======================================== + # ExecutorOrder Operations + # ======================================== + + async def create_executor_order( + self, + executor_id: str, + client_order_id: str, + order_type: str, + trade_type: str, + amount: Decimal, + price: Optional[Decimal] = None, + exchange_order_id: Optional[str] = None, + status: str = "SUBMITTED" + ) -> ExecutorOrder: + """Create a new executor order record.""" + order = ExecutorOrder( + executor_id=executor_id, + client_order_id=client_order_id, + order_type=order_type, + trade_type=trade_type, + amount=amount, + price=price, + exchange_order_id=exchange_order_id, + status=status + ) + + self.session.add(order) + await self.session.flush() + await self.session.refresh(order) + return order + + async def update_executor_order( + self, + client_order_id: str, + status: Optional[str] = None, + filled_amount: Optional[Decimal] = None, + average_fill_price: Optional[Decimal] = None, + exchange_order_id: Optional[str] = None + ) -> Optional[ExecutorOrder]: + """Update an executor order record.""" + stmt = select(ExecutorOrder).where(ExecutorOrder.client_order_id == client_order_id) + result = await self.session.execute(stmt) + order = result.scalar_one_or_none() + + if order: + if status is not None: + order.status = status + if filled_amount is not None: + order.filled_amount = filled_amount + if average_fill_price is not None: + order.average_fill_price = average_fill_price + if exchange_order_id is not None: + order.exchange_order_id = exchange_order_id + + await self.session.flush() + await self.session.refresh(order) + + return order + + async def get_executor_orders( + self, + executor_id: str, + status: Optional[str] = None + ) -> List[ExecutorOrder]: + """Get orders for an executor.""" + stmt = select(ExecutorOrder).where(ExecutorOrder.executor_id == executor_id) + + if status: + stmt = stmt.where(ExecutorOrder.status == status) + + stmt = stmt.order_by(desc(ExecutorOrder.created_at)) + + result = await self.session.execute(stmt) + return list(result.scalars().all()) + + async def get_order_by_client_id(self, client_order_id: str) -> Optional[ExecutorOrder]: + """Get an order by client order ID.""" + stmt = select(ExecutorOrder).where(ExecutorOrder.client_order_id == client_order_id) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + + async def cleanup_orphaned_executors( + self, + active_executor_ids: List[str], + close_type: str = "SYSTEM_CLEANUP" + ) -> int: + """ + Clean up orphaned executors - those marked as RUNNING but not in active memory. + Args: + active_executor_ids: List of executor IDs currently active in memory + close_type: Close type to set for cleaned up executors + Returns: + Number of executors cleaned up + """ + from sqlalchemy import update + + # Find executors that are RUNNING but not in the active list + conditions = [ExecutorRecord.status == "RUNNING"] + + if active_executor_ids: + conditions.append(~ExecutorRecord.executor_id.in_(active_executor_ids)) + + # First, get the count of orphaned executors for logging + count_stmt = select(func.count(ExecutorRecord.id)).where(and_(*conditions)) + count_result = await self.session.execute(count_stmt) + orphaned_count = count_result.scalar() or 0 + + if orphaned_count > 0: + # Update orphaned executors to TERMINATED status + update_stmt = ( + update(ExecutorRecord) + .where(and_(*conditions)) + .values( + status="TERMINATED", + close_type=close_type, + closed_at=datetime.now(timezone.utc) + ) + ) + + await self.session.execute(update_stmt) + await self.session.flush() + + return orphaned_count diff --git a/database/repositories/funding_repository.py b/database/repositories/funding_repository.py new file mode 100644 index 00000000..e9b8dd42 --- /dev/null +++ b/database/repositories/funding_repository.py @@ -0,0 +1,84 @@ +from datetime import datetime +from typing import Dict, List, Optional +from decimal import Decimal + +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import FundingPayment + + +class FundingRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_funding_payment(self, funding_data: Dict) -> FundingPayment: + """Create a new funding payment record.""" + funding = FundingPayment(**funding_data) + self.session.add(funding) + await self.session.flush() # Get the ID + return funding + + async def get_funding_payments(self, account_name: str, connector_name: str = None, + trading_pair: str = None, limit: int = 100) -> List[FundingPayment]: + """Get funding payments with optional filters.""" + query = select(FundingPayment).where(FundingPayment.account_name == account_name) + + if connector_name: + query = query.where(FundingPayment.connector_name == connector_name) + if trading_pair: + query = query.where(FundingPayment.trading_pair == trading_pair) + + query = query.order_by(FundingPayment.timestamp.desc()).limit(limit) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_total_funding_fees(self, account_name: str, connector_name: str, + trading_pair: str) -> Dict: + """Get total funding fees for a specific trading pair.""" + query = select(FundingPayment).where( + FundingPayment.account_name == account_name, + FundingPayment.connector_name == connector_name, + FundingPayment.trading_pair == trading_pair + ) + + result = await self.session.execute(query) + payments = result.scalars().all() + + total_funding = Decimal('0') + payment_count = 0 + + for payment in payments: + total_funding += Decimal(str(payment.funding_payment)) + payment_count += 1 + + return { + "total_funding_fees": float(total_funding), + "payment_count": payment_count, + "fee_currency": payments[0].fee_currency if payments else None + } + + async def funding_payment_exists(self, funding_payment_id: str) -> bool: + """Check if a funding payment already exists.""" + result = await self.session.execute( + select(FundingPayment).where(FundingPayment.funding_payment_id == funding_payment_id) + ) + return result.scalar_one_or_none() is not None + + def to_dict(self, funding: FundingPayment) -> Dict: + """Convert FundingPayment model to dictionary format.""" + return { + "id": funding.id, + "funding_payment_id": funding.funding_payment_id, + "timestamp": funding.timestamp.isoformat(), + "account_name": funding.account_name, + "connector_name": funding.connector_name, + "trading_pair": funding.trading_pair, + "funding_rate": float(funding.funding_rate), + "funding_payment": float(funding.funding_payment), + "fee_currency": funding.fee_currency, + "position_size": float(funding.position_size) if funding.position_size else None, + "position_side": funding.position_side, + "exchange_funding_id": funding.exchange_funding_id, + } \ No newline at end of file diff --git a/database/repositories/gateway_clmm_repository.py b/database/repositories/gateway_clmm_repository.py new file mode 100644 index 00000000..af11b0df --- /dev/null +++ b/database/repositories/gateway_clmm_repository.py @@ -0,0 +1,435 @@ +from datetime import datetime, timezone +from typing import Dict, List, Optional, Set, Tuple +from decimal import Decimal + +from sqlalchemy import desc, select, distinct +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import GatewayCLMMPosition, GatewayCLMMEvent + + +class GatewayCLMMRepository: + def __init__(self, session: AsyncSession): + self.session = session + + # ============================================ + # Position Management + # ============================================ + + async def create_position(self, position_data: Dict) -> GatewayCLMMPosition: + """Create a new CLMM position record.""" + position = GatewayCLMMPosition(**position_data) + self.session.add(position) + await self.session.flush() + return position + + async def get_position_by_address(self, position_address: str) -> Optional[GatewayCLMMPosition]: + """Get a position by its address.""" + result = await self.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.position_address == position_address) + ) + return result.scalar_one_or_none() + + async def update_position_liquidity( + self, + position_address: str, + base_token_amount: Decimal, + quote_token_amount: Decimal, + in_range: Optional[str] = None, + current_price: Optional[Decimal] = None + ) -> Optional[GatewayCLMMPosition]: + """Update position liquidity amounts and current price.""" + result = await self.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.position_address == position_address) + ) + position = result.scalar_one_or_none() + if position: + position.base_token_amount = float(base_token_amount) + position.quote_token_amount = float(quote_token_amount) + if in_range is not None: + position.in_range = in_range + if current_price is not None: + position.current_price = float(current_price) + await self.session.flush() + return position + + async def update_position_fees( + self, + position_address: str, + base_fee_pending: Optional[Decimal] = None, + quote_fee_pending: Optional[Decimal] = None, + base_fee_collected: Optional[Decimal] = None, + quote_fee_collected: Optional[Decimal] = None + ) -> Optional[GatewayCLMMPosition]: + """Update position fee amounts.""" + result = await self.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.position_address == position_address) + ) + position = result.scalar_one_or_none() + if position: + if base_fee_pending is not None: + position.base_fee_pending = float(base_fee_pending) + if quote_fee_pending is not None: + position.quote_fee_pending = float(quote_fee_pending) + if base_fee_collected is not None: + position.base_fee_collected = float(base_fee_collected) + if quote_fee_collected is not None: + position.quote_fee_collected = float(quote_fee_collected) + await self.session.flush() + return position + + async def close_position(self, position_address: str) -> Optional[GatewayCLMMPosition]: + """Mark position as closed.""" + result = await self.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.position_address == position_address) + ) + position = result.scalar_one_or_none() + if position: + position.status = "CLOSED" + position.closed_at = datetime.utcnow() + await self.session.flush() + return position + + async def reopen_position(self, position_address: str) -> Optional[GatewayCLMMPosition]: + """ + Reopen a position that was incorrectly marked as closed. + + This is used when autodiscover finds a position that exists on-chain + but was marked as CLOSED in the database (e.g., due to a failed close transaction). + """ + result = await self.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.position_address == position_address) + ) + position = result.scalar_one_or_none() + if position and position.status == "CLOSED": + position.status = "OPEN" + position.closed_at = None + await self.session.flush() + return position + + async def get_positions( + self, + network: Optional[str] = None, + connector: Optional[str] = None, + wallet_address: Optional[str] = None, + trading_pair: Optional[str] = None, + status: Optional[str] = None, + position_addresses: Optional[List[str]] = None, + limit: int = 100, + offset: int = 0 + ) -> List[GatewayCLMMPosition]: + """Get positions with filtering and pagination.""" + query = select(GatewayCLMMPosition) + + # Apply filters + if network: + query = query.where(GatewayCLMMPosition.network == network) + if connector: + query = query.where(GatewayCLMMPosition.connector == connector) + if wallet_address: + query = query.where(GatewayCLMMPosition.wallet_address == wallet_address) + if trading_pair: + query = query.where(GatewayCLMMPosition.trading_pair == trading_pair) + if status: + query = query.where(GatewayCLMMPosition.status == status) + if position_addresses: + query = query.where(GatewayCLMMPosition.position_address.in_(position_addresses)) + + # Apply ordering and pagination + query = query.order_by(GatewayCLMMPosition.created_at.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_open_positions( + self, + network: Optional[str] = None, + wallet_address: Optional[str] = None + ) -> List[GatewayCLMMPosition]: + """Get all open positions.""" + return await self.get_positions( + network=network, + wallet_address=wallet_address, + status="OPEN", + limit=1000 + ) + + async def get_unique_wallet_configs(self) -> List[Dict]: + """ + Get unique combinations of connector/network/wallet from all positions. + + Returns: + List of dicts with keys: connector, network, wallet_address + This is useful for discovering which wallets to poll for positions. + """ + query = select( + distinct(GatewayCLMMPosition.connector), + GatewayCLMMPosition.network, + GatewayCLMMPosition.wallet_address + ).distinct() + + result = await self.session.execute(query) + rows = result.all() + + return [ + { + "connector": row[0], + "network": row[1], + "wallet_address": row[2] + } + for row in rows + ] + + async def get_position_addresses_set(self, status: Optional[str] = None) -> Set[str]: + """ + Get a set of position addresses in the database. + + Args: + status: Optional filter by status ("OPEN" or "CLOSED"). + If None, returns all positions. + + Returns: + Set of position addresses (useful for quick existence checks) + """ + query = select(GatewayCLMMPosition.position_address) + if status: + query = query.where(GatewayCLMMPosition.status == status) + result = await self.session.execute(query) + return {row[0] for row in result.all()} + + # ============================================ + # Event Management + # ============================================ + + async def create_event(self, event_data: Dict) -> GatewayCLMMEvent: + """Create a new CLMM event record.""" + event = GatewayCLMMEvent(**event_data) + self.session.add(event) + await self.session.flush() + return event + + async def get_event_by_tx_hash( + self, + transaction_hash: str, + event_type: Optional[str] = None + ) -> Optional[GatewayCLMMEvent]: + """Get an event by transaction hash.""" + query = select(GatewayCLMMEvent).where(GatewayCLMMEvent.transaction_hash == transaction_hash) + if event_type: + query = query.where(GatewayCLMMEvent.event_type == event_type) + + result = await self.session.execute(query) + return result.scalar_one_or_none() + + async def update_event_status( + self, + transaction_hash: str, + status: str, + error_message: Optional[str] = None, + gas_fee: Optional[Decimal] = None, + gas_token: Optional[str] = None + ) -> Optional[GatewayCLMMEvent]: + """Update event status after transaction confirmation.""" + result = await self.session.execute( + select(GatewayCLMMEvent).where(GatewayCLMMEvent.transaction_hash == transaction_hash) + ) + event = result.scalar_one_or_none() + if event: + event.status = status + if error_message: + event.error_message = error_message + if gas_fee is not None: + event.gas_fee = float(gas_fee) + if gas_token: + event.gas_token = gas_token + await self.session.flush() + return event + + async def get_position_events( + self, + position_address: str, + event_type: Optional[str] = None, + limit: int = 100 + ) -> List[GatewayCLMMEvent]: + """Get all events for a position.""" + # First get the position + position = await self.get_position_by_address(position_address) + if not position: + return [] + + # Then get its events + query = select(GatewayCLMMEvent).where(GatewayCLMMEvent.position_id == position.id) + + if event_type: + query = query.where(GatewayCLMMEvent.event_type == event_type) + + query = query.order_by(GatewayCLMMEvent.timestamp.desc()).limit(limit) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_pending_events(self, limit: int = 100) -> List[GatewayCLMMEvent]: + """Get events that are still pending confirmation.""" + query = select(GatewayCLMMEvent).where( + GatewayCLMMEvent.status == "SUBMITTED" + ).order_by(GatewayCLMMEvent.timestamp.desc()).limit(limit) + + result = await self.session.execute(query) + return result.scalars().all() + + # ============================================ + # Utilities + # ============================================ + + def position_to_dict(self, position: GatewayCLMMPosition) -> Dict: + """Convert GatewayCLMMPosition model to dictionary format with enhanced PnL calculation.""" + pnl_summary = None + + # Get prices for PnL calculation + entry_price = float(position.entry_price) if position.entry_price else None + current_price = float(position.current_price) if position.current_price else None + + # Calculate PnL if we have initial amounts and prices + if (position.initial_base_token_amount is not None and + position.initial_quote_token_amount is not None and + entry_price and entry_price > 0 and + current_price and current_price > 0): + + # Initial amounts + initial_base = float(position.initial_base_token_amount) + initial_quote = float(position.initial_quote_token_amount) + + # Current liquidity amounts + current_base = float(position.base_token_amount) + current_quote = float(position.quote_token_amount) + + # Total fees (collected + pending) + total_fees_base = float(position.base_fee_collected) + float(position.base_fee_pending) + total_fees_quote = float(position.quote_fee_collected) + float(position.quote_fee_pending) + + # Value calculations (all normalized to quote currency) + initial_value_quote = initial_base * entry_price + initial_quote + current_lp_value_quote = current_base * current_price + current_quote + total_fees_value_quote = total_fees_base * current_price + total_fees_quote + current_total_value_quote = current_lp_value_quote + total_fees_value_quote + + # HODL comparison: what if user just held initial tokens without LP + hodl_value_quote = initial_base * current_price + initial_quote + + # Impermanent loss (negative = loss due to LP vs holding) + impermanent_loss_quote = current_lp_value_quote - hodl_value_quote + + # Total P&L + total_pnl_quote = current_total_value_quote - initial_value_quote + total_pnl_pct = (total_pnl_quote / initial_value_quote * 100) if initial_value_quote > 0 else 0 + + # Price change + price_change_pct = ((current_price - entry_price) / entry_price * 100) if entry_price > 0 else 0 + + # Duration and APR estimate + duration_hours = 0 + fee_apr_estimate = None + if position.created_at: + # Use closed_at if closed, otherwise current time + end_time = position.closed_at if position.closed_at else datetime.now(timezone.utc) + # Handle timezone-naive datetimes + if position.created_at.tzinfo is None: + created_at = position.created_at.replace(tzinfo=timezone.utc) + else: + created_at = position.created_at + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=timezone.utc) + + duration_seconds = (end_time - created_at).total_seconds() + duration_hours = duration_seconds / 3600 + + # Calculate fee APR if we have meaningful duration + if duration_seconds > 0 and initial_value_quote > 0: + duration_years = duration_seconds / (365.25 * 24 * 3600) + if duration_years > 0: + fee_apr_estimate = (total_fees_value_quote / initial_value_quote / duration_years * 100) + + pnl_summary = { + # Prices + "entry_price": round(entry_price, 8), + "current_price": round(current_price, 8), + "price_change_pct": round(price_change_pct, 4), + + # Initial state + "initial_base": round(initial_base, 8), + "initial_quote": round(initial_quote, 8), + "initial_value_quote": round(initial_value_quote, 8), + + # Current position (liquidity only, no fees) + "current_base": round(current_base, 8), + "current_quote": round(current_quote, 8), + "current_lp_value_quote": round(current_lp_value_quote, 8), + + # Fees earned + "total_fees_base": round(total_fees_base, 8), + "total_fees_quote": round(total_fees_quote, 8), + "total_fees_value_quote": round(total_fees_value_quote, 8), + + # HODL comparison + "hodl_value_quote": round(hodl_value_quote, 8), + + # Key metrics + "impermanent_loss_quote": round(impermanent_loss_quote, 8), + "current_total_value_quote": round(current_total_value_quote, 8), + "total_pnl_quote": round(total_pnl_quote, 8), + "total_pnl_pct": round(total_pnl_pct, 4), + + # Time metrics + "duration_hours": round(duration_hours, 2), + "fee_apr_estimate": round(fee_apr_estimate, 2) if fee_apr_estimate else None + } + + return { + "position_address": position.position_address, + "pool_address": position.pool_address, + "network": position.network, + "connector": position.connector, + "wallet_address": position.wallet_address, + "trading_pair": position.trading_pair, + "base_token": position.base_token, + "quote_token": position.quote_token, + "created_at": position.created_at.isoformat(), + "closed_at": position.closed_at.isoformat() if position.closed_at else None, + "status": position.status, + "lower_price": float(position.lower_price), + "upper_price": float(position.upper_price), + "lower_bin_id": position.lower_bin_id, + "upper_bin_id": position.upper_bin_id, + "entry_price": entry_price, + "current_price": current_price, + "percentage": float(position.percentage) if position.percentage is not None else None, + "initial_base_token_amount": float(position.initial_base_token_amount) if position.initial_base_token_amount is not None else None, + "initial_quote_token_amount": float(position.initial_quote_token_amount) if position.initial_quote_token_amount is not None else None, + "position_rent": float(position.position_rent) if position.position_rent is not None else None, + "base_token_amount": float(position.base_token_amount), + "quote_token_amount": float(position.quote_token_amount), + "in_range": position.in_range, + "base_fee_collected": float(position.base_fee_collected), + "quote_fee_collected": float(position.quote_fee_collected), + "base_fee_pending": float(position.base_fee_pending), + "quote_fee_pending": float(position.quote_fee_pending), + "pnl_summary": pnl_summary, + "last_updated": position.last_updated.isoformat(), + } + + def event_to_dict(self, event: GatewayCLMMEvent) -> Dict: + """Convert GatewayCLMMEvent model to dictionary format.""" + return { + "transaction_hash": event.transaction_hash, + "timestamp": event.timestamp.isoformat(), + "event_type": event.event_type, + "base_token_amount": float(event.base_token_amount) if event.base_token_amount else None, + "quote_token_amount": float(event.quote_token_amount) if event.quote_token_amount else None, + "base_fee_collected": float(event.base_fee_collected) if event.base_fee_collected else None, + "quote_fee_collected": float(event.quote_fee_collected) if event.quote_fee_collected else None, + "gas_fee": float(event.gas_fee) if event.gas_fee else None, + "gas_token": event.gas_token, + "status": event.status, + "error_message": event.error_message, + } diff --git a/database/repositories/gateway_swap_repository.py b/database/repositories/gateway_swap_repository.py new file mode 100644 index 00000000..57871fb8 --- /dev/null +++ b/database/repositories/gateway_swap_repository.py @@ -0,0 +1,167 @@ +from datetime import datetime +from typing import Dict, List, Optional +from decimal import Decimal + +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import GatewaySwap + + +class GatewaySwapRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_swap(self, swap_data: Dict) -> GatewaySwap: + """Create a new swap record.""" + swap = GatewaySwap(**swap_data) + self.session.add(swap) + await self.session.flush() + return swap + + async def get_swap_by_tx_hash(self, transaction_hash: str) -> Optional[GatewaySwap]: + """Get a swap by its transaction hash.""" + result = await self.session.execute( + select(GatewaySwap).where(GatewaySwap.transaction_hash == transaction_hash) + ) + return result.scalar_one_or_none() + + async def update_swap_status( + self, + transaction_hash: str, + status: str, + error_message: Optional[str] = None, + gas_fee: Optional[Decimal] = None, + gas_token: Optional[str] = None + ) -> Optional[GatewaySwap]: + """Update swap status and optional metadata after transaction confirmation.""" + result = await self.session.execute( + select(GatewaySwap).where(GatewaySwap.transaction_hash == transaction_hash) + ) + swap = result.scalar_one_or_none() + if swap: + swap.status = status + if error_message: + swap.error_message = error_message + if gas_fee is not None: + swap.gas_fee = float(gas_fee) + if gas_token: + swap.gas_token = gas_token + await self.session.flush() + return swap + + async def get_swaps( + self, + network: Optional[str] = None, + connector: Optional[str] = None, + wallet_address: Optional[str] = None, + trading_pair: Optional[str] = None, + status: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, + offset: int = 0 + ) -> List[GatewaySwap]: + """Get swaps with filtering and pagination.""" + query = select(GatewaySwap) + + # Apply filters + if network: + query = query.where(GatewaySwap.network == network) + if connector: + query = query.where(GatewaySwap.connector == connector) + if wallet_address: + query = query.where(GatewaySwap.wallet_address == wallet_address) + if trading_pair: + query = query.where(GatewaySwap.trading_pair == trading_pair) + if status: + query = query.where(GatewaySwap.status == status) + if start_time: + start_dt = datetime.fromtimestamp(start_time) + query = query.where(GatewaySwap.timestamp >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time) + query = query.where(GatewaySwap.timestamp <= end_dt) + + # Apply ordering and pagination + query = query.order_by(GatewaySwap.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_pending_swaps(self, limit: int = 100) -> List[GatewaySwap]: + """Get swaps that are still pending confirmation.""" + query = select(GatewaySwap).where( + GatewaySwap.status == "SUBMITTED" + ).order_by(GatewaySwap.timestamp.desc()).limit(limit) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_swaps_summary( + self, + network: Optional[str] = None, + wallet_address: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None + ) -> Dict: + """Get swap summary statistics.""" + swaps = await self.get_swaps( + network=network, + wallet_address=wallet_address, + start_time=start_time, + end_time=end_time, + limit=10000 # Get all for summary + ) + + total_swaps = len(swaps) + confirmed_swaps = sum(1 for s in swaps if s.status == "CONFIRMED") + failed_swaps = sum(1 for s in swaps if s.status == "FAILED") + pending_swaps = sum(1 for s in swaps if s.status == "SUBMITTED") + + # Calculate total volume (in quote token) + total_volume = sum( + float(s.output_amount if s.side == "BUY" else s.input_amount) + for s in swaps if s.status == "CONFIRMED" + ) + + # Calculate total gas fees + total_gas_fees = sum( + float(s.gas_fee) for s in swaps + if s.gas_fee is not None and s.status == "CONFIRMED" + ) + + return { + "total_swaps": total_swaps, + "confirmed_swaps": confirmed_swaps, + "failed_swaps": failed_swaps, + "pending_swaps": pending_swaps, + "success_rate": confirmed_swaps / total_swaps if total_swaps > 0 else 0, + "total_volume": total_volume, + "total_gas_fees": total_gas_fees, + } + + def to_dict(self, swap: GatewaySwap) -> Dict: + """Convert GatewaySwap model to dictionary format.""" + return { + "transaction_hash": swap.transaction_hash, + "timestamp": swap.timestamp.isoformat(), + "network": swap.network, + "connector": swap.connector, + "wallet_address": swap.wallet_address, + "trading_pair": swap.trading_pair, + "base_token": swap.base_token, + "quote_token": swap.quote_token, + "side": swap.side, + "input_amount": float(swap.input_amount), + "output_amount": float(swap.output_amount), + "price": float(swap.price), + "slippage_pct": float(swap.slippage_pct) if swap.slippage_pct else None, + "gas_fee": float(swap.gas_fee) if swap.gas_fee else None, + "gas_token": swap.gas_token, + "status": swap.status, + "pool_address": swap.pool_address, + "quote_id": swap.quote_id, + "error_message": swap.error_message, + } diff --git a/database/repositories/order_repository.py b/database/repositories/order_repository.py new file mode 100644 index 00000000..27acdc7f --- /dev/null +++ b/database/repositories/order_repository.py @@ -0,0 +1,178 @@ +from datetime import datetime +from typing import Dict, List, Optional +from decimal import Decimal + +from sqlalchemy import desc, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import Order + + +class OrderRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_order(self, order_data: Dict) -> Order: + """Create a new order record.""" + order = Order(**order_data) + self.session.add(order) + await self.session.flush() # Get the ID + return order + + async def get_order_by_client_id(self, client_order_id: str) -> Optional[Order]: + """Get an order by its client order ID.""" + result = await self.session.execute( + select(Order).where(Order.client_order_id == client_order_id) + ) + return result.scalar_one_or_none() + + async def update_order_status(self, client_order_id: str, status: str, + error_message: Optional[str] = None) -> Optional[Order]: + """Update order status and optional error message.""" + result = await self.session.execute( + select(Order).where(Order.client_order_id == client_order_id) + ) + order = result.scalar_one_or_none() + if order: + order.status = status + if error_message: + order.error_message = error_message + await self.session.flush() + return order + + async def update_order_fill(self, client_order_id: str, filled_amount: Decimal, + average_fill_price: Decimal, fee_paid: Decimal = None, + fee_currency: str = None, exchange_order_id: str = None) -> Optional[Order]: + """Update order with fill information.""" + result = await self.session.execute( + select(Order).where(Order.client_order_id == client_order_id) + ) + order = result.scalar_one_or_none() + if order: + # Add to existing filled amount instead of replacing + previous_filled = Decimal(str(order.filled_amount or 0)) + order.filled_amount = float(previous_filled + filled_amount) + + # Update average price (simplified - use latest fill price) + order.average_fill_price = float(average_fill_price) + + # Add to existing fees + if fee_paid is not None: + previous_fee = Decimal(str(order.fee_paid or 0)) + order.fee_paid = float(previous_fee + fee_paid) + if fee_currency: + order.fee_currency = fee_currency + if exchange_order_id: + order.exchange_order_id = exchange_order_id + + # Update status based on total filled amount + total_filled = Decimal(str(order.filled_amount)) + if total_filled >= Decimal(str(order.amount)): + order.status = "FILLED" + elif total_filled > 0: + order.status = "PARTIALLY_FILLED" + + await self.session.flush() + return order + + async def get_orders(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + status: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Order]: + """Get orders with filtering and pagination.""" + query = select(Order) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Order.trading_pair == trading_pair) + if status: + query = query.where(Order.status == status) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Order.created_at >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Order.created_at <= end_dt) + + # Apply ordering and pagination + query = query.order_by(Order.created_at.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_active_orders(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None) -> List[Order]: + """Get active orders (SUBMITTED, OPEN, PARTIALLY_FILLED, PENDING_CANCEL).""" + query = select(Order).where( + Order.status.in_(["SUBMITTED", "OPEN", "PARTIALLY_FILLED", "PENDING_CANCEL"]) + ) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Order.trading_pair == trading_pair) + + query = query.order_by(Order.created_at.desc()).limit(1000) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_orders_summary(self, account_name: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None) -> Dict: + """Get order summary statistics.""" + orders = await self.get_orders( + account_name=account_name, + start_time=start_time, + end_time=end_time, + limit=10000 # Get all for summary + ) + + total_orders = len(orders) + filled_orders = sum(1 for o in orders if o.status == "FILLED") + cancelled_orders = sum(1 for o in orders if o.status == "CANCELLED") + failed_orders = sum(1 for o in orders if o.status == "FAILED") + active_orders = sum(1 for o in orders if o.status in ["SUBMITTED", "OPEN", "PARTIALLY_FILLED"]) + + return { + "total_orders": total_orders, + "filled_orders": filled_orders, + "cancelled_orders": cancelled_orders, + "failed_orders": failed_orders, + "active_orders": active_orders, + "fill_rate": filled_orders / total_orders if total_orders > 0 else 0, + } + + def to_dict(self, order: Order) -> Dict: + """Convert Order model to dictionary format.""" + return { + "order_id": order.client_order_id, + "account_name": order.account_name, + "connector_name": order.connector_name, + "trading_pair": order.trading_pair, + "trade_type": order.trade_type, + "order_type": order.order_type, + "amount": float(order.amount), + "price": float(order.price) if order.price else None, + "status": order.status, + "filled_amount": float(order.filled_amount), + "average_fill_price": float(order.average_fill_price) if order.average_fill_price else None, + "fee_paid": float(order.fee_paid) if order.fee_paid else None, + "fee_currency": order.fee_currency, + "created_at": order.created_at.isoformat(), + "updated_at": order.updated_at.isoformat(), + "exchange_order_id": order.exchange_order_id, + "error_message": order.error_message, + } \ No newline at end of file diff --git a/database/repositories/trade_repository.py b/database/repositories/trade_repository.py new file mode 100644 index 00000000..f718a643 --- /dev/null +++ b/database/repositories/trade_repository.py @@ -0,0 +1,126 @@ +from datetime import datetime +from typing import Dict, List, Optional + +from sqlalchemy import desc, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import Trade, Order + + +class TradeRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def create_trade(self, trade_data: Dict) -> Optional[Trade]: + """Create a new trade record if it doesn't already exist. + + Returns the trade if created, or None if it already exists (idempotent). + Handles race conditions gracefully by catching IntegrityError. + """ + # Check if trade already exists + trade_id = trade_data.get("trade_id") + if trade_id: + existing = await self.get_trade_by_id(trade_id) + if existing: + return None # Already exists, skip silently + + trade = Trade(**trade_data) + self.session.add(trade) + try: + await self.session.flush() # Get the ID + return trade + except IntegrityError: + # Race condition: another concurrent insert succeeded first + await self.session.rollback() + return None + + async def get_trade_by_id(self, trade_id: str) -> Optional[Trade]: + """Get a trade by its trade_id.""" + query = select(Trade).where(Trade.trade_id == trade_id) + result = await self.session.execute(query) + return result.scalar_one_or_none() + + async def get_trades(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + trade_type: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Trade]: + """Get trades with filtering and pagination.""" + # Join trades with orders to get account information + query = select(Trade).join(Order, Trade.order_id == Order.id) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Trade.trading_pair == trading_pair) + if trade_type: + query = query.where(Trade.trade_type == trade_type) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Trade.timestamp >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Trade.timestamp <= end_dt) + + # Apply ordering and pagination + query = query.order_by(Trade.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.scalars().all() + + async def get_trades_with_orders(self, account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + trade_type: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[tuple]: + """Get trades with their associated order information.""" + # Join trades with orders to get complete information + query = select(Trade, Order).join(Order, Trade.order_id == Order.id) + + # Apply filters + if account_name: + query = query.where(Order.account_name == account_name) + if connector_name: + query = query.where(Order.connector_name == connector_name) + if trading_pair: + query = query.where(Trade.trading_pair == trading_pair) + if trade_type: + query = query.where(Trade.trade_type == trade_type) + if start_time: + start_dt = datetime.fromtimestamp(start_time / 1000) + query = query.where(Trade.timestamp >= start_dt) + if end_time: + end_dt = datetime.fromtimestamp(end_time / 1000) + query = query.where(Trade.timestamp <= end_dt) + + # Apply ordering and pagination + query = query.order_by(Trade.timestamp.desc()) + query = query.limit(limit).offset(offset) + + result = await self.session.execute(query) + return result.all() # Returns tuples of (Trade, Order) + + def to_dict(self, trade: Trade, order: Optional[Order] = None) -> Dict: + """Convert Trade model to dictionary format.""" + return { + "trade_id": trade.trade_id, + "order_id": order.client_order_id if order else None, + "account_name": order.account_name if order else None, + "connector_name": order.connector_name if order else None, + "trading_pair": trade.trading_pair, + "trade_type": trade.trade_type, + "amount": float(trade.amount), + "price": float(trade.price), + "fee_paid": float(trade.fee_paid), + "fee_currency": trade.fee_currency, + "timestamp": trade.timestamp.isoformat(), + } \ No newline at end of file diff --git a/deps.py b/deps.py new file mode 100644 index 00000000..99e0f776 --- /dev/null +++ b/deps.py @@ -0,0 +1,61 @@ +from fastapi import Request +from services.bots_orchestrator import BotsOrchestrator +from services.accounts_service import AccountsService +from services.docker_service import DockerService +from services.gateway_service import GatewayService +from services.unified_connector_service import UnifiedConnectorService +from services.market_data_service import MarketDataService +from services.trading_service import TradingService +from services.executor_service import ExecutorService +from utils.bot_archiver import BotArchiver +from database import AsyncDatabaseManager + + +def get_bots_orchestrator(request: Request) -> BotsOrchestrator: + """Get BotsOrchestrator service from app state.""" + return request.app.state.bots_orchestrator + + +def get_accounts_service(request: Request) -> AccountsService: + """Get AccountsService from app state.""" + return request.app.state.accounts_service + + +def get_docker_service(request: Request) -> DockerService: + """Get DockerService from app state.""" + return request.app.state.docker_service + + +def get_gateway_service(request: Request) -> GatewayService: + """Get GatewayService from app state.""" + return request.app.state.gateway_service + + +def get_connector_service(request: Request) -> UnifiedConnectorService: + """Get UnifiedConnectorService from app state.""" + return request.app.state.connector_service + + +def get_market_data_service(request: Request) -> MarketDataService: + """Get MarketDataService from app state.""" + return request.app.state.market_data_service + + +def get_trading_service(request: Request) -> TradingService: + """Get TradingService from app state.""" + return request.app.state.trading_service + + +def get_executor_service(request: Request) -> ExecutorService: + """Get ExecutorService from app state.""" + return request.app.state.executor_service + + +def get_bot_archiver(request: Request) -> BotArchiver: + """Get BotArchiver from app state.""" + return request.app.state.bot_archiver + + +def get_database_manager(request: Request) -> AsyncDatabaseManager: + """Get AsyncDatabaseManager from app state.""" + return request.app.state.db_manager diff --git a/docker-compose.yml b/docker-compose.yml index 03e11633..c33f7ebb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,20 +1,28 @@ -version: "3.9" - services: - backend-api: - build: . + hummingbot-api: + container_name: hummingbot-api + image: hummingbot/hummingbot-api:latest ports: - "8000:8000" volumes: - - ./bots:/backend-api/bots + - ./bots:/hummingbot-api/bots - /var/run/docker.sock:/var/run/docker.sock env_file: - .env environment: + # Override specific values for Docker networking - BROKER_HOST=emqx - - BROKER_PORT=1883 + - DATABASE_URL=postgresql+asyncpg://hbot:hummingbot-api@postgres:5432/hummingbot_api + - GATEWAY_URL=http://host.docker.internal:15888 + extra_hosts: + # Map host.docker.internal to host gateway for Linux compatibility + # On macOS/Windows, Docker Desktop provides this automatically + # On Linux, this maps to the docker bridge gateway IP + - "host.docker.internal:host-gateway" networks: - emqx-bridge + depends_on: + - postgres emqx: container_name: hummingbot-broker image: emqx:5 @@ -46,6 +54,30 @@ services: interval: 5s timeout: 25s retries: 5 + postgres: + container_name: hummingbot-postgres + image: postgres:16 + restart: unless-stopped + environment: + # These variables automatically create the user and database on first initialization + - POSTGRES_DB=hummingbot_api + - POSTGRES_USER=hbot + - POSTGRES_PASSWORD=hummingbot-api + # Additional init parameters for better compatibility + - POSTGRES_INITDB_ARGS=--encoding=UTF8 + volumes: + - postgres-data:/var/lib/postgresql/data + # Init script as safety net - only runs on first initialization + - ./init-db.sql:/docker-entrypoint-initdb.d/init-db.sql:ro + ports: + - "5432:5432" + networks: + - emqx-bridge + healthcheck: + test: ["CMD-SHELL", "pg_isready -U hbot -d hummingbot_api"] + interval: 10s + timeout: 5s + retries: 5 networks: emqx-bridge: @@ -55,3 +87,4 @@ volumes: emqx-data: { } emqx-log: { } emqx-etc: { } + postgres-data: { } diff --git a/environment.yml b/environment.yml index 69b97efe..6100a57c 100644 --- a/environment.yml +++ b/environment.yml @@ -1,22 +1,35 @@ -name: backend-api +name: hummingbot-api channels: - conda-forge - defaults dependencies: - - python=3.10 + - python=3.12 - fastapi - uvicorn - - libcxx + - boto3 + - libcxx + - libta-lib>=0.6.4 + - python-dotenv + - pandas>=2.3.2 + - numba>=0.61.2 + - numpy>=2.2.6 + - pandas-ta>=0.4.71b + - ta-lib>=0.6.4 + - tqdm>=4.67.1 + - docker-py - pip - pip: - hummingbot - - numpy==1.26.4 - - git+https://github.com/felixfontein/docker-py - - python-dotenv - - boto3 - - python-multipart==0.0.12 - - PyYAML - - git+https://github.com/hummingbot/hbot-remote-client-py.git + - msgpack>=1.0.5 - flake8 - isort - pre-commit + - logfire + - logfire[fastapi] + - logfire[system-metrics] + - aiomqtt>=2.0.0 + - sqlalchemy>=2.0.0 + - asyncpg + - psycopg2-binary + - greenlet + - pydantic-settings diff --git a/init-db.sql b/init-db.sql new file mode 100644 index 00000000..81c3f0e9 --- /dev/null +++ b/init-db.sql @@ -0,0 +1,10 @@ +-- Safety net for PostgreSQL initialization +-- PostgreSQL auto-creates user/db from POSTGRES_USER, POSTGRES_DB env vars +-- This script only runs on first container initialization + +-- Ensure proper permissions on public schema +GRANT ALL ON SCHEMA public TO hbot; +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO hbot; +GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO hbot; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO hbot; +ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO hbot; diff --git a/main.py b/main.py index 5e790e6a..b53ede81 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,395 @@ +import logging +import secrets +from contextlib import asynccontextmanager +from typing import Annotated +from urllib.parse import urlparse + +import logfire from dotenv import load_dotenv -from fastapi import FastAPI -from routers import manage_accounts, manage_backtesting, manage_broker_messages, manage_docker, manage_files, manage_market_data +# Apply the patch before importing hummingbot components +from hummingbot.client.config import config_helpers +# Load environment variables early load_dotenv() -app = FastAPI() - -app.include_router(manage_docker.router) -app.include_router(manage_broker_messages.router) -app.include_router(manage_files.router) -app.include_router(manage_market_data.router) -app.include_router(manage_backtesting.router) -app.include_router(manage_accounts.router) + +VERSION = "1.0.1" + +# Monkey patch save_to_yml to prevent writes to library directory + + +def patched_save_to_yml(yml_path, cm): + """Patched version of save_to_yml that prevents writes to library directory""" + import logging + logger = logging.getLogger(__name__) + logger.debug(f"Skipping config write to {yml_path} (patched for API mode)") + # Do nothing - this prevents the original function from trying to write to the library directory + + +config_helpers.save_to_yml = patched_save_to_yml + +from fastapi import Depends, FastAPI, HTTPException, Request, status # noqa: E402 +from fastapi.exceptions import RequestValidationError # noqa: E402 +from fastapi.middleware.cors import CORSMiddleware # noqa: E402 +from fastapi.responses import JSONResponse # noqa: E402 +from fastapi.security import HTTPBasic, HTTPBasicCredentials # noqa: E402 +from hummingbot.client.config.client_config_map import GatewayConfigMap # noqa: E402 +from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger # noqa: E402 +from hummingbot.core.gateway.gateway_http_client import GatewayHttpClient # noqa: E402 +from hummingbot.core.rate_oracle.rate_oracle import RATE_ORACLE_SOURCES, RateOracle # noqa: E402 + +from config import settings # noqa: E402 +from database import AsyncDatabaseManager # noqa: E402 +from routers import ( # noqa: E402 + accounts, + archived_bots, + backtesting, + bot_orchestration, + connectors, + controllers, + docker, + executors, + gateway, + gateway_clmm, + gateway_proxy, + gateway_swap, + market_data, + portfolio, + rate_oracle, + scripts, + trading, +) +from services.accounts_service import AccountsService # noqa: E402 +from services.bots_orchestrator import BotsOrchestrator # noqa: E402 +from services.docker_service import DockerService # noqa: E402 +from services.executor_service import ExecutorService # noqa: E402 +from services.gateway_service import GatewayService # noqa: E402 +from services.market_data_service import MarketDataService # noqa: E402 +from services.trading_service import TradingService # noqa: E402 +from services.unified_connector_service import UnifiedConnectorService # noqa: E402 +from utils.bot_archiver import BotArchiver # noqa: E402 +from utils.security import BackendAPISecurity # noqa: E402 + +# Set up logging configuration +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +# Enable info logging for MQTT manager +logging.getLogger('services.mqtt_manager').setLevel(logging.INFO) + +# Get settings from Pydantic Settings +username = settings.security.username +password = settings.security.password +debug_mode = settings.security.debug_mode + +# Security setup +security = HTTPBasic() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + Lifespan context manager for the FastAPI application. + Handles startup and shutdown events. + """ + # Ensure password verification file exists + if BackendAPISecurity.new_password_required(): + # Create secrets manager with CONFIG_PASSWORD + secrets_manager = ETHKeyFileSecretManger(password=settings.security.config_password) + BackendAPISecurity.store_password_verification(secrets_manager) + logging.info("Created password verification file for master_account") + + # ========================================================================= + # 1. Infrastructure Setup + # ========================================================================= + + # Initialize GatewayHttpClient singleton + parsed_gateway_url = urlparse(settings.gateway.url) + gateway_config = GatewayConfigMap( + gateway_api_host=parsed_gateway_url.hostname or "localhost", + gateway_api_port=str(parsed_gateway_url.port or 15888), + gateway_use_ssl=parsed_gateway_url.scheme == "https" + ) + GatewayHttpClient.get_instance(gateway_config) + logging.info(f"Initialized GatewayHttpClient with URL: {settings.gateway.url}") + + # Initialize secrets manager and database + secrets_manager = ETHKeyFileSecretManger(password=settings.security.config_password) + db_manager = AsyncDatabaseManager(settings.database.url) + await db_manager.create_tables() + logging.info("Database initialized") + + # Read rate oracle configuration from conf_client.yml + from utils.file_system import FileSystemUtil + fs_util = FileSystemUtil() + + try: + conf_client_path = "credentials/master_account/conf_client.yml" + config_data = fs_util.read_yaml_file(conf_client_path) + + # Get rate_oracle_source configuration + rate_oracle_source_data = config_data.get("rate_oracle_source", {}) + source_name = rate_oracle_source_data.get("name", "binance") + + # Get global_token configuration + global_token_data = config_data.get("global_token", {}) + quote_token = global_token_data.get("global_token_name", "USDT") + + # Create rate source instance + if source_name in RATE_ORACLE_SOURCES: + rate_source = RATE_ORACLE_SOURCES[source_name]() + logging.info(f"Configured RateOracle with source: {source_name}, quote_token: {quote_token}") + else: + logging.warning(f"Unknown rate oracle source '{source_name}', defaulting to binance") + rate_source = RATE_ORACLE_SOURCES["binance"]() + source_name = "binance" + + # Initialize RateOracle with configured source and quote token + rate_oracle = RateOracle.get_instance() + rate_oracle.source = rate_source + rate_oracle.quote_token = quote_token + + except FileNotFoundError: + logging.warning("conf_client.yml not found, using default RateOracle configuration (binance, USDT)") + rate_oracle = RateOracle.get_instance() + except Exception as e: + logging.warning(f"Error reading conf_client.yml: {e}, using default RateOracle configuration") + rate_oracle = RateOracle.get_instance() + + # ========================================================================= + # 2. UnifiedConnectorService - Single source of truth for all connectors + # ========================================================================= + + connector_service = UnifiedConnectorService( + secrets_manager=secrets_manager, + db_manager=db_manager + ) + logging.info("UnifiedConnectorService initialized") + + # ========================================================================= + # 3. Services that depend on connector_service + # ========================================================================= + + # MarketDataService - candles, order books, prices + market_data_service = MarketDataService( + connector_service=connector_service, + rate_oracle=rate_oracle, + cleanup_interval=settings.market_data.cleanup_interval, + feed_timeout=settings.market_data.feed_timeout + ) + logging.info("MarketDataService initialized") + + # TradingService - order placement, positions, trading interfaces + trading_service = TradingService( + connector_service=connector_service, + market_data_service=market_data_service + ) + logging.info("TradingService initialized") + + # AccountsService - account management, balances, portfolio (simplified) + accounts_service = AccountsService( + account_update_interval=settings.app.account_update_interval, + gateway_url=settings.gateway.url + ) + # Inject services into AccountsService + accounts_service._connector_service = connector_service + accounts_service._market_data_service = market_data_service + accounts_service._trading_service = trading_service + logging.info("AccountsService initialized") + + # ========================================================================= + # 4. ExecutorService - depends on TradingService (NO circular dependency) + # ========================================================================= + + executor_service = ExecutorService( + trading_service=trading_service, + db_manager=db_manager, + default_account="master_account", + update_interval=1.0, + max_retries=10 + ) + logging.info("ExecutorService initialized") + # Ensure lp_executor is in the registry (workspace hummingbot may load after class definition) + try: + from hummingbot.strategy_v2.executors.lp_executor.data_types import LPExecutorConfig + from hummingbot.strategy_v2.executors.lp_executor.lp_executor import LPExecutor + print(f"[LP-FIX] imports OK. Registry before: {list(ExecutorService.EXECUTOR_REGISTRY.keys())}", flush=True) + ExecutorService.EXECUTOR_REGISTRY["lp_executor"] = (LPExecutor, LPExecutorConfig) + print(f"[LP-FIX] Registry after: {list(ExecutorService.EXECUTOR_REGISTRY.keys())}", flush=True) + except Exception as e: + import traceback + print(f"[LP-FIX] FAILED: {e}", flush=True) + traceback.print_exc() + + # ========================================================================= + # 5. Other Services + # ========================================================================= + + bots_orchestrator = BotsOrchestrator( + broker_host=settings.broker.host, + broker_port=settings.broker.port, + broker_username=settings.broker.username, + broker_password=settings.broker.password + ) + + docker_service = DockerService() + gateway_service = GatewayService() + bot_archiver = BotArchiver( + settings.aws.api_key, + settings.aws.secret_key, + settings.aws.s3_default_bucket_name + ) + + # ========================================================================= + # 6. Start services + # ========================================================================= + + # Initialize all trading connectors FIRST (before any service that might use them) + # This ensures OrdersRecorder is properly attached before any concurrent access + logging.info("Initializing all trading connectors...") + await connector_service.initialize_all_trading_connectors() + + bots_orchestrator.start() + market_data_service.start() + await market_data_service.warmup_rate_oracle() + executor_service.start() + await executor_service.cleanup_orphaned_executors() + await executor_service.recover_positions_from_db() + accounts_service.start() + + # ========================================================================= + # 7. Store services in app state + # ========================================================================= + + app.state.db_manager = db_manager + app.state.connector_service = connector_service + app.state.market_data_service = market_data_service + app.state.trading_service = trading_service + app.state.accounts_service = accounts_service + app.state.executor_service = executor_service + app.state.bots_orchestrator = bots_orchestrator + app.state.docker_service = docker_service + app.state.gateway_service = gateway_service + app.state.bot_archiver = bot_archiver + + logging.info("All services started successfully") + + yield + + # ========================================================================= + # Shutdown services + # ========================================================================= + + logging.info("Shutting down services...") + + bots_orchestrator.stop() + await accounts_service.stop() + await executor_service.stop() + market_data_service.stop() + await connector_service.stop_all() + docker_service.cleanup() + await db_manager.close() + + logging.info("All services stopped") + +# Initialize FastAPI with metadata and lifespan +app = FastAPI( + title="Hummingbot API", + description="API for managing Hummingbot trading instances", + version=VERSION, + lifespan=lifespan, +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Modify in production to specific origins + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """ + Custom handler for validation errors to log detailed error messages. + """ + # Build a readable error message from validation errors + error_messages = [] + for error in exc.errors(): + loc = " -> ".join(str(part) for part in error.get("loc", [])) + msg = error.get("msg", "Validation error") + error_messages.append(f"{loc}: {msg}") + + # Log the validation error with details + logging.warning( + f"Validation error on {request.method} {request.url.path}: {'; '.join(error_messages)}" + ) + + # Return standard FastAPI validation error response + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content={"detail": exc.errors()}, + ) + +logfire.configure(send_to_logfire="if-token-present", environment=settings.app.logfire_environment, + service_name="hummingbot-api") +logfire.instrument_fastapi(app) + + +def auth_user( + credentials: Annotated[HTTPBasicCredentials, Depends(security)], +): + """Authenticate user using HTTP Basic Auth""" + current_username_bytes = credentials.username.encode("utf8") + correct_username_bytes = f"{username}".encode("utf8") + is_correct_username = secrets.compare_digest( + current_username_bytes, correct_username_bytes + ) + current_password_bytes = credentials.password.encode("utf8") + correct_password_bytes = f"{password}".encode("utf8") + is_correct_password = secrets.compare_digest( + current_password_bytes, correct_password_bytes + ) + if not (is_correct_username and is_correct_password) and not debug_mode: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Basic"}, + ) + + return credentials.username + + +# Include all routers with authentication +app.include_router(docker.router, dependencies=[Depends(auth_user)]) +app.include_router(gateway.router, dependencies=[Depends(auth_user)]) +app.include_router(accounts.router, dependencies=[Depends(auth_user)]) +app.include_router(connectors.router, dependencies=[Depends(auth_user)]) +app.include_router(portfolio.router, dependencies=[Depends(auth_user)]) +app.include_router(trading.router, dependencies=[Depends(auth_user)]) +app.include_router(gateway_swap.router, dependencies=[Depends(auth_user)]) +app.include_router(gateway_clmm.router, dependencies=[Depends(auth_user)]) +app.include_router(bot_orchestration.router, dependencies=[Depends(auth_user)]) +app.include_router(controllers.router, dependencies=[Depends(auth_user)]) +app.include_router(scripts.router, dependencies=[Depends(auth_user)]) +app.include_router(market_data.router, dependencies=[Depends(auth_user)]) +app.include_router(rate_oracle.router, dependencies=[Depends(auth_user)]) +app.include_router(backtesting.router, dependencies=[Depends(auth_user)]) +app.include_router(archived_bots.router, dependencies=[Depends(auth_user)]) + +app.include_router(executors.router, dependencies=[Depends(auth_user)]) +app.include_router(gateway_proxy.router, dependencies=[Depends(auth_user)]) + + +@app.get("/") +async def root(): + """API root endpoint returning basic information.""" + return { + "name": "Hummingbot API", + "version": VERSION, + "status": "running", + } diff --git a/models.py b/models.py deleted file mode 100644 index 94aa10b3..00000000 --- a/models.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Any, Dict, Optional - -from pydantic import BaseModel - - -class HummingbotInstanceConfig(BaseModel): - instance_name: str - credentials_profile: str - image: str = "hummingbot/hummingbot:latest" - script: Optional[str] = None - script_config: Optional[str] = None - - -class ImageName(BaseModel): - image_name: str - - -class Script(BaseModel): - name: str - content: str - - -class ScriptConfig(BaseModel): - name: str - content: Dict[str, Any] # YAML content represented as a dictionary - - -class BotAction(BaseModel): - bot_name: str - - -class StartBotAction(BotAction): - log_level: str = None - script: str = None - conf: str = None - async_backend: bool = False - - -class StopBotAction(BotAction): - skip_order_cancellation: bool = False - async_backend: bool = False - - -class ImportStrategyAction(BotAction): - strategy: str - - -class ConfigureBotAction(BotAction): - params: dict - - -class ShortcutAction(BotAction): - params: list diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 00000000..62da9281 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,370 @@ +""" +Model definitions for the Backend API. + +Each model file corresponds to a router file with the same name. +Models are organized by functional domain to match the API structure. +""" + +# Account models +from .accounts import CredentialRequest, LeverageRequest, PositionModeRequest + +# Archived bots models +from .archived_bots import ( + ArchivedBotListResponse, + BotPerformanceResponse, + BotSummary, + DatabaseStatus, + ExecutorInfo, + ExecutorsResponse, + OrderDetail, + OrderHistoryResponse, + OrderStatus, + PerformanceMetrics, + TradeDetail, + TradeHistoryResponse, +) + +# Backtesting models +from .backtesting import BacktestingConfig + +# Bot orchestration models (bot lifecycle management) +from .bot_orchestration import ( + AllBotsStatusResponse, + BotAction, + BotHistoryRequest, + BotHistoryResponse, + BotStatus, + ConfigureBotAction, + ImportStrategyAction, + MQTTStatus, + ShortcutAction, + StartBotAction, + StopAndArchiveRequest, + StopAndArchiveResponse, + StopBotAction, + V2ControllerDeployment, +) + +# Connector models +from .connectors import ( + ConnectorConfigMapResponse, + ConnectorInfo, + ConnectorListResponse, + ConnectorOrderTypesResponse, + ConnectorTradingRulesResponse, + TradingRule, +) + +# Controller models +from .controllers import Controller, ControllerConfig, ControllerConfigResponse, ControllerResponse, ControllerType + +# Docker models +from .docker import DockerImage + +# Executor models +from .executors import ( + CreateExecutorRequest, + CreateExecutorResponse, + ExecutorDetailResponse, + ExecutorFilterRequest, + ExecutorResponse, + ExecutorsSummaryResponse, + StopExecutorRequest, + StopExecutorResponse, +) + +# Gateway models (consolidated) +from .gateway import ( + AddPoolRequest, + AddTokenRequest, + CreateWalletRequest, + GatewayBalanceRequest, + GatewayConfig, + GatewayStatus, + GatewayWalletCredential, + GatewayWalletInfo, + SendTransactionRequest, + ShowPrivateKeyRequest, +) + +# Gateway Trading models (Swap + CLMM only, AMM removed) +from .gateway_trading import ( # Swap models; CLMM models; Pool info models; Pool listing models + CLMMAddLiquidityRequest, + CLMMClosePositionRequest, + CLMMCollectFeesRequest, + CLMMCollectFeesResponse, + CLMMGetPositionInfoRequest, + CLMMOpenPositionRequest, + CLMMOpenPositionResponse, + CLMMPoolBin, + CLMMPoolInfoRequest, + CLMMPoolInfoResponse, + CLMMPoolListItem, + CLMMPoolListResponse, + CLMMPositionInfo, + CLMMPositionsOwnedRequest, + CLMMRemoveLiquidityRequest, + GetPoolInfoRequest, + PoolInfo, + SwapExecuteRequest, + SwapExecuteResponse, + SwapQuoteRequest, + SwapQuoteResponse, + TimeBasedMetrics, +) + +# Market data models +from .market_data import ( # New enhanced market data models; Trading pair management models + ActiveFeedInfo, + ActiveFeedsResponse, + AddTradingPairRequest, + CandleData, + CandlesResponse, + FundingInfoRequest, + FundingInfoResponse, + MarketDataSettings, + OrderBookLevel, + OrderBookQueryRequest, + OrderBookQueryResult, + OrderBookRequest, + OrderBookResponse, + PriceData, + PriceForQuoteVolumeRequest, + PriceForVolumeRequest, + PriceRequest, + PricesResponse, + QuoteVolumeForPriceRequest, + RemoveTradingPairRequest, + SupportedOrderTypesResponse, + TradingPairResponse, + TradingRulesResponse, + VolumeForPriceRequest, + VWAPForVolumeRequest, +) + +# Pagination models +from .pagination import PaginatedResponse, PaginationParams, TimeRangePaginationParams + +# Portfolio models +from .portfolio import ( + AccountDistribution, + AccountPortfolioState, + AccountsDistributionResponse, + ConnectorBalances, + HistoricalPortfolioState, + PortfolioDistributionResponse, + PortfolioHistoryFilters, + PortfolioStateResponse, + TokenBalance, + TokenDistribution, +) + +# Rate Oracle models +from .rate_oracle import ( + GlobalTokenConfig, + RateOracleConfig, + RateOracleConfigResponse, + RateOracleConfigUpdateRequest, + RateOracleConfigUpdateResponse, + RateOracleSourceConfig, + RateOracleSourceEnum, + RateRequest, + RateResponse, + SingleRateResponse, +) + +# Script models +from .scripts import Script, ScriptConfig, ScriptConfigResponse, ScriptResponse + +# Trading models +from .trading import ( + AccountBalance, + ActiveOrderFilterRequest, + ActiveOrdersResponse, + ConnectorBalance, + FundingPaymentFilterRequest, + OrderFilterRequest, + OrderInfo, + OrderSummary, + OrderTypesResponse, + PortfolioState, + PositionFilterRequest, + TokenInfo, + TradeFilterRequest, + TradeInfo, + TradeRequest, + TradeResponse, + TradingRulesInfo, +) + +__all__ = [ + # Bot orchestration models + "BotAction", + "StartBotAction", + "StopBotAction", + "ImportStrategyAction", + "ConfigureBotAction", + "ShortcutAction", + "BotStatus", + "BotHistoryRequest", + "BotHistoryResponse", + "MQTTStatus", + "AllBotsStatusResponse", + "StopAndArchiveRequest", + "StopAndArchiveResponse", + "V2ControllerDeployment", + # Trading models + "TradeRequest", + "TradeResponse", + "TokenInfo", + "ConnectorBalance", + "AccountBalance", + "PortfolioState", + "OrderInfo", + "ActiveOrdersResponse", + "OrderSummary", + "TradeInfo", + "TradingRulesInfo", + "OrderTypesResponse", + "OrderFilterRequest", + "ActiveOrderFilterRequest", + "PositionFilterRequest", + "FundingPaymentFilterRequest", + "TradeFilterRequest", + # Controller models + "ControllerType", + "Controller", + "ControllerResponse", + "ControllerConfig", + "ControllerConfigResponse", + # Script models + "Script", + "ScriptResponse", + "ScriptConfig", + "ScriptConfigResponse", + # Market data models + "CandleData", + "CandlesResponse", + "ActiveFeedInfo", + "ActiveFeedsResponse", + "MarketDataSettings", + "TradingRulesResponse", + "SupportedOrderTypesResponse", + # New enhanced market data models + "PriceRequest", + "PriceData", + "PricesResponse", + "FundingInfoRequest", + "FundingInfoResponse", + "OrderBookRequest", + "OrderBookLevel", + "OrderBookResponse", + "OrderBookQueryRequest", + "VolumeForPriceRequest", + "PriceForVolumeRequest", + "QuoteVolumeForPriceRequest", + "PriceForQuoteVolumeRequest", + "VWAPForVolumeRequest", + "OrderBookQueryResult", + # Trading pair management models + "AddTradingPairRequest", + "RemoveTradingPairRequest", + "TradingPairResponse", + # Account models + "LeverageRequest", + "PositionModeRequest", + "CredentialRequest", + # Docker models + "DockerImage", + # Gateway models + "GatewayConfig", + "GatewayStatus", + "CreateWalletRequest", + "ShowPrivateKeyRequest", + "SendTransactionRequest", + "GatewayWalletCredential", + "GatewayWalletInfo", + "GatewayBalanceRequest", + "AddPoolRequest", + "AddTokenRequest", + # Backtesting models + "BacktestingConfig", + # Pagination models + "PaginatedResponse", + "PaginationParams", + "TimeRangePaginationParams", + # Connector models + "ConnectorInfo", + "ConnectorConfigMapResponse", + "TradingRule", + "ConnectorTradingRulesResponse", + "ConnectorOrderTypesResponse", + "ConnectorListResponse", + # Gateway Trading models + "SwapQuoteRequest", + "SwapQuoteResponse", + "SwapExecuteRequest", + "SwapExecuteResponse", + "CLMMOpenPositionRequest", + "CLMMOpenPositionResponse", + "CLMMAddLiquidityRequest", + "CLMMRemoveLiquidityRequest", + "CLMMClosePositionRequest", + "CLMMCollectFeesRequest", + "CLMMCollectFeesResponse", + "CLMMPositionsOwnedRequest", + "CLMMPositionInfo", + "CLMMGetPositionInfoRequest", + "CLMMPoolInfoRequest", + "CLMMPoolBin", + "CLMMPoolInfoResponse", + "GetPoolInfoRequest", + "PoolInfo", + "TimeBasedMetrics", + "CLMMPoolListItem", + "CLMMPoolListResponse", + # Portfolio models + "TokenBalance", + "ConnectorBalances", + "AccountPortfolioState", + "PortfolioStateResponse", + "TokenDistribution", + "PortfolioDistributionResponse", + "AccountDistribution", + "AccountsDistributionResponse", + "HistoricalPortfolioState", + "PortfolioHistoryFilters", + # Archived bots models + "OrderStatus", + "DatabaseStatus", + "BotSummary", + "PerformanceMetrics", + "TradeDetail", + "OrderDetail", + "ExecutorInfo", + "ArchivedBotListResponse", + "BotPerformanceResponse", + "TradeHistoryResponse", + "OrderHistoryResponse", + "ExecutorsResponse", + # Rate Oracle models + "RateOracleSourceEnum", + "GlobalTokenConfig", + "RateOracleSourceConfig", + "RateOracleConfig", + "RateOracleConfigResponse", + "RateOracleConfigUpdateRequest", + "RateOracleConfigUpdateResponse", + "RateRequest", + "RateResponse", + "SingleRateResponse", + # Executor models + "CreateExecutorRequest", + "CreateExecutorResponse", + "StopExecutorRequest", + "StopExecutorResponse", + "ExecutorFilterRequest", + "ExecutorResponse", + "ExecutorDetailResponse", + "ExecutorsSummaryResponse", +] diff --git a/models/accounts.py b/models/accounts.py new file mode 100644 index 00000000..4f7b8c2a --- /dev/null +++ b/models/accounts.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel, Field +from typing import Dict, Any + + +class LeverageRequest(BaseModel): + """Request model for setting leverage on perpetual connectors""" + trading_pair: str = Field(description="Trading pair (e.g., BTC-USDT)") + leverage: int = Field(description="Leverage value (typically 1-125)", ge=1, le=125) + + +class PositionModeRequest(BaseModel): + """Request model for setting position mode on perpetual connectors""" + position_mode: str = Field(description="Position mode (HEDGE or ONEWAY)") + + +class CredentialRequest(BaseModel): + """Request model for adding connector credentials""" + credentials: Dict[str, Any] = Field(description="Connector credentials dictionary") \ No newline at end of file diff --git a/models/archived_bots.py b/models/archived_bots.py new file mode 100644 index 00000000..cea57962 --- /dev/null +++ b/models/archived_bots.py @@ -0,0 +1,134 @@ +""" +Pydantic models for the archived bots router. + +These models define the request/response schemas for archived bot analysis endpoints. +""" + +from typing import Dict, List, Optional, Any +from datetime import datetime +from pydantic import BaseModel, Field +from enum import Enum + + +class OrderStatus(str, Enum): + """Order status enumeration""" + OPEN = "OPEN" + FILLED = "FILLED" + CANCELLED = "CANCELLED" + FAILED = "FAILED" + + +class DatabaseStatus(BaseModel): + """Database status information""" + db_path: str = Field(description="Path to the database file") + status: Dict[str, Any] = Field(description="Database health status") + healthy: bool = Field(description="Whether the database is healthy") + + +class BotSummary(BaseModel): + """Summary information for an archived bot""" + bot_name: str = Field(description="Name of the bot") + start_time: Optional[datetime] = Field(default=None, description="Bot start time") + end_time: Optional[datetime] = Field(default=None, description="Bot end time") + total_trades: int = Field(default=0, description="Total number of trades") + total_orders: int = Field(default=0, description="Total number of orders") + markets: List[str] = Field(default_factory=list, description="List of traded markets") + strategies: List[str] = Field(default_factory=list, description="List of strategies used") + + +class PerformanceMetrics(BaseModel): + """Performance metrics for an archived bot""" + total_pnl: float = Field(description="Total profit and loss") + total_volume: float = Field(description="Total trading volume") + avg_return: float = Field(description="Average return per trade") + win_rate: float = Field(description="Percentage of winning trades") + sharpe_ratio: Optional[float] = Field(default=None, description="Sharpe ratio") + max_drawdown: Optional[float] = Field(default=None, description="Maximum drawdown") + total_trades: int = Field(description="Total number of trades") + + +class TradeDetail(BaseModel): + """Detailed trade information""" + id: Optional[int] = Field(default=None, description="Trade ID") + config_file_path: str = Field(description="Configuration file path") + strategy: str = Field(description="Strategy name") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + base_asset: str = Field(description="Base asset") + quote_asset: str = Field(description="Quote asset") + timestamp: datetime = Field(description="Trade timestamp") + order_id: str = Field(description="Order ID") + trade_type: str = Field(description="Trade type (BUY/SELL)") + price: float = Field(description="Trade price") + amount: float = Field(description="Trade amount") + trade_fee: Dict[str, float] = Field(description="Trade fees") + exchange_trade_id: str = Field(description="Exchange trade ID") + leverage: Optional[int] = Field(default=None, description="Leverage used") + position: Optional[str] = Field(default=None, description="Position type") + + +class OrderDetail(BaseModel): + """Detailed order information""" + id: Optional[int] = Field(default=None, description="Order ID") + client_order_id: str = Field(description="Client order ID") + exchange_order_id: Optional[str] = Field(default=None, description="Exchange order ID") + trading_pair: str = Field(description="Trading pair") + status: OrderStatus = Field(description="Order status") + order_type: str = Field(description="Order type") + amount: float = Field(description="Order amount") + price: Optional[float] = Field(default=None, description="Order price") + creation_timestamp: datetime = Field(description="Order creation time") + last_update_timestamp: Optional[datetime] = Field(default=None, description="Last update time") + filled_amount: Optional[float] = Field(default=None, description="Filled amount") + leverage: Optional[int] = Field(default=None, description="Leverage used") + position: Optional[str] = Field(default=None, description="Position type") + + +class ExecutorInfo(BaseModel): + """Executor information""" + controller_id: str = Field(description="Controller ID") + timestamp: datetime = Field(description="Timestamp") + type: str = Field(description="Executor type") + controller_config: Dict[str, Any] = Field(description="Controller configuration") + net_pnl_flat: float = Field(description="Net PnL in flat terms") + net_pnl_pct: float = Field(description="Net PnL percentage") + total_executors: int = Field(description="Total number of executors") + total_amount: float = Field(description="Total amount") + total_spent: float = Field(description="Total spent") + + +class ArchivedBotListResponse(BaseModel): + """Response for listing archived bots""" + bots: List[str] = Field(description="List of archived bot database paths") + count: int = Field(description="Total number of archived bots") + + +class BotPerformanceResponse(BaseModel): + """Response for bot performance analysis""" + bot_name: str = Field(description="Bot name") + metrics: PerformanceMetrics = Field(description="Performance metrics") + period_start: Optional[datetime] = Field(default=None, description="Analysis period start") + period_end: Optional[datetime] = Field(default=None, description="Analysis period end") + + +class TradeHistoryResponse(BaseModel): + """Response for trade history""" + trades: List[TradeDetail] = Field(description="List of trades") + total: int = Field(description="Total number of trades") + page: int = Field(description="Current page") + page_size: int = Field(description="Page size") + + +class OrderHistoryResponse(BaseModel): + """Response for order history""" + orders: List[OrderDetail] = Field(description="List of orders") + total: int = Field(description="Total number of orders") + page: int = Field(description="Current page") + page_size: int = Field(description="Page size") + filtered_by_status: Optional[OrderStatus] = Field(default=None, description="Status filter applied") + + +class ExecutorsResponse(BaseModel): + """Response for executors information""" + executors: List[ExecutorInfo] = Field(description="List of executors") + total: int = Field(description="Total number of executors") \ No newline at end of file diff --git a/models/backtesting.py b/models/backtesting.py new file mode 100644 index 00000000..c3cb5bd5 --- /dev/null +++ b/models/backtesting.py @@ -0,0 +1,10 @@ +from typing import Dict, Union +from pydantic import BaseModel + + +class BacktestingConfig(BaseModel): + start_time: int = 1735689600 # 2025-01-01 00:00:00 + end_time: int = 1738368000 # 2025-02-01 00:00:00 + backtesting_resolution: str = "1m" + trade_cost: float = 0.0006 + config: Union[Dict, str] \ No newline at end of file diff --git a/models/bot_orchestration.py b/models/bot_orchestration.py new file mode 100644 index 00000000..d71252b0 --- /dev/null +++ b/models/bot_orchestration.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class BotAction(BaseModel): + """Base class for bot actions""" + bot_name: str = Field(description="Name of the bot instance to act upon") + + +class StartBotAction(BotAction): + """Action to start a bot""" + log_level: Optional[str] = Field(default=None, description="Logging level (DEBUG, INFO, WARNING, ERROR)") + script: Optional[str] = Field(default=None, description="Script name to run (without .py extension)") + conf: Optional[str] = Field(default=None, description="Configuration file name (without .yml extension)") + async_backend: bool = Field(default=False, description="Whether to run in async backend mode") + + +class StopBotAction(BotAction): + """Action to stop a bot""" + skip_order_cancellation: bool = Field(default=False, description="Whether to skip cancelling open orders when stopping") + async_backend: bool = Field(default=False, description="Whether to run in async backend mode") + + +class ImportStrategyAction(BotAction): + """Action to import a strategy for a bot""" + strategy: str = Field(description="Name of the strategy to import") + + +class ConfigureBotAction(BotAction): + """Action to configure bot parameters""" + params: dict = Field(description="Configuration parameters to update") + + +class ShortcutAction(BotAction): + """Action to execute bot shortcuts""" + params: list = Field(description="List of shortcut parameters") + + +class BotStatus(BaseModel): + """Status information for a bot""" + bot_name: str = Field(description="Bot name") + status: str = Field(description="Bot status (running, stopped, etc.)") + uptime: Optional[float] = Field(None, description="Bot uptime in seconds") + performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics") + + +class BotHistoryRequest(BaseModel): + """Request for bot trading history""" + bot_name: str = Field(description="Bot name") + days: int = Field(default=0, description="Number of days of history (0 for all)") + verbose: bool = Field(default=False, description="Include verbose information") + precision: Optional[int] = Field(None, description="Decimal precision for numbers") + timeout: float = Field(default=30.0, description="Request timeout in seconds") + + +class BotHistoryResponse(BaseModel): + """Response for bot trading history""" + bot_name: str = Field(description="Bot name") + history: Dict[str, Any] = Field(description="Trading history data") + status: str = Field(description="Response status") + + +class MQTTStatus(BaseModel): + """MQTT connection status""" + mqtt_connected: bool = Field(description="Whether MQTT is connected") + discovered_bots: List[str] = Field(description="List of discovered bots") + active_bots: List[str] = Field(description="List of active bots") + broker_host: str = Field(description="MQTT broker host") + broker_port: int = Field(description="MQTT broker port") + broker_username: Optional[str] = Field(None, description="MQTT broker username") + client_state: str = Field(description="MQTT client state") + + +class AllBotsStatusResponse(BaseModel): + """Response for all bots status""" + bots: List[BotStatus] = Field(description="List of bot statuses") + + +class StopAndArchiveRequest(BaseModel): + """Request for stopping and archiving a bot""" + skip_order_cancellation: bool = Field(default=True, description="Skip order cancellation") + async_backend: bool = Field(default=True, description="Use async backend") + archive_locally: bool = Field(default=True, description="Archive locally") + s3_bucket: Optional[str] = Field(None, description="S3 bucket for archiving") + timeout: float = Field(default=30.0, description="Operation timeout") + + +class StopAndArchiveResponse(BaseModel): + """Response for stop and archive operation""" + status: str = Field(description="Operation status") + message: str = Field(description="Status message") + details: Dict[str, Any] = Field(description="Operation details") + + +# Bot deployment models +class V2ControllerDeployment(BaseModel): + """Configuration for deploying a bot with controllers""" + instance_name: str = Field(description="Unique name for the bot instance") + credentials_profile: str = Field(description="Name of the credentials profile to use") + controllers_config: List[str] = Field( + description="List of controller configuration files to use (without .yml extension)" + ) + max_global_drawdown_quote: Optional[float] = Field( + default=None, description="Maximum allowed global drawdown in quote usually USDT" + ) + max_controller_drawdown_quote: Optional[float] = Field( + default=None, description="Maximum allowed per-controller drawdown in quote usually USDT" + ) + image: str = Field(default="hummingbot/hummingbot:latest", description="Docker image for the Hummingbot instance") + script_config: Optional[str] = Field(default=None, description="Generated script configuration file name") + headless: bool = Field(default=False, description="Run in headless mode (no UI)") diff --git a/models/connectors.py b/models/connectors.py new file mode 100644 index 00000000..eea431fe --- /dev/null +++ b/models/connectors.py @@ -0,0 +1,56 @@ +""" +Pydantic models for the connectors router. + +These models define the request/response schemas for connector-related endpoints. +""" + +from typing import Dict, List, Any, Optional +from pydantic import BaseModel, Field + + +class ConnectorInfo(BaseModel): + """Information about a connector""" + name: str = Field(description="Connector name") + is_perpetual: bool = Field(default=False, description="Whether the connector supports perpetual trading") + supported_order_types: Optional[List[str]] = Field(default=None, description="Supported order types") + + +class ConnectorConfigMapResponse(BaseModel): + """Response for connector configuration requirements""" + connector_name: str = Field(description="Name of the connector") + config_fields: List[str] = Field(description="List of required configuration fields") + + +class TradingRule(BaseModel): + """Trading rules for a specific trading pair""" + min_order_size: float = Field(description="Minimum order size") + max_order_size: float = Field(description="Maximum order size") + min_price_increment: float = Field(description="Minimum price increment") + min_base_amount_increment: float = Field(description="Minimum base amount increment") + min_quote_amount_increment: float = Field(description="Minimum quote amount increment") + min_notional_size: float = Field(description="Minimum notional size") + min_order_value: float = Field(description="Minimum order value") + max_price_significant_digits: float = Field(description="Maximum price significant digits") + supports_limit_orders: bool = Field(description="Whether limit orders are supported") + supports_market_orders: bool = Field(description="Whether market orders are supported") + buy_order_collateral_token: str = Field(description="Collateral token for buy orders") + sell_order_collateral_token: str = Field(description="Collateral token for sell orders") + + +class ConnectorTradingRulesResponse(BaseModel): + """Response for connector trading rules""" + connector: str = Field(description="Connector name") + trading_pairs: Optional[List[str]] = Field(default=None, description="Filtered trading pairs if provided") + rules: Dict[str, TradingRule] = Field(description="Trading rules by trading pair") + + +class ConnectorOrderTypesResponse(BaseModel): + """Response for supported order types""" + connector: str = Field(description="Connector name") + supported_order_types: List[str] = Field(description="List of supported order types") + + +class ConnectorListResponse(BaseModel): + """Response for list of available connectors""" + connectors: List[str] = Field(description="List of available connector names") + count: int = Field(description="Total number of connectors") \ No newline at end of file diff --git a/models/controllers.py b/models/controllers.py new file mode 100644 index 00000000..a2eeff82 --- /dev/null +++ b/models/controllers.py @@ -0,0 +1,52 @@ +from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field +from enum import Enum + + +class ControllerType(str, Enum): + """Types of controllers available""" + DIRECTIONAL_TRADING = "directional_trading" + MARKET_MAKING = "market_making" + GENERIC = "generic" + + +# Controller file operations +class Controller(BaseModel): + """Controller file content""" + content: str = Field(description="Controller source code") + type: Optional[ControllerType] = Field(None, description="Controller type (optional for flexibility)") + + +class ControllerResponse(BaseModel): + """Response for getting a controller""" + name: str = Field(description="Controller name") + type: str = Field(description="Controller type") + content: str = Field(description="Controller source code") + + +# Controller configuration operations +class ControllerConfig(BaseModel): + """Controller configuration""" + controller_name: str = Field(description="Controller name") + controller_type: str = Field(description="Controller type") + connector_name: Optional[str] = Field(None, description="Connector name") + trading_pair: Optional[str] = Field(None, description="Trading pair") + total_amount_quote: Optional[float] = Field(None, description="Total amount in quote currency") + + +class ControllerConfigResponse(BaseModel): + """Response for controller configuration with metadata""" + config_name: str = Field(description="Configuration name") + controller_name: str = Field(description="Controller name") + controller_type: str = Field(description="Controller type") + connector_name: Optional[str] = Field(None, description="Connector name") + trading_pair: Optional[str] = Field(None, description="Trading pair") + total_amount_quote: Optional[float] = Field(None, description="Total amount in quote currency") + error: Optional[str] = Field(None, description="Error message if config is malformed") + + +# Bot-specific controller configurations +class BotControllerConfig(BaseModel): + """Controller configuration for a specific bot""" + config_name: str = Field(description="Configuration name") + config_data: Dict[str, Any] = Field(description="Configuration data") \ No newline at end of file diff --git a/models/docker.py b/models/docker.py new file mode 100644 index 00000000..b18fb768 --- /dev/null +++ b/models/docker.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel, Field + + +class DockerImage(BaseModel): + image_name: str = Field(description="Docker image name with optional tag (e.g., 'hummingbot/hummingbot:latest')") \ No newline at end of file diff --git a/models/executors.py b/models/executors.py new file mode 100644 index 00000000..1c1f1941 --- /dev/null +++ b/models/executors.py @@ -0,0 +1,471 @@ +""" +Pydantic models for executor API endpoints. + +These models wrap Hummingbot's executor configuration types and provide +validation for the REST API. +""" +from datetime import datetime +from decimal import Decimal +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, computed_field + +from .pagination import PaginationParams + +# ======================================== +# Position Hold for Aggregated Tracking +# ======================================== + + +class PositionHold(BaseModel): + """ + Tracks aggregated position from executors stopped with keep_position=True. + + Similar to hummingbot's PositionHold, this tracks: + - Separate buy/sell amounts for proper breakeven calculation + - Matched volume (realized PnL) vs unmatched volume (unrealized PnL) + - Aggregation across multiple executors on the same trading pair + """ + model_config = ConfigDict(arbitrary_types_allowed=True) + + trading_pair: str = Field(description="Trading pair (e.g., 'BTC-USDT')") + connector_name: str = Field(description="Connector name") + account_name: str = Field(description="Account name") + controller_id: str = Field(default="main", description="Controller that owns this position") + + # Buy side tracking + buy_amount_base: Decimal = Field(default=Decimal("0"), description="Total bought amount in base currency") + buy_amount_quote: Decimal = Field(default=Decimal("0"), description="Total spent on buys in quote currency") + + # Sell side tracking + sell_amount_base: Decimal = Field(default=Decimal("0"), description="Total sold amount in base currency") + sell_amount_quote: Decimal = Field(default=Decimal("0"), description="Total received from sells in quote currency") + + # Realized PnL from matched positions + realized_pnl_quote: Decimal = Field(default=Decimal("0"), description="Realized PnL from matched buy/sell pairs") + + # Tracking + executor_ids: List[str] = Field(default_factory=list, description="IDs of executors contributing to this position") + last_updated: Optional[datetime] = Field(default=None, description="Last update timestamp") + + @computed_field + @property + def net_amount_base(self) -> Decimal: + """Net position in base currency (positive = long, negative = short).""" + return self.buy_amount_base - self.sell_amount_base + + @computed_field + @property + def buy_breakeven_price(self) -> Optional[Decimal]: + """Average buy price (breakeven for long position).""" + if self.buy_amount_base > 0: + return self.buy_amount_quote / self.buy_amount_base + return None + + @computed_field + @property + def sell_breakeven_price(self) -> Optional[Decimal]: + """Average sell price (breakeven for short position).""" + if self.sell_amount_base > 0: + return self.sell_amount_quote / self.sell_amount_base + return None + + @computed_field + @property + def matched_amount_base(self) -> Decimal: + """Amount that has been matched (min of buy/sell).""" + return min(self.buy_amount_base, self.sell_amount_base) + + @computed_field + @property + def unmatched_amount_base(self) -> Decimal: + """Absolute unmatched position size.""" + return abs(self.net_amount_base) + + @computed_field + @property + def position_side(self) -> Optional[str]: + """Current position side: LONG, SHORT, or FLAT.""" + if self.net_amount_base > 0: + return "LONG" + elif self.net_amount_base < 0: + return "SHORT" + return "FLAT" + + def add_fill( + self, + side: str, + amount_base: Decimal, + amount_quote: Decimal, + executor_id: Optional[str] = None + ): + """ + Add a fill to the position tracking. + + Args: + side: "BUY" or "SELL" + amount_base: Amount in base currency + amount_quote: Amount in quote currency + executor_id: Optional executor ID to track + """ + if side.upper() == "BUY": + self.buy_amount_base += amount_base + self.buy_amount_quote += amount_quote + else: + self.sell_amount_base += amount_base + self.sell_amount_quote += amount_quote + + # Calculate realized PnL when we have matched volume + self._calculate_realized_pnl() + + if executor_id and executor_id not in self.executor_ids: + self.executor_ids.append(executor_id) + + self.last_updated = datetime.utcnow() + + def _calculate_realized_pnl(self): + """Calculate realized PnL from matched buy/sell pairs using FIFO.""" + matched = self.matched_amount_base + if matched > 0 and self.buy_amount_base > 0 and self.sell_amount_base > 0: + # Average prices + avg_buy = self.buy_amount_quote / self.buy_amount_base + avg_sell = self.sell_amount_quote / self.sell_amount_base + # Realized PnL = matched_amount * (avg_sell - avg_buy) + self.realized_pnl_quote = matched * (avg_sell - avg_buy) + + def get_unrealized_pnl(self, current_price: Decimal) -> Decimal: + """ + Calculate unrealized PnL for unmatched position. + + Args: + current_price: Current market price + + Returns: + Unrealized PnL in quote currency + """ + if self.net_amount_base > 0: + # Long position: profit if price goes up + avg_buy = self.buy_breakeven_price or Decimal("0") + return self.net_amount_base * (current_price - avg_buy) + elif self.net_amount_base < 0: + # Short position: profit if price goes down + avg_sell = self.sell_breakeven_price or Decimal("0") + return abs(self.net_amount_base) * (avg_sell - current_price) + return Decimal("0") + + def merge(self, other: "PositionHold"): + """Merge another PositionHold into this one.""" + self.buy_amount_base += other.buy_amount_base + self.buy_amount_quote += other.buy_amount_quote + self.sell_amount_base += other.sell_amount_base + self.sell_amount_quote += other.sell_amount_quote + + for eid in other.executor_ids: + if eid not in self.executor_ids: + self.executor_ids.append(eid) + + self._calculate_realized_pnl() + self.last_updated = datetime.utcnow() + + +class PositionHoldResponse(BaseModel): + """API response model for PositionHold.""" + trading_pair: str + connector_name: str + account_name: str + controller_id: str = Field(default="main", description="Controller that owns this position") + buy_amount_base: float + buy_amount_quote: float + sell_amount_base: float + sell_amount_quote: float + net_amount_base: float + buy_breakeven_price: Optional[float] + sell_breakeven_price: Optional[float] + matched_amount_base: float + unmatched_amount_base: float + position_side: Optional[str] + realized_pnl_quote: float + unrealized_pnl_quote: Optional[float] = None + executor_count: int + executor_ids: List[str] + last_updated: Optional[str] + + +class PositionsSummaryResponse(BaseModel): + """Summary of all held positions.""" + total_positions: int = Field(description="Number of active position holds") + total_realized_pnl: float = Field(description="Total realized PnL across all positions") + total_unrealized_pnl: Optional[float] = Field( + default=None, description="Total unrealized PnL (None if no rates available)" + ) + positions: List[PositionHoldResponse] = Field(description="List of position holds") + + +# ======================================== +# Executor Type Definitions +# ======================================== + +EXECUTOR_TYPES = Literal[ + "position_executor", + "grid_executor", + "dca_executor", + "arbitrage_executor", + "twap_executor", + "xemm_executor", + "order_executor", + "lp_executor" +] + + +# ======================================== +# API Request Models +# ======================================== + +class CreateExecutorRequest(BaseModel): + """Request to create a new executor.""" + model_config = ConfigDict( + json_schema_extra={ + "examples": [ + { + "summary": "Position Executor", + "description": "Create a position executor with triple barrier", + "value": { + "account_name": "master_account", + "executor_config": { + "type": "position_executor", + "connector_name": "binance_perpetual", + "trading_pair": "BTC-USDT", + "side": "BUY", + "amount": "0.01", + "leverage": 10, + "triple_barrier_config": { + "stop_loss": "0.02", + "take_profit": "0.04", + "time_limit": 3600 + } + } + } + }, + { + "summary": "LP Executor", + "description": "Create an LP position on a CLMM DEX (Meteora, Raydium)", + "value": { + "account_name": "master_account", + "executor_config": { + "type": "lp_executor", + "connector_name": "meteora/clmm", + "trading_pair": "SOL-USDC", + "pool_address": "HTvjzsfX3yU6BUodCjZ5vZkUrAxMDTrBs3CJaq43ashR", + "lower_price": "80", + "upper_price": "100", + "base_amount": "0", + "quote_amount": "10.0", + "side": 1, + "auto_close_above_range_seconds": None, + "auto_close_below_range_seconds": 300, + "extra_params": {"strategyType": 0}, + "keep_position": False + } + } + } + ] + } + ) + + account_name: Optional[str] = Field( + None, + description="Account name to use (defaults to master_account)" + ) + controller_id: str = Field( + default="main", + description="Controller ID that owns this executor (for per-agent isolation)" + ) + executor_config: Dict[str, Any] = Field( + ..., + description="Executor configuration. Must include 'type' field and executor-specific parameters." + ) + + +class StopExecutorRequest(BaseModel): + """Request to stop an executor.""" + keep_position: bool = Field( + default=False, + description="Whether to keep the position open (for position executors)" + ) + + +class ExecutorFilterRequest(PaginationParams): + """Request to filter and list executors.""" + account_names: Optional[List[str]] = Field( + None, + description="Filter by account names" + ) + connector_names: Optional[List[str]] = Field( + None, + description="Filter by connector names" + ) + trading_pairs: Optional[List[str]] = Field( + None, + description="Filter by trading pairs" + ) + executor_types: Optional[List[EXECUTOR_TYPES]] = Field( + None, + description="Filter by executor types" + ) + status: Optional[str] = Field( + None, + description="Filter by status (RUNNING, TERMINATED, etc.)" + ) + controller_ids: Optional[List[str]] = Field( + None, + description="Filter by controller IDs" + ) + + +# ======================================== +# API Response Models +# ======================================== + +class ExecutorResponse(BaseModel): + """Response for a single executor (summary view).""" + model_config = ConfigDict( + json_schema_extra={ + "example": { + "executor_id": "abc123...", + "executor_type": "position_executor", + "account_name": "master_account", + "connector_name": "binance_perpetual", + "trading_pair": "BTC-USDT", + "side": "BUY", + "status": "RUNNING", + "is_active": True, + "is_trading": True, + "timestamp": 1705315800.0, + "created_at": "2024-01-15T10:30:00Z", + "close_type": None, + "close_timestamp": None, + "controller_id": None, + "net_pnl_quote": 125.50, + "net_pnl_pct": 2.5, + "cum_fees_quote": 1.25, + "filled_amount_quote": 5000.0 + } + } + ) + + executor_id: str = Field(description="Unique executor identifier") + executor_type: Optional[str] = Field(description="Type of executor") + account_name: Optional[str] = Field(description="Account name") + connector_name: Optional[str] = Field(description="Connector name") + trading_pair: Optional[str] = Field(description="Trading pair") + side: Optional[str] = Field(None, description="Trade side (BUY/SELL) if applicable") + status: str = Field(description="Current status (RUNNING, TERMINATED, etc.)") + is_active: bool = Field(description="Whether the executor is active") + is_trading: bool = Field(description="Whether the executor has open trades") + timestamp: Optional[float] = Field(None, description="Creation timestamp (Unix)") + created_at: Optional[str] = Field(None, description="Creation timestamp (ISO format)") + close_type: Optional[str] = Field(None, description="How the executor was closed (if applicable)") + close_timestamp: Optional[float] = Field(None, description="Close timestamp (Unix)") + controller_id: Optional[str] = Field(None, description="ID of the controller that spawned this executor") + net_pnl_quote: float = Field(description="Net PnL in quote currency") + net_pnl_pct: float = Field(description="Net PnL percentage") + cum_fees_quote: float = Field(description="Cumulative fees in quote currency") + filled_amount_quote: float = Field(description="Total filled amount in quote currency") + error_count: int = Field(default=0, description="Number of ERROR-level log entries captured") + last_error: Optional[str] = Field(default=None, description="Most recent error message, if any") + + +class ExecutorDetailResponse(ExecutorResponse): + """Detailed response for a single executor.""" + config: Optional[Dict[str, Any]] = Field( + None, + description="Full executor configuration" + ) + custom_info: Optional[Dict[str, Any]] = Field( + None, + description="Executor-specific custom information" + ) + + +class CreateExecutorResponse(BaseModel): + """Response after creating an executor.""" + executor_id: str = Field(description="Unique executor identifier") + executor_type: str = Field(description="Type of executor created") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + controller_id: str = Field(default="main", description="Controller that owns this executor") + status: str = Field(description="Initial status") + created_at: str = Field(description="Creation timestamp (ISO format)") + + +class StopExecutorResponse(BaseModel): + """Response after stopping an executor.""" + executor_id: str = Field(description="Executor identifier") + status: str = Field(description="New status (usually 'stopping')") + keep_position: bool = Field(description="Whether position was kept open") + + +class ExecutorsSummaryResponse(BaseModel): + """Summary of active executors.""" + model_config = ConfigDict( + json_schema_extra={ + "example": { + "total_active": 5, + "total_pnl_quote": 1234.56, + "total_volume_quote": 50000.00, + "by_type": {"position_executor": 3, "grid_executor": 2}, + "by_connector": {"binance_perpetual": 4, "binance": 1}, + "by_status": {"RUNNING": 5} + } + } + ) + + total_active: int = Field(description="Number of active executors") + total_pnl_quote: float = Field(description="Total PnL across active executors") + total_volume_quote: float = Field(description="Total volume across active executors") + by_type: Dict[str, int] = Field(description="Executor count by type") + by_connector: Dict[str, int] = Field(description="Executor count by connector") + by_status: Dict[str, int] = Field(description="Executor count by status") + + +class ExecutorTypeBreakdown(BaseModel): + """Performance breakdown for a single executor type.""" + executor_type: str = Field(description="Executor type name") + total: int = Field(description="Total executors of this type") + completed: int = Field(description="Completed executors") + running: int = Field(description="Currently running executors") + pnl_quote: float = Field(description="Net PnL in quote currency") + volume_quote: float = Field(description="Total filled volume in quote currency") + fees_quote: float = Field(description="Cumulative fees in quote currency") + + +class PerformanceReportResponse(BaseModel): + """Performance report for executors, optionally filtered by controller_id.""" + controller_id: Optional[str] = Field(None, description="Controller ID filter (None = all)") + total_executors: int = Field(description="Total executor count") + by_status: Dict[str, int] = Field(description="Executor count by status") + pnl_total_quote: float = Field(description="Realized PnL from completed executors in quote currency") + unrealized_pnl_quote: float = Field(description="Unrealized PnL from active executors and position holds") + global_pnl_quote: float = Field(description="Global PnL (realized + unrealized)") + pnl_pct_avg: float = Field(description="Average PnL percentage across completed executors") + fees_total_quote: float = Field(description="Total cumulative fees in quote currency") + volume_total_quote: float = Field(description="Total filled volume in quote currency") + win_rate: float = Field(description="Win rate: fraction of completed executors with positive PnL") + sharpe_ratio: Optional[float] = Field(None, description="Sharpe ratio of PnL returns (null if <2 executors)") + by_type: List[ExecutorTypeBreakdown] = Field(description="Performance breakdown by executor type") + active_positions: int = Field(description="Number of active position holds") + + +class ExecutorLogEntry(BaseModel): + """A single log entry from an executor.""" + timestamp: str = Field(description="ISO-format timestamp") + level: str = Field(description="Log level (DEBUG, INFO, WARNING, ERROR)") + message: str = Field(description="Log message") + exc_info: Optional[str] = Field(default=None, description="Exception traceback if present") + + +class ExecutorLogsResponse(BaseModel): + """Response for executor log entries.""" + executor_id: str = Field(description="Executor identifier") + logs: List[ExecutorLogEntry] = Field(description="Log entries") + total_count: int = Field(description="Total number of log entries (before limit)") diff --git a/models/gateway.py b/models/gateway.py new file mode 100644 index 00000000..11b97132 --- /dev/null +++ b/models/gateway.py @@ -0,0 +1,99 @@ +from pydantic import BaseModel, Field +from typing import Optional, List + + +# ============================================ +# Container Management Models +# ============================================ + +class GatewayConfig(BaseModel): + """Configuration for Gateway container deployment""" + passphrase: str = Field(description="Gateway passphrase for configuration encryption") + image: str = Field(default="hummingbot/gateway:latest", description="Docker image for Gateway") + port: int = Field(default=15888, description="Port for Gateway API") + dev_mode: bool = Field(default=True, description="Enable development mode") + + +class GatewayStatus(BaseModel): + """Status information for Gateway instance""" + running: bool = Field(description="Whether Gateway container is running") + container_id: Optional[str] = Field(default=None, description="Container ID if running") + image: Optional[str] = Field(default=None, description="Image used for the container") + created_at: Optional[str] = Field(default=None, description="Container creation timestamp") + port: Optional[int] = Field(default=None, description="Port Gateway is running on") + + +# ============================================ +# Wallet Management Models +# ============================================ + +class CreateWalletRequest(BaseModel): + """Request to create a new wallet in Gateway""" + chain: str = Field(description="Blockchain chain (e.g., 'solana', 'ethereum')") + set_default: bool = Field(default=True, description="Set as default wallet for this chain") + + +class ShowPrivateKeyRequest(BaseModel): + """Request to show private key for a wallet""" + chain: str = Field(description="Blockchain chain (e.g., 'solana', 'ethereum')") + address: str = Field(description="Wallet address") + passphrase: str = Field(description="Gateway passphrase for decryption") + + +class SendTransactionRequest(BaseModel): + """Request to send a native token transaction""" + chain: str = Field(description="Blockchain chain (e.g., 'solana', 'ethereum')") + network: str = Field(description="Network (e.g., 'mainnet-beta', 'mainnet')") + address: str = Field(description="Sender wallet address") + to_address: str = Field(description="Recipient address") + amount: str = Field(description="Amount to send (in native token units)") + + +class GatewayWalletCredential(BaseModel): + """Credentials for connecting a Gateway wallet""" + chain: str = Field(description="Blockchain chain (e.g., 'solana', 'ethereum')") + private_key: str = Field(description="Wallet private key") + network: Optional[str] = Field(default=None, description="Network to use (defaults to chain's default)") + + +class GatewayWalletInfo(BaseModel): + """Information about a connected Gateway wallet""" + chain: str = Field(description="Blockchain chain") + address: str = Field(description="Wallet address") + network: str = Field(description="Network the wallet is configured for") + + +# ============================================ +# Pool and Token Management Models +# ============================================ + +class AddPoolRequest(BaseModel): + """Request to add a liquidity pool""" + connector_name: str = Field(description="DEX connector name (e.g., 'raydium', 'meteora')") + type: str = Field(description="Pool type ('clmm' or 'amm')") + network: str = Field(description="Network name (e.g., 'mainnet-beta')") + address: str = Field(description="Pool contract address") + base: str = Field(description="Base token symbol") + quote: str = Field(description="Quote token symbol") + base_address: str = Field(description="Base token contract address") + quote_address: str = Field(description="Quote token contract address") + fee_pct: Optional[float] = Field(default=None, description="Pool fee percentage (e.g., 0.25)") + + +class AddTokenRequest(BaseModel): + """Request to add a custom token to Gateway""" + address: str = Field(description="Token contract address") + symbol: str = Field(description="Token symbol") + name: Optional[str] = Field(default=None, description="Token name (defaults to symbol)") + decimals: int = Field(description="Number of decimals for the token") + + +# ============================================ +# Balance Query Models +# ============================================ + +class GatewayBalanceRequest(BaseModel): + """Request for Gateway wallet balances""" + account_name: str = Field(description="Account name") + chain: str = Field(description="Blockchain chain") + tokens: Optional[List[str]] = Field(default=None, description="List of token symbols to query (optional)") diff --git a/models/gateway_trading.py b/models/gateway_trading.py new file mode 100644 index 00000000..02d7c422 --- /dev/null +++ b/models/gateway_trading.py @@ -0,0 +1,335 @@ +""" +Models for Gateway DEX trading operations. +Supports swaps via routers (Jupiter, 0x) and CLMM liquidity positions (Meteora, Raydium, Uniswap V3). + +Note: AMM support has been removed. Use Router for simple swaps, CLMM for liquidity provision. +""" +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field +from decimal import Decimal + + +# ============================================ +# Swap Models (Router: Jupiter, 0x) +# ============================================ + +class SwapQuoteRequest(BaseModel): + """Request for swap price quote""" + connector: str = Field(description="DEX router connector (e.g., 'jupiter', '0x')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta', 'ethereum-mainnet')") + trading_pair: str = Field(description="Trading pair in BASE-QUOTE format (e.g., 'SOL-USDC')") + side: str = Field(description="Trade side: 'BUY' or 'SELL'") + amount: Decimal = Field(description="Amount to swap (in base token for SELL, quote token for BUY)") + slippage_pct: Optional[Decimal] = Field(default=1.0, description="Maximum slippage percentage (default: 1.0)") + + +class SwapQuoteResponse(BaseModel): + """Response with swap quote details""" + base: str = Field(description="Base token symbol") + quote: str = Field(description="Quote token symbol") + price: Decimal = Field(description="Quoted price (base/quote)") + amount: Decimal = Field(description="Amount specified in request (BUY: base amount to receive, SELL: base amount to sell)") + amount_in: Optional[Decimal] = Field(default=None, description="Actual input amount (BUY: quote to spend, SELL: base to sell)") + amount_out: Optional[Decimal] = Field(default=None, description="Actual output amount (BUY: base to receive, SELL: quote to receive)") + expected_amount: Optional[Decimal] = Field(default=None, description="Deprecated: use amount_out instead") + slippage_pct: Decimal = Field(description="Applied slippage percentage") + gas_estimate: Optional[Decimal] = Field(default=None, description="Estimated gas cost") + + +class SwapExecuteRequest(BaseModel): + """Request to execute a swap""" + connector: str = Field(description="DEX router connector (e.g., 'jupiter', '0x')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + trading_pair: str = Field(description="Trading pair (e.g., 'SOL-USDC')") + side: str = Field(description="Trade side: 'BUY' or 'SELL'") + amount: Decimal = Field(description="Amount to swap") + slippage_pct: Optional[Decimal] = Field(default=1.0, description="Maximum slippage percentage (default: 1.0)") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + +class SwapExecuteResponse(BaseModel): + """Response after executing swap""" + transaction_hash: str = Field(description="Transaction hash") + trading_pair: str = Field(description="Trading pair") + side: str = Field(description="Trade side") + amount: Decimal = Field(description="Amount swapped") + status: str = Field(default="submitted", description="Transaction status") + + +# ============================================ +# CLMM Liquidity Models (Meteora, Raydium, Uniswap V3) +# ============================================ + +class CLMMOpenPositionRequest(BaseModel): + """Request to open a new CLMM position with initial liquidity""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + pool_address: str = Field(description="Pool contract address") + + # Position range + lower_price: Decimal = Field(description="Lower price for position range") + upper_price: Decimal = Field(description="Upper price for position range") + + # Initial liquidity + base_token_amount: Optional[Decimal] = Field(default=None, description="Amount of base token to add") + quote_token_amount: Optional[Decimal] = Field(default=None, description="Amount of quote token to add") + slippage_pct: Optional[Decimal] = Field(default=1.0, description="Maximum slippage percentage (default: 1.0)") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + # Connector-specific parameters (e.g., strategyType for Meteora) + extra_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional connector-specific parameters") + + +class CLMMOpenPositionResponse(BaseModel): + """Response after opening a new CLMM position""" + transaction_hash: str = Field(description="Transaction hash") + position_address: str = Field(description="Address of the newly created position") + trading_pair: str = Field(description="Trading pair") + pool_address: str = Field(description="Pool address") + lower_price: Decimal = Field(description="Lower price bound") + upper_price: Decimal = Field(description="Upper price bound") + status: str = Field(default="submitted", description="Transaction status") + + +class CLMMAddLiquidityRequest(BaseModel): + """Request to add MORE liquidity to an EXISTING CLMM position""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + position_address: str = Field(description="Existing position address to add liquidity to") + base_token_amount: Optional[Decimal] = Field(default=None, description="Amount of base token to add") + quote_token_amount: Optional[Decimal] = Field(default=None, description="Amount of quote token to add") + slippage_pct: Optional[Decimal] = Field(default=1.0, description="Maximum slippage percentage (default: 1.0)") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + +class CLMMRemoveLiquidityRequest(BaseModel): + """Request to remove SOME liquidity from a CLMM position (partial removal)""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + position_address: str = Field(description="Position address to remove liquidity from") + percentage: Decimal = Field(description="Percentage of liquidity to remove (0-100)") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + +class CLMMClosePositionRequest(BaseModel): + """Request to CLOSE a CLMM position completely (removes all liquidity and closes position)""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + position_address: str = Field(description="Position address to close") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + +class CLMMCollectFeesRequest(BaseModel): + """Request to collect fees from a CLMM position""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + position_address: str = Field(description="Position address to collect fees from") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + +class CLMMCollectFeesResponse(BaseModel): + """Response after collecting fees""" + transaction_hash: str = Field(description="Transaction hash") + position_address: str = Field(description="Position address") + base_fee_collected: Optional[Decimal] = Field(default=None, description="Base token fees collected") + quote_fee_collected: Optional[Decimal] = Field(default=None, description="Quote token fees collected") + status: str = Field(default="submitted", description="Transaction status") + + +class CLMMPositionsOwnedRequest(BaseModel): + """Request to get all CLMM positions owned by a wallet for a specific pool""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + pool_address: str = Field(description="Pool contract address to filter positions") + wallet_address: Optional[str] = Field(default=None, description="Wallet address (optional, uses default if not provided)") + + +class CLMMPositionInfo(BaseModel): + """Information about a CLMM liquidity position""" + position_address: str = Field(description="Position address") + pool_address: str = Field(description="Pool address") + trading_pair: str = Field(description="Trading pair") + base_token: str = Field(description="Base token symbol") + quote_token: str = Field(description="Quote token symbol") + base_token_amount: Decimal = Field(description="Base token amount in position") + quote_token_amount: Decimal = Field(description="Quote token amount in position") + current_price: Decimal = Field(description="Current pool price") + lower_price: Decimal = Field(description="Lower price bound") + upper_price: Decimal = Field(description="Upper price bound") + base_fee_amount: Optional[Decimal] = Field(default=None, description="Base token uncollected fees") + quote_fee_amount: Optional[Decimal] = Field(default=None, description="Quote token uncollected fees") + lower_bin_id: Optional[int] = Field(default=None, description="Lower bin ID (Meteora)") + upper_bin_id: Optional[int] = Field(default=None, description="Upper bin ID (Meteora)") + in_range: bool = Field(description="Whether position is currently in range") + + +class CLMMGetPositionInfoRequest(BaseModel): + """Request to get detailed info about a specific CLMM position""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium', 'uniswap')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + position_address: str = Field(description="Position address to query") + + +class CLMMPoolInfoRequest(BaseModel): + """Request to get CLMM pool information by pool address""" + connector: str = Field(description="CLMM connector (e.g., 'meteora', 'raydium')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + pool_address: str = Field(description="Pool contract address") + + +class CLMMPoolBin(BaseModel): + """Individual bin in a CLMM pool (e.g., Meteora)""" + bin_id: int = Field(alias="binId", description="Bin identifier") + price: Decimal = Field(description="Price at this bin") + base_token_amount: Decimal = Field(alias="baseTokenAmount", description="Base token amount in bin") + quote_token_amount: Decimal = Field(alias="quoteTokenAmount", description="Quote token amount in bin") + + model_config = { + "populate_by_name": True, + "json_schema_extra": { + "example": { + "bin_id": -374, + "price": 0.47366592950616504, + "base_token_amount": 19656.740028, + "quote_token_amount": 18197.718539 + } + } + } + + +class CLMMPoolInfoResponse(BaseModel): + """Response with detailed CLMM pool information""" + address: str = Field(description="Pool address") + base_token_address: str = Field(alias="baseTokenAddress", description="Base token contract address") + quote_token_address: str = Field(alias="quoteTokenAddress", description="Quote token contract address") + bin_step: Optional[int] = Field(None, alias="binStep", description="Bin step (Meteora DLMM only)") + fee_pct: Decimal = Field(alias="feePct", description="Pool fee percentage") + price: Decimal = Field(description="Current pool price") + base_token_amount: Decimal = Field(alias="baseTokenAmount", description="Total base token liquidity") + quote_token_amount: Decimal = Field(alias="quoteTokenAmount", description="Total quote token liquidity") + active_bin_id: Optional[int] = Field(None, alias="activeBinId", description="Currently active bin ID (Meteora DLMM only)") + dynamic_fee_pct: Optional[Decimal] = Field(None, alias="dynamicFeePct", description="Dynamic fee percentage") + min_bin_id: Optional[int] = Field(None, alias="minBinId", description="Minimum bin ID (Meteora-specific)") + max_bin_id: Optional[int] = Field(None, alias="maxBinId", description="Maximum bin ID (Meteora-specific)") + bins: List[CLMMPoolBin] = Field(default_factory=list, description="List of bins with liquidity") + + model_config = { + "populate_by_name": True, + "json_schema_extra": { + "example": { + "address": "5hbf9JP8k5zdrZp9pokPypFQoBse5mGCmW6nqodurGcd", + "base_token_address": "METvsvVRapdj9cFLzq4Tr43xK4tAjQfwX76z3n6mWQL", + "quote_token_address": "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v", + "bin_step": 20, + "fee_pct": 0.2, + "price": 0.47366592950616504, + "base_token_amount": 8645709.142366, + "quote_token_amount": 1095942.335132, + "active_bin_id": -374, + "dynamic_fee_pct": 0.2, + "min_bin_id": -21835, + "max_bin_id": 21835, + "bins": [] + } + } + } + + +# ============================================ +# Pool Information Models +# ============================================ + +class GetPoolInfoRequest(BaseModel): + """Request to get pool information""" + connector: str = Field(description="DEX connector (e.g., 'meteora', 'raydium', 'jupiter')") + network: str = Field(description="Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta')") + trading_pair: str = Field(description="Trading pair (e.g., 'SOL-USDC')") + + +class PoolInfo(BaseModel): + """Information about a liquidity pool""" + type: str = Field(description="Pool type: 'clmm' or 'router'") + address: str = Field(description="Pool address") + trading_pair: str = Field(description="Trading pair") + base_token: str = Field(description="Base token symbol") + quote_token: str = Field(description="Quote token symbol") + current_price: Decimal = Field(description="Current pool price") + base_token_amount: Decimal = Field(description="Base token liquidity in pool") + quote_token_amount: Decimal = Field(description="Quote token liquidity in pool") + fee_pct: Decimal = Field(description="Pool fee percentage") + + # CLMM-specific + bin_step: Optional[int] = Field(default=None, description="Bin step (CLMM)") + active_bin_id: Optional[int] = Field(default=None, description="Active bin ID (CLMM)") + + +# ============================================ +# CLMM Pool Listing Models +# ============================================ + +class TimeBasedMetrics(BaseModel): + """Time-based metrics (volume, fees, fee-to-TVL ratio) for different time periods""" + min_30: Optional[Decimal] = Field(default=None, description="30 minute metric") + hour_1: Optional[Decimal] = Field(default=None, description="1 hour metric") + hour_2: Optional[Decimal] = Field(default=None, description="2 hour metric") + hour_4: Optional[Decimal] = Field(default=None, description="4 hour metric") + hour_12: Optional[Decimal] = Field(default=None, description="12 hour metric") + hour_24: Optional[Decimal] = Field(default=None, description="24 hour metric") + + +class CLMMPoolListItem(BaseModel): + """Individual pool item in CLMM pool listing""" + address: str = Field(description="Pool address") + name: str = Field(description="Pool name (e.g., 'SOL-USDC')") + trading_pair: str = Field(description="Trading pair derived from tokens") + mint_x: str = Field(description="Base token mint address") + mint_y: str = Field(description="Quote token mint address") + bin_step: int = Field(description="Bin step size") + current_price: Decimal = Field(description="Current pool price") + liquidity: str = Field(description="Total liquidity in pool") + reserve_x: str = Field(description="Base token reserves") + reserve_y: str = Field(description="Quote token reserves") + reserve_x_amount: Optional[Decimal] = Field(default=None, description="Base token reserves as decimal amount") + reserve_y_amount: Optional[Decimal] = Field(default=None, description="Quote token reserves as decimal amount") + + # Fee structure + base_fee_percentage: Optional[str] = Field(default=None, description="Base fee percentage") + max_fee_percentage: Optional[str] = Field(default=None, description="Maximum fee percentage") + protocol_fee_percentage: Optional[str] = Field(default=None, description="Protocol fee percentage") + + # APR/APY + apr: Optional[Decimal] = Field(default=None, description="Annual percentage rate") + apy: Optional[Decimal] = Field(default=None, description="Annual percentage yield") + farm_apr: Optional[Decimal] = Field(default=None, description="Farming annual percentage rate") + farm_apy: Optional[Decimal] = Field(default=None, description="Farming annual percentage yield") + + # Volume and fees + volume_24h: Optional[Decimal] = Field(default=None, description="24h trading volume") + fees_24h: Optional[Decimal] = Field(default=None, description="24h fees collected") + today_fees: Optional[Decimal] = Field(default=None, description="Today's fees collected") + cumulative_trade_volume: Optional[str] = Field(default=None, description="Cumulative trade volume") + cumulative_fee_volume: Optional[str] = Field(default=None, description="Cumulative fee volume") + + # Time-based metrics + volume: Optional[TimeBasedMetrics] = Field(default=None, description="Volume across different time periods") + fees: Optional[TimeBasedMetrics] = Field(default=None, description="Fees across different time periods") + fee_tvl_ratio: Optional[TimeBasedMetrics] = Field(default=None, description="Fee-to-TVL ratio across different time periods") + + # Rewards + reward_mint_x: Optional[str] = Field(default=None, description="Base token reward mint address") + reward_mint_y: Optional[str] = Field(default=None, description="Quote token reward mint address") + + # Metadata + tags: Optional[List[str]] = Field(default=None, description="Pool tags") + is_verified: bool = Field(default=False, description="Whether tokens are verified") + is_blacklisted: Optional[bool] = Field(default=None, description="Whether pool is blacklisted") + hide: Optional[bool] = Field(default=None, description="Whether pool should be hidden") + launchpad: Optional[str] = Field(default=None, description="Associated launchpad") + + +class CLMMPoolListResponse(BaseModel): + """Response with list of available CLMM pools""" + pools: List[CLMMPoolListItem] = Field(description="List of available pools") + total: int = Field(description="Total number of pools") + page: int = Field(description="Current page number") + limit: int = Field(description="Results per page") diff --git a/models/market_data.py b/models/market_data.py new file mode 100644 index 00000000..047f6b6f --- /dev/null +++ b/models/market_data.py @@ -0,0 +1,193 @@ +from datetime import datetime +from decimal import Decimal +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class CandleData(BaseModel): + """Single candle data point""" + timestamp: datetime = Field(description="Candle timestamp") + open: float = Field(description="Opening price") + high: float = Field(description="Highest price") + low: float = Field(description="Lowest price") + close: float = Field(description="Closing price") + volume: float = Field(description="Trading volume") + +class CandlesConfigRequest(BaseModel): + """ + The CandlesConfig class is a data class that stores the configuration of a Candle object. + It has the following attributes: + - connector: str + - trading_pair: str + - interval: str + - max_records: int + """ + connector_name: str + trading_pair: str + interval: str = "1m" + max_records: int = 500 + +class CandlesResponse(BaseModel): + """Response for candles data""" + candles: List[CandleData] = Field(description="List of candle data") + + +class ActiveFeedInfo(BaseModel): + """Information about an active market data feed""" + connector: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + interval: str = Field(description="Candle interval") + last_access: datetime = Field(description="Last access time") + expires_at: datetime = Field(description="Expiration time") + + +class ActiveFeedsResponse(BaseModel): + """Response for active market data feeds""" + feeds: List[ActiveFeedInfo] = Field(description="List of active feeds") + + +class MarketDataSettings(BaseModel): + """Market data configuration settings""" + cleanup_interval: int = Field(description="Cleanup interval in seconds") + feed_timeout: int = Field(description="Feed timeout in seconds") + description: str = Field(description="Settings description") + + +class TradingRulesResponse(BaseModel): + """Response for trading rules""" + trading_pairs: Dict[str, Dict[str, Any]] = Field(description="Trading rules by pair") + + +class SupportedOrderTypesResponse(BaseModel): + """Response for supported order types""" + connector: str = Field(description="Connector name") + supported_order_types: List[str] = Field(description="List of supported order types") + + +# New models for enhanced market data functionality + +class PriceRequest(BaseModel): + """Request model for getting prices""" + connector_name: str = Field(description="Name of the connector") + trading_pairs: List[str] = Field(description="List of trading pairs to get prices for") + + +class PriceData(BaseModel): + """Price data for a trading pair""" + trading_pair: str = Field(description="Trading pair") + price: float = Field(description="Current price") + timestamp: float = Field(description="Price timestamp") + + +class PricesResponse(BaseModel): + """Response for prices data""" + connector: str = Field(description="Connector name") + prices: Dict[str, float] = Field(description="Trading pair to price mapping") + timestamp: float = Field(description="Response timestamp") + + +class FundingInfoRequest(BaseModel): + """Request model for getting funding info""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair to get funding info for") + + +class FundingInfoResponse(BaseModel): + """Response for funding info""" + trading_pair: str = Field(description="Trading pair") + funding_rate: Optional[float] = Field(description="Current funding rate") + next_funding_time: Optional[float] = Field(description="Next funding time timestamp") + mark_price: Optional[float] = Field(description="Mark price") + index_price: Optional[float] = Field(description="Index price") + + +class OrderBookRequest(BaseModel): + """Request model for getting order book data""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair") + depth: int = Field(default=10, ge=1, le=1000, description="Number of price levels to return") + + +class OrderBookLevel(BaseModel): + """Single order book level""" + price: float = Field(description="Price level") + amount: float = Field(description="Amount at this price level") + + +class OrderBookResponse(BaseModel): + """Response for order book data""" + trading_pair: str = Field(description="Trading pair") + bids: List[OrderBookLevel] = Field(description="Bid levels (highest to lowest)") + asks: List[OrderBookLevel] = Field(description="Ask levels (lowest to highest)") + timestamp: float = Field(description="Snapshot timestamp") + + +class OrderBookQueryRequest(BaseModel): + """Request model for order book queries""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair") + is_buy: bool = Field(description="True for buy side, False for sell side") + + +class VolumeForPriceRequest(OrderBookQueryRequest): + """Request model for getting volume at a specific price""" + price: float = Field(description="Price to query volume for") + + +class PriceForVolumeRequest(OrderBookQueryRequest): + """Request model for getting price for a specific volume""" + volume: float = Field(description="Volume to query price for") + + +class QuoteVolumeForPriceRequest(OrderBookQueryRequest): + """Request model for getting quote volume at a specific price""" + price: float = Field(description="Price to query quote volume for") + + +class PriceForQuoteVolumeRequest(OrderBookQueryRequest): + """Request model for getting price for a specific quote volume""" + quote_volume: float = Field(description="Quote volume to query price for") + + +class VWAPForVolumeRequest(OrderBookQueryRequest): + """Request model for getting VWAP for a specific volume""" + volume: float = Field(description="Volume to calculate VWAP for") + + +class OrderBookQueryResult(BaseModel): + """Response for order book query operations""" + trading_pair: str = Field(description="Trading pair") + is_buy: bool = Field(description="Query side (buy/sell)") + query_volume: Optional[float] = Field(default=None, description="Queried volume") + query_price: Optional[float] = Field(default=None, description="Queried price") + result_price: Optional[float] = Field(default=None, description="Resulting price") + result_volume: Optional[float] = Field(default=None, description="Resulting volume") + result_quote_volume: Optional[float] = Field(default=None, description="Resulting quote volume") + average_price: Optional[float] = Field(default=None, description="Average/VWAP price") + timestamp: float = Field(description="Query timestamp") + + +# Trading Pair Management Models + +class AddTradingPairRequest(BaseModel): + """Request model for adding a trading pair to order book tracking""" + connector_name: str = Field(description="Name of the connector (e.g., 'binance', 'binance_perpetual')") + trading_pair: str = Field(description="Trading pair to add (e.g., 'BTC-USDT')") + account_name: Optional[str] = Field(default=None, description="Optional account name for trading connector preference") + timeout: float = Field(default=30.0, ge=1.0, le=120.0, description="Timeout in seconds for order book initialization") + + +class RemoveTradingPairRequest(BaseModel): + """Request model for removing a trading pair from order book tracking""" + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair to remove") + account_name: Optional[str] = Field(default=None, description="Optional account name for trading connector preference") + + +class TradingPairResponse(BaseModel): + """Response model for trading pair management operations""" + success: bool = Field(description="Whether the operation succeeded") + connector_name: str = Field(description="Name of the connector") + trading_pair: str = Field(description="Trading pair that was added/removed") + message: str = Field(description="Status message") \ No newline at end of file diff --git a/models/pagination.py b/models/pagination.py new file mode 100644 index 00000000..32309218 --- /dev/null +++ b/models/pagination.py @@ -0,0 +1,37 @@ +from datetime import datetime +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field, ConfigDict + + +class PaginationParams(BaseModel): + """Common pagination parameters.""" + limit: int = Field(default=100, ge=1, le=1000, description="Number of items per page") + cursor: Optional[str] = Field(None, description="Cursor for next page") + + +class TimeRangePaginationParams(BaseModel): + """Time-based pagination parameters for trading endpoints using integer timestamps.""" + limit: int = Field(default=100, ge=1, le=1000, description="Number of items per page") + start_time: Optional[int] = Field(None, description="Start time as Unix timestamp in milliseconds") + end_time: Optional[int] = Field(None, description="End time as Unix timestamp in milliseconds") + cursor: Optional[str] = Field(None, description="Cursor for next page") + + +class PaginatedResponse(BaseModel): + """Generic paginated response.""" + model_config = ConfigDict( + json_schema_extra={ + "example": { + "data": [], + "pagination": { + "limit": 100, + "has_more": True, + "next_cursor": "2024-01-10T12:00:00", + "total_count": 500 + } + } + } + ) + + data: List[Dict[str, Any]] + pagination: Dict[str, Any] \ No newline at end of file diff --git a/models/portfolio.py b/models/portfolio.py new file mode 100644 index 00000000..d2d179fc --- /dev/null +++ b/models/portfolio.py @@ -0,0 +1,97 @@ +""" +Pydantic models for the portfolio router. + +These models define the request/response schemas for portfolio-related endpoints. +""" + +from typing import Dict, List, Optional, Any +from datetime import datetime +from pydantic import BaseModel, Field + + +class TokenBalance(BaseModel): + """Token balance information""" + token: str = Field(description="Token symbol") + units: float = Field(description="Number of units held") + price: float = Field(description="Current price per unit") + value: float = Field(description="Total value (units * price)") + available_units: float = Field(description="Available units (not locked in orders)") + + +class ConnectorBalances(BaseModel): + """Balances for a specific connector""" + connector_name: str = Field(description="Name of the connector") + balances: List[TokenBalance] = Field(description="List of token balances") + total_value: float = Field(description="Total value across all tokens") + + +class AccountPortfolioState(BaseModel): + """Portfolio state for a single account""" + account_name: str = Field(description="Name of the account") + connectors: Dict[str, List[TokenBalance]] = Field(description="Balances by connector") + total_value: float = Field(description="Total account value across all connectors") + last_updated: Optional[datetime] = Field(default=None, description="Last update timestamp") + + +class PortfolioStateResponse(BaseModel): + """Response for portfolio state endpoint""" + accounts: Dict[str, Dict[str, List[Dict[str, Any]]]] = Field( + description="Portfolio state by account and connector" + ) + total_portfolio_value: Optional[float] = Field(default=None, description="Total value across all accounts") + timestamp: datetime = Field(default_factory=datetime.utcnow, description="Response timestamp") + + +class TokenDistribution(BaseModel): + """Token distribution information""" + token: str = Field(description="Token symbol") + total_value: float = Field(description="Total value of this token") + total_units: float = Field(description="Total units of this token") + percentage: float = Field(description="Percentage of total portfolio") + accounts: Dict[str, Dict[str, Any]] = Field( + description="Breakdown by account and connector" + ) + + +class PortfolioDistributionResponse(BaseModel): + """Response for portfolio distribution endpoint""" + total_portfolio_value: float = Field(description="Total portfolio value") + token_count: int = Field(description="Number of unique tokens") + distribution: List[TokenDistribution] = Field(description="Token distribution list") + account_filter: str = Field( + default="all_accounts", + description="Applied account filter (all_accounts or specific accounts)" + ) + + +class AccountDistribution(BaseModel): + """Account distribution information""" + account: str = Field(description="Account name") + total_value: float = Field(description="Total value in this account") + percentage: float = Field(description="Percentage of total portfolio") + connectors: Dict[str, Dict[str, float]] = Field( + description="Value breakdown by connector" + ) + + +class AccountsDistributionResponse(BaseModel): + """Response for accounts distribution endpoint""" + total_portfolio_value: float = Field(description="Total portfolio value") + account_count: int = Field(description="Number of accounts") + distribution: List[AccountDistribution] = Field(description="Account distribution list") + + +class HistoricalPortfolioState(BaseModel): + """Historical portfolio state entry""" + timestamp: datetime = Field(description="State timestamp") + state: Dict[str, Dict[str, List[Dict[str, Any]]]] = Field( + description="Portfolio state snapshot" + ) + total_value: Optional[float] = Field(default=None, description="Total value at this point") + + +class PortfolioHistoryFilters(BaseModel): + """Filters applied to portfolio history query""" + account_names: Optional[List[str]] = Field(default=None, description="Filtered account names") + start_time: Optional[datetime] = Field(default=None, description="Start time filter") + end_time: Optional[datetime] = Field(default=None, description="End time filter") \ No newline at end of file diff --git a/models/rate_oracle.py b/models/rate_oracle.py new file mode 100644 index 00000000..3aaa2205 --- /dev/null +++ b/models/rate_oracle.py @@ -0,0 +1,114 @@ +""" +Pydantic models for the rate oracle router. + +These models define the request/response schemas for rate oracle configuration endpoints. +""" + +from typing import Optional, List, Dict +from enum import Enum +from pydantic import BaseModel, Field + + +class RateOracleSourceEnum(str, Enum): + """Available rate oracle sources.""" + BINANCE = "binance" + BINANCE_US = "binance_us" + COIN_GECKO = "coin_gecko" + COIN_CAP = "coin_cap" + KUCOIN = "kucoin" + ASCEND_EX = "ascend_ex" + GATE_IO = "gate_io" + COINBASE_ADVANCED_TRADE = "coinbase_advanced_trade" + CUBE = "cube" + DEXALOT = "dexalot" + HYPERLIQUID = "hyperliquid" + DERIVE = "derive" + TEGRO = "tegro" + + +class GlobalTokenConfig(BaseModel): + """Global token configuration for displaying values.""" + global_token_name: str = Field( + default="USDT", + description="The token to use as global quote (e.g., USDT, USD, BTC)" + ) + global_token_symbol: str = Field( + default="$", + description="Symbol to display for the global token" + ) + + +class RateOracleSourceConfig(BaseModel): + """Rate oracle source configuration.""" + name: RateOracleSourceEnum = Field( + default=RateOracleSourceEnum.BINANCE, + description="The rate oracle source to use for price data" + ) + + +class RateOracleConfig(BaseModel): + """Complete rate oracle configuration.""" + rate_oracle_source: RateOracleSourceConfig = Field( + default_factory=RateOracleSourceConfig, + description="Rate oracle source configuration" + ) + global_token: GlobalTokenConfig = Field( + default_factory=GlobalTokenConfig, + description="Global token configuration" + ) + + +class RateOracleConfigResponse(BaseModel): + """Response for rate oracle configuration GET endpoint.""" + rate_oracle_source: RateOracleSourceConfig = Field( + description="Current rate oracle source configuration" + ) + global_token: GlobalTokenConfig = Field( + description="Current global token configuration" + ) + available_sources: List[str] = Field( + description="List of available rate oracle sources" + ) + + +class RateOracleConfigUpdateRequest(BaseModel): + """Request model for updating rate oracle configuration.""" + rate_oracle_source: Optional[RateOracleSourceConfig] = Field( + default=None, + description="New rate oracle source configuration (optional)" + ) + global_token: Optional[GlobalTokenConfig] = Field( + default=None, + description="New global token configuration (optional)" + ) + + +class RateOracleConfigUpdateResponse(BaseModel): + """Response for rate oracle configuration update.""" + success: bool = Field(description="Whether the update was successful") + message: str = Field(description="Status message") + config: RateOracleConfig = Field(description="Updated configuration") + + +class RateRequest(BaseModel): + """Request for getting rates.""" + trading_pairs: List[str] = Field( + description="List of trading pairs to get rates for (e.g., ['BTC-USDT', 'ETH-USDT'])" + ) + + +class RateResponse(BaseModel): + """Response containing rates for trading pairs.""" + source: str = Field(description="Rate oracle source used") + quote_token: str = Field(description="Quote token used") + rates: Dict[str, Optional[float]] = Field( + description="Mapping of trading pairs to their rates (None if rate not found)" + ) + + +class SingleRateResponse(BaseModel): + """Response for a single trading pair rate.""" + trading_pair: str = Field(description="The trading pair") + rate: Optional[float] = Field(description="The rate (None if not found)") + source: str = Field(description="Rate oracle source used") + quote_token: str = Field(description="Quote token used") diff --git a/models/scripts.py b/models/scripts.py new file mode 100644 index 00000000..fd60b07d --- /dev/null +++ b/models/scripts.py @@ -0,0 +1,34 @@ +from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field + + +# Script file operations +class Script(BaseModel): + """Script file content""" + content: str = Field(description="Script source code") + + +class ScriptResponse(BaseModel): + """Response for getting a script""" + name: str = Field(description="Script name") + content: str = Field(description="Script source code") + + +# Script configuration operations +class ScriptConfig(BaseModel): + """Script configuration content""" + config_name: str = Field(description="Configuration name") + script_file_name: str = Field(description="Script file name") + controllers_config: List[str] = Field(default=[], description="List of controller configurations") + candles_config: List[Dict[str, Any]] = Field(default=[], description="Candles configuration") + markets: Dict[str, Any] = Field(default={}, description="Markets configuration") + + +class ScriptConfigResponse(BaseModel): + """Response for script configuration with metadata""" + config_name: str = Field(description="Configuration name") + script_file_name: str = Field(description="Script file name") + controllers_config: List[str] = Field(default=[], description="List of controller configurations") + candles_config: List[Dict[str, Any]] = Field(default=[], description="Candles configuration") + markets: Dict[str, Any] = Field(default={}, description="Markets configuration") + error: Optional[str] = Field(None, description="Error message if config is malformed") \ No newline at end of file diff --git a/models/trading.py b/models/trading.py new file mode 100644 index 00000000..a7ee0d7d --- /dev/null +++ b/models/trading.py @@ -0,0 +1,224 @@ +from typing import Dict, List, Optional, Any, Literal +from pydantic import BaseModel, Field, field_validator +from decimal import Decimal +from datetime import datetime +from hummingbot.core.data_type.common import OrderType, TradeType, PositionAction +from .pagination import PaginationParams, TimeRangePaginationParams + + +class TradeRequest(BaseModel): + """Request model for placing trades""" + account_name: str = Field(description="Name of the account to trade with") + connector_name: str = Field(description="Name of the connector/exchange") + trading_pair: str = Field(description="Trading pair (e.g., BTC-USDT)") + trade_type: Literal["BUY", "SELL"] = Field(description="Whether to buy or sell") + amount: Decimal = Field(description="Amount to trade", gt=0) + order_type: Literal["LIMIT", "MARKET", "LIMIT_MAKER"] = Field(default="LIMIT", description="Type of order") + price: Optional[Decimal] = Field(default=None, description="Price for limit orders") + position_action: Literal["OPEN", "CLOSE"] = Field(default="OPEN", description="Position action for perpetual contracts (OPEN/CLOSE)") + + @field_validator('trade_type') + @classmethod + def validate_trade_type(cls, v): + """Validate that trade_type is a valid TradeType enum name.""" + try: + return TradeType[v].name + except KeyError: + valid_types = [t.name for t in TradeType] + raise ValueError(f"Invalid trade_type '{v}'. Must be one of: {valid_types}") + + @field_validator('order_type') + @classmethod + def validate_order_type(cls, v): + """Validate that order_type is a valid OrderType enum name.""" + try: + return OrderType[v].name + except KeyError: + valid_types = [t.name for t in OrderType] + raise ValueError(f"Invalid order_type '{v}'. Must be one of: {valid_types}") + + @field_validator('position_action') + @classmethod + def validate_position_action(cls, v): + """Validate that position_action is a valid PositionAction enum name.""" + try: + return PositionAction[v].name + except KeyError: + valid_actions = [a.name for a in PositionAction] + raise ValueError(f"Invalid position_action '{v}'. Must be one of: {valid_actions}") + + +class TradeResponse(BaseModel): + """Response model for trade execution""" + order_id: str = Field(description="Client order ID assigned by the connector") + account_name: str = Field(description="Account used for the trade") + connector_name: str = Field(description="Connector used for the trade") + trading_pair: str = Field(description="Trading pair") + trade_type: str = Field(description="Trade type") + amount: Decimal = Field(description="Trade amount") + order_type: str = Field(description="Order type") + price: Optional[Decimal] = Field(description="Order price") + status: str = Field(default="submitted", description="Order status") + + +class TokenInfo(BaseModel): + """Information about a token balance""" + token: str = Field(description="Token symbol") + balance: Decimal = Field(description="Token balance") + value_usd: Optional[Decimal] = Field(None, description="USD value of the balance") + + +class ConnectorBalance(BaseModel): + """Balance information for a connector""" + connector_name: str = Field(description="Name of the connector") + tokens: List[TokenInfo] = Field(description="List of token balances") + + +class AccountBalance(BaseModel): + """Balance information for an account""" + account_name: str = Field(description="Name of the account") + connectors: List[ConnectorBalance] = Field(description="List of connector balances") + + +class PortfolioState(BaseModel): + """Complete portfolio state across all accounts""" + accounts: List[AccountBalance] = Field(description="List of account balances") + timestamp: datetime = Field(description="Timestamp of the portfolio state") + + +class OrderInfo(BaseModel): + """Information about an order""" + order_id: str = Field(description="Order ID") + client_order_id: str = Field(description="Client order ID") + account_name: str = Field(description="Account name") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + order_type: str = Field(description="Order type") + trade_type: str = Field(description="Trade type (BUY/SELL)") + amount: Decimal = Field(description="Order amount") + price: Optional[Decimal] = Field(description="Order price") + filled_amount: Decimal = Field(description="Filled amount") + status: str = Field(description="Order status") + creation_timestamp: datetime = Field(description="Order creation time") + last_update_timestamp: datetime = Field(description="Last update time") + + +class ActiveOrdersResponse(BaseModel): + """Response for active orders""" + orders: Dict[str, OrderInfo] = Field(description="Dictionary of active orders") + + +class OrderSummary(BaseModel): + """Summary statistics for orders""" + total_orders: int = Field(description="Total number of orders") + filled_orders: int = Field(description="Number of filled orders") + cancelled_orders: int = Field(description="Number of cancelled orders") + fill_rate: float = Field(description="Order fill rate percentage") + total_volume_base: Decimal = Field(description="Total volume in base currency") + total_volume_quote: Decimal = Field(description="Total volume in quote currency") + avg_fill_time: Optional[float] = Field(description="Average fill time in seconds") + + +class TradeInfo(BaseModel): + """Information about a trade fill""" + trade_id: str = Field(description="Trade ID") + order_id: str = Field(description="Associated order ID") + account_name: str = Field(description="Account name") + connector_name: str = Field(description="Connector name") + trading_pair: str = Field(description="Trading pair") + trade_type: str = Field(description="Trade type (BUY/SELL)") + amount: Decimal = Field(description="Trade amount") + price: Decimal = Field(description="Trade price") + fee: Decimal = Field(description="Trade fee") + timestamp: datetime = Field(description="Trade timestamp") + + +class TradingRulesInfo(BaseModel): + """Trading rules for a trading pair""" + trading_pair: str = Field(description="Trading pair") + min_order_size: Decimal = Field(description="Minimum order size") + max_order_size: Optional[Decimal] = Field(description="Maximum order size") + min_price_increment: Decimal = Field(description="Minimum price increment") + min_base_amount_increment: Decimal = Field(description="Minimum base amount increment") + min_quote_amount_increment: Decimal = Field(description="Minimum quote amount increment") + + +class OrderTypesResponse(BaseModel): + """Response for supported order types""" + connector: str = Field(description="Connector name") + supported_order_types: List[str] = Field(description="List of supported order types") + + +class OrderFilterRequest(TimeRangePaginationParams): + """Request model for filtering orders with multiple criteria""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pairs: Optional[List[str]] = Field(default=None, description="List of trading pairs to filter by") + status: Optional[str] = Field(default=None, description="Order status filter") + + +class ActiveOrderFilterRequest(PaginationParams): + """Request model for filtering active orders""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pairs: Optional[List[str]] = Field(default=None, description="List of trading pairs to filter by") + + +class PositionFilterRequest(PaginationParams): + """Request model for filtering positions""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + + +class FundingPaymentFilterRequest(TimeRangePaginationParams): + """Request model for filtering funding payments""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pair: Optional[str] = Field(default=None, description="Filter by trading pair") + + +class TradeFilterRequest(TimeRangePaginationParams): + """Request model for filtering trades""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + trading_pairs: Optional[List[str]] = Field(default=None, description="List of trading pairs to filter by") + trade_types: Optional[List[str]] = Field(default=None, description="List of trade types to filter by (BUY/SELL)") + + +class PortfolioStateFilterRequest(BaseModel): + """Request model for filtering portfolio state""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + skip_gateway: bool = Field(default=False, description="Skip Gateway wallet balance updates for faster CEX-only queries") + refresh: bool = Field(default=False, description="If True, refresh balances before returning. If False, return cached state") + + +class PortfolioHistoryFilterRequest(TimeRangePaginationParams): + """Request model for filtering portfolio history""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + interval: Optional[str] = Field( + default="5m", + description="Data sampling interval: 5m, 15m, 30m, 1h, 4h, 12h, 1d. Default is 5m (raw data)" + ) + + @field_validator('interval') + @classmethod + def validate_interval(cls, v): + """Validate that interval is a supported value.""" + valid_intervals = ["5m", "15m", "30m", "1h", "4h", "12h", "1d"] + if v not in valid_intervals: + raise ValueError(f"Invalid interval '{v}'. Must be one of: {valid_intervals}") + return v + + +class PortfolioDistributionFilterRequest(BaseModel): + """Request model for filtering portfolio distribution""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") + + +class AccountsDistributionFilterRequest(BaseModel): + """Request model for filtering accounts distribution""" + account_names: Optional[List[str]] = Field(default=None, description="List of account names to filter by") + connector_names: Optional[List[str]] = Field(default=None, description="List of connector names to filter by") \ No newline at end of file diff --git a/routers/accounts.py b/routers/accounts.py new file mode 100644 index 00000000..8d6de870 --- /dev/null +++ b/routers/accounts.py @@ -0,0 +1,221 @@ +from typing import Dict, List, Optional +from datetime import datetime + +from fastapi import APIRouter, HTTPException, Depends, Query +from starlette import status + +from services.accounts_service import AccountsService +from deps import get_accounts_service +from models import PaginatedResponse, GatewayWalletCredential, GatewayWalletInfo + +router = APIRouter(tags=["Accounts"], prefix="/accounts") + + +@router.get("/", response_model=List[str]) +async def list_accounts(accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get a list of all account names in the system. + + Returns: + List of account names + """ + return accounts_service.list_accounts() + + +@router.get("/{account_name}/credentials", response_model=List[str]) +async def list_account_credentials(account_name: str, + accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get a list of all connectors that have credentials configured for a specific account. + + Args: + account_name: Name of the account to list credentials for + + Returns: + List of connector names that have credentials configured + + Raises: + HTTPException: 404 if account not found + """ + try: + credentials = accounts_service.list_credentials(account_name) + # Remove .yml extension from filenames + return [cred.replace('.yml', '') for cred in credentials] + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/add-account", status_code=status.HTTP_201_CREATED) +async def add_account(account_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Create a new account with default configuration files. + + Args: + account_name: Name of the new account to create + + Returns: + Success message when account is created + + Raises: + HTTPException: 400 if account already exists + """ + try: + accounts_service.add_account(account_name) + return {"message": "Account added successfully."} + except FileExistsError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/delete-account") +async def delete_account(account_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Delete an account and all its associated credentials. + + Args: + account_name: Name of the account to delete + + Returns: + Success message when account is deleted + + Raises: + HTTPException: 400 if trying to delete master account, 404 if account not found + """ + try: + if account_name == "master_account": + raise HTTPException(status_code=400, detail="Cannot delete master account.") + await accounts_service.delete_account(account_name) + return {"message": "Account deleted successfully."} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("/delete-credential/{account_name}/{connector_name}") +async def delete_credential(account_name: str, connector_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Delete a specific connector credential for an account. + + Args: + account_name: Name of the account + connector_name: Name of the connector to delete credentials for + + Returns: + Success message when credential is deleted + + Raises: + HTTPException: 404 if credential not found + """ + try: + await accounts_service.delete_credentials(account_name, connector_name) + return {"message": "Credential deleted successfully."} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("/add-credential/{account_name}/{connector_name}", status_code=status.HTTP_201_CREATED) +async def add_credential(account_name: str, connector_name: str, credentials: Dict, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Add or update connector credentials (API keys) for a specific account and connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector + credentials: Dictionary containing the connector credentials + + Returns: + Success message when credentials are added + + Raises: + HTTPException: 400 if there's an error adding the credentials + """ + try: + await accounts_service.add_credentials(account_name, connector_name, credentials) + return {"message": "Connector credentials added successfully."} + except Exception as e: + await accounts_service.delete_credentials(account_name, connector_name) + raise HTTPException(status_code=400, detail=str(e)) + + +# ============================================ +# Gateway Wallet Management Endpoints +# ============================================ + +@router.get("/gateway/wallets") +async def list_gateway_wallets(accounts_service: AccountsService = Depends(get_accounts_service)): + """ + List all wallets managed by Gateway. + Gateway manages its own encrypted wallet storage. + + Returns: + List of wallet information from Gateway + + Raises: + HTTPException: 503 if Gateway unavailable + """ + try: + wallets = await accounts_service.get_gateway_wallets() + return wallets + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/gateway/add-wallet", status_code=status.HTTP_201_CREATED) +async def add_gateway_wallet( + wallet_credential: GatewayWalletCredential, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Add a wallet to Gateway. Gateway handles encryption and storage internally. + + Args: + wallet_credential: Wallet credentials (chain and private_key) + + Returns: + Wallet information from Gateway including address + + Raises: + HTTPException: 503 if Gateway unavailable, 400 on validation error + """ + try: + result = await accounts_service.add_gateway_wallet( + chain=wallet_credential.chain, + private_key=wallet_credential.private_key + ) + return result + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/gateway/{chain}/{address}") +async def remove_gateway_wallet( + chain: str, + address: str, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Remove a wallet from Gateway. + + Args: + chain: Blockchain chain (e.g., 'solana', 'ethereum') + address: Wallet address to remove + + Returns: + Success message + + Raises: + HTTPException: 503 if Gateway unavailable + """ + try: + result = await accounts_service.remove_gateway_wallet(chain, address) + return result + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + diff --git a/routers/archived_bots.py b/routers/archived_bots.py new file mode 100644 index 00000000..38b42d3b --- /dev/null +++ b/routers/archived_bots.py @@ -0,0 +1,297 @@ +from typing import List, Optional +from fastapi import APIRouter, HTTPException, Query + +from utils.file_system import fs_util +from utils.hummingbot_database_reader import HummingbotDatabase + +router = APIRouter(tags=["Archived Bots"], prefix="/archived-bots") + + +@router.get("/", response_model=List[str]) +async def list_databases(): + """ + List all available database files in the system. + + Returns: + List of database file paths + """ + return fs_util.list_databases() + + +@router.get("/{db_path:path}/status") +async def get_database_status(db_path: str): + """ + Get status information for a specific database. + + Args: + db_path: Path to the database file + + Returns: + Database status including table health + """ + try: + db = HummingbotDatabase(db_path) + return { + "db_path": db_path, + "status": db.status, + "healthy": db.status["general_status"] + } + except Exception as e: + raise HTTPException(status_code=404, detail=f"Database not found or error: {str(e)}") + + +@router.get("/{db_path:path}/summary") +async def get_database_summary(db_path: str): + """ + Get a summary of database contents including basic statistics. + + Args: + db_path: Full path to the database file + + Returns: + Summary statistics of the database contents + """ + try: + db = HummingbotDatabase(db_path) + + # Get basic counts + orders = db.get_orders() + trades = db.get_trade_fills() + executors = db.get_executors_data() + positions = db.get_positions() + controllers = db.get_controllers_data() + + return { + "db_path": db_path, + "total_orders": len(orders), + "total_trades": len(trades), + "total_executors": len(executors), + "total_positions": len(positions), + "total_controllers": len(controllers), + "trading_pairs": orders["trading_pair"].unique().tolist() if len(orders) > 0 else [], + "exchanges": orders["connector_name"].unique().tolist() if len(orders) > 0 else [], + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error analyzing database: {str(e)}") + + +@router.get("/{db_path:path}/performance") +async def get_database_performance(db_path: str): + """ + Get trade-based performance analysis for a bot database. + + Args: + db_path: Full path to the database file + + Returns: + Trade-based performance metrics with rolling calculations + """ + try: + db = HummingbotDatabase(db_path) + + # Use new trade-based performance calculation + performance_data = db.calculate_trade_based_performance() + + if len(performance_data) == 0: + return { + "db_path": db_path, + "error": "No trades found in database", + "performance_data": [] + } + + # Convert to records for JSON response + performance_records = performance_data.fillna(0).to_dict('records') + + # Calculate summary statistics + final_row = performance_data.iloc[-1] if len(performance_data) > 0 else {} + summary = { + "total_trades": len(performance_data), + "final_net_pnl_quote": float(final_row.get('net_pnl_quote', 0)), + "final_realized_pnl_quote": float(final_row.get('realized_trade_pnl_quote', 0)), + "final_unrealized_pnl_quote": float(final_row.get('unrealized_trade_pnl_quote', 0)), + "total_fees_quote": float(performance_data['fees_quote'].sum()), + "total_volume_quote": float(performance_data['cum_volume_quote'].iloc[-1] if len(performance_data) > 0 else 0), + "final_net_position": float(final_row.get('net_position', 0)), + "trading_pairs": performance_data['trading_pair'].unique().tolist(), + "connector_names": performance_data['connector_name'].unique().tolist() + } + + return { + "db_path": db_path, + "summary": summary, + "performance_data": performance_records + } + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error calculating performance: {str(e)}") + + +@router.get("/{db_path:path}/trades") +async def get_database_trades( + db_path: str, + limit: int = Query(default=100, description="Limit number of trades returned"), + offset: int = Query(default=0, description="Offset for pagination") +): + """ + Get trade history from a database. + + Args: + db_path: Full path to the database file + limit: Maximum number of trades to return + offset: Offset for pagination + + Returns: + List of trades with pagination info + """ + try: + db = HummingbotDatabase(db_path) + trades = db.get_trade_fills() + + # Apply pagination + total_trades = len(trades) + trades_page = trades.iloc[offset:offset + limit] + + return { + "db_path": db_path, + "trades": trades_page.fillna(0).to_dict('records'), + "pagination": { + "total": total_trades, + "limit": limit, + "offset": offset, + "has_more": offset + limit < total_trades + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching trades: {str(e)}") + + +@router.get("/{db_path:path}/orders") +async def get_database_orders( + db_path: str, + limit: int = Query(default=100, description="Limit number of orders returned"), + offset: int = Query(default=0, description="Offset for pagination"), + status: Optional[str] = Query(default=None, description="Filter by order status") +): + """ + Get order history from a database. + + Args: + db_path: Full path to the database file + limit: Maximum number of orders to return + offset: Offset for pagination + status: Optional status filter + + Returns: + List of orders with pagination info + """ + try: + db = HummingbotDatabase(db_path) + orders = db.get_orders() + + # Apply status filter if provided + if status: + orders = orders[orders["last_status"] == status] + + # Apply pagination + total_orders = len(orders) + orders_page = orders.iloc[offset:offset + limit] + + return { + "db_path": db_path, + "orders": orders_page.fillna(0).to_dict('records'), + "pagination": { + "total": total_orders, + "limit": limit, + "offset": offset, + "has_more": offset + limit < total_orders + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching orders: {str(e)}") + + +@router.get("/{db_path:path}/executors") +async def get_database_executors(db_path: str): + """ + Get executor data from a database. + + Args: + db_path: Full path to the database file + + Returns: + List of executors with their configurations and results + """ + try: + db = HummingbotDatabase(db_path) + executors = db.get_executors_data() + + return { + "db_path": db_path, + "executors": executors.fillna(0).to_dict('records'), + "total": len(executors) + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching executors: {str(e)}") + + +@router.get("/{db_path:path}/positions") +async def get_database_positions( + db_path: str, + limit: int = Query(default=100, description="Limit number of positions returned"), + offset: int = Query(default=0, description="Offset for pagination") +): + """ + Get position data from a database. + + Args: + db_path: Full path to the database file + limit: Maximum number of positions to return + offset: Offset for pagination + + Returns: + List of positions with pagination info + """ + try: + db = HummingbotDatabase(db_path) + positions = db.get_positions() + + # Apply pagination + total_positions = len(positions) + positions_page = positions.iloc[offset:offset + limit] + + return { + "db_path": db_path, + "positions": positions_page.fillna(0).to_dict('records'), + "pagination": { + "total": total_positions, + "limit": limit, + "offset": offset, + "has_more": offset + limit < total_positions + } + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching positions: {str(e)}") + + +@router.get("/{db_path:path}/controllers") +async def get_database_controllers(db_path: str): + """ + Get controller data from a database. + + Args: + db_path: Full path to the database file + + Returns: + List of controllers that were running with their configurations + """ + try: + db = HummingbotDatabase(db_path) + controllers = db.get_controllers_data() + + return { + "db_path": db_path, + "controllers": controllers.fillna(0).to_dict('records'), + "total": len(controllers) + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching controllers: {str(e)}") diff --git a/routers/backtesting.py b/routers/backtesting.py new file mode 100644 index 00000000..3d68ee9b --- /dev/null +++ b/routers/backtesting.py @@ -0,0 +1,55 @@ +from fastapi import APIRouter +from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory +from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase + +from config import settings +from models.backtesting import BacktestingConfig + +router = APIRouter(tags=["Backtesting"], prefix="/backtesting") +candles_factory = CandlesFactory() +backtesting_engine = BacktestingEngineBase() + + +@router.post("/run-backtesting") +async def run_backtesting(backtesting_config: BacktestingConfig): + """ + Run a backtesting simulation with the provided configuration. + + Args: + backtesting_config: Configuration for the backtesting including start/end time, + resolution, trade cost, and controller config + + Returns: + Dictionary containing executors, processed data, and results from the backtest + + Raises: + Returns error dictionary if backtesting fails + """ + try: + if isinstance(backtesting_config.config, str): + controller_config = backtesting_engine.get_controller_config_instance_from_yml( + config_path=backtesting_config.config, + controllers_conf_dir_path=settings.app.controllers_path, + controllers_module=settings.app.controllers_module + ) + else: + controller_config = backtesting_engine.get_controller_config_instance_from_dict( + config_data=backtesting_config.config, + controllers_module=settings.app.controllers_module + ) + backtesting_results = await backtesting_engine.run_backtesting( + controller_config=controller_config, trade_cost=backtesting_config.trade_cost, + start=int(backtesting_config.start_time), end=int(backtesting_config.end_time), + backtesting_resolution=backtesting_config.backtesting_resolution) + processed_data = backtesting_results["processed_data"]["features"].fillna(0) + executors_info = [e.to_dict() for e in backtesting_results["executors"]] + backtesting_results["processed_data"] = processed_data.to_dict() + results = backtesting_results["results"] + results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 + return { + "executors": executors_info, + "processed_data": backtesting_results["processed_data"], + "results": backtesting_results["results"], + } + except Exception as e: + return {"error": str(e)} diff --git a/routers/bot_orchestration.py b/routers/bot_orchestration.py new file mode 100644 index 00000000..5634d3eb --- /dev/null +++ b/routers/bot_orchestration.py @@ -0,0 +1,685 @@ +import asyncio +import logging +import os +from datetime import datetime + +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException + +from database import AsyncDatabaseManager, BotRunRepository +from deps import get_bot_archiver, get_bots_orchestrator, get_database_manager, get_docker_service +from models import StartBotAction, StopBotAction, V2ControllerDeployment +from services.bots_orchestrator import BotsOrchestrator +from services.docker_service import DockerService +from utils.bot_archiver import BotArchiver +from utils.file_system import fs_util + +# Create module-specific logger +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Bot Orchestration"], prefix="/bot-orchestration") + + +@router.get("/status") +def get_active_bots_status(bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator)): + """ + Get the status of all active bots. + + Args: + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with status and data containing all active bot statuses + """ + return {"status": "success", "data": bots_manager.get_all_bots_status()} + + +@router.get("/mqtt") +def get_mqtt_status(bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator)): + """ + Get MQTT connection status and discovered bots. + + Args: + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with MQTT connection status, discovered bots, and broker information + """ + mqtt_connected = bots_manager.mqtt_manager.is_connected + discovered_bots = bots_manager.mqtt_manager.get_discovered_bots() + active_bots = list(bots_manager.active_bots.keys()) + + # Check client state + client_state = "connected" if bots_manager.mqtt_manager.is_connected else "disconnected" + + return { + "status": "success", + "data": { + "mqtt_connected": mqtt_connected, + "discovered_bots": discovered_bots, + "active_bots": active_bots, + "broker_host": bots_manager.broker_host, + "broker_port": bots_manager.broker_port, + "broker_username": bots_manager.broker_username, + "client_state": client_state + } + } + + +@router.get("/{bot_name}/status") +def get_bot_status(bot_name: str, bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator)): + """ + Get the status of a specific bot. + + Args: + bot_name: Name of the bot to get status for + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with bot status information + + Raises: + HTTPException: 404 if bot not found + """ + response = bots_manager.get_bot_status(bot_name) + if not response: + raise HTTPException(status_code=404, detail="Bot not found") + return { + "status": "success", + "data": response + } + + +@router.get("/{bot_name}/history") +async def get_bot_history( + bot_name: str, + days: int = 0, + verbose: bool = False, + precision: int = None, + timeout: float = 30.0, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator) +): + """ + Get trading history for a bot with optional parameters. + + Args: + bot_name: Name of the bot to get history for + days: Number of days of history to retrieve (0 for all) + verbose: Whether to include verbose output + precision: Decimal precision for numerical values + timeout: Timeout in seconds for the operation + bots_manager: Bot orchestrator service dependency + + Returns: + Dictionary with bot trading history + """ + response = await bots_manager.get_bot_history( + bot_name, + days=days, + verbose=verbose, + precision=precision, + timeout=timeout + ) + return {"status": "success", "response": response} + + +@router.post("/start-bot") +async def start_bot( + action: StartBotAction, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Start a bot with the specified configuration. + + Args: + action: StartBotAction containing bot configuration parameters + bots_manager: Bot orchestrator service dependency + db_manager: Database manager dependency + + Returns: + Dictionary with status and response from bot start operation + """ + response = await bots_manager.start_bot( + action.bot_name, log_level=action.log_level, script=action.script, + conf=action.conf, async_backend=action.async_backend + ) + + # Bot run tracking simplified - only track deployment and stop times + + return {"status": "success", "response": response} + + +@router.post("/stop-bot") +async def stop_bot( + action: StopBotAction, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Stop a bot with the specified configuration. + + Args: + action: StopBotAction containing bot stop parameters + bots_manager: Bot orchestrator service dependency + db_manager: Database manager dependency + + Returns: + Dictionary with status and response from bot stop operation + """ + # Capture final status BEFORE stopping (performance data is cleared on stop) + final_status = None + try: + final_status = bots_manager.get_bot_status(action.bot_name) + logger.info(f"Captured final status for {action.bot_name} before stopping") + except Exception as e: + logger.warning(f"Failed to capture final status for {action.bot_name}: {e}") + + response = await bots_manager.stop_bot( + action.bot_name, skip_order_cancellation=action.skip_order_cancellation, + async_backend=action.async_backend + ) + + # Update bot run status to STOPPED if stop was successful + if response.get("success"): + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + action.bot_name, + final_status=final_status + ) + logger.info(f"Updated bot run status to STOPPED for {action.bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run status: {e}") + # Don't fail the stop operation if bot run update fails + + return {"status": "success", "response": response} + + +@router.get("/bot-runs") +async def get_bot_runs( + bot_name: str = None, + account_name: str = None, + strategy_type: str = None, + strategy_name: str = None, + run_status: str = None, + deployment_status: str = None, + limit: int = 100, + offset: int = 0, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get bot runs with optional filtering. + + Args: + bot_name: Filter by bot name + account_name: Filter by account name + strategy_type: Filter by strategy type (script or controller) + strategy_name: Filter by strategy name + run_status: Filter by run status (CREATED, RUNNING, STOPPED, ERROR) + deployment_status: Filter by deployment status (DEPLOYED, FAILED, ARCHIVED) + limit: Maximum number of results to return + offset: Number of results to skip + db_manager: Database manager dependency + + Returns: + List of bot runs with their details + """ + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_runs = await bot_run_repo.get_bot_runs( + bot_name=bot_name, + account_name=account_name, + strategy_type=strategy_type, + strategy_name=strategy_name, + run_status=run_status, + deployment_status=deployment_status, + limit=limit, + offset=offset + ) + + # Convert bot runs to dictionaries for JSON serialization + runs_data = [] + for run in bot_runs: + run_dict = { + "id": run.id, + "bot_name": run.bot_name, + "instance_name": run.instance_name, + "deployed_at": run.deployed_at.isoformat() if run.deployed_at else None, + "stopped_at": run.stopped_at.isoformat() if run.stopped_at else None, + "strategy_type": run.strategy_type, + "strategy_name": run.strategy_name, + "config_name": run.config_name, + "account_name": run.account_name, + "image_version": run.image_version, + "deployment_status": run.deployment_status, + "run_status": run.run_status, + "deployment_config": run.deployment_config, + "final_status": run.final_status, + "error_message": run.error_message + } + runs_data.append(run_dict) + + return { + "status": "success", + "data": runs_data, + "total": len(runs_data), + "limit": limit, + "offset": offset + } + except Exception as e: + logger.error(f"Failed to get bot runs: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/bot-runs/{bot_run_id}") +async def get_bot_run_by_id( + bot_run_id: int, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get a specific bot run by ID. + + Args: + bot_run_id: ID of the bot run + db_manager: Database manager dependency + + Returns: + Bot run details + + Raises: + HTTPException: 404 if bot run not found + """ + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + bot_run = await bot_run_repo.get_bot_run_by_id(bot_run_id) + + if not bot_run: + raise HTTPException(status_code=404, detail=f"Bot run {bot_run_id} not found") + + run_dict = { + "id": bot_run.id, + "bot_name": bot_run.bot_name, + "instance_name": bot_run.instance_name, + "deployed_at": bot_run.deployed_at.isoformat() if bot_run.deployed_at else None, + "stopped_at": bot_run.stopped_at.isoformat() if bot_run.stopped_at else None, + "strategy_type": bot_run.strategy_type, + "strategy_name": bot_run.strategy_name, + "config_name": bot_run.config_name, + "account_name": bot_run.account_name, + "image_version": bot_run.image_version, + "deployment_status": bot_run.deployment_status, + "run_status": bot_run.run_status, + "deployment_config": bot_run.deployment_config, + "final_status": bot_run.final_status, + "error_message": bot_run.error_message + } + + return {"status": "success", "data": run_dict} + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to get bot run {bot_run_id}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/bot-runs/stats") +async def get_bot_run_stats( + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get statistics about bot runs. + + Args: + db_manager: Database manager dependency + + Returns: + Bot run statistics + """ + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + stats = await bot_run_repo.get_bot_run_stats() + + return {"status": "success", "data": stats} + except Exception as e: + logger.error(f"Failed to get bot run stats: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +async def _background_stop_and_archive( + bot_name: str, + container_name: str, + bot_name_for_orchestrator: str, + skip_order_cancellation: bool, + archive_locally: bool, + s3_bucket: str, + bots_manager: BotsOrchestrator, + docker_manager: DockerService, + bot_archiver: BotArchiver, + db_manager: AsyncDatabaseManager +): + """Background task to handle the stop and archive process""" + try: + logger.info(f"Starting background stop-and-archive for {bot_name}") + + # Step 1: Capture bot final status before stopping (while bot is still running) + logger.info(f"Capturing final status for {bot_name_for_orchestrator}") + final_status = None + try: + final_status = bots_manager.get_bot_status(bot_name_for_orchestrator) + logger.info(f"Captured final status for {bot_name_for_orchestrator}: {final_status}") + except Exception as e: + logger.warning(f"Failed to capture final status for {bot_name_for_orchestrator}: {e}") + + # Step 2: Update bot run with stopped_at timestamp and final status before stopping + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + final_status=final_status + ) + logger.info(f"Updated bot run with stopped_at timestamp and final status for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run with stopped status: {e}") + # Continue with stop process even if database update fails + + # Step 3: Mark the bot as stopping, and stop the bot trading process + bots_manager.set_bot_stopping(bot_name_for_orchestrator) + logger.info(f"Stopping bot trading process for {bot_name_for_orchestrator}") + stop_response = await bots_manager.stop_bot( + bot_name_for_orchestrator, + skip_order_cancellation=skip_order_cancellation, + async_backend=True # Always use async for background tasks + ) + + if not stop_response or not stop_response.get("success", False): + error_msg = stop_response.get('error', 'Unknown error') if stop_response else 'No response from bot orchestrator' + logger.error(f"Failed to stop bot process: {error_msg}") + return + + # Step 4: Wait for graceful shutdown (15 seconds as requested) + logger.info(f"Waiting 15 seconds for bot {bot_name} to gracefully shutdown") + await asyncio.sleep(15) + + # Step 5: Stop the container with monitoring + max_retries = 10 + retry_interval = 2 + container_stopped = False + + for i in range(max_retries): + logger.info(f"Attempting to stop container {container_name} (attempt {i+1}/{max_retries})") + docker_manager.stop_container(container_name) + + # Check if container is already stopped + container_status = docker_manager.get_container_status(container_name) + if container_status.get("state", {}).get("status") == "exited": + container_stopped = True + logger.info(f"Container {container_name} is already stopped") + break + + await asyncio.sleep(retry_interval) + + if not container_stopped: + logger.error(f"Failed to stop container {container_name} after {max_retries} attempts") + return + + # Step 6: Archive the bot data + instance_dir = os.path.join('bots', 'instances', container_name) + logger.info(f"Archiving bot data from {instance_dir}") + + try: + if archive_locally: + bot_archiver.archive_locally(container_name, instance_dir) + else: + bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) + logger.info(f"Successfully archived bot data for {container_name}") + except Exception as e: + logger.error(f"Archive failed: {str(e)}") + # Continue with removal even if archive fails + + # Step 7: Remove the container + logging.info(f"Removing container {container_name}") + remove_response = docker_manager.remove_container(container_name, force=False) + + if not remove_response.get("success"): + # If graceful remove fails, try force remove + logging.warning("Graceful container removal failed, attempting force removal") + remove_response = docker_manager.remove_container(container_name, force=True) + + if remove_response.get("success"): + logging.info(f"Successfully completed stop-and-archive for bot {bot_name}") + + # Step 8: Update bot run deployment status to ARCHIVED + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_archived(bot_name) + logger.info(f"Updated bot run deployment status to ARCHIVED for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run to archived: {e}") + else: + logging.error(f"Failed to remove container {container_name}") + + # Update bot run with error status (but keep stopped_at timestamp from earlier) + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + error_message="Failed to remove container during archive process" + ) + logger.info(f"Updated bot run with error status for {bot_name}") + except Exception as e: + logger.error(f"Failed to update bot run with error: {e}") + + except Exception as e: + logging.error(f"Error in background stop-and-archive for {bot_name}: {str(e)}") + + # Update bot run with error status + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.update_bot_run_stopped( + bot_name, + error_message=str(e) + ) + logger.info(f"Updated bot run with error status for {bot_name}") + except Exception as db_error: + logger.error(f"Failed to update bot run with error: {db_error}") + finally: + # Always clear the stopping status when the background task completes + bots_manager.clear_bot_stopping(bot_name_for_orchestrator) + logger.info(f"Cleared stopping status for bot {bot_name}") + + # Remove bot from active_bots and clear all MQTT data + if bot_name_for_orchestrator in bots_manager.active_bots: + bots_manager.mqtt_manager.clear_bot_data(bot_name_for_orchestrator) + del bots_manager.active_bots[bot_name_for_orchestrator] + logger.info(f"Removed bot {bot_name_for_orchestrator} from active_bots and cleared MQTT data") + + +@router.post("/stop-and-archive-bot/{bot_name}") +async def stop_and_archive_bot( + bot_name: str, + background_tasks: BackgroundTasks, + skip_order_cancellation: bool = True, + archive_locally: bool = True, + s3_bucket: str = None, + bots_manager: BotsOrchestrator = Depends(get_bots_orchestrator), + docker_manager: DockerService = Depends(get_docker_service), + bot_archiver: BotArchiver = Depends(get_bot_archiver), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Gracefully stop a bot and archive its data in the background. + This initiates a background task that will: + 1. Stop the bot trading process via MQTT + 2. Wait 15 seconds for graceful shutdown + 3. Monitor and stop the Docker container + 4. Archive the bot data (locally or to S3) + 5. Remove the container + + Returns immediately with a success message while the process continues in the background. + """ + try: + # Step 1: Normalize bot name and container name + # Container name is now the same as bot name (no prefix added) + actual_bot_name = bot_name + container_name = bot_name + + logging.info(f"Normalized bot_name: {actual_bot_name}, container_name: {container_name}") + + # Step 2: Validate bot exists in active bots + active_bots = list(bots_manager.active_bots.keys()) + + # Check if bot exists in active bots (could be stored as either format) + bot_found = (actual_bot_name in active_bots) or (container_name in active_bots) + + if not bot_found: + return { + "status": "error", + "message": ( + f"Bot '{actual_bot_name}' not found in active bots. " + f"Active bots: {active_bots}. Cannot perform graceful shutdown." + ), + "details": { + "input_name": bot_name, + "actual_bot_name": actual_bot_name, + "container_name": container_name, + "active_bots": active_bots, + "reason": "Bot must be actively managed via MQTT for graceful shutdown" + } + } + + # Use the format that's actually stored in active bots + bot_name_for_orchestrator = container_name if container_name in active_bots else actual_bot_name + + # Add the background task + background_tasks.add_task( + _background_stop_and_archive, + bot_name=actual_bot_name, + container_name=container_name, + bot_name_for_orchestrator=bot_name_for_orchestrator, + skip_order_cancellation=skip_order_cancellation, + archive_locally=archive_locally, + s3_bucket=s3_bucket, + bots_manager=bots_manager, + docker_manager=docker_manager, + bot_archiver=bot_archiver, + db_manager=db_manager + ) + + return { + "status": "success", + "message": f"Stop and archive process started for bot {actual_bot_name}", + "details": { + "input_name": bot_name, + "actual_bot_name": actual_bot_name, + "container_name": container_name, + "process": ( + "The bot will be gracefully stopped, archived, and removed in the background. " + "This process typically takes 20-30 seconds." + ) + } + } + + except Exception as e: + logging.error(f"Error initiating stop_and_archive_bot for {bot_name}: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/deploy-v2-controllers") +async def deploy_v2_controllers( + deployment: V2ControllerDeployment, + docker_manager: DockerService = Depends(get_docker_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Deploy a V2 strategy with controllers by generating the script config and creating the instance. + This endpoint simplifies the deployment process for V2 controller strategies. + + Args: + deployment: V2ControllerDeployment configuration + docker_manager: Docker service dependency + + Returns: + Dictionary with deployment response and generated configuration details + + Raises: + HTTPException: 500 if deployment fails + """ + try: + # Generate unique script config filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + script_config_filename = f"{deployment.instance_name}-{timestamp}.yml" + # Use the same name with timestamp for the instance to ensure uniqueness + unique_instance_name = f"{deployment.instance_name}-{timestamp}" + + # Ensure controller config names have .yml extension + controllers_with_extension = [] + for controller in deployment.controllers_config: + if not controller.endswith('.yml'): + controllers_with_extension.append(f"{controller}.yml") + else: + controllers_with_extension.append(controller) + + # Create the script config content + # Note: candles_config and markets removed - they're optional and empty, + # and older hummingbot versions don't expect them in the config + script_config_content = { + "script_file_name": "v2_with_controllers.py", + "controllers_config": controllers_with_extension, + } + + # Add optional drawdown parameters if provided + if deployment.max_global_drawdown_quote is not None: + script_config_content["max_global_drawdown_quote"] = deployment.max_global_drawdown_quote + if deployment.max_controller_drawdown_quote is not None: + script_config_content["max_controller_drawdown_quote"] = deployment.max_controller_drawdown_quote + + # Save the script config to the scripts directory + scripts_dir = os.path.join("conf", "scripts") + + script_config_path = os.path.join(scripts_dir, script_config_filename) + fs_util.dump_dict_to_yaml(script_config_path, script_config_content) + + logging.info(f"Generated script config: {script_config_filename} with content: {script_config_content}") + + # Set generated config on the deployment and deploy + deployment.instance_name = unique_instance_name + deployment.script_config = script_config_filename + response = docker_manager.create_hummingbot_instance(deployment) + + if response.get("success"): + response["script_config_generated"] = script_config_filename + response["controllers_deployed"] = deployment.controllers_config + response["unique_instance_name"] = unique_instance_name + + # Track bot run if deployment was successful + try: + async with db_manager.get_session_context() as session: + bot_run_repo = BotRunRepository(session) + await bot_run_repo.create_bot_run( + bot_name=unique_instance_name, + instance_name=unique_instance_name, + strategy_type="controller", + strategy_name="v2_with_controllers", + account_name=deployment.credentials_profile, + config_name=script_config_filename, + image_version=deployment.image, + deployment_config=deployment.dict() + ) + logger.info(f"Created bot run record for controller deployment {unique_instance_name}") + except Exception as e: + logger.error(f"Failed to create bot run record: {e}") + # Don't fail the deployment if bot run creation fails + + return response + + except Exception as e: + logging.error(f"Error deploying V2 controllers: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/routers/connectors.py b/routers/connectors.py new file mode 100644 index 00000000..85d7af86 --- /dev/null +++ b/routers/connectors.py @@ -0,0 +1,119 @@ +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query, Request +from hummingbot.client.settings import AllConnectorSettings + +from deps import get_accounts_service +from services.accounts_service import AccountsService +from services.market_data_service import MarketDataService + +router = APIRouter(tags=["Connectors"], prefix="/connectors") + + +@router.get("/", response_model=List[str]) +async def available_connectors(): + """ + Get a list of all available connectors. + + Returns: + List of connector names supported by the system + """ + return list(AllConnectorSettings.get_connector_settings().keys()) + + +@router.get("/{connector_name}/config-map", response_model=Dict[str, dict]) +async def get_connector_config_map(connector_name: str, accounts_service: AccountsService = Depends(get_accounts_service)): + """ + Get configuration fields required for a specific connector with type information. + + Args: + connector_name: Name of the connector to get config map for + + Returns: + Dictionary mapping field names to their type information. + Each field contains: + - type: The expected data type (e.g., "str", "SecretStr", "int") + - required: Whether the field is required + """ + return accounts_service.get_connector_config_map(connector_name) + + +@router.get("/{connector_name}/trading-rules") +async def get_trading_rules( + request: Request, + connector_name: str, + trading_pairs: Optional[List[str]] = Query(default=None, description="Filter by specific trading pairs") +): + """ + Get trading rules for a connector, optionally filtered by trading pairs. + + This endpoint uses the MarketDataService to access non-trading connector instances, + which means no authentication or account setup is required. + + Args: + request: FastAPI request object + connector_name: Name of the connector (e.g., 'binance', 'binance_perpetual') + trading_pairs: Optional list of trading pairs to filter by (e.g., ['BTC-USDT', 'ETH-USDT']) + + Returns: + Dictionary mapping trading pairs to their trading rules + + Raises: + HTTPException: 404 if connector not found, 500 for other errors + """ + try: + market_data_service: MarketDataService = request.app.state.market_data_service + + # Get trading rules (filtered by trading pairs if provided) + rules = await market_data_service.get_trading_rules(connector_name, trading_pairs) + + if "error" in rules: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found or error: {rules['error']}") + + return rules + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error retrieving trading rules: {str(e)}") + + +@router.get("/{connector_name}/order-types") +async def get_supported_order_types(request: Request, connector_name: str): + """ + Get order types supported by a specific connector. + + This endpoint uses the MarketDataService to access non-trading connector instances, + which means no authentication or account setup is required. + + Args: + request: FastAPI request object + connector_name: Name of the connector (e.g., 'binance', 'binance_perpetual') + + Returns: + List of supported order types (LIMIT, MARKET, LIMIT_MAKER) + + Raises: + HTTPException: 404 if connector not found, 500 for other errors + """ + try: + market_data_service: MarketDataService = request.app.state.market_data_service + + # Access connector through UnifiedConnectorService + # This creates a data connector if it doesn't exist + try: + connector_instance = market_data_service.connector_service.get_data_connector(connector_name) + except (KeyError, ValueError) as e: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' not found: {str(e)}") + + # Get supported order types + if hasattr(connector_instance, 'supported_order_types'): + order_types = [order_type.name for order_type in connector_instance.supported_order_types()] + return {"connector": connector_name, "supported_order_types": order_types} + else: + raise HTTPException(status_code=404, detail=f"Connector '{connector_name}' does not support order types query") + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error retrieving order types: {str(e)}") diff --git a/routers/controllers.py b/routers/controllers.py new file mode 100644 index 00000000..feec75f6 --- /dev/null +++ b/routers/controllers.py @@ -0,0 +1,392 @@ +import json +from typing import Dict, List + +import yaml +from fastapi import APIRouter, HTTPException +from starlette import status + +from models import Controller, ControllerType +from utils.file_system import fs_util + +router = APIRouter(tags=["Controllers"], prefix="/controllers") + + +@router.get("/", response_model=Dict[str, List[str]]) +async def list_controllers(): + """ + List all controllers organized by type. + + Detects both single-file controllers (controller.py) and + package-style controllers (controller/controller.py). + + Returns: + Dictionary mapping controller types to lists of controller names + """ + result = {} + for controller_type in ControllerType: + controllers = [] + type_path = f'controllers/{controller_type.value}' + + try: + # Get single-file controllers (*.py files) + files = fs_util.list_files(type_path) + controllers.extend([ + f.replace('.py', '') for f in files + if f.endswith('.py') and f != "__init__.py" + ]) + + # Get package-style controllers (folders with same-named .py file inside) + folders = fs_util.list_folders(type_path) + for folder in folders: + if folder.startswith('__') or folder == 'examples': + continue + # Check if folder contains a .py file with the same name + try: + folder_files = fs_util.list_files(f'{type_path}/{folder}') + if f'{folder}.py' in folder_files: + controllers.append(folder) + except FileNotFoundError: + pass + + result[controller_type.value] = sorted(set(controllers)) + except FileNotFoundError: + result[controller_type.value] = [] + return result + + +# Controller Configuration endpoints (must come before controller type routes) +@router.get("/configs/", response_model=List[Dict]) +async def list_controller_configs(): + """ + List all controller configurations with metadata. + + Returns: + List of controller configuration objects with name, controller_name, controller_type, and other metadata + """ + try: + config_files = [f for f in fs_util.list_files('conf/controllers') if f.endswith('.yml')] + configs = [] + + for config_file in config_files: + config_name = config_file.replace('.yml', '') + try: + config = fs_util.read_yaml_file(f"conf/controllers/{config_file}") + configs.append(config) + except Exception as e: + # If config is malformed, still include it with basic info + configs.append({ + "id": config_name, + "controller_name": "error", + "controller_type": "error", + "error": str(e) + }) + + return configs + except FileNotFoundError: + return [] + + +@router.get("/configs/{config_name}", response_model=Dict) +async def get_controller_config(config_name: str): + """ + Get controller configuration by config name. + + Args: + config_name: Name of the configuration file to retrieve + + Returns: + Dictionary with controller configuration + + Raises: + HTTPException: 404 if configuration not found + """ + try: + config = fs_util.read_yaml_file(f"conf/controllers/{config_name}.yml") + return config + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.post("/configs/{config_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_controller_config(config_name: str, config: Dict): + """ + Create or update controller configuration. + + Args: + config_name: Name of the configuration file + config: Configuration dictionary to save + + Returns: + Success message when configuration is saved + + Raises: + HTTPException: 400 if save error occurs + """ + try: + yaml_content = yaml.dump(config, default_flow_style=False) + fs_util.add_file('conf/controllers', f"{config_name}.yml", yaml_content, override=True) + return {"message": f"Configuration '{config_name}' saved successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/configs/{config_name}") +async def delete_controller_config(config_name: str): + """ + Delete controller configuration. + + Args: + config_name: Name of the configuration file to delete + + Returns: + Success message when configuration is deleted + + Raises: + HTTPException: 404 if configuration not found + """ + try: + fs_util.delete_file('conf/controllers', f"{config_name}.yml") + return {"message": f"Configuration '{config_name}' deleted successfully"} + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.get("/{controller_type}/{controller_name}", response_model=Dict[str, str]) +async def get_controller(controller_type: ControllerType, controller_name: str): + """ + Get controller content by type and name. + + Supports both single-file controllers (controller.py) and + package-style controllers (controller/controller.py). + + Args: + controller_type: Type of the controller + controller_name: Name of the controller + + Returns: + Dictionary with controller name, type, and content + + Raises: + HTTPException: 404 if controller not found + """ + # Try single-file first, then package-style + paths_to_try = [ + f"controllers/{controller_type.value}/{controller_name}.py", + f"controllers/{controller_type.value}/{controller_name}/{controller_name}.py", + ] + + for path in paths_to_try: + try: + content = fs_util.read_file(path) + return { + "name": controller_name, + "type": controller_type.value, + "content": content + } + except FileNotFoundError: + continue + + raise HTTPException( + status_code=404, + detail=f"Controller '{controller_name}' not found in '{controller_type.value}'" + ) + + +@router.post("/{controller_type}/{controller_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_controller(controller_type: ControllerType, controller_name: str, controller: Controller): + """ + Create or update a controller. + + If controller exists as a package (folder), updates the file inside. + Otherwise creates/updates as a single file. + + Args: + controller_type: Type of controller to create/update + controller_name: Name of the controller (from URL path) + controller: Controller object with content (and optional type for validation) + + Returns: + Success message when controller is saved + + Raises: + HTTPException: 400 if controller type mismatch or save error + """ + # If type is provided in body, validate it matches URL + if controller.type is not None and controller.type != controller_type: + raise HTTPException( + status_code=400, + detail=f"Controller type mismatch: URL has '{controller_type}', body has '{controller.type}'" + ) + + try: + type_path = f'controllers/{controller_type.value}' + package_path = f'{type_path}/{controller_name}' + + # Check if controller exists as a package (folder with same-named .py file) + if fs_util.path_exists(package_path): + fs_util.add_file(package_path, f"{controller_name}.py", controller.content, override=True) + else: + fs_util.add_file(type_path, f"{controller_name}.py", controller.content, override=True) + + return {"message": f"Controller '{controller_name}' saved successfully in '{controller_type.value}'"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/{controller_type}/{controller_name}") +async def delete_controller(controller_type: ControllerType, controller_name: str): + """ + Delete a controller. + + Handles both single-file and package-style controllers. + + Args: + controller_type: Type of the controller + controller_name: Name of the controller to delete + + Returns: + Success message when controller is deleted + + Raises: + HTTPException: 404 if controller not found + """ + type_path = f'controllers/{controller_type.value}' + + # Try single-file first + try: + fs_util.delete_file(type_path, f"{controller_name}.py") + return {"message": f"Controller '{controller_name}' deleted successfully from '{controller_type.value}'"} + except FileNotFoundError: + pass + + # Try package-style (delete entire folder) + try: + fs_util.delete_folder(type_path, controller_name) + return {"message": f"Controller '{controller_name}' deleted successfully from '{controller_type.value}'"} + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Controller '{controller_name}' not found in '{controller_type.value}'" + ) + + +@router.get("/{controller_type}/{controller_name}/config/template") +async def get_controller_config_template(controller_type: ControllerType, controller_name: str): + """ + Get controller configuration template with default values. + + Args: + controller_type: Type of the controller + controller_name: Name of the controller + + Returns: + Dictionary with configuration template and default values + + Raises: + HTTPException: 404 if controller configuration class not found + """ + config_class = fs_util.load_controller_config_class(controller_type.value, controller_name) + if config_class is None: + raise HTTPException( + status_code=404, + detail=f"Controller configuration class for '{controller_name}' not found" + ) + + # Extract fields and default values + config_fields = {name: {"default": field.default, + "type": field.annotation, + "required": field.required if hasattr(field, 'required') else False, + } for name, field in config_class.model_fields.items()} + return json.loads(json.dumps(config_fields, default=str)) + + +@router.post("/{controller_type}/{controller_name}/config/validate") +async def validate_controller_config(controller_type: ControllerType, controller_name: str, config: Dict): + """ + Validate controller configuration against the controller's config class. + + Args: + controller_type: Type of the controller + controller_name: Name of the controller + config: Configuration dictionary to validate + + Returns: + Success message if configuration is valid + + Raises: + HTTPException: 400 if validation fails + """ + config_class = fs_util.load_controller_config_class(controller_type.value, controller_name) + if config_class is None: + raise HTTPException( + status_code=404, + detail=f"Controller configuration class for '{controller_name}' not found" + ) + + try: + config_class(**config) # Validate by instantiating the model + return {"message": "Configuration is valid"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +# Bot-specific controller config endpoints +@router.get("/bots/{bot_name}/configs", response_model=List[Dict]) +async def get_bot_controller_configs(bot_name: str): + """ + Get all controller configurations for a specific bot. + + Args: + bot_name: Name of the bot to get configurations for + + Returns: + List of controller configurations for the bot + + Raises: + HTTPException: 404 if bot not found + """ + bots_config_path = f"instances/{bot_name}/conf/controllers" + if not fs_util.path_exists(bots_config_path): + raise HTTPException(status_code=404, detail=f"Bot '{bot_name}' not found") + + configs = [] + for controller_file in fs_util.list_files(bots_config_path): + if controller_file.endswith('.yml'): + config = fs_util.read_yaml_file(f"{bots_config_path}/{controller_file}") + config['_config_name'] = controller_file.replace('.yml', '') + configs.append(config) + return configs + + +@router.post("/bots/{bot_name}/{controller_name}/config") +async def update_bot_controller_config(bot_name: str, controller_name: str, config: Dict): + """ + Update controller configuration for a specific bot. + + Args: + bot_name: Name of the bot + controller_name: Name of the controller to update + config: Configuration dictionary to update with + + Returns: + Success message when configuration is updated + + Raises: + HTTPException: 404 if bot or controller not found, 400 if update error + """ + bots_config_path = f"instances/{bot_name}/conf/controllers" + if not fs_util.path_exists(bots_config_path): + raise HTTPException(status_code=404, detail=f"Bot '{bot_name}' not found") + + try: + current_config = fs_util.read_yaml_file(f"{bots_config_path}/{controller_name}.yml") + current_config.update(config) + fs_util.dump_dict_to_yaml(f"{bots_config_path}/{controller_name}.yml", current_config) + return {"message": f"Controller configuration for bot '{bot_name}' updated successfully"} + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Controller configuration '{controller_name}' not found for bot '{bot_name}'" + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/docker.py b/routers/docker.py new file mode 100644 index 00000000..7b0f8287 --- /dev/null +++ b/routers/docker.py @@ -0,0 +1,192 @@ +import os + +from fastapi import APIRouter, HTTPException, Depends + +from models import DockerImage +from utils.bot_archiver import BotArchiver +from services.docker_service import DockerService +from deps import get_docker_service, get_bot_archiver + +router = APIRouter(tags=["Docker"], prefix="/docker") + + +@router.get("/running") +async def is_docker_running(docker_service: DockerService = Depends(get_docker_service)): + """ + Check if Docker daemon is running. + + Args: + docker_service: Docker service dependency + + Returns: + Dictionary indicating if Docker is running + """ + return docker_service.is_docker_running() + + +@router.get("/available-images/") +async def available_images(image_name: str = None, docker_service: DockerService = Depends(get_docker_service)): + """ + Get available Docker images matching the specified name. + + Args: + image_name: Name pattern to search for in image tags + docker_service: Docker service dependency + + Returns: + Dictionary with list of available image tags + """ + available_images = docker_service.get_available_images() + if image_name: + return [tag for image in available_images["images"] for tag in image.tags if image_name in tag] + return [tag for tag in available_images["images"]] + + +@router.get("/active-containers") +async def active_containers(name_filter: str = None, docker_service: DockerService = Depends(get_docker_service)): + """ + Get all currently active (running) Docker containers. + + Args: + name_filter: Optional filter to match container names (case-insensitive) + docker_service: Docker service dependency + + Returns: + List of active container information + """ + return docker_service.get_active_containers(name_filter) + + +@router.get("/exited-containers") +async def exited_containers(name_filter: str = None, docker_service: DockerService = Depends(get_docker_service)): + """ + Get all exited (stopped) Docker containers. + + Args: + name_filter: Optional filter to match container names (case-insensitive) + docker_service: Docker service dependency + + Returns: + List of exited container information + """ + return docker_service.get_exited_containers(name_filter) + + +@router.post("/clean-exited-containers") +async def clean_exited_containers(docker_service: DockerService = Depends(get_docker_service)): + """ + Remove all exited Docker containers to free up space. + + Args: + docker_service: Docker service dependency + + Returns: + Response from cleanup operation + """ + return docker_service.clean_exited_containers() + + +@router.post("/remove-container/{container_name}") +async def remove_container(container_name: str, archive_locally: bool = True, s3_bucket: str = None, docker_service: DockerService = Depends(get_docker_service), bot_archiver: BotArchiver = Depends(get_bot_archiver)): + """ + Remove a Hummingbot container and optionally archive its bot data. + + NOTE: This endpoint only works with Hummingbot containers (names starting with 'hummingbot-') + as it archives bot-specific data from the bots/instances directory. + + Args: + container_name: Name of the Hummingbot container to remove + archive_locally: Whether to archive data locally (default: True) + s3_bucket: S3 bucket name for cloud archiving (optional) + docker_service: Docker service dependency + bot_archiver: Bot archiver service dependency + + Returns: + Response from container removal operation + + Raises: + HTTPException: 400 if container is not a Hummingbot container + HTTPException: 500 if archiving fails + """ + # Validate that this is a Hummingbot container + if not container_name.startswith("hummingbot-"): + raise HTTPException( + status_code=400, + detail=f"This endpoint only removes Hummingbot containers. Container '{container_name}' is not a Hummingbot container." + ) + + # Remove the container + response = docker_service.remove_container(container_name) + # Form the instance directory path correctly + instance_dir = os.path.join('bots', 'instances', container_name) + try: + # Archive the data + if archive_locally: + bot_archiver.archive_locally(container_name, instance_dir) + else: + bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + return response + + +@router.post("/stop-container/{container_name}") +async def stop_container(container_name: str, docker_service: DockerService = Depends(get_docker_service)): + """ + Stop a running Docker container. + + Args: + container_name: Name of the container to stop + docker_service: Docker service dependency + + Returns: + Response from container stop operation + """ + return docker_service.stop_container(container_name) + + +@router.post("/start-container/{container_name}") +async def start_container(container_name: str, docker_service: DockerService = Depends(get_docker_service)): + """ + Start a stopped Docker container. + + Args: + container_name: Name of the container to start + docker_service: Docker service dependency + + Returns: + Response from container start operation + """ + return docker_service.start_container(container_name) + + +@router.post("/pull-image/") +async def pull_image(image: DockerImage, docker_service: DockerService = Depends(get_docker_service)): + """ + Initiate Docker image pull as background task. + Returns immediately with task status for monitoring. + + Args: + image: DockerImage object containing the image name to pull + docker_service: Docker service dependency + + Returns: + Status of the pull operation initiation + """ + result = docker_service.pull_image_async(image.image_name) + return result + + +@router.get("/pull-status/") +async def get_pull_status(docker_service: DockerService = Depends(get_docker_service)): + """ + Get status of all pull operations. + + Args: + docker_service: Docker service dependency + + Returns: + Dictionary with all pull operations and their statuses + """ + return docker_service.get_all_pull_status() diff --git a/routers/executors.py b/routers/executors.py new file mode 100644 index 00000000..11ff01d7 --- /dev/null +++ b/routers/executors.py @@ -0,0 +1,687 @@ +""" +Executor Router - REST API endpoints for dynamic executor management. + +This router enables running Hummingbot executors directly via API +without Docker containers or full strategy setup. +""" +import logging +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException +from starlette import status + +from deps import get_executor_service, get_market_data_service +from models.executors import ( + CreateExecutorRequest, + CreateExecutorResponse, + ExecutorDetailResponse, + ExecutorFilterRequest, + ExecutorLogsResponse, + ExecutorsSummaryResponse, + PerformanceReportResponse, + PositionHoldResponse, + PositionsSummaryResponse, + StopExecutorRequest, + StopExecutorResponse, +) +from models.pagination import PaginatedResponse +from services.executor_service import ExecutorService +from services.market_data_service import MarketDataService + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Executors"], prefix="/executors") + + +@router.post("/", response_model=CreateExecutorResponse, status_code=status.HTTP_201_CREATED) +async def create_executor( + request: CreateExecutorRequest, + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Create and start a new executor. + + Supported executor types: + - **position_executor**: Single position with triple barrier (stop loss, take profit, time limit) + - **grid_executor**: Grid trading with multiple levels + - **dca_executor**: Dollar-cost averaging with multiple entry points + - **twap_executor**: Time-weighted average price execution + - **arbitrage_executor**: Cross-exchange arbitrage + - **xemm_executor**: Cross-exchange market making + - **order_executor**: Simple order execution + - **lp_executor**: Liquidity provider position on CLMM DEXs (Meteora, Raydium, etc.) + + The `executor_config` must include: + - `type`: One of the executor types above + - `connector_name`: Exchange connector (e.g., "binance", "binance_perpetual") + - `trading_pair`: Trading pair (e.g., "BTC-USDT") + - Additional type-specific configuration (see /executors/types/{type}/config for details) + + Returns the created executor ID and initial status. + """ + try: + result = await executor_service.create_executor( + executor_config=request.executor_config, + account_name=request.account_name, + controller_id=request.controller_id + ) + return CreateExecutorResponse(**result) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error creating executor: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error creating executor: {str(e)}") + + +@router.post("/search", response_model=PaginatedResponse) +async def list_executors( + filter_request: ExecutorFilterRequest, + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Get list of executors with optional filtering. + + Returns active executors from memory combined with completed executors from database. + + Filters: + - `account_names`: Filter by specific accounts + - `connector_names`: Filter by connectors + - `trading_pairs`: Filter by trading pairs + - `executor_types`: Filter by executor types + - `status`: Filter by status (RUNNING, TERMINATED, etc.) + + Returns paginated list of executor summaries. + """ + try: + # Get filtered executors (active from memory + completed from DB) + executors = await executor_service.get_executors( + account_name=filter_request.account_names[0] if filter_request.account_names else None, + connector_name=filter_request.connector_names[0] if filter_request.connector_names else None, + trading_pair=filter_request.trading_pairs[0] if filter_request.trading_pairs else None, + executor_type=filter_request.executor_types[0] if filter_request.executor_types else None, + status=filter_request.status, + controller_id=filter_request.controller_ids[0] if filter_request.controller_ids else None + ) + + # Apply additional multi-value filters + if filter_request.account_names and len(filter_request.account_names) > 1: + executors = [e for e in executors if e.get("account_name") in filter_request.account_names] + if filter_request.connector_names and len(filter_request.connector_names) > 1: + executors = [e for e in executors if e.get("connector_name") in filter_request.connector_names] + if filter_request.trading_pairs and len(filter_request.trading_pairs) > 1: + executors = [e for e in executors if e.get("trading_pair") in filter_request.trading_pairs] + if filter_request.executor_types and len(filter_request.executor_types) > 1: + executors = [e for e in executors if e.get("executor_type") in filter_request.executor_types] + if filter_request.controller_ids and len(filter_request.controller_ids) > 1: + executors = [e for e in executors if e.get("controller_id") in filter_request.controller_ids] + + # Apply cursor-based pagination + start_idx = 0 + if filter_request.cursor: + for i, ex in enumerate(executors): + if ex.get("executor_id") == filter_request.cursor: + start_idx = i + 1 + break + + end_idx = start_idx + filter_request.limit + page_data = executors[start_idx:end_idx] + has_more = end_idx < len(executors) + next_cursor = page_data[-1]["executor_id"] if page_data and has_more else None + + return PaginatedResponse( + data=page_data, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(executors) + } + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error listing executors: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error listing executors: {str(e)}") + + +@router.get("/summary", response_model=ExecutorsSummaryResponse) +async def get_executors_summary( + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Get summary statistics for all executors. + + Returns aggregate information including: + - Total active/completed executor counts + - Total PnL and volume + - Breakdown by executor type, connector, and status + """ + try: + summary = executor_service.get_summary() + return ExecutorsSummaryResponse(**summary) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting executor summary: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting summary: {str(e)}") + + +@router.get("/performance", response_model=PerformanceReportResponse) +async def get_performance_report( + controller_id: Optional[str] = None, + executor_service: ExecutorService = Depends(get_executor_service), + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Get a performance report for executors. + + Aggregates metrics from all completed executors (optionally filtered by controller_id): + - Realized PnL (from completed executors, excluding POSITION_HOLD close type) + - Unrealized PnL (from active executors + position holds) + - Global PnL (realized + unrealized) + - Fees and volume totals + - Win rate and Sharpe ratio + - Breakdown by executor type + - Active position count + + Query parameters: + - **controller_id**: Filter by controller ID (omit for all controllers) + """ + try: + report = await executor_service.get_performance_report( + controller_id=controller_id, + market_data_service=market_data_service + ) + return PerformanceReportResponse(**report) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error generating performance report: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error generating performance report: {str(e)}") + + +@router.get("/{executor_id}/logs", response_model=ExecutorLogsResponse) +async def get_executor_logs( + executor_id: str, + level: Optional[str] = None, + limit: int = 50, + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Get captured log entries for a specific executor. + + Returns log entries from the in-memory ring buffer. Only available for + active executors - logs are cleared when the executor completes. + + Query parameters: + - **level**: Filter by log level (ERROR, WARNING, INFO, DEBUG) + - **limit**: Maximum entries to return (default 50) + """ + try: + all_logs = executor_service.get_executor_logs(executor_id, level=level) + total_count = len(all_logs) + limited_logs = all_logs[-limit:] if limit else all_logs + + return ExecutorLogsResponse( + executor_id=executor_id, + logs=limited_logs, + total_count=total_count, + ) + except Exception as e: + logger.error(f"Error getting logs for executor {executor_id}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting executor logs: {str(e)}") + + +@router.get("/types/available") +async def get_available_executor_types(): + """ + Get list of available executor types with descriptions. + + Returns information about each supported executor type. + """ + return { + "executor_types": [ + { + "type": "position_executor", + "description": "Single position with triple barrier (stop loss, take profit, time limit)", + "use_case": "Directional trading with risk management" + }, + { + "type": "grid_executor", + "description": "Grid trading with multiple buy/sell levels", + "use_case": "Range-bound market trading" + }, + { + "type": "dca_executor", + "description": "Dollar-cost averaging with multiple entry points", + "use_case": "Gradual position building" + }, + { + "type": "twap_executor", + "description": "Time-weighted average price execution", + "use_case": "Large order execution with minimal market impact" + }, + { + "type": "arbitrage_executor", + "description": "Cross-exchange price arbitrage", + "use_case": "Exploiting price differences between exchanges" + }, + { + "type": "xemm_executor", + "description": "Cross-exchange market making", + "use_case": "Providing liquidity across exchanges" + }, + { + "type": "order_executor", + "description": "Simple order execution with retry logic", + "use_case": "Basic order placement with reliability" + }, + { + "type": "lp_executor", + "description": "LP position management for CLMM pools (Meteora, Raydium) ", + "use_case": "Automated liquidity provision with position tracking" + } + ] + } + + +@router.get("/{executor_id}", response_model=ExecutorDetailResponse) +async def get_executor( + executor_id: str, + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Get detailed information about a specific executor. + + Checks active executors in memory first, then falls back to database for completed executors. + + Returns full executor information including: + - Current status and PnL + - Full configuration + - Executor-specific custom information + """ + try: + executor = await executor_service.get_executor(executor_id) + + if not executor: + raise HTTPException(status_code=404, detail=f"Executor {executor_id} not found") + + return ExecutorDetailResponse(**executor) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting executor {executor_id}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting executor: {str(e)}") + + +@router.post("/{executor_id}/stop", response_model=StopExecutorResponse) +async def stop_executor( + executor_id: str, + request: StopExecutorRequest, + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Stop an active executor. + + Options: + - `keep_position`: If true, keeps any open position (for position executors). + If false, the executor will attempt to close all positions before stopping. + + Returns confirmation of the stop action. + """ + try: + result = await executor_service.stop_executor( + executor_id=executor_id, + keep_position=request.keep_position + ) + return StopExecutorResponse(**result) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error stopping executor {executor_id}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error stopping executor: {str(e)}") + + +# ======================================== +# Position Hold Endpoints +# ======================================== + +@router.get("/positions/summary", response_model=PositionsSummaryResponse) +async def get_positions_summary( + controller_id: Optional[str] = None, + executor_service: ExecutorService = Depends(get_executor_service), + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Get summary of all held positions from executors stopped with keep_position=True. + + Returns aggregate information including: + - Total number of active position holds + - Total realized PnL across all positions + - Total unrealized PnL (when market rates are available) + - List of all positions with breakeven prices and PnL + + Query parameters: + - **controller_id**: Filter positions by controller ID + """ + try: + positions = executor_service.get_positions_held(controller_id=controller_id) + total_realized_pnl = sum(float(p.realized_pnl_quote) for p in positions) + total_unrealized_pnl = None + position_responses = [] + + for p in positions: + unrealized_pnl = None + parts = p.trading_pair.split("-") + if len(parts) == 2: + base, quote = parts + rate = market_data_service.get_rate(base, quote) + if rate is not None: + unrealized_pnl = float(p.get_unrealized_pnl(rate)) + if total_unrealized_pnl is None: + total_unrealized_pnl = 0.0 + total_unrealized_pnl += unrealized_pnl + + position_responses.append(PositionHoldResponse( + trading_pair=p.trading_pair, + connector_name=p.connector_name, + account_name=p.account_name, + controller_id=p.controller_id, + buy_amount_base=float(p.buy_amount_base), + buy_amount_quote=float(p.buy_amount_quote), + sell_amount_base=float(p.sell_amount_base), + sell_amount_quote=float(p.sell_amount_quote), + net_amount_base=float(p.net_amount_base), + buy_breakeven_price=float(p.buy_breakeven_price) if p.buy_breakeven_price else None, + sell_breakeven_price=float(p.sell_breakeven_price) if p.sell_breakeven_price else None, + matched_amount_base=float(p.matched_amount_base), + unmatched_amount_base=float(p.unmatched_amount_base), + position_side=p.position_side, + realized_pnl_quote=float(p.realized_pnl_quote), + unrealized_pnl_quote=unrealized_pnl, + executor_count=len(p.executor_ids), + executor_ids=p.executor_ids, + last_updated=p.last_updated.isoformat() if p.last_updated else None + )) + + return PositionsSummaryResponse( + total_positions=len(positions), + total_realized_pnl=total_realized_pnl, + total_unrealized_pnl=total_unrealized_pnl, + positions=position_responses + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting positions summary: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting positions summary: {str(e)}") + + +@router.get("/positions/{connector_name}/{trading_pair}", response_model=PositionHoldResponse) +async def get_position_held( + connector_name: str, + trading_pair: str, + account_name: str = "master_account", + controller_id: str = "main", + executor_service: ExecutorService = Depends(get_executor_service), + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Get held position for a specific connector/trading pair. + + Returns the aggregated position from executors stopped with keep_position=True, + including breakeven prices, matched/unmatched volume, realized PnL, and unrealized PnL. + + Query parameters: + - **controller_id**: Controller ID (default "main") + """ + try: + position = executor_service.get_position_held( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + controller_id=controller_id + ) + + if not position: + raise HTTPException( + status_code=404, + detail=f"No position hold found for {connector_name}/{trading_pair}" + ) + + unrealized_pnl = None + parts = trading_pair.split("-") + if len(parts) == 2: + base, quote = parts + rate = market_data_service.get_rate(base, quote) + if rate is not None: + unrealized_pnl = float(position.get_unrealized_pnl(rate)) + + return PositionHoldResponse( + trading_pair=position.trading_pair, + connector_name=position.connector_name, + account_name=position.account_name, + controller_id=position.controller_id, + buy_amount_base=float(position.buy_amount_base), + buy_amount_quote=float(position.buy_amount_quote), + sell_amount_base=float(position.sell_amount_base), + sell_amount_quote=float(position.sell_amount_quote), + net_amount_base=float(position.net_amount_base), + buy_breakeven_price=float(position.buy_breakeven_price) if position.buy_breakeven_price else None, + sell_breakeven_price=float(position.sell_breakeven_price) if position.sell_breakeven_price else None, + matched_amount_base=float(position.matched_amount_base), + unmatched_amount_base=float(position.unmatched_amount_base), + position_side=position.position_side, + realized_pnl_quote=float(position.realized_pnl_quote), + unrealized_pnl_quote=unrealized_pnl, + executor_count=len(position.executor_ids), + executor_ids=position.executor_ids, + last_updated=position.last_updated.isoformat() if position.last_updated else None + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting position: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting position: {str(e)}") + + +@router.delete("/positions/{connector_name}/{trading_pair}") +async def clear_position_held( + connector_name: str, + trading_pair: str, + account_name: str = "master_account", + controller_id: str = "main", + executor_service: ExecutorService = Depends(get_executor_service) +): + """ + Clear a held position (after manual close or full exit). + + This removes the position from tracking but preserves historical data + in completed executors. + + Query parameters: + - **controller_id**: Controller ID (default "main") + """ + try: + cleared = executor_service.clear_position_held( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + controller_id=controller_id + ) + + if not cleared: + raise HTTPException( + status_code=404, + detail=f"No position hold found for {connector_name}/{trading_pair}" + ) + + return { + "message": f"Position hold for {connector_name}/{trading_pair} cleared", + "connector_name": connector_name, + "trading_pair": trading_pair + } + except HTTPException: + raise + except Exception as e: + logger.error(f"Error clearing position: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error clearing position: {str(e)}") + + +def _extract_field_info(schema: dict, definitions: dict) -> list: + """ + Extract field information from a JSON schema. + + Returns list of field dicts with: name, type, description, required, default, constraints + """ + fields = [] + properties = schema.get("properties", {}) + required_fields = set(schema.get("required", [])) + + for field_name, field_schema in properties.items(): + # Skip internal fields + if field_name.startswith("_"): + continue + + field_info = { + "name": field_name, + "required": field_name in required_fields, + } + + # Resolve $ref if present + if "$ref" in field_schema: + ref_path = field_schema["$ref"].split("/")[-1] + if ref_path in definitions: + field_schema = {**definitions[ref_path], **field_schema} + del field_schema["$ref"] + + # Handle anyOf (usually Optional types) + if "anyOf" in field_schema: + types = [] + for option in field_schema["anyOf"]: + if "$ref" in option: + ref_name = option["$ref"].split("/")[-1] + types.append(ref_name) + elif option.get("type") == "null": + field_info["required"] = False + else: + types.append(option.get("type", "any")) + field_info["type"] = types[0] if len(types) == 1 else f"Union[{', '.join(types)}]" + elif "allOf" in field_schema: + # Handle allOf (usually inheritance) + refs = [opt["$ref"].split("/")[-1] for opt in field_schema["allOf"] if "$ref" in opt] + field_info["type"] = refs[0] if refs else "object" + elif "enum" in field_schema: + field_info["type"] = "enum" + field_info["enum_values"] = field_schema["enum"] + elif "type" in field_schema: + field_info["type"] = field_schema["type"] + else: + field_info["type"] = "any" + + # Extract description + if "description" in field_schema: + field_info["description"] = field_schema["description"] + elif "title" in field_schema: + field_info["description"] = field_schema["title"] + + # Extract default value + if "default" in field_schema: + field_info["default"] = field_schema["default"] + + # Extract constraints + constraints = {} + if "minimum" in field_schema: + constraints["minimum"] = field_schema["minimum"] + if "maximum" in field_schema: + constraints["maximum"] = field_schema["maximum"] + if "exclusiveMinimum" in field_schema: + constraints["exclusive_minimum"] = field_schema["exclusiveMinimum"] + if "exclusiveMaximum" in field_schema: + constraints["exclusive_maximum"] = field_schema["exclusiveMaximum"] + if "minLength" in field_schema: + constraints["min_length"] = field_schema["minLength"] + if "maxLength" in field_schema: + constraints["max_length"] = field_schema["maxLength"] + if "pattern" in field_schema: + constraints["pattern"] = field_schema["pattern"] + if "ge" in field_schema: + constraints["ge"] = field_schema["ge"] + if "le" in field_schema: + constraints["le"] = field_schema["le"] + if "gt" in field_schema: + constraints["gt"] = field_schema["gt"] + if "lt" in field_schema: + constraints["lt"] = field_schema["lt"] + + if constraints: + field_info["constraints"] = constraints + + fields.append(field_info) + + return fields + + +@router.get("/types/{executor_type}/config") +async def get_executor_config_schema(executor_type: str): + """ + Get configuration schema for a specific executor type. + + Returns detailed information about each configuration field including: + - **name**: Field name + - **type**: Data type (str, int, Decimal, enum, etc.) + - **description**: Field description + - **required**: Whether the field is required + - **default**: Default value if any + - **constraints**: Validation constraints (min, max, pattern, etc.) + - **enum_values**: Possible values for enum types + + Also returns nested type definitions for complex fields. + """ + from services.executor_service import ExecutorService + + if executor_type not in ExecutorService.EXECUTOR_REGISTRY: + raise HTTPException( + status_code=404, + detail=f"Unknown executor type '{executor_type}'. Valid types: {list(ExecutorService.EXECUTOR_REGISTRY.keys())}" + ) + + _, config_class = ExecutorService.EXECUTOR_REGISTRY[executor_type] + + try: + # Get JSON schema from pydantic model + schema = config_class.model_json_schema() + definitions = schema.get("$defs", {}) + + # Extract field information + fields = _extract_field_info(schema, definitions) + + # Extract nested type definitions + nested_types = {} + for def_name, def_schema in definitions.items(): + if "properties" in def_schema: + nested_types[def_name] = { + "description": def_schema.get("description", def_schema.get("title", "")), + "fields": _extract_field_info(def_schema, definitions) + } + elif "enum" in def_schema: + nested_types[def_name] = { + "type": "enum", + "values": def_schema["enum"], + "description": def_schema.get("description", def_schema.get("title", "")) + } + + return { + "executor_type": executor_type, + "config_class": config_class.__name__, + "description": schema.get("description", schema.get("title", "")), + "fields": fields, + "nested_types": nested_types + } + + except Exception as e: + logger.error(f"Error extracting config schema for {executor_type}: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Error extracting config schema: {str(e)}" + ) diff --git a/routers/gateway.py b/routers/gateway.py new file mode 100644 index 00000000..8bca7309 --- /dev/null +++ b/routers/gateway.py @@ -0,0 +1,816 @@ +from fastapi import APIRouter, HTTPException, Depends, Query +from typing import Optional, Dict, List +import re + +from models import ( + GatewayConfig, + GatewayStatus, + AddPoolRequest, + AddTokenRequest, + CreateWalletRequest, + ShowPrivateKeyRequest, + SendTransactionRequest, +) +from services.gateway_service import GatewayService +from services.accounts_service import AccountsService +from deps import get_gateway_service, get_accounts_service + +router = APIRouter(tags=["Gateway"], prefix="/gateway") + + +def camel_to_snake(name: str) -> str: + """Convert camelCase to snake_case""" + name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower() + + +def snake_to_camel(name: str) -> str: + """ + Convert snake_case to camelCase, handling common acronyms. + + Special cases: + - url -> URL + - cu -> CU (compute units) + - id -> ID + - api -> API + - rpc -> RPC + """ + # Map of acronyms that should be uppercase + acronyms = {'url', 'cu', 'id', 'api', 'rpc', 'uri'} + + components = name.split('_') + + # Process each component + result_parts = [components[0]] # First component stays lowercase + + for component in components[1:]: + if component.lower() in acronyms: + # Uppercase acronyms + result_parts.append(component.upper()) + else: + # Title case for normal words + result_parts.append(component.title()) + + return ''.join(result_parts) + + +def normalize_gateway_response(data: Dict) -> Dict: + """ + Normalize Gateway response data to Python conventions. + - Converts camelCase to snake_case + - Maps baseSymbol -> base, quoteSymbol -> quote + - Creates trading_pair field + """ + if isinstance(data, dict): + normalized = {} + for key, value in data.items(): + # Handle special mappings + if key == "baseSymbol": + normalized["base"] = value + elif key == "quoteSymbol": + normalized["quote"] = value + else: + # Convert to snake_case + new_key = camel_to_snake(key) + # Recursively normalize nested dicts/lists + if isinstance(value, dict): + normalized[new_key] = normalize_gateway_response(value) + elif isinstance(value, list): + normalized[new_key] = [normalize_gateway_response(item) if isinstance(item, dict) else item for item in value] + else: + normalized[new_key] = value + + # Create trading_pair if we have base and quote + if "base" in normalized and "quote" in normalized: + normalized["trading_pair"] = f"{normalized['base']}-{normalized['quote']}" + + return normalized + return data + + +# ============================================ +# Container Management +# ============================================ + +@router.get("/status", response_model=GatewayStatus) +async def get_gateway_status(gateway_service: GatewayService = Depends(get_gateway_service)): + """Get Gateway container status.""" + return gateway_service.get_status() + + +@router.post("/start") +async def start_gateway( + config: GatewayConfig, + gateway_service: GatewayService = Depends(get_gateway_service) +): + """Start Gateway container.""" + result = gateway_service.start(config) + if not result["success"]: + if "already running" in result["message"]: + raise HTTPException(status_code=400, detail=result["message"]) + raise HTTPException(status_code=500, detail=result["message"]) + return result + + +@router.post("/stop") +async def stop_gateway(gateway_service: GatewayService = Depends(get_gateway_service)): + """Stop Gateway container.""" + result = gateway_service.stop() + if not result["success"]: + if "not found" in result["message"]: + raise HTTPException(status_code=404, detail=result["message"]) + raise HTTPException(status_code=500, detail=result["message"]) + return result + + +@router.post("/restart") +async def restart_gateway( + config: Optional[GatewayConfig] = None, + gateway_service: GatewayService = Depends(get_gateway_service) +): + """ + Restart Gateway container. + + If config is provided, the container will be removed and recreated with new configuration. + If no config is provided, the container will be stopped and started with existing configuration. + """ + result = gateway_service.restart(config) + if not result["success"]: + if "not found" in result["message"]: + raise HTTPException(status_code=404, detail=result["message"]) + raise HTTPException(status_code=500, detail=result["message"]) + return result + + +@router.get("/logs") +async def get_gateway_logs( + tail: int = Query(default=100, ge=1, le=10000), + gateway_service: GatewayService = Depends(get_gateway_service) +): + """Get Gateway container logs.""" + result = gateway_service.get_logs(tail) + if not result["success"]: + if "not found" in result["message"]: + raise HTTPException(status_code=404, detail=result["message"]) + raise HTTPException(status_code=500, detail=result["message"]) + return result + + +# ============================================ +# Connectors +# ============================================ + +@router.get("/connectors") +async def list_connectors(accounts_service: AccountsService = Depends(get_accounts_service)) -> Dict: + """ + List all available DEX connectors with their configurations. + + Returns connector details including name, trading types, chain, and networks. + All fields normalized to snake_case. + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client._request("GET", "config/connectors") + return normalize_gateway_response(result) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error listing connectors: {str(e)}") + + +@router.get("/connectors/{connector_name}") +async def get_connector_config( + connector_name: str, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Get configuration for a specific DEX connector. + + Args: + connector_name: Connector name (e.g., 'meteora', 'raydium') + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.get_config(connector_name) + return normalize_gateway_response(result) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error getting connector config: {str(e)}") + + +@router.post("/connectors/{connector_name}") +async def update_connector_config( + connector_name: str, + config_updates: Dict, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Update configuration for a DEX connector. + + Args: + connector_name: Connector name (e.g., 'meteora', 'raydium') + config_updates: Dict with path-value pairs to update. + Keys can be in snake_case (e.g., {"slippage_pct": 0.5}) + or camelCase (e.g., {"slippagePct": 0.5}) + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + results = [] + for path, value in config_updates.items(): + # Convert snake_case to camelCase if needed + camel_path = snake_to_camel(path) if '_' in path else path + result = await accounts_service.gateway_client.update_config(connector_name, camel_path, value) + results.append(result) + + return { + "success": True, + "message": f"Updated {len(results)} config parameter(s) for {connector_name}. Restart Gateway for changes to take effect.", + "restart_required": True, + "restart_endpoint": "POST /gateway/restart", + "results": results + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error updating connector config: {str(e)}") + + +# ============================================ +# Chains (Networks) and Tokens +# ============================================ + +@router.get("/chains") +async def list_chains(accounts_service: AccountsService = Depends(get_accounts_service)) -> Dict: + """ + List all available blockchain chains and their networks. + + This also serves as the networks list endpoint. + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.get_chains() + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error listing chains: {str(e)}") + + +# ============================================ +# Pools +# ============================================ + +@router.get("/pools") +async def list_pools( + connector_name: str = Query(description="DEX connector (e.g., 'meteora', 'raydium')"), + network: str = Query(description="Network (e.g., 'mainnet-beta')"), + accounts_service: AccountsService = Depends(get_accounts_service) +) -> List[Dict]: + """ + List all liquidity pools for a connector and network. + + Returns normalized data with snake_case fields and trading_pair. + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + pools = await accounts_service.gateway_client.get_pools(connector_name, network) + + if not pools: + raise HTTPException(status_code=400, detail=f"No pools found for {connector_name}/{network}") + + # Normalize each pool + normalized_pools = [normalize_gateway_response(pool) for pool in pools] + return normalized_pools + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error getting pools: {str(e)}") + + +@router.post("/pools") +async def add_pool( + pool_request: AddPoolRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Add a custom liquidity pool. + + Args: + pool_request: Pool details (connector, type, network, base, quote, address) + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.add_pool( + connector=pool_request.connector_name, + pool_type=pool_request.type, + network=pool_request.network, + address=pool_request.address, + base_symbol=pool_request.base, + quote_symbol=pool_request.quote, + base_token_address=pool_request.base_address, + quote_token_address=pool_request.quote_address, + fee_pct=pool_request.fee_pct + ) + + if result is None: + raise HTTPException(status_code=502, detail="Failed to add pool: Gateway returned no response") + + if "error" in result: + status = result.get("status", 400) + raise HTTPException(status_code=status, detail=f"Failed to add pool: {result.get('error')}") + + trading_pair = f"{pool_request.base}-{pool_request.quote}" + return { + "message": f"Pool {trading_pair} added to {pool_request.connector_name}/{pool_request.network}", + "trading_pair": trading_pair + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error adding pool: {str(e)}") + + +@router.delete("/pools/{address}") +async def delete_pool( + address: str, + connector_name: str = Query(description="DEX connector (e.g., 'meteora', 'raydium', 'uniswap')"), + network: str = Query(description="Network name (e.g., 'mainnet-beta', 'mainnet')"), + pool_type: str = Query(description="Pool type (e.g., 'clmm', 'amm')"), + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Delete a liquidity pool from Gateway's pool list. + + Args: + address: Pool contract address to remove + connector_name: DEX connector (e.g., 'meteora', 'raydium', 'uniswap') + network: Network name (e.g., 'mainnet-beta', 'mainnet') + pool_type: Pool type (e.g., 'clmm', 'amm') + + Example: DELETE /gateway/pools/2sf5NYcY...?connector_name=meteora&network=mainnet-beta&pool_type=clmm + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.delete_pool( + connector=connector_name, + network=network, + pool_type=pool_type, + address=address + ) + + if result is None: + raise HTTPException(status_code=400, detail="Failed to delete pool - no response from Gateway") + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Failed to delete pool: {result.get('error')}") + + return { + "success": True, + "message": f"Pool {address} deleted from {connector_name}/{network}", + "pool_address": address, + "connector": connector_name, + "network": network + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error deleting pool: {str(e)}") + + +# ============================================ +# Networks (Primary Endpoints) +# ============================================ + +@router.get("/networks") +async def list_networks(accounts_service: AccountsService = Depends(get_accounts_service)) -> Dict: + """ + List all available networks across all chains. + + Returns a flattened list of network IDs in the format 'chain-network'. + This is the primary interface for network discovery. + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + chains_result = await accounts_service.gateway_client.get_chains() + + # Flatten chain-network combinations into network IDs + networks = [] + if "chains" in chains_result and isinstance(chains_result["chains"], list): + for chain_item in chains_result["chains"]: + chain = chain_item.get("chain") + chain_networks = chain_item.get("networks", []) + for network in chain_networks: + network_id = f"{chain}-{network}" + networks.append({ + "network_id": network_id, + "chain": chain, + "network": network + }) + + return { + "networks": networks, + "count": len(networks) + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error listing networks: {str(e)}") + + +@router.get("/networks/{network_id}") +async def get_network_config( + network_id: str, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Get configuration for a specific network. + + Args: + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta', 'ethereum-mainnet') + + Example: GET /gateway/networks/solana-mainnet-beta + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.get_config(network_id) + return normalize_gateway_response(result) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error getting network config: {str(e)}") + + +@router.post("/networks/{network_id}") +async def update_network_config( + network_id: str, + config_updates: Dict, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Update configuration for a specific network. + + Args: + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta') + config_updates: Dict with path-value pairs to update. + Keys can be in snake_case (e.g., {"node_url": "https://..."}) + or camelCase (e.g., {"nodeURL": "https://..."}) + + Example: POST /gateway/networks/solana-mainnet-beta + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + results = [] + for path, value in config_updates.items(): + # Convert snake_case to camelCase if needed + camel_path = snake_to_camel(path) if '_' in path else path + result = await accounts_service.gateway_client.update_config(network_id, camel_path, value) + results.append(result) + + return { + "success": True, + "message": f"Updated {len(results)} config parameter(s) for {network_id}. Restart Gateway for changes to take effect.", + "restart_required": True, + "restart_endpoint": "POST /gateway/restart", + "results": results + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error updating network config: {str(e)}") + + +@router.get("/networks/{network_id}/tokens") +async def get_network_tokens( + network_id: str, + search: Optional[str] = Query(default=None), + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Get available tokens for a network. + + Args: + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta') + search: Filter tokens by symbol or name + + Example: GET /gateway/networks/solana-mainnet-beta/tokens?search=USDC + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id into chain and network + parts = network_id.split('-', 1) + if len(parts) != 2: + raise HTTPException(status_code=400, detail=f"Invalid network_id format. Expected 'chain-network', got '{network_id}'") + + chain, network = parts + result = await accounts_service.gateway_client.get_tokens(chain, network) + + # Apply search filter + if search and "tokens" in result: + search_lower = search.lower() + result["tokens"] = [ + token for token in result["tokens"] + if search_lower in token.get("symbol", "").lower() or + search_lower in token.get("name", "").lower() + ] + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error getting network tokens: {str(e)}") + + +@router.post("/networks/{network_id}/tokens") +async def add_network_token( + network_id: str, + token_request: AddTokenRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Add a custom token to Gateway's token list for a specific network. + + Args: + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta', 'ethereum-mainnet') + token_request: Token details (address, symbol, name, decimals) + + Example: POST /gateway/networks/ethereum-mainnet/tokens + { + "address": "0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48", + "symbol": "USDC", + "name": "USD Coin", + "decimals": 6 + } + + Note: After adding a token, restart Gateway for changes to take effect. + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id into chain and network + parts = network_id.split('-', 1) + if len(parts) != 2: + raise HTTPException(status_code=400, detail=f"Invalid network_id format. Expected 'chain-network', got '{network_id}'") + + chain, network = parts + + # Use symbol as name if name is not provided + token_name = token_request.name if token_request.name else token_request.symbol + + result = await accounts_service.gateway_client.add_token( + chain=chain, + network=network, + address=token_request.address, + symbol=token_request.symbol, + name=token_name, + decimals=token_request.decimals + ) + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Failed to add token: {result.get('error')}") + + return { + "success": True, + "message": f"Token {token_request.symbol} added to {network_id}. Restart Gateway for changes to take effect.", + "restart_required": True, + "restart_endpoint": "POST /gateway/restart", + "token": { + "symbol": token_request.symbol, + "address": token_request.address, + "network_id": network_id + } + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error adding token: {str(e)}") + + +@router.delete("/networks/{network_id}/tokens/{token_address}") +async def delete_network_token( + network_id: str, + token_address: str, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Delete a custom token from Gateway's token list for a specific network. + + Args: + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta', 'ethereum-mainnet') + token_address: Token contract address to delete + + Example: DELETE /gateway/networks/solana-mainnet-beta/tokens/9QFfgxdSqH5zT7j6rZb1y6SZhw2aFtcQu2r6BuYpump + + Note: After deleting a token, restart Gateway for changes to take effect. + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id into chain and network + parts = network_id.split('-', 1) + if len(parts) != 2: + raise HTTPException(status_code=400, detail=f"Invalid network_id format. Expected 'chain-network', got '{network_id}'") + + chain, network = parts + + result = await accounts_service.gateway_client.delete_token( + chain=chain, + network=network, + token_address=token_address + ) + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Failed to delete token: {result.get('error')}") + + return { + "success": True, + "message": f"Token {token_address} deleted from {network_id}. Restart Gateway for changes to take effect.", + "restart_required": True, + "restart_endpoint": "POST /gateway/restart", + "token_address": token_address, + "network_id": network_id + } + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error deleting token: {str(e)}") + + +# ============================================ +# Wallet Management +# ============================================ + +@router.post("/wallets/create") +async def create_wallet( + request: CreateWalletRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Create a new wallet in Gateway. + + Args: + request: Contains chain and set_default flag + + Returns: + Dict with address and chain of the created wallet. + + Example: POST /gateway/wallets/create + { + "chain": "solana", + "set_default": true + } + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.create_wallet( + chain=request.chain, + set_default=request.set_default + ) + + if result is None: + raise HTTPException(status_code=502, detail="Failed to create wallet: Gateway returned no response") + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Failed to create wallet: {result.get('error')}") + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error creating wallet: {str(e)}") + + +@router.post("/wallets/show-private-key") +async def show_private_key( + request: ShowPrivateKeyRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Show private key for a wallet. + + WARNING: This endpoint exposes sensitive information. Use with caution. + + Args: + request: Contains chain, address, and passphrase + + Returns: + Dict with privateKey field. + + Example: POST /gateway/wallets/show-private-key + { + "chain": "solana", + "address": "", + "passphrase": "" + } + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.show_private_key( + chain=request.chain, + address=request.address, + passphrase=request.passphrase + ) + + if result is None: + raise HTTPException(status_code=502, detail="Failed to retrieve private key: Gateway returned no response") + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Failed to retrieve private key: {result.get('error')}") + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error retrieving private key: {str(e)}") + + +@router.post("/wallets/send") +async def send_transaction( + request: SendTransactionRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +) -> Dict: + """ + Send a native token transaction. + + Args: + request: Contains chain, network, sender address, recipient address, and amount + + Returns: + Dict with transaction signature/hash. + + Example: POST /gateway/wallets/send + { + "chain": "solana", + "network": "mainnet-beta", + "address": "", + "to_address": "", + "amount": "0.001" + } + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + result = await accounts_service.gateway_client.send_transaction( + chain=request.chain, + network=request.network, + address=request.address, + to_address=request.to_address, + amount=request.amount + ) + + if result is None: + raise HTTPException(status_code=502, detail="Failed to send transaction: Gateway returned no response") + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Failed to send transaction: {result.get('error')}") + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error sending transaction: {str(e)}") diff --git a/routers/gateway_clmm.py b/routers/gateway_clmm.py new file mode 100644 index 00000000..72292459 --- /dev/null +++ b/routers/gateway_clmm.py @@ -0,0 +1,1494 @@ +""" +Gateway CLMM Router - Handles DEX CLMM liquidity operations via Hummingbot Gateway. +Supports CLMM connectors (Meteora, Raydium, Uniswap V3) for concentrated liquidity positions. +""" +import asyncio +import logging +from typing import List, Optional +from decimal import Decimal +import aiohttp + +from fastapi import APIRouter, Depends, HTTPException, Query + +from deps import get_accounts_service, get_database_manager +from services.accounts_service import AccountsService +from database import AsyncDatabaseManager +from database.repositories import GatewayCLMMRepository +from models import ( + CLMMOpenPositionRequest, + CLMMOpenPositionResponse, + CLMMAddLiquidityRequest, + CLMMRemoveLiquidityRequest, + CLMMClosePositionRequest, + CLMMCollectFeesRequest, + CLMMCollectFeesResponse, + CLMMPositionsOwnedRequest, + CLMMPositionInfo, + CLMMPoolInfoResponse, + CLMMPoolListItem, + CLMMPoolListResponse, + TimeBasedMetrics, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Gateway CLMM"], prefix="/gateway") + + +async def fetch_meteora_pools( + page: int = 0, + limit: int = 50, + search_term: Optional[str] = None, + sort_key: Optional[str] = "volume", + order_by: Optional[str] = "desc", + include_unknown: bool = True +) -> Optional[dict]: + """ + Fetch available pools from Meteora API. + + Args: + page: Page number (default: 0) + limit: Results per page (default: 50) + search_term: Search term to filter pools + sort_key: Sort key (tvl, volume, feetvlratio, etc.) + order_by: Sort order (asc, desc) + include_unknown: Include pools with unverified tokens + + Returns: + Dictionary with pools from Meteora API, or None if failed + """ + try: + url = "https://dlmm-api.meteora.ag/pair/all_by_groups" + params = { + "page": page, + "limit": limit, + "include_unknown": str(include_unknown).lower() # Convert boolean to lowercase string + } + + if search_term: + params["search_term"] = search_term + if sort_key: + params["sort_key"] = sort_key + if order_by: + params["order_by"] = order_by + + async with aiohttp.ClientSession() as session: + async with session.get(url, params=params, headers={"accept": "application/json"}) as response: + response.raise_for_status() + data = await response.json() + return data + except aiohttp.ClientError as e: + logger.error(f"Failed to fetch pools from Meteora API: {e}") + return None + except Exception as e: + logger.error(f"Error fetching Meteora pools: {e}", exc_info=True) + return None + + +async def fetch_raydium_pool_info(pool_address: str) -> Optional[dict]: + """ + Fetch pool info from Raydium API. + + Args: + pool_address: Pool contract address + + Returns: + Dictionary with pool info from Raydium API, or None if failed + """ + try: + url = f"https://api-v3.raydium.io/pools/info/ids?ids={pool_address}" + async with aiohttp.ClientSession() as session: + async with session.get(url, headers={"accept": "application/json"}) as response: + response.raise_for_status() + data = await response.json() + + if not data.get("success"): + logger.error(f"Raydium API returned unsuccessful response: {data}") + return None + + # Extract the first pool from the data list + pools_data = data.get("data", []) + if not pools_data: + logger.error(f"Raydium API returned empty data for pool: {pool_address}") + return None + + # Return the pool data directly (not wrapped in data key) + return pools_data[0] + except aiohttp.ClientError as e: + logger.error(f"Failed to fetch pool info from Raydium API: {e}") + return None + except Exception as e: + logger.error(f"Error fetching Raydium pool info: {e}", exc_info=True) + return None + + +def transform_raydium_to_clmm_response(raydium_data: dict, pool_address: str) -> dict: + """ + Transform Raydium API response to match Gateway's CLMMPoolInfoResponse format. + + Args: + raydium_data: Pool data from Raydium API (pools/info/ids endpoint) + pool_address: Pool contract address + + Returns: + Dictionary matching Gateway's pool info structure + """ + # Extract token info + mint_a = raydium_data.get("mintA", {}) + mint_b = raydium_data.get("mintB", {}) + + base_token_address = mint_a.get("address", "") + quote_token_address = mint_b.get("address", "") + + # Get current price + current_price = Decimal(str(raydium_data.get("price", 0))) + + # Get token amounts + base_amount = Decimal(str(raydium_data.get("mintAmountA", 0))) + quote_amount = Decimal(str(raydium_data.get("mintAmountB", 0))) + + # Get fee rate (convert from decimal to percentage, e.g., 0.0025 -> 0.25%) + fee_rate = raydium_data.get("feeRate", 0.0025) + fee_pct = Decimal(str(fee_rate * 100)) + + # Check if this is a CLMM (Concentrated) pool + pool_type = raydium_data.get("type", "Standard") + is_clmm = pool_type == "Concentrated" + + # Return in Gateway-compatible format + return { + "address": pool_address, + "baseTokenAddress": base_token_address, + "quoteTokenAddress": quote_token_address, + "binStep": 1 if is_clmm else None, # CLMM pools have tick spacing + "feePct": fee_pct, + "price": current_price, + "baseTokenAmount": base_amount, + "quoteTokenAmount": quote_amount, + "activeBinId": None, # Not available from this endpoint + "dynamicFeePct": None, + "minBinId": None, + "maxBinId": None, + "bins": [] # Bin data not available from pool info endpoint + } + + +def get_transaction_status_from_response(gateway_response: dict) -> str: + """ + Determine transaction status from Gateway response. + + Gateway returns status field in the response: + - status: 1 = confirmed + - status: 0 = pending/submitted + + Returns: + "CONFIRMED" if status == 1 + "SUBMITTED" if status == 0 or not present + """ + status = gateway_response.get("status") + + # Status 1 means transaction is confirmed on-chain + if status == 1: + return "CONFIRMED" + + # Status 0 or missing means submitted but not confirmed yet + return "SUBMITTED" + + +def get_native_gas_token(chain: str) -> str: + """ + Get the native gas token symbol for a blockchain. + + Args: + chain: Blockchain name (e.g., 'solana', 'ethereum', 'polygon') + + Returns: + Gas token symbol (e.g., 'SOL', 'ETH', 'MATIC') + """ + gas_token_map = { + "solana": "SOL", + "ethereum": "ETH", + "polygon": "MATIC", + "avalanche": "AVAX", + "optimism": "ETH", + "arbitrum": "ETH", + "base": "ETH", + "bsc": "BNB", + "cronos": "CRO", + } + return gas_token_map.get(chain.lower(), "UNKNOWN") + + +async def _refresh_position_data(position, accounts_service: AccountsService, clmm_repo: GatewayCLMMRepository): + """ + Refresh position data from Gateway and update database. + + This updates: + - in_range status + - liquidity amounts + - pending fees + - position status (if closed externally) + """ + try: + # Get wallet address for the position + wallet_address = position.wallet_address + + # Get all positions for this pool and find our specific position + try: + positions_list = await accounts_service.gateway_client.clmm_positions_owned( + connector=position.connector, + chain_network=position.network, # position.network is already in 'chain-network' format + wallet_address=wallet_address, + pool_address=position.pool_address + ) + + # Find our specific position in the list + result = None + if isinstance(positions_list, list): + for pos in positions_list: + if pos.get("address") == position.position_address: + result = pos + break + + # If position not found, it was closed externally + if result is None: + logger.info(f"Position {position.position_address} not found on Gateway, marking as CLOSED") + await clmm_repo.close_position(position.position_address) + return + + except Exception as e: + # If we can't fetch positions, log error but don't mark as closed + logger.error(f"Error fetching position from Gateway: {e}") + return + + # Extract current state + current_price = Decimal(str(result.get("price", 0))) + lower_price = Decimal(str(result.get("lowerPrice", 0))) if result.get("lowerPrice") else Decimal("0") + upper_price = Decimal(str(result.get("upperPrice", 0))) if result.get("upperPrice") else Decimal("0") + + # Calculate in_range status + in_range = "UNKNOWN" + if current_price > 0 and lower_price > 0 and upper_price > 0: + if lower_price <= current_price <= upper_price: + in_range = "IN_RANGE" + else: + in_range = "OUT_OF_RANGE" + + # Extract token amounts + base_token_amount = Decimal(str(result.get("baseTokenAmount", 0))) + quote_token_amount = Decimal(str(result.get("quoteTokenAmount", 0))) + + # Check if position has been closed (zero liquidity) + if base_token_amount == 0 and quote_token_amount == 0: + logger.info(f"Position {position.position_address} has zero liquidity, marking as CLOSED") + await clmm_repo.close_position(position.position_address) + return + + # Update liquidity amounts, in_range status, and current price + await clmm_repo.update_position_liquidity( + position_address=position.position_address, + base_token_amount=base_token_amount, + quote_token_amount=quote_token_amount, + in_range=in_range, + current_price=current_price + ) + + # Update pending fees if available + base_fee_pending = Decimal(str(result.get("baseFeeAmount", 0))) + quote_fee_pending = Decimal(str(result.get("quoteFeeAmount", 0))) + + if base_fee_pending or quote_fee_pending: + await clmm_repo.update_position_fees( + position_address=position.position_address, + base_fee_pending=base_fee_pending, + quote_fee_pending=quote_fee_pending + ) + + logger.debug(f"Refreshed position {position.position_address}: price={current_price}, in_range={in_range}, " + f"base={base_token_amount}, quote={quote_token_amount}") + + except Exception as e: + logger.error(f"Error refreshing position {position.position_address}: {e}", exc_info=True) + raise + + +@router.get("/clmm/pool-info", response_model=CLMMPoolInfoResponse, response_model_by_alias=False) +async def get_clmm_pool_info( + connector: str, + network: str, + pool_address: str, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get detailed information about a CLMM pool by pool address. + + Args: + connector: CLMM connector (e.g., 'meteora', 'raydium') + network: Network ID in 'chain-network' format (e.g., 'solana-mainnet-beta') + pool_address: Pool contract address + + Example: + GET /gateway/clmm/pool-info?connector=meteora&network=solana-mainnet-beta&pool_address=2sf5NYcY4zUPXUSmG6f66mskb24t5F8S11pC1Nz5nQT3 + + Returns: + Pool information including liquidity, price, bins (for Meteora), etc. + All field names are returned in snake_case format. + + Note: + For Raydium connector, uses Raydium API directly instead of Gateway. + """ + try: + # Special handling for Raydium - use Raydium API directly (not Gateway) + if connector.lower() == "raydium": + logger.info(f"Using Raydium API directly for pool info: {pool_address}") + + # Fetch from Raydium API + raydium_data = await fetch_raydium_pool_info(pool_address) + if raydium_data is None: + raise HTTPException(status_code=503, detail="Failed to get pool info from Raydium API") + + # Check if this is a CLMM pool - Standard AMM pools are not supported on this endpoint + pool_type = raydium_data.get("type", "Standard") + if pool_type != "Concentrated": + raise HTTPException( + status_code=400, + detail=f"Pool {pool_address} is a Raydium {pool_type} AMM pool, not a CLMM pool. " + f"This endpoint only supports Concentrated Liquidity (CLMM) pools." + ) + + # Transform to Gateway-compatible format + result = transform_raydium_to_clmm_response(raydium_data, pool_address) + + # Parse into response model + return CLMMPoolInfoResponse(**result) + + # Default behavior for other connectors: use Gateway + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network_name = accounts_service.gateway_client.parse_network_id(network) + + # Get pool info from Gateway using the CLMM-specific endpoint + result = await accounts_service.gateway_client.clmm_pool_info( + connector=connector, + network=network_name, + pool_address=pool_address + ) + + if result is None: + raise HTTPException(status_code=503, detail="Failed to get pool info from Gateway") + + # Parse the camelCase Gateway response into snake_case Pydantic model + # The model's aliases will handle the conversion + return CLMMPoolInfoResponse(**result) + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error getting CLMM pool info: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting CLMM pool info: {str(e)}") + + +@router.get("/clmm/pools", response_model=CLMMPoolListResponse) +async def get_clmm_pools( + connector: str, + page: int = Query(0, ge=0, description="Page number"), + limit: int = Query(50, ge=1, le=100, description="Results per page (max 100)"), + search_term: Optional[str] = Query(None, description="Search term to filter pools"), + sort_key: Optional[str] = Query("volume", description="Sort key (volume, tvl, etc.)"), + order_by: Optional[str] = Query("desc", description="Sort order (asc, desc)"), + include_unknown: bool = Query(True, description="Include pools with unverified tokens") +): + """ + Get list of available CLMM pools for a connector. + + Currently supports: meteora + + Args: + connector: CLMM connector (e.g., 'meteora') + page: Page number (default: 0) + limit: Results per page (default: 50, max: 100) + search_term: Search term to filter pools (optional) + sort_key: Sort by field (volume, tvl, feetvlratio, etc.) + order_by: Sort order (asc, desc) + include_unknown: Include pools with unverified tokens + + Example: + GET /gateway/clmm/pools?connector=meteora&search_term=SOL&limit=20 + + Returns: + List of available pools with trading pairs, addresses, liquidity, volume, APR, etc. + """ + try: + # Only support Meteora for now + if connector.lower() != "meteora": + raise HTTPException( + status_code=400, + detail=f"Pool listing not supported for connector '{connector}'. Currently only 'meteora' is supported." + ) + + # Fetch pools from Meteora API + logger.info(f"Fetching pools from Meteora API (page={page}, limit={limit}, search={search_term})") + meteora_data = await fetch_meteora_pools( + page=page, + limit=limit, + search_term=search_term, + sort_key=sort_key, + order_by=order_by, + include_unknown=include_unknown + ) + + if meteora_data is None: + raise HTTPException(status_code=503, detail="Failed to fetch pools from Meteora API") + + # Transform Meteora response to our format + pools = [] + groups = meteora_data.get("groups", []) + + for group in groups: + pairs = group.get("pairs", []) + for pair in pairs: + # Extract trading pair from name or construct from mints + name = pair.get("name", "") + trading_pair = name if name else f"{pair.get('mint_x', '')[:8]}-{pair.get('mint_y', '')[:8]}" + + # Helper function to safely convert dict metrics to TimeBasedMetrics + def to_time_metrics(data): + if not data: + return None + return TimeBasedMetrics( + min_30=Decimal(str(data.get("min_30"))) if data.get("min_30") is not None else None, + hour_1=Decimal(str(data.get("hour_1"))) if data.get("hour_1") is not None else None, + hour_2=Decimal(str(data.get("hour_2"))) if data.get("hour_2") is not None else None, + hour_4=Decimal(str(data.get("hour_4"))) if data.get("hour_4") is not None else None, + hour_12=Decimal(str(data.get("hour_12"))) if data.get("hour_12") is not None else None, + hour_24=Decimal(str(data.get("hour_24"))) if data.get("hour_24") is not None else None + ) + + pools.append(CLMMPoolListItem( + address=pair.get("address", ""), + name=name, + trading_pair=trading_pair, + mint_x=pair.get("mint_x", ""), + mint_y=pair.get("mint_y", ""), + bin_step=pair.get("bin_step", 0), + current_price=Decimal(str(pair.get("current_price", 0))), + liquidity=pair.get("liquidity", "0"), + reserve_x=pair.get("reserve_x", "0"), + reserve_y=pair.get("reserve_y", "0"), + reserve_x_amount=Decimal(str(pair.get("reserve_x_amount"))) if pair.get("reserve_x_amount") is not None else None, + reserve_y_amount=Decimal(str(pair.get("reserve_y_amount"))) if pair.get("reserve_y_amount") is not None else None, + + # Fee structure + base_fee_percentage=pair.get("base_fee_percentage"), + max_fee_percentage=pair.get("max_fee_percentage"), + protocol_fee_percentage=pair.get("protocol_fee_percentage"), + + # APR/APY + apr=Decimal(str(pair.get("apr", 0))) if pair.get("apr") is not None else None, + apy=Decimal(str(pair.get("apy", 0))) if pair.get("apy") is not None else None, + farm_apr=Decimal(str(pair.get("farm_apr"))) if pair.get("farm_apr") is not None else None, + farm_apy=Decimal(str(pair.get("farm_apy"))) if pair.get("farm_apy") is not None else None, + + # Volume and fees + volume_24h=Decimal(str(pair.get("trade_volume_24h", 0))) if pair.get("trade_volume_24h") is not None else None, + fees_24h=Decimal(str(pair.get("fees_24h", 0))) if pair.get("fees_24h") is not None else None, + today_fees=Decimal(str(pair.get("today_fees"))) if pair.get("today_fees") is not None else None, + cumulative_trade_volume=pair.get("cumulative_trade_volume"), + cumulative_fee_volume=pair.get("cumulative_fee_volume"), + + # Time-based metrics + volume=to_time_metrics(pair.get("volume")), + fees=to_time_metrics(pair.get("fees")), + fee_tvl_ratio=to_time_metrics(pair.get("fee_tvl_ratio")), + + # Rewards + reward_mint_x=pair.get("reward_mint_x"), + reward_mint_y=pair.get("reward_mint_y"), + + # Metadata + tags=pair.get("tags"), + is_verified=pair.get("is_verified", False), + is_blacklisted=pair.get("is_blacklisted"), + hide=pair.get("hide"), + launchpad=pair.get("launchpad") + )) + + total = meteora_data.get("total", len(pools)) + + return CLMMPoolListResponse( + pools=pools, + total=total, + page=page, + limit=limit + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting CLMM pools: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting CLMM pools: {str(e)}") + + +@router.post("/clmm/open", response_model=CLMMOpenPositionResponse) +async def open_clmm_position( + request: CLMMOpenPositionRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Open a NEW CLMM position with initial liquidity. + + Example: + connector: 'meteora' + network: 'solana-mainnet-beta' + pool_address: '2sf5NYcY4zUPXUSmG6f66mskb24t5F8S11pC1Nz5nQT3' + lower_price: 150 + upper_price: 250 + base_token_amount: 0.01 + quote_token_amount: 2 + slippage_pct: 1 + wallet_address: (optional) + extra_params: {"strategyType": 0} # Meteora-specific + + Returns: + Transaction hash and position address + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network = accounts_service.gateway_client.parse_network_id(request.network) + + # Get wallet address + wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( + chain=chain, + wallet_address=request.wallet_address + ) + + # Get pool info to extract trading pair for database + pool_info = await accounts_service.gateway_client.clmm_pool_info( + connector=request.connector, + network=network, + pool_address=request.pool_address + ) + + # Extract tokens from pool info + base_token_address = pool_info.get("baseTokenAddress", "") + quote_token_address = pool_info.get("quoteTokenAddress", "") + + # Extract entry price from pool info (current pool price at time of opening) + entry_price = float(pool_info.get("price", 0)) if pool_info.get("price") else None + if entry_price: + logger.info(f"Entry price for position: {entry_price}") + + # Store full token addresses in the database + base = base_token_address if base_token_address else "UNKNOWN" + quote = quote_token_address if quote_token_address else "UNKNOWN" + trading_pair = f"{base}-{quote}" + + # Open position + result = await accounts_service.gateway_client.clmm_open_position( + connector=request.connector, + network=network, + wallet_address=wallet_address, + pool_address=request.pool_address, + lower_price=float(request.lower_price), + upper_price=float(request.upper_price), + base_token_amount=float(request.base_token_amount) if request.base_token_amount else None, + quote_token_amount=float(request.quote_token_amount) if request.quote_token_amount else None, + slippage_pct=float(request.slippage_pct) if request.slippage_pct else 1.0, + extra_params=request.extra_params + ) + if not result: + raise HTTPException(status_code=404, detail=f"Failed to open CLMM position: {trading_pair}") + + transaction_hash = result.get("signature") or result.get("txHash") or result.get("hash") + + # Position address can be at root level or nested in data object + data = result.get("data", {}) + position_address = result.get("positionAddress") or result.get("position") or data.get("positionAddress") or data.get("position") + + # Extract position rent (SOL locked for position NFT) + position_rent = data.get("positionRent") + if position_rent: + logger.info(f"Position rent: {position_rent} SOL") + + if not transaction_hash: + raise HTTPException(status_code=500, detail="No transaction hash returned from Gateway") + if not position_address: + raise HTTPException(status_code=500, detail="No position address returned from Gateway") + + # Calculate percentage: (upper_price - lower_price) / lower_price + percentage = None + if request.lower_price and request.upper_price and request.lower_price > 0: + percentage = float((request.upper_price - request.lower_price) / request.lower_price) + logger.info(f"Position price range percentage: {percentage:.4f} ({percentage*100:.2f}%)") + + # Get transaction status from Gateway response + tx_status = get_transaction_status_from_response(result) + + # Extract gas fee from Gateway response + gas_fee = data.get("fee") + gas_token = get_native_gas_token(chain) + + # Store position and event in database + try: + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + + # Create position record + position_data = { + "position_address": position_address, + "pool_address": request.pool_address, + "network": request.network, + "connector": request.connector, + "wallet_address": wallet_address, + "trading_pair": trading_pair, + "base_token": base, + "quote_token": quote, + "status": "OPEN", + "lower_price": float(request.lower_price), + "upper_price": float(request.upper_price), + "percentage": percentage, + "entry_price": entry_price, # Pool price when position opened + "current_price": entry_price, # Same as entry at open time, updated by poller + "initial_base_token_amount": float(request.base_token_amount) if request.base_token_amount else 0, + "initial_quote_token_amount": float(request.quote_token_amount) if request.quote_token_amount else 0, + "position_rent": float(position_rent) if position_rent else None, + "base_token_amount": float(request.base_token_amount) if request.base_token_amount else 0, + "quote_token_amount": float(request.quote_token_amount) if request.quote_token_amount else 0, + "in_range": "UNKNOWN" # Will be updated by poller + } + + position = await clmm_repo.create_position(position_data) + logger.info(f"Recorded CLMM position in database: {position_address}") + + # Create OPEN event with polled status + event_data = { + "position_id": position.id, + "transaction_hash": transaction_hash, + "event_type": "OPEN", + "base_token_amount": float(request.base_token_amount) if request.base_token_amount else None, + "quote_token_amount": float(request.quote_token_amount) if request.quote_token_amount else None, + "gas_fee": float(gas_fee) if gas_fee else None, + "gas_token": gas_token, + "status": tx_status + } + + await clmm_repo.create_event(event_data) + logger.info(f"Recorded CLMM OPEN event in database: {transaction_hash} (status: {tx_status}, gas: {gas_fee} {gas_token})") + except Exception as db_error: + # Log but don't fail the operation - it was submitted successfully + logger.error(f"Error recording CLMM position in database: {db_error}", exc_info=True) + + return CLMMOpenPositionResponse( + transaction_hash=transaction_hash, + position_address=position_address, + trading_pair=trading_pair, + pool_address=request.pool_address, + lower_price=request.lower_price, + upper_price=request.upper_price, + status="submitted" + ) + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error opening CLMM position: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error opening CLMM position: {str(e)}") + + +# @router.post("/clmm/add") +# async def add_liquidity_to_clmm_position( +# request: CLMMAddLiquidityRequest, +# accounts_service: AccountsService = Depends(get_accounts_service), +# db_manager: AsyncDatabaseManager = Depends(get_database_manager) +# ): +# """ +# Add MORE liquidity to an EXISTING CLMM position. +# +# Example: +# connector: 'meteora' +# network: 'solana-mainnet-beta' +# position_address: '...' +# base_token_amount: 0.5 +# quote_token_amount: 50.0 +# slippage_pct: 1 +# wallet_address: (optional) +# +# Returns: +# Transaction hash +# """ +# try: +# if not await accounts_service.gateway_client.ping(): +# raise HTTPException(status_code=503, detail="Gateway service is not available") +# +# # Parse network_id +# chain, network = accounts_service.gateway_client.parse_network_id(request.network) +# +# # Get wallet address +# wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( +# chain=chain, +# wallet_address=request.wallet_address +# ) +# +# # Add liquidity to existing position +# result = await accounts_service.gateway_client.clmm_add_liquidity( +# connector=request.connector, +# network=network, +# wallet_address=wallet_address, +# position_address=request.position_address, +# base_token_amount=float(request.base_token_amount) if request.base_token_amount else None, +# quote_token_amount=float(request.quote_token_amount) if request.quote_token_amount else None, +# slippage_pct=float(request.slippage_pct) if request.slippage_pct else 1.0 +# ) +# +# transaction_hash = result.get("signature") or result.get("txHash") or result.get("hash") +# if not transaction_hash: +# raise HTTPException(status_code=500, detail="No transaction hash returned from Gateway") +# +# # Get transaction status from Gateway response +# tx_status = get_transaction_status_from_response(result) +# +# # Extract gas fee from Gateway response +# data = result.get("data", {}) +# gas_fee = data.get("fee") +# gas_token = "SOL" if chain == "solana" else "ETH" if chain == "ethereum" else None +# +# # Store ADD_LIQUIDITY event in database +# try: +# async with db_manager.get_session_context() as session: +# clmm_repo = GatewayCLMMRepository(session) +# +# # Get position to link event +# position = await clmm_repo.get_position_by_address(request.position_address) +# if position: +# event_data = { +# "position_id": position.id, +# "transaction_hash": transaction_hash, +# "event_type": "ADD_LIQUIDITY", +# "base_token_amount": float(request.base_token_amount) if request.base_token_amount else None, +# "quote_token_amount": float(request.quote_token_amount) if request.quote_token_amount else None, +# "gas_fee": float(gas_fee) if gas_fee else None, +# "gas_token": gas_token, +# "status": tx_status +# } +# await clmm_repo.create_event(event_data) +# logger.info(f"Recorded CLMM ADD_LIQUIDITY event: {transaction_hash} (status: {tx_status}, gas: {gas_fee} {gas_token})") +# except Exception as db_error: +# logger.error(f"Error recording ADD_LIQUIDITY event: {db_error}", exc_info=True) +# +# return { +# "transaction_hash": transaction_hash, +# "position_address": request.position_address, +# "status": "submitted" +# } +# +# except HTTPException: +# raise +# except ValueError as e: +# raise HTTPException(status_code=400, detail=str(e)) +# except Exception as e: +# logger.error(f"Error adding liquidity to CLMM position: {e}", exc_info=True) +# raise HTTPException(status_code=500, detail=f"Error adding liquidity to CLMM position: {str(e)}") +# +# +# @router.post("/clmm/remove") +# async def remove_liquidity_from_clmm_position( +# request: CLMMRemoveLiquidityRequest, +# accounts_service: AccountsService = Depends(get_accounts_service), +# db_manager: AsyncDatabaseManager = Depends(get_database_manager) +# ): +# """ +# Remove SOME liquidity from a CLMM position (partial removal). +# +# Example: +# connector: 'meteora' +# network: 'solana-mainnet-beta' +# position_address: '...' +# percentage: 50 +# wallet_address: (optional) +# +# Returns: +# Transaction hash +# """ +# try: +# if not await accounts_service.gateway_client.ping(): +# raise HTTPException(status_code=503, detail="Gateway service is not available") +# +# # Parse network_id +# chain, network = accounts_service.gateway_client.parse_network_id(request.network) +# +# # Get wallet address +# wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( +# chain=chain, +# wallet_address=request.wallet_address +# ) +# +# # Remove liquidity +# result = await accounts_service.gateway_client.clmm_remove_liquidity( +# connector=request.connector, +# network=network, +# wallet_address=wallet_address, +# position_address=request.position_address, +# percentage=float(request.percentage) +# ) +# +# transaction_hash = result.get("signature") or result.get("txHash") or result.get("hash") +# if not transaction_hash: +# raise HTTPException(status_code=500, detail="No transaction hash returned from Gateway") +# +# # Get transaction status from Gateway response +# tx_status = get_transaction_status_from_response(result) +# +# # Extract gas fee from Gateway response +# data = result.get("data", {}) +# gas_fee = data.get("fee") +# gas_token = "SOL" if chain == "solana" else "ETH" if chain == "ethereum" else None +# +# # Store REMOVE_LIQUIDITY event in database +# try: +# async with db_manager.get_session_context() as session: +# clmm_repo = GatewayCLMMRepository(session) +# +# # Get position to link event +# position = await clmm_repo.get_position_by_address(request.position_address) +# if position: +# event_data = { +# "position_id": position.id, +# "transaction_hash": transaction_hash, +# "event_type": "REMOVE_LIQUIDITY", +# "percentage": float(request.percentage), +# "gas_fee": float(gas_fee) if gas_fee else None, +# "gas_token": gas_token, +# "status": tx_status +# } +# await clmm_repo.create_event(event_data) +# logger.info(f"Recorded CLMM REMOVE_LIQUIDITY event: {transaction_hash} (status: {tx_status}, gas: {gas_fee} {gas_token})") +# except Exception as db_error: +# logger.error(f"Error recording REMOVE_LIQUIDITY event: {db_error}", exc_info=True) +# +# return { +# "transaction_hash": transaction_hash, +# "position_address": request.position_address, +# "percentage": float(request.percentage), +# "status": "submitted" +# } +# +# except HTTPException: +# raise +# except ValueError as e: +# raise HTTPException(status_code=400, detail=str(e)) +# except Exception as e: +# logger.error(f"Error removing liquidity from CLMM position: {e}", exc_info=True) +# raise HTTPException(status_code=500, detail=f"Error removing liquidity from CLMM position: {str(e)}") +# + +@router.post("/clmm/close", response_model=CLMMCollectFeesResponse) +async def close_clmm_position( + request: CLMMClosePositionRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + CLOSE a CLMM position completely (removes all liquidity and collects pending fees). + + Example: + connector: 'meteora' + network: 'solana-mainnet-beta' + position_address: '...' + wallet_address: (optional) + + Returns: + Transaction hash and collected fee amounts + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network = accounts_service.gateway_client.parse_network_id(request.network) + + # Get pool_address and wallet_address from database + pool_address = None + wallet_address = None + + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + db_position = await clmm_repo.get_position_by_address(request.position_address) + if db_position: + pool_address = db_position.pool_address + wallet_address = db_position.wallet_address + + # If not in database, use default wallet + if not wallet_address: + wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( + chain=chain, + wallet_address=request.wallet_address + ) + + # If no pool_address from database, we can't query Gateway + if not pool_address: + raise HTTPException( + status_code=404, + detail=f"Position {request.position_address} not found in database. Pool address is required." + ) + + # Fetch pending fees and current price BEFORE closing (Gateway doesn't always return these in response) + base_fee_to_collect = Decimal("0") + quote_fee_to_collect = Decimal("0") + close_price = None + + try: + positions_list = await accounts_service.gateway_client.clmm_positions_owned( + connector=request.connector, + chain_network=request.network, # request.network is already in 'chain-network' format + wallet_address=wallet_address, + pool_address=pool_address + ) + + # Find our specific position and get pending fees and current price + if positions_list and isinstance(positions_list, list): + for pos in positions_list: + if pos and pos.get("address") == request.position_address: + base_fee_to_collect = Decimal(str(pos.get("baseFeeAmount", 0))) + quote_fee_to_collect = Decimal(str(pos.get("quoteFeeAmount", 0))) + close_price = float(pos.get("price", 0)) if pos.get("price") else None + logger.info(f"Before closing: price={close_price}, pending fees base={base_fee_to_collect}, quote={quote_fee_to_collect}") + break + else: + logger.warning(f"Could not find position {request.position_address} in positions_owned response") + except Exception as e: + logger.warning(f"Could not fetch position state before closing: {e}", exc_info=True) + + # Close position + result = await accounts_service.gateway_client.clmm_close_position( + connector=request.connector, + network=network, + wallet_address=wallet_address, + position_address=request.position_address + ) + + transaction_hash = result.get("signature") or result.get("txHash") or result.get("hash") + if not transaction_hash: + raise HTTPException(status_code=500, detail="No transaction hash returned from Gateway") + + # Get transaction status from Gateway response + tx_status = get_transaction_status_from_response(result) + + # Extract gas fee from Gateway response + data = result.get("data", {}) + gas_fee = data.get("fee") + gas_token = get_native_gas_token(chain) + + # Try to extract collected amounts from Gateway response, fallback to pre-fetched amounts + base_fee_from_response = data.get("baseFeeAmountCollected") + quote_fee_from_response = data.get("quoteFeeAmountCollected") + + # Use response values if available, otherwise use pre-fetched values + base_fee_collected = Decimal(str(base_fee_from_response)) if base_fee_from_response is not None else base_fee_to_collect + quote_fee_collected = Decimal(str(quote_fee_from_response)) if quote_fee_from_response is not None else quote_fee_to_collect + + logger.info(f"Collected fees on close: base={base_fee_collected}, quote={quote_fee_collected}") + + # Store CLOSE event in database and update position + try: + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + + # Get position to link event + position = await clmm_repo.get_position_by_address(request.position_address) + if position: + # Create event record + event_data = { + "position_id": position.id, + "transaction_hash": transaction_hash, + "event_type": "CLOSE", + "base_fee_collected": float(base_fee_collected) if base_fee_collected else None, + "quote_fee_collected": float(quote_fee_collected) if quote_fee_collected else None, + "gas_fee": float(gas_fee) if gas_fee else None, + "gas_token": gas_token, + "status": tx_status + } + await clmm_repo.create_event(event_data) + logger.info(f"Recorded CLMM CLOSE event: {transaction_hash} (status: {tx_status}, gas: {gas_fee} {gas_token})") + + # Update position: add to collected, reset pending to 0, mark as CLOSED + new_base_collected = Decimal(str(position.base_fee_collected)) + base_fee_collected + new_quote_collected = Decimal(str(position.quote_fee_collected)) + quote_fee_collected + + await clmm_repo.update_position_fees( + position_address=request.position_address, + base_fee_collected=new_base_collected, + quote_fee_collected=new_quote_collected, + base_fee_pending=Decimal("0"), + quote_fee_pending=Decimal("0") + ) + + # Update current_price with close price + if close_price: + await clmm_repo.update_position_liquidity( + position_address=request.position_address, + base_token_amount=Decimal(str(position.base_token_amount)), + quote_token_amount=Decimal(str(position.quote_token_amount)), + current_price=Decimal(str(close_price)) + ) + + # Verify position is actually closed by checking if it still exists on Gateway + # Gateway returns 500 (or 404) when position doesn't exist + try: + await asyncio.sleep(2) # Wait for transaction to propagate + + verify_result = await accounts_service.gateway_client.clmm_position_info( + connector=request.connector, + chain_network=request.network, + position_address=request.position_address + ) + + # If we get an error response (404 or 500), position is closed + if verify_result and isinstance(verify_result, dict) and "error" in verify_result: + status_code = verify_result.get("status") + if status_code in (404, 500): + await clmm_repo.close_position(request.position_address) + logger.info(f"Position {request.position_address} verified as closed (Gateway returned {status_code})") + else: + logger.warning(f"Unexpected error verifying position close: {verify_result}") + elif verify_result and "address" in verify_result: + # Position still exists - might be a failed close or delayed propagation + logger.warning(f"Position {request.position_address} still exists after close transaction. Will be handled by poller.") + else: + logger.debug(f"Could not verify position close status, will be handled by poller") + + except Exception as verify_error: + logger.warning(f"Error verifying position close: {verify_error}. Will be handled by poller.") + + logger.info(f"Updated position {request.position_address}: collected fees updated, pending fees reset to 0.") + except Exception as db_error: + logger.error(f"Error recording CLOSE event: {db_error}", exc_info=True) + + return CLMMCollectFeesResponse( + transaction_hash=transaction_hash, + position_address=request.position_address, + base_fee_collected=Decimal(str(base_fee_collected)) if base_fee_collected else None, + quote_fee_collected=Decimal(str(quote_fee_collected)) if quote_fee_collected else None, + status="submitted" + ) + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error closing CLMM position: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error closing CLMM position: {str(e)}") + + +@router.post("/clmm/collect-fees", response_model=CLMMCollectFeesResponse) +async def collect_fees_from_clmm_position( + request: CLMMCollectFeesRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Collect accumulated fees from a CLMM liquidity position. + + Example: + connector: 'meteora' + network: 'solana-mainnet-beta' + position_address: '...' + wallet_address: (optional) + + Returns: + Transaction hash and collected fee amounts + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network = accounts_service.gateway_client.parse_network_id(request.network) + + # Get pool_address and wallet_address from database + pool_address = None + wallet_address = None + + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + db_position = await clmm_repo.get_position_by_address(request.position_address) + if db_position: + pool_address = db_position.pool_address + wallet_address = db_position.wallet_address + + # If not in database, use default wallet + if not wallet_address: + wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( + chain=chain, + wallet_address=request.wallet_address + ) + + # If no pool_address from database, we can't query Gateway + if not pool_address: + raise HTTPException( + status_code=404, + detail=f"Position {request.position_address} not found in database. Pool address is required." + ) + + # Fetch pending fees BEFORE collecting (Gateway doesn't always return collected amounts in response) + base_fee_to_collect = Decimal("0") + quote_fee_to_collect = Decimal("0") + + try: + positions_list = await accounts_service.gateway_client.clmm_positions_owned( + connector=request.connector, + chain_network=request.network, # request.network is already in 'chain-network' format + wallet_address=wallet_address, + pool_address=pool_address + ) + + # Find our specific position and get pending fees + if positions_list and isinstance(positions_list, list): + for pos in positions_list: + if pos and pos.get("address") == request.position_address: + base_fee_to_collect = Decimal(str(pos.get("baseFeeAmount", 0))) + quote_fee_to_collect = Decimal(str(pos.get("quoteFeeAmount", 0))) + logger.info(f"Pending fees before collection: base={base_fee_to_collect}, quote={quote_fee_to_collect}") + break + else: + logger.warning(f"Could not find position {request.position_address} in positions_owned response") + except Exception as e: + logger.warning(f"Could not fetch pending fees before collection: {e}", exc_info=True) + + # Collect fees + result = await accounts_service.gateway_client.clmm_collect_fees( + connector=request.connector, + network=network, + wallet_address=wallet_address, + position_address=request.position_address + ) + + if not result: + raise HTTPException(status_code=500, detail="No response from Gateway collect-fees endpoint") + + transaction_hash = result.get("signature") or result.get("txHash") or result.get("hash") + if not transaction_hash: + raise HTTPException(status_code=500, detail="No transaction hash returned from Gateway") + + # Get transaction status from Gateway response + tx_status = get_transaction_status_from_response(result) + + # Try to extract collected amounts from Gateway response, fallback to pre-fetched amounts + data = result.get("data", {}) + base_fee_from_response = data.get("baseFeeAmountCollected") + quote_fee_from_response = data.get("quoteFeeAmountCollected") + + # Use response values if available, otherwise use pre-fetched values + base_fee_collected = Decimal(str(base_fee_from_response)) if base_fee_from_response is not None else base_fee_to_collect + quote_fee_collected = Decimal(str(quote_fee_from_response)) if quote_fee_from_response is not None else quote_fee_to_collect + + # Extract gas fee from Gateway response + gas_fee = data.get("fee") + gas_token = get_native_gas_token(chain) + + logger.info(f"Collected fees: base={base_fee_collected}, quote={quote_fee_collected}") + + # Store COLLECT_FEES event in database and update position + try: + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + + # Get position to link event + position = await clmm_repo.get_position_by_address(request.position_address) + if position: + # Create event record + event_data = { + "position_id": position.id, + "transaction_hash": transaction_hash, + "event_type": "COLLECT_FEES", + "base_fee_collected": float(base_fee_collected) if base_fee_collected else None, + "quote_fee_collected": float(quote_fee_collected) if quote_fee_collected else None, + "gas_fee": float(gas_fee) if gas_fee else None, + "gas_token": gas_token, + "status": tx_status + } + await clmm_repo.create_event(event_data) + logger.info(f"Recorded CLMM COLLECT_FEES event: {transaction_hash} (status: {tx_status}, gas: {gas_fee} {gas_token})") + + # Update position: add to collected, reset pending to 0 + new_base_collected = Decimal(str(position.base_fee_collected)) + base_fee_collected + new_quote_collected = Decimal(str(position.quote_fee_collected)) + quote_fee_collected + + await clmm_repo.update_position_fees( + position_address=request.position_address, + base_fee_collected=new_base_collected, + quote_fee_collected=new_quote_collected, + base_fee_pending=Decimal("0"), + quote_fee_pending=Decimal("0") + ) + logger.info(f"Updated position {request.position_address}: collected fees updated, pending fees reset to 0") + except Exception as db_error: + logger.error(f"Error recording COLLECT_FEES event: {db_error}", exc_info=True) + + return CLMMCollectFeesResponse( + transaction_hash=transaction_hash, + position_address=request.position_address, + base_fee_collected=Decimal(str(base_fee_collected)) if base_fee_collected else None, + quote_fee_collected=Decimal(str(quote_fee_collected)) if quote_fee_collected else None, + status="submitted" + ) + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error collecting fees: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error collecting fees: {str(e)}") + + +@router.post("/clmm/positions_owned", response_model=List[CLMMPositionInfo]) +async def get_clmm_positions_owned( + request: CLMMPositionsOwnedRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get all CLMM liquidity positions owned by a wallet for a specific pool. + + Example: + connector: 'meteora' + network: 'solana-mainnet-beta' + pool_address: '2sf5NYcY4zUPXUSmG6f66mskb24t5F8S11pC1Nz5nQT3' + wallet_address: (optional, uses default if not provided) + + Returns: + List of CLMM position information for the specified pool + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network = accounts_service.gateway_client.parse_network_id(request.network) + + # Get wallet address + wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( + chain=chain, + wallet_address=request.wallet_address + ) + + # Get positions for the specified pool + result = await accounts_service.gateway_client.clmm_positions_owned( + connector=request.connector, + chain_network=request.network, # request.network is already in 'chain-network' format + wallet_address=wallet_address, + pool_address=request.pool_address + ) + + if result is None: + raise HTTPException(status_code=500, detail="Failed to get positions from Gateway") + + # Gateway returns a list directly + positions_data = result if isinstance(result, list) else [] + positions = [] + + for pos in positions_data: + # Extract token addresses (Gateway returns addresses, not symbols) + base_token_address = pos.get("baseTokenAddress", "") + quote_token_address = pos.get("quoteTokenAddress", "") + + # Use short addresses as symbols for now + base_token = base_token_address[-8:] if base_token_address else "" + quote_token = quote_token_address[-8:] if quote_token_address else "" + trading_pair = f"{base_token}-{quote_token}" if base_token and quote_token else "" + + current_price = Decimal(str(pos.get("price", 0))) + lower_price = Decimal(str(pos.get("lowerPrice", 0))) if pos.get("lowerPrice") else Decimal("0") + upper_price = Decimal(str(pos.get("upperPrice", 0))) if pos.get("upperPrice") else Decimal("0") + + # Determine if position is in range + in_range = False + if current_price > 0 and lower_price > 0 and upper_price > 0: + in_range = lower_price <= current_price <= upper_price + + positions.append(CLMMPositionInfo( + position_address=pos.get("address", ""), + pool_address=pos.get("poolAddress", ""), + trading_pair=trading_pair, + base_token=base_token, + quote_token=quote_token, + base_token_amount=Decimal(str(pos.get("baseTokenAmount", 0))), + quote_token_amount=Decimal(str(pos.get("quoteTokenAmount", 0))), + current_price=current_price, + lower_price=lower_price, + upper_price=upper_price, + base_fee_amount=Decimal(str(pos.get("baseFeeAmount", 0))) if pos.get("baseFeeAmount") else None, + quote_fee_amount=Decimal(str(pos.get("quoteFeeAmount", 0))) if pos.get("quoteFeeAmount") else None, + lower_bin_id=pos.get("lowerBinId"), + upper_bin_id=pos.get("upperBinId"), + in_range=in_range + )) + + return positions + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error getting CLMM positions owned: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting CLMM positions owned: {str(e)}") + + +@router.get("/clmm/positions/{position_address}/events") +async def get_clmm_position_events( + position_address: str, + event_type: Optional[str] = None, + limit: int = 100, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get event history for a CLMM position. + + Args: + position_address: Position NFT address + event_type: Filter by event type (OPEN, ADD_LIQUIDITY, REMOVE_LIQUIDITY, COLLECT_FEES, CLOSE) + limit: Max events to return + + Returns: + List of position events + """ + try: + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + events = await clmm_repo.get_position_events( + position_address=position_address, + event_type=event_type, + limit=limit + ) + + return { + "data": [clmm_repo.event_to_dict(event) for event in events], + "total_count": len(events) + } + + except Exception as e: + logger.error(f"Error getting position events: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting position events: {str(e)}") + + +@router.post("/clmm/positions/search") +async def search_clmm_positions( + network: Optional[str] = None, + connector: Optional[str] = None, + wallet_address: Optional[str] = None, + trading_pair: Optional[str] = None, + status: Optional[str] = None, + position_addresses: Optional[List[str]] = Query(None), + limit: int = 50, + offset: int = 0, + refresh: bool = False, + db_manager: AsyncDatabaseManager = Depends(get_database_manager), + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Search CLMM positions with filters. + + Args: + network: Filter by network (e.g., 'solana-mainnet-beta') + connector: Filter by connector (e.g., 'meteora') + wallet_address: Filter by wallet address + trading_pair: Filter by trading pair (e.g., 'SOL-USDC') + status: Filter by status (OPEN, CLOSED) + position_addresses: Filter by specific position addresses (list of addresses) + limit: Max results (default 50, max 1000) + offset: Pagination offset + refresh: If True, refresh position data from Gateway before returning (default False) + + Returns: + Paginated list of positions + """ + try: + # Validate limit + if limit > 1000: + limit = 1000 + + # Optionally refresh position data from Gateway first + if refresh and await accounts_service.gateway_client.ping(): + # Get positions to refresh + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + positions_to_refresh = await clmm_repo.get_positions( + network=network, + connector=connector, + wallet_address=wallet_address, + trading_pair=trading_pair, + status=status, + position_addresses=position_addresses, + limit=limit, + offset=offset + ) + + # Extract position addresses and details before closing session + position_details = [ + { + "position_address": pos.position_address, + "pool_address": pos.pool_address, + "connector": pos.connector, + "network": pos.network, + "wallet_address": pos.wallet_address + } + for pos in positions_to_refresh + ] + + # Refresh each position in a separate session + logger.info(f"Refreshing {len(position_details)} positions from Gateway") + for pos_detail in position_details: + try: + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + # Get position again in this session + position = await clmm_repo.get_position_by_address(pos_detail["position_address"]) + if position: + await _refresh_position_data(position, accounts_service, clmm_repo) + except Exception as e: + logger.warning(f"Failed to refresh position {pos_detail['position_address']}: {e}") + # Continue with other positions even if one fails + + # Get final results after refresh + async with db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + positions = await clmm_repo.get_positions( + network=network, + connector=connector, + wallet_address=wallet_address, + trading_pair=trading_pair, + status=status, + position_addresses=position_addresses, + limit=limit, + offset=offset + ) + + # Get total count for pagination + has_more = len(positions) == limit + + return { + "data": [clmm_repo.position_to_dict(pos) for pos in positions], + "pagination": { + "limit": limit, + "offset": offset, + "has_more": has_more, + "total_count": len(positions) + offset if not has_more else None + } + } + + except Exception as e: + logger.error(f"Error searching CLMM positions: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error searching CLMM positions: {str(e)}") + + diff --git a/routers/gateway_proxy.py b/routers/gateway_proxy.py new file mode 100644 index 00000000..d735c137 --- /dev/null +++ b/routers/gateway_proxy.py @@ -0,0 +1,131 @@ +""" +Gateway Proxy Router + +Catch-all router that forwards requests to Gateway server unchanged. +Dashboard calls /api/gateway-proxy/* and this router forwards to Gateway at localhost:15888/*. + +This allows the dashboard to access all Gateway endpoints through the API without +needing each endpoint to be explicitly defined. + +Examples: + GET /api/gateway-proxy/wallet -> GET localhost:15888/wallet + POST /api/gateway-proxy/wallet/add -> POST localhost:15888/wallet/add + GET /api/gateway-proxy/config -> GET localhost:15888/config + GET /api/gateway-proxy/trading/clmm/positions-owned -> GET localhost:15888/trading/clmm/positions-owned +""" + +import json +import logging + +import aiohttp +from fastapi import APIRouter, Depends, HTTPException, Request, Response +from fastapi.responses import JSONResponse + +from deps import get_accounts_service +from services.accounts_service import AccountsService + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Gateway Proxy"], prefix="/gateway-proxy") + + +@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) +async def forward_to_gateway( + path: str, + request: Request, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Forward request to Gateway server unchanged. + + This catch-all route forwards any request to /api/gateway-proxy/* to the Gateway server. + The request body, headers, and query parameters are passed through unchanged. + The response from Gateway is returned unchanged. + + Examples: + GET /api/gateway-proxy/wallet -> GET localhost:15888/wallet + POST /api/gateway-proxy/wallet/add -> POST localhost:15888/wallet/add + GET /api/gateway-proxy/config -> GET localhost:15888/config + """ + gateway_client = accounts_service.gateway_client + gateway_url = gateway_client.base_url + + # Build target URL + target_url = f"{gateway_url}/{path}" + + # Get query parameters + query_params = dict(request.query_params) + + # Get request body if present + body = None + if request.method in ["POST", "PUT", "PATCH", "DELETE"]: + try: + body = await request.json() + except Exception: + # No JSON body or invalid JSON - that's OK for some requests + body = None + + try: + # Get or create aiohttp session + session = await gateway_client._get_session() + + # Forward the request + async with session.request( + method=request.method, + url=target_url, + params=query_params if query_params else None, + json=body if body else None, + ) as response: + # Read response body + response_body = await response.read() + + # Try to parse as JSON, otherwise return as-is + content_type = response.headers.get("Content-Type", "") + + if "application/json" in content_type: + try: + json_body = json.loads(response_body) + return JSONResponse( + content=json_body, + status_code=response.status, + ) + except Exception: + pass + + # Return raw response + return Response( + content=response_body, + status_code=response.status, + media_type=content_type or "application/octet-stream", + ) + + except aiohttp.ClientError as e: + logger.error(f"Gateway proxy error: {e}") + raise HTTPException( + status_code=503, + detail=f"Gateway service unavailable: {str(e)}" + ) + except Exception as e: + logger.error(f"Gateway proxy error: {e}") + raise HTTPException( + status_code=500, + detail=f"Gateway proxy error: {str(e)}" + ) + + +# Also expose the root endpoint for health checks +@router.get("") +async def gateway_root( + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Gateway health check. + Forwards to Gateway root endpoint to check if it's online. + """ + gateway_client = accounts_service.gateway_client + result = await gateway_client._request("GET", "") + if result is None: + raise HTTPException(status_code=503, detail="Gateway service unavailable") + if "error" in result: + raise HTTPException(status_code=result.get("status", 500), detail=result["error"]) + return result diff --git a/routers/gateway_swap.py b/routers/gateway_swap.py new file mode 100644 index 00000000..335d9fd6 --- /dev/null +++ b/routers/gateway_swap.py @@ -0,0 +1,369 @@ +""" +Gateway Swap Router - Handles DEX swap operations via Hummingbot Gateway. +Supports Router connectors (Jupiter, 0x) for token swaps. +""" +import logging +from typing import Optional +from decimal import Decimal + +from fastapi import APIRouter, Depends, HTTPException + +from deps import get_accounts_service, get_database_manager +from services.accounts_service import AccountsService +from database import AsyncDatabaseManager +from database.repositories import GatewaySwapRepository +from models import ( + SwapQuoteRequest, + SwapQuoteResponse, + SwapExecuteRequest, + SwapExecuteResponse, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Gateway Swaps"], prefix="/gateway") + + +def get_transaction_status_from_response(gateway_response: dict) -> str: + """ + Determine transaction status from Gateway response. + + Gateway returns status field in the response: + - status: 1 = confirmed + - status: 0 = pending/submitted + + Returns: + "CONFIRMED" if status == 1 + "SUBMITTED" if status == 0 or not present + """ + status = gateway_response.get("status") + + # Status 1 means transaction is confirmed on-chain + if status == 1: + return "CONFIRMED" + + # Status 0 or missing means submitted but not confirmed yet + return "SUBMITTED" + + +@router.post("/swap/quote", response_model=SwapQuoteResponse) +async def get_swap_quote( + request: SwapQuoteRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get a price quote for a swap via router (Jupiter, 0x). + + Example: + connector: 'jupiter' + network: 'solana-mainnet-beta' + trading_pair: 'SOL-USDC' + side: 'BUY' + amount: 1 + slippage_pct: 1 + + Returns: + Quote with price, expected output amount, and gas estimate + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network = accounts_service.gateway_client.parse_network_id(request.network) + + # Parse trading pair + base, quote = request.trading_pair.split("-") + + # Get quote from Gateway + result = await accounts_service.gateway_client.quote_swap( + connector=request.connector, + network=network, + base_asset=base, + quote_asset=quote, + amount=float(request.amount), + side=request.side, + slippage_pct=float(request.slippage_pct) if request.slippage_pct else 1.0, + pool_address=None + ) + + # Extract amounts from Gateway response (snake_case for consistency) + amount_in_raw = result.get("amountIn") or result.get("amount_in") + amount_out_raw = result.get("amountOut") or result.get("amount_out") + + amount_in = Decimal(str(amount_in_raw)) if amount_in_raw else None + amount_out = Decimal(str(amount_out_raw)) if amount_out_raw else None + + # Extract gas estimate (try both camelCase and snake_case) + gas_estimate = result.get("gasEstimate") or result.get("gas_estimate") + gas_estimate_value = Decimal(str(gas_estimate)) if gas_estimate else None + + return SwapQuoteResponse( + base=base, + quote=quote, + price=Decimal(str(result.get("price", 0))), + amount=request.amount, + amount_in=amount_in, + amount_out=amount_out, + expected_amount=amount_out, # Deprecated, kept for backward compatibility + slippage_pct=request.slippage_pct or Decimal("1.0"), + gas_estimate=gas_estimate_value + ) + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error getting swap quote: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting swap quote: {str(e)}") + + +@router.post("/swap/execute", response_model=SwapExecuteResponse) +async def execute_swap( + request: SwapExecuteRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Execute a swap transaction via router (Jupiter, 0x). + + Example: + connector: 'jupiter' + network: 'solana-mainnet-beta' + trading_pair: 'SOL-USDC' + side: 'BUY' + amount: 1 + slippage_pct: 1 + wallet_address: (optional, uses default if not provided) + + Returns: + Transaction hash and swap details + """ + try: + if not await accounts_service.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + # Parse network_id + chain, network = accounts_service.gateway_client.parse_network_id(request.network) + + # Get wallet address + wallet_address = await accounts_service.gateway_client.get_wallet_address_or_default( + chain=chain, + wallet_address=request.wallet_address + ) + + # Parse trading pair + base, quote = request.trading_pair.split("-") + + # Execute swap + result = await accounts_service.gateway_client.execute_swap( + connector=request.connector, + network=network, + wallet_address=wallet_address, + base_asset=base, + quote_asset=quote, + amount=float(request.amount), + side=request.side, + slippage_pct=float(request.slippage_pct) if request.slippage_pct else 1.0 + ) + if not result: + raise HTTPException(status_code=500, detail="Gateway service is not able to execute swap") + transaction_hash = result.get("signature") or result.get("txHash") or result.get("hash") + if not transaction_hash: + raise HTTPException(status_code=500, detail="No transaction hash returned from Gateway") + + # Extract swap data from Gateway response + # Gateway returns amounts nested under 'data' object + data = result.get("data", {}) + amount_in_raw = data.get("amountIn") + amount_out_raw = data.get("amountOut") + + # Use amounts from Gateway response, fallback to request amount if not available + input_amount = Decimal(str(amount_in_raw)) if amount_in_raw is not None else request.amount + output_amount = Decimal(str(amount_out_raw)) if amount_out_raw is not None else Decimal("0") + + # Calculate price from actual swap amounts + # Price = output / input (how much quote you get/pay per base) + price = output_amount / input_amount if input_amount > 0 else Decimal("0") + + # Get transaction status from Gateway response + tx_status = get_transaction_status_from_response(result) + + # Store swap in database + try: + async with db_manager.get_session_context() as session: + swap_repo = GatewaySwapRepository(session) + + swap_data = { + "transaction_hash": transaction_hash, + "network": request.network, + "connector": request.connector, + "wallet_address": wallet_address, + "trading_pair": request.trading_pair, + "base_token": base, + "quote_token": quote, + "side": request.side, + "input_amount": float(input_amount), + "output_amount": float(output_amount), + "price": float(price), + "slippage_pct": float(request.slippage_pct) if request.slippage_pct else 1.0, + "status": tx_status, + "pool_address": result.get("poolAddress") or result.get("pool_address") + } + + await swap_repo.create_swap(swap_data) + logger.info(f"Recorded swap in database: {transaction_hash} (status: {tx_status})") + except Exception as db_error: + # Log but don't fail the swap - it was submitted successfully + logger.error(f"Error recording swap in database: {db_error}", exc_info=True) + + return SwapExecuteResponse( + transaction_hash=transaction_hash, + trading_pair=request.trading_pair, + side=request.side, + amount=request.amount, + status="submitted" + ) + + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Error executing swap: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error executing swap: {str(e)}") + + +@router.get("/swaps/{transaction_hash}/status") +async def get_swap_status( + transaction_hash: str, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get status of a specific swap by transaction hash. + + Args: + transaction_hash: Transaction hash of the swap + + Returns: + Swap details including current status + """ + try: + async with db_manager.get_session_context() as session: + swap_repo = GatewaySwapRepository(session) + swap = await swap_repo.get_swap_by_tx_hash(transaction_hash) + + if not swap: + raise HTTPException(status_code=404, detail=f"Swap not found: {transaction_hash}") + + return swap_repo.to_dict(swap) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting swap status: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting swap status: {str(e)}") + + +@router.post("/swaps/search") +async def search_swaps( + network: Optional[str] = None, + connector: Optional[str] = None, + wallet_address: Optional[str] = None, + trading_pair: Optional[str] = None, + status: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + limit: int = 50, + offset: int = 0, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Search swap history with filters. + + Args: + network: Filter by network (e.g., 'solana-mainnet-beta') + connector: Filter by connector (e.g., 'jupiter') + wallet_address: Filter by wallet address + trading_pair: Filter by trading pair (e.g., 'SOL-USDC') + status: Filter by status (SUBMITTED, CONFIRMED, FAILED) + start_time: Start timestamp (unix seconds) + end_time: End timestamp (unix seconds) + limit: Max results (default 50, max 1000) + offset: Pagination offset + + Returns: + Paginated list of swaps + """ + try: + # Validate limit + if limit > 1000: + limit = 1000 + + async with db_manager.get_session_context() as session: + swap_repo = GatewaySwapRepository(session) + swaps = await swap_repo.get_swaps( + network=network, + connector=connector, + wallet_address=wallet_address, + trading_pair=trading_pair, + status=status, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + + # Get total count for pagination (simplified - actual count would need separate query) + has_more = len(swaps) == limit + + return { + "data": [swap_repo.to_dict(swap) for swap in swaps], + "pagination": { + "limit": limit, + "offset": offset, + "has_more": has_more, + "total_count": len(swaps) + offset if not has_more else None + } + } + + except Exception as e: + logger.error(f"Error searching swaps: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error searching swaps: {str(e)}") + + +@router.get("/swaps/summary") +async def get_swaps_summary( + network: Optional[str] = None, + wallet_address: Optional[str] = None, + start_time: Optional[int] = None, + end_time: Optional[int] = None, + db_manager: AsyncDatabaseManager = Depends(get_database_manager) +): + """ + Get swap summary statistics. + + Args: + network: Filter by network + wallet_address: Filter by wallet address + start_time: Start timestamp (unix seconds) + end_time: End timestamp (unix seconds) + + Returns: + Summary statistics including volume, fees, success rate + """ + try: + async with db_manager.get_session_context() as session: + swap_repo = GatewaySwapRepository(session) + summary = await swap_repo.get_swaps_summary( + network=network, + wallet_address=wallet_address, + start_time=start_time, + end_time=end_time + ) + return summary + + except Exception as e: + logger.error(f"Error getting swaps summary: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error getting swaps summary: {str(e)}") diff --git a/routers/manage_accounts.py b/routers/manage_accounts.py deleted file mode 100644 index 7751f7ce..00000000 --- a/routers/manage_accounts.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Dict, List - -from fastapi import APIRouter, HTTPException -from hummingbot.client.settings import AllConnectorSettings -from starlette import status - -from services.accounts_service import AccountsService -from utils.file_system import FileSystemUtil - -router = APIRouter(tags=["Manage Credentials"]) -file_system = FileSystemUtil(base_path="bots/credentials") -accounts_service = AccountsService() - - -@router.on_event("startup") -async def startup_event(): - accounts_service.start_update_account_state_loop() - - -@router.on_event("shutdown") -async def shutdown_event(): - accounts_service.stop_update_account_state_loop() - - -@router.get("/accounts-state", response_model=Dict[str, Dict[str, List[Dict]]]) -async def get_all_accounts_state(): - return accounts_service.get_accounts_state() - - -@router.get("/account-state-history", response_model=List[Dict]) -async def get_account_state_history(): - """ - Get the historical state of all accounts. - """ - try: - history = accounts_service.load_account_state_history() - return history - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/available-connectors", response_model=List[str]) -async def available_connectors(): - return list(AllConnectorSettings.get_connector_settings().keys()) - - -@router.get("/connector-config-map/{connector_name}", response_model=List[str]) -async def get_connector_config_map(connector_name: str): - return accounts_service.get_connector_config_map(connector_name) - - -@router.get("/all-connectors-config-map", response_model=Dict[str, List[str]]) -async def get_all_connectors_config_map(): - all_config_maps = {} - for connector in list(AllConnectorSettings.get_connector_settings().keys()): - all_config_maps[connector] = accounts_service.get_connector_config_map(connector) - return all_config_maps - - -@router.get("/list-accounts", response_model=List[str]) -async def list_accounts(): - return accounts_service.list_accounts() - - -@router.get("/list-credentials/{account_name}", response_model=List[str]) -async def list_credentials(account_name: str): - try: - return accounts_service.list_credentials(account_name) - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/add-account", status_code=status.HTTP_201_CREATED) -async def add_account(account_name: str): - try: - accounts_service.add_account(account_name) - return {"message": "Credential added successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/delete-account") -async def delete_account(account_name: str): - try: - if account_name == "master_account": - raise HTTPException(status_code=400, detail="Cannot delete master account.") - accounts_service.delete_account(account_name) - return {"message": "Credential deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/delete-credential/{account_name}/{connector_name}") -async def delete_credential(account_name: str, connector_name: str): - try: - accounts_service.delete_credentials(account_name, connector_name) - return {"message": "Credential deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) - - -@router.post("/add-connector-keys/{account_name}/{connector_name}", status_code=status.HTTP_201_CREATED) -async def add_connector_keys(account_name: str, connector_name: str, keys: Dict): - try: - await accounts_service.add_connector_keys(account_name, connector_name, keys) - return {"message": "Connector keys added successfully."} - except Exception as e: - accounts_service.delete_credentials(account_name, connector_name) - raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/manage_backtesting.py b/routers/manage_backtesting.py deleted file mode 100644 index 1c0f9eac..00000000 --- a/routers/manage_backtesting.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Dict, Union - -from fastapi import APIRouter -from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory -from hummingbot.strategy_v2.backtesting.backtesting_engine_base import BacktestingEngineBase -from hummingbot.strategy_v2.backtesting.controllers_backtesting.directional_trading_backtesting import ( - DirectionalTradingBacktesting, -) -from hummingbot.strategy_v2.backtesting.controllers_backtesting.market_making_backtesting import MarketMakingBacktesting -from pydantic import BaseModel - -from config import CONTROLLERS_MODULE, CONTROLLERS_PATH - -router = APIRouter(tags=["Market Backtesting"]) -candles_factory = CandlesFactory() -directional_trading_backtesting = DirectionalTradingBacktesting() -market_making_backtesting = MarketMakingBacktesting() - -BACKTESTING_ENGINES = { - "directional_trading": directional_trading_backtesting, - "market_making": market_making_backtesting -} - - -class BacktestingConfig(BaseModel): - start_time: int = 1672542000 # 2023-01-01 00:00:00 - end_time: int = 1672628400 # 2023-01-01 23:59:00 - backtesting_resolution: str = "1m" - trade_cost: float = 0.0006 - config: Union[Dict, str] - - -@router.post("/run-backtesting") -async def run_backtesting(backtesting_config: BacktestingConfig): - try: - if isinstance(backtesting_config.config, str): - controller_config = BacktestingEngineBase.get_controller_config_instance_from_yml( - config_path=backtesting_config.config, - controllers_conf_dir_path=CONTROLLERS_PATH, - controllers_module=CONTROLLERS_MODULE - ) - else: - controller_config = BacktestingEngineBase.get_controller_config_instance_from_dict( - config_data=backtesting_config.config, - controllers_module=CONTROLLERS_MODULE - ) - backtesting_engine = BACKTESTING_ENGINES.get(controller_config.controller_type) - if not backtesting_engine: - raise ValueError(f"Backtesting engine for controller type {controller_config.controller_type} not found.") - backtesting_results = await backtesting_engine.run_backtesting( - controller_config=controller_config, trade_cost=backtesting_config.trade_cost, - start=int(backtesting_config.start_time), end=int(backtesting_config.end_time), - backtesting_resolution=backtesting_config.backtesting_resolution) - processed_data = backtesting_results["processed_data"]["features"].fillna(0) - executors_info = [e.to_dict() for e in backtesting_results["executors"]] - backtesting_results["processed_data"] = processed_data.to_dict() - results = backtesting_results["results"] - results["sharpe_ratio"] = results["sharpe_ratio"] if results["sharpe_ratio"] is not None else 0 - return { - "executors": executors_info, - "processed_data": backtesting_results["processed_data"], - "results": backtesting_results["results"], - } - except Exception as e: - return {"error": str(e)} diff --git a/routers/manage_broker_messages.py b/routers/manage_broker_messages.py deleted file mode 100644 index e31d1f3f..00000000 --- a/routers/manage_broker_messages.py +++ /dev/null @@ -1,64 +0,0 @@ -from fastapi import APIRouter, HTTPException - -from config import BROKER_HOST, BROKER_PASSWORD, BROKER_PORT, BROKER_USERNAME -from models import ImportStrategyAction, StartBotAction, StopBotAction -from services.bots_orchestrator import BotsManager - -# Initialize the scheduler -router = APIRouter(tags=["Manage Broker Messages"]) -bots_manager = BotsManager(broker_host=BROKER_HOST, broker_port=BROKER_PORT, broker_username=BROKER_USERNAME, - broker_password=BROKER_PASSWORD) - - -@router.on_event("startup") -async def startup_event(): - bots_manager.start_update_active_bots_loop() - - -@router.on_event("shutdown") -async def shutdown_event(): - # Shutdown the scheduler on application exit - bots_manager.stop_update_active_bots_loop() - - -@router.get("/get-active-bots-status") -def get_active_bots_status(): - """Returns the cached status of all active bots.""" - return {"status": "success", "data": bots_manager.get_all_bots_status()} - - -@router.get("/get-bot-status/{bot_name}") -def get_bot_status(bot_name: str): - response = bots_manager.get_bot_status(bot_name) - if not response: - raise HTTPException(status_code=404, detail="Bot not found") - return { - "status": "success", - "data": response - } - - -@router.get("/get-bot-history/{bot_name}") -def get_bot_history(bot_name: str): - response = bots_manager.get_bot_history(bot_name) - return {"status": "success", "response": response} - - -@router.post("/start-bot") -def start_bot(action: StartBotAction): - response = bots_manager.start_bot(action.bot_name, log_level=action.log_level, script=action.script, - conf=action.conf, async_backend=action.async_backend) - return {"status": "success", "response": response} - - -@router.post("/stop-bot") -def stop_bot(action: StopBotAction): - response = bots_manager.stop_bot(action.bot_name, skip_order_cancellation=action.skip_order_cancellation, - async_backend=action.async_backend) - return {"status": "success", "response": response} - - -@router.post("/import-strategy") -def import_strategy(action: ImportStrategyAction): - response = bots_manager.import_strategy_for_bot(action.bot_name, action.strategy) - return {"status": "success", "response": response} diff --git a/routers/manage_docker.py b/routers/manage_docker.py deleted file mode 100644 index 9769cff3..00000000 --- a/routers/manage_docker.py +++ /dev/null @@ -1,84 +0,0 @@ -import logging -import os - -from fastapi import APIRouter, HTTPException - -from models import HummingbotInstanceConfig, ImageName -from services.bot_archiver import BotArchiver -from services.docker_service import DockerManager - -router = APIRouter(tags=["Docker Management"]) -docker_manager = DockerManager() -bot_archiver = BotArchiver(os.environ.get("AWS_API_KEY"), os.environ.get("AWS_SECRET_KEY"), - os.environ.get("S3_DEFAULT_BUCKET_NAME")) - - -@router.get("/is-docker-running") -async def is_docker_running(): - return {"is_docker_running": docker_manager.is_docker_running()} - - -@router.get("/available-images/{image_name}") -async def available_images(image_name: str): - available_images = docker_manager.get_available_images() - image_tags = [tag for image in available_images["images"] for tag in image.tags if image_name in tag] - return {"available_images": image_tags} - - -@router.get("/active-containers") -async def active_containers(): - return docker_manager.get_active_containers() - - -@router.get("/exited-containers") -async def exited_containers(): - return docker_manager.get_exited_containers() - - -@router.post("/clean-exited-containers") -async def clean_exited_containers(): - return docker_manager.clean_exited_containers() - - -@router.post("/remove-container/{container_name}") -async def remove_container(container_name: str, archive_locally: bool = True, s3_bucket: str = None): - # Remove the container - response = docker_manager.remove_container(container_name) - # Form the instance directory path correctly - instance_dir = os.path.join('bots', 'instances', container_name) - try: - # Archive the data - if archive_locally: - bot_archiver.archive_locally(container_name, instance_dir) - else: - bot_archiver.archive_and_upload(container_name, instance_dir, bucket_name=s3_bucket) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - return response - - -@router.post("/stop-container/{container_name}") -async def stop_container(container_name: str): - return docker_manager.stop_container(container_name) - - -@router.post("/start-container/{container_name}") -async def start_container(container_name: str): - return docker_manager.start_container(container_name) - - -@router.post("/create-hummingbot-instance") -async def create_hummingbot_instance(config: HummingbotInstanceConfig): - logging.info(f"Creating hummingbot instance with config: {config}") - response = docker_manager.create_hummingbot_instance(config) - return response - - -@router.post("/pull-image/") -async def pull_image(image: ImageName): - try: - result = docker_manager.pull_image(image.image_name) - return result - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) diff --git a/routers/manage_files.py b/routers/manage_files.py deleted file mode 100644 index 127fbfd8..00000000 --- a/routers/manage_files.py +++ /dev/null @@ -1,161 +0,0 @@ -import json -from typing import Dict, List - -import yaml -from fastapi import APIRouter, File, HTTPException, UploadFile -from starlette import status - -from models import Script, ScriptConfig -from utils.file_system import FileSystemUtil - -router = APIRouter(tags=["Files Management"]) - -file_system = FileSystemUtil() - - -@router.get("/list-scripts", response_model=List[str]) -async def list_scripts(): - return file_system.list_files('scripts') - - -@router.get("/list-scripts-configs", response_model=List[str]) -async def list_scripts_configs(): - return file_system.list_files('conf/scripts') - - -@router.get("/script-config/{script_name}", response_model=dict) -async def get_script_config(script_name: str): - """ - Retrieves the configuration parameters for a given script. - :param script_name: The name of the script. - :return: JSON containing the configuration parameters. - """ - config_class = file_system.load_script_config_class(script_name) - if config_class is None: - raise HTTPException(status_code=404, detail="Script configuration class not found") - - # Extracting fields and default values - config_fields = {field.name: field.default for field in config_class.__fields__.values()} - return json.loads(json.dumps(config_fields, default=str)) # Handling non-serializable types like Decimal - - -@router.get("/list-controllers", response_model=dict) -async def list_controllers(): - directional_trading_controllers = [file for file in file_system.list_files('controllers/directional_trading') if - file != "__init__.py"] - market_making_controllers = [file for file in file_system.list_files('controllers/market_making') if - file != "__init__.py"] - return {"directional_trading": directional_trading_controllers, "market_making": market_making_controllers} - - -@router.get("/list-controllers-configs", response_model=List[str]) -async def list_controllers_configs(): - return file_system.list_files('conf/controllers') - - -@router.get("/controller-config/{controller_name}", response_model=dict) -async def get_controller_config(controller_name: str): - config = file_system.read_yaml_file(f"bots/conf/controllers/{controller_name}.yml") - return config - - -@router.get("/all-controller-configs", response_model=List[dict]) -async def get_all_controller_configs(): - configs = [] - for controller in file_system.list_files('conf/controllers'): - config = file_system.read_yaml_file(f"bots/conf/controllers/{controller}") - configs.append(config) - return configs - - -@router.get("/all-controller-configs/bot/{bot_name}", response_model=List[dict]) -async def get_all_controller_configs_for_bot(bot_name: str): - configs = [] - bots_config_path = f"instances/{bot_name}/conf/controllers" - if not file_system.path_exists(bots_config_path): - raise HTTPException(status_code=400, detail="Bot not found.") - for controller in file_system.list_files(bots_config_path): - config = file_system.read_yaml_file(f"bots/{bots_config_path}/{controller}") - configs.append(config) - return configs - - -@router.post("/update-controller-config/bot/{bot_name}/{controller_id}") -async def update_controller_config(bot_name: str, controller_id: str, config: Dict): - bots_config_path = f"instances/{bot_name}/conf/controllers" - if not file_system.path_exists(bots_config_path): - raise HTTPException(status_code=400, detail="Bot not found.") - current_config = file_system.read_yaml_file(f"bots/{bots_config_path}/{controller_id}.yml") - current_config.update(config) - file_system.dump_dict_to_yaml(f"bots/{bots_config_path}/{controller_id}.yml", current_config) - return {"message": "Controller configuration updated successfully."} - - -@router.post("/add-script", status_code=status.HTTP_201_CREATED) -async def add_script(script: Script, override: bool = False): - try: - file_system.add_file('scripts', script.name + '.py', script.content, override) - return {"message": "Script added successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/upload-script") -async def upload_script(config_file: UploadFile = File(...), override: bool = False): - try: - contents = await config_file.read() - file_system.add_file('scripts', config_file.filename, contents.decode(), override) - return {"message": "Script uploaded successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/add-script-config", status_code=status.HTTP_201_CREATED) -async def add_script_config(config: ScriptConfig): - try: - yaml_content = yaml.dump(config.content) - - file_system.add_file('conf/scripts', config.name + '.yml', yaml_content, override=True) - return {"message": "Script configuration uploaded successfully."} - except Exception as e: # Consider more specific exception handling - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/upload-script-config") -async def upload_script_config(config_file: UploadFile = File(...), override: bool = False): - try: - contents = await config_file.read() - file_system.add_file('conf/scripts', config_file.filename, contents.decode(), override) - return {"message": "Script configuration uploaded successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/add-controller-config", status_code=status.HTTP_201_CREATED) -async def add_controller_config(config: ScriptConfig): - try: - yaml_content = yaml.dump(config.content) - - file_system.add_file('conf/controllers', config.name + '.yml', yaml_content, override=True) - return {"message": "Controller configuration uploaded successfully."} - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/upload-controller-config") -async def upload_controller_config(config_file: UploadFile = File(...), override: bool = False): - try: - contents = await config_file.read() - file_system.add_file('conf/controllers', config_file.filename, contents.decode(), override) - return {"message": "Controller configuration uploaded successfully."} - except FileExistsError as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.post("/delete-controller-config", status_code=status.HTTP_200_OK) -async def delete_controller_config(config_name: str): - try: - file_system.delete_file('conf/controllers', config_name) - return {"message": f"Controller configuration {config_name} deleted successfully."} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) diff --git a/routers/manage_market_data.py b/routers/manage_market_data.py deleted file mode 100644 index 13529f6a..00000000 --- a/routers/manage_market_data.py +++ /dev/null @@ -1,45 +0,0 @@ -import asyncio - -from fastapi import APIRouter -from hummingbot.data_feed.candles_feed.candles_factory import CandlesConfig, CandlesFactory -from pydantic import BaseModel - -router = APIRouter(tags=["Market Data"]) -candles_factory = CandlesFactory() - - -class HistoricalCandlesConfig(BaseModel): - connector_name: str = "binance_perpetual" - trading_pair: str = "BTC-USDT" - interval: str = "3m" - start_time: int = 1672542000 - end_time: int = 1672628400 - - -@router.post("/real-time-candles") -async def get_candles(candles_config: CandlesConfig): - try: - candles = candles_factory.get_candle(candles_config) - candles.start() - while not candles.ready: - await asyncio.sleep(1) - df = candles.candles_df - candles.stop() - df.drop_duplicates(subset=["timestamp"], inplace=True) - return df - except Exception as e: - return {"error": str(e)} - - -@router.post("/historical-candles") -async def get_historical_candles(config: HistoricalCandlesConfig): - try: - candles_config = CandlesConfig( - connector=config.connector_name, - trading_pair=config.trading_pair, - interval=config.interval - ) - candles = candles_factory.get_candle(candles_config) - return await candles.get_historical_candles(config=config) - except Exception as e: - return {"error": str(e)} diff --git a/routers/market_data.py b/routers/market_data.py new file mode 100644 index 00000000..da0b9de9 --- /dev/null +++ b/routers/market_data.py @@ -0,0 +1,674 @@ +import asyncio +import logging +import time + +from fastapi import APIRouter, Request, HTTPException, Depends +from hummingbot.data_feed.candles_feed.data_types import HistoricalCandlesConfig, CandlesConfig +from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory, UnsupportedConnectorException + +from config import settings +from models.market_data import CandlesConfigRequest +from services.market_data_service import MarketDataService +from models import ( + PriceRequest, PricesResponse, FundingInfoRequest, FundingInfoResponse, + OrderBookRequest, OrderBookResponse, OrderBookLevel, + VolumeForPriceRequest, PriceForVolumeRequest, QuoteVolumeForPriceRequest, + PriceForQuoteVolumeRequest, VWAPForVolumeRequest, OrderBookQueryResult, + AddTradingPairRequest, RemoveTradingPairRequest, TradingPairResponse +) +from deps import get_market_data_service + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["Market Data"], prefix="/market-data") + + +@router.post("/candles") +async def get_candles(request: Request, candles_config: CandlesConfigRequest): + """ + Get real-time candles data for a specific trading pair. + + This endpoint uses the MarketDataProvider to get or create a candles feed that will + automatically start and maintain real-time updates. Subsequent requests with the same + configuration will reuse the existing feed for up-to-date data. + + Args: + request: FastAPI request object + candles_config: Configuration for the candles including connector, trading_pair, interval, and max_records + + Returns: + Real-time candles data or error message + """ + available = list(CandlesFactory._candles_map.keys()) + if candles_config.connector_name not in CandlesFactory._candles_map: + raise HTTPException( + status_code=400, + detail=f"Unsupported connector '{candles_config.connector_name}'. " + f"Available connectors: {available}" + ) + + if "-" not in candles_config.trading_pair: + raise HTTPException( + status_code=400, + detail=f"Invalid trading pair format '{candles_config.trading_pair}'. " + f"Expected format: BASE-QUOTE (e.g., BTC-USDT)" + ) + + try: + market_data_service: MarketDataService = request.app.state.market_data_service + + candles_cfg = CandlesConfig( + connector=candles_config.connector_name, trading_pair=candles_config.trading_pair, + interval=candles_config.interval, max_records=candles_config.max_records) + candles_feed = market_data_service.get_candles_feed(candles_cfg) + + # Wait for the candles feed to be ready with a timeout + timeout = settings.market_data.candles_ready_timeout + start = time.time() + while not candles_feed.ready: + if time.time() - start > timeout: + # Clean up the stale feed so it doesn't stay cached + market_data_service.stop_candle_feed(candles_cfg) + raise HTTPException( + status_code=504, + detail=f"Candle feed for {candles_config.connector_name} " + f"{candles_config.trading_pair} did not become ready within " + f"{timeout}s. The trading pair may not exist on this exchange." + ) + await asyncio.sleep(0.1) + + df = candles_feed.candles_df + + if df is not None and not df.empty: + df = df.tail(candles_config.max_records) + df = df.drop_duplicates(subset=["timestamp"], keep="last") + return df.to_dict(orient="records") + else: + raise HTTPException(status_code=404, detail="No candles data available") + + except HTTPException: + raise + except UnsupportedConnectorException as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(f"Unexpected error fetching candles: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal error fetching candles: {str(e)}") + + +@router.post("/historical-candles") +async def get_historical_candles(request: Request, config: HistoricalCandlesConfig): + """ + Get historical candles data for a specific trading pair. + + Args: + config: Configuration for historical candles including connector, trading pair, interval, start and end time + + Returns: + Historical candles data or error message + """ + available = list(CandlesFactory._candles_map.keys()) + if config.connector_name not in CandlesFactory._candles_map: + raise HTTPException( + status_code=400, + detail=f"Unsupported connector '{config.connector_name}'. " + f"Available connectors: {available}" + ) + + if "-" not in config.trading_pair: + raise HTTPException( + status_code=400, + detail=f"Invalid trading pair format '{config.trading_pair}'. " + f"Expected format: BASE-QUOTE (e.g., BTC-USDT)" + ) + + try: + market_data_service: MarketDataService = request.app.state.market_data_service + + candles_config = CandlesConfig( + connector=config.connector_name, + trading_pair=config.trading_pair, + interval=config.interval + ) + + candles = market_data_service.get_candles_feed(candles_config) + + timeout = settings.market_data.candles_ready_timeout + historical_data = await asyncio.wait_for( + candles.get_historical_candles(config=config), + timeout=timeout + ) + + if historical_data is not None and not historical_data.empty: + return historical_data.to_dict(orient="records") + else: + raise HTTPException(status_code=404, detail="No historical data available") + + except HTTPException: + raise + except asyncio.TimeoutError: + raise HTTPException( + status_code=504, + detail=f"Historical candles request for {config.connector_name} " + f"{config.trading_pair} timed out after " + f"{settings.market_data.candles_ready_timeout}s. " + f"The trading pair may not exist or the time range may be too large." + ) + except UnsupportedConnectorException as e: + raise HTTPException(status_code=400, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=400, detail=f"Invalid parameters: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error fetching historical candles: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal error fetching historical candles: {str(e)}") + + +@router.get("/active-feeds") +async def get_active_feeds(request: Request): + """ + Get information about currently active market data feeds. + + Args: + request: FastAPI request object to access application state + + Returns: + Dictionary with active feeds information including last access times and expiration + """ + try: + market_data_service: MarketDataService = request.app.state.market_data_service + return market_data_service.get_active_feeds_info() + except Exception as e: + return {"error": str(e)} + + +@router.get("/settings") +async def get_market_data_settings(): + """ + Get current market data settings for debugging. + + Returns: + Dictionary with current market data configuration including cleanup and timeout settings + """ + from config import settings + return { + "cleanup_interval": settings.market_data.cleanup_interval, + "feed_timeout": settings.market_data.feed_timeout, + "description": "cleanup_interval: seconds between cleanup runs, feed_timeout: seconds before unused feeds expire" + } + + +@router.get("/available-candle-connectors") +async def get_available_candle_connectors(): + """ + Get list of available connectors that support candle data feeds. + + Returns: + List of connector names that can be used for fetching candle data + """ + return list(CandlesFactory._candles_map.keys()) + + +# Enhanced Market Data Endpoints + +@router.post("/prices", response_model=PricesResponse) +async def get_prices( + request: PriceRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get current prices for specified trading pairs from a connector. + + Args: + request: Price request with connector name and trading pairs + market_data_manager: Injected market data feed manager + + Returns: + Current prices for the specified trading pairs + + Raises: + HTTPException: 500 if there's an error fetching prices + """ + try: + prices = await market_data_manager.get_prices( + request.connector_name, + request.trading_pairs + ) + + if "error" in prices: + raise HTTPException(status_code=500, detail=prices["error"]) + + return PricesResponse( + connector=request.connector_name, + prices=prices, + timestamp=time.time() + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching prices: {str(e)}") + + +@router.post("/funding-info", response_model=FundingInfoResponse) +async def get_funding_info( + request: FundingInfoRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get funding information for a perpetual trading pair. + + Args: + request: Funding info request with connector name and trading pair + market_data_manager: Injected market data feed manager + + Returns: + Funding information including rates, timestamps, and prices + + Raises: + HTTPException: 400 for non-perpetual connectors, 500 for other errors + """ + try: + if "_perpetual" not in request.connector_name.lower(): + raise HTTPException(status_code=400, detail="Funding info is only available for perpetual trading pairs.") + funding_info = await market_data_manager.get_funding_info( + request.connector_name, + request.trading_pair + ) + + if "error" in funding_info: + if "not supported" in funding_info["error"]: + raise HTTPException(status_code=400, detail=funding_info["error"]) + else: + raise HTTPException(status_code=500, detail=funding_info["error"]) + + return FundingInfoResponse(**funding_info) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching funding info: {str(e)}") + + +@router.post("/order-book", response_model=OrderBookResponse) +async def get_order_book( + request: OrderBookRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get order book snapshot with specified depth. + + Args: + request: Order book request with connector, trading pair, and depth + market_data_manager: Injected market data feed manager + + Returns: + Order book snapshot with bids and asks + + Raises: + HTTPException: 500 if there's an error fetching order book + """ + try: + order_book_data = await market_data_manager.get_order_book_data( + request.connector_name, + request.trading_pair, + request.depth + ) + + if "error" in order_book_data: + raise HTTPException(status_code=500, detail=order_book_data["error"]) + + # Convert to response format - data comes as [price, amount] lists + bids = [OrderBookLevel(price=bid[0], amount=bid[1]) for bid in order_book_data["bids"]] + asks = [OrderBookLevel(price=ask[0], amount=ask[1]) for ask in order_book_data["asks"]] + + return OrderBookResponse( + trading_pair=order_book_data["trading_pair"], + bids=bids, + asks=asks, + timestamp=order_book_data["timestamp"] + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching order book: {str(e)}") + + +# Order Book Query Endpoints + +@router.post("/order-book/price-for-volume", response_model=OrderBookQueryResult) +async def get_price_for_volume( + request: PriceForVolumeRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get the price required to fill a specific volume on the order book. + + Args: + request: Request with connector, trading pair, volume, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with price and volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + volume=request.volume + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/volume-for-price", response_model=OrderBookQueryResult) +async def get_volume_for_price( + request: VolumeForPriceRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get the volume available at a specific price level on the order book. + + Args: + request: Request with connector, trading pair, price, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + price=request.price + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/price-for-quote-volume", response_model=OrderBookQueryResult) +async def get_price_for_quote_volume( + request: PriceForQuoteVolumeRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get the price required to fill a specific quote volume on the order book. + + Args: + request: Request with connector, trading pair, quote volume, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with price and volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + quote_volume=request.quote_volume + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/quote-volume-for-price", response_model=OrderBookQueryResult) +async def get_quote_volume_for_price( + request: QuoteVolumeForPriceRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get the quote volume available at a specific price level on the order book. + + Args: + request: Request with connector, trading pair, price, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with quote volume information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + quote_price=request.price + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +@router.post("/order-book/vwap-for-volume", response_model=OrderBookQueryResult) +async def get_vwap_for_volume( + request: VWAPForVolumeRequest, + market_data_manager: MarketDataService = Depends(get_market_data_service) +): + """ + Get the VWAP (Volume Weighted Average Price) for a specific volume on the order book. + + Args: + request: Request with connector, trading pair, volume, and side + market_data_manager: Injected market data feed manager + + Returns: + Order book query result with VWAP information + """ + try: + result = await market_data_manager.get_order_book_query_result( + request.connector_name, + request.trading_pair, + request.is_buy, + vwap_volume=request.volume + ) + + if "error" in result: + raise HTTPException(status_code=500, detail=result["error"]) + + return OrderBookQueryResult(**result) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error in order book query: {str(e)}") + + +# Trading Pair Management Endpoints + +@router.post("/trading-pair/add", response_model=TradingPairResponse) +async def add_trading_pair( + request: AddTradingPairRequest, + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Initialize order book for a trading pair. + + This endpoint dynamically adds a trading pair to a connector's order book tracker. + It uses the best available connector (trading connectors are preferred over data connectors). + + Args: + request: Request with connector name, trading pair, optional account name, and timeout + + Returns: + TradingPairResponse with success status and message + + Raises: + HTTPException: 500 if initialization fails + """ + try: + success = await market_data_service.initialize_order_book( + connector_name=request.connector_name, + trading_pair=request.trading_pair, + account_name=request.account_name, + timeout=request.timeout + ) + + if success: + return TradingPairResponse( + success=True, + connector_name=request.connector_name, + trading_pair=request.trading_pair, + message=f"Order book initialized for {request.trading_pair}" + ) + else: + raise HTTPException( + status_code=500, + detail=f"Failed to initialize order book for {request.trading_pair}" + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error initializing order book: {str(e)}" + ) + + +@router.post("/trading-pair/remove", response_model=TradingPairResponse) +async def remove_trading_pair( + request: RemoveTradingPairRequest, + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Remove a trading pair from order book tracking. + + This endpoint removes a trading pair from a connector's order book tracker, + cleaning up resources for pairs that are no longer needed. + + Args: + request: Request with connector name, trading pair, and optional account name + + Returns: + TradingPairResponse with success status and message + + Raises: + HTTPException: 500 if removal fails + """ + try: + success = await market_data_service.remove_trading_pair( + connector_name=request.connector_name, + trading_pair=request.trading_pair, + account_name=request.account_name + ) + + if success: + return TradingPairResponse( + success=True, + connector_name=request.connector_name, + trading_pair=request.trading_pair, + message=f"Trading pair {request.trading_pair} removed" + ) + else: + return TradingPairResponse( + success=False, + connector_name=request.connector_name, + trading_pair=request.trading_pair, + message=f"Trading pair {request.trading_pair} not found or already removed" + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error removing trading pair: {str(e)}" + ) + + +# Order Book Tracker Diagnostics Endpoints + +@router.get("/order-book/diagnostics/{connector_name}") +async def get_order_book_diagnostics( + connector_name: str, + account_name: str = None, + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Get diagnostics for a connector's order book tracker. + + Returns detailed information about the order book tracker status including: + - Task status (running/crashed) + - WebSocket connection status + - Metrics (messages processed, latency, etc.) + - Current order book state + + Args: + connector_name: The connector to diagnose (e.g., "binance") + account_name: Optional account name for trading connectors + + Returns: + Diagnostic information dictionary + """ + try: + diagnostics = market_data_service.get_order_book_tracker_diagnostics( + connector_name=connector_name, + account_name=account_name + ) + return diagnostics + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error getting diagnostics: {str(e)}" + ) + + +@router.post("/order-book/restart/{connector_name}") +async def restart_order_book_tracker( + connector_name: str, + account_name: str = None, + market_data_service: MarketDataService = Depends(get_market_data_service) +): + """ + Restart the order book tracker for a connector. + + Use this endpoint when the order book is stale (WebSocket disconnected). + This will: + 1. Stop the existing order book tracker + 2. Restart it with the same trading pairs + 3. Wait for the WebSocket to reconnect + + Args: + connector_name: The connector to restart (e.g., "binance") + account_name: Optional account name for trading connectors + + Returns: + Restart status with success/failure and trading pairs + """ + try: + result = await market_data_service.restart_order_book_tracker( + connector_name=connector_name, + account_name=account_name + ) + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error restarting order book tracker: {str(e)}" + ) + + diff --git a/routers/portfolio.py b/routers/portfolio.py new file mode 100644 index 00000000..671bd07f --- /dev/null +++ b/routers/portfolio.py @@ -0,0 +1,303 @@ +from typing import Dict, List, Optional +from datetime import datetime + +from fastapi import APIRouter, HTTPException, Depends, Query + +from models.trading import ( + PortfolioStateFilterRequest, + PortfolioHistoryFilterRequest, + PortfolioDistributionFilterRequest, +) +from services.accounts_service import AccountsService +from deps import get_accounts_service +from models import PaginatedResponse + +router = APIRouter(tags=["Portfolio"], prefix="/portfolio") + + +@router.post("/state", response_model=Dict[str, Dict[str, List[Dict]]]) +async def get_portfolio_state( + filter_request: PortfolioStateFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get the current state of all or filtered accounts portfolio. + + Args: + filter_request: JSON payload with filtering criteria including: + - account_names: Optional list of account names to filter by + - connector_names: Optional list of connector names to filter by + - skip_gateway: If True, skip Gateway wallet balance updates for faster CEX-only queries + - refresh: If True, refresh balances before returning. If False (default), return cached state + + Returns: + Dict containing account states with connector balances and token information + """ + # Only refresh balances if explicitly requested + if filter_request.refresh: + await accounts_service.update_account_state( + skip_gateway=filter_request.skip_gateway, + account_names=filter_request.account_names, + connector_names=filter_request.connector_names + ) + + all_states = accounts_service.get_accounts_state() + + # Apply account name filter first + if filter_request.account_names: + filtered_states = {} + for account_name in filter_request.account_names: + if account_name in all_states: + filtered_states[account_name] = all_states[account_name] + all_states = filtered_states + + # Apply connector filter if specified + if filter_request.connector_names: + for account_name, account_data in all_states.items(): + # Filter connectors directly (they are at the top level of account_data) + filtered_connectors = {} + for connector_name in filter_request.connector_names: + if connector_name in account_data: + filtered_connectors[connector_name] = account_data[connector_name] + # Replace account_data with only filtered connectors + all_states[account_name] = filtered_connectors + + return all_states + + +@router.post("/history", response_model=PaginatedResponse) +async def get_portfolio_history( + filter_request: PortfolioHistoryFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get the historical state of all or filtered accounts portfolio with pagination and interval sampling. + + The interval parameter allows you to control data granularity: + - 5m: Raw data (default, collected every 5 minutes) + - 15m: One data point every 15 minutes + - 30m: One data point every 30 minutes + - 1h: One data point every hour + - 4h: One data point every 4 hours + - 12h: One data point every 12 hours + - 1d: One data point every day + + Using larger intervals significantly reduces response size and improves performance. + + Args: + filter_request: JSON payload with filtering criteria (account_names, connector_names, + start_time, end_time, limit, cursor, interval) + + Returns: + Paginated response with historical portfolio data sampled at the requested interval + """ + try: + # Convert integer timestamps to datetime objects + start_time_dt = datetime.fromtimestamp(filter_request.start_time / 1000) if filter_request.start_time else None + end_time_dt = datetime.fromtimestamp(filter_request.end_time / 1000) if filter_request.end_time else None + + if not filter_request.account_names: + # Get history for all accounts + data, next_cursor, has_more = await accounts_service.load_account_state_history( + limit=filter_request.limit, + cursor=filter_request.cursor, + start_time=start_time_dt, + end_time=end_time_dt, + interval=filter_request.interval + ) + else: + # Get history for specific accounts - need to aggregate + all_data = [] + for account_name in filter_request.account_names: + acc_data, _, _ = await accounts_service.get_account_state_history( + account_name=account_name, + limit=filter_request.limit, + cursor=filter_request.cursor, + start_time=start_time_dt, + end_time=end_time_dt, + interval=filter_request.interval + ) + all_data.extend(acc_data) + + # Sort by timestamp and apply pagination + all_data.sort(key=lambda x: x.get("timestamp", ""), reverse=True) + + # Apply limit + data = all_data[:filter_request.limit] + has_more = len(all_data) > filter_request.limit + next_cursor = data[-1]["timestamp"] if data and has_more else None + + # Apply connector filter to the data if specified + if filter_request.connector_names: + for item in data: + for account_name, account_data in item.items(): + if isinstance(account_data, dict) and "connectors" in account_data: + filtered_connectors = {} + for connector_name in filter_request.connector_names: + if connector_name in account_data["connectors"]: + filtered_connectors[connector_name] = account_data["connectors"][connector_name] + account_data["connectors"] = filtered_connectors + + return PaginatedResponse( + data=data, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "current_cursor": filter_request.cursor, + "filters": { + "account_names": filter_request.account_names, + "connector_names": filter_request.connector_names, + "start_time": filter_request.start_time, + "end_time": filter_request.end_time, + "interval": filter_request.interval + } + } + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/distribution") +async def get_portfolio_distribution( + filter_request: PortfolioDistributionFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get portfolio distribution by tokens with percentages across all or filtered accounts. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Dictionary with token distribution including percentages, values, and breakdown by accounts/connectors + """ + if not filter_request.account_names: + # Get distribution for all accounts + distribution = accounts_service.get_portfolio_distribution() + elif len(filter_request.account_names) == 1: + # Single account - use existing method + distribution = accounts_service.get_portfolio_distribution(filter_request.account_names[0]) + else: + # Multiple accounts - need to aggregate + aggregated_distribution = { + "tokens": {}, + "total_value": 0, + "token_count": 0, + "accounts": {} + } + + for account_name in filter_request.account_names: + account_dist = accounts_service.get_portfolio_distribution(account_name) + + # Skip if account doesn't exist or has error + if account_dist.get("error") or account_dist.get("token_count", 0) == 0: + continue + + # Aggregate token data + for token, token_data in account_dist.get("tokens", {}).items(): + if token not in aggregated_distribution["tokens"]: + aggregated_distribution["tokens"][token] = { + "token": token, + "value": 0, + "percentage": 0, + "accounts": {} + } + + aggregated_distribution["tokens"][token]["value"] += token_data.get("value", 0) + + # Copy account-specific data + for acc_name, acc_data in token_data.get("accounts", {}).items(): + aggregated_distribution["tokens"][token]["accounts"][acc_name] = acc_data + + aggregated_distribution["total_value"] += account_dist.get("total_value", 0) + aggregated_distribution["accounts"][account_name] = account_dist.get("accounts", {}).get(account_name, {}) + + # Recalculate percentages + total_value = aggregated_distribution["total_value"] + if total_value > 0: + for token_data in aggregated_distribution["tokens"].values(): + token_data["percentage"] = (token_data["value"] / total_value) * 100 + + aggregated_distribution["token_count"] = len(aggregated_distribution["tokens"]) + + distribution = aggregated_distribution + + # Apply connector filter if specified + if filter_request.connector_names: + filtered_distribution = [] + filtered_total_value = 0 + + for token_data in distribution.get("distribution", []): + filtered_token = { + "token": token_data["token"], + "total_value": 0, + "total_units": 0, + "percentage": 0, + "accounts": {} + } + + # Filter each account's connectors + for account_name, account_data in token_data.get("accounts", {}).items(): + if "connectors" in account_data: + filtered_connectors = {} + account_value = 0 + account_units = 0 + + # Only include specified connectors + for connector_name in filter_request.connector_names: + if connector_name in account_data["connectors"]: + filtered_connectors[connector_name] = account_data["connectors"][connector_name] + account_value += account_data["connectors"][connector_name].get("value", 0) + account_units += account_data["connectors"][connector_name].get("units", 0) + + # Only include account if it has matching connectors + if filtered_connectors: + filtered_token["accounts"][account_name] = { + "value": round(account_value, 6), + "units": account_units, + "percentage": 0, # Will be recalculated later + "connectors": filtered_connectors + } + + filtered_token["total_value"] += account_value + filtered_token["total_units"] += account_units + + # Only include token if it has values after filtering + if filtered_token["total_value"] > 0: + filtered_distribution.append(filtered_token) + filtered_total_value += filtered_token["total_value"] + + # Recalculate percentages after filtering + if filtered_total_value > 0: + for token_data in filtered_distribution: + token_data["percentage"] = round((token_data["total_value"] / filtered_total_value) * 100, 4) + # Update account percentages + for account_data in token_data["accounts"].values(): + account_data["percentage"] = round((account_data["value"] / filtered_total_value) * 100, 4) + + # Sort by value (descending) + filtered_distribution.sort(key=lambda x: x["total_value"], reverse=True) + + # Update the distribution + distribution = { + "total_portfolio_value": round(filtered_total_value, 6), + "token_count": len(filtered_distribution), + "distribution": filtered_distribution, + "account_filter": distribution.get("account_filter", "filtered") + } + + return distribution + + +@router.get("/accounts-distribution") +async def get_accounts_distribution( + accounts_service: AccountsService = Depends(get_accounts_service) +): + """ + Get portfolio distribution by accounts with percentages. + + Returns: + Dictionary with account distribution including percentages, values, and breakdown by connectors + """ + return accounts_service.get_account_distribution() diff --git a/routers/rate_oracle.py b/routers/rate_oracle.py new file mode 100644 index 00000000..36838a3d --- /dev/null +++ b/routers/rate_oracle.py @@ -0,0 +1,332 @@ +""" +Rate Oracle router for managing rate oracle configuration and retrieving rates. + +Provides CRUD endpoints for rate_oracle_source and global_token configuration, +with persistence to conf_client.yml. +""" + +from typing import List +from decimal import Decimal + +from fastapi import APIRouter, Request, HTTPException +from hummingbot.core.rate_oracle.rate_oracle import RateOracle, RATE_ORACLE_SOURCES + +from models.rate_oracle import ( + RateOracleConfig, + RateOracleConfigResponse, + RateOracleConfigUpdateRequest, + RateOracleConfigUpdateResponse, + RateOracleSourceConfig, + GlobalTokenConfig, + RateRequest, + RateResponse, + SingleRateResponse, +) +from utils.file_system import FileSystemUtil + +router = APIRouter(tags=["Rate Oracle"], prefix="/rate-oracle") + +# Path to conf_client.yml relative to the FileSystemUtil base_path ("bots") +CONF_CLIENT_PATH = "credentials/master_account/conf_client.yml" + + +def get_rate_oracle(request: Request) -> RateOracle: + """Get RateOracle instance from the market data service.""" + return request.app.state.market_data_service.rate_oracle + + +def get_file_system_util() -> FileSystemUtil: + """Get FileSystemUtil instance.""" + return FileSystemUtil() + + +@router.get("/sources", response_model=List[str]) +async def get_available_sources(): + """ + Get list of all available rate oracle sources. + + Returns: + List of available source names that can be configured + """ + return list(RATE_ORACLE_SOURCES.keys()) + + +@router.get("/config", response_model=RateOracleConfigResponse) +async def get_rate_oracle_config(request: Request): + """ + Get current rate oracle configuration. + + Returns the current rate_oracle_source and global_token settings, + along with the list of available sources. + + Returns: + Current rate oracle configuration and available sources + """ + try: + fs_util = get_file_system_util() + + # Read current config from file + config_data = fs_util.read_yaml_file(CONF_CLIENT_PATH) + + # Extract rate_oracle_source + rate_oracle_source_data = config_data.get("rate_oracle_source", {}) + source_name = rate_oracle_source_data.get("name", "binance") + + # Extract global_token + global_token_data = config_data.get("global_token", {}) + global_token_name = global_token_data.get("global_token_name", "USDT") + global_token_symbol = global_token_data.get("global_token_symbol", "$") + + return RateOracleConfigResponse( + rate_oracle_source=RateOracleSourceConfig(name=source_name), + global_token=GlobalTokenConfig( + global_token_name=global_token_name, + global_token_symbol=global_token_symbol + ), + available_sources=list(RATE_ORACLE_SOURCES.keys()) + ) + + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Configuration file not found: {CONF_CLIENT_PATH}" + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error reading configuration: {str(e)}" + ) + + +@router.put("/config", response_model=RateOracleConfigUpdateResponse) +async def update_rate_oracle_config( + request: Request, + update_request: RateOracleConfigUpdateRequest +): + """ + Update rate oracle configuration. + + Updates rate_oracle_source and/or global_token settings. Changes are: + 1. Applied to the running RateOracle instance immediately + 2. Persisted to conf_client.yml + + Args: + update_request: Configuration updates to apply + + Returns: + Updated configuration with success status + """ + try: + fs_util = get_file_system_util() + rate_oracle = get_rate_oracle(request) + + # Read current config + config_data = fs_util.read_yaml_file(CONF_CLIENT_PATH) + + # Track if we made changes + changes_made = [] + + # Update rate_oracle_source if provided + if update_request.rate_oracle_source is not None: + new_source_name = update_request.rate_oracle_source.name.value + + # Validate source exists + if new_source_name not in RATE_ORACLE_SOURCES: + raise HTTPException( + status_code=400, + detail=f"Invalid rate oracle source: {new_source_name}. " + f"Available sources: {list(RATE_ORACLE_SOURCES.keys())}" + ) + + # Update config data + if "rate_oracle_source" not in config_data: + config_data["rate_oracle_source"] = {} + config_data["rate_oracle_source"]["name"] = new_source_name + + # Update running RateOracle instance + new_source_class = RATE_ORACLE_SOURCES[new_source_name] + rate_oracle.source = new_source_class() + + changes_made.append(f"rate_oracle_source updated to {new_source_name}") + + # Update global_token if provided + if update_request.global_token is not None: + if "global_token" not in config_data: + config_data["global_token"] = {} + + if update_request.global_token.global_token_name is not None: + config_data["global_token"]["global_token_name"] = update_request.global_token.global_token_name + # Update RateOracle quote token + rate_oracle.quote_token = update_request.global_token.global_token_name + changes_made.append(f"global_token_name updated to {update_request.global_token.global_token_name}") + + if update_request.global_token.global_token_symbol is not None: + config_data["global_token"]["global_token_symbol"] = update_request.global_token.global_token_symbol + changes_made.append(f"global_token_symbol updated to {update_request.global_token.global_token_symbol}") + + # Persist changes to file + if changes_made: + fs_util.dump_dict_to_yaml(CONF_CLIENT_PATH, config_data) + + # Build response + current_source = config_data.get("rate_oracle_source", {}).get("name", "binance") + current_global_token = config_data.get("global_token", {}) + + return RateOracleConfigUpdateResponse( + success=True, + message="; ".join(changes_made) if changes_made else "No changes made", + config=RateOracleConfig( + rate_oracle_source=RateOracleSourceConfig(name=current_source), + global_token=GlobalTokenConfig( + global_token_name=current_global_token.get("global_token_name", "USDT"), + global_token_symbol=current_global_token.get("global_token_symbol", "$") + ) + ) + ) + + except HTTPException: + raise + except FileNotFoundError: + raise HTTPException( + status_code=404, + detail=f"Configuration file not found: {CONF_CLIENT_PATH}" + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error updating configuration: {str(e)}" + ) + + +@router.post("/rates", response_model=RateResponse) +async def get_rates(request: Request, rate_request: RateRequest): + """ + Get rates for specified trading pairs. + + Uses the configured rate oracle source to fetch current rates. + + Args: + rate_request: List of trading pairs to get rates for + + Returns: + Rates for the requested trading pairs + """ + try: + rate_oracle = get_rate_oracle(request) + + rates = {} + for pair in rate_request.trading_pairs: + try: + rate = rate_oracle.get_pair_rate(pair) + rates[pair] = float(rate) if rate and rate != Decimal("0") else None + except Exception: + rates[pair] = None + + return RateResponse( + source=rate_oracle.source.name, + quote_token=rate_oracle.quote_token, + rates=rates + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching rates: {str(e)}" + ) + + +@router.get("/rate/{trading_pair}", response_model=SingleRateResponse) +async def get_single_rate(request: Request, trading_pair: str): + """ + Get rate for a single trading pair. + + Args: + trading_pair: Trading pair in format BASE-QUOTE (e.g., BTC-USDT) + + Returns: + Rate for the specified trading pair + """ + try: + rate_oracle = get_rate_oracle(request) + + rate = rate_oracle.get_pair_rate(trading_pair) + rate_value = float(rate) if rate and rate != Decimal("0") else None + + return SingleRateResponse( + trading_pair=trading_pair, + rate=rate_value, + source=rate_oracle.source.name, + quote_token=rate_oracle.quote_token + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching rate for {trading_pair}: {str(e)}" + ) + + +@router.get("/rate-async/{trading_pair}", response_model=SingleRateResponse) +async def get_rate_async(request: Request, trading_pair: str): + """ + Get rate for a trading pair using async fetch (direct from exchange). + + This bypasses the cached prices and fetches directly from the source. + Useful when cached data may be stale or not yet initialized. + + Args: + trading_pair: Trading pair in format BASE-QUOTE (e.g., BTC-USDT) + + Returns: + Rate for the specified trading pair + """ + try: + rate_oracle = get_rate_oracle(request) + + rate = await rate_oracle.rate_async(trading_pair) + rate_value = float(rate) if rate and rate != Decimal("0") else None + + return SingleRateResponse( + trading_pair=trading_pair, + rate=rate_value, + source=rate_oracle.source.name, + quote_token=rate_oracle.quote_token + ) + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching async rate for {trading_pair}: {str(e)}" + ) + + +@router.get("/prices") +async def get_cached_prices(request: Request): + """ + Get all cached prices from the rate oracle. + + Returns the complete price dictionary that the rate oracle has fetched + from its configured source. + + Returns: + Dictionary of all cached prices + """ + try: + rate_oracle = get_rate_oracle(request) + + prices = rate_oracle.prices + # Convert Decimal to float for JSON serialization + float_prices = {pair: float(price) for pair, price in prices.items()} + + return { + "source": rate_oracle.source.name, + "quote_token": rate_oracle.quote_token, + "prices_count": len(float_prices), + "prices": float_prices + } + + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error fetching cached prices: {str(e)}" + ) diff --git a/routers/scripts.py b/routers/scripts.py new file mode 100644 index 00000000..c4c18409 --- /dev/null +++ b/routers/scripts.py @@ -0,0 +1,214 @@ +import json +import yaml +from typing import Dict, List + +from fastapi import APIRouter, HTTPException +from starlette import status + +from models import Script, ScriptConfig +from utils.file_system import fs_util + +router = APIRouter(tags=["Scripts"], prefix="/scripts") + + +@router.get("/", response_model=List[str]) +async def list_scripts(): + """ + List all available scripts. + + Returns: + List of script names (without .py extension) + """ + return [f.replace('.py', '') for f in fs_util.list_files('scripts') if f.endswith('.py')] + + +# Script Configuration endpoints (must come before script name routes) +@router.get("/configs/", response_model=List[Dict]) +async def list_script_configs(): + """ + List all script configurations with metadata. + + Returns: + List of script configuration objects with name, script_file_name, and other metadata + """ + try: + config_files = [f for f in fs_util.list_files('conf/scripts') if f.endswith('.yml')] + configs = [] + + for config_file in config_files: + config_name = config_file.replace('.yml', '') + try: + config = fs_util.read_yaml_file(f"conf/scripts/{config_file}") + configs.append({ + "config_name": config_name, + "script_file_name": config.get("script_file_name", "unknown"), + "controllers_config": config.get("controllers_config", []), + "candles_config": config.get("candles_config", []), + "markets": config.get("markets", {}) + }) + except Exception as e: + # If config is malformed, still include it with basic info + configs.append({ + "config_name": config_name, + "script_file_name": "error", + "error": str(e) + }) + + return configs + except FileNotFoundError: + return [] + + +@router.get("/configs/{config_name}", response_model=Dict) +async def get_script_config(config_name: str): + """ + Get script configuration by config name. + + Args: + config_name: Name of the configuration file to retrieve + + Returns: + Dictionary with script configuration + + Raises: + HTTPException: 404 if configuration not found + """ + try: + config = fs_util.read_yaml_file(f"conf/scripts/{config_name}.yml") + return config + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.post("/configs/{config_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_script_config(config_name: str, config: Dict): + """ + Create or update script configuration. + + Args: + config_name: Name of the configuration file + config: Configuration dictionary to save + + Returns: + Success message when configuration is saved + + Raises: + HTTPException: 400 if save error occurs + """ + try: + yaml_content = yaml.dump(config, default_flow_style=False) + fs_util.add_file('conf/scripts', f"{config_name}.yml", yaml_content, override=True) + return {"message": f"Configuration '{config_name}' saved successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/configs/{config_name}") +async def delete_script_config(config_name: str): + """ + Delete script configuration. + + Args: + config_name: Name of the configuration file to delete + + Returns: + Success message when configuration is deleted + + Raises: + HTTPException: 404 if configuration not found + """ + try: + fs_util.delete_file('conf/scripts', f"{config_name}.yml") + return {"message": f"Configuration '{config_name}' deleted successfully"} + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Configuration '{config_name}' not found") + + +@router.get("/{script_name}", response_model=Dict[str, str]) +async def get_script(script_name: str): + """ + Get script content by name. + + Args: + script_name: Name of the script to retrieve + + Returns: + Dictionary with script name and content + + Raises: + HTTPException: 404 if script not found + """ + try: + content = fs_util.read_file(f"scripts/{script_name}.py") + return { + "name": script_name, + "content": content + } + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Script '{script_name}' not found") + + +@router.post("/{script_name}", status_code=status.HTTP_201_CREATED) +async def create_or_update_script(script_name: str, script: Script): + """ + Create or update a script. + + Args: + script_name: Name of the script (from URL path) + script: Script object with content + + Returns: + Success message when script is saved + + Raises: + HTTPException: 400 if save error occurs + """ + try: + fs_util.add_file('scripts', f"{script_name}.py", script.content, override=True) + return {"message": f"Script '{script_name}' saved successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/{script_name}") +async def delete_script(script_name: str): + """ + Delete a script. + + Args: + script_name: Name of the script to delete + + Returns: + Success message when script is deleted + + Raises: + HTTPException: 404 if script not found + """ + try: + fs_util.delete_file('scripts', f"{script_name}.py") + return {"message": f"Script '{script_name}' deleted successfully"} + except FileNotFoundError: + raise HTTPException(status_code=404, detail=f"Script '{script_name}' not found") + + +@router.get("/{script_name}/config/template", response_model=Dict) +async def get_script_config_template(script_name: str): + """ + Get script configuration template with default values. + + Args: + script_name: Name of the script to get template for + + Returns: + Dictionary with configuration template and default values + + Raises: + HTTPException: 404 if script configuration class not found + """ + config_class = fs_util.load_script_config_class(script_name) + if config_class is None: + raise HTTPException(status_code=404, detail=f"Script configuration class for '{script_name}' not found") + + # Extract fields and default values + config_fields = {name: field.default for name, field in config_class.model_fields.items()} + return json.loads(json.dumps(config_fields, default=str)) \ No newline at end of file diff --git a/routers/trading.py b/routers/trading.py new file mode 100644 index 00000000..9589f184 --- /dev/null +++ b/routers/trading.py @@ -0,0 +1,451 @@ +import logging +import math + +from typing import Dict, List, Optional + +from fastapi import APIRouter, Depends, HTTPException + +# Create module-specific logger +logger = logging.getLogger(__name__) +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType +from pydantic import BaseModel +from starlette import status + +from deps import get_accounts_service, get_connector_service +from models import ( + ActiveOrderFilterRequest, + FundingPaymentFilterRequest, + OrderFilterRequest, + PaginatedResponse, + PositionFilterRequest, + TradeFilterRequest, + TradeRequest, + TradeResponse, +) +from models.accounts import LeverageRequest, PositionModeRequest +from services.accounts_service import AccountsService + +router = APIRouter(tags=["Trading"], prefix="/trading") + + +# Trade Execution +@router.post("/orders", response_model=TradeResponse, status_code=status.HTTP_201_CREATED) +async def place_trade( + trade_request: TradeRequest, + accounts_service: AccountsService = Depends(get_accounts_service), +): + """ + Place a buy or sell order using a specific account and connector. + + Args: + trade_request: Trading request with account, connector, trading pair, type, amount, etc. + accounts_service: Injected accounts service + + Returns: + TradeResponse with order ID and trading details + + Raises: + HTTPException: 400 for invalid parameters, 404 for account/connector not found, 500 for trade execution errors + """ + try: + # Convert string names to enum instances + trade_type_enum = TradeType[trade_request.trade_type] + order_type_enum = OrderType[trade_request.order_type] + position_action_enum = PositionAction[trade_request.position_action] + + order_id = await accounts_service.place_trade( + account_name=trade_request.account_name, + connector_name=trade_request.connector_name, + trading_pair=trade_request.trading_pair, + trade_type=trade_type_enum, + amount=trade_request.amount, + order_type=order_type_enum, + price=trade_request.price, + position_action=position_action_enum, + ) + + return TradeResponse( + order_id=order_id, + account_name=trade_request.account_name, + connector_name=trade_request.connector_name, + trading_pair=trade_request.trading_pair, + trade_type=trade_request.trade_type, + amount=trade_request.amount, + order_type=trade_request.order_type, + price=trade_request.price, + status="submitted", + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Unexpected error placing trade: {str(e)}") + + +@router.post("/{account_name}/{connector_name}/orders/{client_order_id}/cancel") +async def cancel_order( + account_name: str, + connector_name: str, + client_order_id: str, + accounts_service: AccountsService = Depends(get_accounts_service), +): + """ + Cancel a specific order by its client order ID. + + Args: + account_name: Name of the account + connector_name: Name of the connector + client_order_id: Client order ID to cancel + trading_pair: Trading pair for the order + accounts_service: Injected accounts service + + Returns: + Success message with cancelled order ID + + Raises: + HTTPException: 404 if account/connector not found, 500 for cancellation errors + """ + try: + cancelled_order_id = await accounts_service.cancel_order( + account_name=account_name, connector_name=connector_name, client_order_id=client_order_id + ) + return {"message": f"Order cancellation initiated for {cancelled_order_id}"} + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error cancelling order: {str(e)}") + + +@router.post("/positions", response_model=PaginatedResponse) +async def get_positions( + filter_request: PositionFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + connector_service = Depends(get_connector_service) +): + """ + Get current positions across all or filtered perpetual connectors. + + This endpoint fetches real-time position data directly from the connectors, + including unrealized PnL, leverage, funding fees, and margin information. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with position data and pagination metadata + + Raises: + HTTPException: 500 if there's an error fetching positions + """ + try: + all_positions = [] + all_connectors = connector_service.get_all_trading_connectors() + + # Filter accounts + accounts_to_check = filter_request.account_names if filter_request.account_names else list(all_connectors.keys()) + + for account_name in accounts_to_check: + if account_name in all_connectors: + # Filter connectors + connectors_to_check = ( + filter_request.connector_names + if filter_request.connector_names + else list(all_connectors[account_name].keys()) + ) + + for connector_name in connectors_to_check: + # Only fetch positions from perpetual connectors + if connector_name in all_connectors[account_name] and "_perpetual" in connector_name: + try: + positions = await accounts_service.get_account_positions(account_name, connector_name) + # Add cursor-friendly identifier to each position + for position in positions: + position["_cursor_id"] = f"{account_name}:{connector_name}:{position.get('trading_pair', '')}" + all_positions.extend(positions) + except Exception as e: + # Log error but continue with other connectors + import logging + + logger.warning(f"Failed to get positions for {account_name}/{connector_name}: {e}") + + # Sort by cursor_id for consistent pagination + all_positions.sort(key=lambda x: x.get("_cursor_id", "")) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the position after the cursor + for i, position in enumerate(all_positions): + if position.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_positions = all_positions[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_positions) + next_cursor = page_positions[-1].get("_cursor_id") if page_positions and has_more else None + + # Clean up cursor_id from response data + for position in page_positions: + position.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_positions, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_positions), + }, + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching positions: {str(e)}") + + +# Active Orders Management - Real-time from connectors +@router.post("/orders/active", response_model=PaginatedResponse) +async def get_active_orders( + filter_request: ActiveOrderFilterRequest, + connector_service = Depends(get_connector_service) +): + """ + Get active (in-flight) orders across all or filtered accounts and connectors. + + This endpoint fetches real-time active orders directly from the connectors' in_flight_orders property, + providing current order status, fill amounts, and other live order data. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with active order data and pagination metadata + + Raises: + HTTPException: 500 if there's an error fetching orders + """ + try: + all_active_orders = [] + all_connectors = connector_service.get_all_trading_connectors() + + # Use filter request values + accounts_to_check = filter_request.account_names if filter_request.account_names else list(all_connectors.keys()) + + for account_name in accounts_to_check: + if account_name in all_connectors: + # Filter connectors + connectors_to_check = ( + filter_request.connector_names + if filter_request.connector_names + else list(all_connectors[account_name].keys()) + ) + + for connector_name in connectors_to_check: + if connector_name in all_connectors[account_name]: + try: + connector = all_connectors[account_name][connector_name] + # Get in-flight orders directly from connector + in_flight_orders = connector.in_flight_orders + + for client_order_id, order in in_flight_orders.items(): + # Apply trading pair filter if specified + if filter_request.trading_pairs and order.trading_pair not in filter_request.trading_pairs: + continue + + # Convert to standardized format to match orders search response + standardized_order = _standardize_in_flight_order_response(order, account_name, connector_name) + standardized_order["_cursor_id"] = client_order_id # Use client_order_id as cursor + all_active_orders.append(standardized_order) + + except Exception as e: + # Log error but continue with other connectors + import logging + + logger.warning(f"Failed to get active orders for {account_name}/{connector_name}: {e}") + + # Sort by cursor_id for consistent pagination + all_active_orders.sort(key=lambda x: x.get("_cursor_id", "")) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the order after the cursor + for i, order in enumerate(all_active_orders): + if order.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_orders = all_active_orders[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_active_orders) + next_cursor = page_orders[-1].get("_cursor_id") if page_orders and has_more else None + + # Clean up cursor_id from response data + for order in page_orders: + order.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_orders, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_active_orders), + }, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching active orders: {str(e)}") + + +# Historical Order Management - From registry/database +@router.post("/orders/search", response_model=PaginatedResponse) +async def get_orders( + filter_request: OrderFilterRequest, + accounts_service: AccountsService = Depends(get_accounts_service), + connector_service = Depends(get_connector_service) +): + """ + Get historical order data across all or filtered accounts from the database/registry. + + Args: + filter_request: JSON payload with filtering criteria + + Returns: + Paginated response with historical order data and pagination metadata + """ + try: + all_orders = [] + + # Determine which accounts to query + if filter_request.account_names: + accounts_to_check = filter_request.account_names + else: + # Get all accounts + all_connectors = connector_service.get_all_trading_connectors() + accounts_to_check = list(all_connectors.keys()) + + # Collect orders from all specified accounts + for account_name in accounts_to_check: + try: + orders = await accounts_service.get_orders( + account_name=account_name, + connector_name=( + filter_request.connector_names[0] + if filter_request.connector_names and len(filter_request.connector_names) == 1 + else None + ), + trading_pair=( + filter_request.trading_pairs[0] + if filter_request.trading_pairs and len(filter_request.trading_pairs) == 1 + else None + ), + status=filter_request.status, + start_time=filter_request.start_time, + end_time=filter_request.end_time, + limit=filter_request.limit * 2, # Get more for filtering + offset=0, + ) + # Add cursor-friendly identifier to each order + for order in orders: + order["_cursor_id"] = f"{order.get('timestamp', 0)}:{order.get('client_order_id', '')}" + all_orders.extend(orders) + except Exception as e: + # import logging + logger.warning(f"Failed to get orders for {account_name}: {e}") + + # Apply filters for multiple values + if filter_request.connector_names and len(filter_request.connector_names) > 1: + all_orders = [order for order in all_orders if order.get("connector_name") in filter_request.connector_names] + if filter_request.trading_pairs and len(filter_request.trading_pairs) > 1: + all_orders = [order for order in all_orders if order.get("trading_pair") in filter_request.trading_pairs] + + # Sort by timestamp (most recent first) and then by cursor_id for consistency + all_orders.sort(key=lambda x: (x.get("timestamp", 0), x.get("_cursor_id", "")), reverse=True) + + # Apply cursor-based pagination + start_index = 0 + if filter_request.cursor: + # Find the order after the cursor + for i, order in enumerate(all_orders): + if order.get("_cursor_id") == filter_request.cursor: + start_index = i + 1 + break + + # Get page of results + end_index = start_index + filter_request.limit + page_orders = all_orders[start_index:end_index] + + # Determine next cursor and has_more + has_more = end_index < len(all_orders) + next_cursor = page_orders[-1].get("_cursor_id") if page_orders and has_more else None + + # Clean up cursor_id from response data + for order in page_orders: + order.pop("_cursor_id", None) + + return PaginatedResponse( + data=page_orders, + pagination={ + "limit": filter_request.limit, + "has_more": has_more, + "next_cursor": next_cursor, + "total_count": len(all_orders), + }, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error fetching orders: {str(e)}") + + +def _standardize_in_flight_order_response(order, account_name: str, connector_name: str) -> dict: + from hummingbot.core.data_type.in_flight_order import OrderState + status_mapping = { + OrderState.PENDING_CREATE: "SUBMITTED", + OrderState.OPEN: "OPEN", + OrderState.PENDING_CANCEL: "PENDING_CANCEL", + OrderState.CANCELED: "CANCELLED", + OrderState.PARTIALLY_FILLED: "PARTIALLY_FILLED", + OrderState.FILLED: "FILLED", + OrderState.FAILED: "FAILED", + OrderState.PENDING_APPROVAL: "SUBMITTED", + OrderState.APPROVED: "SUBMITTED", + OrderState.CREATED: "SUBMITTED", + OrderState.COMPLETED: "FILLED", + } + status = status_mapping.get(order.current_state, "SUBMITTED") + from datetime import datetime, timezone + def _safe_iso_timestamp(timestamp_value): + if timestamp_value is None: + return None + try: + timestamp_float = float(timestamp_value) + except (TypeError, ValueError): + return None + if not math.isfinite(timestamp_float): + return None + return datetime.fromtimestamp(timestamp_float, tz=timezone.utc).isoformat() + created_at = _safe_iso_timestamp(getattr(order, "creation_timestamp", None)) + updated_at = _safe_iso_timestamp(getattr(order, "last_update_timestamp", None)) or created_at + return { + "order_id": order.client_order_id, + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": order.trading_pair, + "trade_type": order.trade_type.name, + "order_type": order.order_type.name, + "amount": float(order.amount) if order.amount and not math.isnan(float(order.amount)) else 0, + "price": float(order.price) if order.price and not math.isnan(float(order.price)) else None, + "status": status, + "filled_amount": float(getattr(order, "executed_amount_base", 0) or 0) if not math.isnan(float(getattr(order, "executed_amount_base", 0) or 0)) else 0, + "average_fill_price": float(getattr(order, "last_executed_price", 0)) if getattr(order, "last_executed_price", None) and not math.isnan(float(getattr(order, "last_executed_price", 0))) else None, + "fee_paid": float(getattr(order, "cumulative_fee_paid_quote", 0)) if getattr(order, "cumulative_fee_paid_quote", None) and not math.isnan(float(getattr(order, "cumulative_fee_paid_quote", 0))) else None, + "fee_currency": None, + "created_at": created_at, + "updated_at": updated_at, + "exchange_order_id": order.exchange_order_id, + "error_message": None, + } \ No newline at end of file diff --git a/services/__init__.py b/services/__init__.py index e69de29b..2bd037e2 100644 --- a/services/__init__.py +++ b/services/__init__.py @@ -0,0 +1,9 @@ +from .accounts_service import AccountsService +from .bots_orchestrator import BotsOrchestrator +from .docker_service import DockerService + +__all__ = [ + "AccountsService", + "BotsOrchestrator", + "DockerService", +] \ No newline at end of file diff --git a/services/accounts_service.py b/services/accounts_service.py index be05f583..bb987737 100644 --- a/services/accounts_service.py +++ b/services/accounts_service.py @@ -1,22 +1,436 @@ import asyncio -import json import logging -from datetime import datetime, timedelta +import time +from datetime import datetime, timezone from decimal import Decimal -from typing import Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Set from fastapi import HTTPException -from hummingbot.client.config.client_config_map import ClientConfigMap from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger -from hummingbot.client.config.config_helpers import ClientConfigAdapter, ReadOnlyClientConfigAdapter, get_connector_class -from hummingbot.client.settings import AllConnectorSettings +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType -from config import BANNED_TOKENS, CONFIG_PASSWORD -from utils.file_system import FileSystemUtil -from utils.models import BackendAPIConfigAdapter -from utils.security import BackendAPISecurity +from config import settings +from database import AccountRepository, AsyncDatabaseManager, FundingRepository, OrderRepository, TradeRepository +from services.gateway_client import GatewayClient +from services.gateway_transaction_poller import GatewayTransactionPoller +from utils.file_system import fs_util -file_system = FileSystemUtil() +# Create module-specific logger +logger = logging.getLogger(__name__) + + +class AccountTradingInterface: + """ + ScriptStrategyBase-compatible interface for executor trading. + + This class provides the exact interface that Hummingbot executors expect + from a strategy object, backed by AccountsService resources. + + IMPORTANT: This class does NOT maintain its own connector cache. Instead, it + uses the shared ConnectorManager via AccountsService which is the single source + of truth for all connector instances. + + Executors use the following interface from strategy: + - current_timestamp: float property + - buy(connector_name, trading_pair, amount, order_type, price, position_action) -> str + - sell(connector_name, trading_pair, amount, order_type, price, position_action) -> str + - cancel(connector_name, trading_pair, order_id) -> str + - get_active_orders(connector_name) -> List + + ExecutorBase also accesses: + - connectors: Dict[str, ConnectorBase] (accessed directly in ExecutorBase.__init__) + """ + + def __init__( + self, + accounts_service: 'AccountsService', + account_name: str + ): + """ + Initialize AccountTradingInterface. + + Args: + accounts_service: AccountsService instance for connector access + account_name: Account to use for connectors + """ + self._accounts_service = accounts_service + self._account_name = account_name + + # Track active markets (connector_name -> set of trading_pairs) + self._markets: Dict[str, Set[str]] = {} + + # Timestamp tracking + self._current_timestamp: float = time.time() + + # Lock for async operations + self._lock = asyncio.Lock() + + @property + def account_name(self) -> str: + """Return the account name for this trading interface.""" + return self._account_name + + @property + def connectors(self) -> Dict[str, ConnectorBase]: + """ + Return connectors for this account from the connector service. + + This returns the actual connectors that are already initialized and running, + avoiding any duplicate caching or connector management. + """ + if not self._accounts_service._connector_service: + return {} + all_connectors = self._accounts_service._connector_service.get_all_trading_connectors() + return all_connectors.get(self._account_name, {}) + + @property + def markets(self) -> Dict[str, Set[str]]: + """Return active markets configuration.""" + return self._markets + + @property + def current_timestamp(self) -> float: + """Return current timestamp (updated by control loop).""" + return self._current_timestamp + + def update_timestamp(self): + """Update the current timestamp. Called by ExecutorService control loop.""" + self._current_timestamp = time.time() + + async def ensure_connector(self, connector_name: str) -> ConnectorBase: + """ + Ensure connector is loaded and available. + + This method uses the connector service which already caches connectors. + It also ensures the MarketDataProvider has access to the connector for + order book initialization. + + Args: + connector_name: Name of the connector + + Returns: + The connector instance + """ + # Get connector from connector service (already cached there) + connector = await self._accounts_service._connector_service.get_trading_connector( + self._account_name, + connector_name + ) + return connector + + async def add_market( + self, + connector_name: str, + trading_pair: str, + order_book_timeout: float = 10.0 + ): + """ + Add a trading pair to active markets with full order book support. + + This method ensures: + 1. Connector is loaded + 2. Order book is initialized and has valid data + 3. Rate sources are initialized for price feeds + + Args: + connector_name: Name of the connector + trading_pair: Trading pair to add + order_book_timeout: Timeout in seconds to wait for order book data + """ + await self.ensure_connector(connector_name) + + if connector_name not in self._markets: + self._markets[connector_name] = set() + + # Check if already tracking this pair + if trading_pair in self._markets[connector_name]: + logger.debug(f"Market {connector_name}/{trading_pair} already active") + return + + self._markets[connector_name].add(trading_pair) + + # Get connector and its order book tracker + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not available. Check credentials.") + tracker = connector.order_book_tracker + + # Check if order book already exists, if not initialize it dynamically + if trading_pair in tracker.order_books: + logger.debug(f"Order book already exists for {connector_name}/{trading_pair}") + else: + logger.debug(f"Order book not found for {connector_name}/{trading_pair}, initializing dynamically") + market_data_service = self._accounts_service._market_data_service + if market_data_service: + try: + success = await market_data_service.initialize_order_book( + connector_name, trading_pair, + account_name=self._account_name, + timeout=order_book_timeout + ) + if not success: + logger.warning(f"Order book for {connector_name}/{trading_pair} not ready after timeout") + except Exception as e: + logger.warning(f"Exception initializing order book: {e}") + + # Register the trading pair with the connector + self._register_trading_pair_with_connector(connector, trading_pair) + + async def _wait_for_order_book_ready( + self, + tracker, + trading_pair: str, + timeout: float = 30.0 + ) -> bool: + """ + Wait for an order book to have valid data. + + Args: + tracker: Order book tracker instance + trading_pair: Trading pair to wait for + timeout: Maximum time to wait in seconds + + Returns: + True if order book is ready, False if timeout + """ + import asyncio + waited = 0 + interval = 0.5 + while waited < timeout: + if trading_pair in tracker.order_books: + ob = tracker.order_books[trading_pair] + try: + bids, asks = ob.snapshot + if len(bids) > 0 and len(asks) > 0: + logger.info(f"Order book for {trading_pair} is ready with {len(bids)} bids and {len(asks)} asks") + return True + except Exception: + pass + await asyncio.sleep(interval) + waited += interval + logger.warning(f"Timeout waiting for {trading_pair} order book to be ready") + return False + + def _register_trading_pair_with_connector( + self, + connector: ConnectorBase, + trading_pair: str + ): + """ + Register a trading pair with the connector's internal structures. + + This is needed for methods like get_order_book() to work properly. + Different connector types may store trading pairs differently. + + Args: + connector: The connector instance + trading_pair: Trading pair to register + """ + if trading_pair not in connector._trading_pairs: + connector._trading_pairs.append(trading_pair) + logger.debug(f"Registered {trading_pair} with connector {type(connector).__name__}") + + async def remove_market( + self, + connector_name: str, + trading_pair: str, + remove_order_book: bool = True + ): + """ + Remove a trading pair from active markets and optionally cleanup order book. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair to remove + remove_order_book: Whether to remove the order book (default True) + """ + if connector_name not in self._markets: + return + + self._markets[connector_name].discard(trading_pair) + if not self._markets[connector_name]: + del self._markets[connector_name] + + # Remove order book if requested + if remove_order_book: + market_data_service = self._accounts_service._market_data_service + if market_data_service: + try: + success = await market_data_service.remove_trading_pair( + connector_name, + trading_pair, + account_name=self._account_name + ) + if success: + logger.info(f"Removed order book for {connector_name}/{trading_pair}") + else: + logger.debug(f"Order book for {trading_pair} was not being tracked") + except Exception as e: + logger.warning(f"Failed to remove order book for {trading_pair}: {e}") + + # ======================================== + # ScriptStrategyBase-compatible methods + # These are called by executors via self._strategy.method() + # ======================================== + + def buy( + self, + connector_name: str, + trading_pair: str, + amount: Decimal, + order_type: OrderType, + price: Decimal = Decimal("NaN"), + position_action: PositionAction = PositionAction.NIL + ) -> str: + """ + Place a buy order. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + amount: Order amount in base currency + order_type: Type of order (LIMIT, MARKET, etc.) + price: Order price (for limit orders) + position_action: Position action for perpetuals + + Returns: + Client order ID + """ + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") + connector._set_current_timestamp(time.time()) + + return connector.buy( + trading_pair=trading_pair, + amount=amount, + order_type=order_type, + price=price, + position_action=position_action + ) + + def sell( + self, + connector_name: str, + trading_pair: str, + amount: Decimal, + order_type: OrderType, + price: Decimal = Decimal("NaN"), + position_action: PositionAction = PositionAction.NIL + ) -> str: + """ + Place a sell order. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + amount: Order amount in base currency + order_type: Type of order (LIMIT, MARKET, etc.) + price: Order price (for limit orders) + position_action: Position action for perpetuals + + Returns: + Client order ID + """ + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") + connector._set_current_timestamp(time.time()) + + return connector.sell( + trading_pair=trading_pair, + amount=amount, + order_type=order_type, + price=price, + position_action=position_action + ) + + def cancel( + self, + connector_name: str, + trading_pair: str, + order_id: str + ) -> str: + """ + Cancel an order. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + order_id: Client order ID to cancel + + Returns: + Client order ID that was cancelled + """ + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") + + return connector.cancel(trading_pair=trading_pair, client_order_id=order_id) + + def get_active_orders(self, connector_name: str) -> List: + """ + Get active orders for a connector. + + Args: + connector_name: Name of the connector + + Returns: + List of active in-flight orders + """ + connector = self.connectors.get(connector_name) + if not connector: + return [] + return list(connector.in_flight_orders.values()) + + # ======================================== + # Additional helper methods + # ======================================== + + def get_connector(self, connector_name: str) -> Optional[ConnectorBase]: + """ + Get a connector by name from the shared ConnectorManager. + + Args: + connector_name: Name of the connector + + Returns: + The connector instance or None if not loaded + """ + return self.connectors.get(connector_name) + + def is_connector_loaded(self, connector_name: str) -> bool: + """ + Check if a connector is loaded in the shared ConnectorManager. + + Args: + connector_name: Name of the connector + + Returns: + True if connector is loaded + """ + return connector_name in self.connectors + + def get_all_trading_pairs(self) -> Dict[str, Set[str]]: + """ + Get all active trading pairs by connector. + + Returns: + Dictionary mapping connector names to sets of trading pairs + """ + return {k: v.copy() for k, v in self._markets.items()} + + async def cleanup(self): + """ + Cleanup resources. Called when shutting down. + + Note: This does NOT clean up connectors since they are managed by the + shared ConnectorManager, not by AccountTradingInterface. + """ + # Clear only local state (markets tracking) + self._markets.clear() + logger.info(f"AccountTradingInterface cleanup completed for account {self._account_name}") class AccountsService: @@ -25,286 +439,517 @@ class AccountsService: to initialize all the connectors that are connected to each account, keep track of the balances of each account and update the balances of each account. """ + default_quotes = { + "hyperliquid": "USD", + "hyperliquid_perpetual": "USDC", + "xrpl": "RLUSD", + "kraken": "USD", + } + gateway_default_pricing_connector = { + "ethereum": "uniswap/router", + "solana": "jupiter/router", + } + potential_wrapped_tokens = ["ETH", "SOL", "BNB", "POL", "AVAX", "FTM", "ONE", "GLMR", "MOVR"] + + # Cache for storing last successful prices by trading pair + _last_known_prices = {} def __init__(self, - update_account_state_interval_minutes: int = 1, + account_update_interval: int = 5, default_quote: str = "USDT", - account_history_file: str = "account_state_history.json", - account_history_dump_interval_minutes: int = 1): - # TODO: Add database to store the balances of each account each time it is updated. - self.secrets_manager = ETHKeyFileSecretManger(CONFIG_PASSWORD) - self.accounts = {} + gateway_url: str = "http://localhost:15888"): + """ + Initialize the AccountsService. + + Args: + account_update_interval: How often to update account states in minutes (default: 5) + default_quote: Default quote currency for trading pairs (default: "USDT") + gateway_url: URL for Gateway service (default: "http://localhost:15888") + """ + self.secrets_manager = ETHKeyFileSecretManger(settings.security.config_password) self.accounts_state = {} - self.account_state_update_event = asyncio.Event() - self.initialize_accounts() - self.update_account_state_interval = update_account_state_interval_minutes * 60 + self.update_account_state_interval = account_update_interval * 60 + self.order_status_poll_interval = 60 # Poll order status every 1 minute self.default_quote = default_quote - self.history_file = account_history_file - self.account_history_dump_interval = account_history_dump_interval_minutes self._update_account_state_task: Optional[asyncio.Task] = None - self._dump_account_state_task: Optional[asyncio.Task] = None + self._order_status_polling_task: Optional[asyncio.Task] = None + + # Database setup for account states and orders + self.db_manager = AsyncDatabaseManager(settings.database.url) + self._db_initialized = False + + # Services injected from main.py + self._connector_service = None # UnifiedConnectorService + self._market_data_service = None # MarketDataService + self._trading_service = None # TradingService + + # Initialize Gateway client + self.gateway_client = GatewayClient(gateway_url) + # Initialize Gateway transaction poller + self.gateway_tx_poller = GatewayTransactionPoller( + db_manager=self.db_manager, + gateway_client=self.gateway_client, + poll_interval=10, # Poll every 10 seconds for transactions + position_poll_interval=60, # Poll every 1 minute for positions + max_retry_age=3600 # Stop retrying after 1 hour + ) + self._gateway_poller_started = False + + # Trading interfaces per account (for executor use) + self._trading_interfaces: Dict[str, AccountTradingInterface] = {} + + def get_trading_interface(self, account_name: str) -> AccountTradingInterface: + """ + Get or create a trading interface for the specified account. + + This interface provides ScriptStrategyBase-compatible methods + that executors can use for trading operations. + + Args: + account_name: Account to get trading interface for + + Returns: + AccountTradingInterface instance for the account + """ + if account_name not in self._trading_interfaces: + self._trading_interfaces[account_name] = AccountTradingInterface( + accounts_service=self, + account_name=account_name + ) + return self._trading_interfaces[account_name] + + async def ensure_db_initialized(self): + """Ensure database is initialized before using it.""" + if not self._db_initialized: + await self.db_manager.create_tables() + self._db_initialized = True + def get_accounts_state(self): return self.accounts_state - def get_default_market(self, token): - return f"{token}-{self.default_quote}" + def get_default_market(self, token: str, connector_name: str) -> str: + if token.startswith("LD") and token != "LDO": + # These tokens are staked in binance earn + token = token[2:] + quote = self.default_quotes.get(connector_name, self.default_quote) + return f"{token}-{quote}" - def start_update_account_state_loop(self): + def start(self): """ - Start the loop that updates the balances of all the accounts at a fixed interval. + Start the loop that updates the account state at a fixed interval. + Note: Balance updates are now handled by manual connector state updates. :return: """ + # Start the update loop which will call check_all_connectors self._update_account_state_task = asyncio.create_task(self.update_account_state_loop()) - self._dump_account_state_task = asyncio.create_task(self.dump_account_state_loop()) - def stop_update_account_state_loop(self): + # Start order status polling loop (every 1 minute) + self._order_status_polling_task = asyncio.create_task(self.order_status_polling_loop()) + logger.info("Order status polling started (1 minute interval)") + + # Start Gateway transaction poller + if not self._gateway_poller_started: + asyncio.create_task(self._start_gateway_poller()) + self._gateway_poller_started = True + logger.info("Gateway transaction poller startup initiated") + + async def _start_gateway_poller(self): + """Start the Gateway transaction poller (async helper).""" + try: + await self.gateway_tx_poller.start() + logger.info("Gateway transaction poller started successfully") + except Exception as e: + logger.error(f"Error starting Gateway transaction poller: {e}", exc_info=True) + + async def stop(self): """ - Stop the loop that updates the balances of all the accounts at a fixed interval. - :return: + Stop all accounts service tasks and cleanup resources. + This is the main cleanup method that should be called during application shutdown. """ + logger.info("Stopping AccountsService...") + + # Stop the account state update loop if self._update_account_state_task: self._update_account_state_task.cancel() - if self._dump_account_state_task: - self._dump_account_state_task.cancel() - self._update_account_state_task = None - self._dump_account_state_task = None + self._update_account_state_task = None + logger.info("Stopped account state update loop") + + # Stop the order status polling loop + if self._order_status_polling_task: + self._order_status_polling_task.cancel() + self._order_status_polling_task = None + logger.info("Stopped order status polling loop") + + # Stop Gateway transaction poller + if self._gateway_poller_started: + try: + await self.gateway_tx_poller.stop() + logger.info("Gateway transaction poller stopped") + self._gateway_poller_started = False + except Exception as e: + logger.error(f"Error stopping Gateway transaction poller: {e}", exc_info=True) + + # Cleanup trading interfaces + for interface in self._trading_interfaces.values(): + await interface.cleanup() + self._trading_interfaces.clear() + logger.info("Cleaned up trading interfaces") + + # Stop all connectors through the connector service + if self._connector_service: + await self._connector_service.stop_all() + + logger.info("AccountsService stopped successfully") + + async def _refresh_and_get_tokens_info(self, connector, connector_name: str, account_name: str) -> List[Dict]: + """Refresh connector state from exchange, then get token info with prices. + + Combines the connector state refresh and token info retrieval into a + single awaitable so both can run in parallel across all connectors. + """ + if self._connector_service: + try: + await self._connector_service._update_connector_state(connector, connector_name, account_name) + except Exception as e: + logger.error(f"Error refreshing {connector_name}, using stale data: {e}") + return await self._get_connector_tokens_info(connector, connector_name) async def update_account_state_loop(self): """ - The loop that updates the balances of all the accounts at a fixed interval. - :return: + The loop that updates the account state at a fixed interval. + Performs connector state refresh + token info retrieval in a single parallel pass. """ while True: try: await self.check_all_connectors() - await self.update_balances() - await self.update_trading_rules() - await self.update_account_state() + + # Single parallel pass: refresh connector state + get token info + gateway + all_connectors = self._connector_service.get_all_trading_connectors() if self._connector_service else {} + tasks = [] + task_meta = [] # (account_name, connector_name) + + for account_name, connectors in all_connectors.items(): + if account_name not in self.accounts_state: + self.accounts_state[account_name] = {} + for connector_name, connector in connectors.items(): + tasks.append(self._refresh_and_get_tokens_info(connector, connector_name, account_name)) + task_meta.append((account_name, connector_name)) + + has_connector_tasks = len(tasks) > 0 + tasks.append(self._update_gateway_balances()) + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Process connector results (last result is always gateway) + connector_results = results[:-1] if has_connector_tasks else [] + for (account_name, connector_name), result in zip(task_meta, connector_results): + if isinstance(result, Exception): + logger.error(f"Error updating {connector_name} in {account_name}: {result}") + self.accounts_state[account_name][connector_name] = [] + else: + self.accounts_state[account_name][connector_name] = result + + gw_result = results[-1] + if isinstance(gw_result, Exception): + logger.error(f"Error updating gateway balances: {gw_result}") + + await self.dump_account_state() except Exception as e: - logging.error(f"Error updating account state: {e}") + logger.error(f"Error updating account state: {e}") finally: await asyncio.sleep(self.update_account_state_interval) - async def dump_account_state_loop(self): + async def order_status_polling_loop(self): """ - The loop that dumps the current account state to a file at fixed intervals. - :return: + Sync order state to database for all connectors at a frequent interval (1 minute). + + The connector's built-in _lost_orders_update_polling_loop already polls the exchange. + This loop just syncs that state to our database and cleans up closed orders. """ - await self.account_state_update_event.wait() while True: try: - await self.dump_account_state() + if self._connector_service: + await self._connector_service.sync_all_orders_to_database() except Exception as e: - logging.error(f"Error dumping account state: {e}") + logger.error(f"Error syncing order state to database: {e}") finally: - now = datetime.now() - next_log_time = (now + timedelta(minutes=self.account_history_dump_interval)).replace(second=0, - microsecond=0) - next_log_time = next_log_time - timedelta( - minutes=next_log_time.minute % self.account_history_dump_interval) - sleep_duration = (next_log_time - now).total_seconds() - await asyncio.sleep(sleep_duration) + await asyncio.sleep(self.order_status_poll_interval) async def dump_account_state(self): """ - Dump the current account state to a JSON file. Create it if the file not exists. + Save the current account state to the database. + All account/connector combinations from the same snapshot will use the same timestamp. :return: """ - timestamp = datetime.now().isoformat() - state_to_dump = {"timestamp": timestamp, "state": self.accounts_state} - if not file_system.path_exists(path=f"data/{self.history_file}"): - file_system.add_file(directory="data", file_name=self.history_file, content=json.dumps(state_to_dump) + "\n") - else: - file_system.append_to_file(directory="data", file_name=self.history_file, content=json.dumps(state_to_dump) + "\n") + await self.ensure_db_initialized() + + try: + # Generate a single timestamp for this entire snapshot + snapshot_timestamp = datetime.now(timezone.utc) + + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + + # Save each account-connector combination with the same timestamp + for account_name, connectors in self.accounts_state.items(): + for connector_name, tokens_info in connectors.items(): + if tokens_info: # Only save if there's token data + await repository.save_account_state(account_name, connector_name, tokens_info, snapshot_timestamp) + + except Exception as e: + logger.error(f"Error saving account state to database: {e}") + # Re-raise the exception since we no longer have a fallback + raise - def load_account_state_history(self): + async def load_account_state_history(self, + limit: Optional[int] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + interval: str = "5m"): """ - Load the account state history from the JSON file. - :return: List of account states with timestamps. + Load the account state history from the database with pagination and interval sampling. + + Args: + limit: Maximum number of records to return + cursor: Cursor for pagination + start_time: Start time filter + end_time: End time filter + interval: Sampling interval (5m, 15m, 30m, 1h, 4h, 12h, 1d) + + :return: Tuple of (data, next_cursor, has_more). """ - history = [] + await self.ensure_db_initialized() + try: - with open("bots/data/" + self.history_file, "r") as file: - for line in file: - if line.strip(): # Check if the line is not empty - history.append(json.loads(line)) - except FileNotFoundError: - logging.warning("No account state history file found.") - return history + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_state_history( + limit=limit, + cursor=cursor, + start_time=start_time, + end_time=end_time, + interval=interval + ) + except Exception as e: + logger.error(f"Error loading account state history from database: {e}") + # Return empty result since we no longer have a fallback + return [], None, False async def check_all_connectors(self): """ - Check all avaialble credentials for all accounts and see if the connectors are created. - :return: + Check all available credentials for all accounts and ensure connectors are initialized. + This method is idempotent - it only initializes missing connectors. """ for account_name in self.list_accounts(): - for connector_name in self.list_credentials(account_name): - try: - connector_name = connector_name.split(".")[0] - if account_name not in self.accounts or connector_name not in self.accounts[account_name]: - self.initialize_connector(account_name, connector_name) - except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") + await self._ensure_account_connectors_initialized(account_name) - def initialize_accounts(self): - """ - Initialize all the connectors that are connected to each account. - :return: + async def _ensure_account_connectors_initialized(self, account_name: str): """ - for account_name in self.list_accounts(): - self.accounts[account_name] = {} - for connector_name in self.list_credentials(account_name): - try: - connector_name = connector_name.split(".")[0] - connector = self.get_connector(account_name, connector_name) - self.accounts[account_name][connector_name] = connector - except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") + Ensure all connectors for a specific account are initialized. + This delegates to the connector service for actual initialization. - def initialize_account(self, account_name: str): - """ - Initialize all the connectors that are connected to the specified account. - :param account_name: The name of the account. - :return: + :param account_name: The name of the account to initialize connectors for. """ - for connector_name in self.list_credentials(account_name): + if not self._connector_service: + return + + # Initialize missing connectors + for connector_name in self._connector_service.list_available_credentials(account_name): try: - connector_name = connector_name.split(".")[0] - self.initialize_connector(account_name, connector_name) + # Only initialize if connector doesn't exist + if not self._connector_service.is_trading_connector_initialized(account_name, connector_name): + # Get connector will now handle all initialization + await self._connector_service.get_trading_connector(account_name, connector_name) except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") - - def initialize_connector(self, account_name: str, connector_name: str): - """ - Initialize the specified connector for the specified account. - :param account_name: The name of the account. - :param connector_name: The name of the connector. - :return: - """ - if account_name not in self.accounts: - self.accounts[account_name] = {} - try: - connector = self.get_connector(account_name, connector_name) - self.accounts[account_name][connector_name] = connector - except Exception as e: - logging.error(f"Error initializing connector {connector_name}: {e}") + logger.error(f"Error initializing connector {connector_name} for account {account_name}: {e}") - async def update_balances(self): - tasks = [] - for account_name, connectors in self.accounts.items(): - for connector_instance in connectors.values(): - tasks.append(self._safe_update_balances(connector_instance)) - await asyncio.gather(*tasks) + async def update_account_state( + self, + skip_gateway: bool = False, + account_names: Optional[List[str]] = None, + connector_names: Optional[List[str]] = None + ): + """Update account state for filtered connectors and optionally Gateway wallets. - async def _safe_update_balances(self, connector_instance): - try: - await connector_instance._update_balances() - except Exception as e: - logging.error(f"Error updating balances for connector {connector_instance}: {e}") + Args: + skip_gateway: If True, skip Gateway wallet balance updates for faster CEX-only queries. + account_names: If provided, only update these accounts. If None, update all accounts. + connector_names: If provided, only update these connectors. If None, update all connectors. + For Gateway, this filters by chain-network (e.g., 'solana-mainnet-beta'). + """ + all_connectors = self._connector_service.get_all_trading_connectors() if self._connector_service else {} - async def update_trading_rules(self): + # Prepare parallel tasks tasks = [] - for account_name, connectors in self.accounts.items(): - for connector_instance in connectors.values(): - tasks.append(self._safe_update_trading_rules(connector_instance)) - await asyncio.gather(*tasks) + task_meta = [] # (account_name, connector_name) - async def _safe_update_trading_rules(self, connector_instance): - try: - await connector_instance._update_trading_rules() - except Exception as e: - logging.error(f"Error updating trading rules for connector {connector_instance}: {e}") + for account_name, connectors in all_connectors.items(): + # Filter by account_names if specified + if account_names and account_name not in account_names: + continue - async def update_account_state(self): - for account_name, connectors in self.accounts.items(): if account_name not in self.accounts_state: self.accounts_state[account_name] = {} for connector_name, connector in connectors.items(): - tokens_info = [] - try: - balances = [{"token": key, "units": value} for key, value in connector.get_all_balances().items() if - value != Decimal("0") and key not in BANNED_TOKENS] - unique_tokens = [balance["token"] for balance in balances] - trading_pairs = [self.get_default_market(token) for token in unique_tokens if "USD" not in token] - last_traded_prices = await self._safe_get_last_traded_prices(connector, trading_pairs) - for balance in balances: - token = balance["token"] - if "USD" in token: - price = Decimal("1") - else: - market = self.get_default_market(balance["token"]) - price = Decimal(last_traded_prices.get(market, 0)) - tokens_info.append({ - "token": balance["token"], - "units": float(balance["units"]), - "price": float(price), - "value": float(price * balance["units"]), - "available_units": float(connector.get_available_balance(balance["token"])) - }) - self.account_state_update_event.set() - except Exception as e: - logging.error( - f"Error updating balances for connector {connector_name} in account {account_name}: {e}") - self.accounts_state[account_name][connector_name] = tokens_info + # Filter by connector_names if specified + if connector_names and connector_name not in connector_names: + continue + + tasks.append(self._get_connector_tokens_info(connector, connector_name)) + task_meta.append((account_name, connector_name)) + + # Execute connectors + gateway in parallel (unless skip_gateway is True) + if skip_gateway: + results = await asyncio.gather(*tasks, return_exceptions=True) + else: + # Pass connector_names filter to gateway for chain-network filtering + results = await asyncio.gather( + *tasks, + self._update_gateway_balances(chain_networks=connector_names), + return_exceptions=True + ) + # Remove gateway result from processing (it handles its own state internally) + results = results[:-1] + + # Process results + for (account_name, connector_name), result in zip(task_meta, results): + if isinstance(result, Exception): + logger.error(f"Error updating balances for connector {connector_name} in account {account_name}: {result}") + self.accounts_state[account_name][connector_name] = [] + else: + self.accounts_state[account_name][connector_name] = result + + async def _get_connector_tokens_info(self, connector, connector_name: str) -> List[Dict]: + """Get token info from a connector instance using RateOracle cached prices. + + Tries the RateOracle (instant, in-memory) first for each token. + Only falls back to a batch exchange call for tokens the oracle can't price. + """ + balances = [{"token": key, "units": value} for key, value in connector.get_all_balances().items() if + value != Decimal("0") and key not in settings.banned_tokens] + + tokens_info = [] + missing_pairs = [] # trading pairs the oracle can't price + missing_indices = [] # indices into tokens_info that need patching + + for balance in balances: + token = balance["token"] + if "USD" in token: + price = Decimal("1") + else: + # Try RateOracle first (instant, cached) + rate = None + if self._market_data_service: + rate = self._market_data_service.get_rate(token, "USDT") + if rate and rate > 0: + price = rate + else: + # Queue for fallback batch fetch from exchange + market = self.get_default_market(token, connector_name) + missing_pairs.append(market) + missing_indices.append(len(tokens_info)) + price = None # resolved below + + tokens_info.append({ + "token": token, + "units": float(balance["units"]), + "price": float(price) if price is not None else 0.0, + "value": float(price * balance["units"]) if price is not None else 0.0, + "available_units": float(connector.get_available_balance(token)) + }) + + # Batch-fetch only the missing prices from the exchange + if missing_pairs: + fallback_prices = await self._safe_get_last_traded_prices(connector, missing_pairs) + for pair_idx, info_idx in enumerate(missing_indices): + market = missing_pairs[pair_idx] + price = Decimal(str(fallback_prices.get(market, 0))) + tokens_info[info_idx]["price"] = float(price) + tokens_info[info_idx]["value"] = float(price * Decimal(str(tokens_info[info_idx]["units"]))) + + return tokens_info + + async def _safe_get_last_traded_prices(self, connector, trading_pairs, timeout=10): + """Safely get last traded prices with timeout and error handling. + Fetches each pair individually via gather so one bad pair doesn't kill the rest.""" + + async def _fetch_single(pair): + return pair, await connector._get_last_traded_price(trading_pair=pair) - async def _safe_get_last_traded_prices(self, connector, trading_pairs, timeout=5): try: - # TODO: Fix OKX connector to return the markets in Hummingbot format. - last_traded = await asyncio.wait_for(connector.get_last_traded_prices(trading_pairs=trading_pairs), timeout=timeout) - if connector.name == "okx_perpetual": - return {pair.strip("-SWAP"): value for pair, value in last_traded.items()} - return last_traded + results = await asyncio.wait_for( + asyncio.gather(*[_fetch_single(p) for p in trading_pairs], return_exceptions=True), + timeout=timeout, + ) except asyncio.TimeoutError: - logging.error(f"Timeout getting last traded prices for trading pairs {trading_pairs}") - return {pair: Decimal("0") for pair in trading_pairs} - except Exception as e: - logging.error(f"Error getting last traded prices for trading pairs {trading_pairs}: {e}") - return {pair: Decimal("0") for pair in trading_pairs} + logger.error(f"Timeout getting last traded prices for trading pairs {trading_pairs}") + return self._get_fallback_prices(trading_pairs) - @staticmethod - def get_connector_config_map(connector_name: str): + last_traded = {} + for result in results: + if isinstance(result, Exception): + logger.warning(f"Failed to get price for a pair: {result}") + continue + pair, price = result + if price and price > 0: + self._last_known_prices[pair] = price + last_traded[pair] = price + + # Fill in fallbacks for any pairs that failed + for pair in trading_pairs: + if pair not in last_traded: + if pair in self._last_known_prices: + last_traded[pair] = self._last_known_prices[pair] + logger.info(f"Using cached price {self._last_known_prices[pair]} for {pair}") + else: + last_traded[pair] = Decimal("0") + logger.warning(f"No cached price available for {pair}, using 0") + + return last_traded + + def _get_fallback_prices(self, trading_pairs): + """Get fallback prices using cached values, only setting to 0 if no previous price exists.""" + fallback_prices = {} + for pair in trading_pairs: + if pair in self._last_known_prices: + fallback_prices[pair] = self._last_known_prices[pair] + logger.info(f"Using cached price {self._last_known_prices[pair]} for {pair}") + else: + fallback_prices[pair] = Decimal("0") + logger.warning(f"No cached price available for {pair}, using 0") + return fallback_prices + + def get_connector_config_map(self, connector_name: str): """ Get the connector config map for the specified connector. :param connector_name: The name of the connector. :return: The connector config map. """ - connector_config = BackendAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) - return [key for key in connector_config.hb_config.__fields__.keys() if key != "connector"] + from services.unified_connector_service import UnifiedConnectorService + return UnifiedConnectorService.get_connector_config_map(connector_name) - async def add_connector_keys(self, account_name: str, connector_name: str, keys: dict): - BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) - connector_config = BackendAPIConfigAdapter(AllConnectorSettings.get_connector_config_keys(connector_name)) - for key, value in keys.items(): - setattr(connector_config, key, value) - BackendAPISecurity.update_connector_keys(account_name, connector_config) - new_connector = self.get_connector(account_name, connector_name) - await new_connector._update_balances() - self.accounts[account_name][connector_name] = new_connector - await self.update_account_state() - await self.dump_account_state() - - def get_connector(self, account_name: str, connector_name: str): + async def add_credentials(self, account_name: str, connector_name: str, credentials: dict): """ - Get the connector object for the specified account and connector. + Add or update connector credentials and initialize the connector with validation. + :param account_name: The name of the account. :param connector_name: The name of the connector. - :return: The connector object. - """ - BackendAPISecurity.login_account(account_name=account_name, secrets_manager=self.secrets_manager) - client_config_map = ClientConfigAdapter(ClientConfigMap()) - conn_setting = AllConnectorSettings.get_connector_settings()[connector_name] - keys = BackendAPISecurity.api_keys(connector_name) - read_only_config = ReadOnlyClientConfigAdapter.lock_config(client_config_map) - init_params = conn_setting.conn_init_parameters( - trading_pairs=[], - trading_required=True, - api_keys=keys, - client_config_map=read_only_config, - ) - connector_class = get_connector_class(connector_name) - connector = connector_class(**init_params) - return connector + :param credentials: Dictionary containing the connector credentials. + :raises Exception: If credentials are invalid or connector cannot be initialized. + """ + if not self._connector_service: + raise HTTPException(status_code=500, detail="Connector service not initialized") + + try: + # Update the connector keys (this saves the credentials to file and validates them) + connector = await self._connector_service.update_connector_keys(account_name, connector_name, credentials) + + await self.update_account_state() + except Exception as e: + logger.error(f"Error adding connector credentials for account {account_name}: {e}") + await self.delete_credentials(account_name, connector_name) + raise e @staticmethod def list_accounts(): @@ -312,33 +957,42 @@ def list_accounts(): List all the accounts that are connected to the trading system. :return: List of accounts. """ - return file_system.list_folders('credentials') + return fs_util.list_folders('credentials') - def list_credentials(self, account_name: str): + @staticmethod + def list_credentials(account_name: str): """ List all the credentials that are connected to the specified account. :param account_name: The name of the account. :return: List of credentials. """ try: - return [file for file in file_system.list_files(f'credentials/{account_name}/connectors') if + return [file for file in fs_util.list_files(f'credentials/{account_name}/connectors') if file.endswith('.yml')] except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) - def delete_credentials(self, account_name: str, connector_name: str): + async def delete_credentials(self, account_name: str, connector_name: str): """ Delete the credentials of the specified connector for the specified account. :param account_name: :param connector_name: :return: """ - if file_system.path_exists(f"credentials/{account_name}/connectors/{connector_name}.yml"): - file_system.delete_file(directory=f"credentials/{account_name}/connectors", file_name=f"{connector_name}.yml") - if connector_name in self.accounts[account_name]: - self.accounts[account_name].pop(connector_name) - if connector_name in self.accounts_state[account_name]: - self.accounts_state[account_name].pop(connector_name) + # Delete credentials file if it exists + if fs_util.path_exists(f"credentials/{account_name}/connectors/{connector_name}.yml"): + fs_util.delete_file(directory=f"credentials/{account_name}/connectors", file_name=f"{connector_name}.yml") + + # Always perform cleanup regardless of file existence + if self._connector_service: + # Stop the connector if it's running + await self._connector_service.stop_trading_connector(account_name, connector_name) + # Clear the connector from cache + self._connector_service.clear_trading_connector(account_name, connector_name) + + # Remove from account state + if account_name in self.accounts_state and connector_name in self.accounts_state[account_name]: + self.accounts_state[account_name].pop(connector_name) def add_account(self, account_name: str): """ @@ -346,22 +1000,1037 @@ def add_account(self, account_name: str): :param account_name: :return: """ - if account_name in self.accounts: + # Check if account already exists by looking at folders + if account_name in self.list_accounts(): raise HTTPException(status_code=400, detail="Account already exists.") + files_to_copy = ["conf_client.yml", "conf_fee_overrides.yml", "hummingbot_logs.yml", ".password_verification"] - file_system.create_folder('credentials', account_name) - file_system.create_folder(f'credentials/{account_name}', "connectors") + fs_util.create_folder('credentials', account_name) + fs_util.create_folder(f'credentials/{account_name}', "connectors") for file in files_to_copy: - file_system.copy_file(f"credentials/master_account/{file}", f"credentials/{account_name}/{file}") - self.accounts[account_name] = {} + fs_util.copy_file(f"credentials/master_account/{file}", f"credentials/{account_name}/{file}") + + # Initialize account state self.accounts_state[account_name] = {} - def delete_account(self, account_name: str): + async def delete_account(self, account_name: str): """ Delete the specified account. :param account_name: :return: """ - file_system.delete_folder('credentials', account_name) - self.accounts.pop(account_name) - self.accounts_state.pop(account_name) + # Stop all connectors for this account + if self._connector_service: + for connector_name in self._connector_service.list_account_connectors(account_name): + await self._connector_service.stop_trading_connector(account_name, connector_name) + # Clear all connectors for this account from cache + self._connector_service.clear_trading_connector(account_name) + + # Delete account folder + fs_util.delete_folder('credentials', account_name) + + # Remove from account state + if account_name in self.accounts_state: + self.accounts_state.pop(account_name) + + async def get_account_current_state(self, account_name: str) -> Dict[str, List[Dict]]: + """ + Get current state for a specific account from database. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_current_state(account_name) + except Exception as e: + logger.error(f"Error getting account current state: {e}") + # Fallback to in-memory state + return self.accounts_state.get(account_name, {}) + + async def get_account_state_history(self, + account_name: str, + limit: Optional[int] = None, + cursor: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + interval: str = "5m"): + """ + Get historical state for a specific account with pagination and interval sampling. + + Args: + account_name: Account name to filter by + limit: Maximum number of records to return + cursor: Cursor for pagination + start_time: Start time filter + end_time: End time filter + interval: Sampling interval (5m, 15m, 30m, 1h, 4h, 12h, 1d) + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_account_state_history( + account_name=account_name, + connector_name=connector_name, + limit=limit, + cursor=cursor, + start_time=start_time, + end_time=end_time + ) + except Exception as e: + logger.error(f"Error getting connector state history: {e}") + return [], None, False + + async def get_all_unique_tokens(self) -> List[str]: + """ + Get all unique tokens across all accounts and connectors. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_all_unique_tokens() + except Exception as e: + logger.error(f"Error getting unique tokens: {e}") + # Fallback to in-memory state + tokens = set() + for account_data in self.accounts_state.values(): + for connector_data in account_data.values(): + for token_info in connector_data: + tokens.add(token_info.get("token")) + return sorted(list(tokens)) + + async def get_token_current_state(self, token: str) -> List[Dict]: + """ + Get current state of a specific token across all accounts. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_token_current_state(token) + except Exception as e: + logger.error(f"Error getting token current state: {e}") + return [] + + async def get_portfolio_value(self, account_name: Optional[str] = None) -> Dict[str, any]: + """ + Get total portfolio value, optionally filtered by account. + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + repository = AccountRepository(session) + return await repository.get_portfolio_value(account_name) + except Exception as e: + logger.error(f"Error getting portfolio value: {e}") + # Fallback to in-memory calculation + portfolio = {"accounts": {}, "total_value": 0} + + accounts_to_process = [account_name] if account_name else self.accounts_state.keys() + + for acc_name in accounts_to_process: + account_value = 0 + if acc_name in self.accounts_state: + for connector_data in self.accounts_state[acc_name].values(): + for token_info in connector_data: + account_value += token_info.get("value", 0) + portfolio["accounts"][acc_name] = account_value + portfolio["total_value"] += account_value + + return portfolio + + def get_portfolio_distribution(self, account_name: Optional[str] = None) -> Dict[str, any]: + """ + Get portfolio distribution by tokens with percentages. + """ + try: + # Get accounts to process + accounts_to_process = [account_name] if account_name else list(self.accounts_state.keys()) + + # Aggregate all tokens across accounts and connectors + token_values = {} + total_value = 0 + + for acc_name in accounts_to_process: + if acc_name in self.accounts_state: + for connector_name, connector_data in self.accounts_state[acc_name].items(): + for token_info in connector_data: + token = token_info.get("token", "") + value = token_info.get("value", 0) + + if token not in token_values: + token_values[token] = { + "token": token, + "total_value": 0, + "total_units": 0, + "accounts": {} + } + + token_values[token]["total_value"] += value + token_values[token]["total_units"] += token_info.get("units", 0) + total_value += value + + # Track by account + if acc_name not in token_values[token]["accounts"]: + token_values[token]["accounts"][acc_name] = { + "value": 0, + "units": 0, + "connectors": {} + } + + token_values[token]["accounts"][acc_name]["value"] += value + token_values[token]["accounts"][acc_name]["units"] += token_info.get("units", 0) + + # Track by connector within account + if connector_name not in token_values[token]["accounts"][acc_name]["connectors"]: + token_values[token]["accounts"][acc_name]["connectors"][connector_name] = { + "value": 0, + "units": 0 + } + + token_values[token]["accounts"][acc_name]["connectors"][connector_name]["value"] += value + token_values[token]["accounts"][acc_name]["connectors"][connector_name]["units"] += token_info.get("units", 0) + + # Calculate percentages + distribution = [] + for token_data in token_values.values(): + percentage = (token_data["total_value"] / total_value * 100) if total_value > 0 else 0 + + token_dist = { + "token": token_data["token"], + "total_value": round(token_data["total_value"], 6), + "total_units": token_data["total_units"], + "percentage": round(percentage, 4), + "accounts": {} + } + + # Add account-level percentages + for acc_name, acc_data in token_data["accounts"].items(): + acc_percentage = (acc_data["value"] / total_value * 100) if total_value > 0 else 0 + token_dist["accounts"][acc_name] = { + "value": round(acc_data["value"], 6), + "units": acc_data["units"], + "percentage": round(acc_percentage, 4), + "connectors": {} + } + + # Add connector-level data + for conn_name, conn_data in acc_data["connectors"].items(): + token_dist["accounts"][acc_name]["connectors"][conn_name] = { + "value": round(conn_data["value"], 6), + "units": conn_data["units"] + } + + distribution.append(token_dist) + + # Sort by value (descending) + distribution.sort(key=lambda x: x["total_value"], reverse=True) + + return { + "total_portfolio_value": round(total_value, 6), + "token_count": len(distribution), + "distribution": distribution, + "account_filter": account_name if account_name else "all_accounts" + } + + except Exception as e: + logger.error(f"Error calculating portfolio distribution: {e}") + return { + "total_portfolio_value": 0, + "token_count": 0, + "distribution": [], + "account_filter": account_name if account_name else "all_accounts", + "error": str(e) + } + + def get_account_distribution(self) -> Dict[str, any]: + """ + Get portfolio distribution by accounts with percentages. + """ + try: + account_values = {} + total_value = 0 + + for acc_name, account_data in self.accounts_state.items(): + account_value = 0 + connector_values = {} + + for connector_name, connector_data in account_data.items(): + connector_value = 0 + for token_info in connector_data: + value = token_info.get("value", 0) + connector_value += value + account_value += value + + connector_values[connector_name] = round(connector_value, 6) + + account_values[acc_name] = { + "total_value": round(account_value, 6), + "connectors": connector_values + } + total_value += account_value + + # Calculate percentages + distribution = [] + for acc_name, acc_data in account_values.items(): + percentage = (acc_data["total_value"] / total_value * 100) if total_value > 0 else 0 + + connector_dist = {} + for conn_name, conn_value in acc_data["connectors"].items(): + conn_percentage = (conn_value / total_value * 100) if total_value > 0 else 0 + connector_dist[conn_name] = { + "value": conn_value, + "percentage": round(conn_percentage, 4) + } + + distribution.append({ + "account": acc_name, + "total_value": acc_data["total_value"], + "percentage": round(percentage, 4), + "connectors": connector_dist + }) + + # Sort by value (descending) + distribution.sort(key=lambda x: x["total_value"], reverse=True) + + return { + "total_portfolio_value": round(total_value, 6), + "account_count": len(distribution), + "distribution": distribution + } + + except Exception as e: + logger.error(f"Error calculating account distribution: {e}") + return { + "total_portfolio_value": 0, + "account_count": 0, + "distribution": [], + "error": str(e) + } + + async def place_trade(self, account_name: str, connector_name: str, trading_pair: str, + trade_type: TradeType, amount: Decimal, order_type: OrderType = OrderType.LIMIT, + price: Optional[Decimal] = None, position_action: PositionAction = PositionAction.OPEN) -> str: + """ + Place a trade using the specified account and connector. + + Args: + account_name: Name of the account to trade with + connector_name: Name of the connector/exchange + trading_pair: Trading pair (e.g., BTC-USDT) + trade_type: "BUY" or "SELL" + amount: Amount to trade + order_type: "LIMIT", "MARKET", or "LIMIT_MAKER" + price: Price for limit orders (required for LIMIT and LIMIT_MAKER) + position_action: Position action for perpetual contracts (OPEN/CLOSE) + + Returns: + Client order ID assigned by the connector + + Raises: + HTTPException: If account, connector not found, or trade fails + """ + # Validate account exists + if account_name not in self.list_accounts(): + raise HTTPException(status_code=404, detail=f"Account '{account_name}' not found") + + if not self._connector_service: + raise HTTPException(status_code=500, detail="Connector service not initialized") + + connector = await self._connector_service.get_trading_connector(account_name, connector_name) + + # Validate price for limit orders + if order_type in [OrderType.LIMIT, OrderType.LIMIT_MAKER] and price is None: + raise HTTPException(status_code=400, detail="Price is required for LIMIT and LIMIT_MAKER orders") + + # Check if trading rules are loaded + if not connector.trading_rules: + raise HTTPException( + status_code=503, + detail=f"Trading rules not yet loaded for {connector_name}. Please try again in a moment." + ) + + # Validate trading pair and get trading rule + if trading_pair not in connector.trading_rules: + available_pairs = list(connector.trading_rules.keys())[:10] # Show first 10 + more_text = f" (and {len(connector.trading_rules) - 10} more)" if len(connector.trading_rules) > 10 else "" + raise HTTPException( + status_code=400, + detail=f"Trading pair '{trading_pair}' not supported on {connector_name}. " + f"Available pairs: {available_pairs}{more_text}" + ) + + trading_rule = connector.trading_rules[trading_pair] + + # Validate order type is supported + if order_type not in connector.supported_order_types(): + supported_types = [ot.name for ot in connector.supported_order_types()] + raise HTTPException(status_code=400, detail=f"Order type '{order_type.name}' not supported. Supported types: {supported_types}") + + # Quantize amount according to trading rules + quantized_amount = connector.quantize_order_amount(trading_pair, amount) + + # Validate minimum order size + if quantized_amount < trading_rule.min_order_size: + raise HTTPException( + status_code=400, + detail=f"Order amount {quantized_amount} is below minimum order size {trading_rule.min_order_size} for {trading_pair}" + ) + + # Calculate and validate notional size + if order_type in [OrderType.LIMIT, OrderType.LIMIT_MAKER]: + quantized_price = connector.quantize_order_price(trading_pair, price) + notional_size = quantized_price * quantized_amount + else: + # For market orders without price, get current market price for validation + if self._market_data_service: + try: + prices = await self._market_data_service.get_prices(connector_name, [trading_pair]) + if trading_pair in prices and "error" not in prices: + price = Decimal(str(prices[trading_pair])) + except Exception as e: + logger.error(f"Error getting market price for {trading_pair}: {e}") + notional_size = price * quantized_amount if price else Decimal("0") + + if notional_size < trading_rule.min_notional_size: + raise HTTPException( + status_code=400, + detail=f"Order notional value {notional_size} is below minimum notional size {trading_rule.min_notional_size} for {trading_pair}. " + f"Increase the amount or price to meet the minimum requirement." + ) + + + + try: + connector._set_current_timestamp(time.time()) + # Place the order using the connector with quantized values + # (position_action will be ignored by non-perpetual connectors) + if trade_type == TradeType.BUY: + order_id = connector.buy( + trading_pair=trading_pair, + amount=quantized_amount, + order_type=order_type, + price=price or Decimal("1"), + position_action=position_action + ) + else: + order_id = connector.sell( + trading_pair=trading_pair, + amount=quantized_amount, + order_type=order_type, + price=price or Decimal("1"), + position_action=position_action + ) + + logger.info(f"Placed {trade_type} order for {amount} {trading_pair} on {connector_name} (Account: {account_name}). Order ID: {order_id}") + return order_id + + except HTTPException: + # Re-raise HTTP exceptions as-is + raise + except Exception as e: + logger.error(f"Failed to place {trade_type} order: {e}") + raise HTTPException(status_code=500, detail=f"Failed to place trade: {str(e)}") + + async def get_connector_instance(self, account_name: str, connector_name: str): + """ + Get a connector instance for direct access. + + Args: + account_name: Name of the account + connector_name: Name of the connector + + Returns: + Connector instance + + Raises: + HTTPException: If account or connector not found + """ + if account_name not in self.list_accounts(): + raise HTTPException(status_code=404, detail=f"Account '{account_name}' not found") + + if not self._connector_service: + raise HTTPException(status_code=500, detail="Connector service not initialized") + + return await self._connector_service.get_trading_connector(account_name, connector_name) + + async def _get_perpetual_connector(self, account_name: str, connector_name: str): + """ + Get a perpetual connector instance with validation. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + Perpetual connector instance + + Raises: + HTTPException: If connector is not perpetual or not found + """ + if "_perpetual" not in connector_name: + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' is not a perpetual connector") + return await self.get_connector_instance(account_name, connector_name) + + async def get_active_orders(self, account_name: str, connector_name: str) -> Dict[str, any]: + """ + Get active orders for a specific connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector + + Returns: + Dictionary of active orders + """ + connector = await self.get_connector_instance(account_name, connector_name) + return {order_id: order.to_json() for order_id, order in connector.in_flight_orders.items()} + + async def cancel_order(self, account_name: str, connector_name: str, client_order_id: str) -> str: + """ + Cancel an active order. + + Args: + account_name: Name of the account + connector_name: Name of the connector + client_order_id: Client order ID to cancel + + Returns: + Client order ID that was cancelled + + Raises: + HTTPException: 404 if order not found, 500 if cancellation fails + """ + connector = await self.get_connector_instance(account_name, connector_name) + + # Check if order exists in in-flight orders + if client_order_id not in connector.in_flight_orders: + raise HTTPException(status_code=404, detail=f"Order '{client_order_id}' not found in active orders") + + try: + result = connector.cancel(trading_pair="NA", client_order_id=client_order_id) + logger.info(f"Initiated cancellation for order {client_order_id} on {connector_name} (Account: {account_name})") + return result + except Exception as e: + logger.error(f"Failed to initiate cancellation for order {client_order_id}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to initiate order cancellation: {str(e)}") + + async def set_leverage(self, account_name: str, connector_name: str, + trading_pair: str, leverage: int) -> Dict[str, str]: + """ + Set leverage for a specific trading pair on a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + trading_pair: Trading pair to set leverage for + leverage: Leverage value (typically 1-125) + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + if not hasattr(connector, '_execute_set_leverage'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support leverage setting") + + try: + await connector._execute_set_leverage(trading_pair, leverage) + message = f"Leverage for {trading_pair} set to {leverage} on {connector_name}" + logger.info(f"Set leverage for {trading_pair} to {leverage} on {connector_name} (Account: {account_name})") + return {"status": "success", "message": message} + + except Exception as e: + logger.error(f"Failed to set leverage for {trading_pair} to {leverage}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to set leverage: {str(e)}") + + async def set_position_mode(self, account_name: str, connector_name: str, + position_mode: PositionMode) -> Dict[str, str]: + """ + Set position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + position_mode: PositionMode.HEDGE or PositionMode.ONEWAY + + Returns: + Dictionary with success status and message + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + # Check if the requested position mode is supported + supported_modes = connector.supported_position_modes() + if position_mode not in supported_modes: + supported_values = [mode.value for mode in supported_modes] + raise HTTPException( + status_code=400, + detail=f"Position mode '{position_mode.value}' not supported. Supported modes: {supported_values}" + ) + + try: + # Try to call the method - it might be sync or async + result = connector.set_position_mode(position_mode) + # If it's a coroutine, await it + if asyncio.iscoroutine(result): + await result + + message = f"Position mode set to {position_mode.value} on {connector_name}" + logger.info(f"Set position mode to {position_mode.value} on {connector_name} (Account: {account_name})") + return {"status": "success", "message": message} + + except Exception as e: + logger.error(f"Failed to set position mode to {position_mode.value}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to set position mode: {str(e)}") + + async def get_position_mode(self, account_name: str, connector_name: str) -> Dict[str, str]: + """ + Get current position mode for a perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + Dictionary with current position mode + + Raises: + HTTPException: If account/connector not found, not perpetual, or operation fails + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + if not hasattr(connector, 'position_mode'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position mode") + + try: + current_mode = connector.position_mode + return { + "position_mode": current_mode.value if current_mode else "UNKNOWN", + "connector": connector_name, + "account": account_name + } + + except Exception as e: + logger.error(f"Failed to get position mode: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get position mode: {str(e)}") + + async def get_orders(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, status: Optional[str] = None, + start_time: Optional[int] = None, end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Dict]: + """Get order history using OrderRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + orders = await order_repo.get_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + status=status, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + return [order_repo.to_dict(order) for order in orders] + except Exception as e: + logger.error(f"Error getting orders: {e}") + return [] + + async def get_active_orders_history(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None) -> List[Dict]: + """Get active orders from database using OrderRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + orders = await order_repo.get_active_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair + ) + return [order_repo.to_dict(order) for order in orders] + except Exception as e: + logger.error(f"Error getting active orders: {e}") + return [] + + async def get_orders_summary(self, account_name: Optional[str] = None, start_time: Optional[int] = None, + end_time: Optional[int] = None) -> Dict: + """Get order summary statistics using OrderRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + return await order_repo.get_orders_summary( + account_name=account_name, + start_time=start_time, + end_time=end_time + ) + except Exception as e: + logger.error(f"Error getting orders summary: {e}") + return { + "total_orders": 0, + "filled_orders": 0, + "cancelled_orders": 0, + "failed_orders": 0, + "active_orders": 0, + "fill_rate": 0, + } + + async def get_trades(self, account_name: Optional[str] = None, connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, trade_type: Optional[str] = None, + start_time: Optional[int] = None, end_time: Optional[int] = None, + limit: int = 100, offset: int = 0) -> List[Dict]: + """Get trade history using TradeRepository.""" + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + trade_repo = TradeRepository(session) + trade_order_pairs = await trade_repo.get_trades_with_orders( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + trade_type=trade_type, + start_time=start_time, + end_time=end_time, + limit=limit, + offset=offset + ) + return [trade_repo.to_dict(trade, order) for trade, order in trade_order_pairs] + except Exception as e: + logger.error(f"Error getting trades: {e}") + return [] + + async def get_account_positions(self, account_name: str, connector_name: str) -> List[Dict]: + """ + Get current positions for a specific perpetual connector. + + Args: + account_name: Name of the account + connector_name: Name of the connector (must be perpetual) + + Returns: + List of position dictionaries + + Raises: + HTTPException: If account/connector not found or not perpetual + """ + connector = await self._get_perpetual_connector(account_name, connector_name) + + if not hasattr(connector, 'account_positions'): + raise HTTPException(status_code=400, detail=f"Connector '{connector_name}' does not support position tracking") + + try: + # Force position update to ensure current market prices are used + await connector._update_positions() + + positions = [] + raw_positions = connector.account_positions + + for trading_pair, position_info in raw_positions.items(): + # Convert position data to dict format + position_dict = { + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": position_info.trading_pair, + "side": position_info.position_side.name if hasattr(position_info, 'position_side') else "UNKNOWN", + "amount": float(position_info.amount) if hasattr(position_info, 'amount') else 0.0, + "entry_price": float(position_info.entry_price) if hasattr(position_info, 'entry_price') else None, + "unrealized_pnl": float(position_info.unrealized_pnl) if hasattr(position_info, 'unrealized_pnl') else None, + "leverage": float(position_info.leverage) if hasattr(position_info, 'leverage') else None, + } + + # Only include positions with non-zero amounts + if position_dict["amount"] != 0: + positions.append(position_dict) + + return positions + + except Exception as e: + logger.error(f"Failed to get positions for {connector_name}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get positions: {str(e)}") + + async def get_funding_payments(self, account_name: str, connector_name: str = None, + trading_pair: str = None, limit: int = 100) -> List[Dict]: + """ + Get funding payment history for an account. + + Args: + account_name: Name of the account + connector_name: Optional connector name filter + trading_pair: Optional trading pair filter + limit: Maximum number of records to return + + Returns: + List of funding payment dictionaries + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + funding_repo = FundingRepository(session) + funding_payments = await funding_repo.get_funding_payments( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + limit=limit + ) + return [funding_repo.to_dict(payment) for payment in funding_payments] + + except Exception as e: + logger.error(f"Error getting funding payments: {e}") + return [] + + async def get_total_funding_fees(self, account_name: str, connector_name: str, + trading_pair: str) -> Dict: + """ + Get total funding fees for a specific trading pair. + + Args: + account_name: Name of the account + connector_name: Name of the connector + trading_pair: Trading pair to get fees for + + Returns: + Dictionary with total funding fees information + """ + await self.ensure_db_initialized() + + try: + async with self.db_manager.get_session_context() as session: + funding_repo = FundingRepository(session) + return await funding_repo.get_total_funding_fees( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair + ) + + except Exception as e: + logger.error(f"Error getting total funding fees: {e}") + return { + "total_funding_fees": 0, + "payment_count": 0, + "fee_currency": None, + "error": str(e) + } + + # ============================================ + # Gateway Wallet Management Methods + # ============================================ + + async def _update_gateway_balances(self, chain_networks: Optional[List[str]] = None): + """Update Gateway wallet balances in master_account state. + + Only queries the defaultWallet on each network in defaultNetworks for each chain. + This is more efficient than querying all wallets on all networks. + + Args: + chain_networks: If provided, only update these chain-network combinations + (e.g., ['solana-mainnet-beta', 'ethereum-mainnet']). + If None, update all defaultNetworks for each chain. + """ + try: + # Check if Gateway is available + if not await self.gateway_client.ping(): + logger.debug("Gateway service is not available, skipping wallet balance update") + return + + # Get all available chains + chains_result = await self.gateway_client.get_chains() + if not chains_result or "chains" not in chains_result: + logger.error("Could not get chains from Gateway") + return + + known_chains = {c["chain"] for c in chains_result["chains"]} + + # Ensure master_account exists in accounts_state + if "master_account" not in self.accounts_state: + self.accounts_state["master_account"] = {} + + # Collect all balance query tasks for parallel execution + balance_tasks = [] + task_metadata = [] # Store (chain, network, address) for each task + + # For each chain, get its config with defaultWallet and defaultNetworks + for chain_info in chains_result["chains"]: + chain = chain_info["chain"] + networks = chain_info.get("networks", []) + + if not networks: + logger.debug(f"Chain '{chain}' has no networks configured, skipping") + continue + + # Get merged config using chain-network namespace (e.g., solana-mainnet-beta) + # This returns both chain-level fields (defaultWallet, defaultNetworks) and network fields + first_network = networks[0] + try: + config = await self.gateway_client.get_config(f"{chain}-{first_network}") + except Exception as e: + logger.warning(f"Could not get config for '{chain}-{first_network}': {e}") + continue + + default_wallet = config.get("defaultWallet") + default_networks = config.get("defaultNetworks", []) + + if not default_wallet: + logger.debug(f"Chain '{chain}' missing defaultWallet, skipping") + continue + + # Skip placeholder wallet addresses (e.g., "ethereum-default-wallet", "solana-default-wallet") + if default_wallet.endswith("-default-wallet"): + logger.debug(f"Chain '{chain}' has placeholder defaultWallet '{default_wallet}', skipping") + continue + + if not default_networks: + # Fall back to defaultNetwork (singular) if defaultNetworks not set + default_network = config.get("defaultNetwork") + if default_network: + default_networks = [default_network] + else: + logger.debug(f"Chain '{chain}' missing defaultNetworks, skipping") + continue + + # Create balance tasks for each default network + for network in default_networks: + chain_network_key = f"{chain}-{network}" + + # Filter by chain_networks if specified + if chain_networks and chain_network_key not in chain_networks: + continue + + balance_tasks.append(self.get_gateway_balances(chain, default_wallet, network=network)) + task_metadata.append((chain, network, default_wallet)) + + # Build set of active chain-network keys + active_chain_networks = {f"{chain}-{network}" for chain, network, _ in task_metadata} + + # Execute all balance queries in parallel + if balance_tasks: + results = await asyncio.gather(*balance_tasks, return_exceptions=True) + + # Process results + for result, (chain, network, address) in zip(results, task_metadata): + chain_network = f"{chain}-{network}" + + if isinstance(result, Exception): + logger.error(f"Error updating Gateway balances for {chain}-{network} wallet {address}: {result}") + # Store empty list for error state + self.accounts_state["master_account"][chain_network] = [] + elif result: + # Only store if there are actual balances (non-empty list) + self.accounts_state["master_account"][chain_network] = result + else: + # Store empty list to indicate we checked this network + self.accounts_state["master_account"][chain_network] = [] + + # Only remove stale keys if we're doing a full update (no filter) + # When filtering, we don't want to remove keys that weren't in the filter + if not chain_networks: + # Remove stale gateway chain-network keys (default network/wallet changed or no longer configured) + # Gateway keys follow pattern: chain-network (e.g., "solana-mainnet-beta", "ethereum-mainnet") + stale_keys = [] + for key in self.accounts_state["master_account"]: + # Check if key looks like a gateway chain-network (contains hyphen and matches chain pattern) + if "-" in key and key not in active_chain_networks: + # Verify it's a gateway key by checking if chain part matches known chains + chain_part = key.split("-")[0] + if chain_part in known_chains: + stale_keys.append(key) + + for key in stale_keys: + logger.info(f"Removing stale Gateway balance data for {key} (no longer default network )") + del self.accounts_state["master_account"][key] + + except Exception as e: + logger.error(f"Error updating Gateway balances: {e}") + + async def get_gateway_wallets(self) -> List[Dict]: + """ + Get all wallets from Gateway. Gateway manages its own encrypted wallets. + + Returns: + List of wallet information from Gateway + """ + if not await self.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + try: + wallets = await self.gateway_client.get_wallets() + return wallets + except Exception as e: + logger.error(f"Error getting Gateway wallets: {e}") + raise HTTPException(status_code=500, detail=f"Failed to get wallets: {str(e)}") + + async def add_gateway_wallet(self, chain: str, private_key: str) -> Dict: + """ + Add a wallet to Gateway. Gateway handles encryption internally. + + Args: + chain: Blockchain chain (e.g., 'solana', 'ethereum') + private_key: Wallet private key + + Returns: + Dictionary with wallet information from Gateway + """ + if not await self.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + try: + result = await self.gateway_client.add_wallet(chain, private_key, set_default=True) + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Gateway error: {result['error']}") + + logger.info(f"Added {chain} wallet {result.get('address')} to Gateway") + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error adding Gateway wallet: {e}") + raise HTTPException(status_code=500, detail=f"Failed to add wallet: {str(e)}") + + async def remove_gateway_wallet(self, chain: str, address: str) -> Dict: + """ + Remove a wallet from Gateway. + + Args: + chain: Blockchain chain + address: Wallet address to remove + + Returns: + Success message + """ + if not await self.gateway_client.ping(): + raise HTTPException(status_code=503, detail="Gateway service is not available") + + try: + result = await self.gateway_client.remove_wallet(chain, address) + + if "error" in result: + raise HTTPException(status_code=400, detail=f"Gateway error: {result['error']}") + + logger.info(f"Removed {chain} wallet {address} from Gateway") + return {"success": True, "message": f"Successfully removed {chain} wallet"} + + except HTTPException: + raise + except Exception as e: + logger.errow" \ No newline at end of file diff --git a/services/bots_orchestrator.py b/services/bots_orchestrator.py index 24ae656f..85622f14 100644 --- a/services/bots_orchestrator.py +++ b/services/bots_orchestrator.py @@ -1,173 +1,348 @@ import asyncio -from collections import deque +import logging from typing import Optional +import re import docker -from hbotrc import BotCommands -from hbotrc.listener import BotListener -from hbotrc.spec import TopicSpecs - - -class HummingbotPerformanceListener(BotListener): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - topic_prefix = TopicSpecs.PREFIX.format( - namespace=self._ns, - instance_id=self._bot_id - ) - self._performance_topic = f'{topic_prefix}/performance' - self._bot_performance = {} - self._bot_error_logs = deque(maxlen=100) - self._bot_general_logs = deque(maxlen=100) - self.performance_report_sub = None - - def get_bot_performance(self): - return self._bot_performance - - def get_bot_error_logs(self): - return list(self._bot_error_logs) - - def get_bot_general_logs(self): - return list(self._bot_general_logs) - - def _init_endpoints(self): - super()._init_endpoints() - self.performance_report_sub = self.create_subscriber(topic=self._performance_topic, - on_message=self._update_bot_performance) - - def _update_bot_performance(self, msg): - for controller_id, performance_report in msg.items(): - self._bot_performance[controller_id] = performance_report - - def _on_log(self, log): - if log.level_name == "ERROR": - self._bot_error_logs.append(log) - else: - self._bot_general_logs.append(log) - def stop(self): - super().stop() - self._bot_performance = {} +from utils.mqtt_manager import MQTTManager + +logger = logging.getLogger(__name__) + + +# HummingbotPerformanceListener class is no longer needed +# All functionality is now handled by MQTTManager -class BotsManager: +class BotsOrchestrator: + """Orchestrates Hummingbot instances using Docker and MQTT communication.""" + def __init__(self, broker_host, broker_port, broker_username, broker_password): self.broker_host = broker_host self.broker_port = broker_port self.broker_username = broker_username self.broker_password = broker_password + + # Initialize Docker client self.docker_client = docker.from_env() + + # Initialize MQTT manager + self.mqtt_manager = MQTTManager(host=broker_host, port=broker_port, username=broker_username, password=broker_password) + + # Active bots tracking self.active_bots = {} self._update_bots_task: Optional[asyncio.Task] = None + + # Track bots that are currently being stopped and archived + self.stopping_bots = set() + + # MQTT manager will be started asynchronously later @staticmethod def hummingbot_containers_fiter(container): + """Filter for Hummingbot containers based on image name pattern.""" try: - return "hummingbot" in container.name and "broker" not in container.name + # Get the image name (first tag if available, otherwise the image ID) + image_name = container.image.tags[0] if container.image.tags else str(container.image) + pattern = r'.+/hummingbot:' + return bool(re.match(pattern, image_name)) except Exception: return False - def get_active_containers(self): - return [container.name for container in self.docker_client.containers.list() - if container.status == 'running' and self.hummingbot_containers_fiter(container)] + async def get_active_containers(self): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self._sync_get_active_containers) + + def _sync_get_active_containers(self): + return [ + container.name + for container in self.docker_client.containers.list() + if container.status == "running" and self.hummingbot_containers_fiter(container) + ] - def start_update_active_bots_loop(self): - self._update_bots_task = asyncio.create_task(self.update_active_bots()) + def start(self): + """Start the loop that monitors active bots.""" + # Start MQTT manager and update loop in async context + self._update_bots_task = asyncio.create_task(self._start_async()) - def stop_update_active_bots_loop(self): + async def _start_async(self): + """Start MQTT manager and update loop asynchronously.""" + logger.info("Starting MQTT manager...") + await self.mqtt_manager.start() + + # Then start the update loop + await self.update_active_bots() + + def stop(self): + """Stop the active bots monitoring loop.""" if self._update_bots_task: self._update_bots_task.cancel() self._update_bots_task = None - async def update_active_bots(self, sleep_time=1): + # Stop MQTT manager asynchronously + asyncio.create_task(self.mqtt_manager.stop()) + + async def update_active_bots(self, sleep_time=1.0): + """Monitor and update active bots list using both Docker and MQTT discovery.""" while True: - active_hbot_containers = self.get_active_containers() - # Remove bots that are no longer active - for bot in list(self.active_bots): - if bot not in active_hbot_containers: - del self.active_bots[bot] - - # Add new bots or update existing ones - for bot in active_hbot_containers: - if bot not in self.active_bots: - hbot_listener = HummingbotPerformanceListener(host=self.broker_host, port=self.broker_port, - username=self.broker_username, - password=self.broker_password, - bot_id=bot) - hbot_listener.start() - self.active_bots[bot] = { - "bot_name": bot, - "broker_client": BotCommands(host=self.broker_host, port=self.broker_port, - username=self.broker_username, password=self.broker_password, - bot_id=bot), - "broker_listener": hbot_listener, - } + try: + # Get bots from Docker containers + docker_bots = await self.get_active_containers() + + # Get bots from MQTT messages (auto-discovered) + mqtt_bots = self.mqtt_manager.get_discovered_bots(timeout_seconds=30) # 30 second timeout + + # Combine both sources + all_active_bots = set([bot for bot in docker_bots + mqtt_bots if not self.is_bot_stopping(bot)]) + + # Remove bots that are no longer active + for bot_name in list(self.active_bots): + if bot_name not in all_active_bots: + self.mqtt_manager.clear_bot_data(bot_name) + del self.active_bots[bot_name] + + # Add new bots + for bot_name in all_active_bots: + if bot_name not in self.active_bots: + self.active_bots[bot_name] = { + "bot_name": bot_name, + "status": "connected", + "source": "docker" if bot_name in docker_bots else "mqtt", + } + # Subscribe to this specific bot's topics + await self.mqtt_manager.subscribe_to_bot(bot_name) + + except Exception as e: + logger.error(f"Error in update_active_bots: {e}", exc_info=True) + await asyncio.sleep(sleep_time) # Interact with a specific bot - def start_bot(self, bot_name, **kwargs): - if bot_name in self.active_bots: - self.active_bots[bot_name]["broker_listener"].start() - return self.active_bots[bot_name]["broker_client"].start(**kwargs) + async def start_bot(self, bot_name, **kwargs): + """ + Start a bot with optional script. + Maintains backward compatibility with kwargs. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create StartCommandMessage.Request format + data = { + "log_level": kwargs.get("log_level"), + "script": kwargs.get("script"), + "conf": kwargs.get("conf"), + "is_quickstart": kwargs.get("is_quickstart", False), + "async_backend": kwargs.get("async_backend", True), + } + + success = await self.mqtt_manager.publish_command(bot_name, "start", data) + return {"success": success} + + async def stop_bot(self, bot_name, **kwargs): + """ + Stop a bot. + Maintains backward compatibility with kwargs. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} - def stop_bot(self, bot_name, **kwargs): - if bot_name in self.active_bots: - self.active_bots[bot_name]["broker_listener"].stop() - return self.active_bots[bot_name]["broker_client"].stop(**kwargs) + # Create StopCommandMessage.Request format + data = { + "skip_order_cancellation": kwargs.get("skip_order_cancellation", False), + "async_backend": kwargs.get("async_backend", True), + } - def import_strategy_for_bot(self, bot_name, strategy, **kwargs): - if bot_name in self.active_bots: - return self.active_bots[bot_name]["broker_client"].import_strategy(strategy, **kwargs) + success = await self.mqtt_manager.publish_command(bot_name, "stop", data) - def configure_bot(self, bot_name, params, **kwargs): - if bot_name in self.active_bots: - return self.active_bots[bot_name]["broker_client"].config(params, **kwargs) + # Clear performance data after stop command to immediately reflect stopped status + if success: + self.mqtt_manager.clear_bot_controller_reports(bot_name) - def get_bot_history(self, bot_name, **kwargs): - if bot_name in self.active_bots: - return self.active_bots[bot_name]["broker_client"].history(**kwargs) + return {"success": success} + + async def import_strategy_for_bot(self, bot_name, strategy, **kwargs): + """ + Import a strategy configuration for a bot. + Maintains backward compatibility. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create ImportCommandMessage.Request format + data = {"strategy": strategy} + success = await self.mqtt_manager.publish_command(bot_name, "import_strategy", data) + return {"success": success} + + async def configure_bot(self, bot_name, params, **kwargs): + """ + Configure bot parameters. + Maintains backward compatibility. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create ConfigCommandMessage.Request format + data = {"params": params} + success = await self.mqtt_manager.publish_command(bot_name, "config", data) + return {"success": success} + + async def get_bot_history(self, bot_name, **kwargs): + """ + Request bot trading history and wait for the response. + Maintains backward compatibility. + """ + if bot_name not in self.active_bots: + logger.warning(f"Bot {bot_name} not found in active bots") + return {"success": False, "message": f"Bot {bot_name} not found"} + + # Create HistoryCommandMessage.Request format + data = { + "days": kwargs.get("days", 0), + "verbose": kwargs.get("verbose", False), + "precision": kwargs.get("precision"), + "async_backend": kwargs.get("async_backend", False), + } + + # Use the new RPC method to wait for response + timeout = kwargs.get("timeout", 30.0) # Default 30 second timeout + response = await self.mqtt_manager.publish_command_and_wait(bot_name, "history", data, timeout=timeout) + + if response is None: + return { + "success": False, + "message": f"No response received from {bot_name} within {timeout} seconds", + "timeout": True, + } + + return {"success": True, "data": response} @staticmethod - def determine_controller_performance(controllers_performance): - cleaned_performance = {} - for controller, performance in controllers_performance.items(): + def determine_controller_performance(controller_reports): + """Process controller reports and extract performance and custom_info. + + Args: + controller_reports: Dict with controller_id as key and report dict as value. + New format: Each report contains 'performance' and 'custom_info' keys. + Old format: Report contains performance metrics directly (backward compatible). + + Returns: + Dict with cleaned controller data including status, performance, and custom_info. + """ + cleaned_data = {} + for controller_id, report in controller_reports.items(): try: - # Check if all the metrics are numeric - _ = sum(metric for key, metric in performance.items() if key != "close_type_counts") - cleaned_performance[controller] = { + # Support both new format (nested) and old format (flat) + # New format: {"performance": {...}, "custom_info": {...}} + # Old format: {...performance metrics directly...} + if "performance" in report: + # New format with nested structure + performance = report.get("performance", {}) + custom_info = report.get("custom_info", {}) + else: + # Old format - metrics are directly in the report + performance = report + custom_info = {} + + # Validate performance metrics are numeric (skip known non-numeric fields) + non_numeric_fields = ("positions_summary", "close_type_counts") + _ = sum( + metric for key, metric in performance.items() + if key not in non_numeric_fields and isinstance(metric, (int, float)) + ) + + cleaned_data[controller_id] = { "status": "running", - "performance": performance + "performance": performance, + "custom_info": custom_info } except Exception as e: - cleaned_performance[controller] = { + # Handle both formats in error case too + if "performance" in report: + perf = report.get("performance", {}) + info = report.get("custom_info", {}) + else: + perf = report + info = {} + cleaned_data[controller_id] = { "status": "error", - "error": f"Some metrics are not numeric, check logs and restart controller: {e}", + "error": f"Error processing controller data: {e}", + "performance": perf, + "custom_info": info } - return cleaned_performance + return cleaned_data def get_all_bots_status(self): + # TODO: improve logic of bots state management + """Get status information for all active bots.""" all_bots_status = {} - for bot in self.active_bots: - all_bots_status[bot] = self.get_bot_status(bot) + for bot in [bot for bot in self.active_bots if not self.is_bot_stopping(bot)]: + status = self.get_bot_status(bot) + status["source"] = self.active_bots[bot].get("source", "unknown") + all_bots_status[bot] = status return all_bots_status def get_bot_status(self, bot_name): - if bot_name in self.active_bots: - try: - broker_listner = self.active_bots[bot_name]["broker_listener"] - controllers_performance = broker_listner.get_bot_performance() - performance = self.determine_controller_performance(controllers_performance) - error_logs = broker_listner.get_bot_error_logs() - general_logs = broker_listner.get_bot_general_logs() - status = "running" if len(performance) > 0 else "stopped" - return { - "status": status, - "performance": performance, - "error_logs": error_logs, - "general_logs": general_logs - } - except Exception as e: + """ + Get status information for a specific bot. + """ + if bot_name not in self.active_bots: + return {"status": "not_found", "error": f"Bot {bot_name} not found"} + + try: + # Check if bot is currently being stopped and archived + if bot_name in self.stopping_bots: return { - "status": "error", - "error": str(e) + "status": "stopping", + "message": "Bot is currently being stopped and archived", + "performance": {}, + "error_logs": [], + "general_logs": [], + "recently_active": False, } + + # Get data from MQTT manager + controller_reports = self.mqtt_manager.get_bot_controller_reports(bot_name) + performance = self.determine_controller_performance(controller_reports) + error_logs = self.mqtt_manager.get_bot_error_logs(bot_name) + general_logs = self.mqtt_manager.get_bot_logs(bot_name) + + # Check if bot has sent recent messages (within last 30 seconds) + discovered_bots = self.mqtt_manager.get_discovered_bots(timeout_seconds=30) + recently_active = bot_name in discovered_bots + + # Determine status based on performance data and recent activity + if len(performance) > 0 and recently_active: + status = "running" + elif len(performance) > 0 and not recently_active: + status = "idle" # Has performance data but no recent activity + else: + status = "stopped" + + return { + "status": status, + "performance": performance, + "error_logs": error_logs, + "general_logs": general_logs, + "recently_active": recently_active, + } + except Exception as e: + return {"status": "error", "error": str(e)} + + def set_bot_stopping(self, bot_name: str): + """Mark a bot as currently being stopped and archived.""" + self.stopping_bots.add(bot_name) + logger.info(f"Marked bot {bot_name} as stopping") + + def clear_bot_stopping(self, bot_name: str): + """Clear the stopping status for a bot.""" + self.stopping_bots.discard(bot_name) + logger.info(f"Cleared stopping status for bot {bot_name}") + + def is_bot_stopping(self, bot_name: str) -> bool: + """Check if a bot is currently being stopped.""" + return bot_name in self.stopping_bots + diff --git a/services/docker_service.py b/services/docker_service.py index 2232893c..a474499f 100644 --- a/services/docker_service.py +++ b/services/docker_service.py @@ -1,31 +1,65 @@ import logging import os import shutil +import threading +import time +from typing import Dict import docker from docker.errors import DockerException from docker.types import LogConfig -from models import HummingbotInstanceConfig -from utils.file_system import FileSystemUtil +from config import settings +from models import V2ControllerDeployment +from utils.file_system import fs_util -file_system = FileSystemUtil() +# Create module-specific logger +logger = logging.getLogger(__name__) -class DockerManager: +class DockerService: + # Class-level configuration for cleanup + PULL_STATUS_MAX_AGE_SECONDS = 3600 # Keep status for 1 hour + PULL_STATUS_MAX_ENTRIES = 100 # Maximum number of entries to keep + CLEANUP_INTERVAL_SECONDS = 300 # Run cleanup every 5 minutes + def __init__(self): self.SOURCE_PATH = os.getcwd() + self._pull_status: Dict[str, Dict] = {} + self._cleanup_thread = None + self._stop_cleanup = threading.Event() + try: self.client = docker.from_env() + # Start background cleanup thread + self._start_cleanup_thread() except DockerException as e: - logging.error(f"It was not possible to connect to Docker. Please make sure Docker is running. Error: {e}") + logger.error(f"It was not possible to connect to Docker. Please make sure Docker is running. Error: {e}") - def get_active_containers(self): + def get_active_containers(self, name_filter: str = None): try: - containers_info = [{"id": container.id, "name": container.name, "status": container.status} for - container in self.client.containers.list(filters={"status": "running"}) if - "hummingbot" in container.name and "broker" not in container.name] - return {"active_instances": containers_info} + all_containers = self.client.containers.list(filters={"status": "running"}) + if name_filter: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers if name_filter.lower() in container.name.lower() + ] + else: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers + ] + return containers_info except DockerException as e: return str(e) @@ -38,16 +72,42 @@ def get_available_images(self): def pull_image(self, image_name): try: - self.client.images.pull(image_name) + return self.client.images.pull(image_name) except DockerException as e: return str(e) - def get_exited_containers(self): + def pull_image_sync(self, image_name): + """Synchronous pull operation for background tasks""" try: - containers_info = [{"id": container.id, "name": container.name, "status": container.status} for - container in self.client.containers.list(filters={"status": "exited"}) if - "hummingbot" in container.name and "broker" not in container.name] - return {"exited_instances": containers_info} + result = self.client.images.pull(image_name) + return {"success": True, "image": image_name, "result": str(result)} + except DockerException as e: + return {"success": False, "error": str(e)} + + def get_exited_containers(self, name_filter: str = None): + try: + all_containers = self.client.containers.list(filters={"status": "exited"}, all=True) + if name_filter: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers if name_filter.lower() in container.name.lower() + ] + else: + containers_info = [ + { + "id": container.id, + "name": container.name, + "status": container.status, + "image": container.image.tags[0] if container.image.tags else container.image.id[:12] + } + for container in all_containers + ] + return containers_info except DockerException as e: return str(e) @@ -78,6 +138,21 @@ def start_container(self, container_name): except DockerException as e: return str(e) + def get_container_status(self, container_name): + """Get the status of a container""" + try: + container = self.client.containers.get(container_name) + return { + "success": True, + "state": { + "status": container.status, + "running": container.status == "running", + "exit_code": getattr(container.attrs.get("State", {}), "ExitCode", None) + } + } + except DockerException as e: + return {"success": False, "message": str(e)} + def remove_container(self, container_name, force=True): try: container = self.client.containers.get(container_name) @@ -86,9 +161,9 @@ def remove_container(self, container_name, force=True): except DockerException as e: return {"success": False, "message": str(e)} - def create_hummingbot_instance(self, config: HummingbotInstanceConfig): + def create_hummingbot_instance(self, config: V2ControllerDeployment): bots_path = os.environ.get('BOTS_PATH', self.SOURCE_PATH) # Default to 'SOURCE_PATH' if BOTS_PATH is not set - instance_name = f"hummingbot-{config.instance_name}" + instance_name = config.instance_name instance_dir = os.path.join("bots", 'instances', instance_name) if not os.path.exists(instance_dir): os.makedirs(instance_dir) @@ -97,51 +172,102 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): # Copy credentials to instance directory source_credentials_dir = os.path.join("bots", 'credentials', config.credentials_profile) - script_config_dir = os.path.join("bots", 'conf', 'scripts') - controllers_config_dir = os.path.join("bots", 'conf', 'controllers') destination_credentials_dir = os.path.join(instance_dir, 'conf') - destination_scripts_config_dir = os.path.join(instance_dir, 'conf', 'scripts') - destination_controllers_config_dir = os.path.join(instance_dir, 'conf', 'controllers') # Remove the destination directory if it already exists if os.path.exists(destination_credentials_dir): shutil.rmtree(destination_credentials_dir) - # Copy the entire contents of source_credentials_dir to destination_credentials_dir + # Copy the entire contents of source_credentials_dir to destination_credentials_dir shutil.copytree(source_credentials_dir, destination_credentials_dir) - shutil.copytree(script_config_dir, destination_scripts_config_dir) - shutil.copytree(controllers_config_dir, destination_controllers_config_dir) - conf_file_path = f"{instance_dir}/conf/conf_client.yml" - client_config = FileSystemUtil.read_yaml_file(conf_file_path) + + # Copy specific script config and referenced controllers if provided + if config.script_config: + script_config_dir = os.path.join("bots", 'conf', 'scripts') + controllers_config_dir = os.path.join("bots", 'conf', 'controllers') + destination_scripts_config_dir = os.path.join(instance_dir, 'conf', 'scripts') + destination_controllers_config_dir = os.path.join(instance_dir, 'conf', 'controllers') + + os.makedirs(destination_scripts_config_dir, exist_ok=True) + + # Copy the specific script config file + source_script_config_file = os.path.join(script_config_dir, config.script_config) + destination_script_config_file = os.path.join(destination_scripts_config_dir, config.script_config) + + if os.path.exists(source_script_config_file): + shutil.copy2(source_script_config_file, destination_script_config_file) + + # Load the script config to find referenced controllers + try: + # Path relative to fs_util base_path (which is "bots") + script_config_relative_path = f"conf/scripts/{config.script_config}" + script_config_content = fs_util.read_yaml_file(script_config_relative_path) + controllers_list = script_config_content.get('controllers_config', []) + + # If there are controllers referenced, copy them + if controllers_list: + os.makedirs(destination_controllers_config_dir, exist_ok=True) + + for controller_file in controllers_list: + source_controller_file = os.path.join(controllers_config_dir, controller_file) + destination_controller_file = os.path.join( + destination_controllers_config_dir, controller_file + ) + + if os.path.exists(source_controller_file): + shutil.copy2(source_controller_file, destination_controller_file) + logger.info(f"Copied controller config: {controller_file}") + else: + logger.warning( + f"Controller config file {controller_file} not found in {controllers_config_dir}" + ) + + except Exception as e: + logger.error(f"Error reading script config file {config.script_config}: {e}") + else: + logger.warning(f"Script config file {config.script_config} not found in {script_config_dir}") + # Path relative to fs_util base_path (which is "bots") + conf_file_path = f"instances/{instance_name}/conf/conf_client.yml" + client_config = fs_util.read_yaml_file(conf_file_path) client_config['instance_id'] = instance_name - FileSystemUtil.dump_dict_to_yaml(conf_file_path, client_config) + fs_util.dump_dict_to_yaml(conf_file_path, client_config) # Set up Docker volumes + instance_conf = os.path.abspath(os.path.join(bots_path, instance_dir, 'conf')) + instance_connectors = os.path.abspath(os.path.join(bots_path, instance_dir, 'conf', 'connectors')) + instance_scripts = os.path.abspath(os.path.join(bots_path, instance_dir, 'conf', 'scripts')) + instance_controllers = os.path.abspath(os.path.join(bots_path, instance_dir, 'conf', 'controllers')) + instance_data = os.path.abspath(os.path.join(bots_path, instance_dir, 'data')) + instance_logs = os.path.abspath(os.path.join(bots_path, instance_dir, 'logs')) + shared_scripts = os.path.abspath(os.path.join(bots_path, "bots", 'scripts')) + shared_controllers = os.path.abspath(os.path.join(bots_path, "bots", 'controllers')) + volumes = { - os.path.abspath(os.path.join(bots_path, instance_dir, 'conf')): {'bind': '/home/hummingbot/conf', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, instance_dir, 'conf', 'connectors')): {'bind': '/home/hummingbot/conf/connectors', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, instance_dir, 'conf', 'scripts')): {'bind': '/home/hummingbot/conf/scripts', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, instance_dir, 'conf', 'controllers')): {'bind': '/home/hummingbot/conf/controllers', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, instance_dir, 'data')): {'bind': '/home/hummingbot/data', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, instance_dir, 'logs')): {'bind': '/home/hummingbot/logs', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, "bots", 'scripts')): {'bind': '/home/hummingbot/scripts', 'mode': 'rw'}, - os.path.abspath(os.path.join(bots_path, "bots", 'controllers')): {'bind': '/home/hummingbot/controllers', 'mode': 'rw'}, + instance_conf: {'bind': '/home/hummingbot/conf', 'mode': 'rw'}, + instance_connectors: {'bind': '/home/hummingbot/conf/connectors', 'mode': 'rw'}, + instance_scripts: {'bind': '/home/hummingbot/conf/scripts', 'mode': 'rw'}, + instance_controllers: {'bind': '/home/hummingbot/conf/controllers', 'mode': 'rw'}, + instance_data: {'bind': '/home/hummingbot/data', 'mode': 'rw'}, + instance_logs: {'bind': '/home/hummingbot/logs', 'mode': 'rw'}, + shared_scripts: {'bind': '/home/hummingbot/scripts', 'mode': 'rw'}, + shared_controllers: {'bind': '/home/hummingbot/controllers', 'mode': 'rw'}, } # Set up environment variables environment = {} - password = os.environ.get('CONFIG_PASSWORD', "a") + password = settings.security.config_password if password: environment["CONFIG_PASSWORD"] = password - if config.script: + if config.script_config: if password: - environment['CONFIG_FILE_NAME'] = config.script - if config.script_config: - environment['SCRIPT_CONFIG'] = config.script_config + environment['SCRIPT_CONFIG'] = config.script_config else: return {"success": False, "message": "Password not provided. We cannot start the bot without a password."} + if config.headless: + environment["HEADLESS_MODE"] = "true" + log_config = LogConfig( type="json-file", config={ @@ -163,3 +289,144 @@ def create_hummingbot_instance(self, config: HummingbotInstanceConfig): return {"success": True, "message": f"Instance {instance_name} created successfully."} except docker.errors.DockerException as e: return {"success": False, "message": str(e)} + + def _start_cleanup_thread(self): + """Start the background cleanup thread""" + if self._cleanup_thread is None or not self._cleanup_thread.is_alive(): + self._cleanup_thread = threading.Thread(target=self._periodic_cleanup, daemon=True) + self._cleanup_thread.start() + logger.info("Started Docker pull status cleanup thread") + + def _periodic_cleanup(self): + """Periodically clean up old pull status entries""" + while not self._stop_cleanup.is_set(): + try: + self._cleanup_old_pull_status() + except Exception as e: + logger.error(f"Error in cleanup thread: {e}") + + # Wait for the next cleanup interval + self._stop_cleanup.wait(self.CLEANUP_INTERVAL_SECONDS) + + def _cleanup_old_pull_status(self): + """Remove old entries to prevent memory growth""" + current_time = time.time() + to_remove = [] + + # Find entries older than max age + for image_name, status_info in self._pull_status.items(): + # Skip ongoing pulls + if status_info["status"] == "pulling": + continue + + # Check age of completed/failed operations + end_time = status_info.get("completed_at") or status_info.get("failed_at") + if end_time and (current_time - end_time > self.PULL_STATUS_MAX_AGE_SECONDS): + to_remove.append(image_name) + + # Remove old entries + for image_name in to_remove: + del self._pull_status[image_name] + logger.info(f"Cleaned up old pull status for {image_name}") + + # If still over limit, remove oldest completed/failed entries + if len(self._pull_status) > self.PULL_STATUS_MAX_ENTRIES: + completed_entries = [ + (name, info) for name, info in self._pull_status.items() + if info["status"] in ["completed", "failed"] + ] + # Sort by end time (oldest first) + completed_entries.sort( + key=lambda x: x[1].get("completed_at") or x[1].get("failed_at") or 0 + ) + + # Remove oldest entries to get under limit + excess_count = len(self._pull_status) - self.PULL_STATUS_MAX_ENTRIES + for i in range(min(excess_count, len(completed_entries))): + del self._pull_status[completed_entries[i][0]] + logger.info(f"Cleaned up excess pull status for {completed_entries[i][0]}") + + def pull_image_async(self, image_name: str): + """Start pulling a Docker image asynchronously with status tracking""" + # Check if pull is already in progress + if image_name in self._pull_status: + current_status = self._pull_status[image_name] + if current_status["status"] == "pulling": + return { + "message": f"Pull already in progress for {image_name}", + "status": "in_progress", + "started_at": current_status["started_at"], + "image_name": image_name + } + + # Start the pull in a background thread + threading.Thread(target=self._pull_image_with_tracking, args=(image_name,), daemon=True).start() + + return { + "message": f"Pull started for {image_name}", + "status": "started", + "image_name": image_name + } + + def _pull_image_with_tracking(self, image_name: str): + """Background task to pull Docker image with status tracking""" + try: + self._pull_status[image_name] = { + "status": "pulling", + "started_at": time.time(), + "progress": "Starting pull..." + } + + # Use the synchronous pull method + result = self.pull_image_sync(image_name) + + if result.get("success"): + self._pull_status[image_name] = { + "status": "completed", + "started_at": self._pull_status[image_name]["started_at"], + "completed_at": time.time(), + "result": result + } + else: + self._pull_status[image_name] = { + "status": "failed", + "started_at": self._pull_status[image_name]["started_at"], + "failed_at": time.time(), + "error": result.get("error", "Unknown error") + } + except Exception as e: + self._pull_status[image_name] = { + "status": "failed", + "started_at": self._pull_status[image_name].get("started_at", time.time()), + "failed_at": time.time(), + "error": str(e) + } + + def get_all_pull_status(self): + """Get status of all pull operations""" + operations = {} + for image_name, status_info in self._pull_status.items(): + status_copy = status_info.copy() + + # Add duration for each operation + start_time = status_copy.get("started_at") + if start_time: + if status_copy["status"] == "pulling": + status_copy["duration_seconds"] = round(time.time() - start_time, 2) + elif "completed_at" in status_copy: + status_copy["duration_seconds"] = round(status_copy["completed_at"] - start_time, 2) + elif "failed_at" in status_copy: + status_copy["duration_seconds"] = round(status_copy["failed_at"] - start_time, 2) + + operations[image_name] = status_copy + + return { + "pull_operations": operations, + "total_operations": len(operations) + } + + def cleanup(self): + """Clean up resources when shutting down""" + self._stop_cleanup.set() + if self._cleanup_thread: + self._cleanup_thread.join(timeout=1) diff --git a/services/executor_service.py b/services/executor_service.py new file mode 100644 index 00000000..62bd7ddb --- /dev/null +++ b/services/executor_service.py @@ -0,0 +1,1165 @@ +""" +ExecutorService manages executor lifecycle and orchestration. +This service enables running Hummingbot executors directly via API +without Docker containers or full strategy setup. +""" +import asyncio +import json +import logging +from datetime import datetime, timezone +from decimal import Decimal +from enum import Enum +from typing import Any, Dict, List, Optional, Type + +from fastapi import HTTPException +from hummingbot.strategy_v2.executors.arbitrage_executor.arbitrage_executor import ArbitrageExecutor +from hummingbot.strategy_v2.executors.arbitrage_executor.data_types import ArbitrageExecutorConfig +from hummingbot.strategy_v2.executors.data_types import ExecutorConfigBase +from hummingbot.strategy_v2.executors.dca_executor.data_types import DCAExecutorConfig +from hummingbot.strategy_v2.executors.dca_executor.dca_executor import DCAExecutor +from hummingbot.strategy_v2.executors.executor_base import ExecutorBase +from hummingbot.strategy_v2.executors.grid_executor.data_types import GridExecutorConfig +from hummingbot.strategy_v2.executors.grid_executor.grid_executor import GridExecutor +from hummingbot.strategy_v2.executors.lp_executor.data_types import LPExecutorConfig +from hummingbot.strategy_v2.executors.lp_executor.lp_executor import LPExecutor +from hummingbot.strategy_v2.executors.order_executor.data_types import OrderExecutorConfig +from hummingbot.strategy_v2.executors.order_executor.order_executor import OrderExecutor +from hummingbot.strategy_v2.executors.position_executor.data_types import PositionExecutorConfig +from hummingbot.strategy_v2.executors.position_executor.position_executor import PositionExecutor +from hummingbot.strategy_v2.executors.twap_executor.data_types import TWAPExecutorConfig +from hummingbot.strategy_v2.executors.twap_executor.twap_executor import TWAPExecutor +from hummingbot.strategy_v2.executors.xemm_executor.data_types import XEMMExecutorConfig +from hummingbot.strategy_v2.executors.xemm_executor.xemm_executor import XEMMExecutor +from hummingbot.strategy_v2.models.executors import CloseType, TrackedOrder + +from database import AsyncDatabaseManager +from models.executors import PositionHold +from services.trading_service import AccountTradingInterface, TradingService +from utils.executor_log_capture import ExecutorLogCapture, current_executor_id + +logger = logging.getLogger(__name__) + + +def _json_default(obj): + """JSON serializer for objects not serializable by default.""" + if isinstance(obj, Decimal): + return float(obj) + if isinstance(obj, Enum): + return obj.name + if isinstance(obj, TrackedOrder): + return { + "order_id": obj.order_id, + "price": float(obj.price) if obj.price else None, + "executed_amount_base": float(obj.executed_amount_base) if obj.executed_amount_base else 0.0, + "executed_amount_quote": float(obj.executed_amount_quote) if obj.executed_amount_quote else 0.0, + "is_filled": obj.is_filled if hasattr(obj, 'is_filled') else False, + "is_open": obj.is_open if hasattr(obj, 'is_open') else False, + } + # Handle Pydantic models + if hasattr(obj, 'model_dump'): + return obj.model_dump(mode='json') + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +class ExecutorService: + """ + Service for managing trading executors without Docker containers. + + This service provides: + - Dynamic executor creation for any market/connector + - Executor lifecycle management (start, stop, cleanup) + - Real-time executor status monitoring + - Database persistence of executor state and history + """ + + # Mapping of executor type strings to (executor_class, config_class) + EXECUTOR_REGISTRY: Dict[str, tuple[Type[ExecutorBase], Type[ExecutorConfigBase]]] = { + "position_executor": (PositionExecutor, PositionExecutorConfig), + "grid_executor": (GridExecutor, GridExecutorConfig), + "dca_executor": (DCAExecutor, DCAExecutorConfig), + "arbitrage_executor": (ArbitrageExecutor, ArbitrageExecutorConfig), + "twap_executor": (TWAPExecutor, TWAPExecutorConfig), + "xemm_executor": (XEMMExecutor, XEMMExecutorConfig), + "order_executor": (OrderExecutor, OrderExecutorConfig), + "lp_executor": (LPExecutor, LPExecutorConfig), + } + + def __init__( + self, + trading_service: TradingService, + db_manager: AsyncDatabaseManager, + default_account: str = "master_account", + update_interval: float = 1.0, + max_retries: int = 10 + ): + """ + Initialize ExecutorService. + + Args: + trading_service: TradingService for trading operations and interfaces + db_manager: AsyncDatabaseManager for persistence + default_account: Default account to use + update_interval: Executor update interval in seconds + max_retries: Maximum retries for executor operations + """ + self._trading_service = trading_service + self.db_manager = db_manager + self.default_account = default_account + self.update_interval = update_interval + self.max_retries = max_retries + + # Trading interfaces per account (lazy initialized via TradingService) + self._trading_interfaces: Dict[str, AccountTradingInterface] = {} + + # Active executors: executor_id -> executor instance + self._active_executors: Dict[str, ExecutorBase] = {} + + # Executor metadata: executor_id -> metadata dict + self._executor_metadata: Dict[str, Dict[str, Any]] = {} + + # Position holds: key = "account_name|connector_name|trading_pair" + # Tracks aggregated positions from executors stopped with keep_position=True + self._positions_held: Dict[str, PositionHold] = {} + + # Executor log capture + self._log_capture = ExecutorLogCapture() + self._log_capture.install() + + # Control loop task + self._control_loop_task: Optional[asyncio.Task] = None + self._is_running = False + + def start(self): + """Start the executor service control loop.""" + if not self._is_running: + self._is_running = True + self._control_loop_task = asyncio.create_task(self._control_loop()) + logger.info("ExecutorService started") + + async def recover_positions_from_db(self): + """ + Recover position holds from database on startup. + + This loads executors that closed with POSITION_HOLD (keep_position=True) + and reconstructs the _positions_held tracking from their final state. + """ + if not self.db_manager: + return + + try: + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + + position_hold_executors = await repo.get_position_hold_executors() + + for executor_record in position_hold_executors: + # Build position key + controller_id = getattr(executor_record, "controller_id", "main") or "main" + position_key = self._get_position_key( + executor_record.account_name, + executor_record.connector_name, + executor_record.trading_pair, + controller_id + ) + + # Initialize position if needed + if position_key not in self._positions_held: + self._positions_held[position_key] = PositionHold( + trading_pair=executor_record.trading_pair, + connector_name=executor_record.connector_name, + account_name=executor_record.account_name, + controller_id=controller_id, + ) + + position = self._positions_held[position_key] + + # Try to extract fill data from final_state + if executor_record.final_state: + try: + final_state = json.loads(executor_record.final_state) + + # Process held_position_orders (most accurate source) + held_orders = final_state.get("held_position_orders", []) + if held_orders: + buy_filled_base = Decimal("0") + buy_filled_quote = Decimal("0") + sell_filled_base = Decimal("0") + sell_filled_quote = Decimal("0") + + for order in held_orders: + if isinstance(order, dict): + trade_type = order.get("trade_type", "BUY") + exec_base = Decimal(str(order.get("executed_amount_base", 0))) + exec_quote = Decimal(str(order.get("executed_amount_quote", 0))) + + if trade_type == "BUY": + buy_filled_base += exec_base + buy_filled_quote += exec_quote + else: + sell_filled_base += exec_base + sell_filled_quote += exec_quote + + # Add fills using proper method + if buy_filled_base > 0: + position.add_fill("BUY", buy_filled_base, buy_filled_quote, executor_record.executor_id) + if sell_filled_base > 0: + position.add_fill("SELL", sell_filled_base, sell_filled_quote, executor_record.executor_id) + + logger.debug( + f"Recovered position from {executor_record.executor_id}: " + f"buy={buy_filled_base} base, sell={sell_filled_base} base" + ) + + except (json.JSONDecodeError, TypeError) as e: + logger.debug(f"Could not parse final_state for {executor_record.executor_id}: {e}") + + if self._positions_held: + logger.info(f"Recovered {len(self._positions_held)} position holds from database") + + except Exception as e: + logger.error(f"Error recovering positions from database: {e}", exc_info=True) + + async def cleanup_orphaned_executors(self): + """ + Clean up orphaned executors from database on startup. + + Identifies executors marked as RUNNING in the database but not present + in memory (i.e., from previous API sessions that were terminated). + """ + if not self.db_manager: + logger.debug("No database manager available, skipping orphaned executor cleanup") + return + + try: + # Get list of currently active executor IDs in memory + active_executor_ids = list(self._active_executors.keys()) + + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + + # Clean up orphaned executors + cleaned_count = await repo.cleanup_orphaned_executors( + active_executor_ids=active_executor_ids, + close_type="SYSTEM_CLEANUP" + ) + + if cleaned_count > 0: + logger.info(f"Cleaned up {cleaned_count} orphaned executors from database") + else: + logger.debug("No orphaned executors found in database") + + except Exception as e: + logger.error(f"Error cleaning up orphaned executors: {e}", exc_info=True) + + async def stop(self): + """Stop the executor service and all active executors.""" + self._is_running = False + + if self._control_loop_task: + self._control_loop_task.cancel() + try: + await self._control_loop_task + except asyncio.CancelledError: + pass + self._control_loop_task = None + + # Stop all active executors + for executor_id in list(self._active_executors.keys()): + try: + executor = self._active_executors.get(executor_id) + if executor: + executor.stop() + except Exception as e: + logger.error(f"Error stopping executor {executor_id}: {e}") + + # Clear active executors + self._active_executors.clear() + self._executor_metadata.clear() + + # Cleanup trading interfaces + for trading_interface in self._trading_interfaces.values(): + await trading_interface.cleanup() + self._trading_interfaces.clear() + + logger.info("ExecutorService stopped") + + async def _control_loop(self): + """Main control loop that updates all active executors.""" + while self._is_running: + try: + # Update timestamps for all trading interfaces via TradingService + self._trading_service.update_all_timestamps() + + # Check for completed executors + completed_ids = [] + for executor_id, executor in self._active_executors.items(): + if executor.is_closed: + completed_ids.append(executor_id) + + # Handle completed executors + for executor_id in completed_ids: + await self._handle_executor_completion(executor_id) + + except Exception as e: + logger.error(f"Error in executor control loop: {e}", exc_info=True) + + await asyncio.sleep(self.update_interval) + + def _get_trading_interface(self, account_name: str) -> AccountTradingInterface: + """Get or create an AccountTradingInterface for the account.""" + if account_name not in self._trading_interfaces: + self._trading_interfaces[account_name] = self._trading_service.get_trading_interface(account_name) + return self._trading_interfaces[account_name] + + async def create_executor( + self, + executor_config: Dict[str, Any], + account_name: Optional[str] = None, + controller_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + Create and start a new executor. + + Args: + executor_config: Executor configuration dictionary (must include 'type') + account_name: Account to use (defaults to master_account) + + Returns: + Dictionary with executor_id and initial status + """ + account = account_name or self.default_account + + # Get executor type from config + executor_type = executor_config.get("type") + if not executor_type: + raise HTTPException( + status_code=400, + detail="executor_config must include 'type' field" + ) + + # Validate executor type + if executor_type not in self.EXECUTOR_REGISTRY: + raise HTTPException( + status_code=400, + detail=f"Invalid executor type '{executor_type}'. Valid types: {list(self.EXECUTOR_REGISTRY.keys())}" + ) + + # Get trading interface for this account + trading_interface = self._get_trading_interface(account) + + # Extract connector and trading pair from config + connector_name = executor_config.get("connector_name") + trading_pair = executor_config.get("trading_pair") + if not connector_name: + raise HTTPException(status_code=400, detail="connector_name is required in executor_config") + if not trading_pair: + raise HTTPException(status_code=400, detail="trading_pair is required in executor_config") + + # Ensure connector and market are ready + await trading_interface.add_market(connector_name, trading_pair) + + # Set timestamp if not provided (required for time-based features like time_limit) + if "timestamp" not in executor_config or executor_config["timestamp"] is None: + executor_config["timestamp"] = trading_interface.current_timestamp + + # Create typed executor config + executor_class, config_class = self.EXECUTOR_REGISTRY[executor_type] + try: + typed_config = config_class(**executor_config) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Invalid executor config: {str(e)}" + ) + + # Create the executor instance + try: + executor = executor_class( + strategy=trading_interface, + config=typed_config, + update_interval=self.update_interval, + max_retries=self.max_retries + ) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to create executor: {str(e)}" + ) + + # Store executor and metadata + executor_id = typed_config.id + controller_id = controller_id or getattr(typed_config, "controller_id", "main") or "main" + self._active_executors[executor_id] = executor + self._executor_metadata[executor_id] = { + "account_name": account, + "connector_name": connector_name, + "trading_pair": trading_pair, + "executor_type": executor_type, + "controller_id": controller_id, + "created_at": datetime.now(timezone.utc), + "config": executor_config + } + + # Set ContextVar so the asyncio Task created by start() inherits it + token = current_executor_id.set(executor_id) + executor.start() + current_executor_id.reset(token) + + # Persist to database + await self._persist_executor_created(executor_id, executor) + + # Capture created_at before potential cleanup + created_at = self._executor_metadata[executor_id]["created_at"].isoformat() + + # Check if executor terminated immediately (e.g., insufficient balance) + # If so, handle completion now rather than waiting for control loop + if executor.is_closed: + await self._handle_executor_completion(executor_id) + + logger.info(f"Created {executor_type} executor {executor_id} for {connector_name}/{trading_pair}") + + return { + "executor_id": executor_id, + "executor_type": executor_type, + "connector_name": connector_name, + "trading_pair": trading_pair, + "controller_id": controller_id, + "status": executor.status.name, + "created_at": created_at + } + + async def get_executors( + self, + account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + executor_type: Optional[str] = None, + status: Optional[str] = None, + controller_id: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Get list of executors with optional filtering. + + Combines active executors from memory with completed executors from database. + + Args: + account_name: Filter by account name + connector_name: Filter by connector name + trading_pair: Filter by trading pair + executor_type: Filter by executor type + status: Filter by status + controller_id: Filter by controller ID + + Returns: + List of executor information dictionaries + """ + result = [] + + # Process active executors from memory + for executor_id, executor in self._active_executors.items(): + metadata = self._executor_metadata.get(executor_id, {}) + + # Apply filters + if account_name and metadata.get("account_name") != account_name: + continue + if connector_name and metadata.get("connector_name") != connector_name: + continue + if trading_pair and metadata.get("trading_pair") != trading_pair: + continue + if executor_type and metadata.get("executor_type") != executor_type: + continue + if status and executor.status.name != status: + continue + if controller_id and metadata.get("controller_id", "main") != controller_id: + continue + + result.append(self._format_executor_info(executor_id, executor)) + + # Get completed executors from database + if self.db_manager: + try: + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + + db_executors = await repo.get_executors( + account_name=account_name, + connector_name=connector_name, + trading_pair=trading_pair, + executor_type=executor_type, + status=status, + controller_id=controller_id + ) + + for record in db_executors: + # Skip if already in active executors (safety check) + if record.executor_id not in self._active_executors: + result.append(self._format_db_record(record)) + except Exception as e: + logger.error(f"Error fetching executors from database: {e}") + + return result + + async def get_executor(self, executor_id: str) -> Optional[Dict[str, Any]]: + """ + Get detailed information about a specific executor. + + Checks active executors in memory first, then falls back to database. + + Args: + executor_id: The executor ID + + Returns: + Detailed executor information or None if not found + """ + # Check active executors first (memory) + executor = self._active_executors.get(executor_id) + if executor: + return self._format_executor_info(executor_id, executor) + + # Fallback to database for completed executors + if self.db_manager: + try: + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + + record = await repo.get_executor_by_id(executor_id) + if record: + return self._format_db_record(record) + except Exception as e: + logger.error(f"Error fetching executor from database: {e}") + + return None + + def get_executor_logs( + self, + executor_id: str, + level: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[dict]: + """ + Get captured log entries for an executor. + + Only available for active executors (logs are cleared on completion). + + Args: + executor_id: The executor ID + level: Optional filter by level (ERROR, WARNING, INFO, DEBUG) + limit: Maximum number of entries to return + + Returns: + List of log entry dicts + """ + return self._log_capture.get_logs(executor_id, level=level, limit=limit) + + async def stop_executor( + self, + executor_id: str, + keep_position: bool = False + ) -> Dict[str, Any]: + """ + Stop an active executor. + + Args: + executor_id: The executor ID to stop + keep_position: Whether to keep the position open + + Returns: + Dictionary with stop confirmation + """ + executor = self._active_executors.get(executor_id) + if not executor: + raise HTTPException(status_code=404, detail=f"Executor {executor_id} not found") + + if executor.is_closed: + raise HTTPException(status_code=400, detail=f"Executor {executor_id} is already closed") + + # Trigger early stop + try: + executor.early_stop(keep_position=keep_position) + except Exception as e: + logger.error(f"Error stopping executor {executor_id}: {e}") + raise HTTPException(status_code=500, detail=f"Error stopping executor: {str(e)}") + + logger.info(f"Initiated stop for executor {executor_id} (keep_position={keep_position})") + + return { + "executor_id": executor_id, + "status": "stopping", + "keep_position": keep_position + } + + async def _handle_executor_completion(self, executor_id: str): + """Handle cleanup when an executor completes.""" + executor = self._active_executors.get(executor_id) + if not executor: + return + + metadata = self._executor_metadata.get(executor_id, {}) + + # Check if this is a POSITION_HOLD close type (keep_position=True) + if executor.close_type == CloseType.POSITION_HOLD: + await self._aggregate_position_hold(executor_id, executor, metadata) + + # Persist final state to database + await self._persist_executor_completed(executor_id, executor) + + # Remove from active executors + del self._active_executors[executor_id] + if executor_id in self._executor_metadata: + del self._executor_metadata[executor_id] + + # Clean up captured logs + self._log_capture.clear(executor_id) + + close_type = executor.close_type.name if executor.close_type else "UNKNOWN" + logger.info(f"Executor {executor_id} completed with close_type: {close_type}") + + def _format_executor_info( + self, + executor_id: str, + executor: ExecutorBase + ) -> Dict[str, Any]: + """Format executor information for API response.""" + metadata = self._executor_metadata.get(executor_id, {}) + executor_type = metadata.get("executor_type") + + # Get executor_info and serialize + executor_info = executor.executor_info + result = json.loads(json.dumps(executor_info.model_dump(), default=_json_default)) + + # Add metadata + result["executor_id"] = executor_id + result["executor_type"] = executor_type + result["account_name"] = metadata.get("account_name") + result["created_at"] = metadata.get("created_at").isoformat() if metadata.get("created_at") else None + + if metadata.get("connector_name"): + result["connector_name"] = metadata.get("connector_name") + if metadata.get("trading_pair"): + result["trading_pair"] = metadata.get("trading_pair") + result["controller_id"] = metadata.get("controller_id", "main") + + # Read status/close_type directly from executor + result["status"] = executor.status.name + result["close_type"] = executor.close_type.name if executor.close_type else None + result["is_active"] = not executor.is_closed + + # For grid executors, filter out heavy fields from custom_info + if executor_type == "grid_executor" and result.get("custom_info"): + heavy_fields = {"levels_by_state", "filled_orders", "failed_orders", "canceled_orders"} + result["custom_info"] = {k: v for k, v in result["custom_info"].items() if k not in heavy_fields} + + # Add log capture info + result["error_count"] = self._log_capture.get_error_count(executor_id) + result["last_error"] = self._log_capture.get_last_error(executor_id) + + return result + + def _format_db_record(self, record) -> Dict[str, Any]: + """Format a database ExecutorRecord for API response.""" + return { + "executor_id": record.executor_id, + "executor_type": record.executor_type, + "account_name": record.account_name, + "connector_name": record.connector_name, + "trading_pair": record.trading_pair, + "side": None, + "status": record.status, + "close_type": record.close_type, + "is_active": record.status == "RUNNING", + "is_trading": False, + "timestamp": None, + "created_at": record.created_at.isoformat() if record.created_at else None, + "close_timestamp": record.closed_at.timestamp() if record.closed_at else None, + "closed_at": record.closed_at.isoformat() if record.closed_at else None, + "controller_id": record.controller_id or "main", + "net_pnl_quote": float(record.net_pnl_quote) if record.net_pnl_quote else 0.0, + "net_pnl_pct": float(record.net_pnl_pct) if record.net_pnl_pct else 0.0, + "cum_fees_quote": float(record.cum_fees_quote) if record.cum_fees_quote else 0.0, + "filled_amount_quote": float(record.filled_amount_quote) if record.filled_amount_quote else 0.0, + "config": json.loads(record.config) if record.config else None, + "custom_info": json.loads(record.final_state) if record.final_state else None, + } + + def get_summary(self) -> Dict[str, Any]: + """ + Get summary statistics for active executors. + + Returns: + Dictionary with aggregate statistics for active executors only. + """ + executors = [] + + # Get active executors from memory + for executor_id, executor in self._active_executors.items(): + executors.append(self._format_executor_info(executor_id, executor)) + + active_count = len(executors) + total_pnl = sum(e.get("net_pnl_quote", 0) for e in executors) + total_volume = sum(e.get("filled_amount_quote", 0) for e in executors) + + by_type: Dict[str, int] = {} + by_connector: Dict[str, int] = {} + by_status: Dict[str, int] = {} + + for e in executors: + ex_type = e.get("executor_type", "unknown") + connector = e.get("connector_name", "unknown") + status = e.get("status", "unknown") + + by_type[ex_type] = by_type.get(ex_type, 0) + 1 + by_connector[connector] = by_connector.get(connector, 0) + 1 + by_status[status] = by_status.get(status, 0) + 1 + + return { + "total_active": active_count, + "total_pnl_quote": total_pnl, + "total_volume_quote": total_volume, + "by_type": by_type, + "by_connector": by_connector, + "by_status": by_status + } + + async def get_performance_report( + self, + controller_id: Optional[str] = None, + market_data_service=None + ) -> Dict[str, Any]: + """ + Generate a performance report aggregating executor metrics. + + Combines database aggregations (completed executors) with in-memory + active executor and position hold unrealized PnL. + Excludes POSITION_HOLD close_type from realized PnL to avoid double-counting. + + Args: + controller_id: Filter by controller ID (None = all) + market_data_service: MarketDataService for position hold unrealized PnL + + Returns: + Dictionary with performance metrics ready for PerformanceReportResponse. + """ + import math + + report: Dict[str, Any] = { + "controller_id": controller_id, + "total_executors": 0, + "by_status": {}, + "pnl_total_quote": 0.0, + "unrealized_pnl_quote": 0.0, + "global_pnl_quote": 0.0, + "pnl_pct_avg": 0.0, + "fees_total_quote": 0.0, + "volume_total_quote": 0.0, + "win_rate": 0.0, + "sharpe_ratio": None, + "by_type": [], + "active_positions": 0, + } + + if self.db_manager: + try: + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + db_data = await repo.get_performance_report(controller_id=controller_id) + + report["total_executors"] = db_data["total_executors"] + report["by_status"] = db_data["status_counts"] + report["pnl_total_quote"] = db_data["pnl_total_quote"] + report["pnl_pct_avg"] = db_data["pnl_pct_avg"] + report["fees_total_quote"] = db_data["fees_total_quote"] + report["volume_total_quote"] = db_data["volume_total_quote"] + report["win_rate"] = db_data["win_rate"] + report["by_type"] = db_data["by_type"] + + # Sharpe ratio: mean(pnl) / std(pnl), requires >= 2 values + pnl_values = db_data.get("pnl_values", []) + if len(pnl_values) >= 2: + mean_pnl = sum(pnl_values) / len(pnl_values) + variance = sum((v - mean_pnl) ** 2 for v in pnl_values) / (len(pnl_values) - 1) + std_pnl = math.sqrt(variance) + if std_pnl > 0: + report["sharpe_ratio"] = round(mean_pnl / std_pnl, 4) + + except Exception as e: + logger.error(f"Error generating performance report: {e}", exc_info=True) + + # --- Unrealized PnL from active executors --- + unrealized_pnl = 0.0 + for executor_id, executor in self._active_executors.items(): + metadata = self._executor_metadata.get(executor_id, {}) + if controller_id and metadata.get("controller_id", "main") != controller_id: + continue + try: + unrealized_pnl += float(executor.executor_info.net_pnl_quote) + except Exception: + pass + + # --- Unrealized PnL from position holds --- + positions = self.get_positions_held(controller_id=controller_id) + report["active_positions"] = len(positions) + + if market_data_service: + for p in positions: + parts = p.trading_pair.split("-") + if len(parts) == 2: + base, quote = parts + rate = market_data_service.get_rate(base, quote) + if rate is not None: + unrealized_pnl += float(p.get_unrealized_pnl(rate)) + + report["unrealized_pnl_quote"] = round(unrealized_pnl, 8) + report["global_pnl_quote"] = round(report["pnl_total_quote"] + unrealized_pnl, 8) + + return report + + async def _persist_executor_created(self, executor_id: str, executor: ExecutorBase): + """Persist executor creation to database.""" + if not self.db_manager: + return + + try: + metadata = self._executor_metadata.get(executor_id, {}) + + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + + await repo.create_executor( + executor_id=executor_id, + executor_type=metadata.get("executor_type"), + account_name=metadata.get("account_name"), + connector_name=metadata.get("connector_name"), + trading_pair=metadata.get("trading_pair"), + config=json.dumps(metadata.get("config", {}), default=_json_default), + status=executor.status.name, + controller_id=metadata.get("controller_id", "main") + ) + + logger.debug(f"Persisted executor {executor_id} creation to database") + + except Exception as e: + logger.error(f"Error persisting executor creation: {e}") + + async def _persist_executor_completed(self, executor_id: str, executor: ExecutorBase): + """Persist executor completion to database.""" + if not self.db_manager: + return + + try: + # Read status/close_type directly from executor (most reliable) + status_name = executor.status.name + close_type = executor.close_type.name if executor.close_type else None + + # Get PnL values from executor_info + try: + executor_info = executor.executor_info + net_pnl_quote = executor_info.net_pnl_quote + net_pnl_pct = executor_info.net_pnl_pct + cum_fees_quote = executor_info.cum_fees_quote + filled_amount_quote = executor_info.filled_amount_quote + except Exception as e: + logger.debug(f"Error accessing executor_info for persistence: {e}") + net_pnl_quote = Decimal("0") + net_pnl_pct = Decimal("0") + cum_fees_quote = Decimal("0") + filled_amount_quote = Decimal("0") + + # Get custom_info directly from executor to avoid Pydantic serialization issues + # with TrackedOrder and other complex types + custom_info = executor.get_custom_info() + # Serialize custom_info, fallback to None if serialization fails + final_state_json = None + metadata = self._executor_metadata.get(executor_id, {}) + executor_type = metadata.get("executor_type") + if executor_type == "grid_executor": + heavy_fields = { + "levels_by_state", + "filled_orders", + "failed_orders", + "canceled_orders", + } + custom_info = {k: v for k, v in custom_info.items() if k not in heavy_fields} + + try: + final_state_json = json.dumps(custom_info, default=_json_default) + except Exception as e: + logger.warning(f"Failed to serialize custom_info for {executor_id}: {e}") + # Try a simpler serialization without complex objects + try: + simple_info = {k: v for k, v in custom_info.items() + if isinstance(v, (str, int, float, bool, list, dict, type(None)))} + final_state_json = json.dumps(simple_info) + except Exception: + final_state_json = None + + async with self.db_manager.get_session_context() as session: + from database.repositories.executor_repository import ExecutorRepository + repo = ExecutorRepository(session) + + await repo.update_executor( + executor_id=executor_id, + status=status_name, + close_type=close_type, + net_pnl_quote=net_pnl_quote, + net_pnl_pct=net_pnl_pct, + cum_fees_quote=cum_fees_quote, + filled_amount_quote=filled_amount_quote, + final_state=final_state_json + ) + + logger.debug(f"Persisted executor {executor_id} completion to database") + + except Exception as e: + logger.error(f"Error persisting executor completion: {e}") + + # ======================================== + # Position Hold Tracking Methods + # ======================================== + + def _get_position_key( + self, + account_name: str, + connector_name: str, + trading_pair: str, + controller_id: str = "main" + ) -> str: + """Generate a unique key for position tracking.""" + return f"{account_name}|{connector_name}|{trading_pair}|{controller_id}" + + async def _aggregate_position_hold( + self, + executor_id: str, + executor: ExecutorBase, + metadata: Dict[str, Any] + ): + """ + Aggregate position data from an executor stopped with keep_position=True. + + This extracts the filled amounts from the executor and adds them to + the aggregated position tracking. + """ + account_name = metadata.get("account_name", self.default_account) + connector_name = metadata.get("connector_name", "") + trading_pair = metadata.get("trading_pair", "") + controller_id = metadata.get("controller_id", "main") + + if not connector_name or not trading_pair: + logger.warning(f"Cannot aggregate position for executor {executor_id}: missing connector/pair info") + return + + position_key = self._get_position_key(account_name, connector_name, trading_pair, controller_id) + + # Get or create position hold + if position_key not in self._positions_held: + self._positions_held[position_key] = PositionHold( + trading_pair=trading_pair, + connector_name=connector_name, + account_name=account_name, + controller_id=controller_id + ) + + position = self._positions_held[position_key] + + # Extract filled amounts from executor + try: + # Try to get executor info + try: + executor_info = executor.executor_info + custom_info = executor_info.custom_info or {} + except Exception: + custom_info = executor.get_custom_info() if hasattr(executor, 'get_custom_info') else {} + + # Get side from config or custom_info + config = metadata.get("config", {}) + side = config.get("side", custom_info.get("side", "BUY")) + + # Extract filled amounts - try different sources + filled_amount_base = Decimal("0") + filled_amount_quote = Decimal("0") + + # Try from executor attributes directly + if hasattr(executor, 'filled_amount_base'): + filled_amount_base = Decimal(str(executor.filled_amount_base or 0)) + if hasattr(executor, 'filled_amount_quote'): + filled_amount_quote = Decimal(str(executor.filled_amount_quote or 0)) + + # Fallback to custom_info + if filled_amount_base == 0 and custom_info: + filled_amount_base = Decimal(str(custom_info.get("filled_amount_base", 0))) + if filled_amount_quote == 0 and custom_info: + filled_amount_quote = Decimal(str(custom_info.get("filled_amount_quote", 0))) + + # Check for held_position_orders (used by grid_executor, position_executor, etc.) + held_orders = custom_info.get("held_position_orders", []) if custom_info else [] + + if held_orders: + buy_filled_base = Decimal("0") + buy_filled_quote = Decimal("0") + sell_filled_base = Decimal("0") + sell_filled_quote = Decimal("0") + + for order in held_orders: + if isinstance(order, dict): + trade_type = order.get("trade_type", "BUY") + exec_base = Decimal(str(order.get("executed_amount_base", 0))) + exec_quote = Decimal(str(order.get("executed_amount_quote", 0))) + + if trade_type == "BUY": + buy_filled_base += exec_base + buy_filled_quote += exec_quote + else: + sell_filled_base += exec_base + sell_filled_quote += exec_quote + + # Add buy and sell fills separately + if buy_filled_base > 0: + position.add_fill("BUY", buy_filled_base, buy_filled_quote, executor_id) + if sell_filled_base > 0: + position.add_fill("SELL", sell_filled_base, sell_filled_quote, executor_id) + + logger.info( + f"Aggregated executor {executor_id} to position {position_key}: " + f"buy={buy_filled_base} base, sell={sell_filled_base} base" + ) + + elif filled_amount_base > 0: + # For non-grid executors with a single side + position.add_fill(side, filled_amount_base, filled_amount_quote, executor_id) + logger.info( + f"Aggregated executor {executor_id} to position {position_key}: " + f"{side} {filled_amount_base} base @ {filled_amount_quote} quote" + ) + else: + logger.debug(f"Executor {executor_id} has no filled amounts to aggregate") + + except Exception as e: + logger.error(f"Error aggregating position for executor {executor_id}: {e}", exc_info=True) + + def get_positions_held( + self, + account_name: Optional[str] = None, + connector_name: Optional[str] = None, + trading_pair: Optional[str] = None, + controller_id: Optional[str] = None + ) -> List[PositionHold]: + """ + Get held positions with optional filtering. + + Args: + account_name: Filter by account name + connector_name: Filter by connector name + trading_pair: Filter by trading pair + controller_id: Filter by controller ID + + Returns: + List of PositionHold objects matching the filters + """ + positions = [] + + for position in self._positions_held.values(): + # Apply filters + if account_name and position.account_name != account_name: + continue + if connector_name and position.connector_name != connector_name: + continue + if trading_pair and position.trading_pair != trading_pair: + continue + if controller_id and position.controller_id != controller_id: + continue + + # Only include positions with actual volume + if position.buy_amount_base > 0 or position.sell_amount_base > 0: + positions.append(position) + + return positions + + def get_position_held( + self, + account_name: str, + connector_name: str, + trading_pair: str, + controller_id: str = "main" + ) -> Optional[PositionHold]: + """ + Get a specific held position. + + Args: + account_name: Account name + connector_name: Connector name + trading_pair: Trading pair + controller_id: Controller ID + + Returns: + PositionHold or None if not found + """ + position_key = self._get_position_key(account_name, connector_name, trading_pair, controller_id) + return self._positions_held.get(position_key) + + def clear_position_held( + self, + account_name: str, + connector_name: str, + trading_pair: str, + controller_id: str = "main" + ) -> bool: + """ + Clear a specific held position (after manual close or full exit). + + Args: + account_name: Account name + connector_name: Connector name + trading_pair: Trading pair + controller_id: Controller ID + + Returns: + True if cleared, False if not found + """ + position_key = self._get_position_key(account_name, connector_name, trading_pair, controller_id) + if position_key in self._positions_held: + del self._positions_held[position_key] + logger.info(f"Cleared position hold for {position_key}") + return True + return False + + def get_positions_summary(self) -> Dict[str, Any]: + """ + Get summary of all held positions. + + Returns: + Dictionary with total positions, PnL, and position list + """ + positions = self.get_positions_held() + total_realized_pnl = sum(float(p.realized_pnl_quote) for p in positions) + + return { + "total_positions": len(positions), + "total_realized_pnl": total_realized_pnl, + "positions": [ + { + "trading_pair": p.trading_pair, + "connector_name": p.connector_name, + "account_name": p.account_name, + "buy_amount_base": float(p.buy_amount_base), + "buy_amount_quote": float(p.buy_amount_quote), + "sell_amount_base": float(p.sell_amount_base), + "sell_amount_quote": float(p.sell_amount_quote), + "net_amount_base": float(p.net_amount_base), + "buy_breakeven_price": float(p.buy_breakeven_price) if p.buy_breakeven_price else None, + "sell_breakeven_price": float(p.sell_breakeven_price) if p.sell_breakeven_price else None, + "matched_amount_base": float(p.matched_amount_base), + "unmatched_amount_base": float(p.unmatched_amount_base), + "position_side": p.position_side, + "realized_pnl_quote": float(p.realized_pnl_quote), + "executor_count": len(p.executor_ids), + "executor_ids": p.executor_ids, + "last_updated": p.last_updated.isoformat() if p.last_updated else None + } + for p in positions + ] + } diff --git a/services/funding_recorder.py b/services/funding_recorder.py new file mode 100644 index 00000000..9560939c --- /dev/null +++ b/services/funding_recorder.py @@ -0,0 +1,147 @@ +import asyncio +import logging +from datetime import datetime +from decimal import Decimal, InvalidOperation +from typing import Dict, Optional + +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.core.event.event_forwarder import SourceInfoEventForwarder +from hummingbot.core.event.events import MarketEvent, FundingPaymentCompletedEvent + +from database import AsyncDatabaseManager, FundingRepository + + +class FundingRecorder: + """ + Records funding payment events and associates them with position data. + Follows the same pattern as OrdersRecorder for consistency. + """ + + def __init__(self, db_manager: AsyncDatabaseManager, account_name: str, connector_name: str): + self.db_manager = db_manager + self.account_name = account_name + self.connector_name = connector_name + self._connector: Optional[ConnectorBase] = None + self.logger = logging.getLogger(__name__) + + # Create event forwarder for funding payments + self._funding_payment_forwarder = SourceInfoEventForwarder(self._did_funding_payment) + + # Event pairs mapping events to forwarders + self._event_pairs = [ + (MarketEvent.FundingPaymentCompleted, self._funding_payment_forwarder), + ] + + def start(self, connector: ConnectorBase): + """Start recording funding payments for the given connector""" + # Idempotency guard: prevent double-registration of listeners + if self._connector is not None: + self.logger.warning(f"FundingRecorder already started for {self.account_name}/{self.connector_name}, ignoring duplicate start") + return + + self._connector = connector + + # Subscribe to funding payment events + for event, forwarder in self._event_pairs: + connector.add_listener(event, forwarder) + + self.logger.info(f"FundingRecorder started for {self.account_name}/{self.connector_name}") + + async def stop(self): + """Stop recording funding payments""" + if self._connector: + for event, forwarder in self._event_pairs: + self._connector.remove_listener(event, forwarder) + self.logger.info(f"FundingRecorder stopped for {self.account_name}/{self.connector_name}") + + def _did_funding_payment(self, event_tag: int, market: ConnectorBase, event: FundingPaymentCompletedEvent): + """Handle funding payment events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_funding_payment(event)) + except Exception as e: + self.logger.error(f"Error in _did_funding_payment: {e}") + + async def _handle_funding_payment(self, event: FundingPaymentCompletedEvent): + """Handle funding payment events""" + # Get current position data if available + position_data = None + if self._connector and hasattr(self._connector, 'account_positions'): + try: + positions = self._connector.account_positions + if positions: + for position in positions.values(): + if position.trading_pair == event.trading_pair: + position_data = { + "size": float(position.amount), + "side": position.position_side.name if hasattr(position.position_side, 'name') else str(position.position_side), + } + break + except Exception as e: + self.logger.warning(f"Could not get position data for funding payment: {e}") + + # Record the funding payment + await self.record_funding_payment(event, self.account_name, self.connector_name, position_data) + + async def record_funding_payment(self, event: FundingPaymentCompletedEvent, + account_name: str, connector_name: str, + position_data: Optional[Dict] = None): + """ + Record a funding payment event with optional position association. + + Args: + event: FundingPaymentCompletedEvent from Hummingbot + account_name: Account name + connector_name: Connector name + position_data: Optional position data at time of payment + """ + try: + # Validate and convert funding data + funding_rate = Decimal(str(event.funding_rate)) + funding_payment = Decimal(str(event.amount)) + + # Create funding payment record + funding_data = { + "funding_payment_id": f"{connector_name}_{event.trading_pair}_{event.timestamp.timestamp()}", + "timestamp": event.timestamp, + "account_name": account_name, + "connector_name": connector_name, + "trading_pair": event.trading_pair, + "funding_rate": float(funding_rate), + "funding_payment": float(funding_payment), + "fee_currency": getattr(event, 'fee_currency', 'USDT'), # Default to USDT if not provided + "exchange_funding_id": getattr(event, 'exchange_funding_id', None), + } + + # Add position data if provided + if position_data: + funding_data.update({ + "position_size": float(position_data.get("size", 0)), + "position_side": position_data.get("side"), + }) + + # Save to database + async with self.db_manager.get_session() as session: + funding_repo = FundingRepository(session) + + # Check if funding payment already exists + if await funding_repo.funding_payment_exists(funding_data["funding_payment_id"]): + self.logger.info(f"Funding payment {funding_data['funding_payment_id']} already exists, skipping") + return + + funding_payment = await funding_repo.create_funding_payment(funding_data) + await session.commit() + + self.logger.info( + f"Recorded funding payment for {account_name}/{connector_name}: " + f"{event.trading_pair} - Rate: {funding_rate}, Payment: {funding_payment} " + f"{funding_data['fee_currency']}" + ) + + return funding_payment + + except (ValueError, InvalidOperation) as e: + self.logger.error(f"Error processing funding payment for {event.trading_pair}: {e}, skipping update") + return + except Exception as e: + self.logger.error(f"Unexpected error recording funding payment: {e}") + return \ No newline at end of file diff --git a/services/gateway_client.py b/services/gateway_client.py new file mode 100644 index 00000000..278448a1 --- /dev/null +++ b/services/gateway_client.py @@ -0,0 +1,615 @@ +import logging +from decimal import Decimal +from typing import Dict, List, Optional + +import aiohttp + +logger = logging.getLogger(__name__) + + +class GatewayClient: + """ + Simplified Gateway HTTP client for API integration. + Provides essential functionality for wallet management and balance queries. + """ + + def __init__(self, base_url: str = "http://localhost:15888"): + self.base_url = base_url + self._session: Optional[aiohttp.ClientSession] = None + + @staticmethod + def parse_network_id(network_id: str) -> tuple[str, str]: + """ + Parse network_id in format 'chain-network' into (chain, network). + + Examples: + 'solana-mainnet-beta' -> ('solana', 'mainnet-beta') + 'ethereum-mainnet' -> ('ethereum', 'mainnet') + """ + parts = network_id.split('-', 1) + if len(parts) != 2: + raise ValueError(f"Invalid network_id format. Expected 'chain-network', got '{network_id}'") + return parts[0], parts[1] + + async def get_wallet_address_or_default(self, chain: str, wallet_address: Optional[str] = None) -> str: + """Get wallet address - use provided or get default for chain""" + if wallet_address: + return wallet_address + + default_wallet = await self.get_default_wallet_address(chain) + if not default_wallet: + raise ValueError(f"No wallet configured for chain '{chain}'") + # Skip placeholder wallet addresses (e.g., "ethereum-default-wallet", "solana-default-wallet") + if default_wallet.endswith("-default-wallet"): + raise ValueError(f"No valid wallet configured for chain '{chain}' (found placeholder: {default_wallet})") + return default_wallet + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp session""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def close(self): + """Close the aiohttp session""" + if self._session and not self._session.closed: + await self._session.close() + + async def _request(self, method: str, path: str, params: Dict = None, json: Dict = None) -> Optional[Dict]: + """Make HTTP request to Gateway""" + session = await self._get_session() + url = f"{self.base_url}/{path}" + + try: + if method == "GET": + async with session.get(url, params=params) as response: + if not response.ok: + error_body = await self._get_error_body(response) + logger.warning(f"Gateway request failed: {method} {url} - {response.status} - {error_body}") + return {"error": error_body, "status": response.status} + return await response.json() + elif method == "POST": + async with session.post(url, json=json) as response: + if not response.ok: + error_body = await self._get_error_body(response) + logger.warning(f"Gateway request failed: {method} {url} - {response.status} - {error_body}") + return {"error": error_body, "status": response.status} + return await response.json() + elif method == "DELETE": + async with session.delete(url, params=params, json=json) as response: + if not response.ok: + error_body = await self._get_error_body(response) + logger.warning(f"Gateway request failed: {method} {url} - {response.status} - {error_body}") + return {"error": error_body, "status": response.status} + return await response.json() + except aiohttp.ClientError as e: + logger.debug(f"Gateway request error: {method} {url} - {e}") + return None + except Exception as e: + logger.debug(f"Gateway request failed: {method} {url} - {e}") + raise + + async def _get_error_body(self, response: aiohttp.ClientResponse) -> str: + """Extract error message from response body""" + try: + data = await response.json() + if isinstance(data, dict): + return data.get("message") or data.get("error") or str(data) + return str(data) + except Exception: + try: + return await response.text() + except Exception: + return f"HTTP {response.status}" + + async def ping(self) -> bool: + """Check if Gateway is online""" + try: + response = await self._request("GET", "") + return response.get("status") == "ok" + except Exception: + return False + + async def get_wallets(self) -> List[Dict]: + """Get all connected wallets""" + return await self._request("GET", "wallet") + + async def get_default_wallet_address(self, chain: str) -> Optional[str]: + """Get default wallet address for a chain from Gateway config""" + try: + config = await self._request("GET", "config", params={"namespace": chain}) + return config.get("defaultWallet") + except Exception as e: + logger.error(f"Error getting default wallet for chain {chain}: {e}") + return None + + async def get_all_wallet_addresses(self, chain: Optional[str] = None) -> Dict[str, List[str]]: + """ + Get all wallet addresses, optionally filtered by chain. + + Args: + chain: Optional chain filter (e.g., 'solana', 'ethereum'). + If not provided, returns wallets for all chains. + + Returns: + Dict mapping chain name to list of wallet addresses. + Example: {"solana": ["addr1", "addr2"], "ethereum": ["addr3"]} + """ + try: + wallets = await self.get_wallets() + if wallets is None: + return {} + + result = {} + for wallet in wallets: + wallet_chain = wallet.get("chain") + if chain and wallet_chain != chain: + continue + + addresses = wallet.get("walletAddresses", []) + if addresses and wallet_chain: + result[wallet_chain] = addresses + + return result + except Exception as e: + logger.error(f"Error getting all wallet addresses: {e}") + return {} + + async def add_wallet(self, chain: str, private_key: str, set_default: bool = True) -> Dict: + """Add a wallet to Gateway""" + return await self._request("POST", "wallet/add", json={ + "chain": chain, + "privateKey": private_key, + "setDefault": set_default + }) + + async def create_wallet(self, chain: str, set_default: bool = True) -> Dict: + """Create a new wallet in Gateway""" + return await self._request("POST", "wallet/create", json={ + "chain": chain, + "setDefault": set_default + }) + + async def show_private_key(self, chain: str, address: str, passphrase: str) -> Dict: + """Show private key for a wallet""" + return await self._request("POST", "wallet/show-private-key", json={ + "chain": chain, + "address": address, + "passphrase": passphrase + }) + + async def send_transaction( + self, + chain: str, + network: str, + address: str, + to_address: str, + amount: str + ) -> Dict: + """Send a native token transaction""" + return await self._request("POST", "wallet/send", json={ + "chain": chain, + "network": network, + "address": address, + "toAddress": to_address, + "amount": amount + }) + + async def remove_wallet(self, chain: str, address: str) -> Dict: + """Remove a wallet from Gateway""" + return await self._request("DELETE", "wallet/remove", json={ + "chain": chain, + "address": address + }) + + async def get_balances(self, chain: str, network: str, address: str, tokens: Optional[List[str]] = None) -> Dict: + """Get token balances for a wallet""" + return await self._request("POST", f"chains/{chain}/balances", json={ + "network": network, + "address": address, + "tokens": tokens if tokens is not None else [] + }) + + async def get_chains(self) -> Dict: + """Get available chains""" + return await self._request("GET", "config/chains") + + async def get_default_network(self, chain: str) -> Optional[str]: + """Get default network for a chain""" + try: + config = await self._request("GET", "config", params={"namespace": chain}) + return config.get("defaultNetwork") + except Exception: + return None + + async def get_tokens(self, chain: str, network: str) -> Dict: + """Get available tokens for a chain/network""" + return await self._request("GET", "tokens", params={ + "chain": chain, + "network": network + }) + + async def add_token(self, chain: str, network: str, address: str, symbol: str, name: str, decimals: int) -> Dict: + """Add a custom token to Gateway's token list""" + return await self._request("POST", "tokens", json={ + "chain": chain, + "network": network, + "token": { + "address": address, + "symbol": symbol, + "name": name, + "decimals": decimals + } + }) + + async def delete_token(self, chain: str, network: str, token_address: str) -> Dict: + """Delete a custom token from Gateway's token list""" + return await self._request("DELETE", f"tokens/{token_address}", params={ + "chain": chain, + "network": network + }) + + async def get_config(self, namespace: str) -> Dict: + """Get configuration for a specific namespace (connector or chain-network)""" + return await self._request("GET", "config", params={"namespace": namespace}) + + async def update_config(self, namespace: str, path: str, value: any) -> Dict: + """Update a configuration value for a namespace""" + return await self._request("POST", "config/update", json={ + "namespace": namespace, + "path": path, + "value": value + }) + + async def get_pools(self, connector: str, network: str) -> List[Dict]: + """Get pools for a connector and network""" + return await self._request("GET", "pools", params={ + "connector": connector, + "network": network + }) + + async def add_pool( + self, + connector: str, + pool_type: str, + network: str, + address: str, + base_symbol: str, + quote_symbol: str, + base_token_address: str, + quote_token_address: str, + fee_pct: Optional[float] = None + ) -> Dict: + """Add a new pool""" + payload = { + "connector": connector, + "type": pool_type.lower(), # Gateway expects lowercase (amm, clmm) + "network": network, + "address": address, + "baseSymbol": base_symbol, + "quoteSymbol": quote_symbol, + "baseTokenAddress": base_token_address, + "quoteTokenAddress": quote_token_address + } + if fee_pct is not None: + payload["feePct"] = fee_pct + return await self._request("POST", "pools", json=payload) + + async def delete_pool(self, connector: str, network: str, pool_type: str, address: str) -> Dict: + """Delete a pool from Gateway's pool list""" + return await self._request("DELETE", f"pools/{address}", params={ + "connector": connector, + "network": network, + "type": pool_type.lower() # Gateway expects lowercase (amm, clmm) + }) + + async def pool_info(self, connector: str, network: str, pool_address: str) -> Dict: + """Get detailed information about a specific pool""" + return await self._request("POST", "clmm/liquidity/pool", json={ + "connector": connector, + "network": network, + "poolAddress": pool_address + }) + + # ============================================ + # Swap Operations + # ============================================ + + async def quote_swap( + self, + connector: str, + network: str, + base_asset: str, + quote_asset: str, + amount: float, + side: str, + slippage_pct: Optional[float] = None, + pool_address: Optional[str] = None + ) -> Dict: + """Get a quote for a swap""" + payload = { + "network": network, + "baseToken": base_asset, + "quoteToken": quote_asset, + "amount": str(amount), + "side": side.upper() + } + if slippage_pct is not None: + payload["slippagePct"] = slippage_pct + if pool_address: + payload["poolAddress"] = pool_address + + return await self._request("GET", f"connectors/{connector}/router/quote-swap", params=payload) + + async def execute_swap( + self, + connector: str, + network: str, + wallet_address: str, + base_asset: str, + quote_asset: str, + amount: float, + side: str, + slippage_pct: Optional[float] = None + ) -> Dict: + """Execute a swap""" + payload = { + "network": network, + "walletAddress": wallet_address, + "baseToken": base_asset, + "quoteToken": quote_asset, + "amount": str(amount), + "side": side.upper() + } + if slippage_pct is not None: + payload["slippagePct"] = slippage_pct + + return await self._request("POST", f"connectors/{connector}/router/execute-swap", json=payload) + + async def execute_quote( + self, + connector: str, + network: str, + wallet_address: str, + quote_id: str + ) -> Dict: + """Execute a previously obtained quote""" + return await self._request("POST", f"connectors/{connector}/router/execute-quote", json={ + "network": network, + "address": wallet_address, + "quoteId": quote_id + }) + + # ============================================ + # Liquidity Operations - CLMM (Concentrated Liquidity) + # ============================================ + + async def clmm_open_position( + self, + connector: str, + network: str, + wallet_address: str, + pool_address: str, + lower_price: float, + upper_price: float, + base_token_amount: Optional[float] = None, + quote_token_amount: Optional[float] = None, + slippage_pct: Optional[float] = None, + extra_params: Optional[Dict] = None + ) -> Dict: + """Open a NEW CLMM position with initial liquidity""" + payload = { + "network": network, + "walletAddress": wallet_address, + "poolAddress": pool_address, + "lowerPrice": lower_price, + "upperPrice": upper_price + } + if base_token_amount is not None: + payload["baseTokenAmount"] = str(base_token_amount) + if quote_token_amount is not None: + payload["quoteTokenAmount"] = str(quote_token_amount) + if slippage_pct is not None: + payload["slippagePct"] = slippage_pct + + # Add any connector-specific parameters + if extra_params: + payload.update(extra_params) + + return await self._request("POST", f"connectors/{connector}/clmm/open-position", json=payload) + + async def clmm_add_liquidity( + self, + connector: str, + network: str, + wallet_address: str, + position_address: str, + base_token_amount: Optional[float] = None, + quote_token_amount: Optional[float] = None, + slippage_pct: Optional[float] = None + ) -> Dict: + """Add more liquidity to an existing CLMM position""" + payload = { + "connector": connector, + "network": network, + "address": wallet_address, + "positionAddress": position_address + } + if base_token_amount is not None: + payload["baseTokenAmount"] = str(base_token_amount) + if quote_token_amount is not None: + payload["quoteTokenAmount"] = str(quote_token_amount) + if slippage_pct is not None: + payload["slippagePct"] = slippage_pct + + return await self._request("POST", "clmm/liquidity/add", json=payload) + + async def clmm_close_position( + self, + connector: str, + network: str, + wallet_address: str, + position_address: str + ) -> Dict: + """Close a CLMM position completely""" + return await self._request("POST", f"connectors/{connector}/clmm/close-position", json={ + "network": network, + "walletAddress": wallet_address, + "positionAddress": position_address + }) + + async def clmm_remove_liquidity( + self, + connector: str, + network: str, + wallet_address: str, + position_address: str, + percentage: float + ) -> Dict: + """Remove liquidity from a CLMM position (partial)""" + return await self._request("POST", "clmm/liquidity/remove", json={ + "connector": connector, + "network": network, + "address": wallet_address, + "positionAddress": position_address, + "percentage": percentage + }) + + async def clmm_position_info( + self, + connector: str, + chain_network: str, + position_address: str + ) -> Dict: + """ + Get CLMM position information including pending fees. + + Note: Gateway returns 500 instead of 404 when position doesn't exist (is closed). + Callers should treat 500 errors as "position not found/closed". + """ + # Validate required parameters + if not connector: + raise ValueError("connector is required for clmm_position_info") + if not chain_network: + raise ValueError("chain_network is required for clmm_position_info") + if not position_address: + raise ValueError("position_address is required for clmm_position_info") + + params = { + "connector": connector, + "chainNetwork": chain_network, + "positionAddress": position_address + } + return await self._request("GET", "trading/clmm/position-info", params=params) + + async def clmm_positions_owned( + self, + connector: str, + chain_network: str, + wallet_address: str, + pool_address: Optional[str] = None + ) -> List[Dict]: + """ + Get CLMM positions owned by a wallet. + + Args: + connector: CLMM connector (e.g., 'meteora', 'raydium') + chain_network: Chain and network in format 'chain-network' (e.g., 'solana-mainnet-beta') + wallet_address: Wallet address to query + pool_address: Optional pool address to filter positions. + If not provided, returns ALL positions across all pools. + + Returns: + List of position dictionaries with fields like: + - address: Position NFT address + - poolAddress: Pool address + - baseTokenAddress, quoteTokenAddress + - baseTokenAmount, quoteTokenAmount + - baseFeeAmount, quoteFeeAmount + - lowerBinId, upperBinId + - lowerPrice, upperPrice, price + """ + params = { + "connector": connector, + "chainNetwork": chain_network, + "walletAddress": wallet_address, + } + + # Only add poolAddress if specified (allows fetching all positions) + if pool_address: + params["poolAddress"] = pool_address + + return await self._request("GET", "trading/clmm/positions-owned", params=params) + + async def clmm_collect_fees( + self, + connector: str, + network: str, + wallet_address: str, + position_address: str + ) -> Dict: + """Collect accumulated fees from a CLMM position""" + return await self._request("POST", f"connectors/{connector}/clmm/collect-fees", json={ + "network": network, + "address": wallet_address, + "positionAddress": position_address + }) + + async def clmm_pool_info( + self, + connector: str, + network: str, + pool_address: str + ) -> Dict: + """Get detailed CLMM pool information by pool address""" + return await self._request("GET", f"connectors/{connector}/clmm/pool-info", params={ + "network": network, + "poolAddress": pool_address + }) + + # ============================================ + # Transaction Polling + # ============================================ + + async def poll_transaction( + self, + network_id: str, + tx_hash: str, + wallet_address: Optional[str] = None + ) -> Optional[Dict]: + """ + Poll transaction status on blockchain. + + Args: + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta', 'ethereum-mainnet') + tx_hash: Transaction hash/signature + wallet_address: Optional wallet address for verification + + Returns: + Transaction status dict with fields: + - txStatus: 1 for confirmed, 0 for failed/pending + - fee: Transaction fee amount + - txData: Full transaction data including meta.err + Returns None if Gateway is unavailable or request fails. + """ + try: + # Split network_id into chain and network + parts = network_id.split('-', 1) + if len(parts) != 2: + logger.error(f"Invalid network_id format: {network_id}. Expected 'chain-network'") + return None + + chain, network = parts + + payload = { + "network": network, + "signature": tx_hash + } + if wallet_address: + payload["walletAddress"] = wallet_address + + return await self._request("POST", f"chains/{chain}/poll", json=payload) + except Exception as e: + logger.error(f"Error polling transaction {tx_hash}: {e}") + return None + diff --git a/services/gateway_service.py b/services/gateway_service.py new file mode 100644 index 00000000..d7378838 --- /dev/null +++ b/services/gateway_service.py @@ -0,0 +1,361 @@ +import logging +import os +import platform +import shutil +from typing import Optional, Dict + +import docker +from docker.errors import DockerException +from docker.types import LogConfig + +from models.gateway import GatewayConfig, GatewayStatus + +# Create module-specific logger +logger = logging.getLogger(__name__) + + +class GatewayService: + """ + Service for managing the Hummingbot Gateway Docker container. + Ensures only one Gateway instance can exist at a time. + """ + + GATEWAY_CONTAINER_NAME = "gateway" + GATEWAY_DIR = "gateway-files" + + def __init__(self): + self.SOURCE_PATH = os.getcwd() + # Use BOTS_PATH if set (for Docker), otherwise use SOURCE_PATH (for local) + self.BOTS_PATH = os.environ.get('BOTS_PATH', self.SOURCE_PATH) + try: + self.client = docker.from_env() + except DockerException as e: + logger.error(f"Failed to connect to Docker. Error: {e}") + raise + + def _ensure_gateway_directories(self): + """Create necessary directories for Gateway if they don't exist""" + # Gateway files are at root level, same as bots directory + gateway_base = os.path.join(self.BOTS_PATH, self.GATEWAY_DIR) + + conf_dir = os.path.join(gateway_base, "conf") + logs_dir = os.path.join(gateway_base, "logs") + + os.makedirs(conf_dir, exist_ok=True) + os.makedirs(logs_dir, exist_ok=True) + + return { + "base": gateway_base, + "conf": conf_dir, + "logs": logs_dir + } + + def _get_gateway_container(self) -> Optional[docker.models.containers.Container]: + """Get the Gateway container if it exists""" + try: + return self.client.containers.get(self.GATEWAY_CONTAINER_NAME) + except docker.errors.NotFound: + return None + except DockerException as e: + logger.error(f"Error getting Gateway container: {e}") + return None + + def get_status(self) -> GatewayStatus: + """Get the current status of the Gateway container""" + container = self._get_gateway_container() + + if container is None: + return GatewayStatus( + running=False, + container_id=None, + image=None, + created_at=None, + port=None + ) + + # Extract port from container configuration + port = None + if container.status == "running": + # Check if using host networking + network_mode = container.attrs.get("HostConfig", {}).get("NetworkMode", "") + if network_mode == "host": + # Host networking: Gateway uses port 15888 directly + port = 15888 + else: + # Bridge networking: Extract from port mappings + ports = container.attrs.get("NetworkSettings", {}).get("Ports", {}) + if "15888/tcp" in ports and ports["15888/tcp"]: + port = int(ports["15888/tcp"][0]["HostPort"]) + + return GatewayStatus( + running=container.status == "running", + container_id=container.id, + image=container.image.tags[0] if container.image.tags else container.image.id[:12], + created_at=container.attrs.get("Created"), + port=port + ) + + def start(self, config: GatewayConfig) -> Dict[str, any]: + """ + Start the Gateway container. + If a container already exists, it will be stopped and removed before creating a new one. + """ + # Check if Gateway is already running + existing_container = self._get_gateway_container() + if existing_container: + if existing_container.status == "running": + return { + "success": False, + "message": f"Gateway is already running. Use stop first or restart to update configuration." + } + else: + # Remove stopped container + logger.info("Removing stopped Gateway container") + existing_container.remove(force=True) + + # Ensure directories exist + dirs = self._ensure_gateway_directories() + + # Set up volumes - use BOTS_PATH which contains the HOST path + volumes = { + os.path.join(self.BOTS_PATH, self.GATEWAY_DIR, "conf"): {'bind': '/home/gateway/conf', 'mode': 'rw'}, + os.path.join(self.BOTS_PATH, self.GATEWAY_DIR, "logs"): {'bind': '/home/gateway/logs', 'mode': 'rw'}, + } + + # Set up environment variables + environment = { + "GATEWAY_PASSPHRASE": config.passphrase, + "DEV": str(config.dev_mode).lower(), + } + + # Configure logging + log_config = LogConfig( + type="json-file", + config={ + 'max-size': '10m', + 'max-file': "5", + } + ) + + # Detect platform and configure networking + # Native Linux: Use host networking (works natively) + # Docker Desktop (macOS/Windows) or containerized: Use bridge networking + system_platform = platform.system() + + # Check if running inside Docker container (Docker Desktop or containerized API) + in_container = os.path.exists('/.dockerenv') or os.path.exists('/run/.containerenv') + + # Only use host networking on native Linux (not inside a container) + use_host_network = system_platform == "Linux" and not in_container + + if use_host_network: + logger.info("Detected native Linux - using host network mode for Gateway") + else: + logger.info(f"Detected {system_platform} (in_container={in_container}) - using bridge networking for Gateway") + + try: + # Build container configuration + container_config = { + "image": config.image, + "name": self.GATEWAY_CONTAINER_NAME, + "volumes": volumes, + "environment": environment, + "detach": True, + "restart_policy": {"Name": "always"}, + "log_config": log_config, + } + + if use_host_network: + # Linux: Use host networking + container_config["network_mode"] = "host" + else: + # macOS/Windows: Use bridge networking with port mapping + container_config["ports"] = {'15888/tcp': config.port} + + container = self.client.containers.run(**container_config) + + # On macOS/Windows, connect to emqx-bridge network if it exists + if not use_host_network: + possible_networks = ["hummingbot-api_emqx-bridge", "emqx-bridge"] + for net in possible_networks: + try: + network = self.client.networks.get(net) + network.connect(container) + logger.info(f"Connected Gateway to {net} network") + break + except docker.errors.NotFound: + continue + + logger.info(f"Gateway container started successfully: {container.id}") + return { + "success": True, + "message": f"Gateway started successfully", + "container_id": container.id, + "port": config.port + } + + except DockerException as e: + logger.error(f"Failed to start Gateway container: {e}") + return { + "success": False, + "message": f"Failed to start Gateway: {str(e)}" + } + + def stop(self) -> Dict[str, any]: + """Stop the Gateway container""" + container = self._get_gateway_container() + + if container is None: + return { + "success": False, + "message": "Gateway container not found" + } + + try: + if container.status == "running": + container.stop() + logger.info("Gateway container stopped") + return { + "success": True, + "message": "Gateway stopped successfully" + } + except DockerException as e: + logger.error(f"Failed to stop Gateway container: {e}") + return { + "success": False, + "message": f"Failed to stop Gateway: {str(e)}" + } + + def restart(self, config: Optional[GatewayConfig] = None) -> Dict[str, any]: + """ + Restart the Gateway container. + If config is provided, the container will be recreated with the new configuration. + """ + container = self._get_gateway_container() + + if container is None: + if config: + # No existing container, just start with new config + return self.start(config) + else: + return { + "success": False, + "message": "Gateway container not found. Use start with configuration to create one." + } + + if config: + # Stop and remove existing container, then start with new config + try: + container.remove(force=True) + logger.info("Removed existing Gateway container for restart with new config") + except DockerException as e: + logger.error(f"Failed to remove Gateway container: {e}") + return { + "success": False, + "message": f"Failed to remove existing container: {str(e)}" + } + return self.start(config) + else: + # Simple restart of existing container + try: + container.restart() + logger.info("Gateway container restarted") + return { + "success": True, + "message": "Gateway restarted successfully" + } + except DockerException as e: + logger.error(f"Failed to restart Gateway container: {e}") + return { + "success": False, + "message": f"Failed to restart Gateway: {str(e)}" + } + + def remove(self, remove_data: bool = False) -> Dict[str, any]: + """ + Remove the Gateway container and optionally its data. + + Args: + remove_data: If True, also remove the gateway-files directory + """ + container = self._get_gateway_container() + + if container is None: + if remove_data: + # No container, but try to remove data if requested + gateway_dir = os.path.join(self.SOURCE_PATH, self.GATEWAY_DIR) + if os.path.exists(gateway_dir): + try: + shutil.rmtree(gateway_dir) + logger.info(f"Removed Gateway data directory: {gateway_dir}") + return { + "success": True, + "message": "Gateway data removed (no container was found)" + } + except Exception as e: + logger.error(f"Failed to remove Gateway data: {e}") + return { + "success": False, + "message": f"Failed to remove Gateway data: {str(e)}" + } + return { + "success": False, + "message": "Gateway container not found" + } + + try: + # Remove container + container.remove(force=True) + logger.info("Gateway container removed") + + # Remove data if requested + if remove_data: + gateway_dir = os.path.join(self.SOURCE_PATH, self.GATEWAY_DIR) + if os.path.exists(gateway_dir): + shutil.rmtree(gateway_dir) + logger.info(f"Removed Gateway data directory: {gateway_dir}") + return { + "success": True, + "message": "Gateway container and data removed successfully" + } + + return { + "success": True, + "message": "Gateway container removed successfully" + } + + except DockerException as e: + logger.error(f"Failed to remove Gateway container: {e}") + return { + "success": False, + "message": f"Failed to remove Gateway: {str(e)}" + } + except Exception as e: + logger.error(f"Failed to remove Gateway data: {e}") + return { + "success": False, + "message": f"Gateway container removed but failed to remove data: {str(e)}" + } + + def get_logs(self, tail: int = 100) -> Dict[str, any]: + """Get logs from the Gateway container""" + container = self._get_gateway_container() + + if container is None: + return { + "success": False, + "message": "Gateway container not found" + } + + try: + logs = container.logs(tail=tail, timestamps=True).decode('utf-8') + return { + "success": True, + "logs": logs + } + except DockerException as e: + logger.error(f"Failed to get Gateway logs: {e}") + return { + "success": False, + "message": f"Failed to get logs: {str(e)}" + } diff --git a/services/gateway_transaction_poller.py b/services/gateway_transaction_poller.py new file mode 100644 index 00000000..991fa44d --- /dev/null +++ b/services/gateway_transaction_poller.py @@ -0,0 +1,802 @@ +""" +Gateway Transaction Poller + +This service polls blockchain transactions to confirm Gateway swap and CLMM operations. +Unlike CEX connectors that emit events, DEX transactions require active polling until confirmation. + +Additionally polls CLMM position state to keep database in sync with on-chain state. +""" +import asyncio +import logging +from typing import Optional, Dict, List +from datetime import datetime, timedelta, timezone +from decimal import Decimal + +from sqlalchemy import select +from sqlalchemy.orm import selectinload + +from database import AsyncDatabaseManager +from database.repositories import GatewaySwapRepository, GatewayCLMMRepository +from database.models import GatewayCLMMEvent, GatewayCLMMPosition +from services.gateway_client import GatewayClient + +logger = logging.getLogger(__name__) + + +class GatewayTransactionPoller: + """ + Polls Gateway for transaction status updates and position state. + + - Transaction polling: Confirms pending swap/CLMM transactions + - Position polling: Updates CLMM position state (in_range, liquidity, fees) + + Unlike CEX connectors that emit events when orders fill, DEX transactions + need to be polled until they are confirmed on-chain or fail. + """ + + def __init__( + self, + db_manager: AsyncDatabaseManager, + gateway_client: GatewayClient, + poll_interval: int = 10, # Poll every 10 seconds for transactions + position_poll_interval: int = 300, # Poll every 5 minutes for positions + max_retry_age: int = 3600 # Stop retrying after 1 hour + ): + self.db_manager = db_manager + self.gateway_client = gateway_client + self.poll_interval = poll_interval + self.position_poll_interval = position_poll_interval + self.max_retry_age = max_retry_age + self._running = False + self._poll_task: Optional[asyncio.Task] = None + self._position_poll_task: Optional[asyncio.Task] = None + self._last_position_poll: Optional[datetime] = None + + async def start(self): + """Start the polling service.""" + if self._running: + logger.warning("GatewayTransactionPoller already running") + return + + self._running = True + self._poll_task = asyncio.create_task(self._poll_loop()) + self._position_poll_task = asyncio.create_task(self._position_poll_loop()) + logger.info(f"GatewayTransactionPoller started (tx_poll={self.poll_interval}s, pos_poll={self.position_poll_interval}s)") + + async def stop(self): + """Stop the polling service.""" + if not self._running: + return + + self._running = False + + # Cancel transaction polling task + if self._poll_task: + self._poll_task.cancel() + try: + await self._poll_task + except asyncio.CancelledError: + pass + + # Cancel position polling task + if self._position_poll_task: + self._position_poll_task.cancel() + try: + await self._position_poll_task + except asyncio.CancelledError: + pass + + logger.info("GatewayTransactionPoller stopped") + + async def _poll_loop(self): + """Main polling loop.""" + while self._running: + try: + await self._poll_pending_transactions() + except Exception as e: + logger.error(f"Error in poll loop: {e}", exc_info=True) + + # Wait before next poll + try: + await asyncio.sleep(self.poll_interval) + except asyncio.CancelledError: + break + + async def _poll_pending_transactions(self): + """Poll all pending transactions and update their status.""" + try: + async with self.db_manager.get_session_context() as session: + swap_repo = GatewaySwapRepository(session) + clmm_repo = GatewayCLMMRepository(session) + + # Get pending swaps + pending_swaps = await swap_repo.get_pending_swaps(limit=100) + logger.debug(f"Found {len(pending_swaps)} pending swaps") + + for swap in pending_swaps: + # Skip if too old (likely failed without proper error) + age = (datetime.now(timezone.utc) - swap.timestamp).total_seconds() + if age > self.max_retry_age: + logger.warning(f"Swap {swap.transaction_hash} exceeded max retry age, marking as FAILED") + await swap_repo.update_swap_status( + transaction_hash=swap.transaction_hash, + status="FAILED", + error_message="Transaction confirmation timeout" + ) + continue + + # Poll transaction status + await self._poll_swap_transaction(swap, swap_repo) + + # Get pending CLMM events + pending_events = await clmm_repo.get_pending_events(limit=100) + logger.debug(f"Found {len(pending_events)} pending CLMM events") + + for event in pending_events: + # Skip if too old + age = (datetime.now(timezone.utc) - event.timestamp).total_seconds() + if age > self.max_retry_age: + logger.warning(f"CLMM event {event.transaction_hash} exceeded max retry age, marking as FAILED") + await clmm_repo.update_event_status( + transaction_hash=event.transaction_hash, + status="FAILED", + error_message="Transaction confirmation timeout" + ) + continue + + # Poll transaction status + await self._poll_clmm_event_transaction(event, clmm_repo) + + except Exception as e: + logger.error(f"Error polling pending transactions: {e}", exc_info=True) + + async def _poll_swap_transaction(self, swap, swap_repo: GatewaySwapRepository): + """Poll a specific swap transaction status.""" + try: + # Parse network into chain and network + parts = swap.network.split('-', 1) + if len(parts) != 2: + logger.error(f"Invalid network format for swap {swap.transaction_hash}: {swap.network}") + return + + chain, network = parts + + # Check transaction status on Gateway/blockchain + # Note: This is a placeholder - actual implementation depends on Gateway API + status_result = await self._check_transaction_status( + chain=chain, + network=network, + tx_hash=swap.transaction_hash + ) + + if status_result: + if status_result["status"] == "CONFIRMED": + logger.info(f"Swap transaction confirmed: {swap.transaction_hash}") + await swap_repo.update_swap_status( + transaction_hash=swap.transaction_hash, + status="CONFIRMED", + gas_fee=Decimal(str(status_result.get("gas_fee", 0))) if status_result.get("gas_fee") else None, + gas_token=status_result.get("gas_token") + ) + elif status_result["status"] == "FAILED": + logger.warning(f"Swap transaction failed: {swap.transaction_hash}") + await swap_repo.update_swap_status( + transaction_hash=swap.transaction_hash, + status="FAILED", + error_message=status_result.get("error_message", "Transaction failed on-chain") + ) + # If status is still pending, do nothing and retry later + + except Exception as e: + logger.error(f"Error polling swap transaction {swap.transaction_hash}: {e}") + + async def _poll_clmm_event_transaction(self, event, clmm_repo: GatewayCLMMRepository): + """Poll a specific CLMM event transaction status.""" + try: + # Get the position by ID from the event's position_id foreign key + result = await clmm_repo.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.id == event.position_id) + ) + position = result.scalar_one_or_none() + + if not position: + logger.error(f"Position not found for CLMM event {event.transaction_hash}") + return + + # Parse network + parts = position.network.split('-', 1) + if len(parts) != 2: + logger.error(f"Invalid network format for CLMM event {event.transaction_hash}: {position.network}") + return + + chain, network = parts + + # Check transaction status + status_result = await self._check_transaction_status( + chain=chain, + network=network, + tx_hash=event.transaction_hash + ) + + if status_result: + if status_result["status"] == "CONFIRMED": + logger.info(f"CLMM event transaction confirmed: {event.transaction_hash}") + await clmm_repo.update_event_status( + transaction_hash=event.transaction_hash, + status="CONFIRMED", + gas_fee=Decimal(str(status_result.get("gas_fee", 0))) if status_result.get("gas_fee") else None, + gas_token=status_result.get("gas_token") + ) + + # Update position state based on event type + await self._update_position_from_event(event, clmm_repo) + + elif status_result["status"] == "FAILED": + logger.warning(f"CLMM event transaction failed: {event.transaction_hash}") + await clmm_repo.update_event_status( + transaction_hash=event.transaction_hash, + status="FAILED", + error_message=status_result.get("error_message", "Transaction failed on-chain") + ) + + except Exception as e: + logger.error(f"Error polling CLMM event transaction {event.transaction_hash}: {e}") + + async def _update_position_from_event(self, event, clmm_repo: GatewayCLMMRepository): + """Update CLMM position state based on confirmed event.""" + try: + # Get position by ID using the existing clmm_repo session + result = await clmm_repo.session.execute( + select(GatewayCLMMPosition).where(GatewayCLMMPosition.id == event.position_id) + ) + position = result.scalar_one_or_none() + + if not position: + logger.error(f"Position not found for event {event.id}") + return + + if event.event_type == "CLOSE": + await clmm_repo.close_position(position.position_address) + + elif event.event_type == "COLLECT_FEES": + # Add collected fees to cumulative total + if event.base_fee_collected or event.quote_fee_collected: + new_base_collected = float(position.base_fee_collected or 0) + float(event.base_fee_collected or 0) + new_quote_collected = float(position.quote_fee_collected or 0) + float(event.quote_fee_collected or 0) + + await clmm_repo.update_position_fees( + position_address=position.position_address, + base_fee_collected=Decimal(str(new_base_collected)), + quote_fee_collected=Decimal(str(new_quote_collected)), + base_fee_pending=Decimal("0"), + quote_fee_pending=Decimal("0") + ) + + except Exception as e: + logger.error(f"Error updating position from event: {e}", exc_info=True) + + async def _check_transaction_status( + self, + chain: str, + network: str, + tx_hash: str + ) -> Optional[Dict]: + """ + Check transaction status on blockchain via Gateway. + + Returns: + Dict with status, gas_fee, gas_token, and error_message if available. + None if transaction not yet confirmed or pending. + """ + try: + # Check if Gateway is available + if not await self.gateway_client.ping(): + logger.warning("Gateway not available for transaction polling") + return None + + # Reconstruct network_id from chain and network + network_id = f"{chain}-{network}" + + # Poll transaction status from Gateway + result = await self.gateway_client.poll_transaction( + network_id=network_id, + tx_hash=tx_hash + ) + + # Check if we got a valid response + if result is None or not isinstance(result, dict): + logger.warning(f"Invalid response from Gateway for transaction {tx_hash} on {network_id}: {result}") + return None + + logger.debug(f"Polled transaction {tx_hash} on {network_id}: txStatus={result.get('txStatus')}") + + # Parse the response with defensive checks + tx_status = result.get("txStatus") + tx_data = result.get("txData") or {} + meta = tx_data.get("meta") if isinstance(tx_data, dict) else {} + error = meta.get("err") if isinstance(meta, dict) else None + + # Determine gas token based on chain + gas_token = { + "solana": "SOL", + "ethereum": "ETH", + "arbitrum": "ETH", + "optimism": "ETH", + "polygon": "MATIC", + "avalanche": "AVAX" + }.get(chain, "UNKNOWN") + + # Transaction is confirmed if txStatus == 1 and no error + if tx_status == 1 and error is None: + return { + "status": "CONFIRMED", + "gas_fee": result.get("fee", 0), + "gas_token": gas_token, + "error_message": None + } + + # Transaction failed if there's an error + if error is not None: + error_msg = str(error) if error else "Transaction failed on-chain" + return { + "status": "FAILED", + "gas_fee": result.get("fee", 0), + "gas_token": gas_token, + "error_message": error_msg + } + + # Transaction still pending (txStatus == 0 or not finalized) + return None + + except Exception as e: + logger.error(f"Error checking transaction status for {tx_hash}: {e}") + return None + + async def poll_transaction_once(self, tx_hash: str, network_id: str, wallet_address: Optional[str] = None) -> Optional[Dict]: + """ + Poll a specific transaction once (useful for immediate status checks). + + Args: + tx_hash: Transaction hash + network_id: Network ID in format 'chain-network' (e.g., 'solana-mainnet-beta') + wallet_address: Optional wallet address for verification + + Returns: + Transaction status dict or None if pending + """ + parts = network_id.split('-', 1) + if len(parts) != 2: + logger.error(f"Invalid network format: {network_id}") + return None + + chain, network = parts + return await self._check_transaction_status(chain, network, tx_hash) + + # ============================================ + # Position State Polling & Discovery + # ============================================ + + # Supported CLMM connectors and their default networks + SUPPORTED_CLMM_CONFIGS = [ + {"connector": "meteora", "chain": "solana", "network": "mainnet-beta"}, + # Add more connectors as they become supported: + # {"connector": "raydium", "chain": "solana", "network": "mainnet-beta"}, + # {"connector": "uniswap", "chain": "ethereum", "network": "mainnet"}, + ] + + async def _position_poll_loop(self): + """Position state polling loop (runs less frequently).""" + while self._running: + try: + # Check if it's time to poll positions + now = datetime.now(timezone.utc) + if self._last_position_poll is None or \ + (now - self._last_position_poll).total_seconds() >= self.position_poll_interval: + await self._poll_and_discover_positions() + self._last_position_poll = now + + # Sleep for a short time to avoid busy waiting + await asyncio.sleep(10) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in position poll loop: {e}", exc_info=True) + await asyncio.sleep(10) + + async def _poll_and_discover_positions(self): + """ + Main position polling method that: + 1. Discovers new positions from Gateway (created via UI or other means) + 2. Updates all open positions with latest state + """ + try: + # Check if Gateway is available + if not await self.gateway_client.ping(): + logger.debug("Gateway not available, skipping position polling") + return + + # Step 1: Discover new positions from Gateway + discovered_count = await self._discover_positions_from_gateway() + if discovered_count > 0: + logger.info(f"Discovered {discovered_count} new positions from Gateway") + + # Step 2: Update all open positions + await self._update_all_open_positions() + + except Exception as e: + logger.error(f"Error in position poll and discovery: {e}", exc_info=True) + + async def _discover_positions_from_gateway(self) -> int: + """ + Discover positions from Gateway that aren't tracked in the database, + and reopen positions that were incorrectly marked as closed. + + This allows tracking positions created directly via UI or other means, + not just those created through the API. + + Also corrects data inconsistencies where a position was marked CLOSED + in the database but is still OPEN on-chain (e.g., due to a failed close + transaction). + + Returns: + Number of newly discovered + reopened positions + """ + discovered_count = 0 + reopened_count = 0 + + try: + # Get all wallet addresses for supported chains + wallet_addresses_by_chain = await self.gateway_client.get_all_wallet_addresses() + if not wallet_addresses_by_chain: + logger.debug("No wallets configured in Gateway, skipping position discovery") + return 0 + + # Get existing position addresses from database (for quick existence check) + async with self.db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + # Get OPEN positions (to skip - already tracked correctly) + open_positions = await clmm_repo.get_position_addresses_set(status="OPEN") + # Get CLOSED positions (to potentially reopen if still on-chain) + closed_positions = await clmm_repo.get_position_addresses_set(status="CLOSED") + + # Poll each supported connector/chain/wallet combination + for config in self.SUPPORTED_CLMM_CONFIGS: + connector = config["connector"] + chain = config["chain"] + network = config["network"] + + # Get wallet addresses for this chain + wallet_addresses = wallet_addresses_by_chain.get(chain, []) + if not wallet_addresses: + continue + + for wallet_address in wallet_addresses: + try: + # Fetch ALL positions for this wallet (no pool filter) + chain_network = f"{chain}-{network}" + gateway_positions = await self.gateway_client.clmm_positions_owned( + connector=connector, + chain_network=chain_network, + wallet_address=wallet_address, + pool_address=None # Get all positions across all pools + ) + + if not gateway_positions or not isinstance(gateway_positions, list): + continue + + # Process each position + for pos_data in gateway_positions: + position_address = pos_data.get("address") + if not position_address: + continue + + # Skip if already tracked as OPEN + if position_address in open_positions: + continue + + # Check if position was incorrectly marked as CLOSED + if position_address in closed_positions: + # Position exists on-chain but is CLOSED in DB → reopen it + async with self.db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + reopened = await clmm_repo.reopen_position(position_address) + if reopened: + reopened_count += 1 + # Move from closed to open set for this run + closed_positions.discard(position_address) + open_positions.add(position_address) + logger.warning(f"Reopened position {position_address} - " + f"was CLOSED in DB but still exists on-chain") + continue + + # Create new position in database + new_position = await self._create_discovered_position( + pos_data=pos_data, + connector=connector, + chain=chain, + network=network, + wallet_address=wallet_address + ) + + if new_position: + discovered_count += 1 + open_positions.add(position_address) + logger.info(f"Discovered new position: {position_address} " + f"(pool: {pos_data.get('poolAddress', 'unknown')[:16]}...)") + + except Exception as e: + logger.warning(f"Error discovering positions for {connector}/{chain}/{wallet_address}: {e}") + continue + + except Exception as e: + logger.error(f"Error in position discovery: {e}", exc_info=True) + + if reopened_count > 0: + logger.info(f"Position discovery complete: {discovered_count} new, {reopened_count} reopened") + + return discovered_count + reopened_count + + async def _create_discovered_position( + self, + pos_data: Dict, + connector: str, + chain: str, + network: str, + wallet_address: str + ) -> Optional[GatewayCLMMPosition]: + """ + Create a database record for a discovered position. + + These positions were created externally (e.g., via UI) and are being + discovered by the poller. + """ + try: + position_address = pos_data.get("address") + pool_address = pos_data.get("poolAddress", "") + + # Extract token addresses + base_token_address = pos_data.get("baseTokenAddress", "") + quote_token_address = pos_data.get("quoteTokenAddress", "") + + # Use full addresses as tokens (consistent with API-created positions) + base_token = base_token_address if base_token_address else "UNKNOWN" + quote_token = quote_token_address if quote_token_address else "UNKNOWN" + trading_pair = f"{base_token}-{quote_token}" + + # Extract price data + current_price = float(pos_data.get("price", 0)) + lower_price = float(pos_data.get("lowerPrice", 0)) + upper_price = float(pos_data.get("upperPrice", 0)) + + # Extract liquidity amounts + base_token_amount = float(pos_data.get("baseTokenAmount", 0)) + quote_token_amount = float(pos_data.get("quoteTokenAmount", 0)) + + # Extract fee data + base_fee_pending = float(pos_data.get("baseFeeAmount", 0)) + quote_fee_pending = float(pos_data.get("quoteFeeAmount", 0)) + + # Extract bin IDs (for Meteora) + lower_bin_id = pos_data.get("lowerBinId") + upper_bin_id = pos_data.get("upperBinId") + + # Calculate in_range status + in_range = "UNKNOWN" + if current_price > 0 and lower_price > 0 and upper_price > 0: + if lower_price <= current_price <= upper_price: + in_range = "IN_RANGE" + else: + in_range = "OUT_OF_RANGE" + + # Calculate percentage: (upper_price - lower_price) / lower_price + percentage = None + if lower_price > 0: + percentage = (upper_price - lower_price) / lower_price + + # Network in unified format + network_id = f"{chain}-{network}" + + # Create position in database + async with self.db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + + position_data = { + "position_address": position_address, + "pool_address": pool_address, + "network": network_id, + "connector": connector, + "wallet_address": wallet_address, + "trading_pair": trading_pair, + "base_token": base_token, + "quote_token": quote_token, + "status": "OPEN", + "lower_price": lower_price, + "upper_price": upper_price, + "lower_bin_id": lower_bin_id, + "upper_bin_id": upper_bin_id, + "entry_price": current_price, # Best available estimate + "current_price": current_price, + "percentage": percentage, + # For discovered positions, we don't know initial amounts + # Use current amounts as initial (best estimate) + "initial_base_token_amount": base_token_amount, + "initial_quote_token_amount": quote_token_amount, + "base_token_amount": base_token_amount, + "quote_token_amount": quote_token_amount, + "in_range": in_range, + "base_fee_pending": base_fee_pending, + "quote_fee_pending": quote_fee_pending, + "base_fee_collected": 0, + "quote_fee_collected": 0, + } + + position = await clmm_repo.create_position(position_data) + + # Create a DISCOVERED event to mark this position was auto-discovered + event_data = { + "position_id": position.id, + "transaction_hash": f"discovered_{position_address[:16]}", # Synthetic tx hash + "event_type": "DISCOVERED", + "base_token_amount": base_token_amount, + "quote_token_amount": quote_token_amount, + "status": "CONFIRMED" # No actual transaction to confirm + } + await clmm_repo.create_event(event_data) + + return position + + except Exception as e: + logger.error(f"Error creating discovered position {pos_data.get('address')}: {e}", exc_info=True) + return None + + async def _update_all_open_positions(self): + """Update state for all open positions from Gateway.""" + try: + async with self.db_manager.get_session_context() as session: + clmm_repo = GatewayCLMMRepository(session) + + # Get all open positions + open_positions = await clmm_repo.get_open_positions() + if not open_positions: + logger.debug("No open CLMM positions to update") + return + + logger.info(f"Updating {len(open_positions)} open CLMM positions") + + # Update each position within the same session + for position in open_positions: + try: + await self._refresh_position_state(position, clmm_repo) + except Exception as e: + logger.warning(f"Failed to update position {position.position_address}: {e}") + continue + + except Exception as e: + logger.error(f"Error updating open positions: {e}", exc_info=True) + + # Legacy method name for backwards compatibility + async def _poll_open_positions(self): + """Poll all open CLMM positions and update their state. (Legacy wrapper)""" + await self._poll_and_discover_positions() + + async def _refresh_position_state(self, position: GatewayCLMMPosition, clmm_repo: GatewayCLMMRepository): + """ + Refresh a single position's state from Gateway. + + Updates: + - in_range status + - liquidity amounts + - pending fees + - position status (if closed externally) + """ + try: + # Validate position has required fields + if not position.position_address: + logger.error(f"Position ID {position.id} has no position_address, skipping refresh") + return + if not position.wallet_address: + logger.error(f"Position {position.position_address} has no wallet_address, skipping refresh") + return + if not position.connector: + logger.error(f"Position {position.position_address} has no connector, skipping refresh") + return + if not position.network: + logger.error(f"Position {position.position_address} has no network, skipping refresh") + return + + # Get individual position info from Gateway (includes pending fees) + try: + result = await self.gateway_client.clmm_position_info( + connector=position.connector, + chain_network=position.network, # position.network is already in 'chain-network' format + position_address=position.position_address + ) + + # Check for Gateway errors + if result is None: + logger.debug(f"Gateway connection error for position {position.position_address}, skipping update") + return + + if not isinstance(result, dict): + logger.warning(f"Unexpected response type for position {position.position_address}: {type(result)}") + return + + # Check if Gateway returned an error response + if "error" in result: + status_code = result.get("status") + + # Gateway returns 500 instead of 404 when position doesn't exist (closed) + # Treat any error (404 or 500) on position-info as "position closed" + if status_code in (404, 500): + logger.info(f"Position {position.position_address} not found on Gateway (status: {status_code}), marking as CLOSED") + await clmm_repo.close_position(position.position_address) + return + # Other errors → skip update, don't close + logger.debug(f"Gateway error for position {position.position_address}: {result.get('error')} (status: {status_code})") + return + + # Validate response has required fields + if "address" not in result: + logger.warning(f"Invalid response for position {position.position_address}, missing 'address' field") + return + + except Exception as e: + logger.warning(f"Error fetching position {position.position_address} from Gateway: {e}") + return + + # Extract current state + current_price = Decimal(str(result.get("price", 0))) + lower_price = Decimal(str(result.get("lowerPrice", 0))) if result.get("lowerPrice") else Decimal("0") + upper_price = Decimal(str(result.get("upperPrice", 0))) if result.get("upperPrice") else Decimal("0") + + # Calculate in_range status + in_range = "UNKNOWN" + if current_price > 0 and lower_price > 0 and upper_price > 0: + if lower_price <= current_price <= upper_price: + in_range = "IN_RANGE" + else: + in_range = "OUT_OF_RANGE" + + # Extract token amounts - validate they exist in response + base_amount_raw = result.get("baseTokenAmount") + quote_amount_raw = result.get("quoteTokenAmount") + + # If amounts are missing or None, skip update (don't assume zero) + if base_amount_raw is None or quote_amount_raw is None: + logger.warning(f"Position {position.position_address} missing token amounts in response, skipping update") + return + + base_token_amount = Decimal(str(base_amount_raw)) + quote_token_amount = Decimal(str(quote_amount_raw)) + + # If Gateway confirms zero liquidity, position was closed externally + if base_token_amount == 0 and quote_token_amount == 0: + logger.info(f"Position {position.position_address} has zero liquidity, marking as CLOSED") + await clmm_repo.close_position(position.position_address) + return + + # Update liquidity amounts, in_range status, and current price + await clmm_repo.update_position_liquidity( + position_address=position.position_address, + base_token_amount=base_token_amount, + quote_token_amount=quote_token_amount, + in_range=in_range, + current_price=current_price + ) + + # Update pending fees (always update to keep in sync with on-chain state) + base_fee_pending = Decimal(str(result.get("baseFeeAmount", 0))) + quote_fee_pending = Decimal(str(result.get("quoteFeeAmount", 0))) + + await clmm_repo.update_position_fees( + position_address=position.position_address, + base_fee_pending=base_fee_pending, + quote_fee_pending=quote_fee_pending + ) + + logger.debug(f"Refreshed position {position.position_address}: price={current_price}, in_range={in_range}, " + f"base={base_token_amount}, quote={quote_token_amount}, " + f"base_fee={base_fee_pending}, quote_fee={quote_fee_pending}") + + except Exception as e: + logger.error(f"Error refreshing position state {position.position_address}: {e}", exc_info=True) diff --git a/services/market_data_service.py b/services/market_data_service.py new file mode 100644 index 00000000..048e1c11 --- /dev/null +++ b/services/market_data_service.py @@ -0,0 +1,758 @@ +""" +Market Data Service - Centralized market data access with proper connector integration. + +This service provides access to market data (candles, order books, prices, trading rules) +using the UnifiedConnectorService to ensure proper connector usage. +""" +import asyncio +import time +import logging +from typing import Dict, Optional, List, Any, Tuple +from decimal import Decimal +from enum import Enum + +from hummingbot.core.rate_oracle.rate_oracle import RateOracle +from hummingbot.data_feed.candles_feed.data_types import CandlesConfig +from hummingbot.data_feed.candles_feed.candles_factory import CandlesFactory, UnsupportedConnectorException + + +logger = logging.getLogger(__name__) + + +class FeedType(Enum): + """Types of market data feeds that can be managed.""" + CANDLES = "candles" + ORDER_BOOK = "order_book" + TRADES = "trades" + TICKER = "ticker" + + +class MarketDataService: + """ + Centralized market data service using UnifiedConnectorService. + + This service manages: + - Candles feeds with automatic lifecycle management + - Order book access via UnifiedConnectorService + - Price and trading rules queries + - Feed cleanup for unused data streams + """ + + def __init__( + self, + connector_service: "UnifiedConnectorService", + rate_oracle: RateOracle, + cleanup_interval: int = 300, + feed_timeout: int = 600 + ): + """ + Initialize the MarketDataService. + + Args: + connector_service: UnifiedConnectorService for connector access + rate_oracle: RateOracle instance for price conversions + cleanup_interval: How often to run cleanup (seconds, default: 5 minutes) + feed_timeout: How long to keep unused feeds alive (seconds, default: 10 minutes) + """ + self._connector_service = connector_service + self._rate_oracle = rate_oracle + self._cleanup_interval = cleanup_interval + self._feed_timeout = feed_timeout + + # Candle feeds management + self._candle_feeds: Dict[str, Any] = {} + self._last_access_times: Dict[str, float] = {} + self._feed_configs: Dict[str, Tuple[FeedType, Any]] = {} + + # Background tasks + self._cleanup_task: Optional[asyncio.Task] = None + self._is_running = False + + logger.info("MarketDataService initialized") + + # ==================== Lifecycle ==================== + + def start(self): + """Start the market data service.""" + if not self._is_running: + self._is_running = True + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + self._rate_oracle.start() + logger.info( + f"MarketDataService started with cleanup_interval={self._cleanup_interval}s, " + f"feed_timeout={self._feed_timeout}s" + ) + + async def warmup_rate_oracle(self): + """Eagerly fetch prices so the oracle cache is populated before the first portfolio query.""" + try: + prices = await self._rate_oracle._source.get_prices(quote_token=self._rate_oracle.quote_token) + self._rate_oracle._prices.update(prices) + logger.info(f"RateOracle warmed up with {len(prices)} prices") + except Exception as e: + logger.warning(f"RateOracle warmup failed: {e}") + + def stop(self): + """Stop the market data service and cleanup all feeds.""" + self._is_running = False + + if self._cleanup_task: + self._cleanup_task.cancel() + self._cleanup_task = None + + # Stop all candle feeds + for feed_key, feed in self._candle_feeds.items(): + try: + feed.stop() + except Exception as e: + logger.error(f"Error stopping candle feed {feed_key}: {e}") + + self._candle_feeds.clear() + self._last_access_times.clear() + self._feed_configs.clear() + + logger.info("MarketDataService stopped") + + # ==================== Order Book Access ==================== + + async def initialize_order_book( + self, + connector_name: str, + trading_pair: str, + account_name: Optional[str] = None, + timeout: float = 30.0 + ) -> bool: + """ + Initialize an order book for a trading pair. + + Uses the UnifiedConnectorService to get the best available connector + (prefers trading connectors which already have order book trackers running). + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair (e.g., "SOL-FDUSD") + account_name: Optional account name for trading connector preference + timeout: Timeout for waiting for order book to be ready + + Returns: + True if order book is ready, False otherwise + """ + return await self._connector_service.initialize_order_book( + connector_name=connector_name, + trading_pair=trading_pair, + account_name=account_name, + timeout=timeout + ) + + async def remove_trading_pair( + self, + connector_name: str, + trading_pair: str, + account_name: Optional[str] = None + ) -> bool: + """ + Remove a trading pair from order book tracking. + + Cleans up order book resources for a trading pair that is no longer needed. + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair to remove + account_name: Optional account name for trading connector preference + + Returns: + True if successfully removed, False otherwise + """ + # Clean up our local tracking for this feed + feed_key = self._generate_feed_key(FeedType.ORDER_BOOK, connector_name, trading_pair) + self._last_access_times.pop(feed_key, None) + self._feed_configs.pop(feed_key, None) + + return await self._connector_service.remove_trading_pair( + connector_name=connector_name, + trading_pair=trading_pair, + account_name=account_name + ) + + def get_order_book(self, connector_name: str, trading_pair: str, account_name: Optional[str] = None): + """ + Get order book for a trading pair. + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair + account_name: Optional account name for trading connector preference + + Returns: + OrderBook instance or None + """ + feed_key = self._generate_feed_key(FeedType.ORDER_BOOK, connector_name, trading_pair) + self._last_access_times[feed_key] = time.time() + self._feed_configs[feed_key] = (FeedType.ORDER_BOOK, (connector_name, trading_pair)) + + connector = self._connector_service.get_best_connector_for_market( + connector_name, account_name + ) + + if connector and hasattr(connector, 'order_book_tracker'): + tracker = connector.order_book_tracker + if tracker and trading_pair in tracker.order_books: + return tracker.order_books[trading_pair] + + logger.warning(f"No order book found for {connector_name}/{trading_pair}") + return None + + def get_order_book_snapshot( + self, + connector_name: str, + trading_pair: str, + account_name: Optional[str] = None + ) -> Optional[Tuple]: + """ + Get order book snapshot (bids, asks DataFrames). + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair + account_name: Optional account name for trading connector preference + + Returns: + Tuple of (bids_df, asks_df) or None + """ + order_book = self.get_order_book(connector_name, trading_pair, account_name) + if order_book: + try: + return order_book.snapshot + except Exception as e: + logger.error(f"Error getting order book snapshot: {e}") + return None + + async def get_order_book_data( + self, + connector_name: str, + trading_pair: str, + depth: int = 10, + account_name: Optional[str] = None + ) -> Dict: + """ + Get order book data as a dictionary. + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair + depth: Number of bid/ask levels to return + account_name: Optional account name for trading connector preference + + Returns: + Dictionary with bids, asks, and metadata + """ + try: + connector = self._connector_service.get_best_connector_for_market( + connector_name, account_name + ) + + if not connector: + return {"error": f"No connector available for {connector_name}"} + + # Try to get from existing order book tracker + if hasattr(connector, 'order_book_tracker') and connector.order_book_tracker: + tracker = connector.order_book_tracker + if trading_pair in tracker.order_books: + order_book = tracker.order_books[trading_pair] + snapshot = order_book.snapshot + + return { + "trading_pair": trading_pair, + "bids": snapshot[0].head(depth)[["price", "amount"]].values.tolist(), + "asks": snapshot[1].head(depth)[["price", "amount"]].values.tolist(), + "timestamp": time.time() + } + + # Fallback to getting fresh order book from data source + if hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + orderbook_ds = connector._orderbook_ds + order_book = await orderbook_ds.get_new_order_book(trading_pair) + snapshot = order_book.snapshot + + return { + "trading_pair": trading_pair, + "bids": snapshot[0].head(depth)[["price", "amount"]].values.tolist(), + "asks": snapshot[1].head(depth)[["price", "amount"]].values.tolist(), + "timestamp": time.time() + } + + return {"error": f"Order book not available for {connector_name}/{trading_pair}"} + + except Exception as e: + logger.error(f"Error getting order book data for {connector_name}/{trading_pair}: {e}") + return {"error": str(e)} + + async def get_order_book_query_result( + self, + connector_name: str, + trading_pair: str, + is_buy: bool, + account_name: Optional[str] = None, + **kwargs + ) -> Dict: + """ + Query order book for price/volume calculations. + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair + is_buy: True for buy side, False for sell side + account_name: Optional account name + **kwargs: Query parameters (volume, price, quote_volume, etc.) + + Returns: + Query result dictionary + """ + try: + current_time = time.time() + connector = self._connector_service.get_best_connector_for_market( + connector_name, account_name + ) + + if not connector: + return {"error": f"No connector available for {connector_name}"} + + # Get order book + order_book = None + if hasattr(connector, 'order_book_tracker') and connector.order_book_tracker: + tracker = connector.order_book_tracker + if trading_pair in tracker.order_books: + order_book = tracker.order_books[trading_pair] + + if not order_book and hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + order_book = await connector._orderbook_ds.get_new_order_book(trading_pair) + + if not order_book: + return {"error": f"No order book available for {connector_name}/{trading_pair}"} + + # Process query + if 'volume' in kwargs: + result = order_book.get_price_for_volume(is_buy, kwargs['volume']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_volume": kwargs['volume'], + "result_price": float(result.result_price) if result.result_price else None, + "result_volume": float(result.result_volume) if result.result_volume else None, + "timestamp": current_time + } + + elif 'price' in kwargs: + result = order_book.get_volume_for_price(is_buy, kwargs['price']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_price": kwargs['price'], + "result_volume": float(result.result_volume) if result.result_volume else None, + "result_price": float(result.result_price) if result.result_price else None, + "timestamp": current_time + } + + elif 'vwap_volume' in kwargs: + result = order_book.get_vwap_for_volume(is_buy, kwargs['vwap_volume']) + return { + "trading_pair": trading_pair, + "is_buy": is_buy, + "query_volume": kwargs['vwap_volume'], + "average_price": float(result.result_price) if result.result_price else None, + "result_volume": float(result.result_volume) if result.result_volume else None, + "timestamp": current_time + } + + else: + return {"error": "Invalid query parameters"} + + except Exception as e: + logger.error(f"Error in order book query for {connector_name}/{trading_pair}: {e}") + return {"error": str(e)} + + # ==================== Candles ==================== + + @staticmethod + def validate_connector(connector_name: str) -> None: + if connector_name not in CandlesFactory._candles_map: + raise UnsupportedConnectorException(connector_name) + + def get_candles_feed(self, config: CandlesConfig): + """ + Get or create a candles feed. + + Args: + config: CandlesConfig for the desired feed + + Returns: + Candle feed instance + """ + feed_key = self._generate_feed_key( + FeedType.CANDLES, config.connector, config.trading_pair, config.interval + ) + + self._last_access_times[feed_key] = time.time() + self._feed_configs[feed_key] = (FeedType.CANDLES, config) + + if feed_key not in self._candle_feeds: + self.validate_connector(config.connector) + feed = CandlesFactory.get_candle(config) + feed.start() + self._candle_feeds[feed_key] = feed + logger.info(f"Created candle feed: {feed_key}") + + return self._candle_feeds[feed_key] + + def get_candles_df( + self, + connector_name: str, + trading_pair: str, + interval: str, + max_records: int = 500 + ): + """ + Get candles dataframe. + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair + interval: Candle interval + max_records: Maximum number of records + + Returns: + Pandas DataFrame with candle data + """ + config = CandlesConfig( + connector=connector_name, + trading_pair=trading_pair, + interval=interval, + max_records=max_records + ) + + feed = self.get_candles_feed(config) + return feed.candles_df + + def stop_candle_feed(self, config: CandlesConfig): + """Stop a specific candle feed.""" + feed_key = self._generate_feed_key( + FeedType.CANDLES, config.connector, config.trading_pair, config.interval + ) + + if feed_key in self._candle_feeds: + try: + self._candle_feeds[feed_key].stop() + del self._candle_feeds[feed_key] + logger.info(f"Stopped candle feed: {feed_key}") + except Exception as e: + logger.error(f"Error stopping candle feed {feed_key}: {e}") + + # ==================== Prices ==================== + + async def get_prices( + self, + connector_name: str, + trading_pairs: List[str], + account_name: Optional[str] = None + ) -> Dict[str, float]: + """ + Get current prices for trading pairs. + + Args: + connector_name: Exchange connector name + trading_pairs: List of trading pairs + account_name: Optional account name for trading connector preference + + Returns: + Dictionary mapping trading pairs to prices + """ + try: + connector = self._connector_service.get_best_connector_for_market( + connector_name, account_name + ) + + if not connector: + return {"error": f"No connector available for {connector_name}"} + + prices = await connector.get_last_traded_prices(trading_pairs) + return {pair: float(price) for pair, price in prices.items()} + + except Exception as e: + logger.error(f"Error getting prices for {connector_name}: {e}") + return {"error": str(e)} + + def get_rate(self, base: str, quote: str = "USDT") -> Optional[Decimal]: + """ + Get exchange rate from rate oracle. + + Args: + base: Base currency + quote: Quote currency (default: USD) + + Returns: + Exchange rate or None + """ + try: + return self._rate_oracle.get_pair_rate(f"{base}-{quote}") + except Exception as e: + logger.debug(f"Rate not available for {base}-{quote}: {e}") + return None + + # ==================== Trading Rules ==================== + + async def get_trading_rules( + self, + connector_name: str, + trading_pairs: Optional[List[str]] = None, + account_name: Optional[str] = None + ) -> Dict[str, Dict]: + """ + Get trading rules for trading pairs. + + Args: + connector_name: Exchange connector name + trading_pairs: List of trading pairs (None for all) + account_name: Optional account name + + Returns: + Dictionary mapping trading pairs to their rules + """ + try: + connector = self._connector_service.get_best_connector_for_market( + connector_name, account_name + ) + + if not connector: + return {"error": f"No connector available for {connector_name}"} + + # Ensure trading rules are loaded + if not connector.trading_rules or len(connector.trading_rules) == 0: + await connector._update_trading_rules() + + result = {} + rules_to_process = trading_pairs if trading_pairs else connector.trading_rules.keys() + + for trading_pair in rules_to_process: + if trading_pair in connector.trading_rules: + rule = connector.trading_rules[trading_pair] + result[trading_pair] = { + "min_order_size": float(rule.min_order_size), + "max_order_size": float(rule.max_order_size) if rule.max_order_size else None, + "min_price_increment": float(rule.min_price_increment), + "min_base_amount_increment": float(rule.min_base_amount_increment), + "min_quote_amount_increment": float(rule.min_quote_amount_increment), + "min_notional_size": float(rule.min_notional_size), + "min_order_value": float(rule.min_order_value), + "max_price_significant_digits": float(rule.max_price_significant_digits), + "supports_limit_orders": rule.supports_limit_orders, + "supports_market_orders": rule.supports_market_orders, + "buy_order_collateral_token": rule.buy_order_collateral_token, + "sell_order_collateral_token": rule.sell_order_collateral_token, + } + elif trading_pairs: + result[trading_pair] = {"error": f"Trading pair {trading_pair} not found"} + + return result + + except Exception as e: + logger.error(f"Error getting trading rules for {connector_name}: {e}") + return {"error": str(e)} + + # ==================== Funding Info ==================== + + async def get_funding_info( + self, + connector_name: str, + trading_pair: str, + account_name: Optional[str] = None + ) -> Dict: + """ + Get funding information for perpetual trading pairs. + + Args: + connector_name: Exchange connector name + trading_pair: Trading pair + account_name: Optional account name + + Returns: + Dictionary with funding information + """ + try: + connector = self._connector_service.get_best_connector_for_market( + connector_name, account_name + ) + + if not connector: + return {"error": f"No connector available for {connector_name}"} + + if hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + orderbook_ds = connector._orderbook_ds + funding_info = await orderbook_ds.get_funding_info(trading_pair) + + if funding_info: + return { + "trading_pair": trading_pair, + "funding_rate": float(funding_info.rate) if funding_info.rate else None, + "next_funding_time": float(funding_info.next_funding_utc_timestamp) if funding_info.next_funding_utc_timestamp else None, + "mark_price": float(funding_info.mark_price) if funding_info.mark_price else None, + "index_price": float(funding_info.index_price) if funding_info.index_price else None, + } + else: + return {"error": f"No funding info available for {trading_pair}"} + else: + return {"error": f"Funding info not supported for {connector_name}"} + + except Exception as e: + logger.error(f"Error getting funding info for {connector_name}/{trading_pair}: {e}") + return {"error": str(e)} + + # ==================== Feed Management ==================== + + def get_active_feeds_info(self) -> Dict[str, dict]: + """Get information about active feeds.""" + current_time = time.time() + result = {} + + for feed_key, last_access in self._last_access_times.items(): + feed_type, config = self._feed_configs.get(feed_key, (None, None)) + result[feed_key] = { + "feed_type": feed_type.value if feed_type else "unknown", + "last_access_time": last_access, + "seconds_since_access": current_time - last_access, + "will_expire_in": max(0, self._feed_timeout - (current_time - last_access)), + "config": str(config) + } + + return result + + def manually_cleanup_feed( + self, + feed_type: FeedType, + connector: str, + trading_pair: str, + interval: str = None + ): + """Manually cleanup a specific feed.""" + feed_key = self._generate_feed_key(feed_type, connector, trading_pair, interval) + + if feed_key in self._feed_configs: + try: + if feed_type == FeedType.CANDLES and feed_key in self._candle_feeds: + self._candle_feeds[feed_key].stop() + del self._candle_feeds[feed_key] + + del self._last_access_times[feed_key] + del self._feed_configs[feed_key] + logger.info(f"Manually cleaned up feed: {feed_key}") + except Exception as e: + logger.error(f"Error manually cleaning up feed {feed_key}: {e}") + else: + logger.warning(f"Feed not found for cleanup: {feed_key}") + + # ==================== Internal ==================== + + async def _cleanup_loop(self): + """Background task to cleanup unused feeds.""" + while self._is_running: + try: + await self._cleanup_unused_feeds() + await asyncio.sleep(self._cleanup_interval) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in cleanup loop: {e}", exc_info=True) + await asyncio.sleep(self._cleanup_interval) + + async def _cleanup_unused_feeds(self): + """Clean up feeds that haven't been accessed within timeout.""" + current_time = time.time() + feeds_to_remove = [] + + for feed_key, last_access_time in self._last_access_times.items(): + if current_time - last_access_time > self._feed_timeout: + feeds_to_remove.append(feed_key) + + for feed_key in feeds_to_remove: + try: + feed_type, config = self._feed_configs[feed_key] + + if feed_type == FeedType.CANDLES and feed_key in self._candle_feeds: + self._candle_feeds[feed_key].stop() + del self._candle_feeds[feed_key] + + del self._last_access_times[feed_key] + del self._feed_configs[feed_key] + + logger.info(f"Cleaned up unused {feed_type.value} feed: {feed_key}") + + except Exception as e: + logger.error(f"Error cleaning up feed {feed_key}: {e}", exc_info=True) + + if feeds_to_remove: + logger.info(f"Cleaned up {len(feeds_to_remove)} unused market data feeds") + + def _generate_feed_key( + self, + feed_type: FeedType, + connector: str, + trading_pair: str, + interval: str = None + ) -> str: + """Generate a unique key for a market data feed.""" + if interval: + return f"{feed_type.value}_{connector}_{trading_pair}_{interval}" + return f"{feed_type.value}_{connector}_{trading_pair}" + + # ==================== Properties ==================== + + @property + def rate_oracle(self) -> RateOracle: + """Get the rate oracle instance.""" + return self._rate_oracle + + @property + def connector_service(self) -> "UnifiedConnectorService": + """Get the connector service instance.""" + return self._connector_service + + # ==================== Order Book Tracker Diagnostics ==================== + + def get_order_book_tracker_diagnostics( + self, + connector_name: str, + account_name: Optional[str] = None + ) -> Dict: + """ + Get diagnostics for a connector's order book tracker. + + Args: + connector_name: Exchange connector name + account_name: Optional account name for trading connector preference + + Returns: + Dictionary with diagnostic information + """ + return self._connector_service.get_order_book_tracker_diagnostics( + connector_name=connector_name, + account_name=account_name + ) + + async def restart_order_book_tracker( + self, + connector_name: str, + account_name: Optional[str] = None + ) -> Dict: + """ + Restart the order book tracker for a connector. + + Args: + connector_name: Exchange connector name + account_name: Optional account name for trading connector preference + + Returns: + Dictionary with restart status + """ + return await self._connector_service.restart_order_book_tracker( + connector_name=connector_name, + account_name=account_name + ) diff --git a/services/orders_recorder.py b/services/orders_recorder.py new file mode 100644 index 00000000..04b41bce --- /dev/null +++ b/services/orders_recorder.py @@ -0,0 +1,463 @@ +import asyncio +import logging +import math +import time +from datetime import datetime +from decimal import Decimal, InvalidOperation +from typing import Any, Optional, Union + +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.core.event.event_forwarder import SourceInfoEventForwarder +from hummingbot.core.event.events import BuyOrderCreatedEvent, MarketEvent, OrderFilledEvent, SellOrderCreatedEvent, TradeType + +from database import AsyncDatabaseManager, OrderRepository, TradeRepository + +# Initialize logger +logger = logging.getLogger(__name__) + + +class OrdersRecorder: + """ + Custom orders recorder that mimics Hummingbot's MarketsRecorder functionality + but uses our AsyncDatabaseManager for storage. + """ + + def __init__(self, db_manager: AsyncDatabaseManager, account_name: str, connector_name: str): + self.db_manager = db_manager + self.account_name = account_name + self.connector_name = connector_name + self._connector: Optional[ConnectorBase] = None + + # Create event forwarders similar to MarketsRecorder + self._create_order_forwarder = SourceInfoEventForwarder(self._did_create_order) + self._fill_order_forwarder = SourceInfoEventForwarder(self._did_fill_order) + self._cancel_order_forwarder = SourceInfoEventForwarder(self._did_cancel_order) + self._fail_order_forwarder = SourceInfoEventForwarder(self._did_fail_order) + self._complete_order_forwarder = SourceInfoEventForwarder(self._did_complete_order) + + # Event pairs mapping events to forwarders + self._event_pairs = [ + (MarketEvent.BuyOrderCreated, self._create_order_forwarder), + (MarketEvent.SellOrderCreated, self._create_order_forwarder), + (MarketEvent.OrderFilled, self._fill_order_forwarder), + (MarketEvent.OrderCancelled, self._cancel_order_forwarder), + (MarketEvent.OrderFailure, self._fail_order_forwarder), + (MarketEvent.BuyOrderCompleted, self._complete_order_forwarder), + (MarketEvent.SellOrderCompleted, self._complete_order_forwarder), + ] + + def start(self, connector: ConnectorBase): + """Start recording orders for the given connector""" + # Idempotency guard: prevent double-registration of listeners + if self._connector is not None: + logger.warning(f"OrdersRecorder already started for {self.account_name}/{self.connector_name}, ignoring duplicate start") + return + + self._connector = connector + + # Subscribe to order events using the same pattern as MarketsRecorder + for event, forwarder in self._event_pairs: + connector.add_listener(event, forwarder) + logger.info(f"OrdersRecorder: Added listener for {event} with forwarder {forwarder}") + + # Debug: Check if listeners were actually added + if hasattr(connector, '_event_listeners'): + listeners = connector._event_listeners.get(event, []) + logger.info(f"OrdersRecorder: Event {event} now has {len(listeners)} listeners") + for i, listener in enumerate(listeners): + logger.info(f"OrdersRecorder: Listener {i}: {listener}") + + logger.info(f"OrdersRecorder started for {self.account_name}/{self.connector_name} with {len(self._event_pairs)} event listeners") + + # Debug: Print connector info + logger.info(f"OrdersRecorder: Connector type: {type(connector)}") + logger.info(f"OrdersRecorder: Connector name: {getattr(connector, 'name', 'unknown')}") + logger.info(f"OrdersRecorder: Connector ready: {getattr(connector, 'ready', 'unknown')}") + + # Test if forwarders are callable + for event, forwarder in self._event_pairs: + if callable(forwarder): + logger.info(f"OrdersRecorder: Forwarder for {event} is callable") + else: + logger.error(f"OrdersRecorder: Forwarder for {event} is NOT callable: {type(forwarder)}") + + async def stop(self): + """Stop recording orders""" + if self._connector: + # Remove all event listeners + for event, forwarder in self._event_pairs: + self._connector.remove_listener(event, forwarder) + + logger.info(f"OrdersRecorder stopped for {self.account_name}/{self.connector_name}") + + def _extract_error_message(self, event) -> str: + """Extract error message from various possible event attributes.""" + # Try different possible attribute names for error messages + for attr_name in ['error_message', 'message', 'reason', 'failure_reason', 'error']: + if hasattr(event, attr_name): + error_value = getattr(event, attr_name) + if error_value: + return str(error_value) + + # If no error message found, create a descriptive one + return f"Order failed: {event.__class__.__name__}" + + def _did_create_order(self, event_tag: int, market: ConnectorBase, event: Union[BuyOrderCreatedEvent, SellOrderCreatedEvent]): + """Handle order creation events - called by SourceInfoEventForwarder""" + logger.info(f"OrdersRecorder: _did_create_order called for order {getattr(event, 'order_id', 'unknown')}") + try: + # Determine trade type from event + trade_type = TradeType.BUY if isinstance(event, BuyOrderCreatedEvent) else TradeType.SELL + logger.info(f"OrdersRecorder: Creating task to handle order created - {trade_type} order") + asyncio.create_task(self._handle_order_created(event, trade_type)) + except Exception as e: + logger.error(f"Error in _did_create_order: {e}") + + def _did_fill_order(self, event_tag: int, market: ConnectorBase, event: OrderFilledEvent): + """Handle order fill events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_filled(event)) + except Exception as e: + logger.error(f"Error in _did_fill_order: {e}") + + def _did_cancel_order(self, event_tag: int, market: ConnectorBase, event: Any): + """Handle order cancel events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_cancelled(event)) + except Exception as e: + logger.error(f"Error in _did_cancel_order: {e}") + + def _did_fail_order(self, event_tag: int, market: ConnectorBase, event: Any): + """Handle order failure events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_failed(event)) + except Exception as e: + logger.error(f"Error in _did_fail_order: {e}") + + def _did_complete_order(self, event_tag: int, market: ConnectorBase, event: Any): + """Handle order completion events - called by SourceInfoEventForwarder""" + try: + asyncio.create_task(self._handle_order_completed(event)) + except Exception as e: + logger.error(f"Error in _did_complete_order: {e}") + + async def _handle_order_created(self, event: Union[BuyOrderCreatedEvent, SellOrderCreatedEvent], trade_type: TradeType): + """Handle order creation events""" + logger.info(f"OrdersRecorder: _handle_order_created started for order {event.order_id}") + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + + # Check if order already exists first + existing_order = await order_repo.get_order_by_client_id(event.order_id) + if existing_order: + logger.info(f"OrdersRecorder: Order {event.order_id} already exists with status {existing_order.status}") + + # Update exchange_order_id if we have it now and it was missing + exchange_order_id = getattr(event, 'exchange_order_id', None) + if exchange_order_id and not existing_order.exchange_order_id: + existing_order.exchange_order_id = exchange_order_id + logger.info(f"OrdersRecorder: Updated exchange_order_id to {exchange_order_id} for order {event.order_id}") + + # Update status if it's still in PENDING_CREATE or similar early state + if existing_order.status in ["PENDING_CREATE", "PENDING", "SUBMITTED"]: + existing_order.status = "OPEN" + logger.info(f"OrdersRecorder: Updated status to OPEN for order {event.order_id}") + + await session.flush() + return + + order_data = { + "client_order_id": event.order_id, + "account_name": self.account_name, + "connector_name": self.connector_name, + "trading_pair": event.trading_pair, + "trade_type": trade_type.name, + "order_type": event.type.name if hasattr(event, 'type') else 'UNKNOWN', + "amount": float(event.amount), + "price": float(event.price) if event.price else None, + "status": "OPEN", + "exchange_order_id": getattr(event, 'exchange_order_id', None) + } + await order_repo.create_order(order_data) + + logger.info(f"OrdersRecorder: Successfully recorded order created: {event.order_id}") + except Exception as e: + logger.error(f"OrdersRecorder: Error recording order created: {e}") + + async def _handle_order_filled(self, event: OrderFilledEvent): + """Handle order fill events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + trade_repo = TradeRepository(session) + + # Calculate fees + trade_fee_paid = 0 + trade_fee_currency = None + + if event.trade_fee: + try: + base_asset, quote_asset = event.trading_pair.split("-") + fee_in_quote = event.trade_fee.fee_amount_in_token( + trading_pair=event.trading_pair, + price=event.price, + order_amount=event.amount, + token=quote_asset, + exchange=self._connector, + ) + trade_fee_paid = float(fee_in_quote) + trade_fee_currency = quote_asset + except Exception as e: + logger.warning(f"Primary fee calculation failed: {e}. Attempting fallback...") + try: + base_asset, quote_asset = event.trading_pair.split("-") + fallback_fee = await self._calculate_fee_fallback( + trade_fee=event.trade_fee, + base_asset=base_asset, + quote_asset=quote_asset, + fill_price=event.price, + order_amount=event.amount, + ) + if fallback_fee is not None: + trade_fee_paid = float(fallback_fee) + trade_fee_currency = quote_asset + logger.info(f"Fallback fee calculation succeeded: {trade_fee_paid} {trade_fee_currency}") + else: + logger.error(f"Fallback fee calculation returned None for {event.order_id}") + trade_fee_paid = 0 + trade_fee_currency = None + except Exception as fallback_err: + logger.error(f"Fallback fee calculation also failed: {fallback_err}") + trade_fee_paid = 0 + trade_fee_currency = None + # Update order with fill information (handle potential NaN values like Hummingbot does) + try: + filled_amount = Decimal(str(event.amount)) + average_fill_price = Decimal(str(event.price)) + fee_paid_decimal = Decimal(str(trade_fee_paid)) if trade_fee_paid else None + + order = await order_repo.update_order_fill( + client_order_id=event.order_id, + filled_amount=filled_amount, + average_fill_price=average_fill_price, + fee_paid=fee_paid_decimal, + fee_currency=trade_fee_currency + ) + except (ValueError, InvalidOperation) as e: + logger.error(f"Error processing order fill for {event.order_id}: {e}, skipping update") + return + + # Create trade record using validated values + if order: + try: + # Validate all values before creating trade record + validated_timestamp = event.timestamp if event.timestamp and not math.isnan(event.timestamp) else time.time() + validated_fee = trade_fee_paid if trade_fee_paid and not math.isnan(trade_fee_paid) else 0 + + # Use exchange_trade_id if available (unique per fill), fallback to generated id + exchange_trade_id = getattr(event, 'exchange_trade_id', None) + if exchange_trade_id: + trade_id = f"{event.order_id}_{exchange_trade_id}" + else: + # Fallback: include amount to differentiate partial fills at same timestamp + trade_id = f"{event.order_id}_{validated_timestamp}_{float(filled_amount)}" + + trade_data = { + "order_id": order.id, + "trade_id": trade_id, + "timestamp": datetime.fromtimestamp(validated_timestamp), + "trading_pair": event.trading_pair, + "trade_type": event.trade_type.name, + "amount": float(filled_amount), # Use validated amount + "price": float(average_fill_price), # Use validated price + "fee_paid": validated_fee, + "fee_currency": trade_fee_currency + } + result = await trade_repo.create_trade(trade_data) + if result is None: + logger.debug(f"Trade {trade_id} already exists, skipping duplicate") + except (ValueError, TypeError) as e: + logger.error(f"Error creating trade record for {event.order_id}: {e}") + logger.error(f"Trade data that failed: timestamp={event.timestamp}, amount={event.amount}, price={event.price}, fee={trade_fee_paid}") + + logger.debug(f"Recorded order fill: {event.order_id} - {event.amount} @ {event.price}") + except Exception as e: + logger.error(f"Error recording order fill: {e}") + + async def _handle_order_cancelled(self, event: Any): + """Handle order cancellation events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + await order_repo.update_order_status( + client_order_id=event.order_id, + status="CANCELLED" + ) + + logger.debug(f"Recorded order cancelled: {event.order_id}") + except Exception as e: + logger.error(f"Error recording order cancellation: {e}") + + def _get_order_details_from_connector(self, order_id: str) -> Optional[dict]: + """Try to get order details from connector's tracked orders""" + try: + if self._connector and hasattr(self._connector, 'in_flight_orders'): + in_flight_order = self._connector.in_flight_orders.get(order_id) + if in_flight_order: + return { + "trading_pair": in_flight_order.trading_pair, + "trade_type": in_flight_order.trade_type.name, + "order_type": in_flight_order.order_type.name, + "amount": float(in_flight_order.amount), + "price": float(in_flight_order.price) if in_flight_order.price else None + } + except Exception as e: + logger.error(f"Error getting order details from connector: {e}") + return None + + async def _fetch_conversion_rate(self, from_token: str, to_token: str) -> Optional[Decimal]: + """Fetch the conversion rate between two tokens using the connector's REST API. + Tries direct pair first, then inverse pair.""" + if not self._connector: + return None + try: + direct_pair = f"{from_token}-{to_token}" + price = await asyncio.wait_for( + self._connector._get_last_traded_price(trading_pair=direct_pair), + timeout=5.0, + ) + if price and price > 0: + return Decimal(str(price)) + except Exception: + pass + try: + inverse_pair = f"{to_token}-{from_token}" + price = await asyncio.wait_for( + self._connector._get_last_traded_price(trading_pair=inverse_pair), + timeout=5.0, + ) + if price and price > 0: + return Decimal(1) / Decimal(str(price)) + except Exception: + pass + return None + + async def _calculate_fee_fallback( + self, + trade_fee, + base_asset: str, + quote_asset: str, + fill_price: Decimal, + order_amount: Decimal, + ) -> Optional[Decimal]: + """Manually compute the trade fee in quote asset when the primary method fails.""" + fee_amount = Decimal(0) + + # Handle percent component + if trade_fee.percent and trade_fee.percent != Decimal(0): + fee_amount += (fill_price * order_amount) * trade_fee.percent + + # Handle flat_fees component + for flat_fee in trade_fee.flat_fees: + if flat_fee.token == quote_asset: + fee_amount += flat_fee.amount + elif flat_fee.token == base_asset: + fee_amount += flat_fee.amount * fill_price + else: + rate = await self._fetch_conversion_rate(flat_fee.token, quote_asset) + if rate is not None: + fee_amount += flat_fee.amount * rate + else: + logger.error( + f"Could not fetch conversion rate for {flat_fee.token} -> {quote_asset}" + ) + return None + + return fee_amount + + async def _handle_order_failed(self, event: Any): + """Handle order failure events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + + # Check if order exists, if not try to get details from connector's tracked orders + existing_order = await order_repo.get_order_by_client_id(event.order_id) + if existing_order: + # Extract error message from various possible attributes + error_msg = self._extract_error_message(event) + + # Update existing order with failure status and error message + await order_repo.update_order_status( + client_order_id=event.order_id, + status="FAILED", + error_message=error_msg + ) + logger.info(f"Updated existing order {event.order_id} to FAILED status") + else: + # Try to get order details from connector's tracked orders + order_details = self._get_order_details_from_connector(event.order_id) + if order_details: + logger.info(f"Retrieved order details from connector for {event.order_id}: {order_details}") + + # Create order record as FAILED with available details + if order_details: + order_data = { + "client_order_id": event.order_id, + "account_name": self.account_name, + "connector_name": self.connector_name, + "trading_pair": order_details["trading_pair"], + "trade_type": order_details["trade_type"], + "order_type": order_details["order_type"], + "amount": order_details["amount"], + "price": order_details["price"], + "status": "FAILED", + "error_message": self._extract_error_message(event) + } + else: + # Fallback with minimal details + order_data = { + "client_order_id": event.order_id, + "account_name": self.account_name, + "connector_name": self.connector_name, + "trading_pair": "UNKNOWN", + "trade_type": "UNKNOWN", + "order_type": "UNKNOWN", + "amount": 0.0, + "price": None, + "status": "FAILED", + "error_message": self._extract_error_message(event) + } + + try: + await order_repo.create_order(order_data) + logger.info(f"Created failed order record for {event.order_id}") + except Exception as create_error: + # If creation fails due to duplicate key, try to update existing order + if "duplicate key" in str(create_error).lower() or "unique constraint" in str(create_error).lower(): + logger.info(f"Order {event.order_id} already exists, updating status to FAILED") + await order_repo.update_order_status( + client_order_id=event.order_id, + status="FAILED", + error_message=self._extract_error_message(event) + ) + else: + raise create_error + + except Exception as e: + logger.error(f"Error recording order failure: {e}") + + async def _handle_order_completed(self, event: Any): + """Handle order completion events""" + try: + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + order = await order_repo.get_order_by_client_id(event.order_id) + if order: + order.status = "FILLED" + order.exchange_order_id = getattr(event, 'exchange_order_id', None) + + logger.debug(f"Recorded order completed: {event.order_id}") + except Exception as e: + logger.error(f"Error recording order completion: {e}") \ No newline at end of file diff --git a/services/trading_service.py b/services/trading_service.py new file mode 100644 index 00000000..cb35b1f5 --- /dev/null +++ b/services/trading_service.py @@ -0,0 +1,634 @@ +""" +Trading Service - Centralized trading operations with executor-compatible interface. + +This service provides trading operations (buy, sell, cancel) using the +UnifiedConnectorService for connector management. +""" +import logging +import time +from decimal import Decimal +from typing import Dict, List, Optional, Set, TYPE_CHECKING + +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.core.data_type.common import OrderType, TradeType, PositionAction + +if TYPE_CHECKING: + from services.unified_connector_service import UnifiedConnectorService + from services.market_data_service import MarketDataService + + +logger = logging.getLogger(__name__) + + +class AccountTradingInterface: + """ + ScriptStrategyBase-compatible interface for executor trading. + + This class provides the exact interface that Hummingbot executors expect + from a strategy object, backed by UnifiedConnectorService. + + Executors use the following interface from strategy: + - current_timestamp: float property + - buy(connector_name, trading_pair, amount, order_type, price, position_action) -> str + - sell(connector_name, trading_pair, amount, order_type, price, position_action) -> str + - cancel(connector_name, trading_pair, order_id) -> str + - get_active_orders(connector_name) -> List + + ExecutorBase also accesses: + - connectors: Dict[str, ConnectorBase] (accessed directly in ExecutorBase.__init__) + """ + + def __init__( + self, + connector_service: "UnifiedConnectorService", + market_data_service: "MarketDataService", + account_name: str + ): + """ + Initialize AccountTradingInterface. + + Args: + connector_service: UnifiedConnectorService for connector access + market_data_service: MarketDataService for order book operations + account_name: Account to use for connectors + """ + self._connector_service = connector_service + self._market_data_service = market_data_service + self._account_name = account_name + + # Track active markets (connector_name -> set of trading_pairs) + self._markets: Dict[str, Set[str]] = {} + + # Timestamp tracking + self._current_timestamp: float = time.time() + + logger.info(f"AccountTradingInterface created for account: {account_name}") + + @property + def account_name(self) -> str: + """Return the account name for this trading interface.""" + return self._account_name + + @property + def connectors(self) -> Dict[str, ConnectorBase]: + """ + Return connectors for this account from the UnifiedConnectorService. + + This returns the actual connectors that are already initialized and running. + """ + return self._connector_service.get_account_connectors(self._account_name) + + @property + def markets(self) -> Dict[str, Set[str]]: + """Return active markets configuration.""" + return self._markets + + @property + def current_timestamp(self) -> float: + """Return current timestamp (updated by control loop).""" + return self._current_timestamp + + def update_timestamp(self): + """Update the current timestamp. Called by ExecutorService control loop.""" + self._current_timestamp = time.time() + + async def ensure_connector(self, connector_name: str) -> ConnectorBase: + """ + Ensure connector is loaded and available. + + Args: + connector_name: Name of the connector + + Returns: + The connector instance + """ + return await self._connector_service.get_trading_connector( + self._account_name, + connector_name + ) + + async def add_market( + self, + connector_name: str, + trading_pair: str, + order_book_timeout: float = 30.0 + ): + """ + Add a trading pair to active markets with full order book support. + + This method ensures: + 1. Connector is loaded + 2. Order book is initialized and has valid data + 3. Rate sources are initialized for price feeds + + Args: + connector_name: Name of the connector + trading_pair: Trading pair to add + order_book_timeout: Timeout in seconds to wait for order book data + """ + await self.ensure_connector(connector_name) + + if connector_name not in self._markets: + self._markets[connector_name] = set() + + # Check if already tracking this pair AND order book is ready + if trading_pair in self._markets[connector_name]: + # Verify order book actually has data before returning early + connector = self.connectors.get(connector_name) + if connector and hasattr(connector, 'order_book_tracker'): + tracker = connector.order_book_tracker + if trading_pair in tracker.order_books: + try: + ob = tracker.order_books[trading_pair] + bids, asks = ob.snapshot + if len(bids) > 0 and len(asks) > 0: + logger.debug(f"Market {connector_name}/{trading_pair} already active with valid order book") + return + except Exception: + pass + # Order book not ready, need to re-initialize + logger.info(f"Market {connector_name}/{trading_pair} tracked but order book not ready, re-initializing") + + self._markets[connector_name].add(trading_pair) + + # Get connector from our account's connectors + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not available. Check credentials.") + + # Initialize order book via MarketDataService (uses best available connector) + logger.info(f"Initializing order book for {connector_name}/{trading_pair}") + success = await self._market_data_service.initialize_order_book( + connector_name=connector_name, + trading_pair=trading_pair, + account_name=self._account_name, + timeout=order_book_timeout + ) + + if not success: + raise ValueError(f"Failed to initialize order book for {connector_name}/{trading_pair}") + + logger.info(f"Order book initialized successfully for {connector_name}/{trading_pair}") + + # Register trading pair with connector + self._register_trading_pair_with_connector(connector, trading_pair) + + logger.info(f"Market {connector_name}/{trading_pair} added to trading interface") + + async def remove_market( + self, + connector_name: str, + trading_pair: str, + remove_order_book: bool = True + ): + """ + Remove a trading pair from active markets. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair to remove + remove_order_book: Whether to remove the order book (default True) + """ + if connector_name not in self._markets: + return + + self._markets[connector_name].discard(trading_pair) + if not self._markets[connector_name]: + del self._markets[connector_name] + + # Remove order book via MarketDataService + if remove_order_book: + try: + await self._market_data_service.remove_trading_pair( + connector_name=connector_name, + trading_pair=trading_pair, + account_name=self._account_name + ) + except Exception as e: + logger.warning(f"Failed to remove order book for {connector_name}/{trading_pair}: {e}") + + logger.info(f"Removed market {connector_name}/{trading_pair}") + + def _register_trading_pair_with_connector( + self, + connector: ConnectorBase, + trading_pair: str + ): + """ + Register a trading pair with the connector's internal structures. + + Args: + connector: The connector instance (ExchangePyBase) + trading_pair: Trading pair to register + """ + if trading_pair not in connector._trading_pairs: + connector._trading_pairs.append(trading_pair) + logger.debug(f"Registered {trading_pair} with connector {type(connector).__name__}") + + # ======================================== + # ScriptStrategyBase-compatible methods + # These are called by executors via self._strategy.method() + # ======================================== + + def buy( + self, + connector_name: str, + trading_pair: str, + amount: Decimal, + order_type: OrderType, + price: Decimal = Decimal("NaN"), + position_action: PositionAction = PositionAction.NIL + ) -> str: + """ + Place a buy order. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + amount: Order amount in base currency + order_type: Type of order (LIMIT, MARKET, etc.) + price: Order price (for limit orders) + position_action: Position action for perpetuals + + Returns: + Client order ID + """ + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") + connector._set_current_timestamp(time.time()) + + return connector.buy( + trading_pair=trading_pair, + amount=amount, + order_type=order_type, + price=price, + position_action=position_action + ) + + def sell( + self, + connector_name: str, + trading_pair: str, + amount: Decimal, + order_type: OrderType, + price: Decimal = Decimal("NaN"), + position_action: PositionAction = PositionAction.NIL + ) -> str: + """ + Place a sell order. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + amount: Order amount in base currency + order_type: Type of order (LIMIT, MARKET, etc.) + price: Order price (for limit orders) + position_action: Position action for perpetuals + + Returns: + Client order ID + """ + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") + connector._set_current_timestamp(time.time()) + + return connector.sell( + trading_pair=trading_pair, + amount=amount, + order_type=order_type, + price=price, + position_action=position_action + ) + + def cancel( + self, + connector_name: str, + trading_pair: str, + order_id: str + ) -> str: + """ + Cancel an order. + + Args: + connector_name: Name of the connector + trading_pair: Trading pair + order_id: Client order ID to cancel + + Returns: + Client order ID that was cancelled + """ + connector = self.connectors.get(connector_name) + if not connector: + raise ValueError(f"Connector {connector_name} not loaded. Call ensure_connector first.") + + return connector.cancel(trading_pair=trading_pair, client_order_id=order_id) + + def get_active_orders(self, connector_name: str) -> List: + """ + Get active orders for a connector. + + Args: + connector_name: Name of the connector + + Returns: + List of active in-flight orders + """ + connector = self.connectors.get(connector_name) + if not connector: + return [] + return list(connector.in_flight_orders.values()) + + # ======================================== + # Additional helper methods + # ======================================== + + def get_connector(self, connector_name: str) -> Optional[ConnectorBase]: + """ + Get a connector by name. + + Args: + connector_name: Name of the connector + + Returns: + The connector instance or None if not loaded + """ + return self.connectors.get(connector_name) + + def is_connector_loaded(self, connector_name: str) -> bool: + """ + Check if a connector is loaded. + + Args: + connector_name: Name of the connector + + Returns: + True if connector is loaded + """ + return connector_name in self.connectors + + def get_all_trading_pairs(self) -> Dict[str, Set[str]]: + """ + Get all active trading pairs by connector. + + Returns: + Dictionary mapping connector names to sets of trading pairs + """ + return {k: v.copy() for k, v in self._markets.items()} + + async def cleanup(self): + """ + Cleanup resources. Called when shutting down. + """ + self._markets.clear() + logger.info(f"AccountTradingInterface cleanup completed for account {self._account_name}") + + +class TradingService: + """ + Centralized trading service using UnifiedConnectorService. + + This service manages: + - Trading interfaces for each account (executor-compatible) + - Order placement and cancellation + - Position management for perpetuals + """ + + def __init__( + self, + connector_service: "UnifiedConnectorService", + market_data_service: "MarketDataService" + ): + """ + Initialize the TradingService. + + Args: + connector_service: UnifiedConnectorService for connector access + market_data_service: MarketDataService for order book operations + """ + self._connector_service = connector_service + self._market_data_service = market_data_service + + # Trading interfaces per account (for executor use) + self._trading_interfaces: Dict[str, AccountTradingInterface] = {} + + logger.info("TradingService initialized") + + # ==================== Trading Interface ==================== + + def get_trading_interface(self, account_name: str) -> AccountTradingInterface: + """ + Get or create a trading interface for the specified account. + + This interface provides ScriptStrategyBase-compatible methods + that executors can use for trading operations. + + Args: + account_name: Account to get trading interface for + + Returns: + AccountTradingInterface instance for the account + """ + if account_name not in self._trading_interfaces: + self._trading_interfaces[account_name] = AccountTradingInterface( + connector_service=self._connector_service, + market_data_service=self._market_data_service, + account_name=account_name + ) + return self._trading_interfaces[account_name] + + def get_all_trading_interfaces(self) -> Dict[str, AccountTradingInterface]: + """Get all active trading interfaces.""" + return self._trading_interfaces.copy() + + # ==================== Direct Trading Operations ==================== + + async def place_order( + self, + account_name: str, + connector_name: str, + trading_pair: str, + trade_type: TradeType, + amount: Decimal, + order_type: OrderType, + price: Optional[Decimal] = None, + position_action: PositionAction = PositionAction.NIL + ) -> str: + """ + Place an order on an exchange. + + Args: + account_name: Account to use + connector_name: Exchange connector name + trading_pair: Trading pair + trade_type: BUY or SELL + amount: Order amount + order_type: LIMIT, MARKET, etc. + price: Order price (required for LIMIT orders) + position_action: Position action for perpetuals + + Returns: + Client order ID + """ + interface = self.get_trading_interface(account_name) + await interface.ensure_connector(connector_name) + + if trade_type == TradeType.BUY: + return interface.buy( + connector_name=connector_name, + trading_pair=trading_pair, + amount=amount, + order_type=order_type, + price=price if price else Decimal("NaN"), + position_action=position_action + ) + else: + return interface.sell( + connector_name=connector_name, + trading_pair=trading_pair, + amount=amount, + order_type=order_type, + price=price if price else Decimal("NaN"), + position_action=position_action + ) + + async def cancel_order( + self, + account_name: str, + connector_name: str, + trading_pair: str, + order_id: str + ) -> str: + """ + Cancel an order. + + Args: + account_name: Account name + connector_name: Exchange connector name + trading_pair: Trading pair + order_id: Client order ID to cancel + + Returns: + Client order ID that was cancelled + """ + interface = self.get_trading_interface(account_name) + return interface.cancel(connector_name, trading_pair, order_id) + + def get_active_orders( + self, + account_name: str, + connector_name: str + ) -> List: + """ + Get active orders for an account/connector. + + Args: + account_name: Account name + connector_name: Exchange connector name + + Returns: + List of active orders + """ + interface = self.get_trading_interface(account_name) + return interface.get_active_orders(connector_name) + + # ==================== Position Management ==================== + + async def get_positions( + self, + account_name: str, + connector_name: str + ) -> Dict: + """ + Get positions for a perpetual connector. + + Args: + account_name: Account name + connector_name: Exchange connector name + + Returns: + Dictionary of positions + """ + connector = await self._connector_service.get_trading_connector( + account_name, connector_name + ) + + if hasattr(connector, 'account_positions'): + return { + str(pos.trading_pair): { + "trading_pair": pos.trading_pair, + "position_side": pos.position_side.name, + "unrealized_pnl": float(pos.unrealized_pnl), + "entry_price": float(pos.entry_price), + "amount": float(pos.amount), + "leverage": pos.leverage + } + for pos in connector.account_positions.values() + } + return {} + + async def set_leverage( + self, + account_name: str, + connector_name: str, + trading_pair: str, + leverage: int + ) -> bool: + """ + Set leverage for a trading pair on a perpetual connector. + + Args: + account_name: Account name + connector_name: Exchange connector name + trading_pair: Trading pair + leverage: Leverage value + + Returns: + True if successful + """ + connector = await self._connector_service.get_trading_connector( + account_name, connector_name + ) + + if hasattr(connector, 'set_leverage'): + try: + await connector.set_leverage(trading_pair, leverage) + logger.info(f"Set leverage to {leverage}x for {trading_pair} on {connector_name}") + return True + except Exception as e: + logger.error(f"Error setting leverage: {e}") + return False + return False + + # ==================== Lifecycle ==================== + + async def stop(self): + """Stop all trading interfaces and cleanup resources.""" + logger.info("Stopping TradingService...") + + for account_name, interface in self._trading_interfaces.items(): + try: + await interface.cleanup() + except Exception as e: + logger.error(f"Error cleaning up interface for {account_name}: {e}") + + self._trading_interfaces.clear() + logger.info("TradingService stopped") + + def update_all_timestamps(self): + """Update timestamps for all trading interfaces. Called by executor control loop.""" + for interface in self._trading_interfaces.values(): + interface.update_timestamp() + + # ==================== Properties ==================== + + @property + def connector_service(self) -> "UnifiedConnectorService": + """Get the connector service instance.""" + return self._connector_service + + @property + def market_data_service(self) -> "MarketDataService": + """Get the market data service instance.""" + return self._market_data_service diff --git a/services/unified_connector_service.py b/services/unified_connector_service.py new file mode 100644 index 00000000..9d697f0a --- /dev/null +++ b/services/unified_connector_service.py @@ -0,0 +1,1375 @@ +""" +UnifiedConnectorService - Single source of truth for all connector instances. + +This service consolidates connector management from: +- ConnectorManager (trading connectors) +- MarketDataProvider._non_trading_connectors (data-only connectors) + +Key features: +- Trading connectors: authenticated, per-account, with order tracking +- Data connectors: non-authenticated, shared, for public market data +- get_best_connector_for_market(): prefers trading connector (has order book tracker) +""" +import asyncio +import logging +import time +from decimal import Decimal +from typing import Dict, List, Optional + +from hummingbot.client.config.config_crypt import ETHKeyFileSecretManger +from hummingbot.client.config.config_helpers import ClientConfigAdapter, api_keys_from_connector_config_map, get_connector_class +from hummingbot.client.settings import AllConnectorSettings +from hummingbot.connector.connector_base import ConnectorBase +from hummingbot.connector.connector_metrics_collector import TradeVolumeMetricCollector +from hummingbot.connector.exchange_py_base import ExchangePyBase +from hummingbot.connector.gateway.gateway_lp import GatewayLp +from hummingbot.connector.perpetual_derivative_py_base import PerpetualDerivativePyBase +from hummingbot.core.data_type.common import OrderType, PositionAction, PositionMode, TradeType +from hummingbot.core.data_type.in_flight_order import InFlightOrder, OrderState +from hummingbot.core.rate_oracle.rate_oracle import RateOracle +from hummingbot.core.utils.async_utils import safe_ensure_future + +from utils.file_system import fs_util +from utils.hummingbot_api_config_adapter import HummingbotAPIConfigAdapter +from utils.security import BackendAPISecurity + +logger = logging.getLogger(__name__) + + +class UnifiedConnectorService: + """ + Single source of truth for ALL connector instances. + + Manages two types of connectors: + 1. Trading connectors: authenticated, per-account, with full trading capabilities + 2. Data connectors: non-authenticated, shared, for public market data only + + The key method `get_best_connector_for_market()` ensures that order book + operations use the trading connector when available (which already has + order_book_tracker running), falling back to data connector otherwise. + """ + + METRICS_ACTIVATION_INTERVAL = Decimal("900") # 15 minutes + METRICS_VALUATION_TOKEN = "USDT" + + def __init__(self, secrets_manager: ETHKeyFileSecretManger, db_manager=None): + self.secrets_manager = secrets_manager + self.db_manager = db_manager + + # Trading connectors: account_name -> connector_name -> ConnectorBase + self._trading_connectors: Dict[str, Dict[str, ConnectorBase]] = {} + + # Data-only connectors: connector_name -> ConnectorBase (shared, non-authenticated) + self._data_connectors: Dict[str, ConnectorBase] = {} + self._data_connectors_started: Dict[str, bool] = {} + + # Order and funding recorders (for trading connectors) + self._orders_recorders: Dict[str, any] = {} + self._funding_recorders: Dict[str, any] = {} + self._metrics_collectors: Dict[str, TradeVolumeMetricCollector] = {} + + # Locks to prevent race conditions in connector creation + self._connector_locks: Dict[str, asyncio.Lock] = {} + + # Connector settings cache + self._conn_settings = AllConnectorSettings.get_connector_settings() + + def _is_perpetual_connector(self, connector: ConnectorBase) -> bool: + """Check if connector is a perpetual derivative connector. + + Args: + connector: The connector instance to check + + Returns: + True if perpetual connector, False otherwise + """ + return isinstance(connector, PerpetualDerivativePyBase) + + # ========================================================================= + # Trading Connector Management (authenticated, per-account) + # ========================================================================= + + async def get_trading_connector( + self, + account_name: str, + connector_name: str + ) -> ConnectorBase: + """ + Get or create an authenticated trading connector for a specific account. + + Trading connectors have: + - API key authentication + - Order tracking (OrdersRecorder) + - Funding tracking for perpetuals (FundingRecorder) + - Metrics collection + - Full trading capabilities + + Args: + account_name: The account name + connector_name: The connector name (e.g., "binance", "binance_perpetual") + + Returns: + Initialized trading connector + """ + cache_key = f"{account_name}:{connector_name}" + + # Create lock for this cache key if it doesn't exist + if cache_key not in self._connector_locks: + self._connector_locks[cache_key] = asyncio.Lock() + + # Use lock to prevent race conditions during connector creation + async with self._connector_locks[cache_key]: + if account_name not in self._trading_connectors: + self._trading_connectors[account_name] = {} + + if connector_name not in self._trading_connectors[account_name]: + connector = await self._create_and_initialize_trading_connector( + account_name, connector_name + ) + self._trading_connectors[account_name][connector_name] = connector + + return self._trading_connectors[account_name][connector_name] + + def get_all_trading_connectors(self) -> Dict[str, Dict[str, ConnectorBase]]: + """ + Get all trading connectors organized by account. + + Returns: + Dict mapping account_name -> connector_name -> ConnectorBase + """ + return self._trading_connectors + + def get_account_connectors(self, account_name: str) -> Dict[str, ConnectorBase]: + """ + Get all connectors for a specific account. + + Args: + account_name: Account name + + Returns: + Dict mapping connector_name -> ConnectorBase for this account + """ + return self._trading_connectors.get(account_name, {}) + + def is_trading_connector_initialized( + self, + account_name: str, + connector_name: str + ) -> bool: + """Check if a trading connector is already initialized.""" + return ( + account_name in self._trading_connectors and + connector_name in self._trading_connectors[account_name] + ) + + # ========================================================================= + # Data Connector Management (non-authenticated, shared) + # ========================================================================= + + def get_data_connector(self, connector_name: str) -> ConnectorBase: + """ + Get or create a non-authenticated data connector for public market data. + + Data connectors: + - No API keys required (public endpoints only) + - Shared across accounts + - Used for: trading rules, prices, order books, candles + - NOT used for: trading, balance queries + + Args: + connector_name: The connector name + + Returns: + Non-authenticated connector instance + """ + if connector_name not in self._data_connectors: + self._data_connectors[connector_name] = self._create_data_connector( + connector_name + ) + return self._data_connectors[connector_name] + + async def ensure_data_connector_started( + self, + connector_name: str, + trading_pair: str + ) -> bool: + """ + Ensure a data connector's network is started with at least one trading pair. + + This is needed because exchanges close WebSocket connections without subscriptions. + + Args: + connector_name: The connector name + trading_pair: Initial trading pair to subscribe to + + Returns: + True if started successfully + """ + if self._data_connectors_started.get(connector_name, False): + return True + + connector = self.get_data_connector(connector_name) + + try: + # Add trading pair before starting network + if trading_pair not in connector._trading_pairs: + connector._trading_pairs.append(trading_pair) + + # Start network + await connector.start_network() + self._data_connectors_started[connector_name] = True + logger.info(f"Started data connector: {connector_name} with pair {trading_pair}") + + # Wait for order book tracker to be ready + max_wait = 30 + waited = 0 + tracker = connector.order_book_tracker + while waited < max_wait: + if tracker._order_book_stream_listener_task is not None: + await asyncio.sleep(2.0) + break + await asyncio.sleep(0.5) + waited += 0.5 + + return True + + except Exception as e: + logger.error(f"Error starting data connector {connector_name}: {e}") + return False + + # ========================================================================= + # Best Connector Selection (THE KEY FIX) + # ========================================================================= + + def get_best_connector_for_market( + self, + connector_name: str, + account_name: Optional[str] = None + ) -> Optional[ConnectorBase]: + """ + Get the best available connector for market operations (order books, prices). + + CRITICAL: This method ensures order book initialization uses the correct + connector. It prefers trading connectors because they already have + order_book_tracker running with WebSocket connections. + + Priority: + 1. Specific account's trading connector (if account_name provided) + 2. Any trading connector for this connector_name + 3. Data connector (creates new if needed) + + Args: + connector_name: The connector name + account_name: Optional account to prefer + + Returns: + Best available connector for market operations + """ + # 1. Try specific account's trading connector + if account_name: + trading = self._trading_connectors.get(account_name, {}).get(connector_name) + if trading: + logger.debug( + f"Using trading connector for {connector_name} " + f"(account: {account_name})" + ) + return trading + + # 2. Try ANY trading connector for this connector_name + for acc_name, acc_connectors in self._trading_connectors.items(): + if connector_name in acc_connectors: + logger.debug( + f"Using trading connector for {connector_name} " + f"(found in account: {acc_name})" + ) + return acc_connectors[connector_name] + + # 3. Fall back to data connector + logger.debug(f"Using data connector for {connector_name} (no trading connector)") + return self.get_data_connector(connector_name) + + # ========================================================================= + # Order Book Initialization + # ========================================================================= + + async def initialize_order_book( + self, + connector_name: str, + trading_pair: str, + account_name: Optional[str] = None, + timeout: float = 30.0 + ) -> bool: + """ + Initialize order book for a trading pair using the best available connector. + + This method: + 1. Gets the best connector (prefers trading over data) + 2. Adds trading pair to order book tracker + 3. Waits for order book to have valid data + + Args: + connector_name: The connector name + trading_pair: The trading pair + account_name: Optional account to prefer + timeout: Timeout in seconds + + Returns: + True if order book initialized successfully + """ + connector = self.get_best_connector_for_market(connector_name, account_name) + + if not connector: + logger.error(f"No connector available for {connector_name}") + return False + + # Gateway/AMM connectors don't have order book trackers - skip initialization + if not hasattr(connector, 'order_book_tracker') or connector.order_book_tracker is None: + logger.info(f"Connector {connector_name} doesn't have order book tracker (AMM/Gateway) - skipping") + return True + + tracker = connector.order_book_tracker + + # Check if already initialized + if trading_pair in tracker.order_books: + ob = tracker.order_books[trading_pair] + try: + bids, asks = ob.snapshot + if len(bids) > 0 and len(asks) > 0: + logger.info(f"Order book for {trading_pair} already initialized") + return True + except Exception: + pass + + # For data connectors, ensure network is started + if connector_name in self._data_connectors: + if not self._data_connectors_started.get(connector_name, False): + success = await self.ensure_data_connector_started( + connector_name, trading_pair + ) + if not success: + return False + # Wait for order book after starting + return await self._wait_for_order_book(tracker, trading_pair, timeout) + else: + # Connector started, dynamically add trading pair + success = await self._add_trading_pair_to_tracker( + connector, trading_pair + ) + if not success: + return False + + # For trading connectors, dynamically add trading pair + else: + success = await self._add_trading_pair_to_tracker(connector, trading_pair) + if not success: + return False + + # Wait for order book to have data + return await self._wait_for_order_book(tracker, trading_pair, timeout) + + def _is_tracker_running(self, tracker) -> bool: + """Check if the order book tracker is running.""" + if not tracker: + return False + task = tracker._order_book_stream_listener_task + if task and not task.done(): + return True + task = tracker._init_order_books_task + if task and not task.done(): + return True + return False + + async def _add_trading_pair_to_tracker( + self, + connector: ExchangePyBase, + trading_pair: str + ) -> bool: + """Add a trading pair to connector's order book tracker. + + ExchangePyBase connectors have: + - order_book_tracker with _trading_pairs, start(), _orderbook_ds + - add_trading_pair() for dynamic addition + + Approach: + 1. If tracker is running, use connector.add_trading_pair() + 2. Otherwise, register the pair and start the tracker + """ + try: + # Safety check - gateway/AMM connectors don't have order book trackers + if not hasattr(connector, 'order_book_tracker') or connector.order_book_tracker is None: + logger.debug(f"Connector {type(connector).__name__} doesn't have order book tracker") + return True + + tracker = connector.order_book_tracker + + # Case 1: Tracker is already running and ready + if self._is_tracker_running(tracker) and tracker.ready: + if trading_pair in tracker.order_books: + logger.debug(f"Order book for {trading_pair} already exists") + return True + + logger.info(f"Adding {trading_pair} to running tracker") + result = await connector.add_trading_pair(trading_pair) + if result: + logger.info(f"Successfully added {trading_pair}") + return True + logger.warning(f"add_trading_pair() returned False for {trading_pair}") + + # Case 2: Tracker not running - start it with this trading pair + else: + logger.info(f"Starting order book tracker for {type(connector).__name__} with {trading_pair}") + + # Register the trading pair before starting tracker + if trading_pair not in tracker._trading_pairs: + tracker._trading_pairs.append(trading_pair) + + tracker.start() + try: + await asyncio.wait_for(tracker.wait_ready(), timeout=30.0) + logger.info(f"Order book tracker ready for {type(connector).__name__}") + except asyncio.TimeoutError: + logger.warning("Timeout waiting for tracker to be ready") + + if trading_pair in tracker.order_books: + logger.info(f"Order book for {trading_pair} initialized") + return True + + # Fallback: Get order book snapshot directly via REST + logger.info(f"Fallback order book initialization for {trading_pair}") + try: + order_book = await connector._orderbook_ds.get_new_order_book(trading_pair) + tracker.order_books[trading_pair] = order_book + if trading_pair not in tracker._trading_pairs: + tracker._trading_pairs.append(trading_pair) + logger.info(f"Initialized order book for {trading_pair} via REST fallback") + return True + except Exception as e: + logger.error(f"Fallback order book initialization failed: {e}") + + logger.error(f"Failed to add {trading_pair} to order book tracker") + return False + + except Exception as e: + logger.error(f"Error adding trading pair {trading_pair}: {e}", exc_info=True) + return False + + async def remove_trading_pair( + self, + connector_name: str, + trading_pair: str, + account_name: Optional[str] = None + ) -> bool: + """ + Remove a trading pair from a connector's order book tracker. + + This method cleans up order book resources for a trading pair that is + no longer needed. Useful for: + - Executor cleanup when stopping + - Memory management for unused pairs + - Account cleanup operations + + Args: + connector_name: The connector name + trading_pair: The trading pair to remove + account_name: Optional account to target specific trading connector + + Returns: + True if successfully removed, False otherwise + """ + connector = self.get_best_connector_for_market(connector_name, account_name) + + if not connector: + logger.warning(f"No connector available for {connector_name} to remove {trading_pair}") + return False + + return await self._remove_trading_pair_from_tracker(connector, trading_pair) + + async def _remove_trading_pair_from_tracker( + self, + connector: ExchangePyBase, + trading_pair: str + ) -> bool: + """Remove a trading pair from connector's order book tracker. + + ExchangePyBase.remove_trading_pair() handles: + - Order book cleanup via order_book_tracker + - Funding info cleanup for perpetual connectors + """ + try: + result = await connector.remove_trading_pair(trading_pair) + if result: + logger.info(f"Removed trading pair {trading_pair}") + return True + + # Fallback: Manual removal from tracker (if connector has one) + if not hasattr(connector, 'order_book_tracker') or connector.order_book_tracker is None: + return True # No tracker to clean up for AMM/Gateway connectors + tracker = connector.order_book_tracker + if trading_pair in tracker.order_books: + del tracker.order_books[trading_pair] + if trading_pair in tracker._trading_pairs: + tracker._trading_pairs.remove(trading_pair) + logger.info(f"Removed trading pair {trading_pair} via manual fallback") + return True + + logger.warning(f"Trading pair {trading_pair} not found") + return False + + except Exception as e: + logger.error(f"Error removing trading pair {trading_pair}: {e}") + return False + + async def _wait_for_websocket_ready( + self, + connector: ExchangePyBase, + timeout: float = 10.0 + ) -> bool: + """Wait for the order book data source WebSocket to be connected.""" + data_source = connector._orderbook_ds + waited = 0 + interval = 0.2 + + while waited < timeout: + if data_source._ws_assistant is not None: + logger.debug(f"WebSocket ready for {type(connector).__name__}") + return True + await asyncio.sleep(interval) + waited += interval + + logger.warning(f"Timeout waiting for WebSocket connection on {type(connector).__name__}") + return False + + async def _wait_for_order_book( + self, + tracker, + trading_pair: str, + timeout: float + ) -> bool: + """Wait for order book to have valid bid/ask data.""" + waited = 0 + interval = 0.5 + + while waited < timeout: + if trading_pair in tracker.order_books: + ob = tracker.order_books[trading_pair] + try: + bids, asks = ob.snapshot + if len(bids) > 0 and len(asks) > 0: + logger.info( + f"Order book for {trading_pair} ready with " + f"{len(bids)} bids and {len(asks)} asks" + ) + return True + except Exception: + pass + await asyncio.sleep(interval) + waited += interval + + logger.warning(f"Timeout waiting for {trading_pair} order book") + return False + + # ========================================================================= + # Trading Connector Creation (internal) + # ========================================================================= + + async def _create_and_initialize_trading_connector( + self, + account_name: str, + connector_name: str + ) -> ConnectorBase: + """Create and fully initialize a trading connector.""" + # Authenticate and create connector + connector = self._create_trading_connector(account_name, connector_name) + + # Initialize symbol map and trading rules + await connector._initialize_trading_pair_symbol_map() + await connector._update_trading_rules() + await connector._update_balances() + + # Perpetual-specific setup + if self._is_perpetual_connector(connector): + if PositionMode.HEDGE in connector.supported_position_modes(): + connector.set_position_mode(PositionMode.HEDGE) + await connector._update_positions() + + # Load existing orders from database + if self.db_manager: + await self._load_existing_orders(connector, account_name, connector_name) + + # Setup order and funding recorders + cache_key = f"{account_name}:{connector_name}" + if self.db_manager and cache_key not in self._orders_recorders: + from services.orders_recorder import OrdersRecorder + orders_recorder = OrdersRecorder(self.db_manager, account_name, connector_name) + orders_recorder.start(connector) + self._orders_recorders[cache_key] = orders_recorder + + if self._is_perpetual_connector(connector): + from services.funding_recorder import FundingRecorder + funding_recorder = FundingRecorder(self.db_manager, account_name, connector_name) + funding_recorder.start(connector) + self._funding_recorders[cache_key] = funding_recorder + + # Initialize metrics + self._initialize_metrics(connector, account_name, connector_name, cache_key) + + # Start network tasks + await self._start_connector_network(connector) + + # Only update order status for orders loaded from DB (balances, rules, positions + # were already fetched above — no need to repeat via _update_connector_state) + if connector.in_flight_orders: + try: + connector._set_current_timestamp(time.time()) + await connector._update_order_status() + except Exception as e: + logger.error(f"Error updating initial order status for {connector_name}: {e}") + + logger.info(f"Initialized trading connector {connector_name} for {account_name}") + return connector + + def _create_trading_connector( + self, + account_name: str, + connector_name: str + ) -> ConnectorBase: + """Create a trading connector with API keys. + + For gateway connectors (containing '/'), creates a GatewayLp connector + which auto-detects chain/network and uses the default wallet. + """ + BackendAPISecurity.login_account( + account_name=account_name, + secrets_manager=self.secrets_manager + ) + + # Gateway connectors (e.g., 'meteora/clmm', 'raydium/clmm') are not in AllConnectorSettings + # They use GatewayLp which auto-detects chain/network from gateway config + if '/' in connector_name: + logger.info(f"Creating gateway connector: {connector_name}") + # GatewayLp handles chain/network auto-detection and default wallet lookup + # via start_network() call + return GatewayLp( + connector_name=connector_name, + trading_pairs=[], + trading_required=True, + ) + + conn_setting = self._conn_settings[connector_name] + keys = BackendAPISecurity.api_keys(connector_name) + + init_params = conn_setting.conn_init_parameters( + trading_pairs=[], + trading_required=True, + api_keys=keys, + ) + + connector_class = get_connector_class(connector_name) + return connector_class(**init_params) + + def _create_data_connector(self, connector_name: str) -> ConnectorBase: + """Create a non-authenticated data connector.""" + conn_setting = self._conn_settings.get(connector_name) + if not conn_setting: + raise ValueError(f"Connector {connector_name} not found") + + # Get config keys but don't use real API keys + connector_config = AllConnectorSettings.get_connector_config_keys(connector_name) + if getattr(connector_config, "use_auth_for_public_endpoints", False): + api_keys = api_keys_from_connector_config_map( + ClientConfigAdapter(connector_config) + ) + elif connector_config is not None: + api_keys = { + key: "" + for key in connector_config.__class__.model_fields.keys() + if key != "connector" + } + else: + api_keys = {} + + init_params = conn_setting.conn_init_parameters( + trading_pairs=[], + trading_required=False, + api_keys=api_keys, + ) + + connector_class = get_connector_class(connector_name) + connector = connector_class(**init_params) + + logger.info(f"Created data connector: {connector_name}") + return connector + + # ========================================================================= + # Network and State Management + # ========================================================================= + + async def _start_connector_network(self, connector: ConnectorBase): + """Start connector network tasks.""" + try: + await self._stop_connector_network(connector) + + # Gateway/AMM connectors use start_network() instead of individual polling tasks + if hasattr(connector, '_trading_rules_polling_loop'): + connector._trading_rules_polling_task = safe_ensure_future( + connector._trading_rules_polling_loop() + ) + if hasattr(connector, '_trading_fees_polling_loop'): + connector._trading_fees_polling_task = safe_ensure_future( + connector._trading_fees_polling_loop() + ) + if hasattr(connector, '_create_user_stream_tracker_task'): + connector._user_stream_tracker_task = connector._create_user_stream_tracker_task() + if hasattr(connector, '_user_stream_event_listener'): + connector._user_stream_event_listener_task = safe_ensure_future( + connector._user_stream_event_listener() + ) + if hasattr(connector, '_lost_orders_update_polling_loop'): + connector._lost_orders_update_task = safe_ensure_future( + connector._lost_orders_update_polling_loop() + ) + + # For gateway connectors, call start_network() which handles chain/network detection + if hasattr(connector, 'start_network') and not hasattr(connector, '_trading_rules_polling_loop'): + await connector.start_network() + + # NOTE: Order book tracker is started lazily when first trading pair is added + # (in _add_trading_pair_to_tracker). Starting it here with no subscriptions + # causes exchanges like Binance to immediately disconnect (close code 1008). + + logger.debug("Started network tasks for connector") + + except Exception as e: + logger.error(f"Error starting connector network: {e}") + raise + + async def _stop_connector_network(self, connector: ConnectorBase): + """Stop connector network tasks.""" + tasks = [ + '_trading_rules_polling_task', + '_trading_fees_polling_task', + '_status_polling_task', + '_user_stream_tracker_task', + '_user_stream_event_listener_task', + '_lost_orders_update_task', + ] + + for task_name in tasks: + task = getattr(connector, task_name, None) + if task: + task.cancel() + setattr(connector, task_name, None) + + # Stop the order book tracker (if connector has one - AMM/Gateway connectors don't) + if hasattr(connector, 'order_book_tracker') and connector.order_book_tracker: + connector.order_book_tracker.stop() + + # For gateway connectors, call stop_network() + if hasattr(connector, 'stop_network'): + await connector.stop_network() + + async def _update_connector_state( + self, + connector: ConnectorBase, + connector_name: str, + account_name: str = None + ): + """Update connector state (balances, positions, orders). + + Note: Trading rules are NOT refreshed here — the background + _trading_rules_polling_loop() (started in _start_connector_network) + already handles that. + """ + try: + connector._set_current_timestamp(time.time()) + await connector._update_balances() + + if self._is_perpetual_connector(connector): + await connector._update_positions() + + if connector.in_flight_orders: + await connector._update_order_status() + if account_name: + await self._sync_orders_to_database( + connector, account_name, connector_name + ) + + except Exception as e: + logger.error(f"Error updating connector state: {e}") + + async def update_all_trading_connector_states(self): + """Update state for all trading connectors in parallel.""" + tasks = [] + task_keys = [] + for account_name, connectors in self._trading_connectors.items(): + for connector_name, connector in connectors.items(): + tasks.append(self._update_connector_state(connector, connector_name, account_name)) + task_keys.append(f"{account_name}/{connector_name}") + if tasks: + results = await asyncio.gather(*tasks, return_exceptions=True) + for key, result in zip(task_keys, results): + if isinstance(result, Exception): + logger.error(f"Error updating {key}: {result}") + + async def initialize_all_trading_connectors(self): + """ + Initialize all trading connectors for all accounts at startup. + + This ensures that: + 1. All connectors are ready to use immediately + 2. Existing orders from database are loaded into in_flight_orders + 3. Order tracking and cancellation work without needing manual initialization + """ + # Get list of all accounts + accounts = fs_util.list_folders('credentials') + + total_initialized = 0 + for account_name in accounts: + # Get all connector credentials for this account + connector_names = self.list_available_credentials(account_name) + + for connector_name in connector_names: + try: + logger.info(f"Initializing connector: {account_name}/{connector_name}") + await self.get_trading_connector(account_name, connector_name) + total_initialized += 1 + except Exception as e: + logger.error(f"Failed to initialize {account_name}/{connector_name}: {e}") + # Continue with other connectors even if one fails + continue + + logger.info(f"Initialized {total_initialized} trading connectors across {len(accounts)} accounts") + + # ========================================================================= + # Order Management + # ========================================================================= + + async def _load_existing_orders( + self, + connector: ConnectorBase, + account_name: str, + connector_name: str + ): + """Load existing orders from database into connector.""" + try: + from database import OrderRepository + + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + active_orders = await order_repo.get_active_orders( + account_name=account_name, + connector_name=connector_name + ) + + for order_record in active_orders: + try: + in_flight_order = self._convert_db_order_to_in_flight(order_record) + connector.in_flight_orders[in_flight_order.client_order_id] = in_flight_order + except Exception as e: + logger.error(f"Error loading order {order_record.client_order_id}: {e}") + + logger.info( + f"Loaded {len(connector.in_flight_orders)} orders for " + f"{account_name}/{connector_name}" + ) + + except Exception as e: + logger.error(f"Error loading orders from database: {e}") + + async def _sync_orders_to_database( + self, + connector: ConnectorBase, + account_name: str, + connector_name: str + ): + """Sync connector's in_flight_orders state to database.""" + if not self.db_manager: + return + + terminal_states = [ + OrderState.FILLED, OrderState.CANCELED, + OrderState.FAILED, OrderState.COMPLETED + ] + orders_to_remove = [] + + for client_order_id, order in list(connector.in_flight_orders.items()): + try: + from database import OrderRepository + + async with self.db_manager.get_session_context() as session: + order_repo = OrderRepository(session) + db_order = await order_repo.get_order_by_client_id(client_order_id) + + if db_order: + new_status = self._map_order_state_to_status(order.current_state) + if db_order.status != new_status: + await order_repo.update_order_status(client_order_id, new_status) + + if order.current_state in terminal_states: + orders_to_remove.append(client_order_id) + + except Exception as e: + logger.error(f"Error syncing order {client_order_id}: {e}") + + for order_id in orders_to_remove: + connector.in_flight_orders.pop(order_id, None) + + async def sync_all_orders_to_database(self): + """ + Sync connector's in_flight_orders state to database for all trading connectors. + + The connector's built-in polling already updates in_flight_orders from the exchange. + This method syncs that state to our database and cleans up closed orders. + """ + for account_name, connectors in self._trading_connectors.items(): + for connector_name, connector in connectors.items(): + try: + if not connector.in_flight_orders: + continue + await self._sync_orders_to_database(connector, account_name, connector_name) + logger.debug(f"Synced order state to DB for {account_name}/{connector_name}") + except Exception as e: + logger.error(f"Error syncing order state for {account_name}/{connector_name}: {e}") + + def _convert_db_order_to_in_flight(self, order_record) -> InFlightOrder: + """Convert database order to InFlightOrder.""" + status_mapping = { + "SUBMITTED": OrderState.PENDING_CREATE, + "OPEN": OrderState.OPEN, + "PARTIALLY_FILLED": OrderState.PARTIALLY_FILLED, + "FILLED": OrderState.FILLED, + "CANCELLED": OrderState.CANCELED, + "FAILED": OrderState.FAILED, + } + + order_state = status_mapping.get(order_record.status, OrderState.PENDING_CREATE) + + try: + order_type = OrderType[order_record.order_type] + except (KeyError, ValueError): + order_type = OrderType.LIMIT + + try: + trade_type = TradeType[order_record.trade_type] + except (KeyError, ValueError): + trade_type = TradeType.BUY + + creation_timestamp = ( + order_record.created_at.timestamp() + if order_record.created_at else time.time() + ) + + in_flight_order = InFlightOrder( + client_order_id=order_record.client_order_id, + trading_pair=order_record.trading_pair, + order_type=order_type, + trade_type=trade_type, + amount=Decimal(str(order_record.amount)), + creation_timestamp=creation_timestamp, + price=Decimal(str(order_record.price)) if order_record.price else None, + exchange_order_id=order_record.exchange_order_id, + initial_state=order_state, + leverage=1, + position=PositionAction.NIL, + ) + + in_flight_order.current_state = order_state + if order_record.filled_amount: + in_flight_order.executed_amount_base = Decimal(str(order_record.filled_amount)) + + return in_flight_order + + def _map_order_state_to_status(self, order_state: OrderState) -> str: + """Map OrderState to database status string.""" + mapping = { + OrderState.PENDING_CREATE: "SUBMITTED", + OrderState.OPEN: "OPEN", + OrderState.PENDING_CANCEL: "PENDING_CANCEL", + OrderState.CANCELED: "CANCELLED", + OrderState.PARTIALLY_FILLED: "PARTIALLY_FILLED", + OrderState.FILLED: "FILLED", + OrderState.FAILED: "FAILED", + OrderState.PENDING_APPROVAL: "SUBMITTED", + OrderState.APPROVED: "SUBMITTED", + OrderState.CREATED: "SUBMITTED", + OrderState.COMPLETED: "FILLED", + } + return mapping.get(order_state, "SUBMITTED") + + # ========================================================================= + # Metrics + # ========================================================================= + + def _initialize_metrics( + self, + connector: ConnectorBase, + account_name: str, + connector_name: str, + cache_key: str + ): + """Initialize trade volume metrics collector.""" + if cache_key in self._metrics_collectors: + return + + if "_paper_trade" in connector_name: + return + + try: + instance_id = f"{account_name}_hbotapi" + rate_provider = RateOracle.get_instance() + + metrics_collector = TradeVolumeMetricCollector( + connector=connector, + activation_interval=self.METRICS_ACTIVATION_INTERVAL, + rate_provider=rate_provider, + instance_id=instance_id, + valuation_token=self.METRICS_VALUATION_TOKEN + ) + metrics_collector.start() + self._metrics_collectors[cache_key] = metrics_collector + + except Exception as e: + logger.warning(f"Failed to init metrics for {connector_name}: {e}") + + # ========================================================================= + # Credentials and Configuration + # ========================================================================= + + async def update_connector_keys( + self, + account_name: str, + connector_name: str, + keys: dict + ) -> ConnectorBase: + """Update API keys and recreate connector.""" + if not BackendAPISecurity.login_account( + account_name=account_name, + secrets_manager=self.secrets_manager + ): + raise ValueError(f"Failed to authenticate for {account_name}") + + connector_config = HummingbotAPIConfigAdapter( + AllConnectorSettings.get_connector_config_keys(connector_name) + ) + + for key, value in keys.items(): + setattr(connector_config, key, value) + + BackendAPISecurity.update_connector_keys(account_name, connector_config) + BackendAPISecurity.decrypt_all(account_name=account_name) + + # Properly stop old connector (stops recorders, network tasks, cleans up caches) + await self.stop_trading_connector(account_name, connector_name) + + # Create new connector with fresh recorders + return await self.get_trading_connector(account_name, connector_name) + + def clear_trading_connector( + self, + account_name: Optional[str] = None, + connector_name: Optional[str] = None + ): + """Clear trading connector from cache.""" + if account_name and connector_name: + if account_name in self._trading_connectors: + self._trading_connectors[account_name].pop(connector_name, None) + elif account_name: + self._trading_connectors.pop(account_name, None) + else: + self._trading_connectors.clear() + + def list_account_connectors(self, account_name: str) -> List[str]: + """List initialized connectors for an account.""" + return list(self._trading_connectors.get(account_name, {}).keys()) + + def list_available_credentials(self, account_name: str) -> List[str]: + """List connector credentials available for an account.""" + try: + files = fs_util.list_files(f"credentials/{account_name}/connectors") + return [f.replace(".yml", "") for f in files if f.endswith(".yml")] + except FileNotFoundError: + return [] + + @staticmethod + def get_connector_config_map(connector_name: str): + """Get connector config field info.""" + from typing import Literal, get_args, get_origin + + connector_config = HummingbotAPIConfigAdapter( + AllConnectorSettings.get_connector_config_keys(connector_name) + ) + fields_info = {} + + for key, field in connector_config.hb_config.model_fields.items(): + if key == "connector": + continue + + field_type = field.annotation + type_name = getattr(field_type, "__name__", str(field_type)) + allowed_values = None + + origin = get_origin(field_type) + args = get_args(field_type) + + if origin is Literal: + type_name = "Literal" + allowed_values = list(args) + elif origin is not None: + if type(None) in args: + actual_types = [arg for arg in args if arg is not type(None)] + if actual_types: + inner_type = actual_types[0] + inner_origin = get_origin(inner_type) + inner_args = get_args(inner_type) + if inner_origin is Literal: + type_name = "Literal" + allowed_values = list(inner_args) + else: + type_name = getattr(inner_type, "__name__", str(inner_type)) + else: + type_name = str(field_type) + + field_info = {"type": type_name, "required": field.is_required()} + if allowed_values is not None: + field_info["allowed_values"] = allowed_values + fields_info[key] = field_info + + return fields_info + + # ========================================================================= + # Cleanup + # ========================================================================= + + async def stop_trading_connector(self, account_name: str, connector_name: str): + """Stop a trading connector and its services.""" + cache_key = f"{account_name}:{connector_name}" + + # Stop recorders + if cache_key in self._orders_recorders: + try: + await self._orders_recorders[cache_key].stop() + del self._orders_recorders[cache_key] + except Exception as e: + logger.error(f"Error stopping orders recorder: {e}") + + if cache_key in self._funding_recorders: + try: + await self._funding_recorders[cache_key].stop() + del self._funding_recorders[cache_key] + except Exception as e: + logger.error(f"Error stopping funding recorder: {e}") + + if cache_key in self._metrics_collectors: + try: + self._metrics_collectors[cache_key].stop() + del self._metrics_collectors[cache_key] + except Exception as e: + logger.error(f"Error stopping metrics: {e}") + + # Stop connector network + if account_name in self._trading_connectors: + connector = self._trading_connectors[account_name].get(connector_name) + if connector: + await self._stop_connector_network(connector) + del self._trading_connectors[account_name][connector_name] + + logger.info(f"Stopped trading connector {account_name}/{connector_name}") + + async def stop_all(self): + """Stop all connectors and services.""" + # Stop all trading connectors + for account_name, connectors in list(self._trading_connectors.items()): + for connector_name in list(connectors.keys()): + await self.stop_trading_connector(account_name, connector_name) + + # Stop data connectors + for connector_name, connector in self._data_connectors.items(): + try: + await connector.stop_network() + except Exception as e: + logger.error(f"Error stopping data connector {connector_name}: {e}") + + self._data_connectors.clear() + self._data_connectors_started.clear() + + logger.info("Stopped all connectors") + + # ========================================================================= + # Order Book Tracker Diagnostics & Restart + # ========================================================================= + + def get_order_book_tracker_diagnostics( + self, + connector_name: str, + account_name: Optional[str] = None + ) -> Dict: + """Get diagnostics for a connector's order book tracker. + + Returns information about: + - Whether the tracker is running + - Task status (alive/crashed) + - Metrics (diffs processed, last update, etc.) + - WebSocket status + + Args: + connector_name: The connector to diagnose + account_name: Optional account for trading connector + + Returns: + Dictionary with diagnostic information + """ + connector = self.get_best_connector_for_market(connector_name, account_name) + + if not connector: + return {"error": f"No connector found for {connector_name}"} + + diagnostics = { + "connector_type": type(connector).__name__, + "connector_name": connector_name, + "has_order_book_tracker": False, + "tracker_ready": False, + "tasks": {}, + "trading_pairs": [], + "order_books": {}, + "metrics": None, + "websocket_status": "unknown", + } + + if not hasattr(connector, 'order_book_tracker') or not connector.order_book_tracker: + return diagnostics + + tracker = connector.order_book_tracker + diagnostics["has_order_book_tracker"] = True + diagnostics["tracker_ready"] = tracker.ready if hasattr(tracker, 'ready') else False + + # Get trading pairs + if hasattr(tracker, '_trading_pairs'): + diagnostics["trading_pairs"] = list(tracker._trading_pairs) if isinstance(tracker._trading_pairs, (list, set)) else [] + + # Check task status + task_names = [ + '_order_book_stream_listener_task', + '_order_book_diff_listener_task', + '_order_book_trade_listener_task', + '_order_book_snapshot_listener_task', + '_order_book_diff_router_task', + '_order_book_snapshot_router_task', + '_init_order_books_task', + '_emit_trade_event_task', + ] + + for task_name in task_names: + task = getattr(tracker, task_name, None) + if task is not None: + diagnostics["tasks"][task_name] = { + "exists": True, + "done": task.done(), + "cancelled": task.cancelled(), + "exception": str(task.exception()) if task.done() and not task.cancelled() and task.exception() else None, + } + else: + diagnostics["tasks"][task_name] = {"exists": False} + + # Check order books + if hasattr(tracker, 'order_books'): + for trading_pair, order_book in tracker.order_books.items(): + try: + bids, asks = order_book.snapshot + best_bid = float(bids.iloc[0]['price']) if len(bids) > 0 else None + best_ask = float(asks.iloc[0]['price']) if len(asks) > 0 else None + diagnostics["order_books"][trading_pair] = { + "best_bid": best_bid, + "best_ask": best_ask, + "bid_count": len(bids), + "ask_count": len(asks), + "snapshot_uid": order_book.snapshot_uid if hasattr(order_book, 'snapshot_uid') else None, + "last_diff_uid": order_book.last_diff_uid if hasattr(order_book, 'last_diff_uid') else None, + } + except Exception as e: + diagnostics["order_books"][trading_pair] = {"error": str(e)} + + # Get metrics if available + if hasattr(tracker, 'metrics'): + try: + diagnostics["metrics"] = tracker.metrics.to_dict() + except Exception as e: + diagnostics["metrics"] = {"error": str(e)} + + # Check WebSocket status + if hasattr(connector, '_orderbook_ds') and connector._orderbook_ds: + data_source = connector._orderbook_ds + if hasattr(data_source, '_ws_assistant') and data_source._ws_assistant is not None: + diagnostics["websocket_status"] = "connected" + else: + diagnostics["websocket_status"] = "not_connected" + + return diagnostics + + async def restart_order_book_tracker( + self, + connector_name: str, + account_name: Optional[str] = None + ) -> Dict: + """Restart the order book tracker for a connector. + + This method: + 1. Stops the existing order book tracker + 2. Restarts it with the same trading pairs + + Args: + connector_name: The connector to restart + account_name: Optional account for trading connector + + Returns: + Dictionary with restart status + """ + connector = self.get_best_connector_for_market(connector_name, account_name) + + if not connector: + return {"success": False, "error": f"No connector found for {connector_name}"} + + # Gateway/AMM connectors don't have order book trackers + if not hasattr(connector, 'order_book_tracker') or connector.order_book_tracker is None: + return {"success": False, "error": f"Connector {connector_name} doesn't have order book tracker (AMM/Gateway)"} + + tracker = connector.order_book_tracker + trading_pairs = list(tracker._trading_pairs) + + if not trading_pairs: + return {"success": False, "error": "No trading pairs to restart"} + + try: + # Stop the tracker + logger.info(f"Stopping order book tracker for {connector_name}...") + tracker.stop() + + # Wait a moment for cleanup + await asyncio.sleep(0.5) + + # Re-add trading pairs to tracker before restarting + tracker._trading_pairs.clear() + for tp in trading_pairs: + tracker._trading_pairs.append(tp) + + # Restart the tracker + logger.info(f"Restarting order book tracker for {connector_name} with pairs: {trading_pairs}") + tracker.start() + + # Wait for initialization + try: + await asyncio.wait_for(tracker.wait_ready(), timeout=30.0) + except asyncio.TimeoutError: + logger.warning("Timeout waiting for tracker to be ready, continuing anyway...") + + # Wait for WebSocket to be ready + await self._wait_for_websocket_ready(connector, timeout=10.0) + + return { + "success": True, + "message": f"Order book tracker restarted for {connector_name}", + "trading_pairs": trading_pairs, + } + + except Exception as e: + logger.error(f"Error restarting order book tracker: {e}", exc_info=True) + return {"success": False, "error": str(e)} diff --git a/set_environment.sh b/set_environment.sh deleted file mode 100644 index ee59d61d..00000000 --- a/set_environment.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -# Create or overwrite .env file -echo "Setting up .env file for the project... -By default, the current working directory will be used as the BOTS_PATH and the CONFIG_PASSWORD will be set to 'a'." - -# Asking for CONFIG_PASSWORD and BOTS_PATH -CONFIG_PASSWORD=a -BOTS_PATH=$(pwd) - -# Write to .env file -echo "CONFIG_PASSWORD=$CONFIG_PASSWORD" > .env -echo "BOTS_PATH=$BOTS_PATH" >> .env diff --git a/setup.sh b/setup.sh new file mode 100755 index 00000000..95128ed2 --- /dev/null +++ b/setup.sh @@ -0,0 +1,419 @@ +#!/bin/bash +# Hummingbot API Setup - Creates .env with sensible defaults (Mac/Linux/WSL2) +# - On Linux (apt-based): installs build deps (gcc, build-essential) +# - Ensures Docker + Docker Compose are available (auto-installs on Linux via get.docker.com) +# - Idempotent: safe to run multiple times, skips already-completed steps +# - Verbose output: shows all installation progress directly +# - Fixed: Removed apt-get upgrade, uses /dev/tty for prompts + +set -euo pipefail + +echo "Hummingbot API Setup" +echo "" + +# -------------------------- +# State Tracking Variables +# -------------------------- +APT_CACHE_UPDATED=false +DOCKER_ALREADY_PRESENT=false +COMPOSE_ALREADY_PRESENT=false + +has_cmd() { command -v "$1" >/dev/null 2>&1; } + +resolve_script_dir() { + local src="${BASH_SOURCE[0]}" + while [ -h "$src" ]; do + local dir + dir="$(cd -P "$(dirname "$src")" >/dev/null 2>&1 && pwd)" + src="$(readlink "$src")" + [[ "$src" != /* ]] && src="$dir/$src" + done + cd -P "$(dirname "$src")" >/dev/null 2>&1 && pwd +} + +SCRIPT_DIR="$(resolve_script_dir)" + +# -------------------------- +# OS / Environment Detection +# -------------------------- +OS="$(uname -s || true)" +ARCH="$(uname -m || true)" + +is_linux() { [[ "${OS}" == "Linux" ]]; } +is_macos() { [[ "${OS}" == "Darwin" ]]; } + +docker_ok() { has_cmd docker; } + +docker_compose_ok() { + if has_cmd docker && docker compose version >/dev/null 2>&1; then + return 0 + fi + if has_cmd docker-compose && docker-compose version >/dev/null 2>&1; then + return 0 + fi + return 1 +} + +need_sudo_or_die() { + if ! has_cmd sudo; then + echo "ERROR: 'sudo' is required for dependency installation on this system." + echo "Please install sudo (or run as root) and re-run this script." + exit 1 + fi +} + +# -------------------------- +# APT Cache Management (Linux) +# -------------------------- +safe_apt_update() { + # Only run apt-get update once per script execution + if [ "$APT_CACHE_UPDATED" = false ]; then + echo "[INFO] Updating apt cache..." + sudo env DEBIAN_FRONTEND=noninteractive apt-get update + APT_CACHE_UPDATED=true + fi +} + +# -------------------------- +# Package Check Utilities +# -------------------------- +is_package_installed() { + # Check if a Debian package is installed + # Usage: is_package_installed package-name + dpkg -l "$1" 2>/dev/null | grep -q "^ii" +} + +# -------------------------- +# Linux Dependencies +# -------------------------- +install_linux_build_deps() { + if has_cmd apt-get; then + # Check if build dependencies are already installed + if is_package_installed build-essential && has_cmd gcc; then + echo "[OK] Build dependencies (gcc, build-essential) already installed. Skipping." + return 0 + fi + + need_sudo_or_die + echo "[INFO] Installing build dependencies (gcc, build-essential)..." + + safe_apt_update + + # REMOVED: apt-get upgrade -y + # This was causing failures due to system-wide package upgrades + # apt-get install will get the latest available versions anyway + + sudo env DEBIAN_FRONTEND=noninteractive apt-get install -y gcc build-essential + + echo "[OK] Build dependencies installed." + else + echo "[WARN] Detected Linux, but 'apt-get' is not available. Skipping build dependency install." + fi +} + +ensure_curl_on_linux() { + if has_cmd curl; then + echo "[OK] curl is already installed." + return 0 + fi + + if has_cmd apt-get; then + need_sudo_or_die + echo "[INFO] Installing curl (required for Docker install script)..." + safe_apt_update + sudo env DEBIAN_FRONTEND=noninteractive apt-get install -y curl ca-certificates + echo "[OK] curl installed." + return 0 + fi + + echo "[WARN] curl is not installed and apt-get is unavailable. Please install curl and re-run." + return 1 +} + +# -------------------------- +# Docker Install / Validation +# -------------------------- +check_user_in_docker_group() { + # Check if current user is already in docker group + if [[ "${EUID}" -eq 0 ]]; then + # Running as root, no need for docker group + return 0 + fi + + if has_cmd getent && getent group docker >/dev/null 2>&1; then + if id -nG "$USER" 2>/dev/null | grep -qw docker; then + return 0 + fi + fi + + return 1 +} + +add_user_to_docker_group() { + # Only add user to docker group if not already a member + if check_user_in_docker_group; then + echo "[OK] User '$USER' is already in the 'docker' group." + return 0 + fi + + if has_cmd getent && getent group docker >/dev/null 2>&1; then + if [[ "${EUID}" -ne 0 ]]; then + echo "[INFO] Adding current user to 'docker' group (may require re-login)..." + sudo usermod -aG docker "$USER" >/dev/null 2>&1 || true + echo "[OK] User added to docker group. You may need to log out and back in for this to take effect." + fi + fi +} + +install_docker_linux() { + need_sudo_or_die + ensure_curl_on_linux + + echo "[INFO] Docker not found. Installing Docker using get.docker.com script..." + curl -fsSL https://get.docker.com -o get-docker.sh + sudo sh get-docker.sh + rm -f get-docker.sh + + if has_cmd systemctl; then + if systemctl is-system-running >/dev/null 2>&1; then + echo "[INFO] Enabling and starting Docker service..." + sudo systemctl enable docker 2>/dev/null || true + sudo systemctl start docker 2>/dev/null || true + fi + fi + + add_user_to_docker_group +} + +ensure_docker_and_compose() { + if is_linux; then + # Check Docker installation + if docker_ok; then + echo "[OK] Docker already installed: $(docker --version 2>/dev/null || echo 'version unknown')" + DOCKER_ALREADY_PRESENT=true + + # Even if Docker is installed, ensure user is in docker group + add_user_to_docker_group + else + # Check if Docker binary exists but isn't in PATH + if [ -x "/usr/bin/docker" ] || [ -x "/usr/local/bin/docker" ]; then + echo "[INFO] Docker found but not in current PATH. Adding to PATH..." + export PATH="/usr/bin:/usr/local/bin:$PATH" + + if docker_ok; then + echo "[OK] Docker is now accessible: $(docker --version 2>/dev/null || echo 'version unknown')" + DOCKER_ALREADY_PRESENT=true + add_user_to_docker_group + else + install_docker_linux + fi + else + install_docker_linux + fi + fi + + # Verify Docker is actually working + if ! docker_ok; then + echo "ERROR: Docker installation did not succeed or 'docker' is still not on PATH." + echo " Try opening a new shell and re-running, or verify Docker installation." + exit 1 + fi + + # Check Docker Compose installation + if docker_compose_ok; then + echo "[OK] Docker Compose already available" + COMPOSE_ALREADY_PRESENT=true + + # Show which version we detected + if docker compose version >/dev/null 2>&1; then + echo "[OK] Using Docker Compose plugin: $(docker compose version 2>/dev/null || echo 'version unknown')" + else + echo "[OK] Using standalone docker-compose: $(docker-compose version 2>/dev/null || echo 'version unknown')" + fi + else + # Try to install docker-compose-plugin + if has_cmd apt-get; then + # Check if plugin package is already installed but not working + if is_package_installed docker-compose-plugin; then + echo "[WARN] docker-compose-plugin package is installed but not functioning properly." + echo "[INFO] Attempting to reinstall docker-compose-plugin..." + need_sudo_or_die + safe_apt_update + sudo env DEBIAN_FRONTEND=noninteractive apt-get install --reinstall -y docker-compose-plugin || true + else + need_sudo_or_die + echo "[INFO] Docker Compose not found. Attempting to install docker-compose-plugin..." + safe_apt_update + sudo env DEBIAN_FRONTEND=noninteractive apt-get install -y docker-compose-plugin || true + fi + fi + fi + + # Final verification of Docker Compose + if ! docker_compose_ok; then + echo "ERROR: Docker Compose is not available." + echo " Expected either 'docker compose' (v2) or 'docker-compose' (v1)." + echo " On Ubuntu/Debian, try: sudo apt-get install -y docker-compose-plugin" + exit 1 + fi + + elif is_macos; then + if ! docker_ok || ! docker_compose_ok; then + echo "ERROR: Docker and/or Docker Compose not found on macOS." + echo " Install Docker Desktop for Mac (Apple Silicon or Intel) and re-run this script." + echo " After installation, ensure 'docker' works in this terminal (you may need a new shell)." + exit 1 + fi + + echo "[OK] Docker detected: $(docker --version 2>/dev/null || echo 'version unknown')" + if docker compose version >/dev/null 2>&1; then + echo "[OK] Docker Compose detected: $(docker compose version 2>/dev/null || echo 'version unknown')" + else + echo "[OK] Docker Compose detected: $(docker-compose version 2>/dev/null || echo 'version unknown')" + fi + + else + echo "[WARN] Unsupported/unknown OS '${OS}'. Proceeding without installing OS-level dependencies." + if ! docker_ok || ! docker_compose_ok; then + echo "ERROR: Docker and/or Docker Compose not found." + exit 1 + fi + + echo "[OK] Docker detected: $(docker --version 2>/dev/null || echo 'version unknown')" + if docker compose version >/dev/null 2>&1; then + echo "[OK] Docker Compose detected: $(docker compose version 2>/dev/null || echo 'version unknown')" + else + echo "[OK] Docker Compose detected: $(docker-compose version 2>/dev/null || echo 'version unknown')" + fi + fi +} + +# -------------------------- +# Pull Hummingbot Docker Image +# -------------------------- +pull_hummingbot_image() { + echo "[INFO] Pulling latest Hummingbot image (hummingbot/hummingbot:latest)..." + if docker pull hummingbot/hummingbot:latest; then + echo "[OK] Hummingbot image pulled successfully." + else + echo "[WARN] Could not pull hummingbot/hummingbot:latest (network issue?). You may need to run 'docker pull hummingbot/hummingbot:latest' manually." + fi +} + +# -------------------------- +# Pre-flight (deps + docker) +# -------------------------- +echo "[INFO] OS=${OS} ARCH=${ARCH}" + +if is_linux; then + install_linux_build_deps +fi + +ensure_docker_and_compose + +# Show summary of what was done +echo "" +if [ "$DOCKER_ALREADY_PRESENT" = true ] && [ "$COMPOSE_ALREADY_PRESENT" = true ]; then + echo "[OK] All dependencies were already installed. No changes made." +elif [ "$DOCKER_ALREADY_PRESENT" = true ]; then + echo "[OK] Docker was already installed. Docker Compose has been set up." +elif [ "$COMPOSE_ALREADY_PRESENT" = true ]; then + echo "[OK] Docker has been installed. Docker Compose was already available." +else + echo "[OK] Docker and Docker Compose have been installed." +fi + +echo "" + +# Always pull latest Hummingbot image (first install and upgrade) +pull_hummingbot_image + +echo "" + +# -------------------------- +# Existing .env creation flow +# -------------------------- +if [ -f ".env" ]; then + echo ".env file already exists. Skipping setup." + echo "" + + # Ensure sentinel file exists + if [ ! -f ".setup-complete" ]; then + touch .setup-complete + fi + + exit 0 +fi + +# Clear screen before prompting user (only if running interactively) +if [[ -t 0 ]] && [[ -c /dev/tty ]]; then + if has_cmd clear; then + clear + else + printf "\033c" + fi +fi + +echo "Hummingbot API Setup" +echo "" + +# Use /dev/tty for prompts to work correctly when called from parent scripts +if [[ -c /dev/tty ]] && [[ -r /dev/tty ]]; then + read -p "API username [default: admin]: " USERNAME < /dev/tty +else + read -p "API username [default: admin]: " USERNAME +fi +USERNAME=${USERNAME:-admin} + +if [[ -c /dev/tty ]] && [[ -r /dev/tty ]]; then + read -p "API password [default: admin]: " PASSWORD < /dev/tty +else + read -p "API password [default: admin]: " PASSWORD +fi +PASSWORD=${PASSWORD:-admin} + +if [[ -c /dev/tty ]] && [[ -r /dev/tty ]]; then + read -p "Config password [default: admin]: " CONFIG_PASSWORD < /dev/tty +else + read -p "Config password [default: admin]: " CONFIG_PASSWORD +fi +CONFIG_PASSWORD=${CONFIG_PASSWORD:-admin} + +cat > .env << EOF +# Hummingbot API Configuration +USERNAME=$USERNAME +PASSWORD=$PASSWORD +CONFIG_PASSWORD=$CONFIG_PASSWORD +DEBUG_MODE=false + +# MQTT Broker +BROKER_HOST=localhost +BROKER_PORT=1883 +BROKER_USERNAME=admin +BROKER_PASSWORD=password + +# Database (auto-configured by docker-compose) +DATABASE_URL=postgresql+asyncpg://hbot:hummingbot-api@localhost:5432/hummingbot_api + +# Gateway (optional) +GATEWAY_URL=http://localhost:15888 +GATEWAY_PASSPHRASE=admin + +# Paths +BOTS_PATH=$(pwd) +EOF + +touch .setup-complete + +echo "" +echo ".env created successfully!" +echo "" +echo "Next steps:" +echo "" +echo "Option 1: Start all services with Docker (recommended)" +echo " make deploy" +echo "" +echo "Option 2: Run API locally (dev mode)" +echo " make install # Creates the conda environment - Note: Please install the latest Anaconda version manually" +echo " make run # Run API" +echo "" diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/test_gateway_lp_executor.py b/test/test_gateway_lp_executor.py new file mode 100644 index 00000000..944721e4 --- /dev/null +++ b/test/test_gateway_lp_executor.py @@ -0,0 +1,294 @@ +""" +Tests for Gateway LP Executor functionality. + +Tests the following fixes: +1. KeyError: 'meteora/clmm' - Gateway connectors should use GatewayLp directly +2. Script config staging compatibility - candles_config and markets removed + +Run with: pytest test/test_gateway_lp_executor.py -v +""" +import inspect +import os +from unittest.mock import MagicMock, patch + +import pytest + +# Skip all tests if hummingbot not installed +pytest.importorskip("hummingbot") + + +class TestGatewayConnectorFix: + """Tests for Fix 1: KeyError 'meteora/clmm' resolution.""" + + def test_gateway_lp_import(self): + """GatewayLp should be importable from hummingbot.""" + from hummingbot.connector.gateway.gateway_lp import GatewayLp + assert GatewayLp is not None + + def test_gateway_lp_instantiation(self): + """GatewayLp should instantiate with meteora/clmm connector name.""" + from hummingbot.connector.gateway.gateway_lp import GatewayLp + + connector = GatewayLp( + connector_name="meteora/clmm", + trading_pairs=[], + trading_required=True, + ) + assert connector.connector_name == "meteora/clmm" + assert connector.name == "meteora/clmm" + + def test_gateway_lp_has_required_methods(self): + """GatewayLp should have methods required by LP executor.""" + from hummingbot.connector.gateway.gateway_lp import GatewayLp + + connector = GatewayLp( + connector_name="meteora/clmm", + trading_pairs=[], + trading_required=True, + ) + + required_methods = [ + "get_position_info", + "_clmm_add_liquidity", + "create_market_order_id", + "start_network", + "stop_network", + ] + + for method in required_methods: + assert hasattr(connector, method), f"Missing method: {method}" + + def test_gateway_detection_in_unified_connector_service(self): + """_create_trading_connector should detect gateway connectors.""" + from services.unified_connector_service import UnifiedConnectorService + + source = inspect.getsource(UnifiedConnectorService._create_trading_connector) + + # Check gateway detection logic exists + assert "'/' in connector_name" in source, "Gateway detection condition not found" + assert "GatewayLp(" in source, "GatewayLp instantiation not found" + + def test_gateway_connector_names_detected(self): + """Gateway connector names (with /) should be detected correctly.""" + gateway_connectors = [ + "meteora/clmm", + "raydium/clmm", + "uniswap/amm", + "jupiter/router", + "orca/whirlpool", + ] + + regular_connectors = [ + "binance", + "binance_perpetual", + "kucoin", + "gate_io", + ] + + for name in gateway_connectors: + assert "/" in name, f"{name} should be detected as gateway" + + for name in regular_connectors: + assert "/" not in name, f"{name} should NOT be detected as gateway" + + @pytest.mark.asyncio + async def test_create_trading_connector_for_gateway(self): + """_create_trading_connector should return GatewayLp for gateway connectors.""" + from hummingbot.connector.gateway.gateway_lp import GatewayLp + + from services.unified_connector_service import UnifiedConnectorService + + # Create a minimal service instance + service = UnifiedConnectorService.__new__(UnifiedConnectorService) + service._conn_settings = {} + service.secrets_manager = MagicMock() + + # Mock BackendAPISecurity + with patch("services.unified_connector_service.BackendAPISecurity") as mock_security: + mock_security.login_account = MagicMock() + + connector = service._create_trading_connector( + account_name="master_account", + connector_name="meteora/clmm" + ) + + assert isinstance(connector, GatewayLp) + assert connector.connector_name == "meteora/clmm" + + +class TestScriptConfigFix: + """Tests for Fix 2: Script config staging compatibility.""" + + def test_script_config_no_candles_config(self): + """Script config should not include candles_config.""" + from routers.bot_orchestration import deploy_v2_controllers + + source = inspect.getsource(deploy_v2_controllers) + + # Find the script_config_content dict + import re + match = re.search(r"script_config_content\s*=\s*\{([^}]+)\}", source, re.DOTALL) + assert match, "script_config_content not found in deploy_v2_controllers" + + config_str = match.group(1) + assert "candles_config" not in config_str, "candles_config should not be in script_config_content" + + def test_script_config_no_markets(self): + """Script config should not include markets.""" + from routers.bot_orchestration import deploy_v2_controllers + + source = inspect.getsource(deploy_v2_controllers) + + import re + match = re.search(r"script_config_content\s*=\s*\{([^}]+)\}", source, re.DOTALL) + assert match, "script_config_content not found in deploy_v2_controllers" + + config_str = match.group(1) + assert '"markets"' not in config_str, "markets should not be in script_config_content" + + def test_script_config_has_required_fields(self): + """Script config should have script_file_name and controllers_config.""" + from routers.bot_orchestration import deploy_v2_controllers + + source = inspect.getsource(deploy_v2_controllers) + + import re + match = re.search(r"script_config_content\s*=\s*\{([^}]+)\}", source, re.DOTALL) + assert match, "script_config_content not found in deploy_v2_controllers" + + config_str = match.group(1) + assert "script_file_name" in config_str, "script_file_name should be in script_config_content" + assert "controllers_config" in config_str, "controllers_config should be in script_config_content" + + +class TestLPExecutorRegistry: + """Tests for LP executor type registration.""" + + def test_lp_executor_type_exists(self): + """lp_executor should be a valid executor type.""" + # EXECUTOR_TYPES is a Literal type, get its args + import typing + + from models.executors import EXECUTOR_TYPES + if hasattr(typing, "get_args"): + types = typing.get_args(EXECUTOR_TYPES) + else: + types = EXECUTOR_TYPES.__args__ + + assert "lp_executor" in types, "lp_executor should be in EXECUTOR_TYPES" + + def test_lp_executor_config_importable(self): + """LPExecutorConfig should be importable from hummingbot.""" + from hummingbot.strategy_v2.executors.lp_executor.data_types import LPExecutorConfig + assert LPExecutorConfig is not None + + def test_lp_executor_importable(self): + """LPExecutor should be importable from hummingbot.""" + from hummingbot.strategy_v2.executors.lp_executor.lp_executor import LPExecutor + assert LPExecutor is not None + + +class TestGatewayIntegration: + """Integration tests that require Gateway to be running. + + These tests are skipped if Gateway is not available. + Run with: pytest test/test_gateway_lp_executor.py -v -m integration + """ + + @pytest.fixture + def gateway_url(self): + return os.environ.get("GATEWAY_URL", "http://localhost:15888") + + @pytest.fixture + def api_url(self): + return os.environ.get("API_URL", "http://localhost:8000") + + @pytest.fixture + def api_auth(self): + return ( + os.environ.get("API_USER", "admin"), + os.environ.get("API_PASSWORD", "admin") + ) + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_gateway_status(self, api_url, api_auth): + """Check Gateway status via API.""" + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"{api_url}/gateway/status", + auth=aiohttp.BasicAuth(*api_auth) + ) as response: + assert response.status == 200 + data = await response.json() + # Gateway may or may not be running + assert "running" in data + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_lp_executor_types_available(self, api_url, api_auth): + """Verify lp_executor is in available types.""" + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get( + f"{api_url}/executors/types/available", + auth=aiohttp.BasicAuth(*api_auth) + ) as response: + assert response.status == 200 + data = await response.json() + + types = [t["type"] for t in data["executor_types"]] + assert "lp_executor" in types, "lp_executor should be available" + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_create_lp_executor_no_keyerror(self, api_url, api_auth): + """Creating LP executor should not raise KeyError for meteora/clmm. + + This test verifies the fix for the KeyError: 'meteora/clmm' issue. + The request may fail due to Gateway not running, but should NOT fail + with KeyError. + """ + import aiohttp + + payload = { + "account_name": "master_account", + "executor_config": { + "type": "lp_executor", + "connector_name": "meteora/clmm", + "trading_pair": "SOL-USDC", + "pool_address": "BGm1av58oGcsQJehL9WXBFXF7D27vZsKefj4xJKD5Y", + "lower_price": "84", + "upper_price": "84.8", + "base_amount": "0.03555", + "quote_amount": "3", + "side": 1, + } + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{api_url}/executors/", + json=payload, + auth=aiohttp.BasicAuth(*api_auth) + ) as response: + data = await response.json() + + # Should NOT be KeyError + if response.status != 200: + error_detail = data.get("detail", "") + assert "KeyError" not in str(error_detail), \ + f"Should not have KeyError, got: {error_detail}" + assert "'meteora/clmm'" not in str(error_detail) or "KeyError" not in str(error_detail), \ + f"Should not have KeyError for meteora/clmm, got: {error_detail}" + + # Expected error when Gateway is not running + if "Cannot connect" in str(error_detail) or "Gateway" in str(error_detail): + pytest.skip("Gateway not running - this is expected") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_order_timestamp_tracking.py b/test/test_order_timestamp_tracking.py new file mode 100644 index 00000000..dd7c2b12 --- /dev/null +++ b/test/test_order_timestamp_tracking.py @@ -0,0 +1,88 @@ +import ast +from pathlib import Path + + +ROOT = Path("/root/hummingbot-stack/upstream/hummingbot-api") + + +def _load_module_ast(relative_path: str) -> ast.Module: + return ast.parse((ROOT / relative_path).read_text()) + + +def _find_class(module: ast.Module, class_name: str) -> ast.ClassDef: + for node in module.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + return node + raise AssertionError(f"class {class_name} not found") + + +def _find_method(class_node: ast.ClassDef, method_name: str) -> ast.FunctionDef: + for node in class_node.body: + if isinstance(node, ast.FunctionDef) and node.name == method_name: + return node + raise AssertionError(f"method {method_name} not found in {class_node.name}") + + +def _find_async_method(class_node: ast.ClassDef, method_name: str) -> ast.AsyncFunctionDef: + for node in class_node.body: + if isinstance(node, ast.AsyncFunctionDef) and node.name == method_name: + return node + raise AssertionError(f"async method {method_name} not found in {class_node.name}") + + +def _call_name(call: ast.Call) -> str | None: + func = call.func + if isinstance(func, ast.Attribute): + return func.attr + if isinstance(func, ast.Name): + return func.id + return None + + +def _body_contains_timestamp_then_delegate(method_node: ast.AST, delegate_name: str) -> bool: + seen_set_timestamp = False + for node in ast.walk(method_node): + if not isinstance(node, ast.Call): + continue + name = _call_name(node) + if name == "_set_current_timestamp": + seen_set_timestamp = True + if name == delegate_name and seen_set_timestamp: + return True + return False + + +def test_accounts_interface_buy_sets_timestamp_before_buy_delegate(): + module = _load_module_ast("services/accounts_service.py") + class_node = _find_class(module, "AccountTradingInterface") + method_node = _find_method(class_node, "buy") + assert _body_contains_timestamp_then_delegate(method_node, "buy") + + +def test_accounts_interface_sell_sets_timestamp_before_sell_delegate(): + module = _load_module_ast("services/accounts_service.py") + class_node = _find_class(module, "AccountTradingInterface") + method_node = _find_method(class_node, "sell") + assert _body_contains_timestamp_then_delegate(method_node, "sell") + + +def test_accounts_service_place_trade_sets_timestamp_before_order_submission(): + module = _load_module_ast("services/accounts_service.py") + class_node = _find_class(module, "AccountsService") + method_node = _find_async_method(class_node, "place_trade") + assert _body_contains_timestamp_then_delegate(method_node, "buy") + assert _body_contains_timestamp_then_delegate(method_node, "sell") + + +def test_trading_interface_buy_sets_timestamp_before_buy_delegate(): + module = _load_module_ast("services/trading_service.py") + class_node = _find_class(module, "AccountTradingInterface") + method_node = _find_method(class_node, "buy") + assert _body_contains_timestamp_then_delegate(method_node, "buy") + + +def test_trading_interface_sell_sets_timestamp_before_sell_delegate(): + module = _load_module_ast("services/trading_service.py") + class_node = _find_class(module, "AccountTradingInterface") + method_node = _find_method(class_node, "sell") + assert _body_contains_timestamp_then_delegate(method_node, "sell") diff --git a/test/test_trading_router.py b/test/test_trading_router.py new file mode 100644 index 00000000..fcc8670a --- /dev/null +++ b/test/test_trading_router.py @@ -0,0 +1,76 @@ +import ast +import math +import sys +from decimal import Decimal +from enum import Enum +from pathlib import Path +from types import SimpleNamespace + + +SOURCE_PATH = Path("/root/hummingbot-stack/upstream/hummingbot-api/routers/trading.py") + + +class OrderState(Enum): + PENDING_CREATE = 0 + OPEN = 1 + PENDING_CANCEL = 2 + CANCELED = 3 + PARTIALLY_FILLED = 4 + FILLED = 5 + FAILED = 6 + PENDING_APPROVAL = 7 + APPROVED = 8 + CREATED = 9 + COMPLETED = 10 + + +def _load_standardizer(): + source = SOURCE_PATH.read_text() + module = ast.parse(source) + target = None + for node in module.body: + if isinstance(node, ast.FunctionDef) and node.name == "_standardize_in_flight_order_response": + target = node + break + if target is None: + raise AssertionError("standardizer function not found") + + func_source = ast.get_source_segment(source, target) + fake_module = SimpleNamespace(OrderState=OrderState) + sys.modules.setdefault("hummingbot", SimpleNamespace()) + sys.modules.setdefault("hummingbot.core", SimpleNamespace()) + sys.modules.setdefault("hummingbot.core.data_type", SimpleNamespace()) + sys.modules["hummingbot.core.data_type.in_flight_order"] = fake_module + namespace = {"math": math} + exec(func_source, namespace) + return namespace["_standardize_in_flight_order_response"] + + +def test_standardize_in_flight_order_response_handles_nan_timestamps(): + standardize = _load_standardizer() + order = SimpleNamespace( + client_order_id="cid-1", + trading_pair="BTC-USD", + trade_type=SimpleNamespace(name="BUY"), + order_type=SimpleNamespace(name="LIMIT"), + amount=Decimal("0.02"), + price=Decimal("40000"), + current_state=OrderState.PENDING_CREATE, + executed_amount_base=Decimal("0"), + last_executed_price=None, + cumulative_fee_paid_quote=None, + creation_timestamp=float("nan"), + last_update_timestamp=float("nan"), + exchange_order_id=None, + ) + + result = standardize( + order=order, + account_name="master_account", + connector_name="hyperliquid_perpetual_testnet", + ) + + assert result["created_at"] is None + assert result["updated_at"] is None + assert result["price"] == 40000.0 + assert result["filled_amount"] == 0 diff --git a/services/bot_archiver.py b/utils/bot_archiver.py similarity index 100% rename from services/bot_archiver.py rename to utils/bot_archiver.py diff --git a/utils/executor_log_capture.py b/utils/executor_log_capture.py new file mode 100644 index 00000000..d38093f0 --- /dev/null +++ b/utils/executor_log_capture.py @@ -0,0 +1,155 @@ +""" +Executor log capture via in-memory ring buffer. + +Uses Python's contextvars to attribute log records to specific executor instances, +even though executors share class-level loggers. When executor.start() creates an +asyncio Task, the Task inherits the current context - so a ContextVar set before +start() persists for that executor's entire lifetime. +""" +import logging +import traceback +from collections import deque +from contextvars import ContextVar +from datetime import datetime, timezone +from typing import Dict, List, Optional + +# ContextVar that identifies which executor is running in the current async task. +# Set before executor.start() so the spawned Task inherits it. +current_executor_id: ContextVar[Optional[str]] = ContextVar("current_executor_id", default=None) + + +class ExecutorLogHandler(logging.Handler): + """ + Custom logging handler that routes log records to per-executor ring buffers. + + Reads current_executor_id from contextvars to determine which executor + produced the log record. Unattributed records go to a global buffer. + """ + + def __init__(self, capture: "ExecutorLogCapture"): + super().__init__() + self._capture = capture + + def emit(self, record: logging.LogRecord): + try: + entry = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "message": self.format(record), + } + + if record.exc_info and record.exc_info[1] is not None: + entry["exc_info"] = "".join(traceback.format_exception(*record.exc_info)) + + executor_id = current_executor_id.get() + if executor_id is not None: + self._capture._append_log(executor_id, entry) + else: + self._capture._append_global(entry) + except Exception: + self.handleError(record) + + +class ExecutorLogCapture: + """ + Singleton-style class that manages per-executor log ring buffers. + + Usage: + capture = ExecutorLogCapture() + capture.install() # attaches handler to executor loggers + + # Before executor.start(): + token = current_executor_id.set(executor_id) + executor.start() + current_executor_id.reset(token) + + # Later: + logs = capture.get_logs(executor_id) + """ + + def __init__(self, per_executor_max: int = 50, global_max: int = 200): + self._per_executor_max = per_executor_max + self._global_max = global_max + self._logs: Dict[str, deque] = {} + self._global_logs: deque = deque(maxlen=global_max) + self._handler: Optional[ExecutorLogHandler] = None + + def install(self): + """Attach the log handler to the hummingbot executor logger hierarchy.""" + if self._handler is not None: + return + + self._handler = ExecutorLogHandler(self) + self._handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)s - %(message)s") + self._handler.setFormatter(formatter) + + # Attach to the parent logger for all executors + logger = logging.getLogger("hummingbot.strategy_v2.executors") + logger.setLevel(logging.INFO) + logger.addHandler(self._handler) + + def uninstall(self): + """Remove the log handler.""" + if self._handler is None: + return + + logger = logging.getLogger("hummingbot.strategy_v2.executors") + logger.removeHandler(self._handler) + self._handler = None + + def _append_log(self, executor_id: str, entry: dict): + if executor_id not in self._logs: + self._logs[executor_id] = deque(maxlen=self._per_executor_max) + self._logs[executor_id].append(entry) + + def _append_global(self, entry: dict): + self._global_logs.append(entry) + + def get_logs( + self, + executor_id: str, + level: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[dict]: + """Get log entries for a specific executor.""" + buf = self._logs.get(executor_id) + if buf is None: + return [] + + logs = list(buf) + if level: + level_upper = level.upper() + logs = [e for e in logs if e["level"] == level_upper] + if limit: + logs = logs[-limit:] + return logs + + def get_error_count(self, executor_id: str) -> int: + """Get count of ERROR-level logs for an executor.""" + buf = self._logs.get(executor_id) + if buf is None: + return 0 + return sum(1 for e in buf if e["level"] == "ERROR") + + def get_last_error(self, executor_id: str) -> Optional[str]: + """Get the most recent ERROR message for an executor, or None.""" + buf = self._logs.get(executor_id) + if buf is None: + return None + for entry in reversed(buf): + if entry["level"] == "ERROR": + return entry["message"] + return None + + def get_global_logs(self, level: Optional[str] = None) -> List[dict]: + """Get unattributed (global) log entries.""" + logs = list(self._global_logs) + if level: + level_upper = level.upper() + logs = [e for e in logs if e["level"] == level_upper] + return logs + + def clear(self, executor_id: str): + """Remove logs for a specific executor.""" + self._logs.pop(executor_id, None) diff --git a/utils/file_system.py b/utils/file_system.py index d93bfece..c88a39a2 100644 --- a/utils/file_system.py +++ b/utils/file_system.py @@ -5,36 +5,66 @@ import shutil import sys from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Type import yaml from hummingbot.client.config.config_data_types import BaseClientModel from hummingbot.client.config.config_helpers import ClientConfigAdapter +from hummingbot.strategy_v2.controllers.controller_base import ControllerConfigBase +from hummingbot.strategy_v2.controllers.directional_trading_controller_base import DirectionalTradingControllerConfigBase +from hummingbot.strategy_v2.controllers.market_making_controller_base import MarketMakingControllerConfigBase + +# Create module-specific logger +logger = logging.getLogger(__name__) class FileSystemUtil: """ FileSystemUtil provides utility functions for file and directory management, as well as dynamic loading of script configurations. + + All file operations are performed relative to the base_path unless an absolute path is provided. + Implements singleton pattern to ensure the same instance is reused. """ + _instance = None base_path: str = "bots" # Default base path + def __new__(cls, base_path: Optional[str] = None): + if cls._instance is None: + cls._instance = super(FileSystemUtil, cls).__new__(cls) + cls._instance.base_path = base_path if base_path else "bots" + return cls._instance + def __init__(self, base_path: Optional[str] = None): """ Initializes the FileSystemUtil with a base path. :param base_path: The base directory path for file operations. """ - if base_path: - self.base_path = base_path + # Singleton pattern - instance already configured in __new__ + pass + + def _get_full_path(self, path: str) -> str: + """ + Get the full path by combining base_path with relative path. + :param path: Relative or absolute path. + :return: Full absolute path. + """ + return path if os.path.isabs(path) else os.path.join(self.base_path, path) def list_files(self, directory: str) -> List[str]: """ Lists all files in a given directory. :param directory: The directory to list files from. :return: List of file names in the directory. + :raises FileNotFoundError: If the directory does not exist. + :raises PermissionError: If access is denied to the directory. """ excluded_files = ["__init__.py", "__pycache__", ".DS_Store", ".dockerignore", ".gitignore"] - dir_path = os.path.join(self.base_path, directory) + dir_path = self._get_full_path(directory) + if not os.path.exists(dir_path): + raise FileNotFoundError(f"Directory '{directory}' not found") + if not os.path.isdir(dir_path): + raise NotADirectoryError(f"Path '{directory}' is not a directory") return [f for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f)) and f not in excluded_files] def list_folders(self, directory: str) -> List[str]: @@ -42,62 +72,97 @@ def list_folders(self, directory: str) -> List[str]: Lists all folders in a given directory. :param directory: The directory to list folders from. :return: List of folder names in the directory. + :raises FileNotFoundError: If the directory does not exist. + :raises PermissionError: If access is denied to the directory. """ - dir_path = os.path.join(self.base_path, directory) + dir_path = self._get_full_path(directory) + if not os.path.exists(dir_path): + raise FileNotFoundError(f"Directory '{directory}' not found") + if not os.path.isdir(dir_path): + raise NotADirectoryError(f"Path '{directory}' is not a directory") return [d for d in os.listdir(dir_path) if os.path.isdir(os.path.join(dir_path, d))] - def create_folder(self, directory: str, folder_name: str): + def create_folder(self, directory: str, folder_name: str) -> None: """ Creates a folder in a specified directory. :param directory: The directory to create the folder in. :param folder_name: The name of the folder to be created. + :raises PermissionError: If permission is denied to create the folder. + :raises OSError: If there's an OS-level error creating the folder. """ - folder_path = os.path.join(self.base_path, directory, folder_name) + if not folder_name or '/' in folder_name or '\\' in folder_name: + raise ValueError(f"Invalid folder name: '{folder_name}'") + folder_path = self._get_full_path(os.path.join(directory, folder_name)) os.makedirs(folder_path, exist_ok=True) - def copy_folder(self, src: str, dest: str): + def copy_folder(self, src: str, dest: str) -> None: """ Copies a folder to a new location. :param src: The source folder to copy. :param dest: The destination folder to copy to. + :raises FileNotFoundError: If source folder doesn't exist. + :raises PermissionError: If permission is denied. """ - src_path = os.path.join(self.base_path, src) - dest_path = os.path.join(self.base_path, dest) - os.makedirs(dest_path, exist_ok=True) - for item in os.listdir(src_path): - s = os.path.join(src_path, item) - d = os.path.join(dest_path, item) - if os.path.isdir(s): - self.copy_folder(s, d) - else: - shutil.copy2(s, d) + src_path = self._get_full_path(src) + dest_path = self._get_full_path(dest) + + if not os.path.exists(src_path): + raise FileNotFoundError(f"Source folder '{src}' not found") + if not os.path.isdir(src_path): + raise NotADirectoryError(f"Source path '{src}' is not a directory") + + shutil.copytree(src_path, dest_path, dirs_exist_ok=True) - def copy_file(self, src: str, dest: str): + def copy_file(self, src: str, dest: str) -> None: """ Copies a file to a new location. :param src: The source file to copy. :param dest: The destination file to copy to. + :raises FileNotFoundError: If source file doesn't exist. + :raises PermissionError: If permission is denied. """ - src_path = os.path.join(self.base_path, src) - dest_path = os.path.join(self.base_path, dest) + src_path = self._get_full_path(src) + dest_path = self._get_full_path(dest) + + if not os.path.exists(src_path): + raise FileNotFoundError(f"Source file '{src}' not found") + if os.path.isdir(src_path): + raise IsADirectoryError(f"Source path '{src}' is a directory, not a file") + + # Ensure destination directory exists + dest_dir = os.path.dirname(dest_path) + os.makedirs(dest_dir, exist_ok=True) + shutil.copy2(src_path, dest_path) - def delete_folder(self, directory: str, folder_name: str): + def delete_folder(self, directory: str, folder_name: str) -> None: """ Deletes a folder in a specified directory. :param directory: The directory to delete the folder from. :param folder_name: The name of the folder to be deleted. + :raises FileNotFoundError: If folder doesn't exist. + :raises PermissionError: If permission is denied. """ - folder_path = os.path.join(self.base_path, directory, folder_name) + folder_path = self._get_full_path(os.path.join(directory, folder_name)) + if not os.path.exists(folder_path): + raise FileNotFoundError(f"Folder '{folder_name}' not found in '{directory}'") + if not os.path.isdir(folder_path): + raise NotADirectoryError(f"Path '{folder_name}' is not a directory") shutil.rmtree(folder_path) - def delete_file(self, directory: str, file_name: str): + def delete_file(self, directory: str, file_name: str) -> None: """ Deletes a file in a specified directory. :param directory: The directory to delete the file from. :param file_name: The name of the file to be deleted. + :raises FileNotFoundError: If file doesn't exist. + :raises PermissionError: If permission is denied. """ - file_path = os.path.join(self.base_path, directory, file_name) + file_path = self._get_full_path(os.path.join(directory, file_name)) + if not os.path.exists(file_path): + raise FileNotFoundError(f"File '{file_name}' not found in '{directory}'") + if os.path.isdir(file_path): + raise IsADirectoryError(f"Path '{file_name}' is a directory, not a file") os.remove(file_path) def path_exists(self, path: str) -> bool: @@ -106,56 +171,101 @@ def path_exists(self, path: str) -> bool: :param path: The path to check. :return: True if the path exists, False otherwise. """ - return os.path.exists(os.path.join(self.base_path, path)) + return os.path.exists(self._get_full_path(path)) - def add_file(self, directory: str, file_name: str, content: str, override: bool = False): + def add_file(self, directory: str, file_name: str, content: str, override: bool = False) -> None: """ Adds a file to a specified directory. :param directory: The directory to add the file to. :param file_name: The name of the file to be added. :param content: The content to be written to the file. :param override: If True, override the file if it exists. + :raises ValueError: If file_name is invalid. + :raises FileExistsError: If file exists and override is False. + :raises PermissionError: If permission is denied to write the file. """ - file_path = os.path.join(self.base_path, directory, file_name) + if not file_name or '/' in file_name or '\\' in file_name: + raise ValueError(f"Invalid file name: '{file_name}'") + + dir_path = self._get_full_path(directory) + os.makedirs(dir_path, exist_ok=True) + + file_path = os.path.join(dir_path, file_name) if not override and os.path.exists(file_path): raise FileExistsError(f"File '{file_name}' already exists in '{directory}'.") - with open(file_path, 'w') as file: + + with open(file_path, 'w', encoding='utf-8') as file: file.write(content) - def append_to_file(self, directory: str, file_name: str, content: str): + def append_to_file(self, directory: str, file_name: str, content: str) -> None: """ Appends content to a specified file. :param directory: The directory containing the file. :param file_name: The name of the file to append to. :param content: The content to append to the file. + :raises FileNotFoundError: If file doesn't exist. + :raises PermissionError: If permission is denied. """ - file_path = os.path.join(self.base_path, directory, file_name) - with open(file_path, 'a') as file: + file_path = self._get_full_path(os.path.join(directory, file_name)) + if not os.path.exists(file_path): + raise FileNotFoundError(f"File '{file_name}' not found in '{directory}'") + if os.path.isdir(file_path): + raise IsADirectoryError(f"Path '{file_name}' is a directory, not a file") + + with open(file_path, 'a', encoding='utf-8') as file: file.write(content) - @staticmethod - def dump_dict_to_yaml(filename, data_dict): + def read_file(self, file_path: str) -> str: + """ + Reads the content of a file. + :param file_path: The relative path to the file from base_path. + :return: The content of the file as a string. + :raises FileNotFoundError: If the file does not exist. + :raises PermissionError: If access is denied to the file. + :raises IsADirectoryError: If the path points to a directory. + """ + full_path = self._get_full_path(file_path) + if not os.path.exists(full_path): + raise FileNotFoundError(f"File '{file_path}' not found") + if os.path.isdir(full_path): + raise IsADirectoryError(f"Path '{file_path}' is a directory, not a file") + + with open(full_path, 'r', encoding='utf-8') as file: + return file.read() + + def dump_dict_to_yaml(self, filename: str, data_dict: dict) -> None: """ Dumps a dictionary to a YAML file. + :param filename: The file to dump the dictionary into (relative to base_path). :param data_dict: The dictionary to dump. - :param filename: The file to dump the dictionary into. + :raises PermissionError: If permission is denied to write the file. """ - with open(filename, 'w') as file: - yaml.dump(data_dict, file) + file_path = self._get_full_path(filename) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, 'w', encoding='utf-8') as file: + yaml.dump(data_dict, file, default_flow_style=False, allow_unicode=True) - @staticmethod - def read_yaml_file(file_path): + def read_yaml_file(self, file_path: str) -> dict: """ Reads a YAML file and returns the data as a dictionary. - :param file_path: The path to the YAML file. + :param file_path: The path to the YAML file (relative to base_path or absolute). :return: Dictionary containing the YAML file data. + :raises FileNotFoundError: If the file doesn't exist. + :raises yaml.YAMLError: If the YAML is invalid. """ - with open(file_path, 'r') as file: - data = yaml.safe_load(file) - return data + full_path = self._get_full_path(file_path) if not os.path.isabs(file_path) else file_path + if not os.path.exists(full_path): + raise FileNotFoundError(f"YAML file '{file_path}' not found") + + with open(full_path, 'r', encoding='utf-8') as file: + try: + data = yaml.safe_load(file) + return data if data is not None else {} + except yaml.YAMLError as e: + raise yaml.YAMLError(f"Invalid YAML in file '{file_path}': {e}") @staticmethod - def load_script_config_class(script_name): + def load_script_config_class(script_name: str) -> Optional[Type[BaseClientModel]]: """ Dynamically loads a script's configuration class. :param script_name: The name of the script file (without the '.py' extension). @@ -173,30 +283,188 @@ def load_script_config_class(script_name): for _, cls in inspect.getmembers(script_module, inspect.isclass): if issubclass(cls, BaseClientModel) and cls is not BaseClientModel: return cls - except Exception as e: - print(f"Error loading script class: {e}") # Handle or log the error appropriately + except (ImportError, AttributeError, ModuleNotFoundError) as e: + logger.warning(f"Error loading script class for '{script_name}': {e}") return None @staticmethod - def ensure_file_and_dump_text(file_path, text): + def load_controller_config_class(controller_type: str, controller_name: str) -> Optional[Type]: """ - Ensures that the directory for the file exists, then dumps the dictionary to a YAML file. - :param file_path: The file path to dump the dictionary into. - :param text: The text to dump. + Dynamically loads a controller's configuration class. + Supports both single-file controllers (controller.py) and + package-style controllers (controller/controller.py). + :param controller_type: The type of the controller. + :param controller_name: The name of the controller file (without the '.py' extension). + :return: The configuration class from the controller, or None if not found. """ - os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, "w") as f: + controller_name = controller_name.replace('.py', '') + + # Try single-file first: bots.controllers.{type}.{name} + # Then package-style: bots.controllers.{type}.{name}.{name} + module_paths = [ + f"bots.controllers.{controller_type}.{controller_name}", + f"bots.controllers.{controller_type}.{controller_name}.{controller_name}", + ] + + for module_name in module_paths: + try: + if module_name not in sys.modules: + script_module = importlib.import_module(module_name) + else: + script_module = importlib.reload(sys.modules[module_name]) + + # Find the subclass of controller config base in the module + for _, cls in inspect.getmembers(script_module, inspect.isclass): + is_directional = (issubclass(cls, DirectionalTradingControllerConfigBase) + and cls is not DirectionalTradingControllerConfigBase) + is_market_making = (issubclass(cls, MarketMakingControllerConfigBase) + and cls is not MarketMakingControllerConfigBase) + is_generic = (issubclass(cls, ControllerConfigBase) + and cls is not ControllerConfigBase) + if is_directional or is_market_making or is_generic: + return cls + except (ImportError, AttributeError, ModuleNotFoundError): + continue + + logger.warning(f"Could not load controller class for '{controller_type}.{controller_name}'") + return None + + def ensure_file_and_dump_text(self, file_path: str, text: str) -> None: + """ + Ensures that the directory for the file exists, then writes text to a file. + :param file_path: The file path to write to (relative to base_path or absolute). + :param text: The text to write. + :raises PermissionError: If permission is denied. + """ + full_path = self._get_full_path(file_path) if not os.path.isabs(file_path) else file_path + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding='utf-8') as f: f.write(text) - @staticmethod - # TODO: make paths relative - def get_connector_keys_path(account_name: str, connector_name: str) -> Path: - return Path(f"bots/credentials/{account_name}/connectors/{connector_name}.yml") + def get_connector_keys_path(self, account_name: str, connector_name: str) -> Path: + """ + Get the path to connector credentials file. + :param account_name: Name of the account. + :param connector_name: Name of the connector. + :return: Path to the connector credentials file. + """ + return Path("credentials") / account_name / "connectors" / f"{connector_name}.yml" - def save_model_to_yml(yml_path: Path, cm: ClientConfigAdapter): + def save_model_to_yml(self, yml_path: str, cm: ClientConfigAdapter) -> None: + """ + Save a ClientConfigAdapter model to a YAML file. + :param yml_path: Path to the YAML file (relative to base_path or absolute). + :param cm: The ClientConfigAdapter to save. + :raises PermissionError: If permission is denied to write the file. + """ try: + full_path = self._get_full_path(yml_path) cm_yml_str = cm.generate_yml_output_str_with_comments() - with open(yml_path, "w", encoding="utf-8") as outfile: + os.makedirs(os.path.dirname(full_path), exist_ok=True) + with open(full_path, "w", encoding="utf-8") as outfile: outfile.write(cm_yml_str) except Exception as e: - logging.error("Error writing configs: %s" % (str(e),), exc_info=True) + logger.error(f"Error writing configs to '{yml_path}': {e}", exc_info=True) + raise + + def get_base_path(self) -> str: + """ + Returns the base path for file operations + :return: The base path string + """ + return self.base_path + + def get_directory_creation_time(self, path): + """ + Get the creation time of a directory + :param path: The path to the directory + :return: ISO formatted creation time string or None if directory doesn't exist + """ + import datetime + import os + + full_path = self._get_full_path(path) + if not os.path.exists(full_path): + return None + + # Get creation time (platform dependent) + try: + # For Unix systems, use stat + creation_time = os.stat(full_path).st_ctime + # Convert to datetime + return datetime.datetime.fromtimestamp(creation_time).isoformat() + except Exception: + # Fallback + return "unknown" + + def list_directories(self, path): + """ + List all directories within a given path + :param path: The path to list directories from + :return: List of directory names + """ + import os + + full_path = self._get_full_path(path) + if not os.path.exists(full_path): + return [] + + try: + # Return only directories + return [d for d in os.listdir(full_path) if os.path.isdir(os.path.join(full_path, d))] + except Exception: + return [] + + def list_databases(self) -> List[str]: + """ + Lists all database files in archived instances + :return: List of database file paths + """ + try: + archived_instances = self.list_folders("archived") + except FileNotFoundError: + return [] + + archived_databases = [] + for archived_instance in archived_instances: + db_path = self._get_full_path(os.path.join("archived", archived_instance, "data")) + try: + if os.path.exists(db_path): + archived_databases.extend([ + os.path.join(db_path, db_file) + for db_file in os.listdir(db_path) + if db_file.endswith(".sqlite") + ]) + except (OSError, PermissionError) as e: + logger.warning(f"Error accessing database path '{db_path}': {e}") + return archived_databases + + def list_checkpoints(self, full_path: bool = False) -> List[str]: + """ + Lists all checkpoint database files + :param full_path: If True, return full paths, otherwise just filenames + :return: List of checkpoint database files + """ + dir_path = self._get_full_path("data") + if not os.path.exists(dir_path): + return [] + + try: + files = os.listdir(dir_path) + checkpoint_files = [ + f for f in files + if (os.path.isfile(os.path.join(dir_path, f)) + and f.startswith("checkpoint") + and f.endswith(".sqlite")) + ] + + if full_path: + return [os.path.join(dir_path, f) for f in checkpoint_files] + else: + return checkpoint_files + except (OSError, PermissionError) as e: + logger.warning(f"Error listing checkpoints in '{dir_path}': {e}") + return [] + + +fs_util = FileSystemUtil() diff --git a/utils/models.py b/utils/hummingbot_api_config_adapter.py similarity index 95% rename from utils/models.py rename to utils/hummingbot_api_config_adapter.py index 1ecb977b..4dce67a3 100644 --- a/utils/models.py +++ b/utils/hummingbot_api_config_adapter.py @@ -4,11 +4,11 @@ from pydantic import SecretStr -class BackendAPIConfigAdapter(ClientConfigAdapter): +class HummingbotAPIConfigAdapter(ClientConfigAdapter): def _encrypt_secrets(self, conf_dict: Dict[str, Any]): from utils.security import BackendAPISecurity for attr, value in conf_dict.items(): - attr_type = self._hb_config.__fields__[attr].type_ + attr_type = self._hb_config.model_fields[attr].annotation if attr_type == SecretStr: clear_text_value = value.get_secret_value() if isinstance(value, SecretStr) else value conf_dict[attr] = BackendAPISecurity.secrets_manager.encrypt_secret_value(attr, clear_text_value) diff --git a/utils/hummingbot_database_reader.py b/utils/hummingbot_database_reader.py new file mode 100644 index 00000000..6a57d260 --- /dev/null +++ b/utils/hummingbot_database_reader.py @@ -0,0 +1,308 @@ +import os +import pandas as pd +import json +from typing import List, Dict, Any + +from hummingbot.core.data_type.common import TradeType +from hummingbot.strategy_v2.models.base import RunnableStatus +from hummingbot.strategy_v2.models.executors import CloseType +from hummingbot.strategy_v2.models.executors_info import ExecutorInfo +from sqlalchemy import create_engine, insert, text, MetaData, Table, Column, VARCHAR, INT, FLOAT, Integer, String, Float +from sqlalchemy.orm import sessionmaker + + +class HummingbotDatabase: + def __init__(self, db_path: str): + self.db_name = os.path.basename(db_path) + self.db_path = db_path + self.db_path = f'sqlite:///{os.path.join(db_path)}' + self.engine = create_engine(self.db_path, connect_args={'check_same_thread': False}) + self.session_maker = sessionmaker(bind=self.engine) + + @staticmethod + def _get_table_status(table_loader): + try: + data = table_loader() + return "Correct" if len(data) > 0 else f"Error - No records matched" + except Exception as e: + return f"Error - {str(e)}" + + @property + def status(self): + trade_fill_status = self._get_table_status(self.get_trade_fills) + orders_status = self._get_table_status(self.get_orders) + order_status_status = self._get_table_status(self.get_order_status) + executors_status = self._get_table_status(self.get_executors_data) + controller_status = self._get_table_status(self.get_controllers_data) + positions_status = self._get_table_status(self.get_positions) + general_status = all(status == "Correct" for status in + [trade_fill_status, orders_status, order_status_status, executors_status, controller_status, positions_status]) + status = {"db_name": self.db_name, + "db_path": self.db_path, + "trade_fill": trade_fill_status, + "orders": orders_status, + "order_status": order_status_status, + "executors": executors_status, + "controllers": controller_status, + "positions": positions_status, + "general_status": general_status + } + return status + + def get_orders(self): + with self.session_maker() as session: + query = "SELECT * FROM 'Order'" + orders = pd.read_sql_query(text(query), session.connection()) + orders["amount"] = orders["amount"] / 1e6 + orders["price"] = orders["price"] / 1e6 + orders.rename(columns={"market": "connector_name", "symbol": "trading_pair"}, inplace=True) + return orders + + def get_trade_fills(self): + groupers = ["config_file_path", "connector_name", "trading_pair"] + float_cols = ["amount", "price", "trade_fee_in_quote"] + with self.session_maker() as session: + query = "SELECT * FROM TradeFill" + trade_fills = pd.read_sql_query(text(query), session.connection()) + trade_fills.rename(columns={"market": "connector_name", "symbol": "trading_pair"}, inplace=True) + trade_fills[float_cols] = trade_fills[float_cols] / 1e6 + trade_fills["cum_fees_in_quote"] = trade_fills.groupby(groupers)["trade_fee_in_quote"].cumsum() + trade_fills["trade_fee"] = trade_fills.groupby(groupers)["cum_fees_in_quote"].diff() + return trade_fills + + def get_order_status(self): + with self.session_maker() as session: + query = "SELECT * FROM OrderStatus" + order_status = pd.read_sql_query(text(query), session.connection()) + return order_status + + def get_executors_data(self) -> pd.DataFrame: + with self.session_maker() as session: + query = "SELECT * FROM Executors" + executors = pd.read_sql_query(text(query), session.connection()) + return executors + + def get_controllers_data(self) -> pd.DataFrame: + with self.session_maker() as session: + query = "SELECT * FROM Controllers" + controllers = pd.read_sql_query(text(query), session.connection()) + return controllers + + def get_positions(self) -> pd.DataFrame: + with self.session_maker() as session: + query = "SELECT * FROM Position" + positions = pd.read_sql_query(text(query), session.connection()) + # Convert decimal fields from stored format (divide by 1e6) + decimal_cols = ["volume_traded_quote", "amount", "breakeven_price", "unrealized_pnl_quote", "cum_fees_quote"] + positions[decimal_cols] = positions[decimal_cols] / 1e6 + return positions + + def calculate_trade_based_performance(self) -> pd.DataFrame: + """ + Calculate trade-based performance metrics using vectorized pandas operations. + + Returns: + DataFrame with rolling performance metrics calculated per trading pair. + """ + # Get trade fills data + trades = self.get_trade_fills() + + if len(trades) == 0: + return pd.DataFrame() + + # Sort by timestamp to ensure proper rolling calculation + trades = trades.sort_values(['trading_pair', 'connector_name', 'timestamp']).copy() + + # Create buy/sell indicator columns + trades['is_buy'] = (trades['trade_type'].str.upper() == 'BUY').astype(int) + trades['is_sell'] = (trades['trade_type'].str.upper() == 'SELL').astype(int) + + # Calculate buy and sell amounts and values vectorized + trades['buy_amount'] = trades['amount'] * trades['is_buy'] + trades['sell_amount'] = trades['amount'] * trades['is_sell'] + trades['buy_value'] = trades['price'] * trades['amount'] * trades['is_buy'] + trades['sell_value'] = trades['price'] * trades['amount'] * trades['is_sell'] + + # Group by trading_pair and connector_name for rolling calculations + grouper = ['trading_pair', 'connector_name'] + + # Calculate cumulative volumes and values + trades['buy_volume'] = trades.groupby(grouper)['buy_amount'].cumsum() + trades['sell_volume'] = trades.groupby(grouper)['sell_amount'].cumsum() + trades['buy_value_cum'] = trades.groupby(grouper)['buy_value'].cumsum() + trades['sell_value_cum'] = trades.groupby(grouper)['sell_value'].cumsum() + + # Calculate average prices (avoid division by zero) + trades['buy_avg_price'] = trades['buy_value_cum'] / trades['buy_volume'].replace(0, pd.NA) + trades['sell_avg_price'] = trades['sell_value_cum'] / trades['sell_volume'].replace(0, pd.NA) + + # Forward fill average prices within each group to handle NaN values + trades['buy_avg_price'] = trades.groupby(grouper)['buy_avg_price'].ffill().fillna(0) + trades['sell_avg_price'] = trades.groupby(grouper)['sell_avg_price'].ffill().fillna(0) + + # Calculate net position + trades['net_position'] = trades['buy_volume'] - trades['sell_volume'] + + # Calculate realized PnL + trades['realized_trade_pnl_pct'] = ( + (trades['sell_avg_price'] - trades['buy_avg_price']) / trades['buy_avg_price'] + ).fillna(0) + + # Matched volume for realized PnL (minimum of buy and sell volumes) + trades['matched_volume'] = pd.concat([trades['buy_volume'], trades['sell_volume']], axis=1).min(axis=1) + trades['realized_trade_pnl_quote'] = trades['realized_trade_pnl_pct'] * trades['matched_volume'] * trades['buy_avg_price'] + + # Calculate unrealized PnL based on position direction + # For long positions (net_position > 0): use current price vs buy_avg_price + # For short positions (net_position < 0): use sell_avg_price vs current price + trades['unrealized_trade_pnl_pct'] = 0.0 + + # Long positions + long_mask = trades['net_position'] > 0 + trades.loc[long_mask, 'unrealized_trade_pnl_pct'] = ( + (trades.loc[long_mask, 'price'] - trades.loc[long_mask, 'buy_avg_price']) / + trades.loc[long_mask, 'buy_avg_price'] + ).fillna(0) + + # Short positions + short_mask = trades['net_position'] < 0 + trades.loc[short_mask, 'unrealized_trade_pnl_pct'] = ( + (trades.loc[short_mask, 'sell_avg_price'] - trades.loc[short_mask, 'price']) / + trades.loc[short_mask, 'sell_avg_price'] + ).fillna(0) + + # Calculate unrealized PnL in quote currency + trades['unrealized_trade_pnl_quote'] = 0.0 + + # Long positions: use buy_avg_price as reference + long_mask = trades['net_position'] > 0 + trades.loc[long_mask, 'unrealized_trade_pnl_quote'] = ( + trades.loc[long_mask, 'unrealized_trade_pnl_pct'] * + trades.loc[long_mask, 'net_position'].abs() * + trades.loc[long_mask, 'buy_avg_price'] + ) + + # Short positions: use sell_avg_price as reference + short_mask = trades['net_position'] < 0 + trades.loc[short_mask, 'unrealized_trade_pnl_quote'] = ( + trades.loc[short_mask, 'unrealized_trade_pnl_pct'] * + trades.loc[short_mask, 'net_position'].abs() * + trades.loc[short_mask, 'sell_avg_price'] + ) + + # Fees are already in trade_fee_in_quote column + trades['fees_quote'] = trades['trade_fee_in_quote'] + + # Calculate net PnL + trades['net_pnl_quote'] = ( + trades['realized_trade_pnl_quote'] + + trades['unrealized_trade_pnl_quote'] - + trades['fees_quote'] + ) + + # Calculate cumulative volume in quote currency + trades['volume_quote'] = trades['price'] * trades['amount'] + trades['cum_volume_quote'] = trades.groupby(grouper)['volume_quote'].cumsum() + + # Select and return relevant columns + result_columns = [ + 'timestamp', 'price', 'amount', 'trade_type', 'trading_pair', 'connector_name', + 'buy_avg_price', 'buy_volume', 'sell_avg_price', 'sell_volume', + 'net_position', 'realized_trade_pnl_pct', 'realized_trade_pnl_quote', + 'unrealized_trade_pnl_pct', 'unrealized_trade_pnl_quote', + 'fees_quote', 'net_pnl_quote', 'volume_quote', 'cum_volume_quote' + ] + + return trades[result_columns].sort_values('timestamp') + + + +class PerformanceDataSource: + def __init__(self, executors_dict: Dict[str, Any]): + self.executors_dict = executors_dict + + @property + def executors_df(self): + executors = pd.DataFrame(self.executors_dict) + executors["custom_info"] = executors["custom_info"].apply( + lambda x: json.loads(x) if isinstance(x, str) else x) + executors["config"] = executors["config"].apply(lambda x: json.loads(x) if isinstance(x, str) else x) + executors["timestamp"] = executors["timestamp"].apply(lambda x: self.ensure_timestamp_in_seconds(x)) + executors["close_timestamp"] = executors["close_timestamp"].apply( + lambda x: self.ensure_timestamp_in_seconds(x)) + executors["trading_pair"] = executors["config"].apply(lambda x: x["trading_pair"]) + executors["exchange"] = executors["config"].apply(lambda x: x["connector_name"]) + executors["level_id"] = executors["config"].apply(lambda x: x.get("level_id")) + executors["bep"] = executors["custom_info"].apply(lambda x: x["current_position_average_price"]) + executors["order_ids"] = executors["custom_info"].apply(lambda x: x.get("order_ids")) + executors["close_price"] = executors["custom_info"].apply(lambda x: x.get("close_price", x["current_position_average_price"])) + executors["sl"] = executors["config"].apply(lambda x: x.get("stop_loss")).fillna(0) + executors["tp"] = executors["config"].apply(lambda x: x.get("take_profit")).fillna(0) + executors["tl"] = executors["config"].apply(lambda x: x.get("time_limit")).fillna(0) + return executors + + @property + def executor_info_list(self) -> List[ExecutorInfo]: + executors = self.apply_special_data_types(self.executors_df) + executor_values = [] + for index, row in executors.iterrows(): + executor_to_append = ExecutorInfo( + id=row["id"], + timestamp=row["timestamp"], + type=row["type"], + close_timestamp=row["close_timestamp"], + close_type=row["close_type"], + status=row["status"], + config=row["config"], + net_pnl_pct=row["net_pnl_pct"], + net_pnl_quote=row["net_pnl_quote"], + cum_fees_quote=row["cum_fees_quote"], + filled_amount_quote=row["filled_amount_quote"], + is_active=row["is_active"], + is_trading=row["is_trading"], + custom_info=row["custom_info"], + controller_id=row["controller_id"] + ) + executor_to_append.custom_info["side"] = row["side"] + executor_values.append(executor_to_append) + return executor_values + + def apply_special_data_types(self, executors): + executors["status"] = executors["status"].apply(lambda x: self.get_enum_by_value(RunnableStatus, int(x))) + executors["side"] = executors["config"].apply(lambda x: self.get_enum_by_value(TradeType, int(x["side"]))) + executors["close_type"] = executors["close_type"].apply(lambda x: self.get_enum_by_value(CloseType, int(x))) + executors["close_type_name"] = executors["close_type"].apply(lambda x: x.name) + executors["datetime"] = pd.to_datetime(executors.timestamp, unit="s") + executors["close_datetime"] = pd.to_datetime(executors["close_timestamp"], unit="s") + return executors + + @staticmethod + def get_enum_by_value(enum_class, value): + for member in enum_class: + if member.value == value: + return member + raise ValueError(f"No enum member with value {value}") + + @staticmethod + def ensure_timestamp_in_seconds(timestamp: float) -> float: + """ + Ensure the given timestamp is in seconds. + Args: + - timestamp (int): The input timestamp which could be in seconds, milliseconds, or microseconds. + Returns: + - int: The timestamp in seconds. + Raises: + - ValueError: If the timestamp is not in a recognized format. + """ + timestamp_int = int(float(timestamp)) + if timestamp_int >= 1e18: # Nanoseconds + return timestamp_int / 1e9 + elif timestamp_int >= 1e15: # Microseconds + return timestamp_int / 1e6 + elif timestamp_int >= 1e12: # Milliseconds + return timestamp_int / 1e3 + elif timestamp_int >= 1e9: # Seconds + return timestamp_int + else: + raise ValueError( + "Timestamp is not in a recognized format. Must be in seconds, milliseconds, microseconds or nanoseconds.") \ No newline at end of file diff --git a/utils/mqtt_manager.py b/utils/mqtt_manager.py new file mode 100644 index 00000000..3495eadb --- /dev/null +++ b/utils/mqtt_manager.py @@ -0,0 +1,560 @@ +import asyncio +import json +import logging +import time +from collections import defaultdict, deque +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, Optional, Set + +import aiomqtt + +logger = logging.getLogger(__name__) + + +class MQTTManager: + """ + Manages MQTT connections and message handling for Hummingbot bot communication. + Uses asyncio-mqtt (aiomqtt) for asynchronous MQTT operations. + """ + + def __init__(self, host: str, port: int, username: str, password: str): + self.host = host + self.port = port + self.username = username + self.password = password + + # Message handlers by topic pattern + self._handlers: Dict[str, Callable] = {} + + # Bot data storage - stores full controller reports (performance + custom_info) + self._bot_controller_reports: Dict[str, Dict] = defaultdict(dict) + self._bot_logs: Dict[str, deque] = defaultdict(lambda: deque(maxlen=100)) + self._bot_error_logs: Dict[str, deque] = defaultdict(lambda: deque(maxlen=100)) + + # Auto-discovered bots + self._discovered_bots: Dict[str, float] = {} # bot_id: last_seen_timestamp + + # Message deduplication tracking + self._processed_messages: Dict[str, float] = {} # message_hash: timestamp + self._message_ttl = 300 # 5 minutes TTL for processed messages + + # Connection state + self._connected = False + self._reconnect_interval = 5 # seconds + self._client: Optional[aiomqtt.Client] = None + self._tasks: Set[asyncio.Task] = set() + + # RPC response tracking + self._pending_responses: Dict[str, asyncio.Future] = {} # reply_to_topic: future + + # Subscriptions to restore on reconnect + self._subscriptions = [ + ("hbot/+/log", 1), # Log messages + ("hbot/+/notify", 1), # Notifications + ("hbot/+/status_updates", 1), # Status updates + ("hbot/+/events", 1), # Internal events + ("hbot/+/hb", 1), # Heartbeats + ("hbot/+/performance", 1), # Performance metrics + ("hbot/+/external/event/+", 1), # External events + ("hummingbot-api/response/+", 1), # RPC responses to our reply_to topics + ] + + if username: + logger.info(f"MQTT client configured for user: {username}") + else: + logger.info("MQTT client configured without authentication") + + @asynccontextmanager + async def _get_client(self): + """Get MQTT client for a single connection attempt.""" + client_id = f"hummingbot-api-{int(time.time())}" + + # Create client with credentials if provided + if self.username and self.password: + client = aiomqtt.Client( + hostname=self.host, + port=self.port, + username=self.username, + password=self.password, + identifier=client_id, + keepalive=60, + ) + else: + client = aiomqtt.Client(hostname=self.host, port=self.port, identifier=client_id, keepalive=60) + + async with client: + self._connected = True + logger.info(f"✓ Connected to MQTT broker at {self.host}:{self.port}") + + # Subscribe to topics + for topic, qos in self._subscriptions: + await client.subscribe(topic, qos=qos) + yield client + + # Cleanup on exit + self._connected = False + + async def _handle_messages(self): + """Main message handling loop with reconnection.""" + while True: + try: + async with self._get_client() as client: + self._client = client + async for message in client.messages: + await self._process_message(message) + except aiomqtt.MqttError as error: + logger.error(f'MQTT disconnected during message iteration: "{error}". Reconnecting...') + await asyncio.sleep(self._reconnect_interval) + except Exception as e: + logger.error(f"Unexpected error in message handler: {e}. Reconnecting...") + await asyncio.sleep(self._reconnect_interval) + + async def _process_message(self, message): + """Process incoming MQTT message.""" + try: + topic = str(message.topic) + + # Check if this is an RPC response to our hummingbot-api + if topic.startswith("hummingbot-api/response/"): + await self._handle_rpc_response(topic, message) + return + + topic_parts = topic.split("/") + + # Check if this matches namespace/instance_id/channel pattern + if len(topic_parts) >= 3: + namespace, bot_id, channel = topic_parts[0], topic_parts[1], "/".join(topic_parts[2:]) + # Only process if it's the expected namespace + if namespace == "hbot": + # Auto-discover bot + self._discovered_bots[bot_id] = time.time() + # Parse message + try: + data = json.loads(message.payload.decode("utf-8")) + except json.JSONDecodeError: + data = message.payload.decode("utf-8") + + # Route to appropriate handler based on Hummingbot's topics + if channel == "log": + await self._handle_log(bot_id, data) + elif channel == "notify": + await self._handle_notify(bot_id, data) + elif channel == "status_updates": + await self._handle_status(bot_id, data) + elif channel == "hb": # heartbeat + await self._handle_heartbeat(bot_id, data) + elif channel == "events": + await self._handle_events(bot_id, data) + elif channel == "performance": + await self._handle_performance(bot_id, data) + elif channel.startswith("response/"): + await self._handle_command_response(bot_id, channel, data) + elif channel.startswith("external/event/"): + await self._handle_external_event(bot_id, channel, data) + elif channel in ["history", "start", "stop", "config", "import_strategy"]: + # These are command channels - responses should come on response/* topics + logger.debug(f"Command channel '{channel}' for bot {bot_id} - waiting for response") + else: + logger.info(f"Unknown channel '{channel}' for bot {bot_id}") + + # Call custom handlers + for pattern, handler in self._handlers.items(): + if self._match_topic(pattern, topic): + if asyncio.iscoroutinefunction(handler): + await handler(bot_id, channel, data) + else: + # Run sync handler in executor + await asyncio.get_event_loop().run_in_executor(None, handler, bot_id, channel, data) + except Exception as e: + logger.error(f"Error processing message from {message.topic}: {e}", exc_info=True) + + def _match_topic(self, pattern: str, topic: str) -> bool: + """Check if topic matches pattern (supports + wildcard).""" + pattern_parts = pattern.split("/") + topic_parts = topic.split("/") + + if len(pattern_parts) != len(topic_parts): + return False + + for p, t in zip(pattern_parts, topic_parts): + if p != "+" and p != t: + return False + return True + + async def _handle_performance(self, bot_id: str, data: Any): + """Handle performance updates. + + Expected data structure from Hummingbot: + { + "controller_id": { + "performance": { ... performance metrics ... }, + "custom_info": { ... custom controller data ... } + } + } + """ + if isinstance(data, dict): + for controller_id, controller_report in data.items(): + if bot_id not in self._bot_controller_reports: + self._bot_controller_reports[bot_id] = {} + self._bot_controller_reports[bot_id][controller_id] = controller_report + + async def _handle_log(self, bot_id: str, data: Any): + """Handle log messages with deduplication.""" + # Create a unique message identifier for deduplication + if isinstance(data, dict): + level = data.get("level_name") or data.get("levelname") or data.get("level", "INFO") + message = data.get("msg") or data.get("message", "") + timestamp = data.get("timestamp") or data.get("time") or time.time() + + # Create hash for deduplication (bot_id + message + timestamp within 1 second) + message_hash = f"{bot_id}:{message}:{int(timestamp)}" + elif isinstance(data, str): + message = data + timestamp = time.time() + level = "INFO" + + # Create hash for string messages + message_hash = f"{bot_id}:{message}:{int(timestamp)}" + else: + return # Skip invalid data + + # Check for duplicates + current_time = time.time() + if message_hash in self._processed_messages: + # Skip duplicate message + logger.debug(f"Skipping duplicate log message from {bot_id}: {message[:50]}...") + return + + # Clean up old message hashes (older than TTL) + expired_hashes = [h for h, t in self._processed_messages.items() if current_time - t > self._message_ttl] + for h in expired_hashes: + del self._processed_messages[h] + + # Record this message as processed + self._processed_messages[message_hash] = current_time + + # Process the message + if isinstance(data, dict): + # Normalize the log entry + log_entry = { + "level_name": level, + "msg": message, + "timestamp": timestamp, + **data, # Include all original fields + } + + if level.upper() == "ERROR": + self._bot_error_logs[bot_id].append(log_entry) + else: + self._bot_logs[bot_id].append(log_entry) + elif isinstance(data, str): + # Handle plain string logs + log_entry = {"level_name": "INFO", "msg": data, "timestamp": timestamp} + self._bot_logs[bot_id].append(log_entry) + + async def _handle_notify(self, bot_id: str, data: Any): + """Handle notification messages.""" + # Store notifications if needed + + async def _handle_status(self, bot_id: str, data: Any): + """Handle status updates.""" + # Store latest status + + async def _handle_heartbeat(self, bot_id: str, data: Any): + """Handle heartbeat messages.""" + self._discovered_bots[bot_id] = time.time() # Update last seen + + async def _handle_events(self, bot_id: str, data: Any): + """Handle internal events.""" + # Process events as needed + + async def _handle_external_event(self, bot_id: str, channel: str, data: Any): + """Handle external events.""" + event_type = channel.split("/")[-1] + + async def _handle_rpc_response(self, topic: str, message): + """Handle RPC responses on hummingbot-api/response/* topics.""" + try: + # Parse the response data + try: + data = json.loads(message.payload.decode("utf-8")) + except json.JSONDecodeError: + data = message.payload.decode("utf-8") + + # Check if we have a pending response for this topic + if topic in self._pending_responses: + future = self._pending_responses.pop(topic) + if not future.done(): + future.set_result(data) + else: + logger.warning(f"No pending RPC response found for topic: {topic}") + + except Exception as e: + logger.error(f"Error handling RPC response on {topic}: {e}", exc_info=True) + + async def _handle_command_response(self, bot_id: str, channel: str, data: Any): + """Handle command responses (legacy - keeping for backward compatibility).""" + # Extract command from response channel (e.g., response/start/1234567890 or response/history) + channel_parts = channel.split("/") + if len(channel_parts) >= 2: + command = channel_parts[1] + + async def start(self): + """Start the MQTT client.""" + try: + # Create and store the main message handling task + task = asyncio.create_task(self._handle_messages()) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + + logger.info("MQTT client started") + + # Wait a bit for connection to establish + for i in range(10): + if self._connected: + logger.info("MQTT connection established successfully") + break + await asyncio.sleep(0.5) + else: + logger.warning("MQTT connection not established after 5 seconds") + + except Exception as e: + logger.error(f"Failed to start MQTT client: {e}", exc_info=True) + + async def stop(self): + """Stop the MQTT client.""" + self._connected = False + + # Cancel all running tasks + for task in self._tasks: + task.cancel() + + # Wait for all tasks to complete + await asyncio.gather(*self._tasks, return_exceptions=True) + + logger.info("MQTT client stopped") + + async def publish_command_and_wait( + self, bot_id: str, command: str, data: Dict[str, Any], timeout: float = 30.0, qos: int = 1 + ) -> Optional[Any]: + """ + Publish a command to a bot and wait for the response. + + :param bot_id: The bot instance ID + :param command: The command to send + :param data: Command data + :param timeout: Timeout in seconds to wait for response + :param qos: Quality of Service level + :return: Response data if received, None if timeout or error + """ + if not self._connected or not self._client: + logger.error("Not connected to MQTT broker") + return None + + # Generate unique reply_to topic + timestamp = int(time.time() * 1000) + reply_to_topic = f"hummingbot-api/response/{timestamp}" + + # Create a future to track the response using the reply_to topic as key + future = asyncio.Future() + self._pending_responses[reply_to_topic] = future + + try: + # Send the command with custom reply_to + success = await self._publish_command_with_reply_to(bot_id, command, data, reply_to_topic, qos) + if not success: + self._pending_responses.pop(reply_to_topic, None) + return None + + # Wait for response with timeout + try: + response = await asyncio.wait_for(future, timeout=timeout) + return response + except asyncio.TimeoutError: + logger.warning(f"⏰ Timeout waiting for response from {bot_id} for command '{command}' on {reply_to_topic}") + self._pending_responses.pop(reply_to_topic, None) + return None + + except Exception as e: + logger.error(f"Error sending command and waiting for response: {e}") + self._pending_responses.pop(reply_to_topic, None) + return None + + async def _publish_command_with_reply_to( + self, bot_id: str, command: str, data: Dict[str, Any], reply_to: str, qos: int = 1 + ) -> bool: + """ + Publish a command to a bot with custom reply_to topic. + + :param bot_id: The bot instance ID + :param command: The command to send + :param data: Command data + :param reply_to: Custom reply_to topic + :param qos: Quality of Service level + :return: True if published successfully + """ + if not self._connected or not self._client: + logger.error("Not connected to MQTT broker") + return False + + # Convert dots to slashes for MQTT topic + mqtt_bot_id = bot_id.replace(".", "/") + + # Use the correct topic for each command + topic = f"hbot/{mqtt_bot_id}/{command}" + + # Create the full RPC message structure with custom reply_to + message = { + "header": { + "timestamp": int(time.time() * 1000), # Milliseconds + "reply_to": reply_to, # Custom reply_to topic + "msg_id": int(time.time() * 1000), + "node_id": "hummingbot-api", + "agent": "hummingbot-api", + "properties": {}, + }, + "data": data or {}, + } + + try: + await self._client.publish(topic, payload=json.dumps(message), qos=qos) + return True + except Exception as e: + logger.error(f"Failed to publish command to {bot_id}: {e}") + return False + + async def publish_command(self, bot_id: str, command: str, data: Dict[str, Any], qos: int = 1) -> bool: + """ + Publish a command to a bot using proper RPCMessage Request format. + + :param bot_id: The bot instance ID + :param command: The command to send + :param data: Command data (should match the specific CommandMessage.Request structure) + :param qos: Quality of Service level + :return: True if published successfully + """ + if not self._connected or not self._client: + logger.error("Not connected to MQTT broker") + return False + + # Convert dots to slashes for MQTT topic + mqtt_bot_id = bot_id.replace(".", "/") + + # Use the correct topic for each command + topic = f"hbot/{mqtt_bot_id}/{command}" + + # Create the full RPC message structure as expected by commlib + # Based on RPCClient._prepare_request method + message = { + "header": { + "timestamp": int(time.time() * 1000), # Milliseconds + "reply_to": f"hummingbot-api-response-{int(time.time() * 1000)}", # Unique response topic + "msg_id": int(time.time() * 1000), + "node_id": "hummingbot-api", + "agent": "hummingbot-api", + "properties": {}, + }, + "data": data or {}, + } + + try: + await self._client.publish(topic, payload=json.dumps(message), qos=qos) + return True + except Exception as e: + logger.error(f"Failed to publish command to {bot_id}: {e}") + return False + + def add_handler(self, topic_pattern: str, handler: Callable): + """ + Add a custom message handler for a topic pattern. + + :param topic_pattern: Topic pattern (supports + wildcard) + :param handler: Callback function(bot_id, channel, data) - can be sync or async + """ + self._handlers[topic_pattern] = handler + + def remove_handler(self, topic_pattern: str): + """Remove a message handler.""" + self._handlers.pop(topic_pattern, None) + + def get_bot_controller_reports(self, bot_id: str) -> Dict[str, Any]: + """Get controller reports for a bot. + + Returns: + Dict with controller_id as key and report dict as value. + Each report contains 'performance' and 'custom_info' keys. + """ + return self._bot_controller_reports.get(bot_id, {}) + + def get_bot_logs(self, bot_id: str) -> list: + """Get recent logs for a bot.""" + return list(self._bot_logs.get(bot_id, [])) + + def get_bot_error_logs(self, bot_id: str) -> list: + """Get recent error logs for a bot.""" + return list(self._bot_error_logs.get(bot_id, [])) + + def clear_bot_data(self, bot_id: str): + """Clear stored data for a bot.""" + self._bot_controller_reports.pop(bot_id, None) + self._bot_logs.pop(bot_id, None) + self._bot_error_logs.pop(bot_id, None) + self._discovered_bots.pop(bot_id, None) + + def clear_bot_controller_reports(self, bot_id: str): + """Clear only controller report data for a bot (useful when bot is stopped).""" + self._bot_controller_reports.pop(bot_id, None) + + @property + def is_connected(self) -> bool: + """Check if connected to MQTT broker.""" + return self._connected + + def get_discovered_bots(self, timeout_seconds: int = 300) -> list: + """Get list of auto-discovered bots. + + :param timeout_seconds: Consider bots inactive after this many seconds without messages + :return: List of active bot IDs + """ + current_time = time.time() + active_bots = [ + bot_id for bot_id, last_seen in self._discovered_bots.items() if current_time - last_seen < timeout_seconds + ] + return active_bots + + async def subscribe_to_bot(self, bot_id: str): + """Subscribe to all topics for a specific bot.""" + if self._connected and self._client: + # Convert dots to slashes for MQTT topic + mqtt_bot_id = bot_id.replace(".", "/") + + # Subscribe to all topics for this specific bot + topic = f"hbot/{mqtt_bot_id}/#" + await self._client.subscribe(topic, qos=1) + else: + logger.warning(f"Cannot subscribe to bot {bot_id} - not connected to MQTT") + + +if __name__ == "__main__": + # Example usage + import sys + + # For Windows compatibility + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + logging.basicConfig(level=logging.INFO) + + async def main(): + mqtt_manager = MQTTManager(host="localhost", port=1883, username="", password="") + + await mqtt_manager.start() + + try: + # Keep running to listen for messages + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + await mqtt_manager.stop() + + asyncio.run(main()) diff --git a/utils/security.py b/utils/security.py index 6d71bd9e..2395a5b9 100644 --- a/utils/security.py +++ b/utils/security.py @@ -3,7 +3,6 @@ from hummingbot.client.config.config_crypt import PASSWORD_VERIFICATION_WORD, BaseSecretsManager from hummingbot.client.config.config_helpers import ( ClientConfigAdapter, - _load_yml_data_into_map, connector_name_from_file, get_connector_hb_config, read_yml_file, @@ -11,19 +10,20 @@ ) from hummingbot.client.config.security import Security -from config import PASSWORD_VERIFICATION_PATH -from utils.file_system import FileSystemUtil -from utils.models import BackendAPIConfigAdapter +from config import settings +from utils.file_system import fs_util +from utils.hummingbot_api_config_adapter import HummingbotAPIConfigAdapter class BackendAPISecurity(Security): - fs_util = FileSystemUtil(base_path="bots/credentials") - @classmethod def login_account(cls, account_name: str, secrets_manager: BaseSecretsManager) -> bool: if not cls.validate_password(secrets_manager): return False cls.secrets_manager = secrets_manager + # Also set on parent Security class for hummingbot's ClientConfigAdapter methods + # that access Security.secrets_manager directly + Security.secrets_manager = secrets_manager cls.decrypt_all(account_name=account_name) return True @@ -31,10 +31,10 @@ def login_account(cls, account_name: str, secrets_manager: BaseSecretsManager) - def decrypt_all(cls, account_name: str = "master_account"): cls._secure_configs.clear() cls._decryption_done.clear() - encrypted_files = [file for file in cls.fs_util.list_files(directory=f"{account_name}/connectors") if + encrypted_files = [file for file in fs_util.list_files(directory=f"credentials/{account_name}/connectors") if file.endswith(".yml")] for file in encrypted_files: - path = Path(cls.fs_util.base_path + f"/{account_name}/connectors/" + file) + path = Path(fs_util.base_path + f"/credentials/{account_name}/connectors/" + file) cls.decrypt_connector_config(path) cls._decryption_done.set() @@ -44,36 +44,33 @@ def decrypt_connector_config(cls, file_path: Path): cls._secure_configs[connector_name] = cls.load_connector_config_map_from_file(file_path) @classmethod - def load_connector_config_map_from_file(cls, yml_path: Path) -> BackendAPIConfigAdapter: + def load_connector_config_map_from_file(cls, yml_path: Path) -> HummingbotAPIConfigAdapter: config_data = read_yml_file(yml_path) connector_name = connector_name_from_file(yml_path) - hb_config = get_connector_hb_config(connector_name) - config_map = BackendAPIConfigAdapter(hb_config) - _load_yml_data_into_map(config_data, config_map) + hb_config = get_connector_hb_config(connector_name).model_validate(config_data) + config_map = HummingbotAPIConfigAdapter(hb_config) + config_map.decrypt_all_secure_data() return config_map @classmethod def update_connector_keys(cls, account_name: str, connector_config: ClientConfigAdapter): connector_name = connector_config.connector - file_path = cls.fs_util.get_connector_keys_path(account_name=account_name, connector_name=connector_name) + file_path = fs_util.get_connector_keys_path(account_name=account_name, connector_name=connector_name) cm_yml_str = connector_config.generate_yml_output_str_with_comments() - cls.fs_util.ensure_file_and_dump_text(file_path, cm_yml_str) + fs_util.ensure_file_and_dump_text(str(file_path), cm_yml_str) update_connector_hb_config(connector_config) cls._secure_configs[connector_name] = connector_config @staticmethod def new_password_required() -> bool: - return not PASSWORD_VERIFICATION_PATH.exists() - - @staticmethod - def store_password_verification(secrets_manager: BaseSecretsManager): - encrypted_word = secrets_manager.encrypt_secret_value(PASSWORD_VERIFICATION_WORD, PASSWORD_VERIFICATION_WORD) - FileSystemUtil.ensure_file_and_dump_text(PASSWORD_VERIFICATION_PATH, encrypted_word) + full_path = fs_util._get_full_path(settings.app.password_verification_path) + return not Path(full_path).exists() @staticmethod def validate_password(secrets_manager: BaseSecretsManager) -> bool: valid = False - with open(PASSWORD_VERIFICATION_PATH, "r") as f: + full_path = fs_util._get_full_path(settings.app.password_verification_path) + with open(full_path, "r") as f: encrypted_word = f.read() try: decrypted_word = secrets_manager.decrypt_secret_value(PASSWORD_VERIFICATION_WORD, encrypted_word) @@ -82,3 +79,8 @@ def validate_password(secrets_manager: BaseSecretsManager) -> bool: if str(e) != "MAC mismatch": raise e return valid + + @staticmethod + def store_password_verification(secrets_manager: BaseSecretsManager): + encrypted_word = secrets_manager.encrypt_secret_value(PASSWORD_VERIFICATION_WORD, PASSWORD_VERIFICATION_WORD) + fs_util.ensure_file_and_dump_text(settings.app.password_verification_path, encrypted_word)