Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 109 additions & 85 deletions server/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,52 @@ import (
// made by the API handlers.
const rpcTimeout = 15 * time.Second

// basketStateABI is parsed once at startup and reused across all handler
// invocations.
var basketStateABI abi.ABI

// claimableRevenueABI is parsed once at startup and reused by getClaimableSnapshots.
var claimableRevenueABI abi.ABI

func init() {
var err error

basketStateABI, err = abi.JSON(strings.NewReader(`[{
"inputs": [],
"name": "basketState",
"outputs": [
{"internalType":"address[]","name":"constituents", "type":"address[]"},
{"internalType":"uint256[]","name":"targetWeights", "type":"uint256[]"},
{"internalType":"uint256[]","name":"currentWeights", "type":"uint256[]"},
{"internalType":"uint256[]","name":"balances", "type":"uint256[]"},
{"internalType":"uint256", "name":"totalValue", "type":"uint256"},
{"internalType":"uint256", "name":"nav", "type":"uint256"},
{"internalType":"bool", "name":"rebalancingEnabled", "type":"bool"},
{"internalType":"uint256", "name":"driftThresholdBps", "type":"uint256"},
{"internalType":"uint256", "name":"maxDrift", "type":"uint256"}
],
"stateMutability":"view",
"type":"function"
}]`))
if err != nil {
panic(fmt.Sprintf("api: parse basketStateABI: %v", err))
}

claimableRevenueABI, err = abi.JSON(strings.NewReader(`[{
"inputs":[
{"internalType":"address","name":"account", "type":"address"},
{"internalType":"uint256","name":"snapshotId", "type":"uint256"}
],
"name":"claimableRevenue",
"outputs":[{"internalType":"uint256","name":"","type":"uint256"}],
"stateMutability":"view",
"type":"function"
}]`))
if err != nil {
panic(fmt.Sprintf("api: parse claimableRevenueABI: %v", err))
}
}

// snapshotEntry is shared between getCreatorDashboard and getClaimableSnapshots.
type snapshotEntry struct {
SnapshotID int64 `json:"snapshotId"`
Expand All @@ -34,14 +80,31 @@ type snapshotEntry struct {
ClaimableUsdg string `json:"claimableByWallet"`
}

// NewRouter wires all HTTP routes.
// NewRouter wires all HTTP routes. The RPC client is initialised once here
// and shared across all handlers that need it, avoiding per-request dial
// overhead on cache misses.
func NewRouter(database *db.DB, openAIKey string, openAIModel string) http.Handler {
mux := http.NewServeMux()

rpcURL := os.Getenv("RPC_URL")
if rpcURL == "" {
rpcURL = "https://rpc.testnet.chain.robinhood.com"
}

dialCtx, dialCancel := context.WithTimeout(context.Background(), rpcTimeout)
defer dialCancel()

rpcClient, err := ethclient.DialContext(dialCtx, rpcURL)
if err != nil {
log.Printf("api: RPC dial warning: %v — RPC-dependent handlers will degrade gracefully", err)
rpcClient = nil
}

h := &handler{
db: database,
openAIKey: openAIKey,
openAIModel: openAIModel,
rpcClient: rpcClient,
}

mux.HandleFunc("GET /baskets", h.listBaskets)
Expand All @@ -65,6 +128,7 @@ type handler struct {
db *db.DB
openAIKey string
openAIModel string
rpcClient *ethclient.Client
}

// Marketplace
Expand Down Expand Up @@ -320,7 +384,8 @@ func (h *handler) getBasket(w http.ResponseWriter, r *http.Request) {
}
}

var navPerToken, totalValueUsdg string
navPerToken := "0"
totalValueUsdg := "0"
var maxDriftBps int64
var needsRebalancing bool

Expand All @@ -330,10 +395,14 @@ func (h *handler) getBasket(w http.ResponseWriter, r *http.Request) {
maxDriftBps = state.MaxDriftBps
needsRebalancing = state.NeedsRebalancing
} else {
h.db.QueryRow(`
var dbNav, dbTv string
if h.db.QueryRow(`
SELECT nav_per_token, total_value_usdg FROM nav_history
WHERE basket_address = ? ORDER BY timestamp DESC LIMIT 1`, addr,
).Scan(&navPerToken, &totalValueUsdg)
).Scan(&dbNav, &dbTv) == nil {
navPerToken = dbNav
totalValueUsdg = dbTv
}
}

navChange24h := h.navChangePct(addr, 86400)
Expand Down Expand Up @@ -441,44 +510,16 @@ func (h *handler) getBasketStateFromCache(ctx context.Context, basketAddr string
}

func (h *handler) fetchBasketStateRPC(ctx context.Context, basketAddr string) (*basketStateCache, error) {
basketStateABI, _ := abi.JSON(strings.NewReader(`[{
"inputs": [],
"name": "basketState",
"outputs": [
{"internalType":"address[]","name":"constituents", "type":"address[]"},
{"internalType":"uint256[]","name":"targetWeights", "type":"uint256[]"},
{"internalType":"uint256[]","name":"currentWeights", "type":"uint256[]"},
{"internalType":"uint256[]","name":"balances", "type":"uint256[]"},
{"internalType":"uint256", "name":"totalValue", "type":"uint256"},
{"internalType":"uint256", "name":"nav", "type":"uint256"},
{"internalType":"bool", "name":"rebalancingEnabled", "type":"bool"},
{"internalType":"uint256", "name":"driftThresholdBps", "type":"uint256"},
{"internalType":"uint256", "name":"maxDrift", "type":"uint256"}
],
"stateMutability":"view",
"type":"function"
}]`))

rpcURL := os.Getenv("RPC_URL")
if rpcURL == "" {
rpcURL = "https://rpc.testnet.chain.robinhood.com"
}

dialCtx, dialCancel := context.WithTimeout(ctx, rpcTimeout)
defer dialCancel()

client, err := ethclient.DialContext(dialCtx, rpcURL)
if err != nil {
return nil, fmt.Errorf("dial: %w", err)
if h.rpcClient == nil {
return nil, fmt.Errorf("RPC client not available")
}
defer client.Close()

addr := common.HexToAddress(basketAddr)

callCtx, callCancel := context.WithTimeout(ctx, rpcTimeout)
defer callCancel()

data, err := client.CallContract(callCtx, ethereum.CallMsg{
data, err := h.rpcClient.CallContract(callCtx, ethereum.CallMsg{
To: &addr,
Data: basketStateABI.Methods["basketState"].ID,
}, nil)
Expand All @@ -497,9 +538,17 @@ func (h *handler) fetchBasketStateRPC(ctx context.Context, basketAddr string) (*
balances, _ := unpacked[3].([]*big.Int)
totalValue, _ := unpacked[4].(*big.Int)
nav, _ := unpacked[5].(*big.Int)
needsRebal, _ := unpacked[6].(bool)
rebalancingEnabled, _ := unpacked[6].(bool)
driftThresholdBps, _ := unpacked[7].(*big.Int)
maxDrift, _ := unpacked[8].(*big.Int)

// needsRebalancing is true when rebalancing is enabled and the maximum
// drift across all constituents meets or exceeds the basket's threshold.
needsRebal := false
if rebalancingEnabled && maxDrift != nil && driftThresholdBps != nil {
needsRebal = maxDrift.Cmp(driftThresholdBps) >= 0
}

constituentAddrs := make([]string, len(constituents))
for i, c := range constituents {
constituentAddrs[i] = strings.ToLower(c.Hex())
Expand Down Expand Up @@ -1329,23 +1378,6 @@ func (h *handler) getCreatorDashboard(w http.ResponseWriter, r *http.Request) {
}
}

rpcURL := os.Getenv("RPC_URL")
if rpcURL == "" {
rpcURL = "https://rpc.testnet.chain.robinhood.com"
}

rpcCtx, rpcCancel := context.WithTimeout(r.Context(), rpcTimeout)
defer rpcCancel()

rpcClient, rpcErr := ethclient.DialContext(rpcCtx, rpcURL)
if rpcErr != nil {
log.Printf("api: getCreatorDashboard dial: %v", rpcErr)
rpcClient = nil
}
if rpcClient != nil {
defer rpcClient.Close()
}

totalClaimable := new(big.Int)
result := make([]BasketEntry, 0, len(basketOrder))

Expand All @@ -1356,7 +1388,7 @@ func (h *handler) getCreatorDashboard(w http.ResponseWriter, r *http.Request) {
snaps = []snapshotEntry{}
}

unclaimed := h.getClaimableSnapshots(rpcCtx, rpcClient, wallet, addr, m.creatorToken, snaps)
unclaimed := h.getClaimableSnapshots(r.Context(), wallet, addr, m.creatorToken, snaps)

basketClaimable := new(big.Int)
for _, s := range unclaimed {
Expand Down Expand Up @@ -1387,28 +1419,16 @@ func (h *handler) getCreatorDashboard(w http.ResponseWriter, r *http.Request) {

// getClaimableSnapshots returns claimable amounts per snapshot, reading from
// creator_claimable_cache if fresh (< 60s), otherwise calling claimableRevenue()
// on the contract.
// on the contract via the shared RPC client.
func (h *handler) getClaimableSnapshots(
ctx context.Context,
client *ethclient.Client,
wallet, basketAddr, creatorTokenAddr string,
snapshots []snapshotEntry,
) []snapshotEntry {
if len(snapshots) == 0 {
return []snapshotEntry{}
}

claimableABI, _ := abi.JSON(strings.NewReader(`[{
"inputs":[
{"internalType":"address","name":"account", "type":"address"},
{"internalType":"uint256","name":"snapshotId", "type":"uint256"}
],
"name":"claimableRevenue",
"outputs":[{"internalType":"uint256","name":"","type":"uint256"}],
"stateMutability":"view",
"type":"function"
}]`))

now := time.Now().Unix()

type cachedEntry struct {
Expand Down Expand Up @@ -1446,15 +1466,16 @@ func (h *handler) getClaimableSnapshots(
}

claimable := "0"
if client != nil {
input, err := claimableABI.Pack("claimableRevenue",
rpcSucceeded := false

if h.rpcClient != nil {
input, err := claimableRevenueABI.Pack("claimableRevenue",
walletAddr,
new(big.Int).SetInt64(snap.SnapshotID),
)
if err == nil {
// Per-call timeout within the overall request context.
callCtx, callCancel := context.WithTimeout(ctx, rpcTimeout)
data, err := client.CallContract(callCtx, ethereum.CallMsg{
data, err := h.rpcClient.CallContract(callCtx, ethereum.CallMsg{
To: &ctAddr,
Data: input,
}, nil)
Expand All @@ -1463,26 +1484,31 @@ func (h *handler) getClaimableSnapshots(
if err != nil {
log.Printf("api: claimableRevenue RPC snapshot=%d wallet=%s: %v", snap.SnapshotID, wallet, err)
} else {
unpacked, err := claimableABI.Methods["claimableRevenue"].Outputs.Unpack(data)
unpacked, err := claimableRevenueABI.Methods["claimableRevenue"].Outputs.Unpack(data)
if err == nil && len(unpacked) > 0 {
if amount, ok := unpacked[0].(*big.Int); ok && amount != nil {
claimable = amount.String()
rpcSucceeded = true
}
}

h.db.Exec(`
INSERT INTO creator_claimable_cache
(wallet_address, snapshot_id, basket_address, claimable_usdg, cached_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(wallet_address, snapshot_id, basket_address) DO UPDATE SET
claimable_usdg = excluded.claimable_usdg,
cached_at = excluded.cached_at`,
wallet, snap.SnapshotID, basketAddr, claimable, now,
)
}
}
}

// Only write to cache on RPC success. Caching a zero on RPC failure
// would poison the cache and hide real claimable amounts.
if rpcSucceeded {
h.db.Exec(`
INSERT INTO creator_claimable_cache
(wallet_address, snapshot_id, basket_address, claimable_usdg, cached_at)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(wallet_address, snapshot_id, basket_address) DO UPDATE SET
claimable_usdg = excluded.claimable_usdg,
cached_at = excluded.cached_at`,
wallet, snap.SnapshotID, basketAddr, claimable, now,
)
}

snap.ClaimableUsdg = claimable
result = append(result, snap)
}
Expand Down Expand Up @@ -1544,7 +1570,6 @@ func (h *handler) aiCompose(w http.ResponseWriter, r *http.Request) {
var req struct {
Thesis string `json:"thesis"`
}
// 32KB limit
if err := json.NewDecoder(io.LimitReader(r.Body, 32768)).Decode(&req); err != nil || len(req.Thesis) < 20 {
jsonError(w, "thesis must be at least 20 characters", http.StatusBadRequest)
return
Expand Down Expand Up @@ -1601,7 +1626,6 @@ func jsonOK(w http.ResponseWriter, v any) {
json.NewEncoder(w).Encode(v)
}

// jsonError returns a JSON error body containing only the message string.
func jsonError(w http.ResponseWriter, msg string, code int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
Expand Down Expand Up @@ -1662,4 +1686,4 @@ func (h *handler) serveDocs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write([]byte(swaggerHTML))
}
}
Loading
Loading