diff --git a/roboclaw/account/__init__.py b/roboclaw/account/__init__.py new file mode 100644 index 00000000..262c8040 --- /dev/null +++ b/roboclaw/account/__init__.py @@ -0,0 +1,14 @@ +"""Account credit ledger for Evo Studio billing.""" + +from .ledger import AccountLedger, BillingRecord, PaymentOrder, Wallet +from .training_billing import apply_service_fee_cents, estimate_training_hold_cents, hourly_cost_from_params + +__all__ = [ + "AccountLedger", + "BillingRecord", + "PaymentOrder", + "Wallet", + "apply_service_fee_cents", + "estimate_training_hold_cents", + "hourly_cost_from_params", +] diff --git a/roboclaw/account/ledger.py b/roboclaw/account/ledger.py new file mode 100644 index 00000000..2be25e12 --- /dev/null +++ b/roboclaw/account/ledger.py @@ -0,0 +1,663 @@ +"""Small persistent credit ledger for account billing.""" + +from __future__ import annotations + +import json +import threading +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal +from uuid import uuid4 + +LedgerKind = Literal["admin_recharge", "payment_recharge", "dataset_reward", "freeze", "settle", "release"] +PaymentOrderStatus = Literal["pending", "paid", "cancelled"] + + +@dataclass(frozen=True) +class Wallet: + username: str + balance_cents: int = 0 + frozen_cents: int = 0 + reward_points: int = 0 + updated_at: str = "" + + @property + def available_cents(self) -> int: + return self.balance_cents - self.frozen_cents + + def to_dict(self) -> dict[str, Any]: + return { + "username": self.username, + "balanceCents": self.balance_cents, + "frozenBalanceCents": self.frozen_cents, + "availableBalanceCents": self.available_cents, + "creditPoints": self.reward_points, + # Backward-compatible aliases while the app migrates to balance/credit point names. + "creditCents": self.balance_cents, + "frozenCreditCents": self.frozen_cents, + "availableCreditCents": self.available_cents, + "rewardPoints": self.reward_points, + "frozenCents": self.frozen_cents, + "availableCents": self.available_cents, + "updatedAt": self.updated_at, + } + + +@dataclass(frozen=True) +class BillingRecord: + record_id: str + username: str + kind: LedgerKind + amount_cents: int + balance_after_cents: int + frozen_after_cents: int + reward_points_after: int = 0 + reason: str = "" + task_name: str = "" + job_id: str = "" + created_at: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "recordId": self.record_id, + "username": self.username, + "kind": self.kind, + "amountCents": self.amount_cents, + "balanceAfterCents": self.balance_after_cents, + "frozenBalanceAfterCents": self.frozen_after_cents, + "creditPointsAfter": self.reward_points_after, + # Backward-compatible aliases. + "frozenAfterCents": self.frozen_after_cents, + "creditAfterCents": self.balance_after_cents, + "frozenCreditAfterCents": self.frozen_after_cents, + "rewardPointsAfter": self.reward_points_after, + "reason": self.reason, + "taskName": self.task_name, + "jobId": self.job_id, + "createdAt": self.created_at, + } + + +@dataclass(frozen=True) +class PaymentOrder: + order_id: str + username: str + amount_cents: int + bonus_points: int = 0 + provider: str = "mock" + status: PaymentOrderStatus = "pending" + provider_order_id: str = "" + pay_url: str = "" + payee_name: str = "" + payee_account: str = "" + reason: str = "credit topup" + created_at: str = "" + paid_at: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "orderId": self.order_id, + "username": self.username, + "amountCents": self.amount_cents, + "bonusPoints": self.bonus_points, + "provider": self.provider, + "status": self.status, + "providerOrderId": self.provider_order_id, + "payUrl": self.pay_url, + "payeeName": self.payee_name, + "payeeAccount": self.payee_account, + "reason": self.reason, + "createdAt": self.created_at, + "paidAt": self.paid_at, + } + + +class AccountLedger: + """File-backed wallet ledger. + + This is intentionally small and swappable: production can replace it with + MySQL/RDS while preserving route-level semantics. + """ + + def __init__(self, path: Path | None = None) -> None: + self.path = path or Path.home() / ".roboclaw" / "account_ledger.json" + self._lock = threading.Lock() + + def wallet(self, username: str) -> Wallet: + username = _clean_username(username) + with self._lock: + state = self._load() + return self._wallet_from_state(state, username) + + def records(self, username: str = "", *, limit: int = 50) -> list[BillingRecord]: + with self._lock: + state = self._load() + records = [_record_from_payload(item) for item in state.get("records", [])] + if username: + records = [record for record in records if record.username == username] + return records[-max(limit, 0) :][::-1] + + def orders(self, username: str = "", *, limit: int = 50) -> list[PaymentOrder]: + with self._lock: + state = self._load() + orders = [_order_from_payload(item) for item in state.get("paymentOrders", [])] + if username: + orders = [order for order in orders if order.username == username] + return orders[-max(limit, 0) :][::-1] + + def create_topup_order( + self, + username: str, + amount_cents: int, + *, + bonus_points: int = 0, + provider: str = "mock", + payee_name: str = "", + payee_account: str = "", + reason: str = "credit topup", + ) -> PaymentOrder: + if amount_cents <= 0: + raise ValueError("amount_cents must be positive") + if bonus_points < 0: + raise ValueError("bonus_points must be non-negative") + username = _clean_username(username) + provider = (provider or "mock").strip() + if not provider: + raise ValueError("provider is required") + with self._lock: + state = self._load() + order_id = uuid4().hex + order = PaymentOrder( + order_id=order_id, + username=username, + amount_cents=amount_cents, + bonus_points=bonus_points, + provider=provider, + status="pending", + provider_order_id=f"{provider}_{order_id}", + pay_url=f"roboclaw://pay/{provider}/{order_id}", + payee_name=payee_name, + payee_account=payee_account, + reason=reason, + created_at=_now(), + ) + state.setdefault("paymentOrders", []).append(order.to_dict()) + state.setdefault("wallets", {}).setdefault(username, self._wallet_from_state(state, username).to_dict()) + self._save(state) + return order + + def complete_topup_order( + self, + order_id: str, + *, + provider_order_id: str = "", + ) -> tuple[PaymentOrder, Wallet, BillingRecord | None]: + order_id = order_id.strip() + if not order_id: + raise ValueError("order_id is required") + with self._lock: + state = self._load() + orders = state.setdefault("paymentOrders", []) + for index, payload in enumerate(orders): + order = _order_from_payload(payload) + if order.order_id != order_id: + continue + if order.status == "paid": + wallet = self._wallet_from_state(state, order.username) + return order, wallet, None + if order.status != "pending": + raise ValueError(f"cannot complete {order.status} order") + paid_order = PaymentOrder( + order_id=order.order_id, + username=order.username, + amount_cents=order.amount_cents, + bonus_points=order.bonus_points, + provider=order.provider, + status="paid", + provider_order_id=provider_order_id or order.provider_order_id, + pay_url=order.pay_url, + payee_name=order.payee_name, + payee_account=order.payee_account, + reason=order.reason, + created_at=order.created_at, + paid_at=_now(), + ) + wallet = self._wallet_from_state(state, order.username) + wallet = Wallet( + username=wallet.username, + balance_cents=wallet.balance_cents + order.amount_cents, + frozen_cents=wallet.frozen_cents, + reward_points=wallet.reward_points + order.bonus_points, + updated_at=_now(), + ) + record = self._append_record( + state, + wallet, + "payment_recharge", + order.amount_cents, + reason=f"{order.provider} payment recharge", + job_id=order.order_id, + ) + orders[index] = paid_order.to_dict() + self._save_wallet(state, wallet) + self._save(state) + return paid_order, wallet, record + raise ValueError("payment order not found") + + def grant_dataset_reward( + self, + username: str, + dataset_id: str, + reward_points: int, + *, + reason: str = "dataset upload reward", + ) -> tuple[Wallet, BillingRecord, bool]: + if reward_points <= 0: + raise ValueError("reward_points must be positive") + username = _clean_username(username) + dataset_id = dataset_id.strip() + if not dataset_id: + raise ValueError("dataset_id is required") + with self._lock: + state = self._load() + for payload in state.get("records", []): + record = _record_from_payload(payload) + if record.kind == "dataset_reward" and record.username == username and record.job_id == dataset_id: + return self._wallet_from_state(state, username), record, False + wallet = self._wallet_from_state(state, username) + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents, + frozen_cents=wallet.frozen_cents, + reward_points=wallet.reward_points + reward_points, + updated_at=_now(), + ) + record = self._append_record( + state, + wallet, + "dataset_reward", + reward_points, + reason=reason, + job_id=dataset_id, + ) + self._save_wallet(state, wallet) + self._save(state) + return wallet, record, True + + def admin_recharge(self, username: str, amount_cents: int, *, reason: str = "admin recharge") -> tuple[Wallet, BillingRecord]: + if amount_cents <= 0: + raise ValueError("amount_cents must be positive") + username = _clean_username(username) + with self._lock: + state = self._load() + wallet = self._wallet_from_state(state, username) + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents + amount_cents, + frozen_cents=wallet.frozen_cents, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + record = self._append_record(state, wallet, "admin_recharge", amount_cents, reason=reason) + self._save_wallet(state, wallet) + self._save(state) + return wallet, record + + def freeze( + self, + username: str, + amount_cents: int, + *, + reason: str = "freeze credits", + task_name: str = "", + job_id: str = "", + ) -> tuple[Wallet, BillingRecord]: + if amount_cents <= 0: + raise ValueError("amount_cents must be positive") + username = _clean_username(username) + with self._lock: + state = self._load() + wallet = self._wallet_from_state(state, username) + if wallet.available_cents < amount_cents: + raise ValueError("insufficient available balance") + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents, + frozen_cents=wallet.frozen_cents + amount_cents, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + record = self._append_record( + state, + wallet, + "freeze", + amount_cents, + reason=reason, + task_name=task_name, + job_id=job_id, + ) + self._save_wallet(state, wallet) + self._save(state) + return wallet, record + + def settle( + self, + username: str, + amount_cents: int, + *, + reason: str = "settle credits", + task_name: str = "", + job_id: str = "", + ) -> tuple[Wallet, BillingRecord]: + if amount_cents <= 0: + raise ValueError("amount_cents must be positive") + username = _clean_username(username) + with self._lock: + state = self._load() + wallet = self._wallet_from_state(state, username) + if wallet.frozen_cents < amount_cents: + raise ValueError("settle amount exceeds frozen balance") + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents - amount_cents, + frozen_cents=wallet.frozen_cents - amount_cents, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + record = self._append_record( + state, + wallet, + "settle", + -amount_cents, + reason=reason, + task_name=task_name, + job_id=job_id, + ) + self._save_wallet(state, wallet) + self._save(state) + return wallet, record + + def release( + self, + username: str, + amount_cents: int, + *, + reason: str = "release frozen balance", + task_name: str = "", + job_id: str = "", + ) -> tuple[Wallet, BillingRecord]: + if amount_cents <= 0: + raise ValueError("amount_cents must be positive") + username = _clean_username(username) + with self._lock: + state = self._load() + wallet = self._wallet_from_state(state, username) + if wallet.frozen_cents < amount_cents: + raise ValueError("release amount exceeds frozen balance") + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents, + frozen_cents=wallet.frozen_cents - amount_cents, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + record = self._append_record( + state, + wallet, + "release", + amount_cents, + reason=reason, + task_name=task_name, + job_id=job_id, + ) + self._save_wallet(state, wallet) + self._save(state) + return wallet, record + + def settle_job( + self, + username: str, + job_id: str, + charge_cents: int, + *, + reason: str = "settle training job", + task_name: str = "", + ) -> tuple[Wallet, BillingRecord, BillingRecord | None]: + if charge_cents <= 0: + raise ValueError("charge_cents must be positive") + username = _clean_username(username) + job_id = job_id.strip() + if not job_id: + raise ValueError("job_id is required") + with self._lock: + state = self._load() + outstanding = self._job_frozen_cents(state, username=username, job_id=job_id) + if outstanding <= 0: + raise ValueError("job has no frozen balance") + if charge_cents > outstanding: + raise ValueError("charge amount exceeds frozen balance for job") + wallet = self._wallet_from_state(state, username) + if wallet.frozen_cents < charge_cents: + raise ValueError("settle amount exceeds frozen balance") + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents - charge_cents, + frozen_cents=wallet.frozen_cents - charge_cents, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + settle_record = self._append_record( + state, + wallet, + "settle", + -charge_cents, + reason=reason, + task_name=task_name, + job_id=job_id, + ) + release_record: BillingRecord | None = None + remainder = outstanding - charge_cents + if remainder: + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents, + frozen_cents=wallet.frozen_cents - remainder, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + release_record = self._append_record( + state, + wallet, + "release", + remainder, + reason="release unused training balance", + task_name=task_name, + job_id=job_id, + ) + self._save_wallet(state, wallet) + self._save(state) + return wallet, settle_record, release_record + + def reassign_job_hold( + self, + username: str, + old_job_id: str, + new_job_id: str, + ) -> BillingRecord: + username = _clean_username(username) + old_job_id = old_job_id.strip() + new_job_id = new_job_id.strip() + if not old_job_id or not new_job_id: + raise ValueError("old_job_id and new_job_id are required") + with self._lock: + state = self._load() + records = state.setdefault("records", []) + for index in range(len(records) - 1, -1, -1): + record = _record_from_payload(records[index]) + if record.username == username and record.job_id == old_job_id and record.kind == "freeze": + updated = BillingRecord( + record_id=record.record_id, + username=record.username, + kind=record.kind, + amount_cents=record.amount_cents, + balance_after_cents=record.balance_after_cents, + frozen_after_cents=record.frozen_after_cents, + reward_points_after=record.reward_points_after, + reason=record.reason, + task_name=record.task_name, + job_id=new_job_id, + created_at=record.created_at, + ) + records[index] = updated.to_dict() + self._save(state) + return updated + raise ValueError("frozen job hold not found") + + def release_job_hold( + self, + username: str, + job_id: str, + *, + reason: str = "release job hold", + task_name: str = "", + ) -> tuple[Wallet, BillingRecord]: + username = _clean_username(username) + job_id = job_id.strip() + if not job_id: + raise ValueError("job_id is required") + with self._lock: + state = self._load() + amount_cents = self._job_frozen_cents(state, username=username, job_id=job_id) + if amount_cents <= 0: + raise ValueError("job has no frozen balance") + wallet = self._wallet_from_state(state, username) + if wallet.frozen_cents < amount_cents: + raise ValueError("release amount exceeds frozen balance") + wallet = Wallet( + username=username, + balance_cents=wallet.balance_cents, + frozen_cents=wallet.frozen_cents - amount_cents, + reward_points=wallet.reward_points, + updated_at=_now(), + ) + record = self._append_record( + state, + wallet, + "release", + amount_cents, + reason=reason, + task_name=task_name, + job_id=job_id, + ) + self._save_wallet(state, wallet) + self._save(state) + return wallet, record + + def _append_record( + self, + state: dict[str, Any], + wallet: Wallet, + kind: LedgerKind, + amount_cents: int, + *, + reason: str = "", + task_name: str = "", + job_id: str = "", + ) -> BillingRecord: + record = BillingRecord( + record_id=uuid4().hex, + username=wallet.username, + kind=kind, + amount_cents=amount_cents, + balance_after_cents=wallet.balance_cents, + frozen_after_cents=wallet.frozen_cents, + reward_points_after=wallet.reward_points, + reason=reason, + task_name=task_name, + job_id=job_id, + created_at=_now(), + ) + state.setdefault("records", []).append(record.to_dict()) + return record + + def _job_frozen_cents(self, state: dict[str, Any], *, username: str, job_id: str) -> int: + outstanding = 0 + for payload in state.get("records", []): + record = _record_from_payload(payload) + if record.username != username or record.job_id != job_id: + continue + if record.kind == "freeze": + outstanding += record.amount_cents + elif record.kind == "settle": + outstanding -= abs(record.amount_cents) + elif record.kind == "release": + outstanding -= record.amount_cents + return max(outstanding, 0) + + def _wallet_from_state(self, state: dict[str, Any], username: str) -> Wallet: + payload = state.setdefault("wallets", {}).get(username) or {} + return Wallet( + username=username, + balance_cents=int(payload.get("balanceCents", payload.get("creditCents", 0)) or 0), + frozen_cents=int(payload.get("frozenBalanceCents", payload.get("frozenCreditCents", payload.get("frozenCents", 0))) or 0), + reward_points=int(payload.get("creditPoints", payload.get("rewardPoints", 0)) or 0), + updated_at=str(payload.get("updatedAt") or ""), + ) + + def _save_wallet(self, state: dict[str, Any], wallet: Wallet) -> None: + state.setdefault("wallets", {})[wallet.username] = wallet.to_dict() + + def _load(self) -> dict[str, Any]: + if not self.path.is_file(): + return {"wallets": {}, "records": [], "paymentOrders": []} + return json.loads(self.path.read_text(encoding="utf-8")) + + def _save(self, state: dict[str, Any]) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text(json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") + + +def _clean_username(username: str) -> str: + value = username.strip() + if not value: + raise ValueError("username is required") + return value + + +def _record_from_payload(payload: dict[str, Any]) -> BillingRecord: + return BillingRecord( + record_id=str(payload.get("recordId") or ""), + username=str(payload.get("username") or ""), + kind=str(payload.get("kind") or "freeze"), # type: ignore[arg-type] + amount_cents=int(payload.get("amountCents", 0) or 0), + balance_after_cents=int(payload.get("balanceAfterCents", 0) or 0), + frozen_after_cents=int(payload.get("frozenAfterCents", 0) or 0), + reward_points_after=int(payload.get("creditPointsAfter", payload.get("rewardPointsAfter", 0)) or 0), + reason=str(payload.get("reason") or ""), + task_name=str(payload.get("taskName") or ""), + job_id=str(payload.get("jobId") or ""), + created_at=str(payload.get("createdAt") or ""), + ) + + +def _order_from_payload(payload: dict[str, Any]) -> PaymentOrder: + return PaymentOrder( + order_id=str(payload.get("orderId") or ""), + username=str(payload.get("username") or ""), + amount_cents=int(payload.get("amountCents", 0) or 0), + bonus_points=int(payload.get("bonusPoints", 0) or 0), + provider=str(payload.get("provider") or "mock"), + status=str(payload.get("status") or "pending"), # type: ignore[arg-type] + provider_order_id=str(payload.get("providerOrderId") or ""), + pay_url=str(payload.get("payUrl") or ""), + payee_name=str(payload.get("payeeName") or ""), + payee_account=str(payload.get("payeeAccount") or ""), + reason=str(payload.get("reason") or ""), + created_at=str(payload.get("createdAt") or ""), + paid_at=str(payload.get("paidAt") or ""), + ) + + +def _now() -> str: + return datetime.now(tz=timezone.utc).isoformat() diff --git a/roboclaw/account/training_billing.py b/roboclaw/account/training_billing.py new file mode 100644 index 00000000..9dd4e3df --- /dev/null +++ b/roboclaw/account/training_billing.py @@ -0,0 +1,60 @@ +"""Training billing helpers for cloud compute jobs.""" + +from __future__ import annotations + +from math import ceil +from typing import Any, Mapping + +DEFAULT_SERVICE_FEE_BPS = 1_000 +DEFAULT_MIN_BILLABLE_MINUTES = 60 + + +def estimate_training_hold_cents( + *, + hourly_cost_cents: int, + service_fee_bps: int = DEFAULT_SERVICE_FEE_BPS, + min_billable_minutes: int = DEFAULT_MIN_BILLABLE_MINUTES, +) -> int: + """Estimate the upfront balance hold for a cloud training job.""" + if hourly_cost_cents <= 0: + raise ValueError("hourly_cost_cents must be positive") + if service_fee_bps < 0: + raise ValueError("service_fee_bps must be non-negative") + if min_billable_minutes <= 0: + raise ValueError("min_billable_minutes must be positive") + provider_cost = ceil(hourly_cost_cents * min_billable_minutes / 60) + return apply_service_fee_cents(provider_cost, service_fee_bps=service_fee_bps) + + +def apply_service_fee_cents(provider_cost_cents: int, *, service_fee_bps: int = DEFAULT_SERVICE_FEE_BPS) -> int: + """Convert provider cost to user-facing charge with service fee.""" + if provider_cost_cents <= 0: + raise ValueError("provider_cost_cents must be positive") + if service_fee_bps < 0: + raise ValueError("service_fee_bps must be non-negative") + return ceil(provider_cost_cents * (10_000 + service_fee_bps) / 10_000) + + +def hourly_cost_from_params(params: Mapping[str, Any]) -> int: + """Read hourly provider cost from a training payload. + + The value is provider cost in cents before service fee. Accepted aliases keep + the route compatible with EVO_Train and older AutoDL experiments. + """ + for key in ( + "hourlyCostCents", + "costHourlyCents", + "estimatedHourlyCostCents", + "hourlyPriceCents", + "firstHourCostCents", + ): + value = params.get(key) + if value in (None, ""): + continue + try: + parsed = int(value) + except (TypeError, ValueError): + continue + if parsed > 0: + return parsed + return 0 diff --git a/roboclaw/http/routes/__init__.py b/roboclaw/http/routes/__init__.py index f5931bde..364df9e4 100644 --- a/roboclaw/http/routes/__init__.py +++ b/roboclaw/http/routes/__init__.py @@ -29,6 +29,7 @@ def register_all_routes( from roboclaw.http.routes.train import register_train_routes from roboclaw.http.routes.train_cloud import register_train_cloud_routes from roboclaw.http.routes.vla_rl import register_vla_rl_routes + from roboclaw.http.routes.account import register_account_routes from roboclaw.http.routes.infer import register_infer_routes from roboclaw.http.routes.hub import register_hub_routes from roboclaw.http.routes.chat_uploads import register_chat_upload_routes @@ -47,6 +48,7 @@ def register_all_routes( register_train_routes(app, service) register_train_cloud_routes(app, service) register_vla_rl_routes(app, service) + register_account_routes(app) register_infer_routes(app, service) register_hub_routes(app, service) diff --git a/roboclaw/http/routes/account.py b/roboclaw/http/routes/account.py new file mode 100644 index 00000000..01a90b0a --- /dev/null +++ b/roboclaw/http/routes/account.py @@ -0,0 +1,241 @@ +"""Account balance and contribution credit routes.""" + +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +from typing import Any + +from fastapi import FastAPI, Header, HTTPException +from pydantic import BaseModel, Field + +from roboclaw.account import AccountLedger + +_ledger: AccountLedger | None = None +_ADMIN_TOKEN_ENV = "EVO_STUDIO_ADMIN_TOKEN" + + +class RechargeRequest(BaseModel): + username: str + amount_cents: int + reason: str = "admin recharge" + + +class TopupOrderRequest(BaseModel): + username: str + amount_cents: int = Field(..., description="Cash balance top-up in cents. 100 cents = 1 CNY.") + bonus_points: int = Field( + default=0, + description="Small non-cash credit points granted as a top-up bonus, e.g. 5-20.", + ) + provider: str = "mock" + reason: str = "credit topup" + + +class CompleteTopupOrderRequest(BaseModel): + order_id: str + provider_order_id: str = "" + + +class DatasetRewardRequest(BaseModel): + username: str + dataset_id: str + reward_points: int = Field( + ..., + description="Small non-cash contribution credit points for an accepted dataset, e.g. 10-100.", + ) + reason: str = "dataset upload reward" + + +class BillingAmountRequest(BaseModel): + username: str + amount_cents: int + reason: str = "" + task_name: str = "" + job_id: str = "" + + +def payment_config() -> dict[str, Any]: + provider = os.environ.get("EVO_STUDIO_PAYMENT_PROVIDER", "mock").strip() or "mock" + payee_name = os.environ.get("EVO_STUDIO_PAYEE_NAME", "").strip() + payee_account = os.environ.get("EVO_STUDIO_PAYEE_ACCOUNT", "").strip() + return { + "provider": provider, + "payeeName": payee_name, + "payeeAccount": payee_account, + "configured": bool(payee_name and payee_account and provider != "mock"), + } + + +def _require_admin_token(token: str) -> None: + expected = os.environ.get(_ADMIN_TOKEN_ENV, "").strip() + if not expected: + raise HTTPException( + status_code=503, + detail=f"{_ADMIN_TOKEN_ENV} is required for admin billing routes", + ) + if token.strip() != expected: + raise HTTPException(status_code=403, detail="invalid admin token") + + +def get_ledger() -> AccountLedger: + global _ledger + if _ledger is None: + _ledger = AccountLedger() + return _ledger + + +def set_ledger_for_tests(ledger: AccountLedger | None) -> None: + global _ledger + _ledger = ledger + + +def register_account_routes(app: FastAPI) -> None: + @app.get("/api/account/payment-config") + async def account_payment_config() -> dict[str, Any]: + return payment_config() + + @app.get("/api/account/balance") + async def account_balance(username: str) -> dict[str, Any]: + try: + wallet = await asyncio.to_thread(get_ledger().wallet, username) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"wallet": wallet.to_dict()} + + @app.get("/api/account/billing-records") + async def billing_records(username: str = "", limit: int = 50) -> dict[str, Any]: + records = await asyncio.to_thread(get_ledger().records, username, limit=limit) + return {"records": [record.to_dict() for record in records]} + + @app.get("/api/account/topup-orders") + async def account_topup_orders(username: str = "", limit: int = 50) -> dict[str, Any]: + orders = await asyncio.to_thread(get_ledger().orders, username, limit=limit) + return {"orders": [order.to_dict() for order in orders]} + + @app.post("/api/account/topup-orders") + async def create_topup_order(body: TopupOrderRequest) -> dict[str, Any]: + config = payment_config() + try: + order = await asyncio.to_thread( + get_ledger().create_topup_order, + body.username, + body.amount_cents, + bonus_points=body.bonus_points, + provider=body.provider or config["provider"], + payee_name=config["payeeName"], + payee_account=config["payeeAccount"], + reason=body.reason, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"order": order.to_dict(), "paymentConfig": config} + + @app.post("/api/account/topup-orders/complete") + async def complete_topup_order(body: CompleteTopupOrderRequest) -> dict[str, Any]: + try: + order, wallet, record = await asyncio.to_thread( + get_ledger().complete_topup_order, + body.order_id, + provider_order_id=body.provider_order_id, + ) + except ValueError as exc: + raise HTTPException(status_code=404 if "not found" in str(exc) else 400, detail=str(exc)) from exc + return { + "order": order.to_dict(), + "wallet": wallet.to_dict(), + "record": record.to_dict() if record else None, + } + + @app.post("/api/account/rewards/dataset-upload") + async def grant_dataset_upload_reward(body: DatasetRewardRequest) -> dict[str, Any]: + try: + wallet, record, granted = await asyncio.to_thread( + get_ledger().grant_dataset_reward, + body.username, + body.dataset_id, + body.reward_points, + reason=body.reason, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"wallet": wallet.to_dict(), "record": record.to_dict(), "granted": granted} + + @app.post("/api/admin/account/recharge") + async def admin_account_recharge( + body: RechargeRequest, + x_roboclaw_admin_token: str = Header(default=""), + ) -> dict[str, Any]: + _require_admin_token(x_roboclaw_admin_token) + try: + wallet, record = await asyncio.to_thread( + get_ledger().admin_recharge, + body.username, + body.amount_cents, + reason=body.reason, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + return {"wallet": wallet.to_dict(), "record": record.to_dict()} + + @app.post("/api/billing/freeze") + async def billing_freeze( + body: BillingAmountRequest, + x_roboclaw_admin_token: str = Header(default=""), + ) -> dict[str, Any]: + _require_admin_token(x_roboclaw_admin_token) + try: + wallet, record = await asyncio.to_thread( + get_ledger().freeze, + body.username, + body.amount_cents, + reason=body.reason or "freeze training balance", + task_name=body.task_name, + job_id=body.job_id, + ) + except ValueError as exc: + raise HTTPException(status_code=409 if "insufficient" in str(exc) else 400, detail=str(exc)) from exc + return {"wallet": wallet.to_dict(), "record": record.to_dict()} + + @app.post("/api/billing/settle") + async def billing_settle( + body: BillingAmountRequest, + x_roboclaw_admin_token: str = Header(default=""), + ) -> dict[str, Any]: + _require_admin_token(x_roboclaw_admin_token) + try: + wallet, record = await asyncio.to_thread( + get_ledger().settle, + body.username, + body.amount_cents, + reason=body.reason or "settle training balance", + task_name=body.task_name, + job_id=body.job_id, + ) + except ValueError as exc: + raise HTTPException(status_code=409 if "exceeds" in str(exc) else 400, detail=str(exc)) from exc + return {"wallet": wallet.to_dict(), "record": record.to_dict()} + + @app.post("/api/billing/release") + async def billing_release( + body: BillingAmountRequest, + x_roboclaw_admin_token: str = Header(default=""), + ) -> dict[str, Any]: + _require_admin_token(x_roboclaw_admin_token) + try: + wallet, record = await asyncio.to_thread( + get_ledger().release, + body.username, + body.amount_cents, + reason=body.reason or "release frozen training balance", + task_name=body.task_name, + job_id=body.job_id, + ) + except ValueError as exc: + raise HTTPException(status_code=409 if "exceeds" in str(exc) else 400, detail=str(exc)) from exc + return {"wallet": wallet.to_dict(), "record": record.to_dict()} + + +def ledger_for_path(path: Path) -> AccountLedger: + return AccountLedger(path) diff --git a/roboclaw/http/routes/train_cloud.py b/roboclaw/http/routes/train_cloud.py index a420a3ed..d57af495 100644 --- a/roboclaw/http/routes/train_cloud.py +++ b/roboclaw/http/routes/train_cloud.py @@ -2,14 +2,18 @@ from __future__ import annotations +import asyncio from typing import Any from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field +from roboclaw.account import AccountLedger, apply_service_fee_cents, estimate_training_hold_cents, hourly_cost_from_params from roboclaw.embodied.service import EmbodiedService from roboclaw.training import TrainingPlanSpec, TrainingService, TrainingStartSpec, TrainingStopSpec +_ledger: AccountLedger | None = None + class CloudTrainStartRequest(BaseModel): dataset_name: str = "" @@ -22,6 +26,11 @@ class CloudTrainStartRequest(BaseModel): sku_id: str = "" image_id: str = "" task_name: str = "" + hourly_cost_cents: int = Field( + default=0, + description="Provider hourly compute cost in cents before service fee.", + ) + service_fee_bps: int = Field(default=1_000, description="Service fee in basis points. 1000 = 10%.") class CloudTrainStopRequest(BaseModel): @@ -29,6 +38,14 @@ class CloudTrainStopRequest(BaseModel): username: str = "" +class CloudTrainBillingSettleRequest(BaseModel): + username: str + job_id: str + provider_cost_cents: int = Field(..., description="Actual provider compute cost in cents before service fee.") + service_fee_bps: int = Field(default=1_000, description="Service fee in basis points. 1000 = 10%.") + task_name: str = "" + + class CloudTrainPlanRequest(BaseModel): username: str = "" message: str = "" @@ -51,11 +68,42 @@ def _bridge_error_status(exc: RuntimeError) -> int: return 503 if "bridge is not enabled" in str(exc).lower() else 502 +def get_ledger() -> AccountLedger: + global _ledger + if _ledger is None: + _ledger = AccountLedger() + return _ledger + + +def set_ledger_for_tests(ledger: AccountLedger | None) -> None: + global _ledger + _ledger = ledger + + def register_train_cloud_routes(app: FastAPI, service: EmbodiedService) -> None: training = TrainingService(service) @app.post("/api/train/cloud/start") async def train_cloud_start(body: CloudTrainStartRequest) -> dict[str, Any]: + username = body.username.strip() + hourly_cost_cents = body.hourly_cost_cents or hourly_cost_from_params(body.params) + hold_cents = 0 + freeze_record = None + if username and hourly_cost_cents: + try: + hold_cents = estimate_training_hold_cents( + hourly_cost_cents=hourly_cost_cents, + service_fee_bps=body.service_fee_bps, + ) + _wallet, freeze_record = get_ledger().freeze( + username, + hold_cents, + reason="cloud training first-hour hold", + task_name=body.task_name or body.dataset_name, + job_id=body.task_name or body.dataset_name or "pending-cloud-train", + ) + except ValueError as exc: + raise HTTPException(status_code=409 if "insufficient" in str(exc) else 400, detail=str(exc)) from exc try: result = await training.start( TrainingStartSpec( @@ -72,8 +120,62 @@ async def train_cloud_start(body: CloudTrainStartRequest) -> dict[str, Any]: ) ) except RuntimeError as exc: + if username and freeze_record is not None: + try: + get_ledger().release_job_hold( + username, + freeze_record.job_id, + reason="release hold after cloud training start failure", + task_name=body.task_name or body.dataset_name, + ) + except ValueError: + pass raise HTTPException(status_code=_bridge_error_status(exc), detail=str(exc)) from exc - return result.to_dict() + payload = result.to_dict() + if username and freeze_record is not None: + job_id = payload.get("job_id") or freeze_record.job_id + if job_id != freeze_record.job_id: + try: + freeze_record = get_ledger().reassign_job_hold( + username, + freeze_record.job_id, + str(job_id), + ) + except ValueError: + pass + payload["billing"] = { + "holdCents": hold_cents, + "hourlyCostCents": hourly_cost_cents, + "serviceFeeBps": body.service_fee_bps, + "record": freeze_record.to_dict(), + } + return payload + + @app.post("/api/train/cloud/billing/settle") + async def train_cloud_billing_settle(body: CloudTrainBillingSettleRequest) -> dict[str, Any]: + try: + charge_cents = apply_service_fee_cents( + body.provider_cost_cents, + service_fee_bps=body.service_fee_bps, + ) + wallet, settle_record, release_record = await asyncio.to_thread( + get_ledger().settle_job, + body.username, + body.job_id, + charge_cents, + reason="cloud training final settlement", + task_name=body.task_name, + ) + except ValueError as exc: + raise HTTPException(status_code=409 if "exceeds" in str(exc) or "no frozen" in str(exc) else 400, detail=str(exc)) from exc + return { + "wallet": wallet.to_dict(), + "chargeCents": charge_cents, + "providerCostCents": body.provider_cost_cents, + "serviceFeeBps": body.service_fee_bps, + "settleRecord": settle_record.to_dict(), + "releaseRecord": release_record.to_dict() if release_record else None, + } @app.post("/api/train/cloud/stop") async def train_cloud_stop(body: CloudTrainStopRequest) -> dict[str, Any]: diff --git a/tests/test_account_ledger.py b/tests/test_account_ledger.py new file mode 100644 index 00000000..75f4cec2 --- /dev/null +++ b/tests/test_account_ledger.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +from unittest.mock import patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from roboclaw.account import AccountLedger +from roboclaw.http.routes.account import register_account_routes, set_ledger_for_tests + + +def test_account_ledger_recharge_freeze_settle_release(tmp_path) -> None: + ledger = AccountLedger(tmp_path / "ledger.json") + + wallet, recharge = ledger.admin_recharge("pearl", 10_000) + assert wallet.balance_cents == 10_000 + assert wallet.available_cents == 10_000 + assert recharge.kind == "admin_recharge" + + wallet, frozen = ledger.freeze("pearl", 4_000, task_name="train-1", job_id="job-1") + assert wallet.balance_cents == 10_000 + assert wallet.frozen_cents == 4_000 + assert wallet.available_cents == 6_000 + assert frozen.kind == "freeze" + + wallet, released = ledger.release("pearl", 1_000) + assert wallet.balance_cents == 10_000 + assert wallet.frozen_cents == 3_000 + assert wallet.available_cents == 7_000 + assert released.kind == "release" + + wallet, settled = ledger.settle("pearl", 3_000) + assert wallet.balance_cents == 7_000 + assert wallet.frozen_cents == 0 + assert wallet.available_cents == 7_000 + assert settled.kind == "settle" + assert settled.amount_cents == -3_000 + + records = ledger.records("pearl") + assert [record.kind for record in records] == ["settle", "release", "freeze", "admin_recharge"] + + +def test_account_ledger_topup_order_auto_recharges_once(tmp_path) -> None: + ledger = AccountLedger(tmp_path / "ledger.json") + + order = ledger.create_topup_order( + "pearl", + 5_000, + bonus_points=5, + provider="mockpay", + payee_name="Evo Studio", + payee_account="merchant-001", + ) + + assert order.status == "pending" + assert order.bonus_points == 5 + assert order.payee_name == "Evo Studio" + assert order.payee_account == "merchant-001" + assert order.pay_url == f"roboclaw://pay/mockpay/{order.order_id}" + assert ledger.wallet("pearl").balance_cents == 0 + + paid_order, wallet, record = ledger.complete_topup_order(order.order_id, provider_order_id="txn-1") + + assert paid_order.status == "paid" + assert paid_order.payee_account == "merchant-001" + assert paid_order.provider_order_id == "txn-1" + assert wallet.balance_cents == 5_000 + assert wallet.reward_points == 5 + assert record is not None + assert record.kind == "payment_recharge" + assert record.job_id == order.order_id + + paid_order_2, wallet_2, record_2 = ledger.complete_topup_order(order.order_id) + assert paid_order_2.status == "paid" + assert wallet_2.balance_cents == 5_000 + assert wallet_2.reward_points == 5 + assert record_2 is None + + +def test_account_ledger_dataset_reward_is_idempotent(tmp_path) -> None: + ledger = AccountLedger(tmp_path / "ledger.json") + + wallet, record, granted = ledger.grant_dataset_reward("pearl", "dataset-1", 15) + + assert granted is True + assert wallet.available_cents == 0 + assert wallet.reward_points == 15 + assert record.kind == "dataset_reward" + assert record.amount_cents == 15 + assert record.reward_points_after == 15 + assert record.job_id == "dataset-1" + + wallet_2, record_2, granted_2 = ledger.grant_dataset_reward("pearl", "dataset-1", 15) + assert granted_2 is False + assert wallet_2.available_cents == 0 + assert wallet_2.reward_points == 15 + assert record_2.record_id == record.record_id + + +def test_account_ledger_reassigns_pending_training_hold(tmp_path) -> None: + ledger = AccountLedger(tmp_path / "ledger.json") + ledger.admin_recharge("pearl", 10_000) + _wallet, hold = ledger.freeze("pearl", 990, job_id="pending-cloud-train") + + updated = ledger.reassign_job_hold("pearl", "pending-cloud-train", "cloud-job-1") + + assert updated.record_id == hold.record_id + assert updated.job_id == "cloud-job-1" + records = ledger.records("pearl") + assert records[0].kind == "freeze" + assert records[0].job_id == "cloud-job-1" + assert ledger.wallet("pearl").frozen_cents == 990 + + +def test_account_ledger_releases_only_matching_job_hold(tmp_path) -> None: + ledger = AccountLedger(tmp_path / "ledger.json") + ledger.admin_recharge("pearl", 10_000) + ledger.freeze("pearl", 990, job_id="job-a") + ledger.freeze("pearl", 990, job_id="job-b") + + wallet, release = ledger.release_job_hold("pearl", "job-a") + + assert release.amount_cents == 990 + assert release.job_id == "job-a" + assert wallet.frozen_cents == 990 + records = ledger.records("pearl") + assert [record.job_id for record in records if record.kind == "freeze"] == ["job-b", "job-a"] + + +def test_account_ledger_rejects_insufficient_balance(tmp_path) -> None: + ledger = AccountLedger(tmp_path / "ledger.json") + ledger.admin_recharge("pearl", 100) + + try: + ledger.freeze("pearl", 200) + except ValueError as exc: + assert "insufficient" in str(exc) + else: + raise AssertionError("freeze should fail") + + +def test_account_routes_flow(tmp_path) -> None: + headers = {"X-Roboclaw-Admin-Token": "admin-test"} + from roboclaw.http.routes import account as account_routes + + set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + app = FastAPI() + register_account_routes(app) + client = TestClient(app) + + with patch.dict(account_routes.os.environ, {"EVO_STUDIO_ADMIN_TOKEN": "admin-test"}): + recharge = client.post( + "/api/admin/account/recharge", + json={"username": "pearl", "amount_cents": 10_000, "reason": "test topup"}, + headers=headers, + ) + assert recharge.status_code == 200 + assert recharge.json()["wallet"]["availableBalanceCents"] == 10_000 + assert recharge.json()["wallet"]["availableCents"] == 10_000 + + freeze = client.post( + "/api/billing/freeze", + json={"username": "pearl", "amount_cents": 4_000, "task_name": "train-1"}, + headers=headers, + ) + assert freeze.status_code == 200 + assert freeze.json()["wallet"]["frozenCents"] == 4_000 + + settle = client.post( + "/api/billing/settle", + json={"username": "pearl", "amount_cents": 2_500, "task_name": "train-1"}, + headers=headers, + ) + assert settle.status_code == 200 + assert settle.json()["wallet"]["balanceCents"] == 7_500 + assert settle.json()["wallet"]["frozenCents"] == 1_500 + + balance = client.get("/api/account/balance", params={"username": "pearl"}) + assert balance.status_code == 200 + assert balance.json()["wallet"]["availableCents"] == 6_000 + + records = client.get("/api/account/billing-records", params={"username": "pearl"}) + assert records.status_code == 200 + assert [record["kind"] for record in records.json()["records"]] == ["settle", "freeze", "admin_recharge"] + + set_ledger_for_tests(None) + + +def test_account_routes_topup_order_flow(tmp_path) -> None: + set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + app = FastAPI() + register_account_routes(app) + client = TestClient(app) + + order_response = client.post( + "/api/account/topup-orders", + json={"username": "pearl", "amount_cents": 8_000, "bonus_points": 8, "provider": "mockpay"}, + ) + assert order_response.status_code == 200 + order = order_response.json()["order"] + assert order["status"] == "pending" + assert "paymentConfig" in order_response.json() + + balance = client.get("/api/account/balance", params={"username": "pearl"}) + assert balance.json()["wallet"]["availableCents"] == 0 + + complete_response = client.post( + "/api/account/topup-orders/complete", + json={"order_id": order["orderId"], "provider_order_id": "txn-2"}, + ) + assert complete_response.status_code == 200 + assert complete_response.json()["order"]["status"] == "paid" + assert complete_response.json()["wallet"]["availableBalanceCents"] == 8_000 + assert complete_response.json()["wallet"]["creditPoints"] == 8 + assert complete_response.json()["record"]["kind"] == "payment_recharge" + + orders = client.get("/api/account/topup-orders", params={"username": "pearl"}) + assert orders.status_code == 200 + assert orders.json()["orders"][0]["providerOrderId"] == "txn-2" + + set_ledger_for_tests(None) + + +def test_account_routes_payment_config_from_env(tmp_path, monkeypatch) -> None: + monkeypatch.setenv("EVO_STUDIO_PAYMENT_PROVIDER", "alipay") + monkeypatch.setenv("EVO_STUDIO_PAYEE_NAME", "Evo Studio") + monkeypatch.setenv("EVO_STUDIO_PAYEE_ACCOUNT", "merchant-001") + set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + app = FastAPI() + register_account_routes(app) + client = TestClient(app) + + config = client.get("/api/account/payment-config") + assert config.status_code == 200 + assert config.json()["configured"] is True + assert config.json()["payeeAccount"] == "merchant-001" + + order = client.post( + "/api/account/topup-orders", + json={"username": "pearl", "amount_cents": 1_000, "provider": "alipay"}, + ) + assert order.status_code == 200 + assert order.json()["order"]["payeeName"] == "Evo Studio" + assert order.json()["order"]["payeeAccount"] == "merchant-001" + + set_ledger_for_tests(None) + + +def test_account_routes_dataset_reward_flow(tmp_path) -> None: + set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + app = FastAPI() + register_account_routes(app) + client = TestClient(app) + + reward = client.post( + "/api/account/rewards/dataset-upload", + json={"username": "pearl", "dataset_id": "cloud/verify-so101", "reward_points": 20}, + ) + + assert reward.status_code == 200 + assert reward.json()["granted"] is True + assert reward.json()["wallet"]["availableBalanceCents"] == 0 + assert reward.json()["wallet"]["creditPoints"] == 20 + assert reward.json()["record"]["kind"] == "dataset_reward" + + duplicate = client.post( + "/api/account/rewards/dataset-upload", + json={"username": "pearl", "dataset_id": "cloud/verify-so101", "reward_points": 20}, + ) + assert duplicate.status_code == 200 + assert duplicate.json()["granted"] is False + assert duplicate.json()["wallet"]["availableBalanceCents"] == 0 + assert duplicate.json()["wallet"]["creditPoints"] == 20 + + set_ledger_for_tests(None) + + +def test_account_routes_reject_insufficient_balance(tmp_path) -> None: + from roboclaw.http.routes import account as account_routes + + set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + app = FastAPI() + register_account_routes(app) + client = TestClient(app) + + with patch.dict(account_routes.os.environ, {"EVO_STUDIO_ADMIN_TOKEN": "admin-test"}): + response = client.post( + "/api/billing/freeze", + json={"username": "pearl", "amount_cents": 1}, + headers={"X-Roboclaw-Admin-Token": "admin-test"}, + ) + + assert response.status_code == 409 + assert "insufficient" in response.json()["detail"] + set_ledger_for_tests(None) + + +def test_account_admin_routes_require_token(tmp_path) -> None: + from roboclaw.http.routes import account as account_routes + + set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + app = FastAPI() + register_account_routes(app) + client = TestClient(app) + + with patch.dict(account_routes.os.environ, {"EVO_STUDIO_ADMIN_TOKEN": "admin-test"}): + missing = client.post( + "/api/admin/account/recharge", + json={"username": "pearl", "amount_cents": 10_000}, + ) + wrong = client.post( + "/api/admin/account/recharge", + json={"username": "pearl", "amount_cents": 10_000}, + headers={"X-Roboclaw-Admin-Token": "wrong"}, + ) + + assert missing.status_code == 403 + assert wrong.status_code == 403 + set_ledger_for_tests(None) diff --git a/tests/test_evo_train_routes.py b/tests/test_evo_train_routes.py index f943db75..713332ca 100644 --- a/tests/test_evo_train_routes.py +++ b/tests/test_evo_train_routes.py @@ -14,8 +14,10 @@ from roboclaw.embodied.embodiment.manifest import Manifest from roboclaw.embodied.service import EmbodiedService from roboclaw.cloud.evo_train import EvoTrainBridge, EvoTrainSettings +from roboclaw.account import AccountLedger from roboclaw.http.routes.policies import register_policy_routes from roboclaw.http.routes.train import register_train_routes +from roboclaw.http.routes import train_cloud as train_cloud_routes from roboclaw.http.routes.train_cloud import register_train_cloud_routes from roboclaw.training import TrainingJobStatus @@ -194,6 +196,88 @@ def test_train_start_uses_cloud_bridge_when_enabled(route_app): assert bridge.start_calls[0]["username"] == "13800138000" +def test_train_start_freezes_first_hour_balance_when_cost_is_declared(route_app, tmp_path): + app, _, _ = route_app + bridge = StubBridge() + ledger = AccountLedger(tmp_path / "ledger.json") + ledger.admin_recharge("13800138000", 2_000) + train_cloud_routes.set_ledger_for_tests(ledger) + + with patch("roboclaw.training.service.EvoTrainBridge", return_value=bridge): + register_train_cloud_routes(app, app.state.embodied_service) + + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/api/train/cloud/start", + json={ + "dataset_name": "demo", + "policy_type": "act", + "steps": 5000, + "username": "13800138000", + "hourly_cost_cents": 900, + "service_fee_bps": 1000, + }, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["billing"]["holdCents"] == 990 + assert data["billing"]["record"]["jobId"] == "cloud-job-1" + wallet = ledger.wallet("13800138000") + assert wallet.balance_cents == 2_000 + assert wallet.frozen_cents == 990 + train_cloud_routes.set_ledger_for_tests(None) + + +def test_train_start_rejects_when_balance_is_insufficient(route_app, tmp_path): + app, _, _ = route_app + bridge = StubBridge() + train_cloud_routes.set_ledger_for_tests(AccountLedger(tmp_path / "ledger.json")) + + with patch("roboclaw.training.service.EvoTrainBridge", return_value=bridge): + register_train_cloud_routes(app, app.state.embodied_service) + + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/api/train/cloud/start", + json={"dataset_name": "demo", "username": "13800138000", "hourly_cost_cents": 900}, + ) + + assert resp.status_code == 409 + assert "insufficient" in resp.json()["detail"] + assert bridge.start_calls == [] + train_cloud_routes.set_ledger_for_tests(None) + + +def test_train_cloud_billing_settle_charges_service_fee_and_releases_remainder(route_app, tmp_path): + app, _, _ = route_app + ledger = AccountLedger(tmp_path / "ledger.json") + ledger.admin_recharge("13800138000", 2_000) + ledger.freeze("13800138000", 990, job_id="cloud-job-1", task_name="demo") + train_cloud_routes.set_ledger_for_tests(ledger) + + register_train_cloud_routes(app, app.state.embodied_service) + client = TestClient(app, raise_server_exceptions=False) + resp = client.post( + "/api/train/cloud/billing/settle", + json={ + "username": "13800138000", + "job_id": "cloud-job-1", + "provider_cost_cents": 500, + "service_fee_bps": 1000, + "task_name": "demo", + }, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["chargeCents"] == 550 + assert data["releaseRecord"]["amountCents"] == 440 + assert data["wallet"]["availableBalanceCents"] == 1_450 + assert data["wallet"]["frozenBalanceCents"] == 0 + train_cloud_routes.set_ledger_for_tests(None) + + def test_train_plan_forwards_skill_request_to_evo_train(route_app): app, _, _ = route_app bridge = StubBridge()