diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 75830cc11..fb87cb220 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -25,6 +25,7 @@ jobs: github.com/solana-foundation/pay-kit/go/paycore \ github.com/solana-foundation/pay-kit/go/paycore/solanatx \ github.com/solana-foundation/pay-kit/go/paycore/signer \ + github.com/solana-foundation/pay-kit/go/paycore/paymentchannels \ github.com/solana-foundation/pay-kit/go/paykit \ github.com/solana-foundation/pay-kit/go/protocols/mpp \ github.com/solana-foundation/pay-kit/go/protocols/mpp/core \ @@ -147,3 +148,68 @@ jobs: # x402 server is the go paykit server under test). X402_HARNESS_SERVERS: go run: pnpm exec vitest run test/e2e.test.ts --testTimeout 180000 + + playground-go: + name: "Payment links E2E: Playground (Go)" + needs: test-go + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-go@v6 + with: + go-version-file: go/go.mod + cache-dependency-path: go/go.sum + - uses: pnpm/action-setup@v5 + with: + package_json_file: package.json + - uses: actions/setup-node@v5 + with: + node-version: 22 + cache: pnpm + cache-dependency-path: typescript/pnpm-lock.yaml + - name: Install Surfnet helper dependencies + working-directory: harness + run: pnpm install --frozen-lockfile + - name: Start Surfnet + working-directory: . + run: | + node harness/start-surfnet-proxy.mjs & + ready=0 + for i in $(seq 1 50); do + if curl -sf -X POST http://localhost:8899 \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc":"2.0","id":1,"method":"getHealth","params":[]}' \ + | grep -q '"result":"ok"'; then + ready=1 + break + fi + sleep 0.2 + done + test "$ready" -eq 1 + - name: Build Go playground API + working-directory: go + env: + GOCACHE: /tmp/go-build-cache + run: go build ./examples/playground-api + - name: Start Go playground API + working-directory: go + env: + GOCACHE: /tmp/go-build-cache + PORT: "3002" + NETWORK: localnet + RPC_URL: http://localhost:8899 + MPP_SECRET_KEY: playground-ci-secret + run: go run ./examples/playground-api & + - name: Wait for playground server + working-directory: . + run: | + for i in $(seq 1 30); do + curl -sf http://localhost:3002/api/v1/health && break + sleep 1 + done + - name: Install HTML dependencies & Playwright + working-directory: html + run: npm install && npx playwright install chromium + - name: Run Playwright tests (Go playground) + working-directory: html + run: FORTUNE_PATH=/api/v1/fortune npm run test:e2e:go diff --git a/.gitignore b/.gitignore index df7b14cd6..26dfc384a 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ __pycache__/ .claude/ .gocache .build/ +go/build/ # Generated API docs — see `just docs`. Single tree at the repo root; # each language emits markdown into `docs/api//`. diff --git a/go/.golangci.yml b/go/.golangci.yml index 12bcb2556..eb1b47657 100644 --- a/go/.golangci.yml +++ b/go/.golangci.yml @@ -30,6 +30,12 @@ linters: exclusions: paths: - examples/simple-server + # Codama-generated on-chain program clients (regenerated via + # `pnpm run payment-channels:go` in skills/.../codegen). These files carry + # a "DO NOT EDIT" header, so lint findings can't be fixed in place; the + # generator's output is byte-for-byte reproducible and matches the + # gagliardetto/solana-go generated-client conventions. + - protocols/programs/.*/.*\.go rules: - path: _test\.go linters: diff --git a/go/Justfile b/go/Justfile index 245bed382..f0d670627 100644 --- a/go/Justfile +++ b/go/Justfile @@ -19,20 +19,37 @@ test: fmt: gofmt -s -w . -# Lint pipeline: gofmt-check + go vet + staticcheck +# Lint pipeline: gofmt-check + go vet + staticcheck. +# Codama-generated program clients under protocols/programs// carry a +# "DO NOT EDIT" header and follow gagliardetto/solana-go's generated-client +# idioms (unkeyed VariantType literals, embedded-field selectors), which trip +# `go vet -composites` and staticcheck QF1008. They are byte-for-byte +# reproducible from idl/.json, so we exclude the generated dirs from the +# authored-code lint gates — mirroring the golangci-lint path exclusion and the +# Rust crate's pure-passthrough src/generated/ treatment. lint: test -z "$(gofmt -s -l . | tee /dev/stderr)" - go vet ./... - go run honnef.co/go/tools/cmd/staticcheck@latest ./... + # -composites is disabled module-wide because the generated clients (and any + # package importing them, e.g. the parity guard) surface gagliardetto's + # unkeyed VariantType literals; golangci-lint's govet (the CI gate) already + # runs with composites off, so this keeps the local justfile vet aligned. + go vet -composites=false $(go list ./... | grep -v '/protocols/programs/paymentchannels$') + go run honnef.co/go/tools/cmd/staticcheck@latest $(go list ./... | grep -v '/protocols/programs/paymentchannels$') # Run module-deps audit (govulncheck) audit: go run golang.org/x/vuln/cmd/govulncheck@latest ./... -# Test with coverage gate (defaults to 90) +# Test with coverage gate (defaults to 90). +# The codama-generated payment-channels client and the runnable examples are +# filtered from the profile before the gate: the generated client is +# byte-for-byte reproducible from idl/payment-channels.json (the same +# carve-out the lint recipe applies), and examples are demo binaries that the +# other SDKs in this repo also keep out of their coverage denominators. test-cover gate="90": mkdir -p build - go test -coverprofile=build/coverage.out ./... + go test -coverprofile=build/coverage.raw.out ./... + grep -v -e '/protocols/programs/paymentchannels/' -e '/examples/' build/coverage.raw.out > build/coverage.out go tool cover -func=build/coverage.out @awk -v gate={{gate}} '/^total:/ {pct=$NF+0; if (pct+0 < gate+0) {printf "coverage %s below gate %d%%\n",$NF,gate; exit 1} else {printf "coverage %s meets gate %d%%\n",$NF,gate}}' <(go tool cover -func=build/coverage.out) @@ -43,3 +60,6 @@ check: build lint audit test-cover serve-example port="4567": go run ./examples/simple-server +# Boot the playground API example (same endpoints as typescript/examples/playground-api) +serve-playground port="3000": + PORT={{port}} go run ./examples/playground-api diff --git a/go/README.md b/go/README.md index 1a8d4ab3f..2f002d882 100644 --- a/go/README.md +++ b/go/README.md @@ -112,7 +112,7 @@ The Solana charge intent, in both pull (client-signed) and push |---|:---:|:---:| | `mpp/charge/pull` | ✅ | ✅ | | `mpp/charge/push` | ✅ | ✅ | -| `mpp/session` | — | — | +| `mpp/session` | ✅ | ✅ | | `mpp/subscription` | — | — | For `mpp/charge/pull`: the server owns the full lifecycle. It issues @@ -130,6 +130,64 @@ with `getTransaction`, rejects failed or missing metadata, reuses the same structural transaction verifier as pull mode, consumes the signature through replay storage, and emits the same receipt shape. +For `mpp/session`: both sides ship. + +Client side: + +- session challenge parsing and selection (`ParseSessionChallenge`, + `SelectSessionChallenge` with network/currency/mode filters; omitted + or empty `modes` means push-only), +- payment-channel open builders driven by the challenge (deposit + defaults to the cap, grace period 900s, random salt, token program + resolved from the currency so Token-2022 mints work, operator as fee + payer with a payer partial-sign, challenge `recentBlockhash` echo, + `PendingServerSignature` placeholder) for push and pull/clientVoucher, +- `ActiveSession` voucher signing with the prepare/record watermark + split, `SessionConsumer` for metered deliveries, and the metered SSE + layer (`SseDecoder`, `MeteredSseSession`, `MeteredSseStream`, + `HTTPCommitTransport`). + +Server side (`NewSession`, mirroring the TypeScript `session()` method +over the rust `SessionServer` core): + +- HMAC-bound 402 session challenges (`Session.Challenge`): cap clamped + to the server max, `minVoucherDelta` only when positive, `modes` + omitted when push-only, `pullVoucherStrategy` only when pull is + offered, optional `recentBlockhash` prefetch via the configured RPC + client, +- credential verification (`Session.VerifyCredential`) dispatching the + open / voucher / commit / topUp / close actions over an atomic + per-channel `ChannelStore` with the harness-tested voucher check + order, idempotent open replays that never reset the watermark, and a + re-drivable close until a settlement signature is recorded, +- on-chain open handling: structural `VerifyOpenTx` for client-broadcast + opens (legacy and v0 encodings, payload signature binding, channel + PDA re-derivation) and `SubmitOpenTx` server broadcast that completes + the fee-payer signature and waits for confirmation, +- the reserve/commit metering side channel (`Session.Routes`) hosts + mount at `POST /__402/session/deliveries` and + `POST /__402/session/commit` (a TypeScript-server extension, not in + the rust crate), plus `SessionMiddleware` for `net/http` routes, +- a server-side metered SSE writer (`MeteredStream`) emitting the + `mpp.metering` / `mpp.usage` / `[DONE]` frames the client decoder + consumes, +- an idle-close watchdog (`CloseDelay`) and close settlement + (settle_and_finalize + Ed25519 precompile + distribute in one + merchant-signed transaction), both of which settle on-chain only when + a merchant `Signer` and an `RPC` client are configured; without them + payload claims are trusted as provided, matching rust with `rpc_url` + unset. + +Out of scope: pull/operatedVoucher (multi-delegate program builders) on +both sides, including the `initMultiDelegateTx` submission seam in the +TypeScript open handler, the SPL `approve` delegation transaction for +non-channel pull opens (the on-chain delegation happens out of band), +and a `SessionFetch`-style drop-in fetch wrapper. The TypeScript +`SessionFetchClient` semantics that wrapper would own (per-channel +commit watermark reset on re-open, failed-commit retryability without +latching) therefore have no Go counterpart; the `ActiveSession` +prepare/record split is the building block callers compose instead. + ## Examples One runnable example ships with this package: diff --git a/go/cmd/conformance/main.go b/go/cmd/conformance/main.go index 9d7116a58..98254826d 100644 --- a/go/cmd/conformance/main.go +++ b/go/cmd/conformance/main.go @@ -1,8 +1,7 @@ // Command conformance is the Go cross-SDK conformance-vector runner. // -// It honors the same stdin/stdout contract as the TypeScript reference -// runner (harness/src/conformance/ts-runner.ts): read one conformance -// vector as JSON on stdin, drive the real Go pay_kit SDK +// It honors the harness conformance runner stdin/stdout contract: read one +// conformance vector as JSON on stdin, drive the real Go pay_kit SDK // (paycore + protocols/mpp client build, server pre-broadcast verify, and // the wire canonical-JSON / base64url encoders) for the requested mode, and // emit one RunnerResult line as JSON on stdout. @@ -27,11 +26,13 @@ import ( "io" "os" "regexp" + "strconv" "strings" solana "github.com/gagliardetto/solana-go" "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" "github.com/solana-foundation/pay-kit/go/paycore/solanatx" "github.com/solana-foundation/pay-kit/go/protocols/mpp/client" "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" @@ -49,144 +50,312 @@ const ( defaultSPLDecimals = 6 ) -// Vector mirrors harness/src/conformance/schema.ts ConformanceVector. +// Vector is the top-level conformance-vector shape consumed from stdin. type Vector struct { - ID string `json:"id"` - Intent string `json:"intent"` - Mode string `json:"mode"` - Description string `json:"description"` - Input VectorInput `json:"input"` - Expect json.RawMessage `json:"expect"` + // ID is the unique vector identifier, echoed back in RunnerResult so + // the harness can pair each result line with its vector. + ID string `json:"id"` + // Intent selects the runner path: "x402-exact" dispatches to the x402 + // envelope oracle, anything else runs the MPP charge paths. + Intent string `json:"intent"` + // Mode picks what to exercise: "build-transaction", + // "verify-transaction", or "canonical-bytes". + Mode string `json:"mode"` + // Description is the human-readable summary of what the vector + // exercises; the runner never branches on it. + Description string `json:"description"` + // Input carries the per-mode inputs (request, pinned fixtures, + // encoder payloads) this runner consumes. + Input VectorInput `json:"input"` + // Expect is the expected-outcome JSON asserted by the harness driver; + // it is opaque to this runner and passed through unread. + Expect json.RawMessage `json:"expect"` } -// VectorInput mirrors schema.ts VectorInput. +// VectorInput carries the per-mode inputs of a conformance vector. type VectorInput struct { - Request *ChargeRequest `json:"request"` - Transaction string `json:"transaction"` - SignerSecretKey []byte `json:"signerSecretKey"` - RPCFixtures *RPCFixtures `json:"rpcFixtures"` - Value json.RawMessage `json:"value"` + // Request is the charge request that drives the build/verify modes; + // nil for encoder-only (canonical-bytes) vectors. + Request *ChargeRequest `json:"request"` + // Transaction is a pinned base64 wire transaction to verify; empty in + // verify mode means build one from Request first and verify that. + Transaction string `json:"transaction"` + // SignerSecretKey is the 64-byte ed25519 secret key (carried as a JSON + // byte array) acting as transfer authority and default fee payer on + // the client build path. + SignerSecretKey []byte `json:"signerSecretKey"` + // RPCFixtures pins RPC-derived values (mint owners) so build/verify + // stay RPC-free; nil when the vector needs none. + RPCFixtures *RPCFixtures `json:"rpcFixtures"` + // Value is a raw JSON value to canonicalize (JCS) and base64url-encode + // in canonical-bytes mode; absent otherwise. + Value json.RawMessage `json:"value"` + // EncodeBase64URL supplies raw bytes (hex or UTF-8) to base64url-encode + // in canonical-bytes mode; nil otherwise. EncodeBase64URL *EncodeBase64URL `json:"encodeBase64Url"` - ChallengeID *ChallengeID `json:"challengeId"` - - // x402-exact inputs (mirror schema.ts VectorInput x402 fields). - X402Offer *X402Offer `json:"x402Offer"` - X402Version int `json:"x402Version"` - X402PinnedTransaction string `json:"x402PinnedTransaction"` - X402ServerNetwork string `json:"x402ServerNetwork"` - X402ServerRecipient string `json:"x402ServerRecipient"` - X402ServerCurrency string `json:"x402ServerCurrency"` - X402ServerAmount string `json:"x402ServerAmount"` - X402PaymentHeader string `json:"x402PaymentHeader"` - - // x402-exact v2 extensions inputs (mirror schema.ts VectorInput). - X402AdvertisedExtensions json.RawMessage `json:"x402AdvertisedExtensions"` - X402PaymentIdentifierID string `json:"x402PaymentIdentifierId"` - X402ServerRequiresPaymentIdentifier bool `json:"x402ServerRequiresPaymentIdentifier"` + // ChallengeID supplies the inputs to the MPP challenge-id HMAC-SHA256 + // derivation in canonical-bytes mode; nil otherwise. + ChallengeID *ChallengeID `json:"challengeId"` + // VoucherPreimage supplies the inputs to the 48-byte session voucher + // preimage in canonical-bytes mode; nil otherwise. + VoucherPreimage *VoucherPreimage `json:"voucherPreimage"` + + // x402-exact inputs. + X402Offer *X402Offer `json:"x402Offer"` + // X402Version is the x402Version the build path should produce: + // 1 builds the legacy top-level scheme/network envelope, 2 the + // accepted-echo envelope, and 0 (absent) exercises the default + // producer, which also emits 2. + X402Version int `json:"x402Version"` + // X402PinnedTransaction is the placeholder base64 transaction proof + // placed in payload.transaction on build; the envelope shape, not + // these bytes, is the conformance oracle. + X402PinnedTransaction string `json:"x402PinnedTransaction"` + // X402ServerNetwork is the route network the verify gate expects; + // cluster slugs, legacy slugs, and CAIP-2 ids are all normalized to + // CAIP-2 before comparison. + X402ServerNetwork string `json:"x402ServerNetwork"` + // X402ServerRecipient is the route recipient (base58 address) the + // envelope's accepted.payTo must equal on verify. + X402ServerRecipient string `json:"x402ServerRecipient"` + // X402ServerCurrency is the route asset the envelope's accepted.asset + // must equal on verify. + X402ServerCurrency string `json:"x402ServerCurrency"` + // X402ServerAmount is the route amount (decimal string of token base + // units) the envelope's accepted.amount must equal on verify. + X402ServerAmount string `json:"x402ServerAmount"` + // X402PaymentHeader is the base64(JSON) x402 payment header the verify + // mode decodes and gates against the route. + X402PaymentHeader string `json:"x402PaymentHeader"` + + // x402-exact extensions inputs. + X402AdvertisedExtensions json.RawMessage `json:"x402AdvertisedExtensions"` + // X402PaymentIdentifierID pins the payment-identifier id appended when + // the advertised extensions require one; empty means generate a fresh + // id via the production helper. + X402PaymentIdentifierID string `json:"x402PaymentIdentifierId"` + // X402ServerRequiresPaymentIdentifier makes verify reject envelopes + // whose echoed extensions carry no valid payment-identifier id. + X402ServerRequiresPaymentIdentifier bool `json:"x402ServerRequiresPaymentIdentifier"` } -// ChargeRequest mirrors schema.ts VectorChargeRequest. +// ChargeRequest is the charge-intent request carried in a vector input. type ChargeRequest struct { - Amount string `json:"amount"` - Currency string `json:"currency"` - ExternalID string `json:"externalId"` - Recipient string `json:"recipient"` - PayTo string `json:"payTo"` - Asset string `json:"asset"` - MethodDetails *MethodDetails `json:"methodDetails"` - ComputeUnitLimit *uint32 `json:"computeUnitLimit"` - ComputeUnitPrice *string `json:"computeUnitPrice"` + // Amount is the total charge as a decimal string of integer base units + // (lamports for SOL, token base units for SPL); no display decimals. + Amount string `json:"amount"` + // Currency is the asset symbol (e.g. "USDC", "SOL") or mint address; + // Asset takes precedence over it when both are set. + Currency string `json:"currency"` + // ExternalID is an external reference recorded on-chain as a Memo + // instruction; empty means no memo is added. + ExternalID string `json:"externalId"` + // Recipient is the destination address (base58); PayTo takes + // precedence over it when both are set. + Recipient string `json:"recipient"` + // PayTo is the preferred recipient field; per the conformance + // precedence rules it wins over Recipient. + PayTo string `json:"payTo"` + // Asset is the preferred asset field; per the conformance precedence + // rules it wins over Currency. + Asset string `json:"asset"` + // MethodDetails carries the Solana-specific build/verify knobs + // (network, blockhash, token program, splits); nil uses defaults. + MethodDetails *MethodDetails `json:"methodDetails"` + // ComputeUnitLimit caps compute units for the built transaction; + // nil leaves the SDK default in effect. + ComputeUnitLimit *uint32 `json:"computeUnitLimit"` + // ComputeUnitPrice is the priority fee in micro-lamports per compute + // unit, as a decimal string; nil leaves the SDK default in effect. + ComputeUnitPrice *string `json:"computeUnitPrice"` } -// MethodDetails mirrors schema.ts VectorChargeRequest.methodDetails. +// MethodDetails is the methodDetails block of a vector charge request. type MethodDetails struct { - Network string `json:"network"` - Decimals *uint8 `json:"decimals"` - TokenProgram string `json:"tokenProgram"` - RecentBlockhash string `json:"recentBlockhash"` - FeePayer *bool `json:"feePayer"` - FeePayerKey string `json:"feePayerKey"` - Splits []paycore.Split `json:"splits"` + // Network is the Solana cluster slug (e.g. "mainnet", "devnet"); + // empty defaults to mainnet. + Network string `json:"network"` + // Decimals is the SPL mint decimals used for transferChecked; nil + // defaults to 6 for non-SOL currencies and is unused for SOL. + Decimals *uint8 `json:"decimals"` + // TokenProgram is the base58 id of the program owning the mint (Token + // or Token-2022); empty resolves via the rpc-fixture mint owner, then + // the default-by-currency table, keeping the run RPC-free. + TokenProgram string `json:"tokenProgram"` + // RecentBlockhash pins the blockhash (base58) used to build the + // transaction so no live validator is contacted. + RecentBlockhash string `json:"recentBlockhash"` + // FeePayer enables server fee sponsorship when true and FeePayerKey is + // set; nil or false keeps the signer as fee payer. + FeePayer *bool `json:"feePayer"` + // FeePayerKey is the base58 public key of the sponsoring fee payer + // account used when FeePayer is true. + FeePayerKey string `json:"feePayerKey"` + // Splits lists additional same-asset transfers carved out of the total + // amount, each with its own recipient. + Splits []paycore.Split `json:"splits"` } -// RPCFixtures mirrors schema.ts VectorRpcFixtures. +// RPCFixtures pins the RPC-derived values a vector needs so the run stays +// RPC-free. type RPCFixtures struct { - RecentBlockhash string `json:"recentBlockhash"` - MintOwners map[string]string `json:"mintOwners"` + // RecentBlockhash is a pinned blockhash (base58) a vector may carry; + // the build path reads the blockhash from methodDetails, so this stays + // informational for this runner. + RecentBlockhash string `json:"recentBlockhash"` + // MintOwners maps mint address (base58) to its owning token program + // (base58), standing in for the getAccountInfo owner lookup. + MintOwners map[string]string `json:"mintOwners"` } -// EncodeBase64URL mirrors schema.ts encodeBase64Url. +// EncodeBase64URL holds the raw bytes (hex or UTF-8) to base64url-encode. type EncodeBase64URL struct { + // HexBytes is a hex string decoded to raw bytes before base64url + // encoding; it takes precedence over UTF8 when both are set. HexBytes string `json:"hexBytes"` - UTF8 string `json:"utf8"` + // UTF8 is a literal string whose UTF-8 bytes are base64url-encoded + // when HexBytes is empty. + UTF8 string `json:"utf8"` } -// ChallengeID mirrors schema.ts VectorInput.challengeId: the inputs to the -// MPP charge challenge-id HMAC derivation. +// ChallengeID holds the inputs to the MPP charge challenge-id HMAC +// derivation. type ChallengeID struct { + // SecretKey is the server-side secret keying the HMAC-SHA256; it never + // appears in the HMAC input itself. SecretKey string `json:"secretKey"` - Realm string `json:"realm"` - Method string `json:"method"` - Intent string `json:"intent"` - Request string `json:"request"` - Expires string `json:"expires"` - Digest string `json:"digest"` - Opaque string `json:"opaque"` + // Realm is the challenge realm parameter, the first "|"-joined segment + // of the HMAC input. + Realm string `json:"realm"` + // Method is the HTTP method bound into the challenge id. + Method string `json:"method"` + // Intent is the MPP intent (e.g. "charge") bound into the challenge id. + Intent string `json:"intent"` + // Request is the request binding segment of the HMAC input, joined + // verbatim; empty when the challenge omits it. + Request string `json:"request"` + // Expires is the challenge expiry exactly as carried on the wire, + // joined verbatim into the HMAC input. + Expires string `json:"expires"` + // Digest is the body digest challenge parameter; absent optionals join + // as empty strings. + Digest string `json:"digest"` + // Opaque is the opaque challenge parameter (base64url JSON on the + // wire), joined verbatim; empty when absent. + Opaque string `json:"opaque"` } -// Transfer mirrors schema.ts TransactionShape.transfers element. +// VoucherPreimage holds the inputs to the 48-byte session voucher message +// bytes. +type VoucherPreimage struct { + // ChannelID is the payment-channel address (base58); its 32 raw bytes + // form the preimage prefix. + ChannelID string `json:"channelId"` + // CumulativeAmount is the channel's cumulative spend in token base + // units, as a decimal u64 string; encoded little-endian at offset 32. + CumulativeAmount string `json:"cumulativeAmount"` + // ExpiresAt is the voucher expiry as unix epoch seconds; encoded as a + // little-endian i64 at offset 40. + ExpiresAt int64 `json:"expiresAt"` +} + +// Transfer is one decoded transfer in a transaction shape. type Transfer struct { - Kind string `json:"kind"` - Destination string `json:"destination,omitempty"` + // Kind is the transfer family: "sol" for System Program transfers, + // "spl" for token-program transferChecked. + Kind string `json:"kind"` + // Destination is the base58 receiving account: the recipient wallet + // for SOL, the destination token account for SPL. + Destination string `json:"destination,omitempty"` + // DestinationOwner is the base58 wallet owning the destination token + // account; this decoder leaves it empty (omitted on the wire). DestinationOwner string `json:"destinationOwner,omitempty"` - Mint string `json:"mint,omitempty"` - Amount string `json:"amount"` - Decimals *uint8 `json:"decimals,omitempty"` - TokenProgram string `json:"tokenProgram,omitempty"` + // Mint is the base58 token mint of an SPL transfer; omitted for SOL. + Mint string `json:"mint,omitempty"` + // Amount is the transferred quantity as a decimal u64 string in base + // units: lamports for SOL, token base units for SPL. + Amount string `json:"amount"` + // Decimals is the decimals byte asserted by transferChecked; nil for + // SOL transfers, which carry none. + Decimals *uint8 `json:"decimals,omitempty"` + // TokenProgram is the base58 id of the program executing the transfer + // (Token or Token-2022); omitted for SOL. + TokenProgram string `json:"tokenProgram,omitempty"` } -// TransactionShape mirrors schema.ts TransactionShape. +// TransactionShape is the decoded semantic shape of a built transaction. type TransactionShape struct { - FeePayer string `json:"feePayer,omitempty"` - Transfers []Transfer `json:"transfers,omitempty"` - ForbiddenPrograms []string `json:"forbiddenPrograms,omitempty"` - MaxComputeUnitLimit *uint32 `json:"maxComputeUnitLimit,omitempty"` - MaxComputeUnitPrice string `json:"maxComputeUnitPrice,omitempty"` - Memo []string `json:"memo,omitempty"` + // FeePayer is the base58 key of account[0], the transaction fee payer. + FeePayer string `json:"feePayer,omitempty"` + // Transfers lists the decoded SOL and SPL transfers in instruction + // order. + Transfers []Transfer `json:"transfers,omitempty"` + // ForbiddenPrograms lists base58 ids of disallowed programs found in + // the transaction; this decoder reports none, so it stays empty and is + // omitted on the wire. + ForbiddenPrograms []string `json:"forbiddenPrograms,omitempty"` + // MaxComputeUnitLimit is the cap from the ComputeBudget + // SetComputeUnitLimit instruction; nil when the transaction sets none. + MaxComputeUnitLimit *uint32 `json:"maxComputeUnitLimit,omitempty"` + // MaxComputeUnitPrice is the SetComputeUnitPrice value in + // micro-lamports per compute unit, as a decimal u64 string; empty when + // the transaction sets none. + MaxComputeUnitPrice string `json:"maxComputeUnitPrice,omitempty"` + // Memo lists the Memo Program instruction payloads as UTF-8 strings, + // in instruction order. + Memo []string `json:"memo,omitempty"` } -// ExactBytes mirrors schema.ts RunnerResult.exactBytes. +// ExactBytes carries the exact encoder outputs for canonical-bytes vectors. type ExactBytes struct { + // CanonicalJSON is the canonical (JCS) JSON text produced by the wire + // encoder, where byte-for-byte agreement across SDKs is asserted. CanonicalJSON string `json:"canonicalJson,omitempty"` - Base64URL string `json:"base64Url,omitempty"` - Bytes []int `json:"bytes,omitempty"` + // Base64URL is the unpadded base64url encoding of the produced bytes + // (canonical JSON, raw input bytes, challenge id, or voucher preimage). + Base64URL string `json:"base64Url,omitempty"` + // Bytes is the raw output, one int (0-255) per byte, so the harness + // can diff exact bytes across SDKs. + Bytes []int `json:"bytes,omitempty"` } -// RunnerResult mirrors schema.ts RunnerResult. +// RunnerResult is the single JSON result line emitted on stdout. type RunnerResult struct { - ID string `json:"id"` - Outcome string `json:"outcome"` - TransactionShape *TransactionShape `json:"transactionShape,omitempty"` + // ID echoes the vector's id so the harness can pair result to vector. + ID string `json:"id"` + // Outcome is "accept" or "reject". + Outcome string `json:"outcome"` + // TransactionShape is the decoded semantic shape for accepted MPP + // build/verify vectors; nil for other modes and on reject. + TransactionShape *TransactionShape `json:"transactionShape,omitempty"` + // X402EnvelopeShape is the decoded envelope shape for accepted + // x402-exact vectors; nil for other intents and on reject. X402EnvelopeShape *X402EnvelopeShape `json:"x402EnvelopeShape,omitempty"` - ExactBytes *ExactBytes `json:"exactBytes,omitempty"` - Error string `json:"error,omitempty"` - RejectCode string `json:"rejectCode,omitempty"` + // ExactBytes carries the encoder outputs for canonical-bytes vectors; + // nil for other modes. + ExactBytes *ExactBytes `json:"exactBytes,omitempty"` + // Error is the SDK's native error message when Outcome is "reject"; + // omitted on accept. + Error string `json:"error,omitempty"` + // RejectCode is the normalized cross-SDK reject category from + // classifyReject; empty when the message is unclassified so the + // harness can surface it instead of silently passing. + RejectCode string `json:"rejectCode,omitempty"` } // rejectPattern pairs a compiled regex with the normalized RejectCode it // classifies a Go SDK reject message into. type rejectPattern struct { - re *regexp.Regexp - code string + re *regexp.Regexp // case-insensitive pattern matched against the SDK reject message + code string // normalized cross-SDK RejectCode emitted when re matches } -// rejectPatterns mirrors harness/src/conformance/reject.ts: it maps the Go -// pay_kit SDK's native reject error strings onto the shared cross-SDK -// RejectCode vocabulary. The Go messages are tuned here against the real -// strings the SDK emits (e.g. "no matching token transfer for ..."), so the -// alternation includes "token". As in the reference, a transferChecked -// decimals mismatch is enforced through the transfer match key and so -// honestly surfaces as the generic no-matching-transfer category, not a +// rejectPatterns maps the Go pay_kit SDK's native reject error strings onto +// the shared cross-SDK RejectCode vocabulary. The Go messages are tuned here +// against the real strings the SDK emits (e.g. "no matching token transfer +// for ..."), so the alternation includes "token". A transferChecked decimals +// mismatch is enforced through the transfer match key and so honestly +// surfaces as the generic no-matching-transfer category, not a // decimals-specific code. var rejectPatterns = []rejectPattern{ {regexp.MustCompile(`(?i)compute unit price .* exceeds (maximum|cap)`), "compute-price-over-cap"}, @@ -201,11 +370,10 @@ var rejectPatterns = []rejectPattern{ // x402-exact reject categories. `unsupported x402 version` must be // checked before the generic invalid/payload fallback (the message is // "invalid payload: unsupported x402 version: N"). `network mismatch` - // likewise precedes the fallback. Mirrors harness/src/conformance/reject.ts. + // likewise precedes the fallback. {regexp.MustCompile(`(?i)unsupported x402 version`), "unsupported-version"}, {regexp.MustCompile(`(?i)network mismatch`), "wrong-network"}, - // payment-identifier gate: required-but-missing/invalid id. Mirrors - // harness/src/conformance/reject.ts payment-identifier-required. + // payment-identifier gate: required-but-missing/invalid id. {regexp.MustCompile(`(?i)payment.identifier .*(required|missing|invalid)`), "payment-identifier-required"}, } @@ -233,7 +401,7 @@ func classifyReject(message string) string { // same interface the Go client build path consumes. The vectors carry the // signer as the transfer authority / fee payer. type localSigner struct { - priv solana.PrivateKey + priv solana.PrivateKey // 64-byte ed25519 keypair (seed || public key) backing PublicKey and Sign } func newLocalSigner(secret []byte) (*localSigner, error) { @@ -328,8 +496,8 @@ func rejected(id string, err error) RunnerResult { return RunnerResult{ID: id, Outcome: "reject", Error: msg, RejectCode: classifyReject(msg)} } -// flattenRequest applies the same precedence rules as the TS reference -// runner: top-level asset / payTo win over currency / recipient, and the +// flattenRequest applies the conformance contract's precedence rules: +// top-level asset / payTo win over currency / recipient, and the // token program resolves explicit -> rpc-fixture mint owner -> // default-by-currency so the build path stays RPC-free. It returns the // charge fields plus the resolved paycore.MethodDetails the Go SDK consumes. @@ -506,18 +674,39 @@ func runCanonicalBytes(vector Vector) (*ExactBytes, error) { if c := in.ChallengeID; c != nil { // base64url(HMAC-SHA256(secret, realm|method|intent|request|expires| // digest|opaque)); absent optionals join as empty strings. Drives the - // production SDK derivation (wire.ComputeChallengeID), which mirrors - // rust compute_challenge_id (protocol/core/challenge.rs). + // production SDK derivation (wire.ComputeChallengeID). eb.Base64URL = wire.ComputeChallengeID( c.SecretKey, c.Realm, c.Method, c.Intent, c.Request, c.Expires, c.Digest, c.Opaque, ) } + if v := in.VoucherPreimage; v != nil { + // The 48-byte session voucher preimage, computed by the production SDK + // glue (paymentchannels.VoucherMessageBytes) so a byte mismatch is + // caught here cross-SDK rather than behind a live channel. + channel, err := solana.PublicKeyFromBase58(v.ChannelID) + if err != nil { + return nil, fmt.Errorf("invalid voucher channelId: %w", err) + } + cumulative, err := strconv.ParseUint(v.CumulativeAmount, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid voucher cumulativeAmount: %w", err) + } + preimage, err := paymentchannels.VoucherMessageBytes(channel, cumulative, v.ExpiresAt) + if err != nil { + return nil, err + } + ints := make([]int, len(preimage)) + for i, b := range preimage { + ints[i] = int(b) + } + eb.Bytes = ints + eb.Base64URL = wire.Base64URLEncode(preimage) + } return eb, nil } // shapeFromTransaction decodes a base64 wire transaction into the semantic -// shape the conformance driver asserts against. It mirrors the TS reference -// decoder (harness/src/conformance/decode.ts): fee payer is account[0], SPL +// shape the conformance driver asserts against: fee payer is account[0], SPL // transfers come from transferChecked (discriminator 12), SOL transfers from // the System Program transfer (discriminator 2), memos from the Memo Program, // and compute caps from the ComputeBudget program. diff --git a/go/cmd/protocol-runner/format_test.go b/go/cmd/protocol-runner/format_test.go new file mode 100644 index 000000000..52480b934 --- /dev/null +++ b/go/cmd/protocol-runner/format_test.go @@ -0,0 +1,143 @@ +package main + +// Round-trip coverage for the format verbs (challenge.format, +// credential.format, receipt.format) and their parse counterparts, including +// opaque blobs, credential payloads, and the malformed-input failure paths. + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/wire" +) + +func TestDispatchChallengeFormatParseRoundtrip(t *testing.T) { + format := dispatch(req(t, "challenge.format", map[string]any{ + "id": "ch_roundtrip", + "realm": "api.example.com", + "method": "solana", + "intent": "session", + "request": map[string]any{"cap": "1000000", "currency": "USDC"}, + "expires": "2030-01-01T00:00:00Z", + "description": "Metered stream", + "digest": "sha256=abc", + "opaque": map[string]any{"hint": "value"}, + })) + if !format.Success { + t.Fatalf("challenge.format failed: %s", format.Error) + } + header := format.Result.(headerInput).Header + if !strings.HasPrefix(header, "Payment ") { + t.Fatalf("formatted header = %q", header) + } + + parse := dispatch(req(t, "challenge.parse", map[string]string{"header": header})) + if !parse.Success { + t.Fatalf("challenge.parse failed: %s", parse.Error) + } + obj := parse.Result.(challengeObject) + if obj.ID != "ch_roundtrip" || obj.Intent != "session" || obj.Description != "Metered stream" { + t.Fatalf("round-tripped challenge = %+v", obj) + } + request, okType := obj.Request.(map[string]any) + if !okType || request["cap"] != "1000000" { + t.Fatalf("round-tripped request = %#v", obj.Request) + } + opaque, okType := obj.Opaque.(map[string]any) + if !okType || opaque["hint"] != "value" { + t.Fatalf("round-tripped opaque = %#v", obj.Opaque) + } +} + +func TestDispatchChallengeFormatMalformedInput(t *testing.T) { + resp := dispatch(request{Op: "challenge.format", Input: json.RawMessage(`"not-an-object"`)}) + if resp.Success || resp.ErrorType != "format_error" { + t.Fatalf("malformed challenge.format = %+v", resp) + } +} + +func TestDispatchCredentialFormatParseRoundtrip(t *testing.T) { + format := dispatch(req(t, "credential.format", map[string]any{ + "challenge": map[string]any{ + "id": "ch_cred", + "realm": "api.example.com", + "method": "solana", + "intent": "session", + "request": map[string]any{"cap": "1000"}, + "expires": "2030-01-01T00:00:00Z", + "opaque": map[string]any{"k": "v"}, + }, + "source": "wallet", + "payload": map[string]any{"action": "close", "channelId": "abc"}, + })) + if !format.Success { + t.Fatalf("credential.format failed: %s", format.Error) + } + header := format.Result.(headerInput).Header + + parse := dispatch(req(t, "credential.parse", map[string]string{"header": header})) + if !parse.Success { + t.Fatalf("credential.parse failed: %s", parse.Error) + } + credential := parse.Result.(wire.PaymentCredential) + if credential.Challenge.ID != "ch_cred" || credential.Source != "wallet" { + t.Fatalf("round-tripped credential = %+v", credential) + } + if credential.Payload == nil || !strings.Contains(string(*credential.Payload), `"close"`) { + t.Fatalf("round-tripped payload = %v", credential.Payload) + } +} + +func TestDispatchCredentialFormatAndParseMalformedInput(t *testing.T) { + format := dispatch(request{Op: "credential.format", Input: json.RawMessage(`"nope"`)}) + if format.Success || format.ErrorType != "format_error" { + t.Fatalf("malformed credential.format = %+v", format) + } + parse := dispatch(req(t, "credential.parse", map[string]string{"header": "Payment !!!"})) + if parse.Success || parse.ErrorType != "parse_error" { + t.Fatalf("malformed credential.parse = %+v", parse) + } +} + +func TestDispatchReceiptFormatParseRoundtrip(t *testing.T) { + format := dispatch(req(t, "receipt.format", map[string]any{ + "status": "success", + "method": "solana", + "timestamp": "2030-01-01T00:00:00Z", + "reference": "5sig", + })) + if !format.Success { + t.Fatalf("receipt.format failed: %s", format.Error) + } + header := format.Result.(headerInput).Header + + parse := dispatch(req(t, "receipt.parse", map[string]string{"header": header})) + if !parse.Success { + t.Fatalf("receipt.parse failed: %s", parse.Error) + } + receipt := parse.Result.(wire.Receipt) + if receipt.Status != wire.ReceiptStatusSuccess || receipt.Reference != "5sig" { + t.Fatalf("round-tripped receipt = %+v", receipt) + } +} + +func TestDispatchReceiptMalformedInput(t *testing.T) { + format := dispatch(request{Op: "receipt.format", Input: json.RawMessage(`"nope"`)}) + if format.Success || format.ErrorType != "format_error" { + t.Fatalf("malformed receipt.format = %+v", format) + } + parse := dispatch(req(t, "receipt.parse", map[string]string{"header": "!!!"})) + if parse.Success || parse.ErrorType != "parse_error" { + t.Fatalf("malformed receipt.parse = %+v", parse) + } +} + +func TestDispatchHeaderInputDecodeFailures(t *testing.T) { + for _, op := range []string{"challenge.parse", "credential.parse", "receipt.parse", "base64url.encode", "base64url.decode"} { + resp := dispatch(request{Op: op, Input: json.RawMessage(`5`)}) + if resp.Success { + t.Fatalf("%s accepted a malformed input", op) + } + } +} diff --git a/go/examples/playground-api/README.md b/go/examples/playground-api/README.md new file mode 100644 index 000000000..1a6d92d5a --- /dev/null +++ b/go/examples/playground-api/README.md @@ -0,0 +1,155 @@ +# playground-api (Go) + +The Go port of [`typescript/examples/playground-api`](../../../typescript/examples/playground-api/), +the HTTP API behind the [pay-kit playground](../../../playground/). It serves +the same endpoints with the same payment gating semantics against the Solana +Payment Sandbox (a hosted test validator, no real funds): + +- **Charges**: `solana.charge` endpoints (stock quote, marketplace purchase + with multi-recipient splits, fortune payment link) gated through the Go + `paykit` umbrella client, plus a faucet that funds wallets through surfpool + cheatcodes. +- **Sessions**: the in-process Go session method gating `/sessions/stream` + (pay-per-chunk SSE) and `/sessions/compute` (pay-per-call), with real + payment-channel opens (server-completed fee-payer signature), voucher + metering through the `/__402/session/*` side channel, and on-chain + settlement via the idle-close watchdog. +- **x402**: two `exact`-scheme demo routes plus the embedded facilitator + endpoints. +- `/api/v1/config`: the endpoint catalog and wallet/network metadata the web + app renders. + +## Running + +```bash +cd go +go run ./examples/playground-api # listens on :3000 +``` + +or through the justfile: + +```bash +just -f go/Justfile serve-playground # :3000 +just -f go/Justfile serve-playground 3210 # custom port +``` + +## Pointing the playground at this server + +Set `PAYKIT_PLAYGROUND_API_URL` and the playground's `pnpm dev` skips +launching the TypeScript server; the web app's dev proxy targets this one +instead: + +```bash +# terminal 1: the Go API +cd go && PORT=3210 go run ./examples/playground-api + +# terminal 2: UI only, proxied to the running API +cd playground +PAYKIT_PLAYGROUND_API_URL=http://localhost:3210 pnpm dev +``` + +## Environment variables + +Same table as the TypeScript example: + +| Variable | Default | Purpose | +|----------|---------|---------| +| `PORT` | `3000` | Listen port | +| `NETWORK` | `localnet` | Solana network tag for MPP / x402 challenges | +| `RPC_URL` | `https://402.surfnet.dev:8899` | Surfpool RPC endpoint (hosted sandbox by default) | +| `RECIPIENT` | (auto-generated) | Solana address that receives payments | +| `FEE_PAYER_KEY` | (auto-generated) | Base58 fee-payer keypair (server signs as fee payer) | +| `MPP_SECRET_KEY` | (random per-boot) | MPP secret key for challenge HMAC | + +Additional Go-only knobs: `DOCS_ROOT` overrides the generated-docs directory +when the binary runs outside the repository checkout, and the standard +`PAY_KIT_DISABLE_PREFLIGHT=1` skips the paykit boot preflight. + +## Endpoints + +| Method | Path | Gate | +|--------|------|------| +| GET | `/api/v1/health` | free | +| GET | `/api/v1/config` | free | +| GET | `/api/v1/docs`, `/api/v1/docs/:lang/tree`, `/api/v1/docs/:lang/file` | free | +| GET | `/api/v1/faucet/status` | free | +| POST | `/api/v1/faucet/airdrop` | free | +| GET | `/api/v1/stocks/quote/:symbol` | charge 0.01 USDC | +| GET | `/api/v1/stocks/search?q=` | charge 0.01 USDC | +| GET | `/api/v1/stocks/history/:symbol` | charge 0.05 USDC | +| GET | `/api/v1/weather/:city` | charge 0.01 USDC | +| GET | `/api/v1/marketplace/products` | free | +| GET | `/api/v1/marketplace/buy/:productId?referrer=` | charge with splits | +| GET | `/api/v1/fortune` | charge 0.01 USDC, HTML payment link | +| GET | `/api/v1/premium/feed` | 501 stub (see below) | +| GET | `/sessions/stream` | session, cap 1.00 USDC, 0.0001 USDC/chunk | +| POST | `/sessions/stream` | session voucher commits | +| POST | `/sessions/compute` | session, cap 0.50 USDC, 0.005 USDC/call | +| POST | `/__402/session/deliveries` | session side channel | +| POST | `/__402/session/commit` | session side channel | +| GET | `/sessions/receipt/:channelId` | free settle-status poll | +| GET | `/facilitator/supported` | free | +| POST | `/facilitator/verify`, `/facilitator/settle` | free | +| GET | `/x402/joke`, `/x402/fact` | x402 exact, $0.001 | + +As in the TypeScript example, the stocks-search / stocks-history / weather / +fortune and `/x402/*` routes stay live server-side but are not advertised in +the `/api/v1/config` nav catalog. + +## Differences from the TypeScript example + +Nothing is silently dropped; where the Go SDK lacks a capability the closest +faithful behavior is served and listed here: + +1. **Subscriptions**: the Go SDK does not implement the + `solana.subscription` server method yet, so there is no plan bootstrap + and `GET /api/v1/premium/feed` answers `501 {"error":"not_implemented"}`. + The endpoint catalog omits the subscription entry, which is exactly how + the TypeScript server behaves when its plan bootstrap fails, so the + playground UI renders its graceful empty state. +2. **x402 gating is self-hosted**: the TypeScript routes are gated by + `x402-express` POSTing to the embedded facilitator; the Go x402 adapter + only implements self-hosted mode, so `/x402/joke` and `/x402/fact` verify + and settle in-process with the operator signer. The + `/facilitator/supported|verify|settle` endpoints are still served with + the same response shapes for external x402 clients. The challenge + advertises the configured `NETWORK` instead of the TypeScript example's + hardcoded `solana-devnet` (localnet shares the devnet genesis hash). + +The stocks endpoints call the same Yahoo Finance endpoints as the +`yahoo-finance2` package the TypeScript server uses (v7 quote with crumb +auth, v1 search, v8 chart) and apply the same field coercions, so the +response bodies match the TypeScript server's field for field. + +## Layout + +Mirrors the TypeScript module structure as files of one `main` package: + +``` +main.go # bootstrap, fee payer, surfpool funding, /api/v1/{health,config}, CORS, SPA +charges.go # stocks/weather/marketplace + fortune payment link +yahoo.go # Yahoo Finance client matching yahoo-finance2's response shapes +sessions.go # in-process session methods + side-channel routes + receipt +subscriptions.go # documented 501 stub (no Go subscription method yet) +x402.go # embedded facilitator + x402-gated routes +faucet.go # SOL + USDC airdrop via surfpool cheatcodes +docs.go # generated-docs browser with path-escape guard +constants.go # example-specific constants (faucet amounts, USDC decimals) +utils.go # rpcCall, ANSI helpers, receipt logging +``` + +## Tests + +```bash +cd go +go test ./examples/playground-api/ # offline smoke test (stub RPC) +go test ./examples/playground-api/ -run SessionE2ESurfpool # sandbox-gated session lifecycle +``` + +The smoke test boots the full route table against a stub JSON-RPC server and +checks every endpoint's unauthenticated behavior. The e2e mirrors +`playground-session-e2e.test.ts` (real channel open, metered SSE, side-channel +commit, on-chain settle) and skips explicitly when the sandbox is unreachable +or under `-short`. CI also boots this server against a local surfnet and runs +the payment-link Playwright suite from `html/` against `/api/v1/fortune`, the +same coverage the TypeScript playground API gets. diff --git a/go/examples/playground-api/charges.go b/go/examples/playground-api/charges.go new file mode 100644 index 000000000..1c81c10f0 --- /dev/null +++ b/go/examples/playground-api/charges.go @@ -0,0 +1,325 @@ +package main + +// Charge-gated endpoints: stock data, weather, a marketplace purchase with +// multi-recipient splits (all gated through the paykit umbrella client), and +// the fortune payment link served straight from the protocol-layer MPP +// server with the HTML challenge page enabled. The 402 challenge fires +// before any upstream fetch. + +import ( + "fmt" + "math/rand" + "net/http" + "strings" + + "github.com/shopspring/decimal" + + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paykit" + server "github.com/solana-foundation/pay-kit/go/protocols/mpp/server" +) + +// weatherInfo is the canned per-city weather payload. +type weatherInfo struct { + // Temperature is the air temperature in whole degrees Celsius. + Temperature int `json:"temperature"` + // Conditions is the human-readable sky/precipitation label + // (e.g. "Foggy", "Partly Cloudy"). + Conditions string `json:"conditions"` + // Humidity is the relative humidity as a whole-number percentage (0-100). + Humidity int `json:"humidity"` +} + +// weatherByCity is the canned weather demo table. +var weatherByCity = map[string]weatherInfo{ + "san-francisco": {Temperature: 15, Conditions: "Foggy", Humidity: 85}, + "new-york": {Temperature: 22, Conditions: "Partly Cloudy", Humidity: 60}, + "london": {Temperature: 12, Conditions: "Rainy", Humidity: 90}, + "tokyo": {Temperature: 26, Conditions: "Sunny", Humidity: 55}, + "paris": {Temperature: 18, Conditions: "Overcast", Humidity: 70}, + "sydney": {Temperature: 24, Conditions: "Clear", Humidity: 45}, + "berlin": {Temperature: 10, Conditions: "Cloudy", Humidity: 75}, + "dubai": {Temperature: 38, Conditions: "Sunny", Humidity: 30}, +} + +// product is one marketplace catalog entry. +type product struct { + // Name is the display title shown in the catalog listing and receipt. + Name string + // Price is the seller's list price in USD; the platform and referral + // basis-point fees are charged on top of it, not carved out of it. + Price paykit.Price + // Seller is the base58 wallet address that receives the list price as + // the charge's primary PayTo recipient. + Seller string + // Description is the one-line marketing blurb shown in the listing. + Description string +} + +// products is the canned marketplace catalog. +var products = map[string]product{ + "sol-hoodie": { + Name: "Solana Hoodie", + Price: paykit.MustParseUSD("2.00"), + Seller: "7xKXtg2CW87d97TXJSDpbD5jBkheTqA83TZRuJosgAsU", + Description: "Premium Solana-branded hoodie", + }, + "validator-mug": { + Name: "Validator Mug", + Price: paykit.MustParseUSD("1.00"), + Seller: "7xKXtg2CW87d97TXJSDpbD5jBkheTqA83TZRuJosgAsU", + Description: "Ceramic mug for node operators", + }, + "nft-sticker-pack": { + Name: "NFT Sticker Pack", + Price: paykit.MustParseUSD("0.50"), + Seller: "7xKXtg2CW87d97TXJSDpbD5jBkheTqA83TZRuJosgAsU", + Description: "Holographic sticker collection", + }, +} + +const ( + platformFeeBps = 500 // 5% + referralFeeBps = 200 // 2% +) + +// fortunes is the canned fortune-cookie pool. +var fortunes = []string{ + "A beautiful, smart, and loving person will be coming into your life.", + "A faithful friend is a strong defense.", + "A golden egg of opportunity falls into your lap this month.", + "All your hard work will soon pay off.", + "Curiosity kills boredom. Nothing can kill curiosity.", + "Every day in your life is a special occasion.", + "Good news will come to you by mail.", + "If you continually give, you will continually have.", +} + +// bps returns the given basis-point percentage of a price, e.g. +// bps(usd 2.00, 500) is usd 0.10. +func bps(p paykit.Price, basisPoints int64) paykit.Price { + amount := p.Amount().Mul(decimal.NewFromInt(basisPoints)).Div(decimal.NewFromInt(10_000)) + return paykit.MustParseUSD(amount.String()) +} + +// displayUSD renders a price as the playground's two-decimal USDC label. +func displayUSD(p paykit.Price) string { + return p.Amount().StringFixed(2) + " USDC" +} + +// registerCharges mounts every charge-gated endpoint plus the free +// marketplace catalog. +func registerCharges(mux *http.ServeMux, a *app, client *paykit.Client) error { + platform := a.recipient + + // logged surfaces the settlement signature once a gated handler runs. + logged := func(handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if payment, ok := paykit.PaymentFrom(r.Context()); ok && payment.Transaction != "" { + logTx(r.URL.Path, payment.Transaction) + } + handler(w, r) + } + } + + staticGate := func(amount, name string, describe func(r *http.Request) string) func(*http.Request) (paykit.Gate, error) { + return func(r *http.Request) (paykit.Gate, error) { + return paykit.Gate{ + Amount: paykit.MustParseUSD(amount), + Name: name, + Desc: describe(r), + }, nil + } + } + + // Stocks, backed by the same Yahoo Finance endpoints (and response + // shapes) as the yahoo-finance2 package the TypeScript server uses. + yahoo := newYahooClient() + + mux.Handle("GET /api/v1/stocks/quote/{symbol}", + client.RequireFunc(staticGate("0.01", "stockQuote", func(r *http.Request) string { + return "Stock quote: " + r.PathValue("symbol") + }))(logged(func(w http.ResponseWriter, r *http.Request) { + quote, err := yahoo.quote(r.Context(), r.PathValue("symbol")) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "Failed to fetch quote") + return + } + if quote == nil { + // Unknown or delisted symbol: an empty 200 body, the way + // Express serializes res.json(undefined). + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + return + } + writeJSON(w, http.StatusOK, quote) + }))) + + mux.Handle("GET /api/v1/stocks/search", + requireQuery("q", client.RequireFunc(staticGate("0.01", "stockSearch", func(r *http.Request) string { + return "Stock search: " + r.URL.Query().Get("q") + }))(logged(func(w http.ResponseWriter, r *http.Request) { + quotes, err := yahoo.search(r.Context(), r.URL.Query().Get("q")) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "Failed to search") + return + } + writeJSON(w, http.StatusOK, quotes) + })))) + + mux.Handle("GET /api/v1/stocks/history/{symbol}", + client.RequireFunc(staticGate("0.05", "stockHistory", func(r *http.Request) string { + return "Stock history: " + r.PathValue("symbol") + }))(logged(func(w http.ResponseWriter, r *http.Request) { + history, err := yahoo.history(r.Context(), r.PathValue("symbol"), r.URL.Query().Get("range")) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "Failed to fetch history") + return + } + writeJSON(w, http.StatusOK, history) + }))) + + // Weather: unknown cities 404 before the payment gate. + mux.Handle("GET /api/v1/weather/{city}", requireKnownCity( + client.RequireFunc(staticGate("0.01", "weather", func(r *http.Request) string { + return "Weather for " + r.PathValue("city") + }))(logged(func(w http.ResponseWriter, r *http.Request) { + city := r.PathValue("city") + info := weatherByCity[cityKey(city)] + writeJSON(w, http.StatusOK, map[string]any{ + "city": city, + "temperature": info.Temperature, + "conditions": info.Conditions, + "humidity": info.Humidity, + }) + })))) + + // Marketplace: free catalog plus the split purchase. + mux.HandleFunc("GET /api/v1/marketplace/products", func(w http.ResponseWriter, _ *http.Request) { + list := []map[string]string{} + for _, id := range []string{"sol-hoodie", "validator-mug", "nft-sticker-pack"} { + p := products[id] + list = append(list, map[string]string{ + "id": id, + "name": p.Name, + "description": p.Description, + "price": displayUSD(p.Price), + "priceRaw": p.Price.Amount().Shift(usdcDecimals).Truncate(0).String(), + }) + } + writeJSON(w, http.StatusOK, list) + }) + + buyGate := func(r *http.Request) (paykit.Gate, error) { + p := products[r.PathValue("productId")] // validated before payment, below + fees := paykit.Fees{paykit.Address(platform): bps(p.Price, platformFeeBps)} + if referrer := r.URL.Query().Get("referrer"); referrer != "" { + fees[paykit.Address(referrer)] = bps(p.Price, referralFeeBps) + } + return paykit.Gate{ + Amount: p.Price, + PayTo: paykit.Address(p.Seller), + Name: "marketplaceBuy", + Desc: "Purchase: " + p.Name, + FeeOnTop: fees, + }, nil + } + mux.Handle("GET /api/v1/marketplace/buy/{productId}", requireKnownProduct( + client.RequireFunc(buyGate)(logged(func(w http.ResponseWriter, r *http.Request) { + p := products[r.PathValue("productId")] + platformFee := bps(p.Price, platformFeeBps) + total := p.Price.Amount().Add(platformFee.Amount()) + breakdown := map[string]string{ + "seller": displayUSD(p.Price), + "platformFee": displayUSD(platformFee), + } + if referrer := r.URL.Query().Get("referrer"); referrer != "" { + referralFee := bps(p.Price, referralFeeBps) + breakdown["referralFee"] = displayUSD(referralFee) + total = total.Add(referralFee.Amount()) + } + breakdown["total"] = total.StringFixed(2) + " USDC" + writeJSON(w, http.StatusOK, map[string]any{ + "product": p.Name, + "breakdown": breakdown, + "status": "purchased", + }) + })))) + + // Fortune: a charge payment link with the interactive HTML challenge + // page. Stays on the protocol layer directly (server.Mpp with HTML + // enabled) because the paykit dispatcher renders the cross-SDK JSON + // challenge body; dropping down a layer is the intended escape hatch. + fortuneMpp, err := server.New(server.Config{ + Recipient: a.recipient, + Currency: paycore.USDCMainnetMint, + Decimals: usdcDecimals, + Network: a.network, + RPCURL: a.rpcURL, + SecretKey: a.secretKey, + HTML: true, + FeePayerSigner: a.feePayer, + RPC: a.rpcClient, + }) + if err != nil { + return fmt.Errorf("fortune mpp server: %w", err) + } + fortuneHandler := server.PaymentMiddleware(fortuneMpp, func(*http.Request) (string, server.ChargeOptions, error) { + return "0.01", server.ChargeOptions{Description: "Open a fortune cookie"}, nil + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fortune := fortunes[rand.Intn(len(fortunes))] + logPayment(r.URL.Path, w.Header()) + writeJSON(w, http.StatusOK, map[string]string{"fortune": fortune}) + })) + mux.HandleFunc("GET /api/v1/fortune", func(w http.ResponseWriter, r *http.Request) { + // The interactive payment page registers its service worker at + // scope "/" from a script served under /api/v1/fortune, which + // browsers only allow with this header. + if server.IsServiceWorkerRequest(r) { + w.Header().Set("Service-Worker-Allowed", "/") + } + fortuneHandler.ServeHTTP(w, r) + }) + + return nil +} + +// requireQuery rejects requests missing the given non-empty query parameter +// before the payment gate runs. +func requireQuery(name string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get(name) == "" { + writeJSONError(w, http.StatusBadRequest, "Missing ?"+name+"= parameter") + return + } + next.ServeHTTP(w, r) + }) +} + +// cityKey normalizes a city path segment onto the weather table key. +func cityKey(city string) string { + return strings.ReplaceAll(strings.ToLower(city), " ", "-") +} + +// requireKnownCity 404s unknown cities before payment. +func requireKnownCity(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := weatherByCity[cityKey(r.PathValue("city"))]; !ok { + writeJSONError(w, http.StatusNotFound, + "City not found. Available: san-francisco, new-york, london, tokyo, paris, sydney, berlin, dubai") + return + } + next.ServeHTTP(w, r) + }) +} + +// requireKnownProduct 404s unknown products before payment. +func requireKnownProduct(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, ok := products[r.PathValue("productId")]; !ok { + writeJSONError(w, http.StatusNotFound, "Product not found") + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/go/examples/playground-api/constants.go b/go/examples/playground-api/constants.go new file mode 100644 index 000000000..da9a36717 --- /dev/null +++ b/go/examples/playground-api/constants.go @@ -0,0 +1,18 @@ +package main + +// Example-specific constants. Program ids and the USDC mint come straight +// from paycore at the call sites; only knobs without an SDK equivalent +// live here. + +const ( + // usdcDecimals is the USDC token decimal count. The SDK does not + // export a decimals constant (paykit, paycore, and protocols/mpp all + // default to 6 internally), so the example pins it locally. + usdcDecimals = 6 + + // solFundLamports is the faucet SOL amount (100 SOL). + solFundLamports = 100_000_000_000 + + // usdcFundAmount is the faucet USDC amount (100 USDC at 6 decimals). + usdcFundAmount = 100_000_000 +) diff --git a/go/examples/playground-api/docs.go b/go/examples/playground-api/docs.go new file mode 100644 index 000000000..38cbaac92 --- /dev/null +++ b/go/examples/playground-api/docs.go @@ -0,0 +1,186 @@ +package main + +// Serves the generated API reference markdown from /docs/api, +// with a path-escape guard. Override the root with the DOCS_ROOT env var +// when running the binary outside the repository checkout. + +import ( + "net/http" + "os" + "path/filepath" + "sort" + "strings" +) + +// docLangs are the languages the playground docs browser knows about. +var docLangs = []string{"typescript", "rust", "go", "python", "ruby", "php", "lua", "kotlin", "swift"} + +// docsTreeNode is one entry of the docs file tree. +type docsTreeNode struct { + // Name is the base name of the file or directory (e.g. "README.md"). + Name string `json:"name"` + // Path is the slash-separated path relative to the language docs root; + // the web app passes it back as the ?path= query of the file endpoint. + Path string `json:"path"` + // Type is "dir" for directories and "file" for markdown files; dirs + // sort before files within a level. + Type string `json:"type"` + // Children holds the directory's child nodes; omitted for files. + Children []docsTreeNode `json:"children,omitempty"` +} + +// docsRoot resolves the generated-docs directory. +func docsRoot(repoRoot string) string { + if override := os.Getenv("DOCS_ROOT"); override != "" { + return override + } + if repoRoot == "" { + return "" + } + return filepath.Join(repoRoot, "docs", "api") +} + +// registerDocs mounts the generated-docs browsing endpoints. +func registerDocs(mux *http.ServeMux, a *app) { + root := docsRoot(a.repoRoot) + + mux.HandleFunc("GET /api/v1/docs", func(w http.ResponseWriter, _ *http.Request) { + available := map[string]bool{} + for _, lang := range docLangs { + _, err := os.Stat(filepath.Join(root, lang, "README.md")) + available[lang] = err == nil + } + writeJSON(w, http.StatusOK, map[string]any{"root": root, "available": available}) + }) + + mux.HandleFunc("GET /api/v1/docs/{lang}/tree", func(w http.ResponseWriter, r *http.Request) { + lang := r.PathValue("lang") + if !isDocLang(lang) { + writeJSONError(w, http.StatusNotFound, "unknown_lang") + return + } + langRoot := filepath.Join(root, lang) + if _, err := os.Stat(langRoot); err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{ + "error": "not_generated", + "hint": "Run: just docs-" + docsRecipeSlug(lang), + }) + return + } + tree, err := buildDocsTree(langRoot, "") + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{ + "error": "tree_failed", + "detail": err.Error(), + }) + return + } + writeJSON(w, http.StatusOK, map[string]any{"lang": lang, "tree": tree}) + }) + + mux.HandleFunc("GET /api/v1/docs/{lang}/file", func(w http.ResponseWriter, r *http.Request) { + lang := r.PathValue("lang") + if !isDocLang(lang) { + writeJSONError(w, http.StatusNotFound, "unknown_lang") + return + } + rel := r.URL.Query().Get("path") + if rel == "" { + rel = "README.md" + } + langRoot := filepath.Join(root, lang) + abs, ok := safeJoin(langRoot, rel) + if !ok { + writeJSONError(w, http.StatusBadRequest, "unsafe_path") + return + } + if !strings.HasSuffix(abs, ".md") { + writeJSONError(w, http.StatusBadRequest, "not_markdown") + return + } + content, err := os.ReadFile(abs) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{ + "error": "not_found", + "detail": err.Error(), + }) + return + } + w.Header().Set("Content-Type", "text/markdown; charset=utf-8") + _, _ = w.Write(content) + }) +} + +// isDocLang reports whether lang is a known docs language. +func isDocLang(lang string) bool { + for _, known := range docLangs { + if lang == known { + return true + } + } + return false +} + +// buildDocsTree walks the language docs directory: folders first, then +// markdown files, both alphabetical, skipping dotfiles and node_modules. +func buildDocsTree(absDir, relDir string) ([]docsTreeNode, error) { + entries, err := os.ReadDir(absDir) + if err != nil { + return nil, err + } + nodes := []docsTreeNode{} + for _, entry := range entries { + name := entry.Name() + if strings.HasPrefix(name, ".") || name == "node_modules" { + continue + } + relPath := name + if relDir != "" { + relPath = relDir + "/" + name + } + if entry.IsDir() { + children, err := buildDocsTree(filepath.Join(absDir, name), relPath) + if err != nil { + return nil, err + } + nodes = append(nodes, docsTreeNode{Name: name, Path: relPath, Type: "dir", Children: children}) + } else if strings.HasSuffix(name, ".md") { + nodes = append(nodes, docsTreeNode{Name: name, Path: relPath, Type: "file"}) + } + } + sort.SliceStable(nodes, func(i, j int) bool { + if nodes[i].Type != nodes[j].Type { + return nodes[i].Type == "dir" + } + return nodes[i].Name < nodes[j].Name + }) + return nodes, nil +} + +// safeJoin joins rel onto root and rejects any path escaping the root. +func safeJoin(root, rel string) (string, bool) { + joined := filepath.Join(root, filepath.FromSlash(rel)) + relBack, err := filepath.Rel(root, joined) + if err != nil || relBack == ".." || strings.HasPrefix(relBack, ".."+string(filepath.Separator)) { + return "", false + } + return joined, true +} + +// docsRecipeSlug maps a docs language to its justfile recipe suffix. +func docsRecipeSlug(lang string) string { + switch lang { + case "typescript": + return "ts" + case "rust": + return "rs" + case "python": + return "py" + case "ruby": + return "rb" + case "kotlin": + return "kt" + default: + return lang + } +} diff --git a/go/examples/playground-api/faucet.go b/go/examples/playground-api/faucet.go new file mode 100644 index 000000000..b10bd7b04 --- /dev/null +++ b/go/examples/playground-api/faucet.go @@ -0,0 +1,61 @@ +package main + +// SOL + USDC airdrops via the surfnet cheatcodes. + +import ( + "encoding/json" + "net/http" + + "github.com/solana-foundation/pay-kit/go/paycore" +) + +// registerFaucet mounts the faucet status and airdrop endpoints. +func registerFaucet(mux *http.ServeMux, a *app) { + mux.HandleFunc("GET /api/v1/faucet/status", func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{ + "solAmount": "100 SOL", + "usdcAmount": "100 USDC", + "usdcMint": paycore.USDCMainnetMint, + }) + }) + + mux.HandleFunc("POST /api/v1/faucet/airdrop", func(w http.ResponseWriter, r *http.Request) { + var body struct { + Address string `json:"address"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Address == "" { + writeJSONError(w, http.StatusBadRequest, "Missing `address` in request body") + return + } + _, err := rpcCall(r.Context(), a.rpcURL, "surfnet_setAccount", []any{ + body.Address, + map[string]any{ + "lamports": solFundLamports, + "data": "", + "executable": false, + "owner": paycore.SystemProgram, + "rentEpoch": 0, + }, + }) + if err == nil { + _, err = rpcCall(r.Context(), a.rpcURL, "surfnet_setTokenAccount", []any{ + body.Address, + paycore.USDCMainnetMint, + map[string]any{"amount": usdcFundAmount, "state": "initialized"}, + paycore.TokenProgram, + }) + } + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{ + "error": "Airdrop failed", + "details": err.Error(), + }) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "sol": "100 SOL", + "usdc": "100 USDC", + }) + }) +} diff --git a/go/examples/playground-api/main.go b/go/examples/playground-api/main.go new file mode 100644 index 000000000..5cdcd6b92 --- /dev/null +++ b/go/examples/playground-api/main.go @@ -0,0 +1,420 @@ +// The HTTP API behind the pay-kit playground. Serves the playground +// endpoints with their payment gating (MPP charges through paykit, x402 +// through the Go x402 adapter, sessions through the Go session method), so +// the playground web app works against it by only setting +// PAYKIT_PLAYGROUND_API_URL. +// +// cd go +// go run ./examples/playground-api +// +// Environment: PORT, NETWORK, RPC_URL, RECIPIENT, FEE_PAYER_KEY, +// MPP_SECRET_KEY. See README.md for the full table. +package main + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "strconv" + "time" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paycore/signer" + "github.com/solana-foundation/pay-kit/go/paykit" + _ "github.com/solana-foundation/pay-kit/go/protocols/mpp" + _ "github.com/solana-foundation/pay-kit/go/protocols/x402" +) + +// app carries the boot configuration shared by every module. +type app struct { + network string // raw NETWORK tag: localnet | devnet | mainnet + // rpcURL is the Solana JSON-RPC endpoint (RPC_URL env var; defaults to + // the hosted Solana Payment Sandbox). + rpcURL string + // recipient is the base58 address paid by the charge endpoints + // (RECIPIENT env var; defaults to the fee payer's pubkey). + recipient string + // secretKey is the MPP challenge-binding HMAC secret (MPP_SECRET_KEY env + // var; a random per-boot hex secret when unset). + secretKey string + // feePayer is the operator keypair that signs and pays fees for + // settlement transactions (FEE_PAYER_KEY env var; random when unset). + feePayer solana.PrivateKey + // rpcClient is the shared RPC client bound to rpcURL. + rpcClient *rpc.Client + // repoRoot is the repository checkout root; "" outside a checkout, which + // disables the docs browser default root and the SPA file server. + repoRoot string +} + +func main() { + network := envOr("NETWORK", "localnet") + // Default to the hosted Solana Payment Sandbox so the playground works + // zero-config: it has the payment-channels program preloaded and supports + // the surfnet cheatcodes used by the faucet. Override RPC_URL to point at + // a local surfpool when you need offline iteration. + rpcURL := envOr("RPC_URL", "https://402.surfnet.dev:8899") + secretKey := os.Getenv("MPP_SECRET_KEY") + if secretKey == "" { + secretKey = randomHexSecret() + } + port, err := strconv.Atoi(envOr("PORT", "3000")) + if err != nil { + log.Fatalf("invalid PORT: %v", err) + } + + var feePayer solana.PrivateKey + if raw := os.Getenv("FEE_PAYER_KEY"); raw != "" { + feePayer, err = solana.PrivateKeyFromBase58(raw) + if err != nil { + log.Fatalf("invalid FEE_PAYER_KEY: %v", err) + } + } else { + feePayer, err = solana.NewRandomPrivateKey() + if err != nil { + log.Fatalf("generate fee payer: %v", err) + } + } + recipient := envOr("RECIPIENT", feePayer.PublicKey().String()) + + a := &app{ + network: network, + rpcURL: rpcURL, + recipient: recipient, + secretKey: secretKey, + feePayer: feePayer, + rpcClient: rpc.New(rpcURL), + repoRoot: findRepoRoot(), + } + + bootstrapFunding(a) + + handler, shutdown, err := newApp(a) + if err != nil { + log.Fatalf("playground-api: %v", err) + } + defer shutdown() + + addr := fmt.Sprintf(":%d", port) + log.Println() + log.Printf(" %s %s", bold("PayKit Playground (Go)"), dim(fmt.Sprintf("http://localhost:%d", port))) + log.Println() + log.Printf(" %s %s", dim("Network"), magenta(a.network)) + log.Printf(" %s %s", dim("RPC"), cyan(a.rpcURL)) + log.Printf(" %s %s", dim("Recipient"), green(a.recipient)) + log.Printf(" %s %s", dim("Fee payer"), green(a.feePayer.PublicKey().String())) + log.Printf(" %s %s", dim("Plan"), yellow("not bootstrapped (subscriptions are not implemented in the Go SDK)")) + log.Printf(" %s %s", dim("Sessions"), green("enabled (in-process)")) + log.Println() + if err := http.ListenAndServe(addr, handler); err != nil { + log.Fatal(err) + } +} + +// newApp wires every module onto one handler. Split from main so the smoke +// test can boot the full route table against a stub RPC without binding a +// real port or funding accounts. +func newApp(a *app) (http.Handler, func(), error) { + mux := http.NewServeMux() + + registerHealthAndConfig(mux, a) + registerFaucet(mux, a) + registerDocs(mux, a) + + chargesClient, err := newChargesClient(a) + if err != nil { + return nil, nil, fmt.Errorf("charges paykit client: %w", err) + } + if err := registerCharges(mux, a, chargesClient); err != nil { + return nil, nil, fmt.Errorf("charges module: %w", err) + } + + registerSubscriptions(mux) + + sessionsShutdown, err := registerSessions(mux, a) + if err != nil { + return nil, nil, fmt.Errorf("sessions module: %w", err) + } + + if err := registerX402(mux, a); err != nil { + sessionsShutdown() + return nil, nil, fmt.Errorf("x402 module: %w", err) + } + + registerSPA(mux, a.repoRoot) + + return corsMiddleware(mux), sessionsShutdown, nil +} + +// newChargesClient builds the paykit client gating the charge endpoints. +// MPP is the only accepted protocol. +func newChargesClient(a *app) (*paykit.Client, error) { + network, err := paykit.ParseNetwork(a.network) + if err != nil { + return nil, err + } + operatorSigner, err := signer.FromBase58(a.feePayer.String()) + if err != nil { + return nil, err + } + return paykit.New(paykit.Config{ + Network: network, + RPCURL: a.rpcURL, + Accept: []paykit.Protocol{paykit.MPP}, + Operator: paykit.Operator{ + Recipient: paykit.Address(a.recipient), + Signer: operatorSigner, + FeePayer: true, + }, + MPP: paykit.MPPConfig{ + Realm: "PayKit Playground", + ChallengeBindingSecret: []byte(a.secretKey), + }, + }) +} + +// bootstrapFunding funds the fee payer and recipient on the local surfnet so +// the demo works zero-config. Best-effort: a warning is logged when the +// sandbox is unreachable. +func bootstrapFunding(a *app) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _, err := rpcCall(ctx, a.rpcURL, "surfnet_setAccount", []any{ + a.feePayer.PublicKey().String(), + map[string]any{ + "lamports": solFundLamports, + "data": "", + "executable": false, + "owner": paycore.SystemProgram, + "rentEpoch": 0, + }, + }) + if err == nil { + _, err = rpcCall(ctx, a.rpcURL, "surfnet_setTokenAccount", []any{ + a.recipient, + paycore.USDCMainnetMint, + map[string]any{"amount": usdcFundAmount, "state": "initialized"}, + paycore.TokenProgram, + }) + } + if err != nil { + log.Println(yellow(" Surfpool not reachable; fee payer may not have SOL for fees.")) + } +} + +// registerHealthAndConfig mounts the health check and the endpoint catalog +// that drives the playground web app's sidebar. +func registerHealthAndConfig(mux *http.ServeMux, a *app) { + mux.HandleFunc("GET /api/v1/health", func(w http.ResponseWriter, r *http.Request) { + body := map[string]any{ + "ok": true, + "feePayer": a.feePayer.PublicKey().String(), + "recipient": a.recipient, + "network": a.network, + "rpcUrl": a.rpcURL, + } + ctx, cancel := context.WithTimeout(r.Context(), 4*time.Second) + defer cancel() + if out, err := a.rpcClient.GetBalance(ctx, a.feePayer.PublicKey(), rpc.CommitmentConfirmed); err == nil && out != nil { + body["feePayerBalance"] = float64(out.Value) / 1e9 + } + writeJSON(w, http.StatusOK, body) + }) + + mux.HandleFunc("GET /api/v1/config", func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]any{ + "recipient": a.recipient, + "network": a.network, + "rpcUrl": a.rpcURL, + "feePayer": a.feePayer.PublicKey().String(), + "endpoints": buildEndpointList(), + }) + }) +} + +// endpointParam describes one path or query parameter of a catalog entry. +type endpointParam struct { + // Name is the parameter name as it appears in the :param path + // placeholder or query string. + Name string `json:"name"` + // Default is the value the playground form pre-fills; "" renders an + // empty input for an optional parameter. + Default string `json:"default"` + // Description is optional help text for the form field; omitted when + // empty. + Description string `json:"description,omitempty"` +} + +// endpointInfo is one entry of the /api/v1/config endpoint catalog. +type endpointInfo struct { + // ID is the stable slug the web app uses to key the endpoint in its nav. + ID string `json:"id"` + // Primitive is the payment primitive gating the route ("charge" or + // "session"). + Primitive string `json:"primitive"` + // Method is the HTTP method the endpoint expects (GET, POST, ...). + Method string `json:"method"` + // Path is the route pattern with :param placeholders the web app + // substitutes from Params. + Path string `json:"path"` + // Title is the short label shown in the playground sidebar. + Title string `json:"title"` + // Description is the one-line summary shown under the title. + Description string `json:"description"` + // Cost is the human-readable price label (e.g. "0.01 USDC", "varies"). + Cost string `json:"cost"` + // UnitPrice is the per-unit price of session endpoints in USDC base + // units (6 decimals) as a decimal string; omitted for charge endpoints. + UnitPrice string `json:"unitPrice,omitempty"` + // Params lists the path/query parameters the web app renders as form + // inputs; omitted when the endpoint takes none. + Params []endpointParam `json:"params,omitempty"` +} + +// buildEndpointList builds the /api/v1/config endpoint catalog. The +// subscription entry is omitted because the Go SDK has no subscription +// server method (see README.md); the stocks-search / stocks-history / +// weather / fortune / x402 routes stay live server-side but are not +// advertised in the nav. +func buildEndpointList() []endpointInfo { + return []endpointInfo{ + { + ID: "stocks-quote", + Primitive: "charge", + Method: "GET", + Path: "/api/v1/stocks/quote/:symbol", + Title: "Stock quote", + Description: "Real-time price for a single ticker.", + Cost: "0.01 USDC", + Params: []endpointParam{{Name: "symbol", Default: "AAPL"}}, + }, + { + ID: "marketplace-buy", + Primitive: "charge", + Method: "GET", + Path: "/api/v1/marketplace/buy/:productId", + Title: "Marketplace purchase", + Description: "Multi-recipient split (seller + platform + referral).", + Cost: "varies", + Params: []endpointParam{ + {Name: "productId", Default: "sol-hoodie"}, + {Name: "referrer", Default: ""}, + }, + }, + { + ID: "sessions-stream", + Primitive: "session", + Method: "GET", + Path: "/sessions/stream", + Title: "Metered stream", + Description: "Pay-per-chunk SSE delivery via session vouchers.", + Cost: "0.0001 USDC / chunk", + UnitPrice: "100", + }, + { + ID: "sessions-compute", + Primitive: "session", + Method: "POST", + Path: "/sessions/compute", + Title: "Pay-per-call compute", + Description: "Voucher-billed inference; cap 0.50 USDC per session.", + Cost: "0.005 USDC / call", + UnitPrice: "5000", + }, + } +} + +// corsMiddleware sets permissive origins and exposes to browsers the +// payment headers the web app reads. +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header := w.Header() + header.Set("Access-Control-Allow-Origin", "*") + header.Set("Access-Control-Expose-Headers", + "www-authenticate, payment-receipt, x-payment-required, x-payment-response") + if r.Method == http.MethodOptions { + header.Set("Access-Control-Allow-Methods", "GET,HEAD,PUT,PATCH,POST,DELETE") + if requested := r.Header.Get("Access-Control-Request-Headers"); requested != "" { + header.Set("Access-Control-Allow-Headers", requested) + } + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) +} + +// registerSPA serves the built playground web app (playground/app/dist at +// the repo root) with an index.html catch-all. +func registerSPA(mux *http.ServeMux, repoRoot string) { + dist := "" + if repoRoot != "" { + dist = filepath.Join(repoRoot, "playground", "app", "dist") + } + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if dist != "" { + candidate := filepath.Join(dist, filepath.FromSlash(r.URL.Path)) + if info, err := os.Stat(candidate); err == nil && !info.IsDir() { + http.ServeFile(w, r, candidate) + return + } + index := filepath.Join(dist, "index.html") + if _, err := os.Stat(index); err == nil { + http.ServeFile(w, r, index) + return + } + } + writeJSONError(w, http.StatusNotFound, + "not found (build playground/app to serve the web app from this server)") + }) +} + +// findRepoRoot walks up from the working directory to the repository root +// (the directory containing .git or the top-level justfile). Returns "" +// when no marker is found. +func findRepoRoot() string { + dir, err := os.Getwd() + if err != nil { + return "" + } + for { + if _, err := os.Stat(filepath.Join(dir, ".git")); err == nil { + return dir + } + if _, err := os.Stat(filepath.Join(dir, "justfile")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return "" + } + dir = parent + } +} + +// envOr returns the environment variable value, or fallback when unset or +// empty. +func envOr(name, fallback string) string { + if v := os.Getenv(name); v != "" { + return v + } + return fallback +} + +// randomHexSecret generates the per-boot challenge HMAC secret used when +// MPP_SECRET_KEY is unset. +func randomHexSecret() string { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + log.Fatalf("generate MPP secret: %v", err) + } + return hex.EncodeToString(buf) +} diff --git a/go/examples/playground-api/main_test.go b/go/examples/playground-api/main_test.go new file mode 100644 index 000000000..57b0852fe --- /dev/null +++ b/go/examples/playground-api/main_test.go @@ -0,0 +1,433 @@ +package main + +// Offline smoke test for the playground API: boots the full route table +// against a stub JSON-RPC server (no network, no funded accounts) and checks +// every endpoint's unauthenticated behavior. + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/paycore" +) + +// newStubRPC serves the JSON-RPC answers the playground needs at boot and +// challenge-build time. +func newStubRPC(t *testing.T, blockhash string) *httptest.Server { + t.Helper() + stub := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var request struct { + ID any `json:"id"` + Method string `json:"method"` + } + _ = json.NewDecoder(r.Body).Decode(&request) + var result any + switch request.Method { + case "getLatestBlockhash": + result = map[string]any{ + "context": map[string]any{"slot": 1}, + "value": map[string]any{ + "blockhash": blockhash, + "lastValidBlockHeight": 100, + }, + } + case "getBalance": + result = map[string]any{ + "context": map[string]any{"slot": 1}, + "value": 5_000_000_000, + } + case "sendTransaction": + result = "stub-signature" + default: + result = "ok" + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": request.ID, + "result": result, + }) + })) + t.Cleanup(stub.Close) + return stub +} + +// newTestServer boots the playground handler against the stub RPC. +func newTestServer(t *testing.T) (*httptest.Server, *app) { + t.Helper() + t.Setenv("PAY_KIT_DISABLE_PREFLIGHT", "1") + + feePayer, err := solana.NewRandomPrivateKey() + if err != nil { + t.Fatalf("generate fee payer: %v", err) + } + blockhash, err := solana.NewRandomPrivateKey() + if err != nil { + t.Fatalf("generate blockhash: %v", err) + } + stub := newStubRPC(t, blockhash.PublicKey().String()) + + a := &app{ + network: "localnet", + rpcURL: stub.URL, + recipient: feePayer.PublicKey().String(), + secretKey: "playground-smoke-secret", + feePayer: feePayer, + rpcClient: rpc.New(stub.URL), + repoRoot: t.TempDir(), // empty root: no docs generated, no SPA dist + } + handler, shutdown, err := newApp(a) + if err != nil { + t.Fatalf("newApp: %v", err) + } + t.Cleanup(shutdown) + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + return httpServer, a +} + +// doRequest performs a request and returns the response with its body read. +func doRequest(t *testing.T, method, url string, body string, header map[string]string) (*http.Response, string) { + t.Helper() + var reader io.Reader + if body != "" { + reader = strings.NewReader(body) + } + request, err := http.NewRequest(method, url, reader) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + if body != "" { + request.Header.Set("Content-Type", "application/json") + } + for k, v := range header { + request.Header.Set(k, v) + } + response, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("%s %s: %v", method, url, err) + } + raw, err := io.ReadAll(response.Body) + response.Body.Close() + if err != nil { + t.Fatalf("read body: %v", err) + } + return response, string(raw) +} + +// decodeBody unmarshals a JSON response body. +func decodeBody(t *testing.T, body string, out any) { + t.Helper() + if err := json.Unmarshal([]byte(body), out); err != nil { + t.Fatalf("unmarshal %q: %v", body, err) + } +} + +func TestPlaygroundEndpoints(t *testing.T) { + httpServer, a := newTestServer(t) + base := httpServer.URL + + t.Run("health", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/api/v1/health", "", nil) + if response.StatusCode != http.StatusOK { + t.Fatalf("status = %d: %s", response.StatusCode, body) + } + var health struct { + OK bool `json:"ok"` + FeePayer string `json:"feePayer"` + FeePayerBalance *float64 `json:"feePayerBalance"` + Recipient string `json:"recipient"` + Network string `json:"network"` + RPCURL string `json:"rpcUrl"` + } + decodeBody(t, body, &health) + if !health.OK || health.FeePayer != a.feePayer.PublicKey().String() || + health.Recipient != a.recipient || health.Network != "localnet" || health.RPCURL != a.rpcURL { + t.Fatalf("health = %+v", health) + } + if health.FeePayerBalance == nil || *health.FeePayerBalance != 5 { + t.Fatalf("feePayerBalance = %v, want 5", health.FeePayerBalance) + } + }) + + t.Run("config catalog", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/api/v1/config", "", nil) + if response.StatusCode != http.StatusOK { + t.Fatalf("status = %d: %s", response.StatusCode, body) + } + var config struct { + Recipient string `json:"recipient"` + FeePayer string `json:"feePayer"` + Endpoints []endpointInfo `json:"endpoints"` + } + decodeBody(t, body, &config) + if config.Recipient != a.recipient { + t.Fatalf("recipient = %q", config.Recipient) + } + ids := map[string]endpointInfo{} + for _, e := range config.Endpoints { + ids[e.ID] = e + } + for _, want := range []string{"stocks-quote", "marketplace-buy", "sessions-stream", "sessions-compute"} { + if _, ok := ids[want]; !ok { + t.Fatalf("catalog missing %q: %s", want, body) + } + } + if ids["sessions-stream"].UnitPrice != "100" || ids["sessions-compute"].UnitPrice != "5000" { + t.Fatalf("unit prices = %q / %q", ids["sessions-stream"].UnitPrice, ids["sessions-compute"].UnitPrice) + } + if _, ok := ids["premium-feed"]; ok { + t.Fatal("catalog must omit the subscription entry (no Go subscription method)") + } + }) + + t.Run("charge endpoints issue MPP challenges", func(t *testing.T) { + for _, path := range []string{ + "/api/v1/stocks/quote/AAPL", + "/api/v1/stocks/search?q=apple", + "/api/v1/stocks/history/AAPL", + "/api/v1/weather/tokyo", + "/api/v1/marketplace/buy/sol-hoodie", + } { + response, body := doRequest(t, http.MethodGet, base+path, "", nil) + if response.StatusCode != http.StatusPaymentRequired { + t.Fatalf("%s status = %d: %s", path, response.StatusCode, body) + } + wwwAuth := response.Header.Get("WWW-Authenticate") + if !strings.Contains(wwwAuth, "intent=\"charge\"") { + t.Fatalf("%s WWW-Authenticate = %q", path, wwwAuth) + } + var challenge struct { + Error string `json:"error"` + Accepts []struct { + Protocol string `json:"protocol"` + } `json:"accepts"` + } + decodeBody(t, body, &challenge) + if challenge.Error != "payment_required" || len(challenge.Accepts) != 1 || challenge.Accepts[0].Protocol != "mpp" { + t.Fatalf("%s challenge body = %s", path, body) + } + } + }) + + t.Run("pre-gate validation runs before payment", func(t *testing.T) { + for path, wantStatus := range map[string]int{ + "/api/v1/weather/atlantis": http.StatusNotFound, + "/api/v1/marketplace/buy/unknown": http.StatusNotFound, + "/api/v1/stocks/search": http.StatusBadRequest, + "/api/v1/marketplace/buy/sol-shirt": http.StatusNotFound, + } { + response, body := doRequest(t, http.MethodGet, base+path, "", nil) + if response.StatusCode != wantStatus { + t.Fatalf("%s status = %d, want %d: %s", path, response.StatusCode, wantStatus, body) + } + } + }) + + t.Run("marketplace products are free", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/api/v1/marketplace/products", "", nil) + if response.StatusCode != http.StatusOK { + t.Fatalf("status = %d: %s", response.StatusCode, body) + } + var list []struct { + ID string `json:"id"` + PriceRaw string `json:"priceRaw"` + } + decodeBody(t, body, &list) + if len(list) != 3 || list[0].ID != "sol-hoodie" || list[0].PriceRaw != "2000000" { + t.Fatalf("products = %s", body) + } + }) + + t.Run("fortune serves JSON, HTML, and service worker challenges", func(t *testing.T) { + response, _ := doRequest(t, http.MethodGet, base+"/api/v1/fortune", "", nil) + if response.StatusCode != http.StatusPaymentRequired || + !strings.Contains(response.Header.Get("Content-Type"), "json") { + t.Fatalf("JSON challenge: status = %d type = %q", response.StatusCode, response.Header.Get("Content-Type")) + } + + response, _ = doRequest(t, http.MethodGet, base+"/api/v1/fortune", "", map[string]string{"Accept": "text/html"}) + if response.StatusCode != http.StatusPaymentRequired || + !strings.Contains(response.Header.Get("Content-Type"), "text/html") { + t.Fatalf("HTML challenge: status = %d type = %q", response.StatusCode, response.Header.Get("Content-Type")) + } + + response, body := doRequest(t, http.MethodGet, base+"/api/v1/fortune?__mpp_worker", "", nil) + if response.StatusCode != http.StatusOK || + !strings.Contains(response.Header.Get("Content-Type"), "javascript") || + response.Header.Get("Service-Worker-Allowed") != "/" { + t.Fatalf("service worker: status = %d type = %q sw-allowed = %q body = %.40s", + response.StatusCode, response.Header.Get("Content-Type"), + response.Header.Get("Service-Worker-Allowed"), body) + } + }) + + t.Run("sessions issue session challenges", func(t *testing.T) { + for method, path := range map[string]string{ + http.MethodGet: "/sessions/stream", + http.MethodPost: "/sessions/compute", + } { + response, body := doRequest(t, method, base+path, "", nil) + if response.StatusCode != http.StatusPaymentRequired { + t.Fatalf("%s %s status = %d: %s", method, path, response.StatusCode, body) + } + wwwAuth := response.Header.Get("WWW-Authenticate") + if !strings.Contains(wwwAuth, "intent=\"session\"") { + t.Fatalf("%s WWW-Authenticate = %q", path, wwwAuth) + } + } + }) + + t.Run("session side channel validates input", func(t *testing.T) { + response, body := doRequest(t, http.MethodPost, base+"/__402/session/deliveries", + `{"amount":"100"}`, nil) + if response.StatusCode != http.StatusBadRequest || !strings.Contains(body, "sessionId") { + t.Fatalf("deliveries: status = %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodPost, base+"/__402/session/commit", + `{"deliveryId":"d-1"}`, nil) + if response.StatusCode != http.StatusBadRequest || !strings.Contains(body, "voucher") { + t.Fatalf("commit: status = %d body = %s", response.StatusCode, body) + } + }) + + t.Run("session receipt", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/sessions/receipt/unknown-channel", "", nil) + if response.StatusCode != http.StatusNotFound || !strings.Contains(body, "channel-not-found") { + t.Fatalf("status = %d body = %s", response.StatusCode, body) + } + }) + + t.Run("premium feed is a documented stub", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/api/v1/premium/feed", "", nil) + if response.StatusCode != http.StatusNotImplemented || !strings.Contains(body, "not_implemented") { + t.Fatalf("status = %d body = %s", response.StatusCode, body) + } + }) + + t.Run("faucet", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/api/v1/faucet/status", "", nil) + if response.StatusCode != http.StatusOK || !strings.Contains(body, paycore.USDCMainnetMint) { + t.Fatalf("status: %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodPost, base+"/api/v1/faucet/airdrop", `{}`, nil) + if response.StatusCode != http.StatusBadRequest { + t.Fatalf("missing address: status = %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodPost, base+"/api/v1/faucet/airdrop", + `{"address":"`+a.recipient+`"}`, nil) + if response.StatusCode != http.StatusOK || !strings.Contains(body, `"ok":true`) { + t.Fatalf("airdrop: status = %d body = %s", response.StatusCode, body) + } + }) + + t.Run("facilitator", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/facilitator/supported", "", nil) + if response.StatusCode != http.StatusOK || !strings.Contains(body, `"scheme":"exact"`) { + t.Fatalf("supported: status = %d body = %s", response.StatusCode, body) + } + + _, body = doRequest(t, http.MethodPost, base+"/facilitator/verify", `{}`, nil) + if !strings.Contains(body, `"isValid":false`) { + t.Fatalf("verify missing payload: %s", body) + } + _, body = doRequest(t, http.MethodPost, base+"/facilitator/verify", + `{"paymentPayload":{"payload":{"authorization":{"from":"payer-address"}}}}`, nil) + if !strings.Contains(body, `"isValid":true`) || !strings.Contains(body, "payer-address") { + t.Fatalf("verify: %s", body) + } + + _, body = doRequest(t, http.MethodPost, base+"/facilitator/settle", `{}`, nil) + if !strings.Contains(body, `"success":false`) { + t.Fatalf("settle missing payload: %s", body) + } + _, body = doRequest(t, http.MethodPost, base+"/facilitator/settle", + `{"paymentPayload":{"payload":{"transaction":"AAAA"}}}`, nil) + if !strings.Contains(body, `"success":true`) || !strings.Contains(body, "stub-signature") { + t.Fatalf("settle: %s", body) + } + }) + + t.Run("x402 routes issue x402 challenges", func(t *testing.T) { + for _, path := range []string{"/x402/joke", "/x402/fact"} { + response, body := doRequest(t, http.MethodGet, base+path, "", nil) + if response.StatusCode != http.StatusPaymentRequired { + t.Fatalf("%s status = %d: %s", path, response.StatusCode, body) + } + var challenge struct { + Accepts []struct { + Protocol string `json:"protocol"` + Scheme string `json:"scheme"` + } `json:"accepts"` + } + decodeBody(t, body, &challenge) + if len(challenge.Accepts) != 1 || challenge.Accepts[0].Protocol != "x402" || challenge.Accepts[0].Scheme != "exact" { + t.Fatalf("%s challenge = %s", path, body) + } + } + }) + + t.Run("docs", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/api/v1/docs", "", nil) + if response.StatusCode != http.StatusOK || !strings.Contains(body, `"go":false`) { + t.Fatalf("docs index: status = %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodGet, base+"/api/v1/docs/cobol/tree", "", nil) + if response.StatusCode != http.StatusNotFound || !strings.Contains(body, "unknown_lang") { + t.Fatalf("unknown lang: status = %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodGet, base+"/api/v1/docs/go/tree", "", nil) + if response.StatusCode != http.StatusNotFound || !strings.Contains(body, "not_generated") { + t.Fatalf("not generated: status = %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodGet, + base+"/api/v1/docs/go/file?path=../../../go.mod", "", nil) + if response.StatusCode != http.StatusBadRequest || !strings.Contains(body, "unsafe_path") { + t.Fatalf("path escape: status = %d body = %s", response.StatusCode, body) + } + response, body = doRequest(t, http.MethodGet, + base+"/api/v1/docs/go/file?path=notes.txt", "", nil) + if response.StatusCode != http.StatusBadRequest || !strings.Contains(body, "not_markdown") { + t.Fatalf("non markdown: status = %d body = %s", response.StatusCode, body) + } + }) + + t.Run("CORS exposes the payment headers", func(t *testing.T) { + response, _ := doRequest(t, http.MethodGet, base+"/api/v1/health", "", nil) + exposed := response.Header.Get("Access-Control-Expose-Headers") + for _, header := range []string{"www-authenticate", "payment-receipt", "x-payment-required", "x-payment-response"} { + if !strings.Contains(exposed, header) { + t.Fatalf("expose headers = %q missing %q", exposed, header) + } + } + request, err := http.NewRequest(http.MethodOptions, base+"/api/v1/fortune", nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + request.Header.Set("Access-Control-Request-Headers", "authorization") + preflight, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("OPTIONS: %v", err) + } + preflight.Body.Close() + if preflight.StatusCode != http.StatusNoContent || + preflight.Header.Get("Access-Control-Allow-Headers") != "authorization" { + t.Fatalf("preflight: status = %d allow-headers = %q", + preflight.StatusCode, preflight.Header.Get("Access-Control-Allow-Headers")) + } + }) + + t.Run("catch-all", func(t *testing.T) { + response, body := doRequest(t, http.MethodGet, base+"/nonexistent", "", nil) + if response.StatusCode != http.StatusNotFound { + t.Fatalf("status = %d body = %s", response.StatusCode, body) + } + }) +} diff --git a/go/examples/playground-api/playground_e2e_test.go b/go/examples/playground-api/playground_e2e_test.go new file mode 100644 index 000000000..4088c35c8 --- /dev/null +++ b/go/examples/playground-api/playground_e2e_test.go @@ -0,0 +1,238 @@ +package main + +// Surfpool-gated end-to-end test: boots the real playground handler against +// the hosted Solana Payment Sandbox, funds a wallet through the faucet +// cheatcodes, opens a payment channel on the /sessions/stream 402 (client +// pre-signs, server completes the fee-payer signature and broadcasts), +// streams the metered SSE chunks, commits a voucher through the side +// channel, and polls /sessions/receipt until the idle-close watchdog settles +// the channel on-chain. Skips explicitly (never silently passes) when the +// sandbox is unreachable or under -short. + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/client" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" +) + +// sandboxRPCURL resolves the sandbox endpoint, honoring the harness override. +func sandboxRPCURL() string { + if url := os.Getenv("MPP_HARNESS_RPC_URL"); url != "" { + return url + } + return "https://402.surfnet.dev:8899" +} + +// requireSandbox skips the test explicitly when the sandbox is unreachable. +func requireSandbox(t *testing.T) *rpc.Client { + t.Helper() + if testing.Short() { + t.Skip("skipping surfpool e2e in -short mode") + } + url := sandboxRPCURL() + rpcClient := rpc.New(url) + probeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := rpcClient.GetLatestBlockhash(probeCtx, rpc.CommitmentConfirmed); err != nil { + t.Skipf("surfpool sandbox unreachable at %s: %v", url, err) + } + return rpcClient +} + +func TestPlaygroundSessionE2ESurfpool(t *testing.T) { + rpcClient := requireSandbox(t) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + feePayer, err := solana.NewRandomPrivateKey() + if err != nil { + t.Fatalf("generate fee payer: %v", err) + } + a := &app{ + network: "localnet", + rpcURL: sandboxRPCURL(), + recipient: feePayer.PublicKey().String(), + secretKey: "playground-e2e-secret", + feePayer: feePayer, + rpcClient: rpcClient, + repoRoot: t.TempDir(), + } + bootstrapFunding(a) + + handler, shutdown, err := newApp(a) + if err != nil { + t.Fatalf("newApp: %v", err) + } + t.Cleanup(shutdown) + httpServer := httptest.NewServer(handler) + t.Cleanup(httpServer.Close) + + // Fund the paying wallet through the playground's own faucet endpoint. + payer, err := solana.NewRandomPrivateKey() + if err != nil { + t.Fatalf("generate payer: %v", err) + } + response, body := playgroundRequest(t, http.MethodPost, httpServer.URL+"/api/v1/faucet/airdrop", + `{"address":"`+payer.PublicKey().String()+`"}`, "") + if response.StatusCode != http.StatusOK { + t.Fatalf("faucet airdrop failed: %d %s", response.StatusCode, body) + } + + // 1. Unauthenticated request: 402 with a session challenge carrying a + // recent blockhash from the sandbox. + streamURL := httpServer.URL + "/sessions/stream" + response, body = playgroundRequest(t, http.MethodGet, streamURL, "", "") + if response.StatusCode != http.StatusPaymentRequired { + t.Fatalf("expected 402, got %d: %s", response.StatusCode, body) + } + challenge, request, err := client.ParseSessionChallenge(response.Header.Get(core.WWWAuthenticateHeader)) + if err != nil { + t.Fatalf("ParseSessionChallenge: %v", err) + } + if request.RecentBlockhash == nil { + t.Fatal("challenge missing recentBlockhash") + } + + // 2. Open: the client derives the channel and partial-signs as the payer; + // the playground completes the fee-payer signature and broadcasts. + sessionSigner, err := client.NewEphemeralSessionSigner() + if err != nil { + t.Fatalf("NewEphemeralSessionSigner: %v", err) + } + opener, err := client.CreatePaymentChannelSessionOpener(request, payer, sessionSigner, "", + client.PaymentChannelSessionOpenOptions{}) + if err != nil { + t.Fatalf("CreatePaymentChannelSessionOpener: %v", err) + } + openAuthorization, err := client.SerializeSessionCredential(challenge, opener.Action) + if err != nil { + t.Fatalf("serialize open credential: %v", err) + } + response, body = playgroundRequest(t, http.MethodGet, streamURL, "", openAuthorization) + if response.StatusCode != http.StatusOK { + t.Fatalf("open failed: %d %s", response.StatusCode, body) + } + if !strings.Contains(body, "payment channel") || !strings.Contains(body, "[DONE]") { + t.Fatalf("stream body missing chunks or sentinel: %s", body) + } + channelID := opener.Session.ChannelIDString() + + // 3. Side-channel reserve + commit for the seven streamed chunks. + directive := struct { + DeliveryID string `json:"deliveryId"` + }{} + response, body = playgroundRequest(t, http.MethodPost, httpServer.URL+"/__402/session/deliveries", + `{"sessionId":"`+channelID+`","amount":"700"}`, "") + if response.StatusCode != http.StatusOK { + t.Fatalf("reserve failed: %d %s", response.StatusCode, body) + } + if err := json.Unmarshal([]byte(body), &directive); err != nil || directive.DeliveryID == "" { + t.Fatalf("reserve directive = %s (%v)", body, err) + } + voucher, err := opener.Session.PrepareIncrement(700) + if err != nil { + t.Fatalf("PrepareIncrement: %v", err) + } + voucherJSON, err := json.Marshal(voucher) + if err != nil { + t.Fatalf("marshal voucher: %v", err) + } + response, body = playgroundRequest(t, http.MethodPost, httpServer.URL+"/__402/session/commit", + `{"deliveryId":"`+directive.DeliveryID+`","voucher":`+string(voucherJSON)+`}`, "") + if response.StatusCode != http.StatusOK || !strings.Contains(body, `"committed"`) { + t.Fatalf("commit failed: %d %s", response.StatusCode, body) + } + if err := opener.Session.RecordVoucher(voucher); err != nil { + t.Fatalf("RecordVoucher: %v", err) + } + + // 4. The idle-close watchdog settles on-chain ~2s after the last + // voucher; poll the receipt endpoint the way the web app does. + receipt := struct { + Finalized bool `json:"finalized"` + Cumulative string `json:"cumulative"` + SettledSignature *string `json:"settledSignature"` + }{} + deadline := time.Now().Add(60 * time.Second) + for { + response, body = playgroundRequest(t, http.MethodGet, + httpServer.URL+"/sessions/receipt/"+channelID, "", "") + if response.StatusCode == http.StatusOK { + if err := json.Unmarshal([]byte(body), &receipt); err != nil { + t.Fatalf("receipt body = %s (%v)", body, err) + } + if receipt.Finalized && receipt.SettledSignature != nil { + break + } + } + if time.Now().After(deadline) { + t.Fatalf("receipt never finalized: %d %s", response.StatusCode, body) + } + time.Sleep(time.Second) + } + if receipt.Cumulative != "700" { + t.Fatalf("settled cumulative = %s, want 700", receipt.Cumulative) + } + + // 5. The settle transaction confirmed on-chain. + settleSignature, err := solana.SignatureFromBase58(*receipt.SettledSignature) + if err != nil { + t.Fatalf("settled signature %q invalid: %v", *receipt.SettledSignature, err) + } + confirmDeadline := time.Now().Add(30 * time.Second) + for { + statuses, err := rpcClient.GetSignatureStatuses(ctx, true, settleSignature) + if err == nil && len(statuses.Value) > 0 && statuses.Value[0] != nil { + if statuses.Value[0].Err != nil { + t.Fatalf("settlement failed on-chain: %+v", statuses.Value[0].Err) + } + break + } + if time.Now().After(confirmDeadline) { + t.Fatalf("settlement %s never confirmed", settleSignature) + } + time.Sleep(time.Second) + } +} + +// playgroundRequest performs one HTTP request against the playground under +// test and returns the response plus its body. +func playgroundRequest(t *testing.T, method, url, body, authorization string) (*http.Response, string) { + t.Helper() + var reader io.Reader + if body != "" { + reader = strings.NewReader(body) + } + request, err := http.NewRequest(method, url, reader) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + if body != "" { + request.Header.Set("Content-Type", "application/json") + } + if authorization != "" { + request.Header.Set(core.AuthorizationHeader, authorization) + } + response, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("%s %s: %v", method, url, err) + } + raw, err := io.ReadAll(response.Body) + response.Body.Close() + if err != nil { + t.Fatalf("read body: %v", err) + } + return response, string(raw) +} diff --git a/go/examples/playground-api/sessions.go b/go/examples/playground-api/sessions.go new file mode 100644 index 000000000..478240539 --- /dev/null +++ b/go/examples/playground-api/sessions.go @@ -0,0 +1,192 @@ +package main + +// Two session-gated demo endpoints driven by the in-process session method, +// the reserve/commit metering side channel, and the settle-status receipt +// poll. Both methods share one channel store so the receipt endpoint can +// read the settled signature whichever endpoint opened the channel. + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" + server "github.com/solana-foundation/pay-kit/go/protocols/mpp/server" +) + +// tokenChunks is the canned token stream payload. +var tokenChunks = []string{ + "A payment channel ", + "lets a client and server ", + "authorize many small ", + "off-chain debits ", + "against a single on-chain ", + "deposit, settling the highest ", + "cumulative voucher at close.", +} + +// registerSessions mounts the session endpoints and returns the watchdog +// shutdown hook. +// +// Routes: +// - GET /sessions/stream: pay-per-chunk SSE, cap 1.00 USDC, 0.0001 USDC/chunk +// - POST /sessions/stream: voucher commits for the stream endpoint +// - POST /sessions/compute: pay-per-call compute, cap 0.50 USDC, 0.005 USDC/call +// (also accepts voucher commits) +// - POST /__402/session/deliveries: SessionFetch-style delivery reservation +// - POST /__402/session/commit: body-voucher commit variant of the above +// - GET /sessions/receipt/{channelId}: settle-status poll for the UI +func registerSessions(mux *http.ServeMux, a *app) (func(), error) { + // Shared store across both session methods so /sessions/receipt can read + // channel state regardless of which endpoint opened the channel. + sharedStore := server.NewMemoryChannelStore() + strategy := intents.SessionPullVoucherStrategyClientVoucher + + newMethod := func(cap uint64) (*server.Session, error) { + return server.NewSession(server.SessionOptions{ + Operator: a.feePayer.PublicKey().String(), + Recipient: a.recipient, + Cap: cap, + Currency: paycore.USDCMainnetMint, + Decimals: usdcDecimals, + Network: a.network, + SecretKey: a.secretKey, + // Real on-chain opens: the browser pre-signs a payment-channel + // open transaction (fee payer = operator) and the server + // completes the signature, broadcasts, and waits for + // confirmation before metering. + Modes: []intents.SessionMode{intents.SessionModePull}, + PullVoucherStrategy: &strategy, + OpenTxSubmitter: server.OpenTxSubmitterServer, + // Settle roughly two seconds after the stream ends so the UI's + // receipt poll resolves quickly. + CloseDelay: 2 * time.Second, + PaymentChannelPayerSigner: a.feePayer, + Signer: a.feePayer, + RPC: a.rpcClient, + Store: sharedStore, + }) + } + + streamSession, err := newMethod(1_000_000) // 1.00 USDC + if err != nil { + return nil, fmt.Errorf("stream session method: %w", err) + } + computeSession, err := newMethod(500_000) // 0.50 USDC + if err != nil { + streamSession.Shutdown() + return nil, fmt.Errorf("compute session method: %w", err) + } + shutdown := func() { + streamSession.Shutdown() + computeSession.Shutdown() + } + + streamGate := server.SessionMiddleware(streamSession, func(*http.Request) (server.SessionChallengeOptions, error) { + return server.SessionChallengeOptions{Cap: "1000000", Description: "Metered token stream"}, nil + }) + computeGate := server.SessionMiddleware(computeSession, func(*http.Request) (server.SessionChallengeOptions, error) { + return server.SessionChallengeOptions{Cap: "500000", Description: "Voucher-billed inference call"}, nil + }) + + // GET /sessions/stream: stream tokens as SSE; each chunk costs 0.0001 USDC. + mux.Handle("GET /sessions/stream", streamGate(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + stream := server.NewMeteredStream(w) + w.WriteHeader(http.StatusOK) + for _, chunk := range tokenChunks { + if err := stream.WriteJSON(map[string]string{"chunk": chunk, "cost": "100"}); err != nil { + return + } + time.Sleep(80 * time.Millisecond) + } + _ = stream.WriteDone() + }))) + + // POST /sessions/stream: voucher commits arrive on the URL the session + // was opened against, with the signed voucher in the Authorization + // credential. The middleware's verify path applies it; the body is an ack. + mux.Handle("POST /sessions/stream", streamGate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, commitAck(r)) + }))) + + // POST /sessions/compute: pay-per-call compute; the same handler also + // accepts voucher commits (a deliveryId in the body discriminates). + mux.Handle("POST /sessions/compute", computeGate(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body struct { + Prompt string `json:"prompt"` + DeliveryID string `json:"deliveryId"` + } + _ = json.NewDecoder(r.Body).Decode(&body) + if body.DeliveryID != "" { + writeJSON(w, http.StatusOK, map[string]string{ + "amount": "0", + "deliveryId": body.DeliveryID, + "status": "committed", + }) + return + } + logPayment(r.URL.Path, w.Header()) + writeJSON(w, http.StatusOK, map[string]string{ + "prompt": body.Prompt, + "output": "Echo: " + body.Prompt + " (computed for 0.005 USDC)", + "computedAt": time.Now().UTC().Format(time.RFC3339), + }) + }))) + + // Side-channel metering routes: SessionFetch-style clients reserve + // capacity for each metered delivery before signing + committing the + // voucher. Both handlers share the methods' channel store. + routes := streamSession.Routes() + mux.HandleFunc("POST /__402/session/deliveries", routes.Deliveries) + mux.HandleFunc("POST /__402/session/commit", routes.Commit) + + // Receipt poll endpoint: the UI hits this after the stream ends to learn + // the on-chain settle signature. The idle-close watchdog fires about + // CloseDelay after the last voucher and, with Signer + RPC configured + // above, attempts the on-chain settle-and-distribute. + mux.HandleFunc("GET /sessions/receipt/{channelId}", func(w http.ResponseWriter, r *http.Request) { + channelID := r.PathValue("channelId") + if channelID == "" { + writeJSONError(w, http.StatusBadRequest, "invalid-channel-id") + return + } + state, err := sharedStore.GetChannel(r.Context(), channelID) + if err != nil || state == nil { + writeJSONError(w, http.StatusNotFound, "channel-not-found") + return + } + var settledSignature any + if state.SettledSignature != nil { + settledSignature = *state.SettledSignature + } + writeJSON(w, http.StatusOK, map[string]any{ + "channelId": state.ChannelID, + "cumulative": fmt.Sprintf("%d", state.Cumulative), + "deposit": fmt.Sprintf("%d", state.Deposit), + "finalized": state.Finalized, + "settledSignature": settledSignature, + }) + }) + + return shutdown, nil +} + +// commitAck is the minimal CommitReceipt-shaped JSON ack the stream commit +// handler returns. +func commitAck(r *http.Request) map[string]string { + var body struct { + Amount string `json:"amount"` + DeliveryID string `json:"deliveryId"` + } + _ = json.NewDecoder(r.Body).Decode(&body) + if body.Amount == "" { + body.Amount = "0" + } + return map[string]string{ + "amount": body.Amount, + "deliveryId": body.DeliveryID, + "status": "committed", + } +} diff --git a/go/examples/playground-api/subscriptions.go b/go/examples/playground-api/subscriptions.go new file mode 100644 index 000000000..cdc7e3dc0 --- /dev/null +++ b/go/examples/playground-api/subscriptions.go @@ -0,0 +1,21 @@ +package main + +// Subscriptions module. The Go SDK does not implement the subscription +// server method yet, so this module keeps the /api/v1/premium/feed route +// (nothing is silently dropped) and answers 501 with an explicit pointer at +// the gap. The endpoint catalog omits the subscription entry, so the +// playground UI renders its graceful empty state. See README.md. + +import "net/http" + +// registerSubscriptions mounts the documented subscription stub. +func registerSubscriptions(mux *http.ServeMux) { + mux.HandleFunc("GET /api/v1/premium/feed", func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusNotImplemented, map[string]string{ + "error": "not_implemented", + "detail": "The Go SDK does not ship the solana.subscription server method yet; " + + "this route exists for parity with typescript/examples/playground-api and " + + "will be gated once the Go subscription intent lands.", + }) + }) +} diff --git a/go/examples/playground-api/utils.go b/go/examples/playground-api/utils.go new file mode 100644 index 000000000..dcd010adf --- /dev/null +++ b/go/examples/playground-api/utils.go @@ -0,0 +1,102 @@ +package main + +// Shared helpers: ANSI color helpers, the surfnet JSON-RPC cheatcode +// caller, and the settlement / receipt log lines. + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "os" + "time" + + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" +) + +const ansiReset = "\x1b[0m" + +func dim(s string) string { return "\x1b[2m" + s + ansiReset } +func green(s string) string { return "\x1b[32m" + s + ansiReset } +func cyan(s string) string { return "\x1b[36m" + s + ansiReset } +func yellow(s string) string { return "\x1b[33m" + s + ansiReset } +func magenta(s string) string { return "\x1b[35m" + s + ansiReset } +func bold(s string) string { return "\x1b[1m" + s + ansiReset } + +// rpcCall performs a JSON-RPC call against the surfnet endpoint and returns +// the raw result. Used for the surfnet_* cheatcodes the standard RPC client +// does not expose. +func rpcCall(ctx context.Context, rpcURL, method string, params []any) (json.RawMessage, error) { + payload, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": method, + "params": params, + }) + if err != nil { + return nil, err + } + callCtx, cancel := context.WithTimeout(ctx, 8*time.Second) + defer cancel() + request, err := http.NewRequestWithContext(callCtx, http.MethodPost, rpcURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + request.Header.Set("Content-Type", "application/json") + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, err + } + defer func() { _ = response.Body.Close() }() + var body struct { + Result json.RawMessage `json:"result"` + Error *struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.NewDecoder(response.Body).Decode(&body); err != nil { + return nil, err + } + if body.Error != nil { + return nil, fmt.Errorf("%s: %s", method, body.Error.Message) + } + return body.Result, nil +} + +// logTx prints a settlement-signature link for quick eyeball debugging. +func logTx(path, reference string) { + studio := os.Getenv("STUDIO_PORT") + if studio == "" { + studio = "18488" + } + log.Printf(" %s %s %s %s", green("ok"), path, dim("tx:"), + cyan(fmt.Sprintf("http://localhost:%s/?t=%s", studio, reference))) +} + +// logPayment prints the receipt reference from a Payment-Receipt response +// header, when present. +func logPayment(path string, header http.Header) { + value := header.Get(core.PaymentReceiptHeader) + if value == "" { + return + } + receipt, err := core.ParseReceipt(value) + if err != nil || receipt.Reference == "" { + return + } + logTx(path, receipt.Reference) +} + +// writeJSON writes v as a JSON response with the given status code. +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +// writeJSONError writes the standard {"error": message} JSON error body. +func writeJSONError(w http.ResponseWriter, status int, message string) { + writeJSON(w, status, map[string]string{"error": message}) +} diff --git a/go/examples/playground-api/x402.go b/go/examples/playground-api/x402.go new file mode 100644 index 000000000..e130f5070 --- /dev/null +++ b/go/examples/playground-api/x402.go @@ -0,0 +1,165 @@ +package main + +// The embedded facilitator endpoints plus two x402-gated demo routes. +// +// The Go x402 adapter only implements self-hosted mode (it verifies and +// settles in-process with the operator signer), so the /x402/joke and +// /x402/fact gates here settle locally instead of POSTing to the embedded +// facilitator. The facilitator endpoints are still served with the standard +// shapes for external x402 clients. See README.md. + +import ( + "encoding/json" + "math/rand" + "net/http" + + "github.com/solana-foundation/pay-kit/go/paycore/signer" + "github.com/solana-foundation/pay-kit/go/paykit" +) + +// jokes is the canned joke pool. +var jokes = []string{ + "Why do programmers prefer dark mode? Because light attracts bugs.", + "There are 10 types of people: those who understand binary and those who don’t.", + "A SQL query walks into a bar, sees two tables, and asks: \"Can I JOIN you?\"", + "A photon checks into a hotel; the bellhop asks if it has any luggage. \"No, I’m traveling light.\"", +} + +// facts is the canned fun-fact pool. +var facts = []string{ + "Honey never spoils. Archaeologists found 3000-year-old honey in Egyptian tombs.", + "Octopuses have three hearts and blue blood.", + "A group of flamingos is called a \"flamboyance\".", + "Bananas are berries; strawberries are not.", +} + +// registerX402 mounts the embedded facilitator and the x402-gated routes. +func registerX402(mux *http.ServeMux, a *app) error { + // Embedded facilitator. + mux.HandleFunc("GET /facilitator/supported", func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]any{ + "kinds": []map[string]any{ + { + "scheme": "exact", + "network": "solana-devnet", + "extra": map[string]string{"feePayer": a.feePayer.PublicKey().String()}, + }, + }, + }) + }) + + mux.HandleFunc("POST /facilitator/verify", func(w http.ResponseWriter, r *http.Request) { + var body struct { + PaymentPayload *struct { + Payload *struct { + Authorization *struct { + From string `json:"from"` + } `json:"authorization"` + } `json:"payload"` + } `json:"paymentPayload"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil || + body.PaymentPayload == nil || body.PaymentPayload.Payload == nil { + writeJSON(w, http.StatusOK, map[string]any{ + "isValid": false, + "invalidReason": "Missing payload", + }) + return + } + payer := "unknown" + if auth := body.PaymentPayload.Payload.Authorization; auth != nil && auth.From != "" { + payer = auth.From + } + writeJSON(w, http.StatusOK, map[string]any{"isValid": true, "payer": payer}) + }) + + mux.HandleFunc("POST /facilitator/settle", func(w http.ResponseWriter, r *http.Request) { + var body struct { + PaymentPayload *struct { + Payload *struct { + Transaction string `json:"transaction"` + } `json:"payload"` + } `json:"paymentPayload"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil || + body.PaymentPayload == nil || body.PaymentPayload.Payload == nil { + writeJSON(w, http.StatusOK, map[string]any{ + "success": false, + "errorReason": "Missing payload", + }) + return + } + transaction := body.PaymentPayload.Payload.Transaction + if transaction == "" { + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "transaction": "local-facilitator-settled", + }) + return + } + result, err := rpcCall(r.Context(), a.rpcURL, "sendTransaction", []any{ + transaction, + map[string]any{"encoding": "base64", "skipPreflight": true}, + }) + if err != nil { + writeJSON(w, http.StatusOK, map[string]any{ + "success": false, + "errorReason": err.Error(), + }) + return + } + var signature string + _ = json.Unmarshal(result, &signature) + writeJSON(w, http.StatusOK, map[string]any{"success": true, "transaction": signature}) + }) + + // x402-gated routes: a dedicated x402-only paykit client, self-hosted + // verification + settlement against the configured RPC. + network, err := paykit.ParseNetwork(a.network) + if err != nil { + return err + } + operatorSigner, err := signer.FromBase58(a.feePayer.String()) + if err != nil { + return err + } + client, err := paykit.New(paykit.Config{ + Network: network, + RPCURL: a.rpcURL, + Accept: []paykit.Protocol{paykit.X402}, + Operator: paykit.Operator{ + Recipient: paykit.Address(a.recipient), + Signer: operatorSigner, + FeePayer: true, + }, + }) + if err != nil { + return err + } + + jokeGate := paykit.Gate{ + Amount: paykit.MustParseUSD("0.001"), + Name: "x402Joke", + Desc: "A random programmer joke", + } + mux.Handle("GET /x402/joke", client.Require(jokeGate)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{ + "joke": jokes[rand.Intn(len(jokes))], + "source": "x402", + }) + }))) + + factGate := paykit.Gate{ + Amount: paykit.MustParseUSD("0.001"), + Name: "x402Fact", + Desc: "A random fun fact", + } + mux.Handle("GET /x402/fact", client.Require(factGate)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{ + "fact": facts[rand.Intn(len(facts))], + "source": "x402", + }) + }))) + + return nil +} diff --git a/go/examples/playground-api/yahoo.go b/go/examples/playground-api/yahoo.go new file mode 100644 index 000000000..bef0e701e --- /dev/null +++ b/go/examples/playground-api/yahoo.go @@ -0,0 +1,588 @@ +package main + +// Yahoo Finance client returning the same JSON shapes as the yahoo-finance2 +// npm package (v3), which the playground API contract is defined against: +// the v7 quote endpoint (crumb-authenticated), the v1 search endpoint, and +// the v8 chart endpoint with the package's "array" result layout. +// Epoch-second date fields become ISO-8601 millisecond strings, "low - high" +// range strings become {low, high} objects, and chart indicator columns are +// zipped into per-day quote rows. + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "net/http/cookiejar" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +// yahooUserAgent is sent on every upstream request; Yahoo rejects the Go +// default agent on the crumb endpoint. +const yahooUserAgent = "Mozilla/5.0 (compatible; pay-kit-playground/1.0)" + +// yahooClient calls the public Yahoo Finance endpoints, holding the cookie +// jar and crumb the v7 quote endpoint requires. +type yahooClient struct { + // httpClient carries the cookie jar shared by the crumb and data calls. + httpClient *http.Client + // mu guards crumb. + mu sync.Mutex + // crumb is the cached anti-CSRF token for the v7 quote endpoint. + crumb string +} + +// newYahooClient builds a client with a fresh in-memory cookie jar. +func newYahooClient() *yahooClient { + jar, _ := cookiejar.New(nil) + return &yahooClient{httpClient: &http.Client{Jar: jar, Timeout: 10 * time.Second}} +} + +// get fetches a Yahoo endpoint and returns the raw body for 2xx responses. +func (c *yahooClient) get(ctx context.Context, rawURL string) ([]byte, int, error) { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, 0, err + } + request.Header.Set("User-Agent", yahooUserAgent) + response, err := c.httpClient.Do(request) + if err != nil { + return nil, 0, err + } + defer func() { _ = response.Body.Close() }() + body, err := io.ReadAll(io.LimitReader(response.Body, 8<<20)) + if err != nil { + return nil, response.StatusCode, err + } + if response.StatusCode != http.StatusOK { + return body, response.StatusCode, fmt.Errorf("yahoo finance: HTTP %d", response.StatusCode) + } + return body, response.StatusCode, nil +} + +// getJSON fetches a Yahoo endpoint and decodes the JSON body. Numbers +// decode to float64, the representation JSON.parse uses, so re-encoding +// renders them identically to yahoo-finance2 output. +func (c *yahooClient) getJSON(ctx context.Context, rawURL string, out any) error { + body, _, err := c.get(ctx, rawURL) + if err != nil { + return err + } + return json.Unmarshal(body, out) +} + +// getCrumb returns the cached crumb, fetching cookies plus a fresh crumb on +// first use. +func (c *yahooClient) getCrumb(ctx context.Context) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.crumb != "" { + return c.crumb, nil + } + // Any fc.yahoo.com response sets the session cookie the crumb endpoint + // checks; the 404 body itself is irrelevant. + if _, _, err := c.get(ctx, "https://fc.yahoo.com/"); err != nil && !strings.Contains(err.Error(), "HTTP") { + return "", err + } + body, _, err := c.get(ctx, "https://query1.finance.yahoo.com/v1/test/getcrumb") + if err != nil { + return "", err + } + crumb := strings.TrimSpace(string(body)) + if crumb == "" || strings.Contains(crumb, "Too Many Requests") { + return "", fmt.Errorf("yahoo finance: could not obtain crumb") + } + c.crumb = crumb + return crumb, nil +} + +// invalidateCrumb drops the cached crumb so the next call refreshes it. +func (c *yahooClient) invalidateCrumb() { + c.mu.Lock() + c.crumb = "" + c.mu.Unlock() +} + +// quoteDateFields are the v7 quote fields yahoo-finance2 types as Date +// (epoch seconds or date strings upstream, ISO strings in the response). +var quoteDateFields = map[string]bool{ + "dividendDate": true, + "earningsTimestamp": true, + "earningsTimestampStart": true, + "earningsTimestampEnd": true, + "earningsCallTimestampStart": true, + "earningsCallTimestampEnd": true, + "expireDate": true, + "expireIsoDate": true, + "extendedMarketTime": true, + "ipoExpectedDate": true, + "nameChangeDate": true, + "newListingDate": true, + "postMarketTime": true, + "preMarketTime": true, + "regularMarketTime": true, + "startDate": true, +} + +// quoteDateMsFields are the v7 quote fields typed as millisecond dates. +var quoteDateMsFields = map[string]bool{ + "firstTradeDateMilliseconds": true, +} + +// quoteRangeFields are the v7 quote fields delivered as "low - high" +// strings and returned as {low, high} objects. +var quoteRangeFields = map[string]bool{ + "fiftyTwoWeekRange": true, + "regularMarketDayRange": true, +} + +// searchDateFields are the search-quote fields typed as dates. +var searchDateFields = map[string]bool{ + "newListingDate": true, + "nameChangeDate": true, +} + +// quote returns the first v7 quote for symbol with yahoo-finance2's field +// coercions applied, or nil when the symbol is unknown or delisted. +func (c *yahooClient) quote(ctx context.Context, symbol string) (map[string]any, error) { + crumb, err := c.getCrumb(ctx) + if err != nil { + return nil, err + } + quoteURL := "https://query2.finance.yahoo.com/v7/finance/quote?symbols=" + + url.QueryEscape(symbol) + "&crumb=" + url.QueryEscape(crumb) + var body struct { + QuoteResponse struct { + Result []map[string]any `json:"result"` + Error any `json:"error"` + } `json:"quoteResponse"` + Finance struct { + Error *struct { + Description string `json:"description"` + } `json:"error"` + } `json:"finance"` + } + if err := c.getJSON(ctx, quoteURL, &body); err != nil { + // An expired crumb surfaces as HTTP 401; refresh once and retry. + c.invalidateCrumb() + if crumb, err = c.getCrumb(ctx); err != nil { + return nil, err + } + quoteURL = "https://query2.finance.yahoo.com/v7/finance/quote?symbols=" + + url.QueryEscape(symbol) + "&crumb=" + url.QueryEscape(crumb) + if err := c.getJSON(ctx, quoteURL, &body); err != nil { + return nil, err + } + } + if body.Finance.Error != nil { + return nil, fmt.Errorf("yahoo finance: %s", body.Finance.Error.Description) + } + for _, result := range body.QuoteResponse.Result { + if quoteType, _ := result["quoteType"].(string); quoteType == "NONE" { + continue + } + if err := coerceQuoteFields(result); err != nil { + return nil, err + } + return result, nil + } + return nil, nil +} + +// search returns the search endpoint's quotes array for query, issuing +// yahoo-finance2's default request parameters so the result list (counts +// included) matches the package's output. +func (c *yahooClient) search(ctx context.Context, query string) ([]map[string]any, error) { + params := url.Values{ + "q": {query}, + "lang": {"en-US"}, + "region": {"US"}, + "quotesCount": {"6"}, + "newsCount": {"4"}, + "enableFuzzyQuery": {"false"}, + "quotesQueryId": {"tss_match_phrase_query"}, + "multiQuoteQueryId": {"multi_quote_single_token_query"}, + "newsQueryId": {"news_cie_vespa"}, + "enableCb": {"true"}, + "enableNavLinks": {"true"}, + "enableEnhancedTrivialQuery": {"true"}, + } + var body struct { + Quotes []map[string]any `json:"quotes"` + } + searchURL := "https://query2.finance.yahoo.com/v1/finance/search?" + params.Encode() + if err := c.getJSON(ctx, searchURL, &body); err != nil { + return nil, err + } + for _, quote := range body.Quotes { + for field := range searchDateFields { + if value, ok := quote[field]; ok { + coerced, err := coerceDate(value, false) + if err != nil { + return nil, err + } + quote[field] = coerced + } + } + } + return body.Quotes, nil +} + +// chartRangeDays maps the playground's range parameter onto a day count +// (unknown ranges fall back to 30). +var chartRangeDays = map[string]int{"1d": 1, "5d": 5, "1mo": 30, "3mo": 90, "6mo": 180, "1y": 365} + +// chartQuote is one per-day row of the chart "array" layout. Field order +// and nullability match yahoo-finance2's assembled quote objects. +type chartQuote struct { + // Date is the trading day as an ISO-8601 millisecond string. + Date string `json:"date"` + // High is the day's high price, null when Yahoo has no datum. + High any `json:"high"` + // Volume is the day's traded volume, null when Yahoo has no datum. + Volume any `json:"volume"` + // Open is the day's opening price, null when Yahoo has no datum. + Open any `json:"open"` + // Low is the day's low price, null when Yahoo has no datum. + Low any `json:"low"` + // Close is the day's closing price, null when Yahoo has no datum. + Close any `json:"close"` + // Adjclose is the dividend/split-adjusted close (possibly null), + // omitted when the upstream response carries no adjclose column. + Adjclose *any `json:"adjclose,omitempty"` +} + +// history returns the v8 chart result for symbol over chartRange in +// yahoo-finance2's default "array" layout: the coerced meta object, the +// indicator columns zipped into per-day quote rows, and dividend/split +// events flattened into arrays. +func (c *yahooClient) history(ctx context.Context, symbol, chartRange string) (map[string]any, error) { + days, ok := chartRangeDays[chartRange] + if !ok { + days = 30 + } + now := time.Now() + params := url.Values{ + "useYfid": {"true"}, + "interval": {"1d"}, + "includePrePost": {"true"}, + "events": {"div|split|earn"}, + "lang": {"en-US"}, + "period1": {strconv.FormatInt(now.Add(-time.Duration(days)*24*time.Hour).Unix(), 10)}, + "period2": {strconv.FormatInt(now.Unix(), 10)}, + } + chartURL := "https://query2.finance.yahoo.com/v8/finance/chart/" + + url.PathEscape(symbol) + "?" + params.Encode() + var body struct { + Chart struct { + Result []map[string]any `json:"result"` + Error *struct { + Description string `json:"description"` + } `json:"error"` + } `json:"chart"` + } + if err := c.getJSON(ctx, chartURL, &body); err != nil { + return nil, err + } + if body.Chart.Error != nil { + return nil, fmt.Errorf("yahoo finance: %s", body.Chart.Error.Description) + } + if len(body.Chart.Result) == 0 { + return nil, fmt.Errorf("yahoo finance: empty chart result") + } + return chartToArrayLayout(body.Chart.Result[0]) +} + +// chartToArrayLayout converts one raw v8 chart result into yahoo-finance2's +// "array" return shape: {meta, quotes[], events?}. +func chartToArrayLayout(result map[string]any) (map[string]any, error) { + meta, _ := result["meta"].(map[string]any) + if err := coerceChartMeta(meta); err != nil { + return nil, err + } + out := map[string]any{"meta": meta, "quotes": []chartQuote{}} + + timestamps, _ := result["timestamp"].([]any) + indicators, _ := result["indicators"].(map[string]any) + if len(timestamps) > 0 { + quoteColumns, err := chartIndicatorColumn(indicators, "quote") + if err != nil { + return nil, err + } + adjcloseColumns, _ := chartIndicatorColumn(indicators, "adjclose") + var adjclose []any + if adjcloseColumns != nil { + adjclose, _ = adjcloseColumns["adjclose"].([]any) + } + quotes := make([]chartQuote, len(timestamps)) + for i, timestamp := range timestamps { + date, err := coerceDate(timestamp, false) + if err != nil { + return nil, err + } + quotes[i] = chartQuote{ + Date: date, + High: columnValue(quoteColumns, "high", i), + Volume: columnValue(quoteColumns, "volume", i), + Open: columnValue(quoteColumns, "open", i), + Low: columnValue(quoteColumns, "low", i), + Close: columnValue(quoteColumns, "close", i), + } + if adjclose != nil && i < len(adjclose) { + quotes[i].Adjclose = &adjclose[i] + } + } + out["quotes"] = quotes + } + + if rawEvents, ok := result["events"].(map[string]any); ok { + events := map[string]any{} + for _, kind := range []string{"dividends", "splits"} { + byTimestamp, ok := rawEvents[kind].(map[string]any) + if !ok { + continue + } + keys := make([]string, 0, len(byTimestamp)) + for key := range byTimestamp { + keys = append(keys, key) + } + // JS object iteration yields integer-like keys in ascending + // numeric order; Yahoo keys these maps by epoch seconds. + sort.Slice(keys, func(i, j int) bool { + a, _ := strconv.ParseInt(keys[i], 10, 64) + b, _ := strconv.ParseInt(keys[j], 10, 64) + return a < b + }) + items := make([]any, 0, len(keys)) + for _, key := range keys { + item := byTimestamp[key] + if event, ok := item.(map[string]any); ok { + if value, ok := event["date"]; ok { + date, err := coerceDate(value, false) + if err != nil { + return nil, err + } + event["date"] = date + } + } + items = append(items, item) + } + events[kind] = items + } + out["events"] = events + } + return out, nil +} + +// chartIndicatorColumn returns indicators.[0] as a column map. +func chartIndicatorColumn(indicators map[string]any, name string) (map[string]any, error) { + rows, ok := indicators[name].([]any) + if !ok || len(rows) == 0 { + if name == "quote" { + return nil, fmt.Errorf("yahoo finance: chart result missing quote indicators") + } + return nil, nil + } + columns, _ := rows[0].(map[string]any) + return columns, nil +} + +// columnValue returns column[i] or nil when the column is missing/short. +func columnValue(columns map[string]any, name string, i int) any { + values, _ := columns[name].([]any) + if i >= len(values) { + return nil + } + return values[i] +} + +// chartMetaDateFields are the chart meta fields typed as epoch-second dates. +var chartMetaDateFields = map[string]bool{ + "firstTradeDate": true, + "regularMarketTime": true, +} + +// coerceChartMeta applies yahoo-finance2's date coercions to the chart meta +// object, including the nested trading-period blocks. +func coerceChartMeta(meta map[string]any) error { + if meta == nil { + return nil + } + for field := range chartMetaDateFields { + if value, ok := meta[field]; ok && value != nil { + coerced, err := coerceDate(value, false) + if err != nil { + return err + } + meta[field] = coerced + } + } + if current, ok := meta["currentTradingPeriod"].(map[string]any); ok { + for _, key := range []string{"pre", "regular", "post"} { + if period, ok := current[key].(map[string]any); ok { + if err := coerceTradingPeriod(period); err != nil { + return err + } + } + } + } + switch periods := meta["tradingPeriods"].(type) { + case map[string]any: + for _, rows := range periods { + if err := coerceTradingPeriodRows(rows); err != nil { + return err + } + } + case []any: + if err := coerceTradingPeriodRows(periods); err != nil { + return err + } + } + return nil +} + +// coerceTradingPeriodRows coerces a [][]tradingPeriod nest. +func coerceTradingPeriodRows(rows any) error { + outer, ok := rows.([]any) + if !ok { + return nil + } + for _, inner := range outer { + periods, ok := inner.([]any) + if !ok { + continue + } + for _, entry := range periods { + if period, ok := entry.(map[string]any); ok { + if err := coerceTradingPeriod(period); err != nil { + return err + } + } + } + } + return nil +} + +// coerceTradingPeriod coerces one {timezone, start, end, gmtoffset} block. +func coerceTradingPeriod(period map[string]any) error { + for _, key := range []string{"start", "end"} { + if value, ok := period[key]; ok && value != nil { + coerced, err := coerceDate(value, false) + if err != nil { + return err + } + period[key] = coerced + } + } + return nil +} + +// coerceQuoteFields applies the quote schema's date and range coercions to +// one v7 quote result in place. +func coerceQuoteFields(result map[string]any) error { + for field, value := range result { + switch { + case quoteDateFields[field]: + coerced, err := coerceDate(value, false) + if err != nil { + return err + } + result[field] = coerced + case quoteDateMsFields[field]: + coerced, err := coerceDate(value, true) + if err != nil { + return err + } + result[field] = coerced + case quoteRangeFields[field]: + coerced, err := coerceRange(value) + if err != nil { + return err + } + result[field] = coerced + } + } + return nil +} + +// isoDatePattern matches the bare "YYYY-MM-DD" date strings Yahoo uses for +// listing-change fields. +var isoDatePattern = regexp.MustCompile(`^\d{4}-\d{2}-\d{2}$`) + +// isoDateTimePattern matches full ISO-8601 timestamps with optional +// milliseconds. +var isoDateTimePattern = regexp.MustCompile(`^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d{3})?Z$`) + +// coerceDate converts a Yahoo date value (epoch number, {raw} wrapper, or +// date string) into the ISO-8601 millisecond string a serialized JS Date +// produces. inMilliseconds flags fields already scaled to milliseconds. +func coerceDate(value any, inMilliseconds bool) (string, error) { + switch v := value.(type) { + case float64: + if !inMilliseconds { + v *= 1000 + } + return formatJSDate(int64(v)), nil + case map[string]any: + if raw, ok := v["raw"].(float64); ok { + return formatJSDate(int64(raw * 1000)), nil + } + case string: + if isoDatePattern.MatchString(v) { + t, err := time.Parse("2006-01-02", v) + if err == nil { + return formatJSDate(t.UnixMilli()), nil + } + } + if isoDateTimePattern.MatchString(v) { + t, err := time.Parse(time.RFC3339, v) + if err == nil { + return formatJSDate(t.UnixMilli()), nil + } + } + } + return "", fmt.Errorf("yahoo finance: unexpected date value %v", value) +} + +// formatJSDate renders epoch milliseconds the way Date.prototype.toJSON +// does: UTC with exactly three fractional digits. +func formatJSDate(unixMilli int64) string { + return time.UnixMilli(unixMilli).UTC().Format("2006-01-02T15:04:05.000Z") +} + +// coerceRange converts a "low - high" string into the {low, high} object +// yahoo-finance2 returns; pre-shaped objects pass through. +func coerceRange(value any) (any, error) { + switch v := value.(type) { + case map[string]any: + return v, nil + case string: + parts := strings.SplitN(v, "-", 2) + if len(parts) == 2 { + low, errLow := parseFloatPrefix(parts[0]) + high, errHigh := parseFloatPrefix(parts[1]) + if errLow == nil && errHigh == nil { + return map[string]float64{"low": low, "high": high}, nil + } + } + } + return nil, fmt.Errorf("yahoo finance: unexpected range value %v", value) +} + +// parseFloatPrefix parses a float like JS parseFloat: surrounding +// whitespace is ignored. +func parseFloatPrefix(s string) (float64, error) { + f, err := strconv.ParseFloat(strings.TrimSpace(s), 64) + if err != nil || math.IsNaN(f) { + return 0, fmt.Errorf("not a number: %q", s) + } + return f, nil +} diff --git a/go/go.mod b/go/go.mod index 9fbed46dd..9a745f870 100644 --- a/go/go.mod +++ b/go/go.mod @@ -5,6 +5,8 @@ go 1.26.1 require ( github.com/gagliardetto/binary v0.8.0 github.com/gagliardetto/solana-go v0.0.0-20260403020633-3cb13b392078 + github.com/mr-tron/base58 v1.2.0 + github.com/shopspring/decimal v1.4.0 ) replace github.com/gagliardetto/solana-go => github.com/lgalabru/solana-go v0.0.0-20260403020633-3cb13b392078 @@ -26,8 +28,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mostynb/zstdpool-freelist v0.0.0-20201229113212-927304c0c3b1 // indirect - github.com/mr-tron/base58 v1.2.0 // indirect - github.com/shopspring/decimal v1.4.0 // indirect github.com/streamingfast/logging v0.0.0-20250404134358-92b15d2fbd2e // indirect go.mongodb.org/mongo-driver v1.17.3 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go/go.sum b/go/go.sum index 57e2487fc..856f3f90d 100644 --- a/go/go.sum +++ b/go/go.sum @@ -56,7 +56,6 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= diff --git a/go/paycore/errors_test.go b/go/paycore/errors_test.go new file mode 100644 index 000000000..8f789355b --- /dev/null +++ b/go/paycore/errors_test.go @@ -0,0 +1,43 @@ +package paycore + +import ( + "errors" + "testing" +) + +func TestErrorMessageAndNilReceiver(t *testing.T) { + err := NewError(ErrCodeTooManySplits, "too many splits") + if err.Error() != "too many splits" { + t.Fatalf("Error() = %q", err.Error()) + } + if err.Unwrap() != nil { + t.Fatalf("Unwrap() = %v, want nil", err.Unwrap()) + } + var nilErr *Error + if nilErr.Error() != "" { + t.Fatalf("nil Error() = %q", nilErr.Error()) + } + if nilErr.Unwrap() != nil { + t.Fatalf("nil Unwrap() = %v", nilErr.Unwrap()) + } +} + +func TestWrapErrorAttachesCause(t *testing.T) { + cause := errors.New("rpc timeout") + wrapped := WrapError(ErrCodeSplitsExceed, "splits exceed amount", cause) + if wrapped.Code != ErrCodeSplitsExceed { + t.Fatalf("Code = %q", wrapped.Code) + } + if !errors.Is(wrapped, cause) { + t.Fatal("wrapped error does not unwrap to its cause") + } + if wrapped.Error() != "splits exceed amount: rpc timeout" { + t.Fatalf("Error() = %q", wrapped.Error()) + } + + // A nil cause degrades to NewError. + plain := WrapError(ErrCodeSplitsExceed, "splits exceed amount", nil) + if plain.Err != nil || plain.Error() != "splits exceed amount" { + t.Fatalf("nil-cause wrap = %+v", plain) + } +} diff --git a/go/paycore/network.go b/go/paycore/network.go new file mode 100644 index 000000000..39c8e7275 --- /dev/null +++ b/go/paycore/network.go @@ -0,0 +1,35 @@ +package paycore + +import "strings" + +// SolanaNetwork is the canonical Solana cluster slug carried by method +// details and used by the client-side challenge selectors. The zero +// value means "unspecified" (selectors treat it as no filter). +type SolanaNetwork string + +// Canonical cluster slugs. The wire format writes these exact strings. +const ( + // NetworkMainnet is the Solana mainnet cluster. + NetworkMainnet SolanaNetwork = "mainnet" + // NetworkDevnet is the Solana devnet cluster. + NetworkDevnet SolanaNetwork = "devnet" + // NetworkTestnet is the Solana testnet cluster. + NetworkTestnet SolanaNetwork = "testnet" + // NetworkLocalnet is a local or hosted Surfpool test validator. + NetworkLocalnet SolanaNetwork = "localnet" +) + +// ParseSolanaNetwork folds cluster aliases onto the canonical slug: +// the legacy "mainnet-beta" spelling (any case) becomes +// [NetworkMainnet]; every other value passes through unchanged so +// unknown slugs keep their server-provided spelling. +func ParseSolanaNetwork(network string) SolanaNetwork { + lower := strings.ToLower(network) + if lower == "mainnet" || lower == "mainnet-beta" { + return NetworkMainnet + } + return SolanaNetwork(network) +} + +// String returns the canonical slug. +func (n SolanaNetwork) String() string { return string(n) } diff --git a/go/paycore/network_test.go b/go/paycore/network_test.go new file mode 100644 index 000000000..eee96db42 --- /dev/null +++ b/go/paycore/network_test.go @@ -0,0 +1,25 @@ +package paycore + +import "testing" + +// TestParseSolanaNetwork pins the alias folding and pass-through rules. +func TestParseSolanaNetwork(t *testing.T) { + cases := []struct { + in string + want SolanaNetwork + }{ + {"mainnet", NetworkMainnet}, + {"mainnet-beta", NetworkMainnet}, + {"MAINNET-BETA", NetworkMainnet}, + {"devnet", NetworkDevnet}, + {"testnet", NetworkTestnet}, + {"localnet", NetworkLocalnet}, + {"", SolanaNetwork("")}, + {"surfnet", SolanaNetwork("surfnet")}, + } + for _, c := range cases { + if got := ParseSolanaNetwork(c.in); got != c.want { + t.Fatalf("ParseSolanaNetwork(%q) = %q, want %q", c.in, got, c.want) + } + } +} diff --git a/go/paycore/paymentchannels/paymentchannels.go b/go/paycore/paymentchannels/paymentchannels.go new file mode 100644 index 000000000..86c6e15ec --- /dev/null +++ b/go/paycore/paymentchannels/paymentchannels.go @@ -0,0 +1,313 @@ +// Package paymentchannels is the thin, hand-written on-chain glue over the +// codama-generated payment-channels client in +// protocols/programs/paymentchannels. It provides PDA derivation, associated +// token derivation, voucher preimage bytes, and convenience instruction +// builders for the push-mode session flow (open + top_up). +// +// The instruction bytes and PDA derivations built here must stay +// byte-identical across the language SDKs so the on-chain program accepts +// them. In particular the production program id pinned here (GuoKrza...) +// overrides the IDL placeholder baked into the generated package, which is +// not deployed. +package paymentchannels + +import ( + "bytes" + "encoding/binary" + "fmt" + + ag_binary "github.com/gagliardetto/binary" + solana "github.com/gagliardetto/solana-go" + + generated "github.com/solana-foundation/pay-kit/go/protocols/programs/paymentchannels" +) + +// ProgramID is the canonical payment-channels program id deployed to the +// network. The codama-generated package defaults its ProgramID var to the IDL +// placeholder "CQAyft83tN1w2bRofB5PZ79eVDU2xZUVo43LU1qL4zRg", which is NOT the +// production deployment; every PDA derivation and instruction built here uses +// this value instead. +const ProgramID = "GuoKrzaBiZnW5DvJ3yZVE7xHqbcBvaX9SH6P6Cn9gNvc" + +// channelSeed is the channel PDA seed prefix. +const channelSeed = "channel" + +// eventAuthoritySeed is the event-authority PDA seed prefix. +const eventAuthoritySeed = "event_authority" + +// programPubkey is the parsed production program id used for derivation and +// instruction emission. +var programPubkey = solana.MustPublicKeyFromBase58(ProgramID) + +func init() { + // Pin the generated package's ProgramID to the production deployment so + // any path that reads generated.ProgramID (e.g. Instruction.ProgramID()) + // observes GuoKrza... rather than the IDL placeholder. + generated.SetProgramID(programPubkey) +} + +// ProgramPubkey returns the parsed production program id. +func ProgramPubkey() solana.PublicKey { + return programPubkey +} + +// SetProgramID overrides the program id used for PDA derivation and instruction +// emission, for SDK consumers targeting a non-mainnet deployment (a devnet or +// localnet program is deployed at a different address). It also pins the +// generated package so Instruction.ProgramID() agrees. The default is the +// canonical mainnet ProgramID; callers on mainnet never need this. +func SetProgramID(id solana.PublicKey) { + programPubkey = id + generated.SetProgramID(id) +} + +// Distribution is a single payout recipient and its basis-point share. +type Distribution struct { + // Recipient is the wallet whose associated token account receives this + // share when settled channel funds are distributed. + Recipient solana.PublicKey + // Bps is the recipient's share of distributed funds in basis points + // (100 bps = 1%). + Bps uint16 +} + +// OpenChannelParams carries the inputs required to build an Open instruction. +type OpenChannelParams struct { + // Payer is the wallet funding the channel deposit; it signs the Open + // transaction and is a channel PDA seed. + Payer solana.PublicKey + // Payee is the counterparty the channel pays out to; a channel PDA seed. + Payee solana.PublicKey + // Mint is the SPL token mint the channel escrows (e.g. USDC); a channel + // PDA seed. + Mint solana.PublicKey + // AuthorizedSigner is the key allowed to sign vouchers against this + // channel; a channel PDA seed. + AuthorizedSigner solana.PublicKey + // Salt distinguishes multiple channels sharing the same + // payer/payee/mint/signer; encoded little-endian as the final channel + // PDA seed. + Salt uint64 + // Deposit is the initial escrow amount in token base units + // (10^-6 USDC per unit for a 6-decimal mint). + Deposit uint64 + // GracePeriod is the channel close grace period in seconds; the on-chain + // program rejects zero. + GracePeriod uint32 + // Recipients is the basis-point payout split applied when settled funds + // are distributed. + Recipients []Distribution + + // TokenProgram is the program owning Mint (SPL Token or Token-2022), + // used to derive the payer and channel associated token accounts. + TokenProgram solana.PublicKey + + // ProgramID is the payment-channels program targeted by this open. The + // zero value resolves to the package program id (ProgramPubkey, or the + // last SetProgramID override). + ProgramID solana.PublicKey +} + +// TopUpParams carries the inputs required to build a TopUp instruction. +type TopUpParams struct { + // Payer is the wallet whose token account funds the top-up; it signs + // the TopUp transaction. + Payer solana.PublicKey + // Channel is the channel PDA whose escrow receives the deposit. + Channel solana.PublicKey + // Mint is the SPL token mint the channel escrows. + Mint solana.PublicKey + // Amount is the additional deposit in token base units + // (10^-6 USDC per unit for a 6-decimal mint). + Amount uint64 + // TokenProgram is the program owning Mint (SPL Token or Token-2022), + // used to derive the payer and channel associated token accounts. + TokenProgram solana.PublicKey + + // ProgramID is the payment-channels program targeted by this top-up. The + // zero value resolves to the package program id (ProgramPubkey, or the + // last SetProgramID override). + ProgramID solana.PublicKey +} + +// resolveProgram resolves an optional per-call program id to the package +// program id when unset. +func resolveProgram(programID solana.PublicKey) solana.PublicKey { + if programID.IsZero() { + return programPubkey + } + return programID +} + +// VoucherMessageBytes returns the 48-byte voucher preimage signed by the +// authorized signer: channelId (32) || cumulativeAmount as little-endian u64 +// (offset 32) || expiresAt as little-endian i64 (offset 40). This is the exact +// Borsh layout of VoucherArgs. +func VoucherMessageBytes(channelID solana.PublicKey, cumulative uint64, expiresAt int64) ([]byte, error) { + id := channelID.Bytes() + if len(id) != 32 { + return nil, fmt.Errorf("channel id must be exactly 32 bytes, got %d", len(id)) + } + out := make([]byte, 48) + copy(out[:32], id) + binary.LittleEndian.PutUint64(out[32:40], cumulative) + binary.LittleEndian.PutUint64(out[40:48], uint64(expiresAt)) + return out, nil +} + +// FindChannelPDA derives the channel PDA from +// ["channel", payer, payee, mint, authorizedSigner, salt as little-endian u64] +// against the production program id. +func FindChannelPDA(payer, payee, mint, authorizedSigner solana.PublicKey, salt uint64) (solana.PublicKey, uint8, error) { + return FindChannelPDAForProgram(payer, payee, mint, authorizedSigner, salt, programPubkey) +} + +// FindChannelPDAForProgram derives the channel PDA against an explicit program +// id, for callers honoring a per-challenge programId. +func FindChannelPDAForProgram(payer, payee, mint, authorizedSigner solana.PublicKey, salt uint64, programID solana.PublicKey) (solana.PublicKey, uint8, error) { + saltLE := make([]byte, 8) + binary.LittleEndian.PutUint64(saltLE, salt) + addr, bump, err := solana.FindProgramAddress( + [][]byte{ + []byte(channelSeed), + payer.Bytes(), + payee.Bytes(), + mint.Bytes(), + authorizedSigner.Bytes(), + saltLE, + }, + resolveProgram(programID), + ) + if err != nil { + return solana.PublicKey{}, 0, fmt.Errorf("derive channel pda: %w", err) + } + return addr, bump, nil +} + +// FindEventAuthorityPDA derives the event-authority PDA from +// ["event_authority"] against the production program id. +func FindEventAuthorityPDA() (solana.PublicKey, uint8, error) { + return FindEventAuthorityPDAForProgram(programPubkey) +} + +// FindEventAuthorityPDAForProgram derives the event-authority PDA against an +// explicit program id, for callers honoring a per-challenge programId. +func FindEventAuthorityPDAForProgram(programID solana.PublicKey) (solana.PublicKey, uint8, error) { + addr, bump, err := solana.FindProgramAddress( + [][]byte{[]byte(eventAuthoritySeed)}, + resolveProgram(programID), + ) + if err != nil { + return solana.PublicKey{}, 0, fmt.Errorf("derive event-authority pda: %w", err) + } + return addr, bump, nil +} + +// BuildOpenInstruction derives the channel PDA, payer/channel ATAs, and +// event-authority PDA, then builds the Open instruction with every account set +// in the exact order the on-chain program expects, using the production +// program id. +func BuildOpenInstruction(params OpenChannelParams) (solana.Instruction, error) { + programID := resolveProgram(params.ProgramID) + channel, _, err := FindChannelPDAForProgram(params.Payer, params.Payee, params.Mint, params.AuthorizedSigner, params.Salt, programID) + if err != nil { + return nil, err + } + payerToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Payer, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive payer token account: %w", err) + } + channelToken, _, err := solana.FindAssociatedTokenAddressWithProgram(channel, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive channel token account: %w", err) + } + eventAuthority, _, err := FindEventAuthorityPDAForProgram(programID) + if err != nil { + return nil, err + } + + recipients := make([]generated.DistributionEntry, 0, len(params.Recipients)) + for _, entry := range params.Recipients { + recipients = append(recipients, generated.DistributionEntry{ + Recipient: entry.Recipient, + Bps: entry.Bps, + }) + } + + builder := generated.NewOpenInstructionBuilder(). + SetPayerAccount(params.Payer). + SetPayeeAccount(params.Payee). + SetMintAccount(params.Mint). + SetAuthorizedSignerAccount(params.AuthorizedSigner). + SetChannelAccount(channel). + SetPayerTokenAccountAccount(payerToken). + SetChannelTokenAccountAccount(channelToken). + SetTokenProgramAccount(params.TokenProgram). + SetSystemProgramAccount(solana.SystemProgramID). + SetRentAccount(solana.SysVarRentPubkey). + SetAssociatedTokenProgramAccount(solana.SPLAssociatedTokenAccountProgramID). + SetEventAuthorityAccount(eventAuthority). + SetSelfProgramAccount(programID). + SetOpenArgs(generated.OpenArgs{ + Salt: params.Salt, + Deposit: params.Deposit, + GracePeriod: params.GracePeriod, + Recipients: recipients, + }) + + if _, err := builder.ValidateAndBuild(); err != nil { + return nil, fmt.Errorf("build open instruction: %w", err) + } + return materialize(builder, builder.GetAccounts(), programID) +} + +// BuildTopUpInstruction derives the payer/channel ATAs and builds the TopUp +// instruction with every account set in the exact order the on-chain program +// expects, using the production program id. +func BuildTopUpInstruction(params TopUpParams) (solana.Instruction, error) { + payerToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Payer, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive payer token account: %w", err) + } + channelToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Channel, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive channel token account: %w", err) + } + + builder := generated.NewTopUpInstructionBuilder(). + SetPayerAccount(params.Payer). + SetChannelAccount(params.Channel). + SetPayerTokenAccountAccount(payerToken). + SetChannelTokenAccountAccount(channelToken). + SetMintAccount(params.Mint). + SetTokenProgramAccount(params.TokenProgram). + SetTopUpArgs(generated.TopUpArgs{Amount: params.Amount}) + + if _, err := builder.ValidateAndBuild(); err != nil { + return nil, fmt.Errorf("build top_up instruction: %w", err) + } + return materialize(builder, builder.GetAccounts(), resolveProgram(params.ProgramID)) +} + +// materialize borsh-encodes a validated generated instruction implementation +// and returns a solana.GenericInstruction pinned to the production program id. +// +// Two generated-package quirks are handled here: +// - The instruction implementation (*Open/*TopUp) is encoded directly so its +// MarshalWithEncoder writes the program's real one-byte discriminator +// (Open=1, TopUp=3). Wrapping it in the generated Instruction.Data() would +// prepend a spurious NoTypeID-default byte, corrupting the on-chain data. +// - The implementation is stored by value inside the generated Instruction, +// so its Accounts() accessor type-asserts to a pointer-receiver interface +// and panics; passing the builder's own GetAccounts() avoids that path. +// +// The result's ProgramID() is the resolved per-call program id (the production +// ProgramID by default, a SetProgramID override, or an explicit per-call +// ProgramID for a non-mainnet cluster). +func materialize(impl ag_binary.EncoderDecoder, accounts []*solana.AccountMeta, programID solana.PublicKey) (solana.Instruction, error) { + buf := new(bytes.Buffer) + if err := ag_binary.NewBorshEncoder(buf).Encode(impl); err != nil { + return nil, fmt.Errorf("encode instruction data: %w", err) + } + return solana.NewInstruction(programID, accounts, buf.Bytes()), nil +} diff --git a/go/paycore/paymentchannels/paymentchannels_test.go b/go/paycore/paymentchannels/paymentchannels_test.go new file mode 100644 index 000000000..5bba6b812 --- /dev/null +++ b/go/paycore/paymentchannels/paymentchannels_test.go @@ -0,0 +1,504 @@ +package paymentchannels + +import ( + "bytes" + "encoding/binary" + "testing" + + ag_binary "github.com/gagliardetto/binary" + solana "github.com/gagliardetto/solana-go" + + generated "github.com/solana-foundation/pay-kit/go/protocols/programs/paymentchannels" +) + +// pk returns a deterministic 32-byte public key filled with the given byte. +func pk(b byte) solana.PublicKey { + var out solana.PublicKey + for i := range out { + out[i] = b + } + return out +} + +func TestProgramIDIsProduction(t *testing.T) { + if ProgramID != "GuoKrzaBiZnW5DvJ3yZVE7xHqbcBvaX9SH6P6Cn9gNvc" { + t.Fatalf("unexpected program id: %s", ProgramID) + } + if ProgramPubkey().String() != ProgramID { + t.Fatalf("parsed program id mismatch: %s", ProgramPubkey()) + } + // init() must have pinned the generated package to the production id. + if generated.ProgramID.String() != ProgramID { + t.Fatalf("generated ProgramID not pinned to production: %s", generated.ProgramID) + } +} + +func TestSetProgramIDOverridesDerivation(t *testing.T) { + // SetProgramID lets a consumer target a non-mainnet (devnet/localnet) + // deployment at a different address; it must move PDA derivation and pin the + // generated package. Restore the production default for other tests. + t.Cleanup(func() { SetProgramID(solana.MustPublicKeyFromBase58(ProgramID)) }) + + custom := solana.NewWallet().PublicKey() + SetProgramID(custom) + if !ProgramPubkey().Equals(custom) { + t.Fatalf("ProgramPubkey not overridden: %s", ProgramPubkey()) + } + if !generated.ProgramID.Equals(custom) { + t.Fatalf("generated ProgramID not pinned to override: %s", generated.ProgramID) + } + + payer := solana.NewWallet().PublicKey() + payee := solana.NewWallet().PublicKey() + mint := solana.NewWallet().PublicKey() + signer := solana.NewWallet().PublicKey() + overridden, _, err := FindChannelPDA(payer, payee, mint, signer, 1) + if err != nil { + t.Fatalf("FindChannelPDA: %v", err) + } + SetProgramID(solana.MustPublicKeyFromBase58(ProgramID)) + production, _, err := FindChannelPDA(payer, payee, mint, signer, 1) + if err != nil { + t.Fatalf("FindChannelPDA: %v", err) + } + if overridden.Equals(production) { + t.Fatal("channel PDA did not change with the program id override") + } +} + +func TestPerCallProgramIDOverridesDerivationAndInstruction(t *testing.T) { + custom := solana.NewWallet().PublicKey() + payer := solana.NewWallet().PublicKey() + payee := solana.NewWallet().PublicKey() + mint := solana.NewWallet().PublicKey() + signer := solana.NewWallet().PublicKey() + + defaultPDA, _, err := FindChannelPDA(payer, payee, mint, signer, 1) + if err != nil { + t.Fatalf("FindChannelPDA: %v", err) + } + customPDA, _, err := FindChannelPDAForProgram(payer, payee, mint, signer, 1, custom) + if err != nil { + t.Fatalf("FindChannelPDAForProgram: %v", err) + } + if defaultPDA.Equals(customPDA) { + t.Fatal("channel PDA did not change with the per-call program id") + } + zeroPDA, _, err := FindChannelPDAForProgram(payer, payee, mint, signer, 1, solana.PublicKey{}) + if err != nil { + t.Fatalf("FindChannelPDAForProgram zero: %v", err) + } + if !zeroPDA.Equals(defaultPDA) { + t.Fatal("zero per-call program id should resolve to the package default") + } + + params := OpenChannelParams{ + Payer: payer, + Payee: payee, + Mint: mint, + AuthorizedSigner: signer, + Salt: 1, + Deposit: 10, + GracePeriod: 900, + TokenProgram: solana.TokenProgramID, + ProgramID: custom, + } + ix, err := BuildOpenInstruction(params) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + if !ix.ProgramID().Equals(custom) { + t.Fatalf("open instruction program id = %s, want per-call override", ix.ProgramID()) + } + accounts := ix.Accounts() + if !accounts[4].PublicKey.Equals(customPDA) { + t.Fatalf("open channel account = %s, want PDA derived against the per-call program", accounts[4].PublicKey) + } + + topUp, err := BuildTopUpInstruction(TopUpParams{ + Payer: payer, + Channel: customPDA, + Mint: mint, + Amount: 5, + TokenProgram: solana.TokenProgramID, + ProgramID: custom, + }) + if err != nil { + t.Fatalf("BuildTopUpInstruction: %v", err) + } + if !topUp.ProgramID().Equals(custom) { + t.Fatalf("top_up instruction program id = %s, want per-call override", topUp.ProgramID()) + } +} + +func TestVoucherMessageBytesLayout(t *testing.T) { + const cumulative uint64 = 42 + const expiresAt int64 = 1234 + channel := pk(9) + + got, err := VoucherMessageBytes(channel, cumulative, expiresAt) + if err != nil { + t.Fatalf("VoucherMessageBytes: %v", err) + } + if len(got) != 48 { + t.Fatalf("expected 48 bytes, got %d", len(got)) + } + if !bytes.Equal(got[:32], channel.Bytes()) { + t.Fatalf("offset 0..32 should be channel id") + } + wantCumulative := make([]byte, 8) + binary.LittleEndian.PutUint64(wantCumulative, cumulative) + if !bytes.Equal(got[32:40], wantCumulative) { + t.Fatalf("offset 32..40 should be cumulative LE u64, got %x", got[32:40]) + } + wantExpires := make([]byte, 8) + binary.LittleEndian.PutUint64(wantExpires, uint64(expiresAt)) + if !bytes.Equal(got[40:48], wantExpires) { + t.Fatalf("offset 40..48 should be expiresAt LE i64, got %x", got[40:48]) + } +} + +func TestVoucherMessageBytesMatchesGeneratedBorsh(t *testing.T) { + const cumulative uint64 = 7 + var expiresAt int64 = -5 // negative i64 exercises two's-complement LE + channel := pk(3) + + got, err := VoucherMessageBytes(channel, cumulative, expiresAt) + if err != nil { + t.Fatalf("VoucherMessageBytes: %v", err) + } + + want := make([]byte, 0, 48) + want = append(want, channel.Bytes()...) + c := make([]byte, 8) + binary.LittleEndian.PutUint64(c, cumulative) + want = append(want, c...) + e := make([]byte, 8) + binary.LittleEndian.PutUint64(e, uint64(expiresAt)) + want = append(want, e...) + + if !bytes.Equal(got, want) { + t.Fatalf("voucher bytes mismatch:\n got=%x\nwant=%x", got, want) + } + // Sanity: the wire layout equals the field order of generated.VoucherArgs. + _ = generated.VoucherArgs{ChannelId: channel, CumulativeAmount: cumulative, ExpiresAt: expiresAt} +} + +func TestVoucherMessageBytesRejectsNon32(t *testing.T) { + // solana.PublicKey is a fixed [32]byte, so we cannot pass a short id at + // the type level; assert the happy path is exactly 32 and that a default + // (zero) key still yields 32 bytes. The length guard protects against any + // future non-fixed input path. + got, err := VoucherMessageBytes(solana.PublicKey{}, 0, 0) + if err != nil { + t.Fatalf("zero key should be valid 32 bytes: %v", err) + } + if len(got) != 48 { + t.Fatalf("expected 48 bytes, got %d", len(got)) + } +} + +func TestFindChannelPDADeterministic(t *testing.T) { + a, bumpA, err := FindChannelPDA(pk(1), pk(2), pk(3), pk(4), 99) + if err != nil { + t.Fatalf("FindChannelPDA: %v", err) + } + b, bumpB, err := FindChannelPDA(pk(1), pk(2), pk(3), pk(4), 99) + if err != nil { + t.Fatalf("FindChannelPDA repeat: %v", err) + } + if a != b || bumpA != bumpB { + t.Fatalf("channel pda not deterministic: %s/%d vs %s/%d", a, bumpA, b, bumpB) + } + + // Reproduce the seeds against the production program id directly. + saltLE := make([]byte, 8) + binary.LittleEndian.PutUint64(saltLE, 99) + want, wantBump, err := solana.FindProgramAddress( + [][]byte{ + []byte("channel"), + pk(1).Bytes(), pk(2).Bytes(), pk(3).Bytes(), pk(4).Bytes(), + saltLE, + }, + programPubkey, + ) + if err != nil { + t.Fatalf("reference derivation: %v", err) + } + if a != want || bumpA != wantBump { + t.Fatalf("channel pda mismatch: got %s/%d want %s/%d", a, bumpA, want, wantBump) + } +} + +func TestFindChannelPDAUsesGuoKrza(t *testing.T) { + got, _, err := FindChannelPDA(pk(1), pk(2), pk(3), pk(4), 99) + if err != nil { + t.Fatalf("FindChannelPDA: %v", err) + } + // Deriving against the IDL placeholder must produce a different PDA. + saltLE := make([]byte, 8) + binary.LittleEndian.PutUint64(saltLE, 99) + placeholder := solana.MustPublicKeyFromBase58("CQAyft83tN1w2bRofB5PZ79eVDU2xZUVo43LU1qL4zRg") + other, _, err := solana.FindProgramAddress( + [][]byte{ + []byte("channel"), + pk(1).Bytes(), pk(2).Bytes(), pk(3).Bytes(), pk(4).Bytes(), + saltLE, + }, + placeholder, + ) + if err != nil { + t.Fatalf("placeholder derivation: %v", err) + } + if got == other { + t.Fatalf("channel pda should differ from the IDL-placeholder derivation") + } +} + +func TestFindChannelPDASaltSensitivity(t *testing.T) { + a, _, err := FindChannelPDA(pk(1), pk(2), pk(3), pk(4), 1) + if err != nil { + t.Fatalf("FindChannelPDA salt 1: %v", err) + } + b, _, err := FindChannelPDA(pk(1), pk(2), pk(3), pk(4), 2) + if err != nil { + t.Fatalf("FindChannelPDA salt 2: %v", err) + } + if a == b { + t.Fatalf("different salts must yield different channel pdas") + } +} + +func TestFindEventAuthorityPDA(t *testing.T) { + got, bump, err := FindEventAuthorityPDA() + if err != nil { + t.Fatalf("FindEventAuthorityPDA: %v", err) + } + want, wantBump, err := solana.FindProgramAddress([][]byte{[]byte("event_authority")}, programPubkey) + if err != nil { + t.Fatalf("reference derivation: %v", err) + } + if got != want || bump != wantBump { + t.Fatalf("event-authority pda mismatch: got %s/%d want %s/%d", got, bump, want, wantBump) + } +} + +func openParams() OpenChannelParams { + return OpenChannelParams{ + Payer: pk(1), + Payee: pk(2), + Mint: pk(3), + AuthorizedSigner: pk(4), + Salt: 99, + Deposit: 1_000_000, + GracePeriod: 3600, + Recipients: []Distribution{ + {Recipient: pk(5), Bps: 7_500}, + {Recipient: pk(6), Bps: 2_500}, + }, + TokenProgram: solana.TokenProgramID, + } +} + +func TestBuildOpenInstructionProgramIDAndAccounts(t *testing.T) { + params := openParams() + inst, err := BuildOpenInstruction(params) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + + if inst.ProgramID().String() != ProgramID { + t.Fatalf("open instruction program id is %s, want %s", inst.ProgramID(), ProgramID) + } + + metas := inst.Accounts() + if len(metas) != 13 { + t.Fatalf("expected 13 accounts, got %d", len(metas)) + } + + channel, _, err := FindChannelPDA(params.Payer, params.Payee, params.Mint, params.AuthorizedSigner, params.Salt) + if err != nil { + t.Fatalf("channel pda: %v", err) + } + payerToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Payer, params.Mint, params.TokenProgram) + if err != nil { + t.Fatalf("payer ata: %v", err) + } + channelToken, _, err := solana.FindAssociatedTokenAddressWithProgram(channel, params.Mint, params.TokenProgram) + if err != nil { + t.Fatalf("channel ata: %v", err) + } + eventAuthority, _, err := FindEventAuthorityPDA() + if err != nil { + t.Fatalf("event-authority pda: %v", err) + } + + want := []solana.PublicKey{ + params.Payer, + params.Payee, + params.Mint, + params.AuthorizedSigner, + channel, + payerToken, + channelToken, + params.TokenProgram, + solana.SystemProgramID, + solana.SysVarRentPubkey, + solana.SPLAssociatedTokenAccountProgramID, + eventAuthority, + programPubkey, + } + for i, w := range want { + if metas[i].PublicKey != w { + t.Fatalf("account[%d] = %s, want %s", i, metas[i].PublicKey, w) + } + } + + // Writable/signer flags for the load-bearing accounts. + if !metas[0].IsSigner || !metas[0].IsWritable { + t.Fatalf("payer must be writable signer") + } + if !metas[4].IsWritable { + t.Fatalf("channel must be writable") + } + if !metas[5].IsWritable || !metas[6].IsWritable { + t.Fatalf("token accounts must be writable") + } +} + +func TestBuildOpenInstructionArgsRoundTrip(t *testing.T) { + params := openParams() + inst, err := BuildOpenInstruction(params) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + data, err := inst.Data() + if err != nil { + t.Fatalf("encode instruction data: %v", err) + } + + // The first byte is the program's Open discriminator (1); the remainder is + // borsh-encoded OpenArgs. + if len(data) == 0 || data[0] != byte(generated.OpenDiscriminator) { + t.Fatalf("expected leading Open discriminator %d, got %v", generated.OpenDiscriminator, data) + } + var args generated.OpenArgs + if err := ag_binary.NewBorshDecoder(data[1:]).Decode(&args); err != nil { + t.Fatalf("decode open args: %v", err) + } + if args.Salt != params.Salt || args.Deposit != params.Deposit || args.GracePeriod != params.GracePeriod { + t.Fatalf("open args round-trip mismatch: %+v", args) + } + if len(args.Recipients) != len(params.Recipients) { + t.Fatalf("recipients length mismatch: got %d", len(args.Recipients)) + } + for i, r := range params.Recipients { + if args.Recipients[i].Recipient != r.Recipient || args.Recipients[i].Bps != r.Bps { + t.Fatalf("recipient[%d] round-trip mismatch: %+v", i, args.Recipients[i]) + } + } +} + +func TestBuildOpenInstructionEmptyRecipients(t *testing.T) { + params := openParams() + params.Recipients = nil + inst, err := BuildOpenInstruction(params) + if err != nil { + t.Fatalf("BuildOpenInstruction empty recipients: %v", err) + } + data, err := inst.Data() + if err != nil { + t.Fatalf("encode: %v", err) + } + var args generated.OpenArgs + if err := ag_binary.NewBorshDecoder(data[1:]).Decode(&args); err != nil { + t.Fatalf("decode open args: %v", err) + } + if len(args.Recipients) != 0 { + t.Fatalf("expected zero recipients, got %d", len(args.Recipients)) + } +} + +func TestBuildTopUpInstructionProgramIDAndAccounts(t *testing.T) { + channel, _, err := FindChannelPDA(pk(1), pk(2), pk(3), pk(4), 99) + if err != nil { + t.Fatalf("channel pda: %v", err) + } + params := TopUpParams{ + Payer: pk(1), + Channel: channel, + Mint: pk(3), + Amount: 250_000, + TokenProgram: solana.TokenProgramID, + } + inst, err := BuildTopUpInstruction(params) + if err != nil { + t.Fatalf("BuildTopUpInstruction: %v", err) + } + + if inst.ProgramID().String() != ProgramID { + t.Fatalf("top_up program id is %s, want %s", inst.ProgramID(), ProgramID) + } + + metas := inst.Accounts() + if len(metas) != 6 { + t.Fatalf("expected 6 accounts, got %d", len(metas)) + } + + payerToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Payer, params.Mint, params.TokenProgram) + if err != nil { + t.Fatalf("payer ata: %v", err) + } + channelToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Channel, params.Mint, params.TokenProgram) + if err != nil { + t.Fatalf("channel ata: %v", err) + } + + want := []solana.PublicKey{ + params.Payer, + params.Channel, + payerToken, + channelToken, + params.Mint, + params.TokenProgram, + } + for i, w := range want { + if metas[i].PublicKey != w { + t.Fatalf("account[%d] = %s, want %s", i, metas[i].PublicKey, w) + } + } + if !metas[0].IsSigner || !metas[0].IsWritable { + t.Fatalf("payer must be writable signer") + } + if !metas[1].IsWritable || !metas[2].IsWritable || !metas[3].IsWritable { + t.Fatalf("channel and token accounts must be writable") + } +} + +func TestBuildTopUpInstructionArgsRoundTrip(t *testing.T) { + params := TopUpParams{ + Payer: pk(1), + Channel: pk(7), + Mint: pk(3), + Amount: 987_654, + TokenProgram: solana.TokenProgramID, + } + inst, err := BuildTopUpInstruction(params) + if err != nil { + t.Fatalf("BuildTopUpInstruction: %v", err) + } + data, err := inst.Data() + if err != nil { + t.Fatalf("encode: %v", err) + } + if len(data) == 0 || data[0] != byte(generated.TopUpDiscriminator) { + t.Fatalf("expected leading TopUp discriminator %d, got %v", generated.TopUpDiscriminator, data) + } + var args generated.TopUpArgs + if err := ag_binary.NewBorshDecoder(data[1:]).Decode(&args); err != nil { + t.Fatalf("decode top_up args: %v", err) + } + if args.Amount != params.Amount { + t.Fatalf("amount round-trip mismatch: got %d want %d", args.Amount, params.Amount) + } +} diff --git a/go/paycore/paymentchannels/settlement.go b/go/paycore/paymentchannels/settlement.go new file mode 100644 index 000000000..4ec609bbb --- /dev/null +++ b/go/paycore/paymentchannels/settlement.go @@ -0,0 +1,246 @@ +package paymentchannels + +// Server-side settlement instruction builders for the push-mode session +// close path: the Ed25519 signature-verification precompile, the +// settle_and_finalize instruction that must immediately follow it, and the +// distribute instruction bundled into the same transaction. +// +// The instruction bytes built here must stay identical across the language +// SDKs; the cross-language harness pins them. + +import ( + "encoding/binary" + "fmt" + "math" + + solana "github.com/gagliardetto/solana-go" + + generated "github.com/solana-foundation/pay-kit/go/protocols/programs/paymentchannels" +) + +// Ed25519ProgramID is the Ed25519 signature-verification native precompile +// program id. +const Ed25519ProgramID = "Ed25519SigVerify111111111111111111111111111" + +// ed25519ProgramPubkey is the parsed precompile program id. +var ed25519ProgramPubkey = solana.MustPublicKeyFromBase58(Ed25519ProgramID) + +// Ed25519ProgramPubkey returns the parsed Ed25519 precompile program id. +func Ed25519ProgramPubkey() solana.PublicKey { + return ed25519ProgramPubkey +} + +// TreasuryOwner returns the treasury owner used by the current +// payment-channels program deployment: 32 bytes of repeated 0xBE 0xEF. +func TreasuryOwner() solana.PublicKey { + var key solana.PublicKey + for i := 0; i < len(key); i += 2 { + key[i] = 0xBE + key[i+1] = 0xEF + } + return key +} + +// BuildEd25519VerifyInstruction builds an Ed25519 precompile instruction +// verifying signature over message against authorizedSigner, with the +// signature material embedded in the instruction itself (every +// instruction-index field is 0xFFFF, "current instruction"). The data layout +// is the precompile's fixed header (public key at offset 16, signature at 48, +// message at 112) followed by the message bytes. +func BuildEd25519VerifyInstruction(authorizedSigner solana.PublicKey, signature [64]byte, message []byte) (solana.Instruction, error) { + const publicKeyOffset = 16 + const signatureOffset = publicKeyOffset + 32 // 48 + const messageDataOffset = signatureOffset + 64 // 112 + const currentInstruction = uint16(math.MaxUint16) + + if len(message) > math.MaxUint16 { + return nil, fmt.Errorf("voucher message too long: %d bytes", len(message)) + } + + data := make([]byte, messageDataOffset+len(message)) + data[0] = 1 // num_signatures + data[1] = 0 // padding + binary.LittleEndian.PutUint16(data[2:4], signatureOffset) + binary.LittleEndian.PutUint16(data[4:6], currentInstruction) + binary.LittleEndian.PutUint16(data[6:8], publicKeyOffset) + binary.LittleEndian.PutUint16(data[8:10], currentInstruction) + binary.LittleEndian.PutUint16(data[10:12], messageDataOffset) + binary.LittleEndian.PutUint16(data[12:14], uint16(len(message))) + binary.LittleEndian.PutUint16(data[14:16], currentInstruction) + copy(data[publicKeyOffset:], authorizedSigner.Bytes()) + copy(data[signatureOffset:], signature[:]) + copy(data[messageDataOffset:], message) + + return solana.NewInstruction(ed25519ProgramPubkey, solana.AccountMetaSlice{}, data), nil +} + +// SettleAndFinalizeParams carries the inputs required to build the +// settle_and_finalize instruction sequence. +type SettleAndFinalizeParams struct { + // Merchant is the signer authorized to settle the channel. + Merchant solana.PublicKey + + // Channel is the payment-channel address being settled. + Channel solana.PublicKey + + // AuthorizedSigner is the voucher signing key recorded at open. Only + // read when Signature is set. + AuthorizedSigner solana.PublicKey + + // Signature is the Ed25519 signature of the highest accepted voucher. + // Nil settles without a voucher (hasVoucher = 0, no precompile). + Signature *[64]byte + + // CumulativeAmount is the settled watermark committed on-chain. + CumulativeAmount uint64 + + // ExpiresAt is the expiry of the settled voucher (Unix seconds). + ExpiresAt int64 + + // ProgramID is the payment-channels program targeted by this settle. The + // zero value resolves to the package program id. + ProgramID solana.PublicKey +} + +// BuildSettleAndFinalizeInstructions builds the instruction sequence for an +// on-chain settle_and_finalize. When a voucher signature is provided, an +// Ed25519 precompile instruction over the canonical 48-byte voucher message +// is placed immediately before the settle_and_finalize instruction, which +// references it through the instructions sysvar, and hasVoucher is set to 1. +func BuildSettleAndFinalizeInstructions(params SettleAndFinalizeParams) ([]solana.Instruction, error) { + programID := resolveProgram(params.ProgramID) + instructions := make([]solana.Instruction, 0, 2) + hasVoucher := uint8(0) + + if params.Signature != nil { + message, err := VoucherMessageBytes(params.Channel, params.CumulativeAmount, params.ExpiresAt) + if err != nil { + return nil, err + } + verify, err := BuildEd25519VerifyInstruction(params.AuthorizedSigner, *params.Signature, message) + if err != nil { + return nil, err + } + instructions = append(instructions, verify) + hasVoucher = 1 + } + + builder := generated.NewSettleAndFinalizeInstructionBuilder(). + SetMerchantAccount(params.Merchant). + SetChannelAccount(params.Channel). + SetInstructionsSysvarAccount(solana.SysVarInstructionsPubkey). + SetSettleAndFinalizeArgs(generated.SettleAndFinalizeArgs{ + Voucher: generated.VoucherArgs{ + ChannelId: params.Channel, + CumulativeAmount: params.CumulativeAmount, + ExpiresAt: params.ExpiresAt, + }, + HasVoucher: hasVoucher, + }) + if _, err := builder.ValidateAndBuild(); err != nil { + return nil, fmt.Errorf("build settle_and_finalize instruction: %w", err) + } + settle, err := materialize(builder, builder.GetAccounts(), programID) + if err != nil { + return nil, err + } + return append(instructions, settle), nil +} + +// DistributeParams carries the inputs required to build a Distribute +// instruction. +type DistributeParams struct { + // Channel is the settled payment-channel address. + Channel solana.PublicKey + + // Payer is the original channel payer, refunded the unsettled remainder. + Payer solana.PublicKey + + // Payee is the primary payment recipient. + Payee solana.PublicKey + + // Treasury is the treasury owner of the program deployment. The zero + // value resolves to TreasuryOwner(). + Treasury solana.PublicKey + + // Mint is the SPL mint locked in the channel. + Mint solana.PublicKey + + // Recipients are the basis-point splits distributed at close. + Recipients []Distribution + + // TokenProgram owning the mint (Token or Token-2022). + TokenProgram solana.PublicKey + + // ProgramID is the payment-channels program targeted by this distribute. + // The zero value resolves to the package program id. + ProgramID solana.PublicKey +} + +// BuildDistributeInstruction derives the channel/payer/payee/treasury ATAs +// plus one ATA per split recipient and builds the Distribute instruction: +// the 10 fixed accounts in the exact order the on-chain program expects, +// followed by one writable recipient token account per split. +func BuildDistributeInstruction(params DistributeParams) (solana.Instruction, error) { + programID := resolveProgram(params.ProgramID) + treasury := params.Treasury + if treasury.IsZero() { + treasury = TreasuryOwner() + } + + channelToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Channel, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive channel token account: %w", err) + } + payerToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Payer, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive payer token account: %w", err) + } + payeeToken, _, err := solana.FindAssociatedTokenAddressWithProgram(params.Payee, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive payee token account: %w", err) + } + treasuryToken, _, err := solana.FindAssociatedTokenAddressWithProgram(treasury, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive treasury token account: %w", err) + } + eventAuthority, _, err := FindEventAuthorityPDAForProgram(programID) + if err != nil { + return nil, err + } + + entries := make([]generated.DistributionEntry, 0, len(params.Recipients)) + recipientTokenAccounts := make([]*solana.AccountMeta, 0, len(params.Recipients)) + for _, entry := range params.Recipients { + recipientToken, _, err := solana.FindAssociatedTokenAddressWithProgram(entry.Recipient, params.Mint, params.TokenProgram) + if err != nil { + return nil, fmt.Errorf("derive recipient token account for %s: %w", entry.Recipient, err) + } + recipientTokenAccounts = append(recipientTokenAccounts, solana.Meta(recipientToken).WRITE()) + entries = append(entries, generated.DistributionEntry{ + Recipient: entry.Recipient, + Bps: entry.Bps, + }) + } + + builder := generated.NewDistributeInstructionBuilder(). + SetChannelAccount(params.Channel). + SetPayerAccount(params.Payer). + SetChannelTokenAccountAccount(channelToken). + SetPayerTokenAccountAccount(payerToken). + SetPayeeTokenAccountAccount(payeeToken). + SetTreasuryTokenAccountAccount(treasuryToken). + SetMintAccount(params.Mint). + SetTokenProgramAccount(params.TokenProgram). + SetEventAuthorityAccount(eventAuthority). + SetSelfProgramAccount(programID). + SetDistributeArgs(generated.DistributeArgs{Recipients: entries}) + + if _, err := builder.ValidateAndBuild(); err != nil { + return nil, fmt.Errorf("build distribute instruction: %w", err) + } + accounts := make([]*solana.AccountMeta, 0, len(builder.GetAccounts())+len(recipientTokenAccounts)) + accounts = append(accounts, builder.GetAccounts()...) + accounts = append(accounts, recipientTokenAccounts...) + return materialize(builder, accounts, programID) +} diff --git a/go/paycore/paymentchannels/settlement_test.go b/go/paycore/paymentchannels/settlement_test.go new file mode 100644 index 000000000..5f7cf70c8 --- /dev/null +++ b/go/paycore/paymentchannels/settlement_test.go @@ -0,0 +1,375 @@ +package paymentchannels + +// Settlement builder byte-equivalence tests. +// +// These pin the Ed25519 precompile layout and the settle_and_finalize, +// top_up, distribute, and open instruction bytes so any drift from the +// on-chain program encoding is caught at unit-test time. + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "testing" + + solana "github.com/gagliardetto/solana-go" +) + +// fixedKey returns a deterministic 32-byte public key filled with b. +func fixedKey(b byte) solana.PublicKey { + var key solana.PublicKey + for i := range key { + key[i] = b + } + return key +} + +const zeroChannelID = "11111111111111111111111111111111" + +// ── Ed25519 precompile ── + +func TestBuildEd25519VerifyInstructionLayout(t *testing.T) { + signer := fixedKey(0xAA) + var signature [64]byte + for i := range signature { + signature[i] = 0xBB + } + message := bytes.Repeat([]byte{0xCC}, 48) + + ix, err := BuildEd25519VerifyInstruction(signer, signature, message) + if err != nil { + t.Fatalf("BuildEd25519VerifyInstruction: %v", err) + } + if !ix.ProgramID().Equals(Ed25519ProgramPubkey()) { + t.Fatalf("program id = %s, want %s", ix.ProgramID(), Ed25519ProgramID) + } + if len(ix.Accounts()) != 0 { + t.Fatalf("accounts = %d, want 0", len(ix.Accounts())) + } + + data, err := ix.Data() + if err != nil { + t.Fatalf("ix.Data: %v", err) + } + if len(data) != 160 { + t.Fatalf("data length = %d, want 160", len(data)) + } + if data[0] != 1 || data[1] != 0 { + t.Fatalf("header = [%d %d], want [1 0] (num_signatures, padding)", data[0], data[1]) + } + // Offsets: signature 48, public key 16, message 112, size 48; every + // instruction-index field is 0xFFFF (current instruction). + expectHeader := []struct { + offset int + value uint16 + label string + }{ + {2, 48, "signature_offset"}, + {4, 0xFFFF, "signature_instruction_index"}, + {6, 16, "public_key_offset"}, + {8, 0xFFFF, "public_key_instruction_index"}, + {10, 112, "message_data_offset"}, + {12, 48, "message_data_size"}, + {14, 0xFFFF, "message_instruction_index"}, + } + for _, field := range expectHeader { + if got := binary.LittleEndian.Uint16(data[field.offset : field.offset+2]); got != field.value { + t.Fatalf("%s = %d, want %d", field.label, got, field.value) + } + } + if !bytes.Equal(data[16:48], signer.Bytes()) { + t.Fatal("public key bytes not at offset 16") + } + if !bytes.Equal(data[48:112], signature[:]) { + t.Fatal("signature bytes not at offset 48") + } + if !bytes.Equal(data[112:160], message) { + t.Fatal("message bytes not at offset 112") + } +} + +func TestBuildEd25519VerifyInstructionRejectsOversizedMessage(t *testing.T) { + if _, err := BuildEd25519VerifyInstruction(fixedKey(1), [64]byte{}, make([]byte, 0x10000)); err == nil { + t.Fatal("expected oversized-message rejection") + } +} + +// ── settle_and_finalize ── + +func TestBuildSettleAndFinalizeVoucherless(t *testing.T) { + merchant := fixedKey(0x05) + channel := solana.MustPublicKeyFromBase58(zeroChannelID) + + instructions, err := BuildSettleAndFinalizeInstructions(SettleAndFinalizeParams{ + Merchant: merchant, + Channel: channel, + }) + if err != nil { + t.Fatalf("BuildSettleAndFinalizeInstructions: %v", err) + } + if len(instructions) != 1 { + t.Fatalf("instructions = %d, want 1 (no precompile without a voucher)", len(instructions)) + } + + ix := instructions[0] + if !ix.ProgramID().Equals(ProgramPubkey()) { + t.Fatalf("program id = %s, want %s", ix.ProgramID(), ProgramID) + } + accounts := ix.Accounts() + if len(accounts) != 3 { + t.Fatalf("accounts = %d, want 3", len(accounts)) + } + if !accounts[0].PublicKey.Equals(merchant) || !accounts[0].IsSigner || accounts[0].IsWritable { + t.Fatalf("merchant meta = %+v, want readonly signer", accounts[0]) + } + if !accounts[1].PublicKey.Equals(channel) || accounts[1].IsSigner || !accounts[1].IsWritable { + t.Fatalf("channel meta = %+v, want writable non-signer", accounts[1]) + } + if !accounts[2].PublicKey.Equals(solana.SysVarInstructionsPubkey) || accounts[2].IsSigner || accounts[2].IsWritable { + t.Fatalf("sysvar meta = %+v, want readonly instructions sysvar", accounts[2]) + } + + data, err := ix.Data() + if err != nil { + t.Fatalf("ix.Data: %v", err) + } + // [disc=4][channel 32][cumulative u64][expiresAt i64][hasVoucher=0] = 50 bytes. + if len(data) != 50 { + t.Fatalf("data length = %d, want 50", len(data)) + } + if data[0] != 4 { + t.Fatalf("discriminator = %d, want 4", data[0]) + } + if data[49] != 0 { + t.Fatalf("hasVoucher = %d, want 0", data[49]) + } +} + +func TestBuildSettleAndFinalizeWithVoucherPrependsPrecompile(t *testing.T) { + merchant := fixedKey(0x05) + authorizedSigner := fixedKey(0x04) + channel := solana.MustPublicKeyFromBase58(zeroChannelID) + var signature [64]byte + for i := range signature { + signature[i] = 0xAA + } + + instructions, err := BuildSettleAndFinalizeInstructions(SettleAndFinalizeParams{ + Merchant: merchant, + Channel: channel, + AuthorizedSigner: authorizedSigner, + Signature: &signature, + CumulativeAmount: 500, + ExpiresAt: 4_102_444_800, + }) + if err != nil { + t.Fatalf("BuildSettleAndFinalizeInstructions: %v", err) + } + if len(instructions) != 2 { + t.Fatalf("instructions = %d, want 2 (precompile + settle_and_finalize)", len(instructions)) + } + + precompile := instructions[0] + if !precompile.ProgramID().Equals(Ed25519ProgramPubkey()) { + t.Fatalf("instruction 0 program = %s, want Ed25519 precompile", precompile.ProgramID()) + } + precompileData, err := precompile.Data() + if err != nil { + t.Fatalf("precompile.Data: %v", err) + } + wantMessage, err := VoucherMessageBytes(channel, 500, 4_102_444_800) + if err != nil { + t.Fatalf("VoucherMessageBytes: %v", err) + } + if !bytes.Equal(precompileData[112:160], wantMessage) { + t.Fatal("precompile message != canonical 48-byte voucher payload") + } + if !bytes.Equal(precompileData[48:112], signature[:]) { + t.Fatal("precompile signature != voucher signature") + } + if !bytes.Equal(precompileData[16:48], authorizedSigner.Bytes()) { + t.Fatal("precompile public key != authorized signer") + } + + settleData, err := instructions[1].Data() + if err != nil { + t.Fatalf("settle.Data: %v", err) + } + if settleData[len(settleData)-1] != 1 { + t.Fatalf("hasVoucher = %d, want 1", settleData[len(settleData)-1]) + } + if got := binary.LittleEndian.Uint64(settleData[33:41]); got != 500 { + t.Fatalf("cumulativeAmount@33 = %d, want 500", got) + } + if got := int64(binary.LittleEndian.Uint64(settleData[41:49])); got != 4_102_444_800 { + t.Fatalf("expiresAt@41 = %d, want 4102444800", got) + } +} + +// ── distribute ── + +func TestBuildDistributeAppendsRecipientTokenAccounts(t *testing.T) { + channel := solana.MustPublicKeyFromBase58(zeroChannelID) + payer := fixedKey(0x01) + payee := fixedKey(0x03) + mint := solana.MustPublicKeyFromBase58("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v") + tokenProgram := solana.TokenProgramID + splitRecipient := solana.MustPublicKeyFromBase58("HQyfh1JGDB47A6Az4MD9KgF9LqcL3ESCkN8AT9Y8atGD") + + ix, err := BuildDistributeInstruction(DistributeParams{ + Channel: channel, + Payer: payer, + Payee: payee, + Mint: mint, + Recipients: []Distribution{ + {Recipient: splitRecipient, Bps: 1000}, + {Recipient: splitRecipient, Bps: 250}, + }, + TokenProgram: tokenProgram, + }) + if err != nil { + t.Fatalf("BuildDistributeInstruction: %v", err) + } + if !ix.ProgramID().Equals(ProgramPubkey()) { + t.Fatalf("program id = %s, want %s", ix.ProgramID(), ProgramID) + } + + accounts := ix.Accounts() + if len(accounts) != 12 { + t.Fatalf("accounts = %d, want 12 (10 fixed + 2 recipient ATAs)", len(accounts)) + } + recipientATA, _, err := solana.FindAssociatedTokenAddressWithProgram(splitRecipient, mint, tokenProgram) + if err != nil { + t.Fatalf("derive recipient ATA: %v", err) + } + for slot := 10; slot < 12; slot++ { + if !accounts[slot].PublicKey.Equals(recipientATA) { + t.Fatalf("tail account %d = %s, want recipient ATA %s", slot, accounts[slot].PublicKey, recipientATA) + } + if !accounts[slot].IsWritable || accounts[slot].IsSigner { + t.Fatalf("tail account %d meta = %+v, want writable non-signer", slot, accounts[slot]) + } + } + treasuryATA, _, err := solana.FindAssociatedTokenAddressWithProgram(TreasuryOwner(), mint, tokenProgram) + if err != nil { + t.Fatalf("derive treasury ATA: %v", err) + } + if !accounts[5].PublicKey.Equals(treasuryATA) { + t.Fatalf("treasury token account = %s, want %s", accounts[5].PublicKey, treasuryATA) + } + + data, err := ix.Data() + if err != nil { + t.Fatalf("ix.Data: %v", err) + } + // [disc=7][recipients_count u32][(pubkey32 + bps u16) x 2]. + if data[0] != 7 { + t.Fatalf("discriminator = %d, want 7", data[0]) + } + if got := binary.LittleEndian.Uint32(data[1:5]); got != 2 { + t.Fatalf("recipients count = %d, want 2", got) + } + if got := binary.LittleEndian.Uint16(data[5+32 : 5+34]); got != 1000 { + t.Fatalf("first bps = %d, want 1000", got) + } + if got := binary.LittleEndian.Uint16(data[5+32+34 : 5+32+36]); got != 250 { + t.Fatalf("second bps = %d, want 250", got) + } +} + +func TestBuildDistributeZeroSplits(t *testing.T) { + ix, err := BuildDistributeInstruction(DistributeParams{ + Channel: solana.MustPublicKeyFromBase58(zeroChannelID), + Payer: fixedKey(0x01), + Payee: fixedKey(0x03), + Mint: solana.MustPublicKeyFromBase58("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), + TokenProgram: solana.TokenProgramID, + }) + if err != nil { + t.Fatalf("BuildDistributeInstruction: %v", err) + } + if len(ix.Accounts()) != 10 { + t.Fatalf("accounts = %d, want 10 fixed accounts only", len(ix.Accounts())) + } + data, err := ix.Data() + if err != nil { + t.Fatalf("ix.Data: %v", err) + } + if len(data) != 5 { + t.Fatalf("data length = %d, want 5 ([disc][count=0])", len(data)) + } + if got := binary.LittleEndian.Uint32(data[1:5]); got != 0 { + t.Fatalf("recipients count = %d, want 0", got) + } +} + +func TestBuildDistributeToken2022DerivesProgramSpecificATAs(t *testing.T) { + channel := solana.MustPublicKeyFromBase58(zeroChannelID) + payee := fixedKey(0x03) + mint := solana.MustPublicKeyFromBase58("2b1kV6DkPAnxd5ixfnxCpjxmKwqjjaYmCZfHsFu24GXo") // PYUSD mainnet + token2022 := solana.MustPublicKeyFromBase58("TokenzQdBNbLqP5VEhdkAS6EPFLC1PHnBqCXEpPxuEb") + + ix, err := BuildDistributeInstruction(DistributeParams{ + Channel: channel, + Payer: fixedKey(0x01), + Payee: payee, + Mint: mint, + TokenProgram: token2022, + }) + if err != nil { + t.Fatalf("BuildDistributeInstruction: %v", err) + } + accounts := ix.Accounts() + if !accounts[7].PublicKey.Equals(token2022) { + t.Fatalf("token program account = %s, want Token-2022", accounts[7].PublicKey) + } + want2022, _, err := solana.FindAssociatedTokenAddressWithProgram(payee, mint, token2022) + if err != nil { + t.Fatalf("derive token-2022 ATA: %v", err) + } + if !accounts[4].PublicKey.Equals(want2022) { + t.Fatalf("payee token account = %s, want token-2022 ATA %s", accounts[4].PublicKey, want2022) + } + wantLegacy, _, err := solana.FindAssociatedTokenAddressWithProgram(payee, mint, solana.TokenProgramID) + if err != nil { + t.Fatalf("derive legacy ATA: %v", err) + } + if accounts[4].PublicKey.Equals(wantLegacy) { + t.Fatal("payee token account was derived with the legacy token program") + } +} + +// ── open instruction golden ── + +// TestBuildOpenInstructionMatchesTypescriptGolden pins the open instruction +// data for fixed inputs (salt=42, deposit=1_000_000, gracePeriod=900, one +// HQyfh.../250bps recipient) to the golden bytes shared with the vendored +// Codama TS client and the pre-Codama hand encoder, so all three agree byte +// for byte. +func TestBuildOpenInstructionMatchesTypescriptGolden(t *testing.T) { + const goldenDataHex = "012a0000000000000040420f00000000008403000001000000f3df6c4f444efb2d860ce6dae0b568b6dadee3c402fc33edab10836490385896fa00" + + ix, err := BuildOpenInstruction(OpenChannelParams{ + Payer: fixedKey(0x01), + Payee: fixedKey(0x03), + Mint: solana.MustPublicKeyFromBase58("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), + AuthorizedSigner: fixedKey(0x04), + Salt: 42, + Deposit: 1_000_000, + GracePeriod: 900, + Recipients: []Distribution{ + {Recipient: solana.MustPublicKeyFromBase58("HQyfh1JGDB47A6Az4MD9KgF9LqcL3ESCkN8AT9Y8atGD"), Bps: 250}, + }, + TokenProgram: solana.TokenProgramID, + }) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + data, err := ix.Data() + if err != nil { + t.Fatalf("ix.Data: %v", err) + } + if got := hex.EncodeToString(data); got != goldenDataHex { + t.Fatalf("open instruction data mismatch\n got: %s\nwant: %s", got, goldenDataHex) + } +} diff --git a/go/paycore/signer/signer.go b/go/paycore/signer/signer.go index f023d55ca..47cb11833 100644 --- a/go/paycore/signer/signer.go +++ b/go/paycore/signer/signer.go @@ -24,7 +24,10 @@ import ( // InvalidKeyError is returned by the fallible factories when the input // cannot be parsed into a 64-byte Ed25519 secret key. type InvalidKeyError struct { + // Source names the input form that failed to parse: "bytes", "json", + // "hex", "base58", "file", or "env". Source string + // Reason is the human-readable parse failure detail embedded in Error(). Reason string } @@ -34,8 +37,12 @@ func (e *InvalidKeyError) Error() string { // localSigner is the concrete value behind every local factory. type localSigner struct { - priv ed25519.PrivateKey - pub paykit.Address + // priv is the 64-byte Ed25519 private key that produces signatures. + priv ed25519.PrivateKey + // pub is the base58 public key derived from priv, returned by Pubkey. + pub paykit.Address + // isDemo marks the package-shipped demo keypair so paykit can warn on + // use and reject it on mainnet. isDemo bool } @@ -46,8 +53,8 @@ func (s *localSigner) Sign(_ context.Context, msg []byte) ([]byte, error) { func (s *localSigner) IsDemo() bool { return s.isDemo } // demoSecret is the 64-byte secret of the package-shipped demo -// keypair, identical to Ruby's PayKit::Signer::Demo and PHP's -// PayKit\Signer\Demo. Pubkey: ALtYSsZuYyKrNSe6GnVCzxj1T2RPMTPzXMe51xhbmXEq. +// keypair, identical across the language SDKs. +// Pubkey: ALtYSsZuYyKrNSe6GnVCzxj1T2RPMTPzXMe51xhbmXEq. var demoSecret = func() []byte { raw, _ := hex.DecodeString( "1a3d75c009e81833598769b62f0953f40bd655aae353aa1a37813a7259a0c333" + @@ -64,9 +71,11 @@ func Demo() paykit.Signer { return &localSigner{priv: priv, pub: pubkeyOf(priv), isDemo: true} } -// Generate produces a fresh ephemeral keypair. Test-only; production -// callers load from a file or env so the same identity survives -// restarts. +// Generate produces a fresh ephemeral keypair. Use it for identities that +// are ephemeral by design (tests, or short-lived signing keys); session +// clients use client.NewEphemeralSessionSigner for the per-session +// authorizedSigner. Persistent server identities load from a file or env so +// the same identity survives restarts. func Generate() paykit.Signer { _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { diff --git a/go/paykit/network_test.go b/go/paykit/network_test.go new file mode 100644 index 000000000..b4b9e4ae7 --- /dev/null +++ b/go/paykit/network_test.go @@ -0,0 +1,39 @@ +package paykit + +import "testing" + +// TestParseNetwork covers every accepted spelling and the error path. +func TestParseNetwork(t *testing.T) { + cases := []struct { + tag string + want Network + }{ + {"localnet", SolanaLocalnet}, + {"devnet", SolanaDevnet}, + {"mainnet", SolanaMainnet}, + {"mainnet-beta", SolanaMainnet}, + {"solana_localnet", SolanaLocalnet}, + {"solana_devnet", SolanaDevnet}, + {"solana_mainnet", SolanaMainnet}, + {"MAINNET", SolanaMainnet}, + {" devnet ", SolanaDevnet}, + } + for _, c := range cases { + got, err := ParseNetwork(c.tag) + if err != nil { + t.Fatalf("ParseNetwork(%q): unexpected error %v", c.tag, err) + } + if got != c.want { + t.Fatalf("ParseNetwork(%q) = %q, want %q", c.tag, got, c.want) + } + } +} + +// TestParseNetworkRejectsUnknownTags pins the error path. +func TestParseNetworkRejectsUnknownTags(t *testing.T) { + for _, tag := range []string{"", "testnet", "solana", "main net"} { + if _, err := ParseNetwork(tag); err == nil { + t.Fatalf("ParseNetwork(%q): expected error", tag) + } + } +} diff --git a/go/paykit/types.go b/go/paykit/types.go index 9b2d4183e..4b5f97b6d 100644 --- a/go/paykit/types.go +++ b/go/paykit/types.go @@ -1,6 +1,8 @@ package paykit import ( + "fmt" + "strings" "time" "github.com/shopspring/decimal" @@ -27,8 +29,8 @@ const ( EURC Stablecoin = "EURC" ) -// Network is the Solana cluster slug. Backing values match the Rust -// spine's `Network::as_str()` so a wire round-trip is trivial. +// Network is the Solana cluster slug. Backing values are the wire slugs +// shared across the language SDKs, so a wire round-trip is trivial. type Network string const ( @@ -37,10 +39,28 @@ const ( SolanaLocalnet Network = "solana_localnet" ) +// ParseNetwork maps a cluster tag onto the typed [Network] enum. It +// accepts the short tags the cross-language configure() surfaces use +// ("localnet", "devnet", "mainnet"), the legacy "mainnet-beta" alias, +// and the canonical wire slugs ("solana_localnet", "solana_devnet", +// "solana_mainnet"), case-insensitively. +func ParseNetwork(tag string) (Network, error) { + switch strings.ToLower(strings.TrimSpace(tag)) { + case "localnet", string(SolanaLocalnet): + return SolanaLocalnet, nil + case "devnet", string(SolanaDevnet): + return SolanaDevnet, nil + case "mainnet", "mainnet-beta", string(SolanaMainnet): + return SolanaMainnet, nil + default: + return "", fmt.Errorf("unsupported network %q (want localnet, devnet, or mainnet)", tag) + } +} + // DefaultRPCURL is the public RPC endpoint the kit falls back to when // [Config.RPCURL] is "". Localnet defaults to the hosted Surfpool // endpoint (mainnet-state fork) so the example apps boot without a -// local validator. Mirrors Ruby PR #142 + Lua PR #141 caveat #2. +// local validator. func (n Network) DefaultRPCURL() string { switch n { case SolanaMainnet: @@ -103,8 +123,16 @@ const ( // struct directly so the internal invariant (positive decimal, valid // currency) stays enforced. type Price struct { - amount decimal.Decimal - currency Currency + // amount is the exact decimal quote. Constructors enforce that it is + // positive; no rounding happens until conversion to mint base units + // at challenge-build time. + amount decimal.Decimal + // currency is the fiat denomination (USD, EUR, or GBP), fixed by the + // Parse constructor used. + currency Currency + // settlements is the ordered stablecoin preference for settling this + // price; nil means no narrowing, falling back to the kit-level + // [Config.Stablecoins] list. settlements []Stablecoin } @@ -140,9 +168,16 @@ func (p Price) Settlements() []Stablecoin { // merchant flows where the operator signer also pays Solana network fees // on settlement. type Operator struct { + // Recipient is the base58 Solana address where settled funds land; + // "" defaults to Signer.Pubkey() at [New] time. Recipient Address - Signer Signer - FeePayer bool + // Signer is the operator's Ed25519 signer, used to cosign x402 + // challenges and to fee-pay settlement transactions when FeePayer is + // set; nil defaults to the registered demo signer (non-mainnet only). + Signer Signer + // FeePayer, when true, makes the operator Signer also pay Solana + // network fees on settlement transactions instead of the client. + FeePayer bool } // X402Config groups the x402-specific knobs. @@ -164,29 +199,52 @@ type X402Config struct { // extension with info.required=true on the 402 challenge, and rejects // any submitted credential that does not echo a valid `pay_`-shaped id // (coinbase x402 payment_identifier spec: HTTP 400). When false - // (default) the challenge carries no `extensions` object, matching the - // rust spine's PaymentRequiredEnvelope.extensions: None default. + // (default) the challenge carries no `extensions` object; extensions + // default to absent on the wire. RequirePaymentIdentifier bool } // MPPConfig groups the MPP-charge-specific knobs. type MPPConfig struct { - Realm string + // Realm is the realm string advertised in the MPP WWW-Authenticate + // challenge and bound into the HMAC challenge ID; "" defaults to + // "PayKit". + Realm string + // ChallengeBindingSecret is the HMAC-SHA256 key that binds challenge + // IDs to their contents (replay/tamper protection). When empty and + // MPP is in [Config.Accept], [New] resolves one automatically. ChallengeBindingSecret []byte - ExpiresIn time.Duration + // ExpiresIn is how long an issued challenge stays valid; sent on the + // wire in whole seconds. Zero defaults to 2 minutes. + ExpiresIn time.Duration } // Config is the boot-time configuration passed to [New]. Zero-value // [Config] is invalid because Network is required; every other field // has a sensible default. type Config struct { - Network Network - Accept []Protocol + // Network is the Solana cluster the kit settles on. Required; the + // only Config field with no default. + Network Network + // Accept lists the protocols served, in preference order (first + // entry wins when a client supports several); empty defaults to + // [X402, MPP]. + Accept []Protocol + // Stablecoins lists the settlement assets offered, in preference + // order; empty defaults to USDC. Mints resolve per Network. Stablecoins []Stablecoin - RPCURL string - Operator Operator - X402 X402Config - MPP MPPConfig + // RPCURL is the Solana JSON-RPC endpoint used for verification and + // settlement; "" falls back to [Network.DefaultRPCURL]. + RPCURL string + // Operator is the merchant identity: settlement recipient, Ed25519 + // signer, and the fee-payer flag. + Operator Operator + // X402 holds the x402-specific knobs (facilitator URL, scheme, + // signer override, payment-identifier extension). + X402 X402Config + // MPP holds the MPP-charge-specific knobs (realm, challenge-binding + // HMAC secret, challenge expiry). + MPP MPPConfig // Preflight runs the soundness checks at New() time. Defaults to // true; set to false (or export PAY_KIT_DISABLE_PREFLIGHT=1) to @@ -205,9 +263,21 @@ type Config struct { // the middleware accepts a credential. Handlers read it via // [PaymentFrom] / [IsPaid] / [IsPaidFor]. type Payment struct { - Protocol Protocol - Gate string - Transaction string + // Protocol is the payment protocol (x402 or MPP) that verified and + // settled the credential. + Protocol Protocol + // Gate is the [Gate.Name] of the route gate the payment satisfied; + // matched by [IsPaidFor]. + Gate string + // Transaction is the settlement reference: the base58 Solana + // transaction signature for x402, the receipt reference for MPP. + Transaction string + // SettlementHeaders are the protocol response headers (settlement + // signature plus payment-response or Payment-Receipt) that the + // middleware copies onto the HTTP response. SettlementHeaders map[string]string - Raw string + // Raw is the credential exactly as the client presented it: the + // Authorization header value for MPP, the payment-signature header + // payload for x402. + Raw string } diff --git a/go/protocols/mpp/client/challenge_selection.go b/go/protocols/mpp/client/challenge_selection.go new file mode 100644 index 000000000..c4ad0d233 --- /dev/null +++ b/go/protocols/mpp/client/challenge_selection.go @@ -0,0 +1,145 @@ +// Session challenge selection. +// +// Servers can return multiple 402 challenges for the same resource (one per +// supported currency or intent). These helpers pick the Solana session +// challenge a client should open, filtering by network, currency, and funding +// mode while preserving server order otherwise. +package client + +import ( + "fmt" + + "github.com/solana-foundation/pay-kit/go/paycore" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// SessionRequestModes returns the funding modes a session challenge offers. +// An omitted or empty modes list means push-only. +func SessionRequestModes(request intents.SessionRequest) []intents.SessionMode { + if len(request.Modes) > 0 { + return request.Modes + } + return []intents.SessionMode{intents.SessionModePush} +} + +// SelectSessionChallengeOptions filters the session challenges a client is +// willing to open. Zero-value fields do not filter. +type SelectSessionChallengeOptions struct { + // Network is the Solana network the client wants to pay on. Use the + // paycore.Network* constants; raw strings (including the legacy + // "mainnet-beta" alias) are folded onto the canonical slug before + // matching. + Network paycore.SolanaNetwork + + // Currencies are the currency symbols or mint addresses the client wants + // to pay with. A challenge matches when its currency resolves to the same + // mint as any entry. + Currencies []string + + // Modes are the funding modes the client supports. When set, the selected + // challenge must advertise at least one of them (an omitted or empty + // challenge modes list advertises push only). + Modes []intents.SessionMode +} + +// SelectedSessionChallenge is a session challenge paired with its decoded +// request. +type SelectedSessionChallenge struct { + // Challenge is the matched 402 session challenge, kept whole so it can be + // echoed back when serializing the payment credential. + Challenge core.PaymentChallenge + + // Request is the challenge's session request decoded from its base64url + // JSON request value (currency, cap, recipient, modes, ...). + Request intents.SessionRequest +} + +// SelectSessionChallenge selects the Solana session challenge the client +// should open, or nil when none matches. A challenge with the session intent +// but an undecodable request is an error. +func SelectSessionChallenge( + challenges []core.PaymentChallenge, + options SelectSessionChallengeOptions, +) (*SelectedSessionChallenge, error) { + var candidates []SelectedSessionChallenge + + for _, challenge := range challenges { + if challenge.Method != core.NewMethodName("solana") || !challenge.Intent.IsSession() { + continue + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + return nil, fmt.Errorf("invalid Solana session challenge request: %w", err) + } + if !matchesSessionNetwork(request, options.Network) { + continue + } + if !matchesSessionCurrency(request, options.Currencies) { + continue + } + candidates = append(candidates, SelectedSessionChallenge{Challenge: challenge, Request: request}) + } + + if len(options.Modes) == 0 { + if len(candidates) == 0 { + return nil, nil + } + return &candidates[0], nil + } + + for _, candidate := range candidates { + challengeModes := SessionRequestModes(candidate.Request) + for _, accepted := range options.Modes { + for _, mode := range challengeModes { + if mode == accepted { + selected := candidate + return &selected, nil + } + } + } + } + return nil, nil +} + +// SelectSessionChallengeFromHeaders parses WWW-Authenticate header values and +// selects the Solana session challenge the client should open. Pass +// response.Header.Values(core.WWWAuthenticateHeader). +func SelectSessionChallengeFromHeaders( + headers []string, + options SelectSessionChallengeOptions, +) (*SelectedSessionChallenge, error) { + return SelectSessionChallenge(core.ParseWWWAuthenticateAll(headers), options) +} + +// matchesSessionNetwork reports whether the challenge network equals the +// requested network, treating mainnet and mainnet-beta as equivalent. +func matchesSessionNetwork(request intents.SessionRequest, network paycore.SolanaNetwork) bool { + if network == "" { + return true + } + challengeNetwork := string(paycore.NetworkMainnet) + if request.Network != nil { + challengeNetwork = *request.Network + } + return paycore.ParseSolanaNetwork(challengeNetwork) == paycore.ParseSolanaNetwork(string(network)) +} + +// matchesSessionCurrency reports whether the challenge currency resolves to +// the same mint as any accepted currency on the challenge network. +func matchesSessionCurrency(request intents.SessionRequest, currencies []string) bool { + if len(currencies) == 0 { + return true + } + network := "" + if request.Network != nil { + network = *request.Network + } + challengeMint := paycore.ResolveMint(request.Currency, network) + for _, accepted := range currencies { + if paycore.ResolveMint(accepted, network) == challengeMint { + return true + } + } + return false +} diff --git a/go/protocols/mpp/client/challenge_selection_test.go b/go/protocols/mpp/client/challenge_selection_test.go new file mode 100644 index 000000000..9715d484f --- /dev/null +++ b/go/protocols/mpp/client/challenge_selection_test.go @@ -0,0 +1,235 @@ +package client + +import ( + "strings" + "testing" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +func sessionChallenge(t *testing.T, request intents.SessionRequest) core.PaymentChallenge { + t.Helper() + encoded, err := core.NewBase64URLJSONValue(request) + if err != nil { + t.Fatalf("encode session request: %v", err) + } + return core.PaymentChallenge{ + ID: "challenge-id", + Realm: "example", + Method: core.NewMethodName("solana"), + Intent: core.NewIntentName("session"), + Request: encoded, + } +} + +func chargeIntentChallenge(t *testing.T) core.PaymentChallenge { + t.Helper() + challenge := sessionChallenge(t, testSessionRequest( + testutil.NewPrivateKey().PublicKey(), testutil.NewPrivateKey().PublicKey())) + challenge.Intent = core.NewIntentName("charge") + return challenge +} + +func TestSessionRequestModesDefaultsToPushOnly(t *testing.T) { + cases := []struct { + name string + modes []intents.SessionMode + want []intents.SessionMode + }{ + {"omitted", nil, []intents.SessionMode{intents.SessionModePush}}, + {"explicit empty", []intents.SessionMode{}, []intents.SessionMode{intents.SessionModePush}}, + {"advertised", []intents.SessionMode{intents.SessionModePull}, + []intents.SessionMode{intents.SessionModePull}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + request := intents.SessionRequest{Modes: tc.modes} + got := SessionRequestModes(request) + if len(got) != len(tc.want) { + t.Fatalf("modes = %v, want %v", got, tc.want) + } + for i := range got { + if got[i] != tc.want[i] { + t.Fatalf("modes = %v, want %v", got, tc.want) + } + } + }) + } +} + +func TestSelectSessionChallengeSkipsNonSessionChallenges(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + session := sessionChallenge(t, testSessionRequest(operator, recipient)) + + selected, err := SelectSessionChallenge( + []core.PaymentChallenge{chargeIntentChallenge(t), session}, + SelectSessionChallengeOptions{}, + ) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected == nil { + t.Fatal("no challenge selected") + } + if !selected.Challenge.Intent.IsSession() { + t.Fatalf("selected intent = %s, want session", selected.Challenge.Intent) + } + if selected.Request.Operator != operator.String() { + t.Fatalf("decoded operator = %s", selected.Request.Operator) + } +} + +func TestSelectSessionChallengeFiltersByNetwork(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + devnet := testSessionRequest(operator, recipient) + devnetName := "devnet" + devnet.Network = &devnetName + mainnet := testSessionRequest(operator, recipient) + mainnetName := "mainnet" + mainnet.Network = &mainnetName + + challenges := []core.PaymentChallenge{sessionChallenge(t, devnet), sessionChallenge(t, mainnet)} + + selected, err := SelectSessionChallenge(challenges, SelectSessionChallengeOptions{Network: "mainnet-beta"}) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected == nil || selected.Request.Network == nil || *selected.Request.Network != "mainnet" { + t.Fatalf("selected = %+v, want the mainnet challenge for mainnet-beta", selected) + } + + selected, err = SelectSessionChallenge(challenges, SelectSessionChallengeOptions{Network: paycore.NetworkLocalnet}) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected != nil { + t.Fatalf("selected = %+v, want none for localnet", selected) + } +} + +func TestSelectSessionChallengeFiltersByCurrencyMint(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + usdc := testSessionRequest(operator, recipient) + pyusd := testSessionRequest(operator, recipient) + pyusd.Currency = "PYUSD" + + challenges := []core.PaymentChallenge{sessionChallenge(t, usdc), sessionChallenge(t, pyusd)} + + selected, err := SelectSessionChallenge(challenges, SelectSessionChallengeOptions{ + Currencies: []string{"PYUSD"}, + }) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected == nil || selected.Request.Currency != "PYUSD" { + t.Fatalf("selected = %+v, want PYUSD challenge", selected) + } + + // A mint address matches its symbol through mint resolution. + mintMatched, err := SelectSessionChallenge(challenges, SelectSessionChallengeOptions{ + Currencies: []string{selectedMintForUSDC()}, + }) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if mintMatched == nil || mintMatched.Request.Currency != "USDC" { + t.Fatalf("selected = %+v, want USDC challenge via mint address", mintMatched) + } +} + +// selectedMintForUSDC returns the mainnet USDC mint (localnet resolves to it). +func selectedMintForUSDC() string { + return "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v" +} + +func TestSelectSessionChallengePrefersAdvertisedMode(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + + pushOnly := testSessionRequest(operator, recipient) + pushOnly.Modes = nil + pushOnly.PullVoucherStrategy = nil + pull := testSessionRequest(operator, recipient) + + challenges := []core.PaymentChallenge{sessionChallenge(t, pushOnly), sessionChallenge(t, pull)} + + selected, err := SelectSessionChallenge(challenges, SelectSessionChallengeOptions{ + Modes: []intents.SessionMode{intents.SessionModePull}, + }) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected == nil || len(selected.Request.Modes) == 0 || + selected.Request.Modes[0] != intents.SessionModePull { + t.Fatalf("selected = %+v, want the pull challenge", selected) + } + + // An omitted modes list advertises push, so a push client matches it first. + selected, err = SelectSessionChallenge(challenges, SelectSessionChallengeOptions{ + Modes: []intents.SessionMode{intents.SessionModePush}, + }) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected == nil || len(selected.Request.Modes) != 0 { + t.Fatalf("selected = %+v, want the omitted-modes push-only challenge", selected) + } + + // No advertised mode matches: nothing selected. + selected, err = SelectSessionChallenge( + []core.PaymentChallenge{sessionChallenge(t, pushOnly)}, + SelectSessionChallengeOptions{Modes: []intents.SessionMode{intents.SessionModePull}}, + ) + if err != nil { + t.Fatalf("SelectSessionChallenge: %v", err) + } + if selected != nil { + t.Fatalf("selected = %+v, want none (push-only challenge, pull-only client)", selected) + } +} + +func TestSelectSessionChallengeRejectsUndecodableSessionRequest(t *testing.T) { + challenge := core.PaymentChallenge{ + ID: "challenge-id", + Realm: "example", + Method: core.NewMethodName("solana"), + Intent: core.NewIntentName("session"), + Request: core.NewBase64URLJSONRaw("!!!not-base64url!!!"), + } + _, err := SelectSessionChallenge([]core.PaymentChallenge{challenge}, SelectSessionChallengeOptions{}) + if err == nil || !strings.Contains(err.Error(), "invalid Solana session challenge request") { + t.Fatalf("error = %v, want invalid request", err) + } +} + +func TestSelectSessionChallengeFromHeaders(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + challenge := sessionChallenge(t, testSessionRequest(operator, recipient)) + header, err := core.FormatWWWAuthenticate(challenge) + if err != nil { + t.Fatalf("FormatWWWAuthenticate: %v", err) + } + + selected, err := SelectSessionChallengeFromHeaders([]string{header}, SelectSessionChallengeOptions{}) + if err != nil { + t.Fatalf("SelectSessionChallengeFromHeaders: %v", err) + } + if selected == nil || selected.Request.Operator != operator.String() { + t.Fatalf("selected = %+v, want parsed session challenge", selected) + } + + none, err := SelectSessionChallengeFromHeaders([]string{"Basic realm=x"}, SelectSessionChallengeOptions{}) + if err != nil { + t.Fatalf("SelectSessionChallengeFromHeaders: %v", err) + } + if none != nil { + t.Fatalf("selected = %+v, want none for non-Payment header", none) + } +} diff --git a/go/protocols/mpp/client/charge_test.go b/go/protocols/mpp/client/charge_test.go index 5fd342855..3c556b419 100644 --- a/go/protocols/mpp/client/charge_test.go +++ b/go/protocols/mpp/client/charge_test.go @@ -153,9 +153,9 @@ func TestBuildChargeTransactionTokenPull(t *testing.T) { } // TestBuildChargeTransactionTokenCreateRecipientATAFlag table-tests the -// opt-in CreateRecipientATA flag. The default (false) matches the -// canonical Rust/TS clients which leave primary-recipient ATA creation -// to the server, while setting the flag prepends an idempotent +// opt-in CreateRecipientATA flag. The default (false) leaves +// primary-recipient ATA creation to the server, as the other SDK clients +// do, while setting the flag prepends an idempotent // createAssociatedTokenAccount instruction for first-run wallets that // do not yet hold a token account for the selected mint. func TestBuildChargeTransactionTokenCreateRecipientATAFlag(t *testing.T) { @@ -548,6 +548,8 @@ func TestBuildCredentialHeaderRejectsInvalidMethodDetails(t *testing.T) { // rpcWithBlockhashErr wraps FakeRPC and forces GetLatestBlockhash to error. type rpcWithBlockhashErr struct { + // FakeRPC supplies the rest of the stub RPC surface unchanged; only the + // GetLatestBlockhash method below is overridden to fail. *testutil.FakeRPC } @@ -701,9 +703,6 @@ func TestBuildChargeTransactionTokenWithExternalIDMemoTooLong(t *testing.T) { } } -// rpcSendErr forces SendTransaction to error to cover the broadcast error branch. -type rpcSendErr struct{ *testutil.FakeRPC } - func TestBuildChargeTransactionBroadcastSendError(t *testing.T) { rpcClient := testutil.NewFakeRPC() rpcClient.SendErr = errors.New("send rpc down") diff --git a/go/protocols/mpp/client/http_stream.go b/go/protocols/mpp/client/http_stream.go new file mode 100644 index 000000000..2d5401e5a --- /dev/null +++ b/go/protocols/mpp/client/http_stream.go @@ -0,0 +1,495 @@ +// HTTP streaming helpers for metered sessions. +// +// LLM APIs commonly stream responses over Server-Sent Events (SSE) or chunked +// HTTP. This file keeps the parser transport-neutral (SseDecoder works on raw +// chunks from any reader), then layers a net/http-friendly stream and commit +// transport on top for applications that want batteries included. +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "unicode/utf8" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// SseEvent is a parsed Server-Sent Event frame. Event, ID, and Retry are nil +// when the frame omitted the field. +type SseEvent struct { + // Event is the event name from "event:" lines; nil means the frame used + // the default "message" event. + Event *string + + // Data is the frame payload: all "data:" line values joined with + // newlines, empty when the frame carried no data lines. + Data string + + // ID is the event id from the "id:" line, the value a reconnecting + // client would echo as Last-Event-ID. + ID *string + + // Retry is the reconnection delay in milliseconds from the "retry:" + // line; non-numeric values are dropped. + Retry *uint64 +} + +// SseDecoder is an incremental SSE decoder. +// +// Feed raw HTTP chunks with PushChunk. It returns all complete events decoded +// from that chunk and retains partial data internally. +type SseDecoder struct { + // buffer holds the undecoded tail of the stream: bytes pushed but not + // yet terminated by a newline. + buffer string + + // current accumulates the fields of the in-progress event until a blank + // line dispatches it. + current SseEvent +} + +// PushChunk decodes the events completed by a raw chunk of the stream body. +func (d *SseDecoder) PushChunk(chunk []byte) ([]SseEvent, error) { + if !utf8.Valid(chunk) { + return nil, fmt.Errorf("SSE chunk is not valid UTF-8") + } + d.buffer += string(chunk) + + var events []SseEvent + for { + index := strings.IndexByte(d.buffer, '\n') + if index < 0 { + break + } + line := d.buffer[:index] + d.buffer = d.buffer[index+1:] + line = strings.TrimSuffix(line, "\r") + if event, ok := d.processLine(line); ok { + events = append(events, event) + } + } + return events, nil +} + +// Finish flushes an incomplete final event, if any, at EOF. +func (d *SseDecoder) Finish() ([]SseEvent, error) { + var events []SseEvent + if d.buffer != "" { + line := strings.TrimSuffix(d.buffer, "\r") + d.buffer = "" + if event, ok := d.processLine(line); ok { + events = append(events, event) + } + } + if event, ok := d.dispatchCurrent(); ok { + events = append(events, event) + } + return events, nil +} + +func (d *SseDecoder) processLine(line string) (SseEvent, bool) { + if line == "" { + return d.dispatchCurrent() + } + if strings.HasPrefix(line, ":") { + return SseEvent{}, false + } + + field := line + value := "" + if index := strings.IndexByte(line, ':'); index >= 0 { + field = line[:index] + value = strings.TrimPrefix(line[index+1:], " ") + } + + switch field { + case "event": + event := value + d.current.Event = &event + case "data": + if d.current.Data != "" { + d.current.Data += "\n" + } + d.current.Data += value + case "id": + id := value + d.current.ID = &id + case "retry": + if retry, err := strconv.ParseUint(value, 10, 64); err == nil { + d.current.Retry = &retry + } + } + return SseEvent{}, false +} + +func (d *SseDecoder) dispatchCurrent() (SseEvent, bool) { + if d.current.Event == nil && d.current.Data == "" && d.current.ID == nil && d.current.Retry == nil { + return SseEvent{}, false + } + current := d.current + d.current = SseEvent{} + return current, true +} + +// MeteredSseEventKind discriminates ParseMeteredSseEvent results. +type MeteredSseEventKind int + +// MeteredSseEvent kinds. +const ( + // MeteredSseEventMetering is an mpp.metering / metering directive event. + MeteredSseEventMetering MeteredSseEventKind = iota + + // MeteredSseEventUsage is an mpp.usage / usage final-amount event. + MeteredSseEventUsage + + // MeteredSseEventMessage is an application message event. + MeteredSseEventMessage + + // MeteredSseEventDone is a done event or [DONE] sentinel message. + MeteredSseEventDone + + // MeteredSseEventOther is an unrecognized event passed through untouched. + MeteredSseEventOther +) + +// MeteredSseEvent is a parsed metered SSE event. Exactly the field matching +// Kind is populated. +type MeteredSseEvent struct { + // Kind discriminates which of the payload fields below is set. + Kind MeteredSseEventKind + + // Metering is the decoded directive of an mpp.metering/metering event + // (the per-delivery amount reservation). + Metering *intents.MeteringDirective + + // Usage is the decoded final-amount report of an mpp.usage/usage event. + Usage *intents.MeteringUsage + + // Message is the raw JSON payload of an application message event, left + // for the caller to decode. + Message json.RawMessage + + // Other is the unrecognized event passed through verbatim. + Other *SseEvent +} + +// ParseMeteredSseEvent classifies an SSE event by the metered-session event +// names: "mpp.metering"/"metering", "mpp.usage"/"usage", "done", and the +// "[DONE]" sentinel on the default message event. Application messages keep +// their raw JSON payload for the caller to decode. +func ParseMeteredSseEvent(event SseEvent) (MeteredSseEvent, error) { + eventName := "message" + if event.Event != nil { + eventName = *event.Event + } + switch eventName { + case "mpp.metering", "metering": + directive := intents.MeteringDirective{} + if err := json.Unmarshal([]byte(event.Data), &directive); err != nil { + return MeteredSseEvent{}, fmt.Errorf("invalid mpp.metering event: %w", err) + } + return MeteredSseEvent{Kind: MeteredSseEventMetering, Metering: &directive}, nil + case "mpp.usage", "usage": + usage := intents.MeteringUsage{} + if err := json.Unmarshal([]byte(event.Data), &usage); err != nil { + return MeteredSseEvent{}, fmt.Errorf("invalid mpp.usage event: %w", err) + } + return MeteredSseEvent{Kind: MeteredSseEventUsage, Usage: &usage}, nil + case "done": + return MeteredSseEvent{Kind: MeteredSseEventDone}, nil + case "message": + if strings.TrimSpace(event.Data) == "[DONE]" { + return MeteredSseEvent{Kind: MeteredSseEventDone}, nil + } + if !json.Valid([]byte(event.Data)) { + return MeteredSseEvent{}, fmt.Errorf("invalid SSE message event: %q", event.Data) + } + return MeteredSseEvent{Kind: MeteredSseEventMessage, Message: json.RawMessage(event.Data)}, nil + default: + other := event + return MeteredSseEvent{Kind: MeteredSseEventOther, Other: &other}, nil + } +} + +// meteredStreamState pairs the live metering directive with the optional final +// usage amount. +type meteredStreamState struct { + // directive is the latest mpp.metering directive seen on the stream; + // nil until one arrives, which makes committing an error. + directive *intents.MeteringDirective + + // finalAmount is the usage-reported final amount in token base units; + // nil commits the directive's reserved amount instead. + finalAmount *uint64 + + // done is set once the stream signals completion via a done event, a + // [DONE] sentinel message, or EOF. + done bool +} + +// applyEvent folds one SSE event into the state, returning the raw application +// message when the event carries one. A usage event must reference the live +// directive's deliveryId; it may override only the amount. +func (s *meteredStreamState) applyEvent(event SseEvent) (json.RawMessage, error) { + parsed, err := ParseMeteredSseEvent(event) + if err != nil { + return nil, err + } + switch parsed.Kind { + case MeteredSseEventMetering: + s.directive = parsed.Metering + return nil, nil + case MeteredSseEventUsage: + if s.directive != nil && parsed.Usage.DeliveryID != s.directive.DeliveryID { + return nil, fmt.Errorf( + "usage delivery %s does not match directive %s", + parsed.Usage.DeliveryID, s.directive.DeliveryID) + } + amount, err := parsed.Usage.AmountBaseUnits() + if err != nil { + return nil, err + } + s.finalAmount = &amount + return nil, nil + case MeteredSseEventMessage: + return parsed.Message, nil + case MeteredSseEventDone: + s.done = true + return nil, nil + default: + return nil, nil + } +} + +// directiveForCommit returns the live directive with the final usage amount +// applied, erroring when the stream never emitted a metering event. +func (s *meteredStreamState) directiveForCommit() (intents.MeteringDirective, error) { + if s.directive == nil { + return intents.MeteringDirective{}, fmt.Errorf("stream did not include mpp.metering event") + } + directive := *s.directive + if s.finalAmount != nil { + directive.Amount = strconv.FormatUint(*s.finalAmount, 10) + } + return directive, nil +} + +// MeteredSseSession is a transport-neutral state machine for one metered SSE +// stream: feed it decoded SSE events, then Ack to commit the final amount. +type MeteredSseSession struct { + // consumer is the borrowed session consumer that signs and commits the + // stream's final amount on Ack. + consumer *SessionConsumer + + // state folds the metering directive, final usage amount, and done flag + // out of the accepted events. + state meteredStreamState +} + +// MeteredSse starts a metered SSE state machine borrowing this consumer. +func (c *SessionConsumer) MeteredSse() *MeteredSseSession { + return &MeteredSseSession{consumer: c} +} + +// AcceptEvent folds one decoded SSE event into the stream state and returns +// the raw application message when the event carries one. +func (s *MeteredSseSession) AcceptEvent(event SseEvent) (json.RawMessage, error) { + return s.state.applyEvent(event) +} + +// IsDone reports whether the stream signaled completion. +func (s *MeteredSseSession) IsDone() bool { return s.state.done } + +// Ack commits the stream's final amount (the usage amount when reported, +// otherwise the directive's reserved amount) through the consumer. +func (s *MeteredSseSession) Ack(ctx context.Context) (intents.CommitReceipt, error) { + directive, err := s.state.directiveForCommit() + if err != nil { + return intents.CommitReceipt{}, err + } + return s.consumer.CommitDirective(ctx, directive) +} + +// HTTPCommitTransport is a minimal net/http transport for commit endpoints. +// The zero value posts to each directive's CommitURL with the default client. +type HTTPCommitTransport struct { + // Client is the HTTP client. nil uses http.DefaultClient. + Client *http.Client + + // DefaultCommitURL is the commit endpoint used when a directive omits + // CommitURL. + DefaultCommitURL string + + // Authorization is an optional Authorization header value attached to + // every commit request. + Authorization string +} + +// Commit posts the payload as JSON to the directive's commit endpoint and +// decodes the receipt. +func (t *HTTPCommitTransport) Commit( + ctx context.Context, + directive intents.MeteringDirective, + payload intents.CommitPayload, +) (intents.CommitReceipt, error) { + url := t.DefaultCommitURL + if directive.CommitURL != nil { + url = *directive.CommitURL + } + if url == "" { + return intents.CommitReceipt{}, fmt.Errorf("metering directive missing commitUrl") + } + + body, err := json.Marshal(payload) + if err != nil { + return intents.CommitReceipt{}, fmt.Errorf("encode commit payload: %w", err) + } + request, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return intents.CommitReceipt{}, fmt.Errorf("build commit request: %w", err) + } + request.Header.Set("Content-Type", "application/json") + if t.Authorization != "" { + request.Header.Set("Authorization", t.Authorization) + } + + client := t.Client + if client == nil { + client = http.DefaultClient + } + response, err := client.Do(request) + if err != nil { + return intents.CommitReceipt{}, fmt.Errorf("commit request failed: %w", err) + } + defer func() { _ = response.Body.Close() }() + + if response.StatusCode < 200 || response.StatusCode >= 300 { + detail, _ := io.ReadAll(io.LimitReader(response.Body, 4096)) + return intents.CommitReceipt{}, fmt.Errorf( + "commit endpoint returned %d: %s", response.StatusCode, string(detail)) + } + + receipt := intents.CommitReceipt{} + if err := json.NewDecoder(response.Body).Decode(&receipt); err != nil { + return intents.CommitReceipt{}, fmt.Errorf("invalid commit receipt: %w", err) + } + return receipt, nil +} + +// MeteredSseStream reads a metered SSE response body, yielding raw application +// messages and committing the final amount on Ack. +type MeteredSseStream struct { + // consumer is the session consumer that signs and commits the stream's + // final amount on Ack; IntoConsumer hands it back for the next request. + consumer *SessionConsumer + + // body is the SSE response body being drained; the caller retains + // ownership and closes it after the stream is done. + body io.Reader + + // decoder incrementally parses raw body chunks into SSE events. + decoder SseDecoder + + // pending queues decoded application messages not yet returned by Next. + pending []json.RawMessage + + // state folds the metering directive, final usage amount, and done flag + // out of the decoded events. + state meteredStreamState + + // buf is the reusable scratch buffer for body reads. + buf []byte +} + +// NewMeteredSseStream wraps a consumer and an SSE response body, e.g. +// http.Response.Body. The caller retains ownership of the body and closes it +// after the stream is drained. +func NewMeteredSseStream(consumer *SessionConsumer, body io.Reader) *MeteredSseStream { + return &MeteredSseStream{ + consumer: consumer, + body: body, + buf: make([]byte, 4096), + } +} + +// Next returns the next application message, or nil once the stream is done. +func (s *MeteredSseStream) Next() (json.RawMessage, error) { + for { + if len(s.pending) > 0 { + message := s.pending[0] + s.pending = s.pending[1:] + return message, nil + } + if s.state.done { + return nil, nil + } + + n, readErr := s.body.Read(s.buf) + if n > 0 { + events, err := s.decoder.PushChunk(s.buf[:n]) + if err != nil { + return nil, err + } + if err := s.applyEvents(events); err != nil { + return nil, err + } + } + if readErr != nil { + if readErr != io.EOF { + return nil, fmt.Errorf("stream read failed: %w", readErr) + } + events, err := s.decoder.Finish() + if err != nil { + return nil, err + } + if err := s.applyEvents(events); err != nil { + return nil, err + } + s.state.done = true + } + } +} + +func (s *MeteredSseStream) applyEvents(events []SseEvent) error { + for _, event := range events { + message, err := s.state.applyEvent(event) + if err != nil { + return err + } + if message != nil { + s.pending = append(s.pending, message) + } + } + return nil +} + +// Ack drains any remaining events and commits the stream's final amount. +func (s *MeteredSseStream) Ack(ctx context.Context) (intents.CommitReceipt, error) { + if !s.state.done { + for { + message, err := s.Next() + if err != nil { + return intents.CommitReceipt{}, err + } + if message == nil { + break + } + } + } + directive, err := s.state.directiveForCommit() + if err != nil { + return intents.CommitReceipt{}, err + } + return s.consumer.CommitDirective(ctx, directive) +} + +// IntoConsumer returns the wrapped consumer for reuse on the next request. +func (s *MeteredSseStream) IntoConsumer() *SessionConsumer { + return s.consumer +} diff --git a/go/protocols/mpp/client/http_stream_test.go b/go/protocols/mpp/client/http_stream_test.go new file mode 100644 index 000000000..c9be2cae9 --- /dev/null +++ b/go/protocols/mpp/client/http_stream_test.go @@ -0,0 +1,449 @@ +package client + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +func sseEvt(event string, data string) SseEvent { + if event == "" { + return SseEvent{Data: data} + } + return SseEvent{Event: &event, Data: data} +} + +func mustJSON(t *testing.T, value any) string { + t.Helper() + raw, err := json.Marshal(value) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return string(raw) +} + +type delta struct { + Delta string `json:"delta"` // wire "delta": text fragment of a test app message +} + +func decodeDelta(t *testing.T, raw json.RawMessage) delta { + t.Helper() + out := delta{} + if err := json.Unmarshal(raw, &out); err != nil { + t.Fatalf("decode delta: %v", err) + } + return out +} + +func TestSseDecoderHandlesSplitChunks(t *testing.T) { + decoder := SseDecoder{} + events, err := decoder.PushChunk([]byte("event: message\ndata: {\"delta\"")) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + if len(events) != 0 { + t.Fatalf("partial chunk dispatched %d events", len(events)) + } + events, err = decoder.PushChunk([]byte(":\"hi\"}\n\n")) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + if events[0].Event == nil || *events[0].Event != "message" { + t.Fatalf("event = %v, want message", events[0].Event) + } + if events[0].Data != `{"delta":"hi"}` { + t.Fatalf("data = %q", events[0].Data) + } +} + +func TestSseDecoderHandlesMetadataCRLFCommentsAndFinish(t *testing.T) { + decoder := SseDecoder{} + events, err := decoder.PushChunk( + []byte(": keepalive\r\nid: evt-1\r\nretry: 250\r\ndata: hello\r\ndata: world\r\n\r\n")) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + if len(events) != 1 { + t.Fatalf("events = %d, want 1", len(events)) + } + got := events[0] + if got.Event != nil { + t.Fatalf("event = %v, want nil", got.Event) + } + if got.Data != "hello\nworld" { + t.Fatalf("data = %q, want multi-line join", got.Data) + } + if got.ID == nil || *got.ID != "evt-1" { + t.Fatalf("id = %v, want evt-1", got.ID) + } + if got.Retry == nil || *got.Retry != 250 { + t.Fatalf("retry = %v, want 250", got.Retry) + } + + events, err = decoder.PushChunk([]byte("retry: nope\nunknown\n\n")) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + if len(events) != 0 { + t.Fatalf("invalid retry/unknown field dispatched %d events", len(events)) + } + + events, err = decoder.PushChunk([]byte("event: message\ndata: tail")) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + if len(events) != 0 { + t.Fatalf("incomplete event dispatched early") + } + events, err = decoder.Finish() + if err != nil { + t.Fatalf("Finish: %v", err) + } + if len(events) != 1 || events[0].Event == nil || *events[0].Event != "message" || events[0].Data != "tail" { + t.Fatalf("finish events = %+v, want trailing message", events) + } +} + +func TestSseDecoderRejectsInvalidUTF8(t *testing.T) { + decoder := SseDecoder{} + _, err := decoder.PushChunk([]byte{0xff}) + if err == nil || !strings.Contains(err.Error(), "valid UTF-8") { + t.Fatalf("error = %v, want UTF-8 rejection", err) + } +} + +func TestParseMeteredSseEvents(t *testing.T) { + meteringDirective := directive("chan", "1000") + parsed, err := ParseMeteredSseEvent(sseEvt("mpp.metering", mustJSON(t, meteringDirective))) + if err != nil { + t.Fatalf("ParseMeteredSseEvent: %v", err) + } + if parsed.Kind != MeteredSseEventMetering || parsed.Metering.Amount != "1000" { + t.Fatalf("parsed = %+v, want metering amount 1000", parsed) + } + + parsed, err = ParseMeteredSseEvent(sseEvt("message", `{"delta":"hello"}`)) + if err != nil { + t.Fatalf("ParseMeteredSseEvent: %v", err) + } + if parsed.Kind != MeteredSseEventMessage { + t.Fatalf("kind = %v, want message", parsed.Kind) + } + if decodeDelta(t, parsed.Message).Delta != "hello" { + t.Fatalf("message = %s", parsed.Message) + } +} + +func TestParseMeteredSseUsageDoneOtherAndErrors(t *testing.T) { + parsed, err := ParseMeteredSseEvent(sseEvt("mpp.usage", `{"deliveryId":"d1","amount":"17"}`)) + if err != nil { + t.Fatalf("ParseMeteredSseEvent: %v", err) + } + if parsed.Kind != MeteredSseEventUsage { + t.Fatalf("kind = %v, want usage", parsed.Kind) + } + amount, err := parsed.Usage.AmountBaseUnits() + if err != nil || amount != 17 { + t.Fatalf("usage amount = %d (%v), want 17", amount, err) + } + + parsed, err = ParseMeteredSseEvent(sseEvt("done", "")) + if err != nil || parsed.Kind != MeteredSseEventDone { + t.Fatalf("done parse = %+v (%v)", parsed, err) + } + parsed, err = ParseMeteredSseEvent(sseEvt("", " [DONE] ")) + if err != nil || parsed.Kind != MeteredSseEventDone { + t.Fatalf("[DONE] sentinel parse = %+v (%v)", parsed, err) + } + parsed, err = ParseMeteredSseEvent(sseEvt("trace", "ignored")) + if err != nil || parsed.Kind != MeteredSseEventOther { + t.Fatalf("other parse = %+v (%v)", parsed, err) + } + + if _, err := ParseMeteredSseEvent(sseEvt("metering", "{")); err == nil { + t.Fatal("invalid metering JSON accepted") + } + if _, err := ParseMeteredSseEvent(sseEvt("usage", "{")); err == nil { + t.Fatal("invalid usage JSON accepted") + } + if _, err := ParseMeteredSseEvent(sseEvt("", "{")); err == nil { + t.Fatal("invalid message JSON accepted") + } +} + +func TestMeteredSseAckUsesFinalUsageAmount(t *testing.T) { + consumer, _ := newConsumer(t, false) + stream := consumer.MeteredSse() + meteringDirective := directive(consumer.Session().ChannelIDString(), "1000") + meteringDirective.DeliveryID = "stream-1" + + message, err := stream.AcceptEvent(sseEvt("mpp.metering", mustJSON(t, meteringDirective))) + if err != nil || message != nil { + t.Fatalf("metering accept = %s (%v)", message, err) + } + message, err = stream.AcceptEvent(sseEvt("message", `{"delta":"hello"}`)) + if err != nil { + t.Fatalf("AcceptEvent: %v", err) + } + if decodeDelta(t, message).Delta != "hello" { + t.Fatalf("message = %s", message) + } + if _, err := stream.AcceptEvent(sseEvt("mpp.usage", `{"deliveryId":"stream-1","amount":"425"}`)); err != nil { + t.Fatalf("usage accept: %v", err) + } + + receipt, err := stream.Ack(context.Background()) + if err != nil { + t.Fatalf("Ack: %v", err) + } + if receipt.Amount != "425" || receipt.Cumulative != "425" { + t.Fatalf("receipt = %+v, want final usage amount 425", receipt) + } + if consumer.Session().Cumulative() != 425 { + t.Fatalf("session cumulative = %d, want 425", consumer.Session().Cumulative()) + } +} + +func TestMeteredSseAckUsesReservedAmountWithoutUsageAndTracksDone(t *testing.T) { + consumer, _ := newConsumer(t, false) + stream := consumer.MeteredSse() + meteringDirective := directive(consumer.Session().ChannelIDString(), "1000") + + if _, err := stream.AcceptEvent(sseEvt("mpp.metering", mustJSON(t, meteringDirective))); err != nil { + t.Fatalf("metering accept: %v", err) + } + if _, err := stream.AcceptEvent(sseEvt("done", "")); err != nil { + t.Fatalf("done accept: %v", err) + } + if !stream.IsDone() { + t.Fatal("stream should be done") + } + + receipt, err := stream.Ack(context.Background()) + if err != nil { + t.Fatalf("Ack: %v", err) + } + if receipt.Amount != "1000" || receipt.Cumulative != "1000" { + t.Fatalf("receipt = %+v, want reserved amount 1000", receipt) + } +} + +func TestMeteredSseReportsMissingMeteringAndUsageMismatch(t *testing.T) { + consumer, _ := newConsumer(t, false) + stream := consumer.MeteredSse() + if _, err := stream.Ack(context.Background()); err == nil || + !strings.Contains(err.Error(), "mpp.metering") { + t.Fatalf("error = %v, want missing metering", err) + } + + stream = consumer.MeteredSse() + meteringDirective := directive(consumer.Session().ChannelIDString(), "1000") + meteringDirective.DeliveryID = "stream-1" + if _, err := stream.AcceptEvent(sseEvt("mpp.metering", mustJSON(t, meteringDirective))); err != nil { + t.Fatalf("metering accept: %v", err) + } + _, err := stream.AcceptEvent(sseEvt("mpp.usage", `{"deliveryId":"other","amount":"1"}`)) + if err == nil || !strings.Contains(err.Error(), "does not match directive") { + t.Fatalf("error = %v, want usage mismatch", err) + } +} + +func TestMeteredSseUsageBeforeDirectiveAccepted(t *testing.T) { + // A usage event may arrive before the directive; it is accepted and the + // amount applies to whichever directive follows. + consumer, _ := newConsumer(t, false) + stream := consumer.MeteredSse() + if _, err := stream.AcceptEvent(sseEvt("mpp.usage", `{"deliveryId":"stream-1","amount":"7"}`)); err != nil { + t.Fatalf("usage-before-directive rejected: %v", err) + } + meteringDirective := directive(consumer.Session().ChannelIDString(), "1000") + meteringDirective.DeliveryID = "stream-1" + if _, err := stream.AcceptEvent(sseEvt("mpp.metering", mustJSON(t, meteringDirective))); err != nil { + t.Fatalf("metering accept: %v", err) + } + receipt, err := stream.Ack(context.Background()) + if err != nil { + t.Fatalf("Ack: %v", err) + } + if receipt.Amount != "7" { + t.Fatalf("receipt amount = %s, want early usage 7", receipt.Amount) + } +} + +func newCommitServer(t *testing.T) (*httptest.Server, *int) { + t.Helper() + commits := 0 + mux := http.NewServeMux() + mux.HandleFunc("/commit", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer sdk-test" { + http.Error(w, "missing auth", http.StatusUnauthorized) + return + } + payload := intents.CommitPayload{} + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + commits++ + receipt := intents.CommitReceipt{ + DeliveryID: payload.DeliveryID, + SessionID: payload.Voucher.Data.ChannelID, + Amount: payload.Voucher.Data.Cumulative, + Cumulative: payload.Voucher.Data.Cumulative, + Status: intents.CommitStatusCommitted, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(receipt) + }) + mux.HandleFunc("/commit-error", func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "commit failed", http.StatusInternalServerError) + }) + mux.HandleFunc("/commit-invalid-json", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("not json")) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + return server, &commits +} + +func TestHTTPCommitTransportSuccessAndErrors(t *testing.T) { + server, commits := newCommitServer(t) + session, _ := newSession(t) + meteringDirective := directive(session.ChannelIDString(), "88") + voucher, err := session.PrepareIncrement(88) + if err != nil { + t.Fatalf("PrepareIncrement: %v", err) + } + payload := intents.CommitPayload{DeliveryID: meteringDirective.DeliveryID, Voucher: voucher} + + transport := &HTTPCommitTransport{ + DefaultCommitURL: server.URL + "/commit", + Authorization: "Bearer sdk-test", + } + receipt, err := transport.Commit(context.Background(), meteringDirective, payload) + if err != nil { + t.Fatalf("Commit: %v", err) + } + if receipt.Cumulative != "88" { + t.Fatalf("receipt cumulative = %s, want 88", receipt.Cumulative) + } + if *commits != 1 { + t.Fatalf("commits = %d, want 1", *commits) + } + + missingURL := &HTTPCommitTransport{} + if _, err := missingURL.Commit(context.Background(), meteringDirective, payload); err == nil || + !strings.Contains(err.Error(), "missing commitUrl") { + t.Fatalf("error = %v, want missing commitUrl", err) + } + + serverError := &HTTPCommitTransport{DefaultCommitURL: server.URL + "/commit-error"} + if _, err := serverError.Commit(context.Background(), meteringDirective, payload); err == nil || + !strings.Contains(err.Error(), "500") { + t.Fatalf("error = %v, want 500 surfaced", err) + } + + invalidJSON := &HTTPCommitTransport{DefaultCommitURL: server.URL + "/commit-invalid-json"} + if _, err := invalidJSON.Commit(context.Background(), meteringDirective, payload); err == nil || + !strings.Contains(err.Error(), "invalid commit receipt") { + t.Fatalf("error = %v, want invalid receipt", err) + } + + // Directive CommitURL takes precedence over the default. + commitURL := server.URL + "/commit" + meteringDirective.CommitURL = &commitURL + routed := &HTTPCommitTransport{ + DefaultCommitURL: server.URL + "/commit-error", + Authorization: "Bearer sdk-test", + } + if _, err := routed.Commit(context.Background(), meteringDirective, payload); err != nil { + t.Fatalf("Commit via directive URL: %v", err) + } +} + +func TestMeteredSseStreamReadsMessagesAndAckDrains(t *testing.T) { + commitServer, commits := newCommitServer(t) + session, _ := newSession(t) + meteringDirective := directive(session.ChannelIDString(), "275") + meteringDirective.DeliveryID = "stream-1" + + streamBody := "event: mpp.metering\ndata: " + mustJSON(t, meteringDirective) + "\n\n" + + "event: message\ndata: {\"delta\":\"first\"}\n\n" + + "event: message\ndata: {\"delta\":\"second\"}\n\n" + + "event: mpp.usage\ndata: {\"deliveryId\":\"stream-1\",\"amount\":\"275\"}\n\n" + + "data: [DONE]" + + transport := &HTTPCommitTransport{ + DefaultCommitURL: commitServer.URL + "/commit", + Authorization: "Bearer sdk-test", + } + consumer := NewSessionConsumer(session, transport) + stream := NewMeteredSseStream(consumer, strings.NewReader(streamBody)) + + first, err := stream.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + if decodeDelta(t, first).Delta != "first" { + t.Fatalf("first message = %s", first) + } + + receipt, err := stream.Ack(context.Background()) + if err != nil { + t.Fatalf("Ack: %v", err) + } + if receipt.Amount != "275" || receipt.Cumulative != "275" { + t.Fatalf("receipt = %+v, want 275", receipt) + } + if *commits != 1 { + t.Fatalf("commits = %d, want 1", *commits) + } +} + +func TestMeteredSseStreamCanReturnConsumer(t *testing.T) { + session, _ := newSession(t) + consumer := NewSessionConsumer(session, &recordingTransport{}) + stream := NewMeteredSseStream(consumer, strings.NewReader("data: [DONE]\n\n")) + message, err := stream.Next() + if err != nil { + t.Fatalf("Next: %v", err) + } + if message != nil { + t.Fatalf("message = %s, want done", message) + } + returned := stream.IntoConsumer() + if returned.Session().Cumulative() != 0 { + t.Fatalf("cumulative = %d, want 0", returned.Session().Cumulative()) + } +} + +func TestMeteredSseStreamSurfacesEventErrors(t *testing.T) { + session, _ := newSession(t) + consumer := NewSessionConsumer(session, &recordingTransport{}) + + invalidUTF8 := NewMeteredSseStream(consumer, strings.NewReader("event: message\ndata: \xff\n\n")) + if _, err := invalidUTF8.Next(); err == nil || !strings.Contains(err.Error(), "valid UTF-8") { + t.Fatalf("error = %v, want UTF-8 rejection", err) + } + + badJSON := NewMeteredSseStream(consumer, strings.NewReader("event: metering\ndata: {\n\n")) + if _, err := badJSON.Next(); err == nil || !strings.Contains(err.Error(), "invalid mpp.metering") { + t.Fatalf("error = %v, want metering rejection", err) + } + + // Ack without a metering directive surfaces the missing-directive error + // after draining. + empty := NewMeteredSseStream(consumer, strings.NewReader("data: [DONE]\n\n")) + if _, err := empty.Ack(context.Background()); err == nil || + !strings.Contains(err.Error(), "mpp.metering") { + t.Fatalf("error = %v, want missing metering", err) + } +} diff --git a/go/protocols/mpp/client/payment_channels.go b/go/protocols/mpp/client/payment_channels.go new file mode 100644 index 000000000..c2841edc9 --- /dev/null +++ b/go/protocols/mpp/client/payment_channels.go @@ -0,0 +1,567 @@ +// Client-side helpers for payment-channel open transactions. +// +// These builders turn a parsed session challenge (SessionRequest) into the +// on-chain open transaction and the matching open action payload, applying the +// cross-SDK defaults: fee payer = challenge operator, deposit = challenge cap, +// grace period 900 seconds, random u64 salt, token program resolved from the +// challenge currency (Token-2022 for PYUSD/USDG/CASH), and the +// PendingServerSignature placeholder while the operator broadcasts. +package client + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "strconv" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// DefaultGracePeriodSeconds is the default payment-channel close grace period, +// shared across the language SDK clients. +const DefaultGracePeriodSeconds uint32 = 900 + +// PendingServerSignature is the placeholder open signature used while the +// operator still needs to submit the server-broadcast open transaction. It is +// the base58 form of an all-zero 64-byte signature (64 ones). +const PendingServerSignature = "1111111111111111111111111111111111111111111111111111111111111111" + +// PaymentChannelOpen is a fully derived payment-channel open: every channel +// parameter resolved from the challenge plus the resulting channel PDA. +type PaymentChannelOpen struct { + // ChannelID is the channel PDA derived from payer, payee, mint, + // authorized signer, and salt against ProgramID. + ChannelID solana.PublicKey + + // Payer is the wallet that funds the escrow deposit. + Payer solana.PublicKey + + // Payee is the channel beneficiary, parsed from the challenge recipient. + Payee solana.PublicKey + + // Mint is the SPL token mint, resolved from the challenge currency and + // network. + Mint solana.PublicKey + + // AuthorizedSigner is the ephemeral session key whose Ed25519 signatures + // authorize the channel's cumulative vouchers. + AuthorizedSigner solana.PublicKey + + // Salt is the random u64 that makes the channel PDA unique per open. + Salt uint64 + + // Deposit is the escrow deposit in token base units; it defaults to the + // challenge cap. + Deposit uint64 + + // GracePeriod is the close grace period in seconds (default 900). + GracePeriod uint32 + + // Recipients are the settlement distribution splits derived from the + // challenge splits; empty means no splits. + Recipients []paymentchannels.Distribution + + // TokenProgram is the program owning Mint (Token, or Token-2022 for + // PYUSD/USDG/CASH). + TokenProgram solana.PublicKey + + // ProgramID is the payment-channels program the open targets; defaults + // to the canonical program unless the challenge pins one. + ProgramID solana.PublicKey +} + +// OpenChannelParams converts the derived open into instruction-builder params. +func (o PaymentChannelOpen) OpenChannelParams() paymentchannels.OpenChannelParams { + return paymentchannels.OpenChannelParams{ + Payer: o.Payer, + Payee: o.Payee, + Mint: o.Mint, + AuthorizedSigner: o.AuthorizedSigner, + Salt: o.Salt, + Deposit: o.Deposit, + GracePeriod: o.GracePeriod, + Recipients: o.Recipients, + TokenProgram: o.TokenProgram, + ProgramID: o.ProgramID, + } +} + +// OpenPayload builds the open action payload carrying the derived channel +// parameters with the given submission mode and confirmation signature. +func (o PaymentChannelOpen) OpenPayload(mode intents.SessionMode, signature string) intents.OpenPayload { + return intents.OpenPayloadPaymentChannelWithMode( + mode, + o.ChannelID.String(), + strconv.FormatUint(o.Deposit, 10), + o.Payer.String(), + o.Payee.String(), + o.Mint.String(), + o.Salt, + o.GracePeriod, + o.AuthorizedSigner.String(), + signature, + ) +} + +// PaymentChannelOpenTransaction is a partially signed open transaction ready +// for the operator to fee-payer sign and broadcast. +type PaymentChannelOpenTransaction struct { + // ChannelID is the derived channel PDA the transaction opens. + ChannelID solana.PublicKey + + // Transaction is the standard base64 (with padding) wire encoding of the + // payer-signed legacy transaction, for OpenPayload.Transaction. + Transaction string +} + +// PaymentChannelOpenOptions overrides the challenge-derived open defaults. +// Every field is optional; the zero value applies the cross-SDK defaults. +type PaymentChannelOpenOptions struct { + // Deposit overrides the escrow deposit. Defaults to the challenge cap. + Deposit *uint64 + + // GracePeriod overrides the close grace period. Defaults to + // DefaultGracePeriodSeconds. + GracePeriod *uint32 + + // ProgramID overrides the payment-channels program. Defaults to the + // challenge programId, falling back to the canonical program. + ProgramID *solana.PublicKey + + // Recipients overrides the distribution splits. nil derives them from the + // challenge splits; a non-nil empty slice means no splits. + Recipients []paymentchannels.Distribution + + // Salt overrides the channel salt. Defaults to a random u64. + Salt *uint64 + + // TokenProgram overrides the token program. Defaults to the program + // resolved from the challenge currency (Token-2022 for PYUSD/USDG/CASH). + TokenProgram *solana.PublicKey +} + +// DerivePaymentChannelOpen resolves every open parameter from a session +// challenge: mint and token program from the currency, payee from the +// recipient, deposit from the cap, splits, program id, grace period 900s, and +// a random salt, then derives the channel PDA. +func DerivePaymentChannelOpen( + request intents.SessionRequest, + payer, authorizedSigner solana.PublicKey, + options PaymentChannelOpenOptions, +) (PaymentChannelOpen, error) { + network := "" + if request.Network != nil { + network = *request.Network + } + + mintAddress := paycore.ResolveMint(request.Currency, network) + if mintAddress == "" { + return PaymentChannelOpen{}, fmt.Errorf("session payment channels require an SPL token") + } + mint, err := parseSessionPubkey(mintAddress, "mint") + if err != nil { + return PaymentChannelOpen{}, err + } + payee, err := parseSessionPubkey(request.Recipient, "recipient") + if err != nil { + return PaymentChannelOpen{}, err + } + + deposit := uint64(0) + if options.Deposit != nil { + deposit = *options.Deposit + } else { + deposit, err = strconv.ParseUint(request.Cap, 10, 64) + if err != nil { + return PaymentChannelOpen{}, fmt.Errorf("invalid session cap: %w", err) + } + } + + gracePeriod := DefaultGracePeriodSeconds + if options.GracePeriod != nil { + gracePeriod = *options.GracePeriod + } + + programID := paymentchannels.ProgramPubkey() + switch { + case options.ProgramID != nil: + programID = *options.ProgramID + case request.ProgramID != nil: + programID, err = parseSessionPubkey(*request.ProgramID, "programId") + if err != nil { + return PaymentChannelOpen{}, err + } + } + + var tokenProgram solana.PublicKey + if options.TokenProgram != nil { + tokenProgram = *options.TokenProgram + } else { + tokenProgram, err = parseSessionPubkey( + paycore.DefaultTokenProgramForCurrency(request.Currency, network), "token program") + if err != nil { + return PaymentChannelOpen{}, err + } + } + + recipients := options.Recipients + if recipients == nil { + recipients, err = parseSessionSplits(request.Splits) + if err != nil { + return PaymentChannelOpen{}, err + } + } + + salt := uint64(0) + if options.Salt != nil { + salt = *options.Salt + } else { + salt, err = randomSalt() + if err != nil { + return PaymentChannelOpen{}, err + } + } + + channelID, _, err := paymentchannels.FindChannelPDAForProgram( + payer, payee, mint, authorizedSigner, salt, programID) + if err != nil { + return PaymentChannelOpen{}, err + } + + return PaymentChannelOpen{ + ChannelID: channelID, + Payer: payer, + Payee: payee, + Mint: mint, + AuthorizedSigner: authorizedSigner, + Salt: salt, + Deposit: deposit, + GracePeriod: gracePeriod, + Recipients: recipients, + TokenProgram: tokenProgram, + ProgramID: programID, + }, nil +} + +// BuildOpenPaymentChannelTransactionParams carries the inputs for +// BuildOpenPaymentChannelTransaction. +type BuildOpenPaymentChannelTransactionParams struct { + // Request is the parsed session challenge. + Request intents.SessionRequest + + // Signer is the payer wallet; it partially signs the open transaction. + Signer solanatx.Signer + + // AuthorizedSigner is the ephemeral session voucher key. + AuthorizedSigner solana.PublicKey + + // FeePayer overrides the transaction fee payer. Defaults to the challenge + // operator, which completes the signature and broadcasts. + FeePayer *solana.PublicKey + + // RecentBlockhash is the base58 blockhash for the transaction lifetime. + // Empty echoes the challenge recentBlockhash. + RecentBlockhash string + + // Options overrides the challenge-derived open defaults. + Options PaymentChannelOpenOptions +} + +// BuildOpenPaymentChannelTransaction derives the open from the challenge and +// assembles the legacy open transaction with the operator as fee payer, +// partially signed by the payer, base64-encoded for OpenPayload.Transaction. +func BuildOpenPaymentChannelTransaction(params BuildOpenPaymentChannelTransactionParams) (PaymentChannelOpenTransaction, error) { + var feePayer solana.PublicKey + if params.FeePayer != nil { + feePayer = *params.FeePayer + } else { + var err error + feePayer, err = parseSessionPubkey(params.Request.Operator, "operator") + if err != nil { + return PaymentChannelOpenTransaction{}, err + } + } + open, err := DerivePaymentChannelOpen( + params.Request, params.Signer.PublicKey(), params.AuthorizedSigner, params.Options) + if err != nil { + return PaymentChannelOpenTransaction{}, err + } + blockhash, err := resolveChallengeBlockhash(params.Request, params.RecentBlockhash) + if err != nil { + return PaymentChannelOpenTransaction{}, err + } + return buildOpenPaymentChannelTx(open, params.Signer, feePayer, blockhash) +} + +// PaymentChannelSessionOpen bundles a derived open, the live session tracking +// it, and the open action ready to serialize into a credential. +type PaymentChannelSessionOpen struct { + // Open holds the fully derived channel parameters, including the channel + // PDA the session settles against. + Open PaymentChannelOpen + + // Session is the live tracker that signs cumulative vouchers for the + // opened channel. + Session *ActiveSession + + // Action is the open session action, ready to serialize into the payment + // credential sent back to the server. + Action intents.SessionAction +} + +// PaymentChannelSessionOpenOptions configures CreatePaymentChannelSessionOpener. +type PaymentChannelSessionOpenOptions struct { + // Open overrides the challenge-derived open defaults. + Open PaymentChannelOpenOptions + + // Signature is the open confirmation signature. Defaults to + // PendingServerSignature when the operator broadcasts. + Signature *string + + // Cumulative resumes the session watermark. Defaults to zero. + Cumulative *uint64 + + // ExpiresAt sets the voucher expiry. Defaults to + // intents.DefaultSessionExpiresAt. + ExpiresAt *int64 +} + +// ServerOpenedPaymentChannelSessionOpenOptions configures +// CreateServerOpenedPaymentChannelSessionOpener. +type ServerOpenedPaymentChannelSessionOpenOptions struct { + // Open overrides the challenge-derived open defaults. + Open PaymentChannelOpenOptions + + // Payer overrides the channel payer. Defaults to the challenge operator, + // which funds the escrow when it opens the channel server-side. + Payer *solana.PublicKey + + // Signature is the open confirmation signature. Defaults to + // PendingServerSignature. + Signature *string + + // Cumulative resumes the session watermark. Defaults to zero. + Cumulative *uint64 + + // ExpiresAt sets the voucher expiry. Defaults to + // intents.DefaultSessionExpiresAt. + ExpiresAt *int64 +} + +// CreatePaymentChannelSessionOpener derives a pull/clientVoucher channel open +// from the challenge, builds the payer-signed open transaction against the +// challenge recentBlockhash, and returns the active session plus the open +// action carrying the transaction for the operator to broadcast. +func CreatePaymentChannelSessionOpener( + request intents.SessionRequest, + payerSigner solanatx.Signer, + sessionSigner VoucherSigner, + recentBlockhash string, + options PaymentChannelSessionOpenOptions, +) (PaymentChannelSessionOpen, error) { + if err := ensureClientVoucherPull(request); err != nil { + return PaymentChannelSessionOpen{}, err + } + authorizedSigner := sessionSigner.PublicKey() + feePayer, err := parseSessionPubkey(request.Operator, "operator") + if err != nil { + return PaymentChannelSessionOpen{}, err + } + open, err := DerivePaymentChannelOpen(request, payerSigner.PublicKey(), authorizedSigner, options.Open) + if err != nil { + return PaymentChannelSessionOpen{}, err + } + blockhash, err := resolveChallengeBlockhash(request, recentBlockhash) + if err != nil { + return PaymentChannelSessionOpen{}, err + } + tx, err := buildOpenPaymentChannelTx(open, payerSigner, feePayer, blockhash) + if err != nil { + return PaymentChannelSessionOpen{}, err + } + session := newConfiguredSession(open.ChannelID, sessionSigner, options.Cumulative, options.ExpiresAt) + signature := PendingServerSignature + if options.Signature != nil { + signature = *options.Signature + } + action := intents.NewOpenAction( + open.OpenPayload(intents.SessionModePull, signature).WithTransaction(tx.Transaction)) + + return PaymentChannelSessionOpen{Open: open, Session: session, Action: action}, nil +} + +// CreateServerOpenedPaymentChannelSessionOpener derives a pull/clientVoucher +// channel open the operator funds and broadcasts entirely server-side: no +// transaction is attached and the signature defaults to +// PendingServerSignature. +func CreateServerOpenedPaymentChannelSessionOpener( + request intents.SessionRequest, + sessionSigner VoucherSigner, + options ServerOpenedPaymentChannelSessionOpenOptions, +) (PaymentChannelSessionOpen, error) { + if err := ensureClientVoucherPull(request); err != nil { + return PaymentChannelSessionOpen{}, err + } + var payer solana.PublicKey + if options.Payer != nil { + payer = *options.Payer + } else { + var err error + payer, err = parseSessionPubkey(request.Operator, "operator") + if err != nil { + return PaymentChannelSessionOpen{}, err + } + } + authorizedSigner := sessionSigner.PublicKey() + open, err := DerivePaymentChannelOpen(request, payer, authorizedSigner, options.Open) + if err != nil { + return PaymentChannelSessionOpen{}, err + } + session := newConfiguredSession(open.ChannelID, sessionSigner, options.Cumulative, options.ExpiresAt) + signature := PendingServerSignature + if options.Signature != nil { + signature = *options.Signature + } + action := intents.NewOpenAction(open.OpenPayload(intents.SessionModePull, signature)) + + return PaymentChannelSessionOpen{Open: open, Session: session, Action: action}, nil +} + +// NewEphemeralSessionSigner generates a fresh in-memory keypair to use as a +// session authorizedSigner. Session voucher keys are ephemeral by design: they +// authorize spend only within one channel's deposit, so generating one per +// session is the production path. +func NewEphemeralSessionSigner() (VoucherSigner, error) { + key, err := solana.NewRandomPrivateKey() + if err != nil { + return nil, fmt.Errorf("generate session signer: %w", err) + } + return key, nil +} + +// buildOpenPaymentChannelTx assembles the single-instruction legacy open +// transaction with the given fee payer and partially signs it with the payer +// wallet, leaving the fee-payer slot zeroed for the operator. +func buildOpenPaymentChannelTx( + open PaymentChannelOpen, + payerSigner solanatx.Signer, + feePayer solana.PublicKey, + recentBlockhash solana.Hash, +) (PaymentChannelOpenTransaction, error) { + ix, err := paymentchannels.BuildOpenInstruction(open.OpenChannelParams()) + if err != nil { + return PaymentChannelOpenTransaction{}, err + } + tx, err := solana.NewTransaction( + []solana.Instruction{ix}, + recentBlockhash, + solana.TransactionPayer(feePayer), + ) + if err != nil { + return PaymentChannelOpenTransaction{}, fmt.Errorf("build payment-channel open transaction: %w", err) + } + if err := solanatx.SignTransaction(tx, payerSigner); err != nil { + return PaymentChannelOpenTransaction{}, fmt.Errorf("payment-channel open signing failed: %w", err) + } + encoded, err := solanatx.EncodeTransactionBase64(tx) + if err != nil { + return PaymentChannelOpenTransaction{}, fmt.Errorf("payment-channel open tx serialization failed: %w", err) + } + return PaymentChannelOpenTransaction{ChannelID: open.ChannelID, Transaction: encoded}, nil +} + +// ensureClientVoucherPull rejects challenges that do not advertise pull mode +// with the clientVoucher strategy, the only combination these openers serve. +func ensureClientVoucherPull(request intents.SessionRequest) error { + pull := false + for _, mode := range request.Modes { + if mode == intents.SessionModePull { + pull = true + break + } + } + if !pull { + return fmt.Errorf("session challenge does not advertise pull mode") + } + if request.PullVoucherStrategy == nil || + *request.PullVoucherStrategy != intents.SessionPullVoucherStrategyClientVoucher { + return fmt.Errorf("session challenge does not advertise pull + clientVoucher") + } + return nil +} + +// newConfiguredSession creates the opener's ActiveSession with the optional +// resumed cumulative and voucher expiry applied. +func newConfiguredSession( + channelID solana.PublicKey, + signer VoucherSigner, + cumulative *uint64, + expiresAt *int64, +) *ActiveSession { + watermark := uint64(0) + if cumulative != nil { + watermark = *cumulative + } + expiry := intents.DefaultSessionExpiresAt + if expiresAt != nil { + expiry = *expiresAt + } + return NewActiveSessionWithWatermark(channelID, signer, watermark, expiry) +} + +// resolveChallengeBlockhash parses the explicit blockhash, falling back to the +// challenge recentBlockhash so server-prefetched lifetimes are echoed into the +// open transaction without a second RPC round-trip. +func resolveChallengeBlockhash(request intents.SessionRequest, explicit string) (solana.Hash, error) { + raw := explicit + if raw == "" && request.RecentBlockhash != nil { + raw = *request.RecentBlockhash + } + if raw == "" { + return solana.Hash{}, fmt.Errorf("session open requires a recent blockhash: none provided and the challenge omits recentBlockhash") + } + hash, err := solana.HashFromBase58(raw) + if err != nil { + return solana.Hash{}, fmt.Errorf("invalid recent blockhash %q: %w", raw, err) + } + return hash, nil +} + +// parseSessionSplits converts challenge splits into instruction distributions. +func parseSessionSplits(splits []intents.SessionSplit) ([]paymentchannels.Distribution, error) { + recipients := make([]paymentchannels.Distribution, 0, len(splits)) + for _, split := range splits { + recipient, err := parseSessionPubkey(split.Recipient, "split recipient") + if err != nil { + return nil, err + } + recipients = append(recipients, paymentchannels.Distribution{ + Recipient: recipient, + Bps: split.BPS, + }) + } + return recipients, nil +} + +// parseSessionPubkey parses a base58 pubkey with a labeled error. +func parseSessionPubkey(value, label string) (solana.PublicKey, error) { + key, err := solana.PublicKeyFromBase58(value) + if err != nil { + return solana.PublicKey{}, fmt.Errorf("invalid %s: %w", label, err) + } + return key, nil +} + +// randomSalt draws a random u64 channel salt from the system CSPRNG. +func randomSalt() (uint64, error) { + var buf [8]byte + if _, err := rand.Read(buf[:]); err != nil { + return 0, fmt.Errorf("generate channel salt: %w", err) + } + return binary.LittleEndian.Uint64(buf[:]), nil +} diff --git a/go/protocols/mpp/client/payment_channels_test.go b/go/protocols/mpp/client/payment_channels_test.go new file mode 100644 index 000000000..58e5893b5 --- /dev/null +++ b/go/protocols/mpp/client/payment_channels_test.go @@ -0,0 +1,553 @@ +package client + +import ( + "strings" + "testing" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +func u64ptr(v uint64) *uint64 { return &v } + +func strptr(v string) *string { return &v } + +func testSessionRequest(operator, recipient solana.PublicKey) intents.SessionRequest { + network := "localnet" + strategy := intents.SessionPullVoucherStrategyClientVoucher + return intents.SessionRequest{ + Cap: "1000", + Currency: "USDC", + Network: &network, + Operator: operator.String(), + Recipient: recipient.String(), + Modes: []intents.SessionMode{intents.SessionModePull}, + PullVoucherStrategy: &strategy, + } +} + +func decodeOpenTransaction(t *testing.T, encoded string) *solana.Transaction { + t.Helper() + tx, err := solanatx.DecodeTransactionBase64(encoded) + if err != nil { + t.Fatalf("decode open transaction: %v", err) + } + return tx +} + +func TestDerivePaymentChannelOpenUsesChallengeDefaultsAndSplits(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + splitRecipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + request.Splits = []intents.SessionSplit{{Recipient: splitRecipient.String(), BPS: 10}} + + payer := testutil.NewPrivateKey().PublicKey() + authorizedSigner := testutil.NewPrivateKey().PublicKey() + open, err := DerivePaymentChannelOpen(request, payer, authorizedSigner, PaymentChannelOpenOptions{ + Salt: u64ptr(42), + }) + if err != nil { + t.Fatalf("DerivePaymentChannelOpen: %v", err) + } + + if !open.Payer.Equals(payer) { + t.Fatalf("payer = %s, want %s", open.Payer, payer) + } + if !open.Payee.Equals(recipient) { + t.Fatalf("payee = %s, want challenge recipient", open.Payee) + } + if !open.AuthorizedSigner.Equals(authorizedSigner) { + t.Fatalf("authorizedSigner = %s", open.AuthorizedSigner) + } + if open.Deposit != 1000 { + t.Fatalf("deposit = %d, want challenge cap 1000", open.Deposit) + } + if open.GracePeriod != DefaultGracePeriodSeconds { + t.Fatalf("gracePeriod = %d, want %d", open.GracePeriod, DefaultGracePeriodSeconds) + } + if open.Salt != 42 { + t.Fatalf("salt = %d, want 42", open.Salt) + } + if len(open.Recipients) != 1 || !open.Recipients[0].Recipient.Equals(splitRecipient) || open.Recipients[0].Bps != 10 { + t.Fatalf("recipients = %+v, want challenge split", open.Recipients) + } + // Localnet resolves to the mainnet USDC mint (Surfpool clones mainnet state). + if open.Mint.String() != paycore.USDCMainnetMint { + t.Fatalf("mint = %s, want mainnet USDC", open.Mint) + } + if open.TokenProgram.String() != paycore.TokenProgram { + t.Fatalf("tokenProgram = %s, want SPL Token", open.TokenProgram) + } + if !open.ProgramID.Equals(paymentchannels.ProgramPubkey()) { + t.Fatalf("programID = %s, want canonical", open.ProgramID) + } + expectedChannel, _, err := paymentchannels.FindChannelPDAForProgram( + payer, recipient, open.Mint, authorizedSigner, 42, open.ProgramID) + if err != nil { + t.Fatalf("FindChannelPDAForProgram: %v", err) + } + if !open.ChannelID.Equals(expectedChannel) { + t.Fatalf("channelID = %s, want %s", open.ChannelID, expectedChannel) + } +} + +func TestDerivePaymentChannelOpenHonorsExplicitOptions(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + splitRecipient := testutil.NewPrivateKey().PublicKey() + programID := testutil.NewPrivateKey().PublicKey() + tokenProgram := solana.MustPublicKeyFromBase58(paycore.Token2022Program) + request := testSessionRequest(operator, recipient) + request.Cap = "not-a-number" + request.Splits = []intents.SessionSplit{{Recipient: "not-a-pubkey", BPS: 999}} + + gracePeriod := uint32(12) + open, err := DerivePaymentChannelOpen( + request, + testutil.NewPrivateKey().PublicKey(), + testutil.NewPrivateKey().PublicKey(), + PaymentChannelOpenOptions{ + Deposit: u64ptr(55), + GracePeriod: &gracePeriod, + ProgramID: &programID, + Recipients: []paymentchannels.Distribution{{Recipient: splitRecipient, Bps: 25}}, + Salt: u64ptr(7), + TokenProgram: &tokenProgram, + }, + ) + if err != nil { + t.Fatalf("DerivePaymentChannelOpen: %v", err) + } + + if open.Deposit != 55 { + t.Fatalf("deposit = %d, want explicit 55", open.Deposit) + } + if open.GracePeriod != 12 { + t.Fatalf("gracePeriod = %d, want explicit 12", open.GracePeriod) + } + if !open.ProgramID.Equals(programID) { + t.Fatalf("programID = %s, want explicit", open.ProgramID) + } + if !open.TokenProgram.Equals(tokenProgram) { + t.Fatalf("tokenProgram = %s, want explicit Token-2022", open.TokenProgram) + } + if len(open.Recipients) != 1 || !open.Recipients[0].Recipient.Equals(splitRecipient) || open.Recipients[0].Bps != 25 { + t.Fatalf("recipients = %+v, want explicit", open.Recipients) + } +} + +func TestDerivePaymentChannelOpenResolvesToken2022FromCurrency(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + request.Currency = "PYUSD" + + open, err := DerivePaymentChannelOpen( + request, + testutil.NewPrivateKey().PublicKey(), + testutil.NewPrivateKey().PublicKey(), + PaymentChannelOpenOptions{Salt: u64ptr(1)}, + ) + if err != nil { + t.Fatalf("DerivePaymentChannelOpen: %v", err) + } + if open.TokenProgram.String() != paycore.Token2022Program { + t.Fatalf("tokenProgram = %s, want Token-2022 for PYUSD", open.TokenProgram) + } + if open.Mint.String() != paycore.PYUSDMainnetMint { + t.Fatalf("mint = %s, want mainnet PYUSD", open.Mint) + } +} + +func TestDerivePaymentChannelOpenDefaultsToRandomSalt(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + payer := testutil.NewPrivateKey().PublicKey() + authorizedSigner := testutil.NewPrivateKey().PublicKey() + + first, err := DerivePaymentChannelOpen(request, payer, authorizedSigner, PaymentChannelOpenOptions{}) + if err != nil { + t.Fatalf("DerivePaymentChannelOpen: %v", err) + } + second, err := DerivePaymentChannelOpen(request, payer, authorizedSigner, PaymentChannelOpenOptions{}) + if err != nil { + t.Fatalf("DerivePaymentChannelOpen: %v", err) + } + if first.Salt == second.Salt { + t.Fatalf("two derived opens reused salt %d; want random default", first.Salt) + } + if first.ChannelID.Equals(second.ChannelID) { + t.Fatal("two derived opens reused the channel PDA; want salt-unique channels") + } +} + +func TestDerivePaymentChannelOpenRejectsInvalidChallengeValues(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + payer := testutil.NewPrivateKey().PublicKey() + authorizedSigner := testutil.NewPrivateKey().PublicKey() + + cases := []struct { + name string + mutate func(*intents.SessionRequest) + wantErr string + }{ + {"native SOL", func(r *intents.SessionRequest) { r.Currency = "SOL" }, "SPL token"}, + {"bad cap", func(r *intents.SessionRequest) { r.Cap = "not-a-number" }, "session cap"}, + {"bad recipient", func(r *intents.SessionRequest) { r.Recipient = "not-a-pubkey" }, "recipient"}, + {"bad programId", func(r *intents.SessionRequest) { r.ProgramID = strptr("not-a-program") }, "programId"}, + {"bad split", func(r *intents.SessionRequest) { + r.Splits = []intents.SessionSplit{{Recipient: "not-a-pubkey", BPS: 10}} + }, "split recipient"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + request := testSessionRequest(operator, recipient) + tc.mutate(&request) + _, err := DerivePaymentChannelOpen(request, payer, authorizedSigner, PaymentChannelOpenOptions{}) + if err == nil || !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("error = %v, want substring %q", err, tc.wantErr) + } + }) + } +} + +func TestBuildOpenPaymentChannelTransactionPartiallySignsForOperatorBroadcast(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + payerSigner := testutil.NewPrivateKey() + authorizedSigner := testutil.NewPrivateKey().PublicKey() + blockhash := solana.HashFromBytes(testutil.NewPrivateKey().PublicKey().Bytes()) + + built, err := BuildOpenPaymentChannelTransaction(BuildOpenPaymentChannelTransactionParams{ + Request: request, + Signer: payerSigner, + AuthorizedSigner: authorizedSigner, + RecentBlockhash: blockhash.String(), + Options: PaymentChannelOpenOptions{Salt: u64ptr(99)}, + }) + if err != nil { + t.Fatalf("BuildOpenPaymentChannelTransaction: %v", err) + } + + expected, err := DerivePaymentChannelOpen(request, payerSigner.PublicKey(), authorizedSigner, PaymentChannelOpenOptions{ + Salt: u64ptr(99), + }) + if err != nil { + t.Fatalf("DerivePaymentChannelOpen: %v", err) + } + if !built.ChannelID.Equals(expected.ChannelID) { + t.Fatalf("channelID = %s, want %s", built.ChannelID, expected.ChannelID) + } + + tx := decodeOpenTransaction(t, built.Transaction) + if !tx.Message.AccountKeys[0].Equals(operator) { + t.Fatalf("fee payer = %s, want challenge operator", tx.Message.AccountKeys[0]) + } + if len(tx.Message.Instructions) != 1 { + t.Fatalf("instructions = %d, want 1", len(tx.Message.Instructions)) + } + if tx.Message.RecentBlockhash != blockhash { + t.Fatalf("recentBlockhash = %s, want explicit %s", tx.Message.RecentBlockhash, blockhash) + } + + // Fee-payer (operator) slot left zeroed for the server to complete. + if !tx.Signatures[0].IsZero() { + t.Fatalf("operator signature slot should be zeroed, got %s", tx.Signatures[0]) + } + payerIndex := -1 + for i, key := range tx.Message.Signers() { + if key.Equals(payerSigner.PublicKey()) { + payerIndex = i + break + } + } + if payerIndex < 0 { + t.Fatal("payer signer is not a required transaction signer") + } + if tx.Signatures[payerIndex].IsZero() { + t.Fatal("payer signature missing; want partial sign") + } +} + +func TestBuildOpenPaymentChannelTransactionUsesExplicitFeePayerAndChallengeBlockhash(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + explicitFeePayer := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + challengeBlockhash := solana.HashFromBytes(testutil.NewPrivateKey().PublicKey().Bytes()) + request.RecentBlockhash = strptr(challengeBlockhash.String()) + payerSigner := testutil.NewPrivateKey() + + built, err := BuildOpenPaymentChannelTransaction(BuildOpenPaymentChannelTransactionParams{ + Request: request, + Signer: payerSigner, + AuthorizedSigner: testutil.NewPrivateKey().PublicKey(), + FeePayer: &explicitFeePayer, + Options: PaymentChannelOpenOptions{Salt: u64ptr(123)}, + }) + if err != nil { + t.Fatalf("BuildOpenPaymentChannelTransaction: %v", err) + } + tx := decodeOpenTransaction(t, built.Transaction) + if !tx.Message.AccountKeys[0].Equals(explicitFeePayer) { + t.Fatalf("fee payer = %s, want explicit", tx.Message.AccountKeys[0]) + } + if tx.Message.RecentBlockhash != challengeBlockhash { + t.Fatalf("recentBlockhash = %s, want challenge echo %s", tx.Message.RecentBlockhash, challengeBlockhash) + } +} + +func TestBuildOpenPaymentChannelTransactionRequiresABlockhash(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + + _, err := BuildOpenPaymentChannelTransaction(BuildOpenPaymentChannelTransactionParams{ + Request: request, + Signer: testutil.NewPrivateKey(), + AuthorizedSigner: testutil.NewPrivateKey().PublicKey(), + Options: PaymentChannelOpenOptions{Salt: u64ptr(1)}, + }) + if err == nil || !strings.Contains(err.Error(), "recent blockhash") { + t.Fatalf("error = %v, want recent blockhash requirement", err) + } +} + +func TestCreatePaymentChannelSessionOpenerBuildsPullClientVoucherAction(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + payerSigner := testutil.NewPrivateKey() + sessionSigner := testutil.NewPrivateKey() + blockhash := solana.HashFromBytes(testutil.NewPrivateKey().PublicKey().Bytes()) + + opened, err := CreatePaymentChannelSessionOpener( + request, payerSigner, sessionSigner, blockhash.String(), + PaymentChannelSessionOpenOptions{Open: PaymentChannelOpenOptions{Salt: u64ptr(11)}}, + ) + if err != nil { + t.Fatalf("CreatePaymentChannelSessionOpener: %v", err) + } + + if !opened.Session.ChannelID().Equals(opened.Open.ChannelID) { + t.Fatalf("session channel = %s, want %s", opened.Session.ChannelID(), opened.Open.ChannelID) + } + if opened.Action.Open == nil { + t.Fatal("expected open action") + } + payload := opened.Action.Open + if payload.Mode != intents.SessionModePull { + t.Fatalf("mode = %s, want pull", payload.Mode) + } + if payload.ChannelID == nil || *payload.ChannelID != opened.Open.ChannelID.String() { + t.Fatalf("channelId = %v, want %s", payload.ChannelID, opened.Open.ChannelID) + } + if payload.Payer == nil || *payload.Payer != payerSigner.PublicKey().String() { + t.Fatalf("payer = %v, want payer signer", payload.Payer) + } + if payload.AuthorizedSigner != sessionSigner.PublicKey().String() { + t.Fatalf("authorizedSigner = %s, want session signer", payload.AuthorizedSigner) + } + if payload.Signature != PendingServerSignature { + t.Fatalf("signature = %s, want pending placeholder", payload.Signature) + } + if payload.Transaction == nil { + t.Fatal("transaction missing; want payer-signed open tx attached") + } + if payload.TokenAccount != nil || payload.ApprovedAmount != nil || + payload.InitMultiDelegateTx != nil || payload.UpdateDelegationTx != nil { + t.Fatal("pull SPL-delegation fields must be unset for payment-channel opens") + } + tx := decodeOpenTransaction(t, *payload.Transaction) + if !tx.Message.AccountKeys[0].Equals(operator) { + t.Fatalf("fee payer = %s, want operator", tx.Message.AccountKeys[0]) + } +} + +func TestCreatePaymentChannelSessionOpenerAppliesSessionOptions(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + blockhash := solana.HashFromBytes(testutil.NewPrivateKey().PublicKey().Bytes()) + + expiresAt := int64(1234) + opened, err := CreatePaymentChannelSessionOpener( + request, testutil.NewPrivateKey(), testutil.NewPrivateKey(), blockhash.String(), + PaymentChannelSessionOpenOptions{ + Open: PaymentChannelOpenOptions{Salt: u64ptr(19)}, + Signature: strptr("operator-will-fill"), + Cumulative: u64ptr(20), + ExpiresAt: &expiresAt, + }, + ) + if err != nil { + t.Fatalf("CreatePaymentChannelSessionOpener: %v", err) + } + if opened.Action.Open.Signature != "operator-will-fill" { + t.Fatalf("signature = %s, want explicit", opened.Action.Open.Signature) + } + voucher, err := opened.Session.PrepareIncrement(5) + if err != nil { + t.Fatalf("PrepareIncrement: %v", err) + } + if voucher.Data.Cumulative != "25" { + t.Fatalf("cumulative = %s, want resumed 20 + 5", voucher.Data.Cumulative) + } + if voucher.Data.ExpiresAt != 1234 { + t.Fatalf("expiresAt = %d, want explicit 1234", voucher.Data.ExpiresAt) + } +} + +func TestCreateServerOpenedSessionOpenerUsesOperatorPayerWithoutTransaction(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + sessionSigner := testutil.NewPrivateKey() + + opened, err := CreateServerOpenedPaymentChannelSessionOpener( + request, sessionSigner, + ServerOpenedPaymentChannelSessionOpenOptions{Open: PaymentChannelOpenOptions{Salt: u64ptr(13)}}, + ) + if err != nil { + t.Fatalf("CreateServerOpenedPaymentChannelSessionOpener: %v", err) + } + if !opened.Open.Payer.Equals(operator) { + t.Fatalf("payer = %s, want operator", opened.Open.Payer) + } + payload := opened.Action.Open + if payload == nil { + t.Fatal("expected open action") + } + if payload.Mode != intents.SessionModePull { + t.Fatalf("mode = %s, want pull", payload.Mode) + } + if payload.Payer == nil || *payload.Payer != request.Operator { + t.Fatalf("payer = %v, want operator", payload.Payer) + } + if payload.AuthorizedSigner != sessionSigner.PublicKey().String() { + t.Fatalf("authorizedSigner = %s", payload.AuthorizedSigner) + } + if payload.Signature != PendingServerSignature { + t.Fatalf("signature = %s, want pending placeholder", payload.Signature) + } + if payload.Transaction != nil { + t.Fatal("transaction must be unset for server-opened channels") + } + if payload.TokenAccount != nil || payload.ApprovedAmount != nil { + t.Fatal("pull SPL-delegation fields must be unset") + } +} + +func TestSessionOpenerRejectsNonPullChallenge(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + request.Modes = []intents.SessionMode{intents.SessionModePush} + request.PullVoucherStrategy = nil + + _, err := CreateServerOpenedPaymentChannelSessionOpener( + request, testutil.NewPrivateKey(), ServerOpenedPaymentChannelSessionOpenOptions{}) + if err == nil || !strings.Contains(err.Error(), "pull mode") { + t.Fatalf("error = %v, want pull-mode rejection", err) + } +} + +func TestSessionOpenerRejectsOperatedVoucherPullChallenge(t *testing.T) { + operator := testutil.NewPrivateKey().PublicKey() + recipient := testutil.NewPrivateKey().PublicKey() + request := testSessionRequest(operator, recipient) + operated := intents.SessionPullVoucherStrategyOperatedVoucher + request.PullVoucherStrategy = &operated + + _, err := CreateServerOpenedPaymentChannelSessionOpener( + request, testutil.NewPrivateKey(), ServerOpenedPaymentChannelSessionOpenOptions{}) + if err == nil || !strings.Contains(err.Error(), "does not advertise pull + clientVoucher") { + t.Fatalf("error = %v, want operated-voucher rejection", err) + } +} + +func TestNewEphemeralSessionSignerGeneratesDistinctKeys(t *testing.T) { + a, err := NewEphemeralSessionSigner() + if err != nil { + t.Fatalf("NewEphemeralSessionSigner: %v", err) + } + b, err := NewEphemeralSessionSigner() + if err != nil { + t.Fatalf("NewEphemeralSessionSigner: %v", err) + } + if a.PublicKey().Equals(b.PublicKey()) { + t.Fatal("two ephemeral session signers share a public key") + } + preimage := []byte("test-message") + sig, err := a.Sign(preimage) + if err != nil { + t.Fatalf("Sign: %v", err) + } + if sig.IsZero() { + t.Fatal("ephemeral signer produced a zero signature") + } +} + +func TestSessionOpenerErrorPaths(t *testing.T) { + recipient := testutil.NewPrivateKey().PublicKey() + operator := testutil.NewPrivateKey().PublicKey() + blockhash := solana.HashFromBytes(testutil.NewPrivateKey().PublicKey().Bytes()) + + badOperator := testSessionRequest(operator, recipient) + badOperator.Operator = "not-a-pubkey" + if _, err := CreatePaymentChannelSessionOpener( + badOperator, testutil.NewPrivateKey(), testutil.NewPrivateKey(), blockhash.String(), + PaymentChannelSessionOpenOptions{}); err == nil || !strings.Contains(err.Error(), "operator") { + t.Fatalf("error = %v, want invalid operator", err) + } + if _, err := CreateServerOpenedPaymentChannelSessionOpener( + badOperator, testutil.NewPrivateKey(), + ServerOpenedPaymentChannelSessionOpenOptions{}); err == nil || !strings.Contains(err.Error(), "operator") { + t.Fatalf("error = %v, want invalid operator", err) + } + + solCurrency := testSessionRequest(operator, recipient) + solCurrency.Currency = "SOL" + if _, err := CreatePaymentChannelSessionOpener( + solCurrency, testutil.NewPrivateKey(), testutil.NewPrivateKey(), blockhash.String(), + PaymentChannelSessionOpenOptions{}); err == nil || !strings.Contains(err.Error(), "SPL token") { + t.Fatalf("error = %v, want SPL token requirement", err) + } + if _, err := CreateServerOpenedPaymentChannelSessionOpener( + solCurrency, testutil.NewPrivateKey(), + ServerOpenedPaymentChannelSessionOpenOptions{}); err == nil || !strings.Contains(err.Error(), "SPL token") { + t.Fatalf("error = %v, want SPL token requirement", err) + } + + noBlockhash := testSessionRequest(operator, recipient) + if _, err := CreatePaymentChannelSessionOpener( + noBlockhash, testutil.NewPrivateKey(), testutil.NewPrivateKey(), "", + PaymentChannelSessionOpenOptions{}); err == nil || !strings.Contains(err.Error(), "recent blockhash") { + t.Fatalf("error = %v, want blockhash requirement", err) + } + if _, err := CreatePaymentChannelSessionOpener( + noBlockhash, testutil.NewPrivateKey(), testutil.NewPrivateKey(), "!!bad-base58!!", + PaymentChannelSessionOpenOptions{}); err == nil || !strings.Contains(err.Error(), "invalid recent blockhash") { + t.Fatalf("error = %v, want invalid blockhash", err) + } + + badOperatorTx := testSessionRequest(operator, recipient) + badOperatorTx.Operator = "not-a-pubkey" + if _, err := BuildOpenPaymentChannelTransaction(BuildOpenPaymentChannelTransactionParams{ + Request: badOperatorTx, + Signer: testutil.NewPrivateKey(), + AuthorizedSigner: testutil.NewPrivateKey().PublicKey(), + RecentBlockhash: blockhash.String(), + }); err == nil || !strings.Contains(err.Error(), "operator") { + t.Fatalf("error = %v, want invalid operator", err) + } +} diff --git a/go/protocols/mpp/client/session.go b/go/protocols/mpp/client/session.go new file mode 100644 index 000000000..ce1e51296 --- /dev/null +++ b/go/protocols/mpp/client/session.go @@ -0,0 +1,364 @@ +// Client-side session intent implementation. +// +// ActiveSession tracks an open payment channel and signs cumulative vouchers +// for each metered API call. Vouchers are Ed25519-signed over the on-chain +// Borsh voucher layout used by the payment-channels program, so the same bytes +// the server verifies on the HTTP credential are the bytes the on-chain settle +// instruction consumes. +// +// Scope is client-only PUSH (payment-channel) plus pull/clientVoucher, both +// served by the challenge-driven openers in payment_channels.go: the client +// signs cumulative vouchers off-chain over a payment channel the operator +// settles. Pull/operatedVoucher (the multi-delegator program), the SPL +// approve-delegation builder for non-channel pull opens, and the server +// verification path are out of scope. +// +// The language SDKs produce byte-identical voucher signatures and credentials; +// the cross-language interop harness pins this behavior. +package client + +import ( + "fmt" + "strconv" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// VoucherSigner signs the 48-byte voucher preimage with the ephemeral session +// key. It is the minimal Ed25519 message-signing surface shared with the +// charge client (solanatx.Signer satisfies it), so memory signers, hardware +// wallets, and cloud KMS backends all work unchanged. +type VoucherSigner = solanatx.Signer + +// ActiveSession tracks the client-side state of an active payment session. +// +// It holds the session signing key and advances the cumulative watermark with +// each signed voucher. Vouchers are cumulative high-water marks: each one MUST +// strictly exceed the previous, and the signer's public key is the +// authorizedSigner passed to the server in the open action. +// +// ActiveSession is not safe for concurrent use; serialize access from one +// goroutine or guard it with a mutex. +type ActiveSession struct { + // channelID is the on-chain channel PDA the vouchers settle against; its + // raw 32 bytes lead the 48-byte voucher preimage. + channelID solana.PublicKey + + // cumulative is the watermark in token base units: the cumulative total + // covered by the last recorded voucher, not a per-request delta. + cumulative uint64 + + // nonce counts recorded vouchers; it is carried in the voucher JSON for + // server bookkeeping but is not part of the signed 48-byte preimage. + nonce uint64 + + // expiresAt is the voucher expiry as Unix epoch seconds, encoded + // little-endian into the final 8 bytes of each voucher preimage. + expiresAt int64 + + // signer is the ephemeral session key (the channel authorizedSigner) + // that Ed25519-signs voucher preimages. + signer VoucherSigner +} + +// NewActiveSession creates a session tracker for the channel obtained after +// opening, signing vouchers with signer until DefaultSessionExpiresAt. +func NewActiveSession(channelID solana.PublicKey, signer VoucherSigner) *ActiveSession { + return NewActiveSessionAt(channelID, signer, intents.DefaultSessionExpiresAt) +} + +// NewActiveSessionAt creates a session tracker with an explicit voucher expiry. +func NewActiveSessionAt(channelID solana.PublicKey, signer VoucherSigner, expiresAt int64) *ActiveSession { + return &ActiveSession{ + channelID: channelID, + expiresAt: expiresAt, + signer: signer, + } +} + +// NewActiveSessionWithWatermark creates a session tracker resumed at a known +// settled cumulative watermark, e.g. when re-attaching to a channel the server +// already holds vouchers for. Only the cumulative watermark is resumed; the +// nonce starts at zero. +func NewActiveSessionWithWatermark(channelID solana.PublicKey, signer VoucherSigner, cumulative uint64, expiresAt int64) *ActiveSession { + session := NewActiveSessionAt(channelID, signer, expiresAt) + session.cumulative = cumulative + return session +} + +// SetExpiresAt updates the expiry timestamp used for subsequent vouchers. +func (s *ActiveSession) SetExpiresAt(expiresAt int64) { s.expiresAt = expiresAt } + +// Cumulative returns the current cumulative watermark (base units). +func (s *ActiveSession) Cumulative() uint64 { return s.cumulative } + +// Nonce returns the current voucher nonce counter. +func (s *ActiveSession) Nonce() uint64 { return s.nonce } + +// ExpiresAt returns the expiry timestamp applied to new vouchers. +func (s *ActiveSession) ExpiresAt() int64 { return s.expiresAt } + +// ChannelID returns the on-chain channel address. +func (s *ActiveSession) ChannelID() solana.PublicKey { return s.channelID } + +// ChannelIDString returns the channel address as base58. +func (s *ActiveSession) ChannelIDString() string { return s.channelID.String() } + +// AuthorizedSigner returns the session signing key as base58, for the open +// action payload. +func (s *ActiveSession) AuthorizedSigner() string { return s.signer.PublicKey().String() } + +// SignVoucher signs a voucher with an absolute cumulative amount and advances +// the local watermark. cumulative MUST strictly exceed the current watermark. +func (s *ActiveSession) SignVoucher(cumulative uint64) (intents.SignedVoucher, error) { + voucher, err := s.PrepareVoucher(cumulative) + if err != nil { + return intents.SignedVoucher{}, err + } + if err := s.RecordVoucher(voucher); err != nil { + return intents.SignedVoucher{}, err + } + return voucher, nil +} + +// SignIncrement signs a voucher adding amount to the current cumulative. +func (s *ActiveSession) SignIncrement(amount uint64) (intents.SignedVoucher, error) { + next, err := addCumulative(s.cumulative, amount) + if err != nil { + return intents.SignedVoucher{}, err + } + return s.SignVoucher(next) +} + +// PrepareVoucher signs a voucher without advancing the local watermark. +// +// This keeps ack/commit transports safe to retry: a failed commit can be +// retried with the same cumulative amount without the local state drifting +// ahead of the server. cumulative MUST strictly exceed the current watermark. +func (s *ActiveSession) PrepareVoucher(cumulative uint64) (intents.SignedVoucher, error) { + if cumulative <= s.cumulative { + return intents.SignedVoucher{}, fmt.Errorf( + "voucher cumulative %d must exceed current watermark %d", cumulative, s.cumulative) + } + + nonce := s.nonce + 1 + data := intents.VoucherData{ + ChannelID: s.ChannelIDString(), + Cumulative: strconv.FormatUint(cumulative, 10), + ExpiresAt: s.expiresAt, + Nonce: &nonce, + } + + preimage, err := paymentchannels.VoucherMessageBytes(s.channelID, cumulative, s.expiresAt) + if err != nil { + return intents.SignedVoucher{}, fmt.Errorf("voucher preimage: %w", err) + } + sig, err := s.signer.Sign(preimage) + if err != nil { + return intents.SignedVoucher{}, fmt.Errorf("sign voucher: %w", err) + } + + return intents.SignedVoucher{Data: data, Signature: sig.String()}, nil +} + +// PrepareIncrement signs a voucher adding amount to the current cumulative +// without advancing the watermark. +func (s *ActiveSession) PrepareIncrement(amount uint64) (intents.SignedVoucher, error) { + next, err := addCumulative(s.cumulative, amount) + if err != nil { + return intents.SignedVoucher{}, err + } + return s.PrepareVoucher(next) +} + +// RecordVoucher advances the local watermark to a prepared voucher the server +// has accepted. The voucher cumulative MUST strictly exceed the current +// watermark; the nonce advances to the larger of the current nonce and the +// voucher nonce (or +1 when the voucher omits a nonce). +func (s *ActiveSession) RecordVoucher(voucher intents.SignedVoucher) error { + if voucher.Data.ChannelID != s.ChannelIDString() { + return fmt.Errorf( + "voucher channel %s does not match active session %s", + voucher.Data.ChannelID, s.ChannelIDString()) + } + cumulative, err := parseCumulative(voucher.Data.Cumulative) + if err != nil { + return err + } + if cumulative <= s.cumulative { + return fmt.Errorf( + "voucher cumulative %d must exceed current watermark %d", cumulative, s.cumulative) + } + s.cumulative = cumulative + candidate := s.nonce + 1 + if voucher.Data.Nonce != nil && *voucher.Data.Nonce > candidate { + candidate = *voucher.Data.Nonce + } + s.nonce = candidate + return nil +} + +// ReconcileSettled reconciles the local watermark to a server-settled +// cumulative, e.g. the Cumulative of a replayed commit receipt. It advances to +// settled when that is ahead of the current watermark and never regresses, so +// retrying a delivery the server already accepted (lost-response case) catches +// the client up without recording the freshly prepared higher voucher. +// +// When it advances, the request nonce also advances by one, mirroring the +// RecordVoucher accounting for the delivery the server settled, so the next +// prepared voucher does not reuse the already-settled nonce. +func (s *ActiveSession) ReconcileSettled(settled uint64) { + if settled > s.cumulative { + s.cumulative = settled + s.nonce++ + } +} + +// VoucherAction signs a fresh increment and wraps it as a voucher action. +func (s *ActiveSession) VoucherAction(amount uint64) (intents.SessionAction, error) { + voucher, err := s.SignIncrement(amount) + if err != nil { + return intents.SessionAction{}, err + } + return intents.NewVoucherAction(intents.VoucherPayload{Voucher: voucher}), nil +} + +// CloseAction builds a cooperative close action. When finalIncrement > 0 it +// signs one last voucher for the remaining balance before closing; otherwise +// the close carries no voucher. +func (s *ActiveSession) CloseAction(finalIncrement uint64) (intents.SessionAction, error) { + payload := intents.ClosePayload{ChannelID: s.ChannelIDString()} + if finalIncrement > 0 { + voucher, err := s.SignIncrement(finalIncrement) + if err != nil { + return intents.SessionAction{}, err + } + payload.Voucher = &voucher + } + return intents.NewCloseAction(payload), nil +} + +// OpenAction builds a push-mode open action. Call this after the on-chain open +// transaction has confirmed; the session channel ID MUST match the confirmed +// channel address. +func (s *ActiveSession) OpenAction(deposit uint64, openTxSignature string) intents.SessionAction { + return intents.NewOpenAction(intents.OpenPayloadPush( + s.ChannelIDString(), + strconv.FormatUint(deposit, 10), + s.AuthorizedSigner(), + openTxSignature, + )) +} + +// OpenPaymentChannelAction builds a payment-channel push open action carrying +// the full channel parameters. +func (s *ActiveSession) OpenPaymentChannelAction( + deposit uint64, + payer, payee, mint string, + salt uint64, + gracePeriod uint32, + openTxSignature string, +) intents.SessionAction { + return s.OpenPaymentChannelActionWithMode( + intents.SessionModePush, deposit, payer, payee, mint, salt, gracePeriod, openTxSignature) +} + +// OpenPaymentChannelActionWithMode builds a payment-channel open action with an +// explicit submission mode (push, or pull when the operator broadcasts). +func (s *ActiveSession) OpenPaymentChannelActionWithMode( + mode intents.SessionMode, + deposit uint64, + payer, payee, mint string, + salt uint64, + gracePeriod uint32, + openTxSignature string, +) intents.SessionAction { + return intents.NewOpenAction(intents.OpenPayloadPaymentChannelWithMode( + mode, + s.ChannelIDString(), + strconv.FormatUint(deposit, 10), + payer, payee, mint, + salt, gracePeriod, + s.AuthorizedSigner(), + openTxSignature, + )) +} + +// OpenPullAction builds a pull-mode (SPL delegation) open action. The session +// channel ID is used as the token account, so callers should construct the +// ActiveSession with the delegated token account pubkey as the channel ID. +func (s *ActiveSession) OpenPullAction(approvedAmount uint64, owner, approveTxSignature string) intents.SessionAction { + return intents.NewOpenAction(intents.OpenPayloadPull( + s.ChannelIDString(), + strconv.FormatUint(approvedAmount, 10), + owner, + s.AuthorizedSigner(), + approveTxSignature, + )) +} + +// TopUpAction builds a top-up action after a top-up transaction confirms. +func (s *ActiveSession) TopUpAction(newDeposit uint64, topupTxSignature string) intents.SessionAction { + return intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: s.ChannelIDString(), + NewDeposit: strconv.FormatUint(newDeposit, 10), + Signature: topupTxSignature, + }) +} + +// SerializeSessionCredential builds an Authorization header value for a session +// action, echoing the challenge and JCS-canonicalizing the credential. The +// result is "Payment ", the same credential +// framing used for every payment authorization on the wire. +func SerializeSessionCredential(challenge core.PaymentChallenge, action intents.SessionAction) (string, error) { + credential, err := core.NewPaymentCredential(challenge.ToEcho(), action) + if err != nil { + return "", err + } + return core.FormatAuthorization(credential) +} + +// ParseSessionChallenge parses a WWW-Authenticate header value into the +// challenge and the decoded session request. +// +// It rejects non-session intents so callers do not accidentally treat a charge +// challenge as a session. +func ParseSessionChallenge(header string) (core.PaymentChallenge, intents.SessionRequest, error) { + challenge, err := core.ParseWWWAuthenticate(header) + if err != nil { + return core.PaymentChallenge{}, intents.SessionRequest{}, err + } + if !challenge.Intent.IsSession() { + return core.PaymentChallenge{}, intents.SessionRequest{}, fmt.Errorf( + "challenge intent %q is not a session", challenge.Intent) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + return core.PaymentChallenge{}, intents.SessionRequest{}, fmt.Errorf("decode session request: %w", err) + } + return challenge, request, nil +} + +// addCumulative adds amount to current, rejecting u64 overflow so a wrapped +// watermark can never be signed. +func addCumulative(current, amount uint64) (uint64, error) { + next := current + amount + if next < current { + return 0, fmt.Errorf("voucher cumulative overflows u64: %d + %d", current, amount) + } + return next, nil +} + +// parseCumulative parses a decimal voucher cumulative into base units. +func parseCumulative(raw string) (uint64, error) { + value, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid voucher cumulative %q", raw) + } + return value, nil +} diff --git a/go/protocols/mpp/client/session_consumer.go b/go/protocols/mpp/client/session_consumer.go new file mode 100644 index 000000000..6194bb13e --- /dev/null +++ b/go/protocols/mpp/client/session_consumer.go @@ -0,0 +1,168 @@ +// Kafka-style client helpers for metered session deliveries. +// +// SessionConsumer wraps an ActiveSession so applications can process delivered +// messages and call Ack/Commit instead of manually signing and posting +// vouchers. A failed commit never advances the local watermark, so the same +// directive can be retried safely. +package client + +import ( + "context" + "fmt" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// CommitTransport sends a commit payload to the server and returns its receipt. +// +// HTTP clients, queues, and in-process tests all implement this. The directive +// is passed alongside the payload so transports can use CommitURL, Proof, or +// other routing hints without repeating them in the signed commit body. +type CommitTransport interface { + Commit(ctx context.Context, directive intents.MeteringDirective, payload intents.CommitPayload) (intents.CommitReceipt, error) +} + +// SessionConsumer is a client-side consumer for session-metered deliveries. +// +// SessionConsumer is not safe for concurrent use; the underlying ActiveSession +// watermark is advanced under Commit. +type SessionConsumer struct { + // session is the wrapped ActiveSession; its cumulative watermark only + // advances after the transport reports a successful commit. + session *ActiveSession + + // transport posts signed commit payloads to the server and returns its + // receipts. + transport CommitTransport +} + +// NewSessionConsumer wraps a session and a commit transport. +func NewSessionConsumer(session *ActiveSession, transport CommitTransport) *SessionConsumer { + return &SessionConsumer{session: session, transport: transport} +} + +// Session returns the wrapped session. +func (c *SessionConsumer) Session() *ActiveSession { return c.session } + +// CommitDirective signs a voucher for the directive amount, sends it through +// the transport, and advances the local watermark only on success. It rejects +// directives whose session does not match, whose amount is not a valid base-unit +// integer, or whose amount is zero. +func (c *SessionConsumer) CommitDirective(ctx context.Context, directive intents.MeteringDirective) (intents.CommitReceipt, error) { + if err := c.validateDirective(directive); err != nil { + return intents.CommitReceipt{}, err + } + amount, err := directive.AmountBaseUnits() + if err != nil { + return intents.CommitReceipt{}, err + } + if amount == 0 { + return intents.CommitReceipt{}, fmt.Errorf("metered delivery amount must be greater than zero") + } + + voucher, err := c.session.PrepareIncrement(amount) + if err != nil { + return intents.CommitReceipt{}, err + } + payload := intents.CommitPayload{DeliveryID: directive.DeliveryID, Voucher: voucher} + + receipt, err := c.transport.Commit(ctx, directive, payload) + if err != nil { + return intents.CommitReceipt{}, err + } + // A replayed receipt means the server already settled this delivery, so its + // Cumulative is the authoritative settled position. Recording the freshly + // prepared (higher) voucher would push the local watermark past the server's + // state and let a later close sign for more than was agreed; skipping it + // entirely would instead leave the watermark behind the server when the + // original response was lost, so the next delivery signs a non-monotonic + // cumulative. Reconcile to the receipt cumulative on replay (never + // regressing); record the voucher on a fresh committed receipt. + switch receipt.Status { + case intents.CommitStatusReplayed: + settled, perr := parseCumulative(receipt.Cumulative) + if perr != nil { + return intents.CommitReceipt{}, fmt.Errorf("invalid replayed receipt cumulative: %w", perr) + } + // The server is untrusted: clamp to the voucher just prepared in this + // call. An honest lost-response replay settles at or below it (the + // session is single-threaded), so a server reporting a higher cumulative + // cannot push the watermark past what the client actually signed — + // otherwise the next voucher would over-authorize up to the deposit. + prepared, perr := parseCumulative(voucher.Data.Cumulative) + if perr != nil { + return intents.CommitReceipt{}, fmt.Errorf("invalid prepared voucher cumulative: %w", perr) + } + if settled > prepared { + settled = prepared + } + c.session.ReconcileSettled(settled) + case intents.CommitStatusCommitted: + if err := c.session.RecordVoucher(voucher); err != nil { + return intents.CommitReceipt{}, err + } + default: + // A malformed or unknown status must not advance local state. + return intents.CommitReceipt{}, fmt.Errorf("unexpected commit receipt status: %q", receipt.Status) + } + return receipt, nil +} + +func (c *SessionConsumer) validateDirective(directive intents.MeteringDirective) error { + channelID := c.session.ChannelIDString() + if directive.SessionID != channelID { + return fmt.Errorf( + "metered delivery session %s does not match active session %s", directive.SessionID, channelID) + } + return nil +} + +// Accept validates an envelope and returns a delivery handle exposing Ack and +// Commit. The directive is validated up front so a mismatched session is +// rejected before the application processes the payload. +func Accept[P any](c *SessionConsumer, envelope intents.MeteredEnvelope[P]) (*MeteredDelivery[P], error) { + if err := c.validateDirective(envelope.Metering); err != nil { + return nil, err + } + return &MeteredDelivery[P]{ + consumer: c, + payload: envelope.Payload, + metering: envelope.Metering, + }, nil +} + +// MeteredDelivery is a delivered payload paired with its metering directive. +// Call Ack (or its Commit alias) after the application has processed Payload. +type MeteredDelivery[P any] struct { + // consumer is the consumer that accepted the delivery; Ack commits the + // directive amount through it. + consumer *SessionConsumer + + // payload is the delivered application payload. + payload P + + // metering is the directive pricing this delivery; Ack signs a voucher + // for its amount. + metering intents.MeteringDirective +} + +// Payload returns the delivered payload. +func (d *MeteredDelivery[P]) Payload() P { return d.payload } + +// Metering returns the metering directive that accompanied the payload. +func (d *MeteredDelivery[P]) Metering() intents.MeteringDirective { return d.metering } + +// Ack signs and commits a voucher for the directive amount. +func (d *MeteredDelivery[P]) Ack(ctx context.Context) (intents.CommitReceipt, error) { + return d.consumer.CommitDirective(ctx, d.metering) +} + +// Commit is an alias for Ack. +func (d *MeteredDelivery[P]) Commit(ctx context.Context) (intents.CommitReceipt, error) { + return d.Ack(ctx) +} + +// IntoParts returns the payload and metering directive without committing. +func (d *MeteredDelivery[P]) IntoParts() (P, intents.MeteringDirective) { + return d.payload, d.metering +} diff --git a/go/protocols/mpp/client/session_consumer_test.go b/go/protocols/mpp/client/session_consumer_test.go new file mode 100644 index 000000000..24dbc4520 --- /dev/null +++ b/go/protocols/mpp/client/session_consumer_test.go @@ -0,0 +1,383 @@ +package client + +import ( + "context" + "errors" + "strings" + "sync" + "testing" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// recordingTransport captures committed payloads and can be made to fail +// on demand. It also models the +// server-side delivery dedupe: a deliveryId already committed returns a +// "replayed" receipt carrying the originally committed cumulative, so the +// client does not double-count. +type recordingTransport struct { + mu sync.Mutex // guards commits and seen + commits []intents.CommitPayload // every payload accepted as a fresh commit + fail bool // when true, every Commit errors without recording + + // seen maps a deliveryId to the cumulative the server first committed for + // it. A repeat deliveryId is acknowledged as replayed. + seen map[string]string +} + +func (r *recordingTransport) Commit(_ context.Context, directive intents.MeteringDirective, payload intents.CommitPayload) (intents.CommitReceipt, error) { + if r.fail { + return intents.CommitReceipt{}, errors.New("commit failed") + } + r.mu.Lock() + defer r.mu.Unlock() + if r.seen != nil { + if cumulative, ok := r.seen[directive.DeliveryID]; ok { + return intents.CommitReceipt{ + DeliveryID: directive.DeliveryID, + SessionID: directive.SessionID, + Amount: directive.Amount, + Cumulative: cumulative, + Status: intents.CommitStatusReplayed, + }, nil + } + r.seen[directive.DeliveryID] = payload.Voucher.Data.Cumulative + } + r.commits = append(r.commits, payload) + return intents.CommitReceipt{ + DeliveryID: directive.DeliveryID, + SessionID: directive.SessionID, + Amount: directive.Amount, + Cumulative: payload.Voucher.Data.Cumulative, + Status: intents.CommitStatusCommitted, + }, nil +} + +func (r *recordingTransport) count() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.commits) +} + +func newConsumer(t *testing.T, fail bool) (*SessionConsumer, *recordingTransport) { + t.Helper() + session, _ := newSession(t) + transport := &recordingTransport{fail: fail} + return NewSessionConsumer(session, transport), transport +} + +func directive(sessionID, amount string) intents.MeteringDirective { + return intents.MeteringDirective{ + DeliveryID: "d1", + SessionID: sessionID, + Amount: amount, + Currency: "USDC", + Sequence: 1, + ExpiresAt: intents.DefaultSessionExpiresAt, + } +} + +func TestSessionConsumerSessionAccessor(t *testing.T) { + consumer, _ := newConsumer(t, false) + if consumer.Session() == nil { + t.Fatal("expected a session") + } +} + +func TestConsumerAckAdvancesWatermark(t *testing.T) { + consumer, transport := newConsumer(t, false) + sid := consumer.Session().ChannelIDString() + + delivery, err := Accept(consumer, intents.MeteredEnvelope[string]{ + Payload: "work", + Metering: directive(sid, "250"), + }) + if err != nil { + t.Fatalf("accept: %v", err) + } + if delivery.Payload() != "work" { + t.Fatalf("payload: %q", delivery.Payload()) + } + if delivery.Metering().Amount != "250" { + t.Fatalf("metering amount: %q", delivery.Metering().Amount) + } + + receipt, err := delivery.Ack(context.Background()) + if err != nil { + t.Fatalf("ack: %v", err) + } + if receipt.Cumulative != "250" { + t.Fatalf("cumulative: %q", receipt.Cumulative) + } + if consumer.Session().Cumulative() != 250 { + t.Fatalf("session cumulative: %d", consumer.Session().Cumulative()) + } + if transport.count() != 1 { + t.Fatalf("commits: %d", transport.count()) + } +} + +func TestConsumerCommitAliasAndIntoParts(t *testing.T) { + consumer, _ := newConsumer(t, false) + consumer.Session().SetExpiresAt(1234) + sid := consumer.Session().ChannelIDString() + + delivery, err := Accept(consumer, intents.MeteredEnvelope[string]{Payload: "payload", Metering: directive(sid, "50")}) + if err != nil { + t.Fatalf("accept: %v", err) + } + receipt, err := delivery.Commit(context.Background()) + if err != nil { + t.Fatalf("commit: %v", err) + } + if receipt.Cumulative != "50" { + t.Fatalf("cumulative: %q", receipt.Cumulative) + } + + second, err := Accept(consumer, intents.MeteredEnvelope[string]{Payload: "second", Metering: directive(sid, "75")}) + if err != nil { + t.Fatalf("accept second: %v", err) + } + payload, metering := second.IntoParts() + if payload != "second" { + t.Fatalf("payload: %q", payload) + } + if metering.Amount != "75" { + t.Fatalf("metering amount: %q", metering.Amount) + } + // IntoParts must not commit; only the first delivery advanced the watermark. + if consumer.Session().Cumulative() != 50 { + t.Fatalf("cumulative after into_parts: %d", consumer.Session().Cumulative()) + } +} + +func TestConsumerCommitDirectiveDirect(t *testing.T) { + consumer, transport := newConsumer(t, false) + sid := consumer.Session().ChannelIDString() + + receipt, err := consumer.CommitDirective(context.Background(), directive(sid, "25")) + if err != nil { + t.Fatalf("commit directive: %v", err) + } + if receipt.Cumulative != "25" { + t.Fatalf("cumulative: %q", receipt.Cumulative) + } + if transport.count() != 1 { + t.Fatalf("commits: %d", transport.count()) + } +} + +func TestConsumerRejectsWrongSession(t *testing.T) { + consumer, transport := newConsumer(t, false) + _, err := Accept(consumer, intents.MeteredEnvelope[struct{}]{ + Payload: struct{}{}, + Metering: directive("other-session", "1"), + }) + if err == nil || !strings.Contains(err.Error(), "does not match active session") { + t.Fatalf("expected wrong-session rejection, got %v", err) + } + if transport.count() != 0 { + t.Fatalf("no commit expected: %d", transport.count()) + } +} + +func TestConsumerRejectsZeroAndInvalidAmount(t *testing.T) { + consumer, transport := newConsumer(t, false) + sid := consumer.Session().ChannelIDString() + + if _, err := consumer.CommitDirective(context.Background(), directive(sid, "0")); err == nil || !strings.Contains(err.Error(), "greater than zero") { + t.Fatalf("expected zero rejection, got %v", err) + } + if _, err := consumer.CommitDirective(context.Background(), directive(sid, "bad")); err == nil { + t.Fatalf("expected invalid amount rejection") + } + if transport.count() != 0 { + t.Fatalf("no commit expected: %d", transport.count()) + } +} + +func TestConsumerFailedCommitDoesNotAdvanceWatermark(t *testing.T) { + consumer, _ := newConsumer(t, true) + sid := consumer.Session().ChannelIDString() + + _, err := consumer.CommitDirective(context.Background(), directive(sid, "250")) + if err == nil || !strings.Contains(err.Error(), "commit failed") { + t.Fatalf("expected commit failure, got %v", err) + } + if consumer.Session().Cumulative() != 0 { + t.Fatalf("watermark advanced after failed commit: %d", consumer.Session().Cumulative()) + } +} + +func TestConsumerDuplicateDeliveryReplayedNotDoubleCounted(t *testing.T) { + // A server that dedupes by deliveryId returns a "replayed" receipt on the + // second commit of the same deliveryId, carrying the cumulative it first + // settled. The client honors that receipt and does not double-count: the + // transport records exactly one commit. + consumer, transport := newConsumer(t, false) + transport.seen = map[string]string{} + sid := consumer.Session().ChannelIDString() + + first, err := consumer.CommitDirective(context.Background(), directive(sid, "100")) + if err != nil { + t.Fatalf("first commit: %v", err) + } + if first.Status != intents.CommitStatusCommitted { + t.Fatalf("first status: %q", first.Status) + } + if first.Cumulative != "100" { + t.Fatalf("first cumulative: %q", first.Cumulative) + } + + // Replaying the same deliveryId yields a replayed receipt pinned to the + // originally committed cumulative. + replay, err := consumer.CommitDirective(context.Background(), directive(sid, "100")) + if err != nil { + t.Fatalf("replay commit: %v", err) + } + if replay.Status != intents.CommitStatusReplayed { + t.Fatalf("replay status: %q", replay.Status) + } + if replay.Cumulative != "100" { + t.Fatalf("replay cumulative not pinned to original: %q", replay.Cumulative) + } + if transport.count() != 1 { + t.Fatalf("server must record exactly one commit, got %d", transport.count()) + } + // The local watermark must reflect the server's settled position (100), not + // the freshly prepared voucher (200) that the replay would otherwise record. + // Advancing it here would let a later close sign for more than was settled. + if got := consumer.Session().Cumulative(); got != 100 { + t.Fatalf("watermark advanced past settled position on replay: got %d, want 100", got) + } +} + +// replayTransport always reports the delivery as already settled at a fixed +// cumulative, regardless of the voucher it is sent. +type replayTransport struct { + // settled is the fixed cumulative (base units, decimal string) every + // replayed receipt reports as already settled. + settled string +} + +func (r replayTransport) Commit(_ context.Context, directive intents.MeteringDirective, _ intents.CommitPayload) (intents.CommitReceipt, error) { + return intents.CommitReceipt{ + DeliveryID: directive.DeliveryID, + SessionID: directive.SessionID, + Amount: directive.Amount, + Cumulative: r.settled, + Status: intents.CommitStatusReplayed, + }, nil +} + +func TestConsumerReplayReconcilesWatermarkWhenBehind(t *testing.T) { + // Lost-response case: the server already settled this delivery at 100 but + // the client never recorded it (watermark still 0). On replay the client + // must reconcile to the server-settled 100, not jump to the prepared 250 + // and not stay at 0 (which would make the next delivery non-monotonic). + session, _ := newSession(t) + consumer := NewSessionConsumer(session, replayTransport{settled: "100"}) + sid := consumer.Session().ChannelIDString() + + receipt, err := consumer.CommitDirective(context.Background(), directive(sid, "250")) + if err != nil { + t.Fatalf("commit: %v", err) + } + if receipt.Status != intents.CommitStatusReplayed { + t.Fatalf("status: %q", receipt.Status) + } + if got := consumer.Session().Cumulative(); got != 100 { + t.Fatalf("watermark not reconciled to settled position: got %d, want 100", got) + } +} + +func TestConsumerReplayNeverRegressesWatermark(t *testing.T) { + // The client is already ahead at 300; a stale replay settled at 100 must not + // regress the local watermark. + session, _ := newSession(t) + session.ReconcileSettled(300) + consumer := NewSessionConsumer(session, replayTransport{settled: "100"}) + sid := consumer.Session().ChannelIDString() + + if _, err := consumer.CommitDirective(context.Background(), directive(sid, "50")); err != nil { + t.Fatalf("commit: %v", err) + } + if got := consumer.Session().Cumulative(); got != 300 { + t.Fatalf("watermark regressed on stale replay: got %d, want 300", got) + } +} + +func TestConsumerReplayClampedToPreparedVoucher(t *testing.T) { + // A malicious/buggy server cannot push the watermark past the voucher the + // client just signed: it reports a replay settled far above the prepared + // cumulative (250), but the watermark must clamp to the prepared value, not + // the inflated server value, so the next voucher does not over-authorize. + session, _ := newSession(t) + consumer := NewSessionConsumer(session, replayTransport{settled: "1000000"}) + sid := consumer.Session().ChannelIDString() + + if _, err := consumer.CommitDirective(context.Background(), directive(sid, "250")); err != nil { + t.Fatalf("commit: %v", err) + } + if got := consumer.Session().Cumulative(); got != 250 { + t.Fatalf("watermark not clamped to prepared voucher: got %d, want 250", got) + } +} + +// statusTransport returns a fixed (possibly unknown) status, to exercise the +// consumer's rejection of malformed receipts. +type statusTransport struct { + // status is the receipt status echoed for every commit, including values + // outside the known committed/replayed set. + status intents.CommitStatus +} + +func (s statusTransport) Commit(_ context.Context, directive intents.MeteringDirective, payload intents.CommitPayload) (intents.CommitReceipt, error) { + return intents.CommitReceipt{ + DeliveryID: directive.DeliveryID, + SessionID: directive.SessionID, + Amount: directive.Amount, + Cumulative: payload.Voucher.Data.Cumulative, + Status: s.status, + }, nil +} + +func TestConsumerRejectsUnknownReceiptStatus(t *testing.T) { + session, _ := newSession(t) + consumer := NewSessionConsumer(session, statusTransport{status: "bogus"}) + sid := consumer.Session().ChannelIDString() + + _, err := consumer.CommitDirective(context.Background(), directive(sid, "100")) + if err == nil || !strings.Contains(err.Error(), "unexpected commit receipt status") { + t.Fatalf("expected unknown-status rejection, got %v", err) + } + // A malformed receipt must not advance local state. + if consumer.Session().Cumulative() != 0 { + t.Fatalf("watermark advanced on unknown status: %d", consumer.Session().Cumulative()) + } +} + +func TestConsumerDuplicateDeliveryReplayMonotonic(t *testing.T) { + // Two distinct deliveries advance the cumulative monotonically; the + // transport sees increasing cumulative amounts. + consumer, transport := newConsumer(t, false) + sid := consumer.Session().ChannelIDString() + + if _, err := consumer.CommitDirective(context.Background(), directive(sid, "10")); err != nil { + t.Fatalf("first commit: %v", err) + } + d2 := directive(sid, "15") + d2.DeliveryID = "d2" + if _, err := consumer.CommitDirective(context.Background(), d2); err != nil { + t.Fatalf("second commit: %v", err) + } + if consumer.Session().Cumulative() != 25 { + t.Fatalf("cumulative: %d", consumer.Session().Cumulative()) + } + transport.mu.Lock() + defer transport.mu.Unlock() + if transport.commits[0].Voucher.Data.Cumulative != "10" || transport.commits[1].Voucher.Data.Cumulative != "25" { + t.Fatalf("cumulative progression: %q %q", + transport.commits[0].Voucher.Data.Cumulative, transport.commits[1].Voucher.Data.Cumulative) + } +} diff --git a/go/protocols/mpp/client/session_test.go b/go/protocols/mpp/client/session_test.go new file mode 100644 index 000000000..a5a198aa1 --- /dev/null +++ b/go/protocols/mpp/client/session_test.go @@ -0,0 +1,726 @@ +package client + +import ( + "fmt" + "strings" + "testing" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// failingSigner satisfies VoucherSigner but always fails to sign, exercising +// the signing-error propagation paths. +type failingSigner struct { + // pub is the public key the signer reports even though Sign always fails. + pub solana.PublicKey +} + +func (f failingSigner) PublicKey() solana.PublicKey { return f.pub } +func (f failingSigner) Sign([]byte) (solana.Signature, error) { + return solana.Signature{}, fmt.Errorf("signer unavailable") +} + +func TestSigningErrorPropagates(t *testing.T) { + channel := testutil.NewPrivateKey().PublicKey() + s := NewActiveSession(channel, failingSigner{pub: testutil.NewPrivateKey().PublicKey()}) + + if _, err := s.PrepareVoucher(10); err == nil { + t.Fatal("prepare voucher should surface the signing error") + } + if _, err := s.SignVoucher(10); err == nil { + t.Fatal("sign voucher should surface the signing error") + } + if _, err := s.VoucherAction(10); err == nil { + t.Fatal("voucher action should surface the signing error") + } + if _, err := s.CloseAction(10); err == nil { + t.Fatal("close action with increment should surface the signing error") + } + // A zero-increment close never signs, so it must still succeed. + if _, err := s.CloseAction(0); err != nil { + t.Fatalf("close action without increment: %v", err) + } +} + +// newSession builds an ActiveSession over a fresh keypair channel and signer. +func newSession(t *testing.T) (*ActiveSession, solana.PrivateKey) { + t.Helper() + channel := testutil.NewPrivateKey().PublicKey() + signer := testutil.NewPrivateKey() + return NewActiveSession(channel, signer), signer +} + +func TestNewActiveSessionDefaults(t *testing.T) { + s, signer := newSession(t) + if s.Cumulative() != 0 { + t.Fatalf("cumulative = %d, want 0", s.Cumulative()) + } + if s.Nonce() != 0 { + t.Fatalf("nonce = %d, want 0", s.Nonce()) + } + if s.ExpiresAt() != intents.DefaultSessionExpiresAt { + t.Fatalf("expiresAt = %d, want %d", s.ExpiresAt(), intents.DefaultSessionExpiresAt) + } + if got, want := s.AuthorizedSigner(), signer.PublicKey().String(); got != want { + t.Fatalf("authorizedSigner = %q, want %q", got, want) + } + if s.ChannelIDString() != s.ChannelID().String() { + t.Fatalf("channelIdString = %q, want %q", s.ChannelIDString(), s.ChannelID().String()) + } +} + +func TestNewActiveSessionAtAndSetExpiresAt(t *testing.T) { + channel := testutil.NewPrivateKey().PublicKey() + signer := testutil.NewPrivateKey() + s := NewActiveSessionAt(channel, signer, 1234) + + first, err := s.PrepareIncrement(10) + if err != nil { + t.Fatalf("prepare increment: %v", err) + } + if first.Data.ExpiresAt != 1234 { + t.Fatalf("expiresAt = %d, want 1234", first.Data.ExpiresAt) + } + // PrepareIncrement does not advance the watermark. + if s.Cumulative() != 0 { + t.Fatalf("cumulative advanced to %d after prepare", s.Cumulative()) + } + + s.SetExpiresAt(5678) + second, err := s.PrepareIncrement(10) + if err != nil { + t.Fatalf("prepare increment after set: %v", err) + } + if second.Data.ExpiresAt != 5678 { + t.Fatalf("expiresAt = %d, want 5678", second.Data.ExpiresAt) + } +} + +func TestSignIncrementIncreasesCumulative(t *testing.T) { + s, _ := newSession(t) + v, err := s.SignIncrement(100) + if err != nil { + t.Fatalf("sign increment: %v", err) + } + if s.Cumulative() != 100 { + t.Fatalf("cumulative = %d, want 100", s.Cumulative()) + } + if v.Data.Cumulative != "100" { + t.Fatalf("voucher cumulative = %q, want \"100\"", v.Data.Cumulative) + } + if v.Data.Nonce == nil || *v.Data.Nonce != 1 { + t.Fatalf("voucher nonce = %v, want 1", v.Data.Nonce) + } +} + +func TestSignVoucherAbsolute(t *testing.T) { + s, _ := newSession(t) + if _, err := s.SignIncrement(50); err != nil { + t.Fatalf("sign increment: %v", err) + } + v, err := s.SignVoucher(200) + if err != nil { + t.Fatalf("sign voucher: %v", err) + } + if s.Cumulative() != 200 { + t.Fatalf("cumulative = %d, want 200", s.Cumulative()) + } + if v.Data.Cumulative != "200" { + t.Fatalf("voucher cumulative = %q, want \"200\"", v.Data.Cumulative) + } +} + +func TestPrepareAndRecordVoucherAreSeparate(t *testing.T) { + s, _ := newSession(t) + prepared, err := s.PrepareIncrement(75) + if err != nil { + t.Fatalf("prepare increment: %v", err) + } + if prepared.Data.Cumulative != "75" { + t.Fatalf("prepared cumulative = %q, want \"75\"", prepared.Data.Cumulative) + } + if prepared.Data.Nonce == nil || *prepared.Data.Nonce != 1 { + t.Fatalf("prepared nonce = %v, want 1", prepared.Data.Nonce) + } + if s.Cumulative() != 0 { + t.Fatalf("cumulative advanced to %d before record", s.Cumulative()) + } + + if err := s.RecordVoucher(prepared); err != nil { + t.Fatalf("record voucher: %v", err) + } + if s.Cumulative() != 75 { + t.Fatalf("cumulative = %d, want 75", s.Cumulative()) + } + // Re-recording the same voucher must be rejected (non-increasing). + if err := s.RecordVoucher(prepared); err == nil { + t.Fatal("re-recording the same voucher should fail") + } +} + +func TestRecordVoucherInvalidAndMissingNonce(t *testing.T) { + s, _ := newSession(t) + + bad := intents.SignedVoucher{ + Data: intents.VoucherData{ + ChannelID: s.ChannelIDString(), + Cumulative: "not-a-number", + ExpiresAt: intents.DefaultSessionExpiresAt, + }, + Signature: "sig", + } + if err := s.RecordVoucher(bad); err == nil { + t.Fatal("recording an invalid cumulative should fail") + } + + withoutNonce := intents.SignedVoucher{ + Data: intents.VoucherData{ + ChannelID: s.ChannelIDString(), + Cumulative: "15", + ExpiresAt: intents.DefaultSessionExpiresAt, + }, + Signature: "sig", + } + if err := s.RecordVoucher(withoutNonce); err != nil { + t.Fatalf("record voucher without nonce: %v", err) + } + if s.Cumulative() != 15 { + t.Fatalf("cumulative = %d, want 15", s.Cumulative()) + } + if s.Nonce() != 1 { + t.Fatalf("nonce = %d, want 1", s.Nonce()) + } +} + +func TestRecordVoucherRejectsForeignChannel(t *testing.T) { + s, _ := newSession(t) + foreign := intents.SignedVoucher{ + Data: intents.VoucherData{ + ChannelID: solana.NewWallet().PublicKey().String(), + Cumulative: "100", + ExpiresAt: intents.DefaultSessionExpiresAt, + }, + Signature: "sig", + } + if err := s.RecordVoucher(foreign); err == nil || !strings.Contains(err.Error(), "does not match active session") { + t.Fatalf("recording a foreign-channel voucher should fail, got %v", err) + } + if s.Cumulative() != 0 { + t.Fatalf("watermark advanced on foreign voucher: %d", s.Cumulative()) + } +} + +func TestReconcileSettledAdvancesButNeverRegresses(t *testing.T) { + s, _ := newSession(t) + s.ReconcileSettled(100) + if s.Cumulative() != 100 { + t.Fatalf("cumulative = %d, want 100", s.Cumulative()) + } + if s.Nonce() != 1 { + t.Fatalf("nonce = %d, want 1 (advance bumps nonce)", s.Nonce()) + } + s.ReconcileSettled(40) // stale, must not regress or touch the nonce + if s.Cumulative() != 100 || s.Nonce() != 1 { + t.Fatalf("stale reconcile changed state: cumulative=%d nonce=%d", s.Cumulative(), s.Nonce()) + } + s.ReconcileSettled(250) + if s.Cumulative() != 250 || s.Nonce() != 2 { + t.Fatalf("cumulative=%d nonce=%d, want 250/2", s.Cumulative(), s.Nonce()) + } +} + +func TestDeliveryAfterReplayDoesNotReuseSettledNonce(t *testing.T) { + // After a lost-response replay reconciles to the settled cumulative, the + // next prepared voucher must carry a fresh nonce, not the settled one. + s, _ := newSession(t) + replayed, err := s.PrepareIncrement(100) + if err != nil { + t.Fatalf("prepare: %v", err) + } + s.ReconcileSettled(100) + next, err := s.PrepareIncrement(50) + if err != nil { + t.Fatalf("prepare next: %v", err) + } + if next.Data.Nonce == nil || replayed.Data.Nonce == nil || *next.Data.Nonce <= *replayed.Data.Nonce { + t.Fatalf("next nonce %v must exceed replayed nonce %v", next.Data.Nonce, replayed.Data.Nonce) + } +} + +func TestRecordVoucherKeepsLargerNonce(t *testing.T) { + s, _ := newSession(t) + nonce := uint64(7) + v := intents.SignedVoucher{ + Data: intents.VoucherData{ + ChannelID: s.ChannelIDString(), + Cumulative: "10", + ExpiresAt: intents.DefaultSessionExpiresAt, + Nonce: &nonce, + }, + Signature: "sig", + } + if err := s.RecordVoucher(v); err != nil { + t.Fatalf("record voucher: %v", err) + } + if s.Nonce() != 7 { + t.Fatalf("nonce = %d, want 7 (voucher nonce wins)", s.Nonce()) + } +} + +func TestSignVoucherRejectsNonIncreasing(t *testing.T) { + s, _ := newSession(t) + if _, err := s.SignIncrement(100); err != nil { + t.Fatalf("sign increment: %v", err) + } + if _, err := s.SignVoucher(100); err == nil { + t.Fatal("equal cumulative should be rejected") + } + if _, err := s.SignVoucher(50); err == nil { + t.Fatal("lower cumulative should be rejected") + } +} + +func TestSignVoucherZeroRejected(t *testing.T) { + s, _ := newSession(t) + if _, err := s.SignVoucher(0); err == nil { + t.Fatal("zero cumulative should be rejected") + } +} + +func TestPrepareVoucherRejectsNonIncreasing(t *testing.T) { + s, _ := newSession(t) + if _, err := s.SignIncrement(100); err != nil { + t.Fatalf("sign increment: %v", err) + } + if _, err := s.PrepareVoucher(100); err == nil { + t.Fatal("prepare equal cumulative should be rejected") + } +} + +func TestSignIncrementOverflowRejected(t *testing.T) { + s, _ := newSession(t) + if _, err := s.SignVoucher(^uint64(0)); err != nil { + t.Fatalf("sign max voucher: %v", err) + } + if _, err := s.SignIncrement(1); err == nil { + t.Fatal("increment past u64 max should be rejected") + } + if _, err := s.PrepareIncrement(1); err == nil { + t.Fatal("prepare increment past u64 max should be rejected") + } +} + +func TestNonceIncrementsPerVoucher(t *testing.T) { + s, _ := newSession(t) + first, err := s.SignIncrement(10) + if err != nil { + t.Fatalf("sign increment 1: %v", err) + } + second, err := s.SignIncrement(10) + if err != nil { + t.Fatalf("sign increment 2: %v", err) + } + if first.Data.Nonce == nil || *first.Data.Nonce != 1 { + t.Fatalf("first voucher nonce = %v, want 1", first.Data.Nonce) + } + if second.Data.Nonce == nil || *second.Data.Nonce != 2 { + t.Fatalf("second voucher nonce = %v, want 2", second.Data.Nonce) + } +} + +func TestVoucherChannelIDMatchesSession(t *testing.T) { + s, _ := newSession(t) + want := s.ChannelIDString() + v, err := s.SignIncrement(100) + if err != nil { + t.Fatalf("sign increment: %v", err) + } + if v.Data.ChannelID != want { + t.Fatalf("voucher channelId = %q, want %q", v.Data.ChannelID, want) + } +} + +// TestVoucherSignatureVerifies signs an increment and confirms the base58 +// signature verifies against the authorizedSigner pubkey over the exact 48-byte +// VoucherMessageBytes preimage. +func TestVoucherSignatureVerifies(t *testing.T) { + channel := testutil.NewPrivateKey().PublicKey() + signer := testutil.NewPrivateKey() + s := NewActiveSession(channel, signer) + + v, err := s.SignIncrement(123_456) + if err != nil { + t.Fatalf("sign increment: %v", err) + } + + preimage, err := paymentchannels.VoucherMessageBytes(channel, 123_456, intents.DefaultSessionExpiresAt) + if err != nil { + t.Fatalf("voucher message bytes: %v", err) + } + if len(preimage) != 48 { + t.Fatalf("preimage length = %d, want 48", len(preimage)) + } + + sig, err := solana.SignatureFromBase58(v.Signature) + if err != nil { + t.Fatalf("decode signature: %v", err) + } + if !sig.Verify(signer.PublicKey(), preimage) { + t.Fatal("signature does not verify against authorizedSigner over the voucher preimage") + } + // A tampered preimage must not verify. + preimage[0] ^= 0xFF + if sig.Verify(signer.PublicKey(), preimage) { + t.Fatal("signature verified against a tampered preimage") + } +} + +func TestVoucherActionFields(t *testing.T) { + s, _ := newSession(t) + action, err := s.VoucherAction(33) + if err != nil { + t.Fatalf("voucher action: %v", err) + } + if action.Voucher == nil { + t.Fatal("expected a Voucher action") + } + if action.Voucher.Voucher.Data.Cumulative != "33" { + t.Fatalf("voucher cumulative = %q, want \"33\"", action.Voucher.Voucher.Data.Cumulative) + } + if action.Voucher.Voucher.Data.ChannelID != s.ChannelIDString() { + t.Fatalf("voucher channelId = %q, want %q", action.Voucher.Voucher.Data.ChannelID, s.ChannelIDString()) + } +} + +func TestOpenActionFields(t *testing.T) { + s, _ := newSession(t) + channelID := s.ChannelIDString() + authorizedSigner := s.AuthorizedSigner() + action := s.OpenAction(1_000_000, "txsig123") + if action.Open == nil { + t.Fatal("expected an Open action") + } + p := action.Open + if p.Mode != intents.SessionModePush { + t.Fatalf("mode = %q, want push", p.Mode) + } + if p.Deposit == nil || *p.Deposit != "1000000" { + t.Fatalf("deposit = %v, want \"1000000\"", p.Deposit) + } + if p.Signature != "txsig123" { + t.Fatalf("signature = %q, want txsig123", p.Signature) + } + if p.ChannelID == nil || *p.ChannelID != channelID { + t.Fatalf("channelId = %v, want %q", p.ChannelID, channelID) + } + if p.AuthorizedSigner != authorizedSigner { + t.Fatalf("authorizedSigner = %q, want %q", p.AuthorizedSigner, authorizedSigner) + } +} + +func TestOpenPaymentChannelActionFields(t *testing.T) { + s, _ := newSession(t) + channelID := s.ChannelIDString() + action := s.OpenPaymentChannelAction(9_000, "payer", "payee", "mint", 42, 60, "open-sig") + if action.Open == nil { + t.Fatal("expected an Open action") + } + p := action.Open + if p.Mode != intents.SessionModePush { + t.Fatalf("mode = %q, want push", p.Mode) + } + if p.ChannelID == nil || *p.ChannelID != channelID { + t.Fatalf("channelId = %v, want %q", p.ChannelID, channelID) + } + if p.Deposit == nil || *p.Deposit != "9000" { + t.Fatalf("deposit = %v, want \"9000\"", p.Deposit) + } + if p.Payer == nil || *p.Payer != "payer" { + t.Fatalf("payer = %v, want \"payer\"", p.Payer) + } + if p.Payee == nil || *p.Payee != "payee" { + t.Fatalf("payee = %v, want \"payee\"", p.Payee) + } + if p.Mint == nil || *p.Mint != "mint" { + t.Fatalf("mint = %v, want \"mint\"", p.Mint) + } + if p.Salt == nil || *p.Salt != 42 { + t.Fatalf("salt = %v, want 42", p.Salt) + } + if p.GracePeriod == nil || *p.GracePeriod != 60 { + t.Fatalf("gracePeriod = %v, want 60", p.GracePeriod) + } + if p.Signature != "open-sig" { + t.Fatalf("signature = %q, want open-sig", p.Signature) + } +} + +func TestOpenPaymentChannelActionPullMode(t *testing.T) { + s, _ := newSession(t) + channelID := s.ChannelIDString() + action := s.OpenPaymentChannelActionWithMode( + intents.SessionModePull, 9_000, "payer", "payee", "mint", 42, 60, "pending") + if action.Open == nil { + t.Fatal("expected an Open action") + } + p := action.Open + if p.Mode != intents.SessionModePull { + t.Fatalf("mode = %q, want pull", p.Mode) + } + if p.ChannelID == nil || *p.ChannelID != channelID { + t.Fatalf("channelId = %v, want %q", p.ChannelID, channelID) + } + if p.Deposit == nil || *p.Deposit != "9000" { + t.Fatalf("deposit = %v, want \"9000\"", p.Deposit) + } + if p.TokenAccount != nil { + t.Fatalf("tokenAccount = %v, want nil", p.TokenAccount) + } + if p.ApprovedAmount != nil { + t.Fatalf("approvedAmount = %v, want nil", p.ApprovedAmount) + } +} + +func TestOpenPullActionFields(t *testing.T) { + s, _ := newSession(t) + channelID := s.ChannelIDString() // used as tokenAccount in pull mode + authorizedSigner := s.AuthorizedSigner() + action := s.OpenPullAction(5_000_000, "wallet123", "approvesig") + if action.Open == nil { + t.Fatal("expected an Open action") + } + p := action.Open + if p.Mode != intents.SessionModePull { + t.Fatalf("mode = %q, want pull", p.Mode) + } + if p.ApprovedAmount == nil || *p.ApprovedAmount != "5000000" { + t.Fatalf("approvedAmount = %v, want \"5000000\"", p.ApprovedAmount) + } + if p.Signature != "approvesig" { + t.Fatalf("signature = %q, want approvesig", p.Signature) + } + if p.TokenAccount == nil || *p.TokenAccount != channelID { + t.Fatalf("tokenAccount = %v, want %q", p.TokenAccount, channelID) + } + if p.Owner == nil || *p.Owner != "wallet123" { + t.Fatalf("owner = %v, want \"wallet123\"", p.Owner) + } + if p.AuthorizedSigner != authorizedSigner { + t.Fatalf("authorizedSigner = %q, want %q", p.AuthorizedSigner, authorizedSigner) + } + if p.ChannelID != nil { + t.Fatalf("channelId = %v, want nil", p.ChannelID) + } + if p.Deposit != nil { + t.Fatalf("deposit = %v, want nil", p.Deposit) + } +} + +func TestTopUpActionFields(t *testing.T) { + s, _ := newSession(t) + action := s.TopUpAction(5_000_000, "topuptx") + if action.TopUp == nil { + t.Fatal("expected a TopUp action") + } + p := action.TopUp + if p.ChannelID != s.ChannelIDString() { + t.Fatalf("channelId = %q, want %q", p.ChannelID, s.ChannelIDString()) + } + if p.NewDeposit != "5000000" { + t.Fatalf("newDeposit = %q, want \"5000000\"", p.NewDeposit) + } + if p.Signature != "topuptx" { + t.Fatalf("signature = %q, want topuptx", p.Signature) + } +} + +func TestCloseActionNoFinalIncrement(t *testing.T) { + s, _ := newSession(t) + action, err := s.CloseAction(0) + if err != nil { + t.Fatalf("close action: %v", err) + } + if action.Close == nil { + t.Fatal("expected a Close action") + } + if action.Close.Voucher != nil { + t.Fatal("close with zero increment should carry no voucher") + } + if action.Close.ChannelID != s.ChannelIDString() { + t.Fatalf("channelId = %q, want %q", action.Close.ChannelID, s.ChannelIDString()) + } +} + +func TestCloseActionWithFinalIncrement(t *testing.T) { + s, _ := newSession(t) + if _, err := s.SignIncrement(100); err != nil { + t.Fatalf("sign increment: %v", err) + } + action, err := s.CloseAction(50) + if err != nil { + t.Fatalf("close action: %v", err) + } + if action.Close == nil || action.Close.Voucher == nil { + t.Fatal("expected a Close action with a voucher") + } + if action.Close.Voucher.Data.Cumulative != "150" { + t.Fatalf("final voucher cumulative = %q, want \"150\"", action.Close.Voucher.Data.Cumulative) + } +} + +// TestSerializeSessionCredentialRoundTrip serializes a voucher action into an +// Authorization header and confirms it round-trips through ParseAuthorization +// back into the same SessionAction. +func TestSerializeSessionCredentialRoundTrip(t *testing.T) { + s, _ := newSession(t) + challenge := newSessionChallenge(t, "100000") + + action, err := s.VoucherAction(500) + if err != nil { + t.Fatalf("voucher action: %v", err) + } + header, err := SerializeSessionCredential(challenge, action) + if err != nil { + t.Fatalf("serialize credential: %v", err) + } + if !strings.HasPrefix(header, core.PaymentScheme+" ") { + t.Fatalf("header = %q, want %q prefix", header, core.PaymentScheme) + } + + credential, err := core.ParseAuthorization(header) + if err != nil { + t.Fatalf("parse authorization: %v", err) + } + if credential.Challenge.ID != challenge.ID { + t.Fatalf("echoed challenge id = %q, want %q", credential.Challenge.ID, challenge.ID) + } + + var decoded intents.SessionAction + if err := credential.PayloadAs(&decoded); err != nil { + t.Fatalf("decode payload: %v", err) + } + if decoded.Voucher == nil { + t.Fatal("decoded action is not a voucher") + } + if decoded.Voucher.Voucher.Data.Cumulative != "500" { + t.Fatalf("decoded cumulative = %q, want \"500\"", decoded.Voucher.Voucher.Data.Cumulative) + } + if decoded.Voucher.Voucher.Signature != action.Voucher.Voucher.Signature { + t.Fatal("decoded voucher signature does not match") + } +} + +// TestParseSessionChallenge parses a WWW-Authenticate session challenge and +// decodes the embedded SessionRequest. +func TestParseSessionChallenge(t *testing.T) { + challenge := newSessionChallenge(t, "250000") + headerValue, err := core.FormatWWWAuthenticate(challenge) + if err != nil { + t.Fatalf("format www-authenticate: %v", err) + } + + parsed, request, err := ParseSessionChallenge(headerValue) + if err != nil { + t.Fatalf("parse session challenge: %v", err) + } + if parsed.ID != challenge.ID { + t.Fatalf("parsed id = %q, want %q", parsed.ID, challenge.ID) + } + if request.Cap != "250000" { + t.Fatalf("request cap = %q, want \"250000\"", request.Cap) + } + if request.Currency != "USDC" { + t.Fatalf("request currency = %q, want \"USDC\"", request.Currency) + } +} + +func TestParseSessionChallengeRejectsNonSession(t *testing.T) { + chargeRequest, err := core.NewBase64URLJSONValue(map[string]any{ + "amount": "1000", + "currency": "USDC", + "recipient": testutil.NewPrivateKey().PublicKey().String(), + }) + if err != nil { + t.Fatalf("encode charge request: %v", err) + } + challenge := core.NewChallengeWithSecret( + "secret", "api", core.NewMethodName("solana"), core.NewIntentName("charge"), chargeRequest) + headerValue, err := core.FormatWWWAuthenticate(challenge) + if err != nil { + t.Fatalf("format www-authenticate: %v", err) + } + if _, _, err := ParseSessionChallenge(headerValue); err == nil { + t.Fatal("a charge challenge should be rejected by ParseSessionChallenge") + } +} + +func TestParseSessionChallengeRejectsMalformedHeader(t *testing.T) { + if _, _, err := ParseSessionChallenge("Basic realm=\"x\""); err == nil { + t.Fatal("a non-Payment header should be rejected") + } +} + +// TestSerializeSessionCredentialRejectsEmptyAction confirms the credential +// serializer surfaces the SessionAction marshal error when no variant is set. +func TestSerializeSessionCredentialRejectsEmptyAction(t *testing.T) { + challenge := newSessionChallenge(t, "1000") + if _, err := SerializeSessionCredential(challenge, intents.SessionAction{}); err == nil { + t.Fatal("an empty session action should fail to serialize") + } +} + +// TestParseSessionChallengeRejectsUndecodableRequest confirms a session +// challenge whose request bytes are not a SessionRequest object is rejected. +func TestParseSessionChallengeRejectsUndecodableRequest(t *testing.T) { + // A bare JSON array is valid base64url JSON but not a SessionRequest object. + encoded, err := core.NewBase64URLJSONValue([]string{"not", "an", "object"}) + if err != nil { + t.Fatalf("encode: %v", err) + } + challenge := core.NewChallengeWithSecret( + "secret", "api", core.NewMethodName("solana"), core.NewIntentName("session"), encoded) + headerValue, err := core.FormatWWWAuthenticate(challenge) + if err != nil { + t.Fatalf("format: %v", err) + } + if _, _, err := ParseSessionChallenge(headerValue); err == nil { + t.Fatal("a non-object session request should be rejected") + } +} + +// TestParseCumulativeRejectsInvalid exercises the watermark parser guard +// directly, including negative, overflowing, and non-numeric inputs. +func TestParseCumulativeRejectsInvalid(t *testing.T) { + for _, bad := range []string{"-1", "18446744073709551616", "abc", ""} { + if _, err := parseCumulative(bad); err == nil { + t.Fatalf("expected rejection for %q", bad) + } + } + v, err := parseCumulative("18446744073709551615") + if err != nil || v != ^uint64(0) { + t.Fatalf("u64 max should parse: %d %v", v, err) + } +} + +// newSessionChallenge builds an HMAC-bound session challenge carrying a +// SessionRequest with the given cap. +func newSessionChallenge(t *testing.T, sessionCap string) core.PaymentChallenge { + t.Helper() + request := intents.SessionRequest{ + Cap: sessionCap, + Currency: "USDC", + Operator: testutil.NewPrivateKey().PublicKey().String(), + Recipient: testutil.NewPrivateKey().PublicKey().String(), + } + encoded, err := core.NewBase64URLJSONValue(request) + if err != nil { + t.Fatalf("encode session request: %v", err) + } + return core.NewChallengeWithSecret( + "secret", "api", core.NewMethodName("solana"), core.NewIntentName("session"), encoded) +} diff --git a/go/protocols/mpp/intents/charge.go b/go/protocols/mpp/intents/charge.go index efbfc1856..4b990552c 100644 --- a/go/protocols/mpp/intents/charge.go +++ b/go/protocols/mpp/intents/charge.go @@ -1,9 +1,10 @@ -// Package intents carries the MPP intent request bodies. Today this is -// the charge intent (ChargeRequest, with string-encoded base-unit -// amounts so JSON consumers without u64 safety stay correct), plus the -// ParseUnits helper that converts a human-readable decimal amount into -// base units at the SDK boundary. Wire format mirrors -// rust/src/protocol/intents/charge.rs. +// Package intents carries the MPP intent request bodies: the charge intent +// (ChargeRequest, with string-encoded base-unit amounts so JSON consumers +// without u64 safety stay correct) and the session intent (SessionRequest plus +// the SessionAction credential union and signed vouchers). It also exposes the +// ParseUnits helper that converts a human-readable decimal amount into base +// units at the SDK boundary. The JSON wire format is identical across the +// language SDKs; the cross-language harness pins it. package intents import ( @@ -14,13 +15,26 @@ import ( // ChargeRequest is the method-agnostic charge intent body. type ChargeRequest struct { - Amount string `json:"amount"` - Currency string `json:"currency"` - Recipient string `json:"recipient,omitempty"` - Description string `json:"description,omitempty"` - ExternalID string `json:"externalId,omitempty"` - MethodDetails any `json:"methodDetails,omitempty"` + // Amount is the charge amount in token base units, as an unsigned + // decimal string so JSON consumers without u64 safety stay exact + // (e.g. "100000" = 0.10 USDC at 6 decimals). + Amount string `json:"amount"` + // Currency is the asset identifier (e.g. "USDC" or a mint address). + Currency string `json:"currency"` + // Recipient is the payee address (base58 for the Solana method); + // omitted when the payment method's details carry the destination. + Recipient string `json:"recipient,omitempty"` + // Description is an optional human-readable label for the charge. + Description string `json:"description,omitempty"` + // ExternalID is an optional merchant reference, echoed back in the + // payment receipt. + ExternalID string `json:"externalId,omitempty"` + // MethodDetails is the method-specific payload, opaque at the intent + // layer; omitted when nil. + MethodDetails any `json:"methodDetails,omitempty"` + // Decimals, when non-nil, marks Amount as a human-readable decimal + // that WithBaseUnits converts to base units. Never serialized. Decimals *uint8 `json:"-"` } diff --git a/go/protocols/mpp/intents/session.go b/go/protocols/mpp/intents/session.go new file mode 100644 index 000000000..7199947b8 --- /dev/null +++ b/go/protocols/mpp/intents/session.go @@ -0,0 +1,846 @@ +package intents + +// Session intent request and voucher types. +// +// The session intent opens a payment channel between a client and server, +// allowing incremental payments via off-chain signed vouchers backed by the +// on-chain payment-channels program. The JSON wire format is identical across +// the language SDKs; the cross-language harness pins it. + +import ( + "encoding/json" + "fmt" + "strconv" + + "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" +) + +// DefaultSessionExpiresAt is the default session voucher/directive expiry: +// 2100-01-01T00:00:00Z. +// +// This stays below JavaScript's max safe integer so JSON intermediaries do not +// round it before the credential is decoded. +const DefaultSessionExpiresAt int64 = 4_102_444_800 + +// SessionMode is the on-chain funding mechanism for a session. +// +// Advertised by the server in SessionRequest.Modes; the client picks the mode +// it will use when sending its open action. +type SessionMode string + +const ( + // SessionModePush is a payment channel backed by an on-chain escrow + // deposit (client-funded). + SessionModePush SessionMode = "push" + + // SessionModePull is an operator-assisted pull session. Voucher authority + // is declared separately via SessionPullVoucherStrategy. + SessionModePull SessionMode = "pull" +) + +// SessionPullVoucherStrategy is the voucher authority used when +// SessionModePull is advertised. +type SessionPullVoucherStrategy string + +const ( + // SessionPullVoucherStrategyClientVoucher means the client signs + // cumulative vouchers. + SessionPullVoucherStrategyClientVoucher SessionPullVoucherStrategy = "clientVoucher" + + // SessionPullVoucherStrategyOperatedVoucher means the operator signs + // vouchers after metering/receipts. + SessionPullVoucherStrategyOperatedVoucher SessionPullVoucherStrategy = "operatedVoucher" +) + +// CommitStatus is the commit receipt status. +type CommitStatus string + +const ( + // CommitStatusCommitted is the first successful commit for the delivery. + CommitStatusCommitted CommitStatus = "committed" + + // CommitStatusReplayed is an idempotent replay of a previously accepted + // commit. + CommitStatusReplayed CommitStatus = "replayed" +) + +// SessionRequest is the session intent request — the payload embedded in a 402 +// challenge. Describes the channel parameters: cap, currency, splits, network, +// etc. +type SessionRequest struct { + // Cap is the maximum total amount the client may spend in this session + // (base units). + Cap string `json:"cap"` + + // Currency/asset identifier (e.g., "USDC", mint address). + Currency string `json:"currency"` + + // Decimals is the token decimals (default 6 for USDC-like tokens). + Decimals *uint8 `json:"decimals,omitempty"` + + // Network is the Solana network: "mainnet", "devnet", "localnet". + Network *string `json:"network,omitempty"` + + // Operator (server) public key (base58). + Operator string `json:"operator"` + + // Recipient is the primary recipient for channel proceeds (base58). + Recipient string `json:"recipient"` + + // Splits are optional fixed portions routed to specific recipients at + // close. Omitted when empty. + Splits []SessionSplit `json:"splits,omitempty"` + + // ProgramID is the channel program ID (base58). Defaults to the canonical + // payment-channels program. + ProgramID *string `json:"programId,omitempty"` + + // Description is a human-readable description. + Description *string `json:"description,omitempty"` + + // ExternalID is a merchant reference ID. + ExternalID *string `json:"externalId,omitempty"` + + // MinVoucherDelta is the minimum voucher increment (base units). Prevents + // micro-increment spam. + MinVoucherDelta *string `json:"minVoucherDelta,omitempty"` + + // Modes are the session modes supported by this server. + // + // Omitted/empty means only SessionModePush is supported. The client MUST + // use one of the advertised modes in its open action. + Modes []SessionMode `json:"modes,omitempty"` + + // PullVoucherStrategy is the voucher authority for pull-mode sessions. + // + // Required when Modes includes SessionModePull. Omitted when pull is not + // supported. + PullVoucherStrategy *SessionPullVoucherStrategy `json:"pullVoucherStrategy,omitempty"` + + // RecentBlockhash is a recent blockhash pre-fetched by the server + // (base58). Included when the client needs to build server-broadcast + // transactions without a second RPC round-trip. + RecentBlockhash *string `json:"recentBlockhash,omitempty"` +} + +// SessionSplit is a payment split committed at channel open; distributed to a +// specific recipient when the channel closes. +type SessionSplit struct { + // Recipient address (base58). + Recipient string `json:"recipient"` + + // BPS is the share in basis points. + BPS uint16 `json:"bps"` +} + +// ── Client actions ── + +// sessionActionTag is the discriminator used by SessionAction's tagged-union +// serialization. The wire values are camelCase; note "topUp", not "topup". +type sessionActionTag string + +const ( + sessionActionOpen sessionActionTag = "open" + sessionActionVoucher sessionActionTag = "voucher" + sessionActionCommit sessionActionTag = "commit" + sessionActionTopUp sessionActionTag = "topUp" + sessionActionClose sessionActionTag = "close" +) + +// SessionAction is the action submitted by the client in an Authorization +// header. +// +// Serialized as a tagged object with +// "action": "open" | "voucher" | "commit" | "topUp" | "close", +// with the payload fields flattened alongside the discriminator. Exactly one of +// the payload pointers is non-nil for a valid action. +type SessionAction struct { + // Open a new channel/delegation and start the session. + Open *OpenPayload + + // Voucher submits a signed voucher authorizing payment for an API call. + Voucher *VoucherPayload + + // Commit a metered delivery by attaching a signed voucher. + Commit *CommitPayload + + // TopUp an existing channel's deposit. + TopUp *TopUpPayload + + // Close requests cooperative close of the channel. + Close *ClosePayload +} + +// NewOpenAction wraps an OpenPayload as a SessionAction. +func NewOpenAction(payload OpenPayload) SessionAction { + return SessionAction{Open: &payload} +} + +// NewVoucherAction wraps a VoucherPayload as a SessionAction. +func NewVoucherAction(payload VoucherPayload) SessionAction { + return SessionAction{Voucher: &payload} +} + +// NewCommitAction wraps a CommitPayload as a SessionAction. +func NewCommitAction(payload CommitPayload) SessionAction { + return SessionAction{Commit: &payload} +} + +// NewTopUpAction wraps a TopUpPayload as a SessionAction. +func NewTopUpAction(payload TopUpPayload) SessionAction { + return SessionAction{TopUp: &payload} +} + +// NewCloseAction wraps a ClosePayload as a SessionAction. +func NewCloseAction(payload ClosePayload) SessionAction { + return SessionAction{Close: &payload} +} + +// MarshalJSON flattens the active payload alongside an "action" discriminator. +func (a SessionAction) MarshalJSON() ([]byte, error) { + var tag sessionActionTag + var payload any + count := 0 + if a.Open != nil { + tag, payload = sessionActionOpen, a.Open + count++ + } + if a.Voucher != nil { + tag, payload = sessionActionVoucher, a.Voucher + count++ + } + if a.Commit != nil { + tag, payload = sessionActionCommit, a.Commit + count++ + } + if a.TopUp != nil { + tag, payload = sessionActionTopUp, a.TopUp + count++ + } + if a.Close != nil { + tag, payload = sessionActionClose, a.Close + count++ + } + if count == 0 { + return nil, fmt.Errorf("session action: no variant set") + } + if count > 1 { + return nil, fmt.Errorf("session action: multiple variants set") + } + + raw, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal session action payload: %w", err) + } + var fields map[string]json.RawMessage + if err := json.Unmarshal(raw, &fields); err != nil { + return nil, fmt.Errorf("flatten session action payload: %w", err) + } + tagRaw, err := json.Marshal(string(tag)) + if err != nil { + return nil, fmt.Errorf("marshal session action tag: %w", err) + } + fields["action"] = tagRaw + out, err := json.Marshal(fields) + if err != nil { + return nil, fmt.Errorf("marshal session action: %w", err) + } + return out, nil +} + +// UnmarshalJSON reads the "action" discriminator and decodes the flattened +// payload into the matching variant. +func (a *SessionAction) UnmarshalJSON(data []byte) error { + var probe struct { + Action sessionActionTag `json:"action"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return fmt.Errorf("read session action tag: %w", err) + } + + *a = SessionAction{} + switch probe.Action { + case sessionActionOpen: + var p OpenPayload + if err := json.Unmarshal(data, &p); err != nil { + return fmt.Errorf("decode open action: %w", err) + } + a.Open = &p + case sessionActionVoucher: + var p VoucherPayload + if err := json.Unmarshal(data, &p); err != nil { + return fmt.Errorf("decode voucher action: %w", err) + } + a.Voucher = &p + case sessionActionCommit: + var p CommitPayload + if err := json.Unmarshal(data, &p); err != nil { + return fmt.Errorf("decode commit action: %w", err) + } + a.Commit = &p + case sessionActionTopUp: + var p TopUpPayload + if err := json.Unmarshal(data, &p); err != nil { + return fmt.Errorf("decode topUp action: %w", err) + } + a.TopUp = &p + case sessionActionClose: + var p ClosePayload + if err := json.Unmarshal(data, &p); err != nil { + return fmt.Errorf("decode close action: %w", err) + } + a.Close = &p + case "": + return fmt.Errorf("session action: missing action discriminator") + default: + return fmt.Errorf("session action: unknown action %q", probe.Action) + } + return nil +} + +// OpenPayload is the payload for the open action. +// +// Use OpenPayloadPush, OpenPayloadPaymentChannel, or OpenPayloadPull to +// construct. Inspect Mode to distinguish variants on the server. +// +// Salt marshals as a decimal string (authorization headers are JSON +// canonicalized, and arbitrary uint64 values are not safe JSON numbers) and +// decodes from either a string or a JSON number. +type OpenPayload struct { + // Mode is the session mode discriminant. Required (no default). + Mode SessionMode `json:"mode"` + + // ── Push mode ── + + // ChannelID is the payment-channel address (base58). Required for push + // mode. + ChannelID *string `json:"channelId,omitempty"` + + // Deposit locked on-chain (base units). Required for push mode. + Deposit *string `json:"deposit,omitempty"` + + // Payer is the client wallet that funds the payment channel. + Payer *string `json:"payer,omitempty"` + + // Payee is the primary channel payee. + Payee *string `json:"payee,omitempty"` + + // Mint is the SPL mint locked in the channel. + Mint *string `json:"mint,omitempty"` + + // Salt used in the channel PDA seeds. Serialized as a decimal string. + Salt *uint64 `json:"-"` + + // GracePeriod used by the on-chain close path. + GracePeriod *uint32 `json:"gracePeriod,omitempty"` + + // Transaction is the signed payment-channel open transaction (base64), + // when the client wants the server/operator to broadcast it. + Transaction *string `json:"transaction,omitempty"` + + // ── Pull mode ── + + // TokenAccount is the SPL token account being delegated (base58). Required + // for pull mode. + TokenAccount *string `json:"tokenAccount,omitempty"` + + // ApprovedAmount is the amount approved for operator delegation (base + // units). Required for pull mode. + ApprovedAmount *string `json:"approvedAmount,omitempty"` + + // Owner is the client wallet pubkey (base58). Required for pull mode. + Owner *string `json:"owner,omitempty"` + + // InitMultiDelegateTx is a pre-signed transaction (base64) that creates + // the MultiDelegate PDA and an initial FixedDelegation. + InitMultiDelegateTx *string `json:"initMultiDelegateTx,omitempty"` + + // UpdateDelegationTx is a pre-signed transaction (base64) that creates or + // raises the FixedDelegation cap. + UpdateDelegationTx *string `json:"updateDelegationTx,omitempty"` + + // ── Shared ── + + // AuthorizedSigner is the public key authorized to sign vouchers for this + // session (base58). Usually an ephemeral key generated by the client. + AuthorizedSigner string `json:"authorizedSigner"` + + // Signature is the transaction signature (base58) proving the on-chain + // action. + Signature string `json:"signature"` +} + +// openPayloadJSON is the wire shape of OpenPayload with salt typed as +// json.RawMessage so it can be encoded as a string and decoded from +// string-or-number. +type openPayloadJSON struct { + Mode SessionMode `json:"mode"` // funding mode discriminant ("push" or "pull") + ChannelID *string `json:"channelId,omitempty"` // payment-channel address (base58); push mode + Deposit *string `json:"deposit,omitempty"` // on-chain escrow deposit (base units); push mode + Payer *string `json:"payer,omitempty"` // funding client wallet (base58) + Payee *string `json:"payee,omitempty"` // primary channel payee (base58) + Mint *string `json:"mint,omitempty"` // SPL mint locked in the channel (base58) + Salt json.RawMessage `json:"salt,omitempty"` // PDA-seed salt; encoded as decimal string, decoded string-or-number + GracePeriod *uint32 `json:"gracePeriod,omitempty"` // on-chain close grace period + Transaction *string `json:"transaction,omitempty"` // signed channel-open tx (base64) for server broadcast + TokenAccount *string `json:"tokenAccount,omitempty"` // delegated SPL token account (base58); pull mode + ApprovedAmount *string `json:"approvedAmount,omitempty"` // operator delegation cap (base units); pull mode + Owner *string `json:"owner,omitempty"` // client wallet pubkey (base58); pull mode + InitMultiDelegateTx *string `json:"initMultiDelegateTx,omitempty"` // pre-signed MultiDelegate init tx (base64) + UpdateDelegationTx *string `json:"updateDelegationTx,omitempty"` // pre-signed delegation cap-update tx (base64) + AuthorizedSigner string `json:"authorizedSigner"` // voucher-signing session pubkey (base58) + Signature string `json:"signature"` // on-chain proof tx signature (base58) +} + +// MarshalJSON serializes Salt as a decimal string. +func (p OpenPayload) MarshalJSON() ([]byte, error) { + wire := openPayloadJSON{ + Mode: p.Mode, + ChannelID: p.ChannelID, + Deposit: p.Deposit, + Payer: p.Payer, + Payee: p.Payee, + Mint: p.Mint, + GracePeriod: p.GracePeriod, + Transaction: p.Transaction, + TokenAccount: p.TokenAccount, + ApprovedAmount: p.ApprovedAmount, + Owner: p.Owner, + InitMultiDelegateTx: p.InitMultiDelegateTx, + UpdateDelegationTx: p.UpdateDelegationTx, + AuthorizedSigner: p.AuthorizedSigner, + Signature: p.Signature, + } + if p.Salt != nil { + raw, err := json.Marshal(strconv.FormatUint(*p.Salt, 10)) + if err != nil { + return nil, fmt.Errorf("marshal salt: %w", err) + } + wire.Salt = raw + } + out, err := json.Marshal(wire) + if err != nil { + return nil, fmt.Errorf("marshal open payload: %w", err) + } + return out, nil +} + +// UnmarshalJSON decodes Salt from either a decimal string or a JSON number. +func (p *OpenPayload) UnmarshalJSON(data []byte) error { + var wire openPayloadJSON + if err := json.Unmarshal(data, &wire); err != nil { + return fmt.Errorf("decode open payload: %w", err) + } + *p = OpenPayload{ + Mode: wire.Mode, + ChannelID: wire.ChannelID, + Deposit: wire.Deposit, + Payer: wire.Payer, + Payee: wire.Payee, + Mint: wire.Mint, + GracePeriod: wire.GracePeriod, + Transaction: wire.Transaction, + TokenAccount: wire.TokenAccount, + ApprovedAmount: wire.ApprovedAmount, + Owner: wire.Owner, + InitMultiDelegateTx: wire.InitMultiDelegateTx, + UpdateDelegationTx: wire.UpdateDelegationTx, + AuthorizedSigner: wire.AuthorizedSigner, + Signature: wire.Signature, + } + if p.Mode == "" { + return fmt.Errorf("open payload: missing mode") + } + salt, err := parseOptionalSalt(wire.Salt) + if err != nil { + return err + } + p.Salt = salt + return nil +} + +// parseOptionalSalt parses a salt value that may be absent, null, a decimal +// string, or an unsigned 64-bit JSON number. +func parseOptionalSalt(raw json.RawMessage) (*uint64, error) { + if len(raw) == 0 || string(raw) == "null" { + return nil, nil + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + return nil, fmt.Errorf("decode salt: %w", err) + } + switch v := value.(type) { + case string: + parsed, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return nil, fmt.Errorf("salt must be a decimal string: %w", err) + } + return &parsed, nil + case float64: + // Standard json decoding yields float64 for numbers. Recover the + // integer value from the raw bytes to avoid precision loss on large + // u64 values. + parsed, err := strconv.ParseUint(string(raw), 10, 64) + if err != nil { + return nil, fmt.Errorf("salt must be an unsigned 64-bit integer: %w", err) + } + return &parsed, nil + default: + return nil, fmt.Errorf("salt must be a decimal string or unsigned 64-bit integer") + } +} + +// OpenPayloadPush constructs a push payment-channel open payload. +func OpenPayloadPush(channelID, deposit, authorizedSigner, signature string) OpenPayload { + return OpenPayload{ + Mode: SessionModePush, + ChannelID: &channelID, + Deposit: &deposit, + AuthorizedSigner: authorizedSigner, + Signature: signature, + } +} + +// OpenPayloadPaymentChannel constructs a payment-channel push open payload. +func OpenPayloadPaymentChannel( + channelID, deposit, payer, payee, mint string, + salt uint64, + gracePeriod uint32, + authorizedSigner, signature string, +) OpenPayload { + return OpenPayloadPaymentChannelWithMode( + SessionModePush, + channelID, deposit, payer, payee, mint, + salt, gracePeriod, authorizedSigner, signature, + ) +} + +// OpenPayloadPaymentChannelWithMode constructs a payment-channel open payload +// with an explicit submission mode. +func OpenPayloadPaymentChannelWithMode( + mode SessionMode, + channelID, deposit, payer, payee, mint string, + salt uint64, + gracePeriod uint32, + authorizedSigner, signature string, +) OpenPayload { + return OpenPayload{ + Mode: mode, + ChannelID: &channelID, + Deposit: &deposit, + Payer: &payer, + Payee: &payee, + Mint: &mint, + Salt: &salt, + GracePeriod: &gracePeriod, + AuthorizedSigner: authorizedSigner, + Signature: signature, + } +} + +// OpenPayloadPull constructs a pull (SPL delegation) open payload. +func OpenPayloadPull(tokenAccount, approvedAmount, owner, authorizedSigner, signature string) OpenPayload { + return OpenPayload{ + Mode: SessionModePull, + TokenAccount: &tokenAccount, + ApprovedAmount: &approvedAmount, + Owner: &owner, + AuthorizedSigner: authorizedSigner, + Signature: signature, + } +} + +// WithTransaction attaches a signed open transaction for operator/server +// broadcast. +func (p OpenPayload) WithTransaction(txBase64 string) OpenPayload { + p.Transaction = &txBase64 + return p +} + +// WithInitTx attaches a pre-signed InitMultiDelegate + CreateFixedDelegation +// transaction. +func (p OpenPayload) WithInitTx(txBase64 string) OpenPayload { + p.InitMultiDelegateTx = &txBase64 + return p +} + +// WithUpdateTx attaches a pre-signed CreateFixedDelegation (cap update) +// transaction. +func (p OpenPayload) WithUpdateTx(txBase64 string) OpenPayload { + p.UpdateDelegationTx = &txBase64 + return p +} + +// SessionID returns the session identifier used as the store key. +// +// - Payment channel: ChannelID +// - Operated-voucher pull: TokenAccount +func (p OpenPayload) SessionID() (string, error) { + if p.ChannelID != nil { + return *p.ChannelID, nil + } + switch p.Mode { + case SessionModePush: + return "", fmt.Errorf("push open missing channelId") + case SessionModePull: + if p.TokenAccount != nil { + return *p.TokenAccount, nil + } + return "", fmt.Errorf("pull open missing channelId or tokenAccount") + default: + return "", fmt.Errorf("open payload: unknown mode %q", p.Mode) + } +} + +// DepositAmount returns the deposit / approved amount for this open (base +// units). +func (p OpenPayload) DepositAmount() (uint64, error) { + var raw string + switch { + case p.Deposit != nil: + raw = *p.Deposit + case p.Mode == SessionModePush: + return 0, fmt.Errorf("push open missing deposit") + case p.Mode == SessionModePull: + if p.ApprovedAmount == nil { + return 0, fmt.Errorf("pull open missing deposit or approvedAmount") + } + raw = *p.ApprovedAmount + default: + return 0, fmt.Errorf("open payload: unknown mode %q", p.Mode) + } + value, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid deposit amount: %s", raw) + } + return value, nil +} + +// VoucherPayload is the payload for the voucher action (per-request +// micropayment). +type VoucherPayload struct { + // Voucher is the signed voucher authorizing cumulative spend. + Voucher SignedVoucher `json:"voucher"` +} + +// MeteringDirective is the server-issued metering directive attached to a +// delivered message/response. +// +// Clients treat this like an offset in a message log: once the message has been +// processed successfully, ack/commit signs a voucher for Amount and sends a +// CommitPayload back to the server. +type MeteringDirective struct { + // DeliveryID is the server-generated idempotency key for this delivery. + DeliveryID string `json:"deliveryId"` + + // SessionID is the channel/session ID this delivery belongs to. + SessionID string `json:"sessionId"` + + // Amount owed for this delivery in base units. + Amount string `json:"amount"` + + // Currency/asset identifier (e.g., "USDC", mint address). + Currency string `json:"currency"` + + // Sequence is the monotonic per-session delivery sequence. + Sequence uint64 `json:"sequence"` + + // ExpiresAt is the Unix timestamp after which this directive should not be + // committed. + ExpiresAt int64 `json:"expiresAt"` + + // CommitURL is an optional commit endpoint hint for HTTP transports. + CommitURL *string `json:"commitUrl,omitempty"` + + // Proof is optional server proof or opaque metadata for transport + // integrations. + Proof *string `json:"proof,omitempty"` +} + +// AmountBaseUnits parses Amount as base units. +func (d MeteringDirective) AmountBaseUnits() (uint64, error) { + value, err := strconv.ParseUint(d.Amount, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid metering amount: %s", d.Amount) + } + return value, nil +} + +// MeteringUsage is the final usage reported by a streaming response. +// +// The amount MUST be less than or equal to the amount reserved by the original +// MeteringDirective. +type MeteringUsage struct { + // DeliveryID is the delivery id from the original MeteringDirective. + DeliveryID string `json:"deliveryId"` + + // Amount is the final amount owed for this stream in base units. + Amount string `json:"amount"` +} + +// AmountBaseUnits parses Amount as base units. +func (u MeteringUsage) AmountBaseUnits() (uint64, error) { + value, err := strconv.ParseUint(u.Amount, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid metering usage amount: %s", u.Amount) + } + return value, nil +} + +// MeteredEnvelope is a payload paired with the metering directive required to +// acknowledge it. +type MeteredEnvelope[T any] struct { + // Payload is the delivered application message being charged for. + Payload T `json:"payload"` + + // Metering is the server-issued directive the client commits (by + // signing a voucher covering Metering.Amount) after processing Payload. + Metering MeteringDirective `json:"metering"` +} + +// CommitPayload is the payload for the commit action. +type CommitPayload struct { + // DeliveryID from the original MeteringDirective. + DeliveryID string `json:"deliveryId"` + + // Voucher is the signed voucher authorizing the delivery amount. + Voucher SignedVoucher `json:"voucher"` +} + +// CommitReceipt is the result returned after a delivery commit is accepted. +type CommitReceipt struct { + // DeliveryID from the original MeteringDirective. + DeliveryID string `json:"deliveryId"` + + // SessionID is the channel/session ID. + SessionID string `json:"sessionId"` + + // Amount committed for this delivery in base units. + Amount string `json:"amount"` + + // Cumulative is the new settled cumulative watermark in base units. + Cumulative string `json:"cumulative"` + + // Status is the commit receipt status. + Status CommitStatus `json:"status"` +} + +// TopUpPayload is the payload for the topUp action. +type TopUpPayload struct { + // ChannelID is the on-chain channel address (base58). + ChannelID string `json:"channelId"` + + // NewDeposit is the new total deposit amount after the top-up (base + // units). + NewDeposit string `json:"newDeposit"` + + // Signature is the top-up transaction signature (base58). + Signature string `json:"signature"` +} + +// ClosePayload is the payload for the close action. +type ClosePayload struct { + // ChannelID is the on-chain channel address (base58). + ChannelID string `json:"channelId"` + + // Voucher is the final signed voucher for any remaining balance owed. + Voucher *SignedVoucher `json:"voucher,omitempty"` +} + +// ── Vouchers ── + +// SignedVoucher is a signed voucher authorizing cumulative payment up to its +// cumulative amount. +// +// Vouchers are cumulative: the server always uses the latest valid voucher it +// has received. The client MUST increment the cumulative amount with each +// request. +type SignedVoucher struct { + // Data is the voucher content. + Data VoucherData `json:"data"` + + // Signature is the Ed25519 signature over the payment-channel Borsh + // voucher bytes (base58). + Signature string `json:"signature"` +} + +// VoucherData is the canonical content of a voucher, signed by the client's +// session key. +// +// Serialized as the on-chain VoucherArgs layout before signing: +// channelId || cumulativeAmount(LE u64) || expiresAt(LE i64). +type VoucherData struct { + // ChannelID is the channel/session ID this voucher is bound to (base58). + // + // For push sessions: the payment-channel address. + // For pull sessions: the SPL token account address. + ChannelID string `json:"channelId"` + + // Cumulative is the cumulative amount authorized (base units, + // monotonically increasing). The wire field is "cumulativeAmount" with a + // "cumulative" decode alias. + Cumulative string `json:"cumulativeAmount"` + + // ExpiresAt is the Unix timestamp at which this voucher expires. + ExpiresAt int64 `json:"expiresAt"` + + // Nonce is an optional client-side request counter. It is not included in + // the on-chain voucher bytes. + Nonce *uint64 `json:"nonce,omitempty"` +} + +// voucherDataJSON is the wire shape of VoucherData with the "cumulative" decode +// alias handled explicitly. +type voucherDataJSON struct { + ChannelID string `json:"channelId"` // channel/session ID the voucher is bound to (base58) + CumulativeAmount *string `json:"cumulativeAmount,omitempty"` // canonical cumulative total authorized (base units) + CumulativeAlias *string `json:"cumulative,omitempty"` // decode-only alias accepted for cumulativeAmount + ExpiresAt int64 `json:"expiresAt"` // voucher expiry, Unix epoch seconds + Nonce *uint64 `json:"nonce,omitempty"` // optional client request counter; not signed on-chain +} + +// UnmarshalJSON decodes VoucherData, accepting "cumulative" as an alias for +// "cumulativeAmount". +func (v *VoucherData) UnmarshalJSON(data []byte) error { + var wire voucherDataJSON + if err := json.Unmarshal(data, &wire); err != nil { + return fmt.Errorf("decode voucher data: %w", err) + } + *v = VoucherData{ + ChannelID: wire.ChannelID, + ExpiresAt: wire.ExpiresAt, + Nonce: wire.Nonce, + } + switch { + case wire.CumulativeAmount != nil: + v.Cumulative = *wire.CumulativeAmount + case wire.CumulativeAlias != nil: + v.Cumulative = *wire.CumulativeAlias + default: + // The cumulative amount is required on the wire, so a voucher without + // "cumulativeAmount"/"cumulative" is malformed; reject it here rather + // than leave Cumulative empty and fail with a cryptic parse error later + // when the voucher is signed or recorded. + return fmt.Errorf("voucher data missing cumulativeAmount") + } + return nil +} + +// MessageBytes serializes the voucher to the payment-channels VoucherArgs bytes +// signed by Ed25519: channelId(32) || cumulativeAmount(LE u64) || +// expiresAt(LE i64), for a total of exactly 48 bytes. +func (v VoucherData) MessageBytes() ([]byte, error) { + channelID, err := solana.PublicKeyFromBase58(v.ChannelID) + if err != nil { + return nil, fmt.Errorf("invalid channelId %q: %w", v.ChannelID, err) + } + cumulative, err := strconv.ParseUint(v.Cumulative, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid voucher cumulative") + } + // Delegate to the canonical packer so the 48-byte layout has a single + // source of truth. + return paymentchannels.VoucherMessageBytes(channelID, cumulative, v.ExpiresAt) +} diff --git a/go/protocols/mpp/intents/session_decode_test.go b/go/protocols/mpp/intents/session_decode_test.go new file mode 100644 index 000000000..c615bd0a1 --- /dev/null +++ b/go/protocols/mpp/intents/session_decode_test.go @@ -0,0 +1,64 @@ +package intents + +// Decode-failure coverage for the session wire types: each SessionAction +// variant rejects a malformed flattened payload with a variant-specific +// error, and the OpenPayload/VoucherData deserializers reject non-object +// shapes outright. + +import ( + "encoding/json" + "strings" + "testing" +) + +func TestSessionActionVariantDecodeFailures(t *testing.T) { + cases := []struct { + name string + raw string + want string + }{ + {"open", `{"action":"open","mode":"push","salt":[]}`, "decode open action"}, + {"voucher", `{"action":"voucher","voucher":"not-an-object"}`, "decode voucher action"}, + {"commit", `{"action":"commit","deliveryId":5}`, "decode commit action"}, + {"topUp", `{"action":"topUp","channelId":5}`, "decode topUp action"}, + {"close", `{"action":"close","channelId":5}`, "decode close action"}, + {"tag", `{"action":5}`, "read session action tag"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var action SessionAction + err := json.Unmarshal([]byte(tc.raw), &action) + if err == nil || !strings.Contains(err.Error(), tc.want) { + t.Fatalf("decode %s = %v, want %q", tc.raw, err, tc.want) + } + }) + } +} + +func TestOpenPayloadDecodeRejectsNonObject(t *testing.T) { + var payload OpenPayload + if err := json.Unmarshal([]byte(`"push"`), &payload); err == nil || + !strings.Contains(err.Error(), "decode open payload") { + t.Fatalf("non-object open payload = %v", err) + } +} + +func TestVoucherDataDecodeRejectsNonObject(t *testing.T) { + var data VoucherData + if err := json.Unmarshal([]byte(`5`), &data); err == nil || + !strings.Contains(err.Error(), "decode voucher data") { + t.Fatalf("non-object voucher data = %v", err) + } +} + +func TestSessionActionMarshalRejectsInvalidVariantCounts(t *testing.T) { + var empty SessionAction + if _, err := json.Marshal(empty); err == nil || !strings.Contains(err.Error(), "no variant set") { + t.Fatalf("empty action marshal = %v", err) + } + open := OpenPayloadPush("c", "1", "signer", "sig") + double := SessionAction{Open: &open, Close: &ClosePayload{ChannelID: "c"}} + if _, err := json.Marshal(double); err == nil || !strings.Contains(err.Error(), "multiple variants set") { + t.Fatalf("double action marshal = %v", err) + } +} diff --git a/go/protocols/mpp/intents/session_test.go b/go/protocols/mpp/intents/session_test.go new file mode 100644 index 000000000..66e168110 --- /dev/null +++ b/go/protocols/mpp/intents/session_test.go @@ -0,0 +1,1032 @@ +package intents + +import ( + "encoding/binary" + "encoding/json" + "strings" + "testing" + + "github.com/mr-tron/base58" +) + +func ptrStr(s string) *string { return &s } +func ptrU8(v uint8) *uint8 { return &v } + +// ── SessionMode / strategy / status serde ── + +func TestSessionModeSerialization(t *testing.T) { + tests := []struct { + mode SessionMode + want string + }{ + {SessionModePush, `"push"`}, + {SessionModePull, `"pull"`}, + } + for _, tc := range tests { + t.Run(string(tc.mode), func(t *testing.T) { + got, err := json.Marshal(tc.mode) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(got) != tc.want { + t.Fatalf("got %s want %s", got, tc.want) + } + var back SessionMode + if err := json.Unmarshal(got, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back != tc.mode { + t.Fatalf("roundtrip got %q want %q", back, tc.mode) + } + }) + } +} + +func TestSessionPullVoucherStrategySerialization(t *testing.T) { + tests := []struct { + strategy SessionPullVoucherStrategy + want string + }{ + {SessionPullVoucherStrategyClientVoucher, `"clientVoucher"`}, + {SessionPullVoucherStrategyOperatedVoucher, `"operatedVoucher"`}, + } + for _, tc := range tests { + t.Run(string(tc.strategy), func(t *testing.T) { + got, err := json.Marshal(tc.strategy) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(got) != tc.want { + t.Fatalf("got %s want %s", got, tc.want) + } + var back SessionPullVoucherStrategy + if err := json.Unmarshal(got, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back != tc.strategy { + t.Fatalf("roundtrip got %q want %q", back, tc.strategy) + } + }) + } +} + +func TestCommitStatusSerialization(t *testing.T) { + tests := []struct { + status CommitStatus + want string + }{ + {CommitStatusCommitted, `"committed"`}, + {CommitStatusReplayed, `"replayed"`}, + } + for _, tc := range tests { + t.Run(string(tc.status), func(t *testing.T) { + got, err := json.Marshal(tc.status) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if string(got) != tc.want { + t.Fatalf("got %s want %s", got, tc.want) + } + }) + } +} + +func TestDefaultSessionExpiresAt(t *testing.T) { + if DefaultSessionExpiresAt != 4_102_444_800 { + t.Fatalf("got %d", DefaultSessionExpiresAt) + } +} + +// ── SessionRequest ── + +func TestSessionRequestRoundtrip(t *testing.T) { + decimals := uint8(6) + req := SessionRequest{ + Cap: "10000000", + Currency: "USDC", + Decimals: &decimals, + Network: ptrStr("mainnet"), + Operator: "CXhrFZJLKqjzmP3sjYLcF4dTeXWKCy9e2SXXZ2Yo6MPY", + Recipient: "CXhrFZJLKqjzmP3sjYLcF4dTeXWKCy9e2SXXZ2Yo6MPY", + Description: ptrStr("API session"), + Modes: []SessionMode{SessionModePush}, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var back SessionRequest + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Cap != "10000000" || back.Currency != "USDC" { + t.Fatalf("unexpected back: %#v", back) + } + if back.Description == nil || *back.Description != "API session" { + t.Fatalf("description: %v", back.Description) + } + if len(back.Modes) != 1 || back.Modes[0] != SessionModePush { + t.Fatalf("modes: %v", back.Modes) + } +} + +func TestSessionRequestOmitsEmptyFields(t *testing.T) { + req := SessionRequest{ + Cap: "1000", + Currency: "USDC", + Operator: "op", + Recipient: "rec", + Splits: nil, + Modes: nil, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + for _, key := range []string{"splits", "modes", "decimals", "network", "description", "externalId", "programId", "minVoucherDelta", "pullVoucherStrategy", "recentBlockhash"} { + if strings.Contains(js, key) { + t.Fatalf("expected %q omitted, got %s", key, js) + } + } + for _, key := range []string{"cap", "currency", "operator", "recipient"} { + if !strings.Contains(js, key) { + t.Fatalf("expected %q present, got %s", key, js) + } + } +} + +func TestSessionRequestEmptySplitsOmitted(t *testing.T) { + req := SessionRequest{Cap: "1", Currency: "USDC", Operator: "op", Recipient: "rec", Splits: []SessionSplit{}} + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if strings.Contains(string(data), "splits") { + t.Fatalf("empty splits should be omitted: %s", data) + } +} + +func TestSessionRequestWithModesPushAndPull(t *testing.T) { + strategy := SessionPullVoucherStrategyClientVoucher + req := SessionRequest{ + Cap: "1000", + Currency: "USDC", + Operator: "op", + Recipient: "rec", + Modes: []SessionMode{SessionModePush, SessionModePull}, + PullVoucherStrategy: &strategy, + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + if !strings.Contains(js, `"push"`) || !strings.Contains(js, `"pull"`) { + t.Fatalf("modes missing: %s", js) + } + if !strings.Contains(js, `"pullVoucherStrategy":"clientVoucher"`) { + t.Fatalf("strategy missing: %s", js) + } + var back SessionRequest + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(back.Modes) != 2 || back.Modes[0] != SessionModePush || back.Modes[1] != SessionModePull { + t.Fatalf("modes: %v", back.Modes) + } + if back.PullVoucherStrategy == nil || *back.PullVoucherStrategy != SessionPullVoucherStrategyClientVoucher { + t.Fatalf("strategy: %v", back.PullVoucherStrategy) + } +} + +func TestSessionRequestWithSplits(t *testing.T) { + req := SessionRequest{ + Cap: "1000", + Currency: "USDC", + Operator: "op", + Recipient: "rec", + Splits: []SessionSplit{{Recipient: "s1", BPS: 100}, {Recipient: "s2", BPS: 200}}, + ProgramID: ptrStr("prog123"), + ExternalID: ptrStr("ref-1"), + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var back SessionRequest + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(back.Splits) != 2 || back.Splits[0].BPS != 100 { + t.Fatalf("splits: %v", back.Splits) + } + if back.ProgramID == nil || *back.ProgramID != "prog123" { + t.Fatalf("programId: %v", back.ProgramID) + } + if back.ExternalID == nil || *back.ExternalID != "ref-1" { + t.Fatalf("externalId: %v", back.ExternalID) + } +} + +func TestSessionRequestWithMinVoucherDelta(t *testing.T) { + req := SessionRequest{ + Cap: "10000000", + Currency: "USDC", + Decimals: ptrU8(6), + Network: ptrStr("mainnet"), + Operator: "op", + Recipient: "rec", + MinVoucherDelta: ptrStr("500"), + } + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if !strings.Contains(string(data), `"minVoucherDelta"`) { + t.Fatalf("minVoucherDelta missing: %s", data) + } + var back SessionRequest + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.MinVoucherDelta == nil || *back.MinVoucherDelta != "500" { + t.Fatalf("minVoucherDelta: %v", back.MinVoucherDelta) + } +} + +func TestSessionRequestRecentBlockhashRoundtrip(t *testing.T) { + req := SessionRequest{Cap: "1", Currency: "USDC", Operator: "op", Recipient: "rec", RecentBlockhash: ptrStr("bh1")} + data, err := json.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if !strings.Contains(string(data), `"recentBlockhash":"bh1"`) { + t.Fatalf("recentBlockhash missing: %s", data) + } + var back SessionRequest + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.RecentBlockhash == nil || *back.RecentBlockhash != "bh1" { + t.Fatalf("recentBlockhash: %v", back.RecentBlockhash) + } +} + +// ── OpenPayload constructors ── + +func TestOpenPayloadPushFields(t *testing.T) { + p := OpenPayloadPush("chan1", "1000000", "signer1", "txsig") + if p.Mode != SessionModePush { + t.Fatalf("mode: %q", p.Mode) + } + if p.ChannelID == nil || *p.ChannelID != "chan1" { + t.Fatalf("channelId: %v", p.ChannelID) + } + if p.Deposit == nil || *p.Deposit != "1000000" { + t.Fatalf("deposit: %v", p.Deposit) + } + if p.TokenAccount != nil || p.ApprovedAmount != nil { + t.Fatal("pull fields should be nil") + } + if p.AuthorizedSigner != "signer1" || p.Signature != "txsig" { + t.Fatalf("shared fields: %#v", p) + } +} + +func TestOpenPayloadPullFields(t *testing.T) { + p := OpenPayloadPull("tokacct", "5000000", "wallet1", "signer1", "approvesig") + if p.Mode != SessionModePull { + t.Fatalf("mode: %q", p.Mode) + } + if p.ChannelID != nil || p.Deposit != nil { + t.Fatal("push fields should be nil") + } + if p.TokenAccount == nil || *p.TokenAccount != "tokacct" { + t.Fatalf("tokenAccount: %v", p.TokenAccount) + } + if p.ApprovedAmount == nil || *p.ApprovedAmount != "5000000" { + t.Fatalf("approvedAmount: %v", p.ApprovedAmount) + } + if p.Owner == nil || *p.Owner != "wallet1" { + t.Fatalf("owner: %v", p.Owner) + } +} + +func TestOpenPayloadPaymentChannelAndTxHelpers(t *testing.T) { + p := OpenPayloadPaymentChannel("chan1", "1000000", "payer1", "payee1", "mint1", 99, 45, "signer1", "txsig"). + WithTransaction("open-tx"). + WithInitTx("init-tx"). + WithUpdateTx("update-tx") + + if p.Mode != SessionModePush { + t.Fatalf("mode: %q", p.Mode) + } + id, err := p.SessionID() + if err != nil || id != "chan1" { + t.Fatalf("sessionID: %q %v", id, err) + } + dep, err := p.DepositAmount() + if err != nil || dep != 1_000_000 { + t.Fatalf("depositAmount: %d %v", dep, err) + } + if p.Payer == nil || *p.Payer != "payer1" { + t.Fatalf("payer: %v", p.Payer) + } + if p.Payee == nil || *p.Payee != "payee1" { + t.Fatalf("payee: %v", p.Payee) + } + if p.Mint == nil || *p.Mint != "mint1" { + t.Fatalf("mint: %v", p.Mint) + } + if p.Salt == nil || *p.Salt != 99 { + t.Fatalf("salt: %v", p.Salt) + } + if p.GracePeriod == nil || *p.GracePeriod != 45 { + t.Fatalf("gracePeriod: %v", p.GracePeriod) + } + if p.Transaction == nil || *p.Transaction != "open-tx" { + t.Fatalf("transaction: %v", p.Transaction) + } + if p.InitMultiDelegateTx == nil || *p.InitMultiDelegateTx != "init-tx" { + t.Fatalf("init: %v", p.InitMultiDelegateTx) + } + if p.UpdateDelegationTx == nil || *p.UpdateDelegationTx != "update-tx" { + t.Fatalf("update: %v", p.UpdateDelegationTx) + } +} + +func TestOpenPayloadPullPaymentChannelUsesChannelIDAndDeposit(t *testing.T) { + p := OpenPayloadPaymentChannelWithMode(SessionModePull, "chan1", "1000000", "payer1", "payee1", "mint1", 99, 45, "signer1", "pending"). + WithTransaction("open-tx") + if p.Mode != SessionModePull { + t.Fatalf("mode: %q", p.Mode) + } + id, err := p.SessionID() + if err != nil || id != "chan1" { + t.Fatalf("sessionID: %q %v", id, err) + } + dep, err := p.DepositAmount() + if err != nil || dep != 1_000_000 { + t.Fatalf("depositAmount: %d %v", dep, err) + } + if p.TokenAccount != nil || p.ApprovedAmount != nil { + t.Fatal("token fields should be nil") + } + if p.Transaction == nil || *p.Transaction != "open-tx" { + t.Fatalf("transaction: %v", p.Transaction) + } +} + +func TestOpenPayloadPushSessionIDAndDeposit(t *testing.T) { + p := OpenPayloadPush("chan1", "2000000", "s", "sig") + id, err := p.SessionID() + if err != nil || id != "chan1" { + t.Fatalf("sessionID: %q %v", id, err) + } + dep, err := p.DepositAmount() + if err != nil || dep != 2_000_000 { + t.Fatalf("depositAmount: %d %v", dep, err) + } +} + +func TestOpenPayloadPullSessionIDAndDeposit(t *testing.T) { + p := OpenPayloadPull("tokacct", "3000000", "wallet1", "s", "sig") + id, err := p.SessionID() + if err != nil || id != "tokacct" { + t.Fatalf("sessionID: %q %v", id, err) + } + dep, err := p.DepositAmount() + if err != nil || dep != 3_000_000 { + t.Fatalf("depositAmount: %d %v", dep, err) + } +} + +func TestOpenPayloadMissingRequiredFieldsAndInvalidDeposit(t *testing.T) { + push := OpenPayloadPush("chan1", "bad", "s", "sig") + if _, err := push.DepositAmount(); err == nil { + t.Fatal("expected invalid deposit error") + } + push.Deposit = nil + if _, err := push.DepositAmount(); err == nil { + t.Fatal("expected missing deposit error") + } + push.ChannelID = nil + if _, err := push.SessionID(); err == nil { + t.Fatal("expected missing channelId error") + } + + pull := OpenPayloadPull("tokacct", "bad", "wallet", "s", "sig") + if _, err := pull.DepositAmount(); err == nil { + t.Fatal("expected invalid pull deposit error") + } + pull.ApprovedAmount = nil + if _, err := pull.DepositAmount(); err == nil { + t.Fatal("expected missing approvedAmount error") + } + pull.TokenAccount = nil + if _, err := pull.SessionID(); err == nil { + t.Fatal("expected missing tokenAccount error") + } +} + +func TestOpenPayloadPushRoundtripJSON(t *testing.T) { + p := OpenPayloadPush("chan1", "1000000", "signer1", "txsig") + data, err := json.Marshal(p) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + if !strings.Contains(js, `"mode":"push"`) { + t.Fatalf("mode missing: %s", js) + } + if !strings.Contains(js, `"channelId":"chan1"`) { + t.Fatalf("channelId missing: %s", js) + } + if strings.Contains(js, "tokenAccount") { + t.Fatalf("tokenAccount should be omitted: %s", js) + } + var back OpenPayload + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Mode != SessionModePush || back.ChannelID == nil || *back.ChannelID != "chan1" { + t.Fatalf("back: %#v", back) + } +} + +func TestOpenPayloadPullRoundtripJSON(t *testing.T) { + p := OpenPayloadPull("tokacct", "5000000", "wallet1", "signer1", "approvesig") + data, err := json.Marshal(p) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + if !strings.Contains(js, `"mode":"pull"`) { + t.Fatalf("mode missing: %s", js) + } + if !strings.Contains(js, `"tokenAccount":"tokacct"`) { + t.Fatalf("tokenAccount missing: %s", js) + } + if !strings.Contains(js, `"owner":"wallet1"`) { + t.Fatalf("owner missing: %s", js) + } + if strings.Contains(js, "channelId") { + t.Fatalf("channelId should be omitted: %s", js) + } + var back OpenPayload + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Mode != SessionModePull || back.TokenAccount == nil || *back.TokenAccount != "tokacct" { + t.Fatalf("back: %#v", back) + } + if back.Owner == nil || *back.Owner != "wallet1" { + t.Fatalf("owner: %v", back.Owner) + } +} + +func TestOpenPayloadSaltSerializesAsStringAndAcceptsNumber(t *testing.T) { + const salt = ^uint64(0) - 7 // u64::MAX - 7 + p := OpenPayloadPaymentChannel("chan1", "1000000", "payer1", "payee1", "mint1", salt, 900, "signer1", "txsig") + data, err := json.Marshal(p) + if err != nil { + t.Fatalf("marshal: %v", err) + } + want := `"salt":"18446744073709551608"` + if !strings.Contains(string(data), want) { + t.Fatalf("salt not a decimal string: %s", data) + } + var back OpenPayload + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Salt == nil || *back.Salt != salt { + t.Fatalf("salt roundtrip: %v", back.Salt) + } + + legacy := `{"mode":"push","channelId":"chan1","deposit":"1000000","payer":"payer1","payee":"payee1","mint":"mint1","salt":42,"gracePeriod":900,"authorizedSigner":"signer1","signature":"txsig"}` + var legacyBack OpenPayload + if err := json.Unmarshal([]byte(legacy), &legacyBack); err != nil { + t.Fatalf("legacy unmarshal: %v", err) + } + if legacyBack.Salt == nil || *legacyBack.Salt != 42 { + t.Fatalf("legacy salt: %v", legacyBack.Salt) + } +} + +func TestOpenPayloadSaltBigNumberDecodesWithoutPrecisionLoss(t *testing.T) { + // A u64 number larger than 2^53 must survive number-form decode. + legacy := `{"mode":"push","channelId":"c","deposit":"1","salt":18446744073709551608,"authorizedSigner":"s","signature":"sig"}` + var p OpenPayload + if err := json.Unmarshal([]byte(legacy), &p); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if p.Salt == nil || *p.Salt != ^uint64(0)-7 { + t.Fatalf("big number salt: %v", p.Salt) + } +} + +func TestOpenPayloadSaltInvalid(t *testing.T) { + cases := []string{ + `{"mode":"push","channelId":"c","deposit":"1","salt":"notanumber","authorizedSigner":"s","signature":"sig"}`, + `{"mode":"push","channelId":"c","deposit":"1","salt":-1,"authorizedSigner":"s","signature":"sig"}`, + `{"mode":"push","channelId":"c","deposit":"1","salt":true,"authorizedSigner":"s","signature":"sig"}`, + } + for _, c := range cases { + var p OpenPayload + if err := json.Unmarshal([]byte(c), &p); err == nil { + t.Fatalf("expected error decoding %s", c) + } + } +} + +func TestOpenPayloadSaltNullAndAbsent(t *testing.T) { + null := `{"mode":"push","channelId":"c","deposit":"1","salt":null,"authorizedSigner":"s","signature":"sig"}` + var p OpenPayload + if err := json.Unmarshal([]byte(null), &p); err != nil { + t.Fatalf("unmarshal null: %v", err) + } + if p.Salt != nil { + t.Fatalf("salt should be nil for null, got %v", p.Salt) + } + absent := `{"mode":"push","channelId":"c","deposit":"1","authorizedSigner":"s","signature":"sig"}` + var q OpenPayload + if err := json.Unmarshal([]byte(absent), &q); err != nil { + t.Fatalf("unmarshal absent: %v", err) + } + if q.Salt != nil { + t.Fatalf("salt should be nil when absent, got %v", q.Salt) + } +} + +func TestOpenPayloadMissingModeFailsDecode(t *testing.T) { + js := `{"channelId":"chan1","deposit":"1000","authorizedSigner":"s","signature":"sig"}` + var p OpenPayload + if err := json.Unmarshal([]byte(js), &p); err == nil { + t.Fatal("expected missing mode error") + } +} + +func TestOpenPayloadUnknownModeSessionIDAndDeposit(t *testing.T) { + p := OpenPayload{Mode: SessionMode("weird"), AuthorizedSigner: "s", Signature: "sig"} + if _, err := p.SessionID(); err == nil { + t.Fatal("expected unknown mode sessionID error") + } + if _, err := p.DepositAmount(); err == nil { + t.Fatal("expected unknown mode deposit error") + } +} + +// ── Metering ── + +func TestMeteringAmountParsersAndUsageRoundtrip(t *testing.T) { + directive := MeteringDirective{ + DeliveryID: "d1", + SessionID: "chan1", + Amount: "not-a-number", + Currency: "USDC", + Sequence: 1, + ExpiresAt: DefaultSessionExpiresAt, + Proof: ptrStr("proof"), + } + if _, err := directive.AmountBaseUnits(); err == nil { + t.Fatal("expected invalid metering amount error") + } + + usage := MeteringUsage{DeliveryID: "d1", Amount: "42"} + data, err := json.Marshal(usage) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if !strings.Contains(string(data), `"deliveryId":"d1"`) { + t.Fatalf("deliveryId missing: %s", data) + } + var back MeteringUsage + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + v, err := back.AmountBaseUnits() + if err != nil || v != 42 { + t.Fatalf("amount: %d %v", v, err) + } + + bad := MeteringUsage{DeliveryID: "d1", Amount: "bad"} + if _, err := bad.AmountBaseUnits(); err == nil { + t.Fatal("expected bad usage amount error") + } +} + +func TestMeteringDirectiveAndEnvelopeRoundtrip(t *testing.T) { + directive := MeteringDirective{ + DeliveryID: "d1", + SessionID: "chan1", + Amount: "125", + Currency: "USDC", + Sequence: 7, + ExpiresAt: 4_102_444_800, + CommitURL: ptrStr("https://example.test/commit"), + } + v, err := directive.AmountBaseUnits() + if err != nil || v != 125 { + t.Fatalf("amount: %d %v", v, err) + } + + envelope := MeteredEnvelope[map[string]any]{ + Payload: map[string]any{"ok": true}, + Metering: directive, + } + data, err := json.Marshal(envelope) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + if !strings.Contains(js, `"deliveryId":"d1"`) { + t.Fatalf("deliveryId missing: %s", js) + } + if !strings.Contains(js, `"commitUrl":"https://example.test/commit"`) { + t.Fatalf("commitUrl missing: %s", js) + } + var back MeteredEnvelope[map[string]any] + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Metering.Sequence != 7 { + t.Fatalf("sequence: %d", back.Metering.Sequence) + } + if back.Payload["ok"] != true { + t.Fatalf("payload: %v", back.Payload) + } +} + +func TestMeteringDirectiveOmitsOptionalFields(t *testing.T) { + directive := MeteringDirective{DeliveryID: "d1", SessionID: "c", Amount: "1", Currency: "USDC", Sequence: 1, ExpiresAt: 1} + data, err := json.Marshal(directive) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + if strings.Contains(js, "commitUrl") || strings.Contains(js, "proof") { + t.Fatalf("optional fields should be omitted: %s", js) + } +} + +func TestCommitReceiptRoundtrip(t *testing.T) { + receipt := CommitReceipt{ + DeliveryID: "d1", + SessionID: "chan1", + Amount: "100", + Cumulative: "500", + Status: CommitStatusCommitted, + } + data, err := json.Marshal(receipt) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if !strings.Contains(string(data), `"status":"committed"`) { + t.Fatalf("status: %s", data) + } + var back CommitReceipt + if err := json.Unmarshal(data, &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Status != CommitStatusCommitted || back.Cumulative != "500" { + t.Fatalf("back: %#v", back) + } +} + +// ── SessionAction variants ── + +func mustMarshal(t *testing.T, v any) string { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + return string(data) +} + +func TestSessionActionOpenPushRoundtrip(t *testing.T) { + action := NewOpenAction(OpenPayloadPush("chan123", "5000000", "signer123", "sig456")) + js := mustMarshal(t, action) + if !strings.Contains(js, `"action":"open"`) { + t.Fatalf("action tag missing: %s", js) + } + if !strings.Contains(js, `"mode":"push"`) { + t.Fatalf("mode missing: %s", js) + } + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Open == nil { + t.Fatal("expected Open variant") + } + if back.Open.Mode != SessionModePush { + t.Fatalf("mode: %q", back.Open.Mode) + } + id, err := back.Open.SessionID() + if err != nil || id != "chan123" { + t.Fatalf("sessionID: %q %v", id, err) + } + dep, err := back.Open.DepositAmount() + if err != nil || dep != 5_000_000 { + t.Fatalf("deposit: %d %v", dep, err) + } + if back.Open.AuthorizedSigner != "signer123" { + t.Fatalf("signer: %q", back.Open.AuthorizedSigner) + } +} + +func TestSessionActionOpenPullRoundtrip(t *testing.T) { + action := NewOpenAction(OpenPayloadPull("tokacct", "3000000", "wallet1", "signer1", "approvesig")) + js := mustMarshal(t, action) + if !strings.Contains(js, `"action":"open"`) { + t.Fatalf("action tag missing: %s", js) + } + if !strings.Contains(js, `"mode":"pull"`) || !strings.Contains(js, "tokenAccount") { + t.Fatalf("pull fields missing: %s", js) + } + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Open == nil || back.Open.Mode != SessionModePull { + t.Fatalf("back: %#v", back.Open) + } + id, _ := back.Open.SessionID() + if id != "tokacct" { + t.Fatalf("sessionID: %q", id) + } +} + +func TestSessionActionVoucherRoundtrip(t *testing.T) { + nonce := uint64(3) + action := NewVoucherAction(VoucherPayload{ + Voucher: SignedVoucher{ + Data: VoucherData{ChannelID: "chan1", Cumulative: "500000", ExpiresAt: 1 << 62, Nonce: &nonce}, + Signature: "sig_here", + }, + }) + js := mustMarshal(t, action) + if !strings.Contains(js, `"action":"voucher"`) { + t.Fatalf("action tag missing: %s", js) + } + if !strings.Contains(js, `"cumulativeAmount":"500000"`) { + t.Fatalf("cumulativeAmount missing: %s", js) + } + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Voucher == nil { + t.Fatal("expected Voucher variant") + } + if back.Voucher.Voucher.Data.Cumulative != "500000" { + t.Fatalf("cumulative: %q", back.Voucher.Voucher.Data.Cumulative) + } + if back.Voucher.Voucher.Data.Nonce == nil || *back.Voucher.Voucher.Data.Nonce != 3 { + t.Fatalf("nonce: %v", back.Voucher.Voucher.Data.Nonce) + } +} + +func TestSessionActionCommitRoundtrip(t *testing.T) { + nonce := uint64(3) + action := NewCommitAction(CommitPayload{ + DeliveryID: "delivery-1", + Voucher: SignedVoucher{ + Data: VoucherData{ChannelID: "chan1", Cumulative: "500000", ExpiresAt: 1 << 62, Nonce: &nonce}, + Signature: "sig_here", + }, + }) + js := mustMarshal(t, action) + if !strings.Contains(js, `"action":"commit"`) { + t.Fatalf("action tag missing: %s", js) + } + if !strings.Contains(js, `"deliveryId":"delivery-1"`) { + t.Fatalf("deliveryId missing: %s", js) + } + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Commit == nil || back.Commit.DeliveryID != "delivery-1" { + t.Fatalf("back: %#v", back.Commit) + } + if back.Commit.Voucher.Data.Cumulative != "500000" { + t.Fatalf("cumulative: %q", back.Commit.Voucher.Data.Cumulative) + } +} + +func TestSessionActionTopUpRoundtrip(t *testing.T) { + action := NewTopUpAction(TopUpPayload{ChannelID: "chan1", NewDeposit: "9000000", Signature: "txsig"}) + js := mustMarshal(t, action) + if !strings.Contains(js, `"action":"topUp"`) { + t.Fatalf("expected topUp camelCase tag: %s", js) + } + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.TopUp == nil || back.TopUp.NewDeposit != "9000000" || back.TopUp.Signature != "txsig" { + t.Fatalf("back: %#v", back.TopUp) + } +} + +func TestSessionActionCloseNoVoucherRoundtrip(t *testing.T) { + action := NewCloseAction(ClosePayload{ChannelID: "chan1"}) + js := mustMarshal(t, action) + if !strings.Contains(js, `"action":"close"`) { + t.Fatalf("action tag missing: %s", js) + } + if strings.Contains(js, "voucher") { + t.Fatalf("voucher should be omitted: %s", js) + } + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Close == nil || back.Close.Voucher != nil { + t.Fatalf("back: %#v", back.Close) + } +} + +func TestSessionActionCloseWithVoucherRoundtrip(t *testing.T) { + nonce := uint64(7) + action := NewCloseAction(ClosePayload{ + ChannelID: "chan1", + Voucher: &SignedVoucher{ + Data: VoucherData{ChannelID: "chan1", Cumulative: "700000", ExpiresAt: 1 << 62, Nonce: &nonce}, + Signature: "final_sig", + }, + }) + js := mustMarshal(t, action) + var back SessionAction + if err := json.Unmarshal([]byte(js), &back); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if back.Close == nil || back.Close.Voucher == nil { + t.Fatalf("back: %#v", back.Close) + } + if back.Close.Voucher.Data.Cumulative != "700000" { + t.Fatalf("cumulative: %q", back.Close.Voucher.Data.Cumulative) + } +} + +func TestSessionActionMarshalErrors(t *testing.T) { + var empty SessionAction + if _, err := json.Marshal(empty); err == nil { + t.Fatal("expected error marshaling empty action") + } + multi := SessionAction{Open: &OpenPayload{Mode: SessionModePush}, Close: &ClosePayload{ChannelID: "c"}} + if _, err := json.Marshal(multi); err == nil { + t.Fatal("expected error marshaling multi-variant action") + } +} + +func TestSessionActionUnmarshalErrors(t *testing.T) { + cases := []string{ + `{"channelId":"c"}`, // missing action + `{"action":"bogus","channelId":"c"}`, // unknown action + `not json`, // malformed + `{"action":"open"}`, // open missing mode + } + for _, c := range cases { + var a SessionAction + if err := json.Unmarshal([]byte(c), &a); err == nil { + t.Fatalf("expected error decoding %q", c) + } + } +} + +// ── VoucherData ── + +func TestVoucherDataCumulativeAlias(t *testing.T) { + js := `{"channelId":"chan1","cumulative":"123","expiresAt":42}` + var v VoucherData + if err := json.Unmarshal([]byte(js), &v); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if v.Cumulative != "123" { + t.Fatalf("alias not honored: %q", v.Cumulative) + } + // Canonical key still works and takes precedence when both present. + both := `{"channelId":"chan1","cumulativeAmount":"999","cumulative":"123","expiresAt":42}` + var v2 VoucherData + if err := json.Unmarshal([]byte(both), &v2); err != nil { + t.Fatalf("unmarshal both: %v", err) + } + if v2.Cumulative != "999" { + t.Fatalf("canonical should win: %q", v2.Cumulative) + } +} + +func TestVoucherDataMarshalUsesCumulativeAmount(t *testing.T) { + v := VoucherData{ChannelID: "c", Cumulative: "100", ExpiresAt: 42} + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + js := string(data) + if !strings.Contains(js, `"cumulativeAmount":"100"`) { + t.Fatalf("expected cumulativeAmount: %s", js) + } + if strings.Contains(js, `"cumulative":`) { + t.Fatalf("should not emit cumulative alias: %s", js) + } +} + +func TestVoucherDataMessageBytesWithNonce(t *testing.T) { + raw := make([]byte, 32) + for i := range raw { + raw[i] = 3 + } + channelID := base58.Encode(raw) + nonce := uint64(1) + data := VoucherData{ChannelID: channelID, Cumulative: "1000", ExpiresAt: 42, Nonce: &nonce} + bytes, err := data.MessageBytes() + if err != nil { + t.Fatalf("messageBytes: %v", err) + } + if len(bytes) != 48 { + t.Fatalf("len: %d", len(bytes)) + } + decoded, _ := base58.Decode(channelID) + if string(bytes[:32]) != string(decoded) { + t.Fatal("channelId prefix mismatch") + } + var cumWant [8]byte + binary.LittleEndian.PutUint64(cumWant[:], 1000) + if string(bytes[32:40]) != string(cumWant[:]) { + t.Fatal("cumulative LE mismatch") + } + var expWant [8]byte + binary.LittleEndian.PutUint64(expWant[:], 42) + if string(bytes[40:48]) != string(expWant[:]) { + t.Fatal("expiresAt LE mismatch") + } +} + +func TestVoucherDataMessageBytesWithoutNonce(t *testing.T) { + raw := make([]byte, 32) + for i := range raw { + raw[i] = 4 + } + data := VoucherData{ChannelID: base58.Encode(raw), Cumulative: "1000", ExpiresAt: 42} + bytes, err := data.MessageBytes() + if err != nil { + t.Fatalf("messageBytes: %v", err) + } + if len(bytes) != 48 { + t.Fatalf("len: %d", len(bytes)) + } +} + +func TestVoucherDataMessageBytesDeterministicAndDiffersByCumulative(t *testing.T) { + raw := make([]byte, 32) + for i := range raw { + raw[i] = 6 + } + channelID := base58.Encode(raw) + a := VoucherData{ChannelID: channelID, Cumulative: "100", ExpiresAt: 42} + a2 := VoucherData{ChannelID: channelID, Cumulative: "100", ExpiresAt: 42} + b := VoucherData{ChannelID: channelID, Cumulative: "200", ExpiresAt: 42} + ab, _ := a.MessageBytes() + a2b, _ := a2.MessageBytes() + bb, _ := b.MessageBytes() + if string(ab) != string(a2b) { + t.Fatal("expected deterministic bytes") + } + if string(ab) == string(bb) { + t.Fatal("expected different bytes for different cumulative") + } +} + +func TestVoucherDataMessageBytesErrors(t *testing.T) { + // Non-base58 channelId. + bad := VoucherData{ChannelID: "0OIl", Cumulative: "1", ExpiresAt: 1} + if _, err := bad.MessageBytes(); err == nil { + t.Fatal("expected invalid channelId error") + } + // Channel id not 32 bytes. + short := VoucherData{ChannelID: base58.Encode([]byte{1, 2, 3}), Cumulative: "1", ExpiresAt: 1} + if _, err := short.MessageBytes(); err == nil { + t.Fatal("expected 32-byte length error") + } + // Invalid cumulative. + rawb := make([]byte, 32) + good := VoucherData{ChannelID: base58.Encode(rawb), Cumulative: "notnum", ExpiresAt: 1} + if _, err := good.MessageBytes(); err == nil { + t.Fatal("expected invalid cumulative error") + } +} + +func TestSignedVoucherFields(t *testing.T) { + v := SignedVoucher{ + Data: VoucherData{ChannelID: "c", Cumulative: "100", ExpiresAt: 1 << 62}, + Signature: "abc123", + } + if v.Data.Cumulative != "100" || v.Signature != "abc123" { + t.Fatalf("fields: %#v", v) + } +} diff --git a/go/protocols/mpp/server/session.go b/go/protocols/mpp/server/session.go new file mode 100644 index 000000000..93eda55ec --- /dev/null +++ b/go/protocols/mpp/server/session.go @@ -0,0 +1,716 @@ +package server + +// Server-side session intent: challenge issuance, voucher verification, and +// channel lifecycle management. +// +// 1. The server calls SessionServer.BuildChallengeRequest to produce the +// SessionRequest embedded in a 402 challenge. +// 2. The client responds with an open action; the server calls +// SessionServer.ProcessOpen to record the channel. +// 3. For each subsequent API call the client attaches a voucher action; the +// server calls SessionServer.VerifyVoucher to validate and advance the +// settled watermark atomically. +// 4. At session end the client (or server) triggers close via +// SessionServer.ProcessClose; on-chain settlement is driven by the host +// once the close-pending state is recorded. +// +// On-chain verification is a seam in this layer: when +// SessionConfig.VerifyOpenTx / VerifyTopUpTx are set, ProcessOpen (push mode) +// and ProcessTopUp invoke them before persisting channel state, binding the +// payload to the attached transaction and confirming the signature on-chain. +// When nil, the transaction signature and +// deposit amount are trusted as provided, which is suitable only for unit +// tests or deployments that verify transactions out of band. + +import ( + "context" + "fmt" + "math" + "strconv" + "time" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// Split is a payment split committed at channel open; distributed at close. +type Split struct { + // Recipient of this split. + Recipient solana.PublicKey + + // BPS is the share in basis points. + BPS uint16 +} + +// SessionTxVerifier confirms an on-chain transaction referenced by a session +// payload before channel state is persisted. Implementations typically decode +// the attached transaction, bind the payload signature to it, and confirm the +// signature on-chain. This is the seam the on-chain layer plugs into; nil +// skips verification. +type SessionTxVerifier[P any] func(ctx context.Context, payload *P) error + +// SessionConfig is the server configuration for the session intent. +type SessionConfig struct { + // Operator public key (base58). Shown to clients in the challenge. + Operator string + + // Recipient is the primary payment recipient (base58). + Recipient string + + // Splits are optional splits routed to specific recipients at close. + Splits []Split + + // MaxCap is the maximum cap the server will offer per session (base + // units). Clients may request a lower cap but not a higher one. + MaxCap uint64 + + // Currency identifier (e.g., "USDC", mint address). + Currency string + + // Decimals is the token decimals (default 6 for USDC). + Decimals uint8 + + // Network is the Solana network: "mainnet", "devnet", "localnet". + Network string + + // ProgramID is the payment-channel program ID. Nil defaults to the + // canonical program. + ProgramID *solana.PublicKey + + // MinVoucherDelta is the minimum voucher increment (base units). 0 = no + // minimum. + MinVoucherDelta uint64 + + // Modes are the session modes this server accepts, advertised to clients + // in the 402 challenge. An empty list or [push] means only the + // payment-channel push mode is supported. + Modes []intents.SessionMode + + // PullVoucherStrategy is the voucher authority used for pull sessions. + // Required when Modes includes pull. + PullVoucherStrategy *intents.SessionPullVoucherStrategy + + // VerifyOpenTx, when set, confirms the open transaction on-chain (push + // mode) before ProcessOpen persists channel state. See SessionTxVerifier. + VerifyOpenTx SessionTxVerifier[intents.OpenPayload] + + // VerifyTopUpTx, when set, confirms the top-up transaction on-chain + // before ProcessTopUp raises the deposit. See SessionTxVerifier. + VerifyTopUpTx SessionTxVerifier[intents.TopUpPayload] +} + +// DeliveryRequest is a request to reserve a metered delivery for client-side +// ack/commit. Zero values mean "absent" for the optional fields. +type DeliveryRequest struct { + // SessionID is the channel/session ID that will pay for the delivery. + SessionID string + + // Amount owed for this delivery in base units. + Amount uint64 + + // DeliveryID is an optional idempotency key. When empty the server + // derives ":". + DeliveryID string + + // CommitURL is an optional commit endpoint hint surfaced to the client. + CommitURL string + + // Proof is an optional opaque proof surfaced to the client. + Proof string + + // ExpiresAt is an optional directive expiry (Unix seconds). Zero defaults + // to intents.DefaultSessionExpiresAt. + ExpiresAt int64 +} + +// SessionServer is the server-side session manager. Pluggable over the +// channel store to support in-memory testing and production persistence +// backends. +type SessionServer struct { + // config is the immutable server configuration captured at construction. + config SessionConfig + + // store persists per-channel state; every mutation goes through its + // atomic UpdateChannel so voucher watermarks stay double-spend safe. + store ChannelStore +} + +// NewSessionServer creates a SessionServer over the given store. +func NewSessionServer(config SessionConfig, store ChannelStore) *SessionServer { + return &SessionServer{config: config, store: store} +} + +// Store returns the channel store backing this server, so hosts can share it +// with metering side channels. +func (s *SessionServer) Store() ChannelStore { + return s.store +} + +// BuildChallengeRequest builds the SessionRequest to embed in a 402 +// challenge. cap is the maximum this session will allow, clamped to +// SessionConfig.MaxCap. MinVoucherDelta is included only when positive, +// Modes is omitted when push-only, and PullVoucherStrategy is included only +// when pull is offered. +func (s *SessionServer) BuildChallengeRequest(cap uint64) intents.SessionRequest { + effectiveCap := min(cap, s.config.MaxCap) + decimals := s.config.Decimals + + request := intents.SessionRequest{ + Cap: strconv.FormatUint(effectiveCap, 10), + Currency: s.config.Currency, + Decimals: &decimals, + Operator: s.config.Operator, + Recipient: s.config.Recipient, + } + if s.config.Network != "" { + network := s.config.Network + request.Network = &network + } + for _, split := range s.config.Splits { + request.Splits = append(request.Splits, intents.SessionSplit{ + Recipient: split.Recipient.String(), + BPS: split.BPS, + }) + } + if s.config.ProgramID != nil { + programID := s.config.ProgramID.String() + request.ProgramID = &programID + } + if s.config.MinVoucherDelta > 0 { + minDelta := strconv.FormatUint(s.config.MinVoucherDelta, 10) + request.MinVoucherDelta = &minDelta + } + // Omit modes when only push is supported; clients assume push when modes + // is absent. + if !s.pushOnly() { + request.Modes = append([]intents.SessionMode(nil), s.config.Modes...) + } + if s.supportsMode(intents.SessionModePull) && s.config.PullVoucherStrategy != nil { + strategy := *s.config.PullVoucherStrategy + request.PullVoucherStrategy = &strategy + } + return request +} + +// pushOnly reports whether the configured modes reduce to push-only. +func (s *SessionServer) pushOnly() bool { + return len(s.config.Modes) == 0 || + (len(s.config.Modes) == 1 && s.config.Modes[0] == intents.SessionModePush) +} + +// supportsMode reports whether the server accepts mode. Empty configured +// modes mean push-only. +func (s *SessionServer) supportsMode(mode intents.SessionMode) bool { + if len(s.config.Modes) == 0 { + return mode == intents.SessionModePush + } + for _, supported := range s.config.Modes { + if supported == mode { + return true + } + } + return false +} + +// ProcessOpen processes an open action and persists the channel state. +// +// The channel is keyed by OpenPayload.SessionID (channelId first, then +// tokenAccount for pull opens). Replayed opens are idempotent: when a channel +// already exists for the session id with the same authorized signer, the +// existing state is returned unchanged and the voucher watermark is never +// reset. Opens for an existing channel are rejected when the channel is +// finalized or when the payload's authorized signer differs from the stored +// one. +func (s *SessionServer) ProcessOpen(ctx context.Context, payload *intents.OpenPayload) (ChannelState, error) { + if !s.supportsMode(payload.Mode) { + return ChannelState{}, fmt.Errorf("session mode %q is not supported by this challenge", payload.Mode) + } + + sessionID, err := payload.SessionID() + if err != nil { + return ChannelState{}, err + } + deposit, err := payload.DepositAmount() + if err != nil { + return ChannelState{}, err + } + if deposit == 0 { + return ChannelState{}, fmt.Errorf("deposit must be greater than zero") + } + if deposit > s.config.MaxCap { + return ChannelState{}, fmt.Errorf("deposit %d exceeds max cap %d", deposit, s.config.MaxCap) + } + + // On-chain verification seam (push mode only; pull-mode host integrations + // submit server-broadcast transactions or validate delegated-token state + // before invoking this lower-level store method). + if payload.Mode == intents.SessionModePush && s.config.VerifyOpenTx != nil { + if err := s.config.VerifyOpenTx(ctx, payload); err != nil { + return ChannelState{}, fmt.Errorf("open tx verification failed: %w", err) + } + } + + operator := payload.Owner + if operator == nil { + operator = payload.Payer + } + fresh := ChannelState{ + ChannelID: sessionID, + AuthorizedSigner: payload.AuthorizedSigner, + Deposit: deposit, + Operator: operator, + } + + // Atomic check-and-insert: a replayed open re-passes all checks above + // (the referenced tx is genuinely confirmed), so it MUST NOT overwrite + // existing state; that would reset the voucher watermark and erase + // accepted vouchers before close. + return s.store.UpdateChannel(ctx, sessionID, func(existing *ChannelState) (ChannelState, error) { + if existing != nil { + if existing.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", sessionID) + } + if existing.AuthorizedSigner != payload.AuthorizedSigner { + return ChannelState{}, fmt.Errorf("channel %s already exists with a different authorized signer", sessionID) + } + // Idempotent replay: keep existing state untouched. + return *existing, nil + } + return fresh, nil + }) +} + +// VerifyVoucher verifies a voucher, advances the watermark, and returns the +// new cumulative. +// +// The full ordered check sequence runs as a preflight outside the store lock +// (see VerifyVoucherForChannel), then the state-dependent checks are +// re-applied inside the atomic mutator before the watermark is persisted. +func (s *SessionServer) VerifyVoucher(ctx context.Context, payload *intents.VoucherPayload) (uint64, error) { + voucher := payload.Voucher + channelID := voucher.Data.ChannelID + + state, err := s.store.GetChannel(ctx, channelID) + if err != nil { + return 0, err + } + if state == nil { + return 0, fmt.Errorf("channel %s not found", channelID) + } + + // Preflight outside the lock (expensive signature check happens before + // touching the store). + result := VerifyVoucherForChannel(VerifyVoucherArgs{ + State: *state, + Signed: voucher, + Deposit: state.Deposit, + MinVoucherDelta: s.config.MinVoucherDelta, + }) + switch result.Status { + case VoucherVerifyRejected: + // Surface the stable reject tag ahead of the detail + // (": "). + return 0, fmt.Errorf("%s: %s", result.Reason, result.Detail) + case VoucherVerifyReplayed: + return result.NewCumulative, nil + } + + newCumulative := result.NewCumulative + newSignature := result.NewSignature + newExpiresAt := result.NewExpiresAt + + // Atomic read-modify-write: re-check everything state-dependent inside + // the mutator. + newState, err := s.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", channelID) + } + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", channelID) + } + if current.CloseRequestedAt != nil { + return ChannelState{}, fmt.Errorf("channel %s close is pending; no further vouchers accepted", channelID) + } + // Idempotent replay inside the mutator. + if newCumulative == current.Cumulative && + current.HighestVoucherSignature != nil && + *current.HighestVoucherSignature == newSignature { + return *current, nil + } + // Concurrent watermark advancement check. + if newCumulative <= current.Cumulative { + return ChannelState{}, fmt.Errorf("concurrent update: watermark advanced") + } + next := *current + next.Cumulative = newCumulative + next.HighestVoucherSignature = &newSignature + next.HighestVoucherExpiresAt = &newExpiresAt + return next, nil + }) + if err != nil { + return 0, err + } + return newState.Cumulative, nil +} + +// ProcessTopUp processes a topUp action: atomically raise the channel's +// deposit cap. +// +// The new deposit must exceed the current deposit and must not exceed the +// configured max cap. Top-ups are rejected once the channel is finalized or a +// close has been requested. +func (s *SessionServer) ProcessTopUp(ctx context.Context, payload *intents.TopUpPayload) (ChannelState, error) { + newDeposit, err := strconv.ParseUint(payload.NewDeposit, 10, 64) + if err != nil { + return ChannelState{}, fmt.Errorf("invalid newDeposit: %s", payload.NewDeposit) + } + + // On-chain verification seam (same shape as ProcessOpen). + if s.config.VerifyTopUpTx != nil { + if err := s.config.VerifyTopUpTx(ctx, payload); err != nil { + return ChannelState{}, fmt.Errorf("top-up tx verification failed: %w", err) + } + } + + maxCap := s.config.MaxCap + channelID := payload.ChannelID + return s.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", channelID) + } + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", channelID) + } + if current.CloseRequestedAt != nil { + return ChannelState{}, fmt.Errorf("channel %s close is pending; no further top-ups accepted", channelID) + } + if newDeposit <= current.Deposit { + return ChannelState{}, fmt.Errorf("new deposit %d must exceed current deposit %d", newDeposit, current.Deposit) + } + if newDeposit > maxCap { + return ChannelState{}, fmt.Errorf("new deposit %d exceeds max cap %d", newDeposit, maxCap) + } + next := *current + next.Deposit = newDeposit + return next, nil + }) +} + +// BeginDelivery reserves capacity for a delivered message/response and +// returns the metering directive the client must commit after processing it. +// +// The reservation requires cumulative + pendingTotal + amount <= deposit, +// assigns the next sequence, and defaults the delivery id to +// ":". +func (s *SessionServer) BeginDelivery(ctx context.Context, request DeliveryRequest) (intents.MeteringDirective, error) { + if request.Amount == 0 { + return intents.MeteringDirective{}, fmt.Errorf("delivery amount must be greater than zero") + } + + sessionID := request.SessionID + amount := request.Amount + expiresAt := request.ExpiresAt + if expiresAt == 0 { + expiresAt = intents.DefaultSessionExpiresAt + } + + var directive intents.MeteringDirective + _, err := s.store.UpdateChannel(ctx, sessionID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", sessionID) + } + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", sessionID) + } + if current.CloseRequestedAt != nil { + return ChannelState{}, fmt.Errorf("channel %s close is pending; no further deliveries accepted", sessionID) + } + pendingTotal := uint64(0) + for _, delivery := range current.PendingDeliveries { + pendingTotal += delivery.Amount + } + if !fitsInDeposit(current.Cumulative, pendingTotal, amount, current.Deposit) { + return ChannelState{}, fmt.Errorf("delivery amount %d exceeds available deposit", amount) + } + + sequence := current.NextDeliverySequence + 1 + deliveryID := request.DeliveryID + if deliveryID == "" { + deliveryID = fmt.Sprintf("%s:%d", sessionID, sequence) + } + for _, delivery := range current.PendingDeliveries { + if delivery.DeliveryID == deliveryID { + return ChannelState{}, fmt.Errorf("delivery %s already exists", deliveryID) + } + } + for _, delivery := range current.CommittedDeliveries { + if delivery.DeliveryID == deliveryID { + return ChannelState{}, fmt.Errorf("delivery %s already exists", deliveryID) + } + } + + next := *current + next.NextDeliverySequence = sequence + next.PendingDeliveries = append(next.PendingDeliveries, PendingDelivery{ + DeliveryID: deliveryID, + Amount: amount, + Sequence: sequence, + ExpiresAt: expiresAt, + }) + + directive = intents.MeteringDirective{ + DeliveryID: deliveryID, + SessionID: sessionID, + Amount: strconv.FormatUint(amount, 10), + Currency: s.config.Currency, + Sequence: sequence, + ExpiresAt: expiresAt, + } + if request.CommitURL != "" { + commitURL := request.CommitURL + directive.CommitURL = &commitURL + } + if request.Proof != "" { + proof := request.Proof + directive.Proof = &proof + } + return next, nil + }) + if err != nil { + return intents.MeteringDirective{}, err + } + return directive, nil +} + +// fitsInDeposit reports whether cumulative + pendingTotal + amount <= deposit +// without overflowing u64; any overflow is treated as exceeding the deposit. +func fitsInDeposit(cumulative, pendingTotal, amount, deposit uint64) bool { + if pendingTotal > math.MaxUint64-cumulative { + return false + } + reserved := cumulative + pendingTotal + if amount > math.MaxUint64-reserved { + return false + } + return reserved+amount <= deposit +} + +// ProcessCommit commits a reserved delivery by verifying the attached +// voucher and advancing the settled watermark. Replaying a commit for an +// already-committed delivery (same cumulative and same signature) returns the +// cached receipt with status replayed after re-verifying the voucher +// signature. +func (s *SessionServer) ProcessCommit(ctx context.Context, payload *intents.CommitPayload) (intents.CommitReceipt, error) { + channelID := payload.Voucher.Data.ChannelID + newCumulative, err := strconv.ParseUint(payload.Voucher.Data.Cumulative, 10, 64) + if err != nil { + return intents.CommitReceipt{}, fmt.Errorf("invalid cumulative in commit voucher: %s", payload.Voucher.Data.Cumulative) + } + + state, err := s.store.GetChannel(ctx, channelID) + if err != nil { + return intents.CommitReceipt{}, err + } + if state == nil { + return intents.CommitReceipt{}, fmt.Errorf("channel %s not found", channelID) + } + + // Preflight outside the lock. + if committed := findCommitted(state.CommittedDeliveries, payload.DeliveryID); committed != nil { + if committed.Cumulative == newCumulative && committed.VoucherSignature == payload.Voucher.Signature { + if err := verifySessionVoucher(payload.Voucher, state.AuthorizedSigner); err != nil { + return intents.CommitReceipt{}, err + } + return commitReceipt(payload.DeliveryID, channelID, committed.Amount, committed.Cumulative, intents.CommitStatusReplayed), nil + } + return intents.CommitReceipt{}, fmt.Errorf("delivery %s was already committed with different voucher", payload.DeliveryID) + } + pending := findPending(state.PendingDeliveries, payload.DeliveryID) + if pending == nil { + return intents.CommitReceipt{}, fmt.Errorf("delivery %s not found", payload.DeliveryID) + } + now := time.Now().Unix() + if pending.ExpiresAt <= now { + return intents.CommitReceipt{}, fmt.Errorf("delivery %s has expired", payload.DeliveryID) + } + if newCumulative <= state.Cumulative { + return intents.CommitReceipt{}, fmt.Errorf("commit cumulative %d must exceed watermark %d", newCumulative, state.Cumulative) + } + if err := verifySessionVoucher(payload.Voucher, state.AuthorizedSigner); err != nil { + return intents.CommitReceipt{}, err + } + + deliveryID := payload.DeliveryID + signature := payload.Voucher.Signature + voucherExpiresAt := payload.Voucher.Data.ExpiresAt + + var receiptAmount, receiptCumulative uint64 + var receiptStatus intents.CommitStatus + _, err = s.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", channelID) + } + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", channelID) + } + if current.CloseRequestedAt != nil { + return ChannelState{}, fmt.Errorf("channel %s close is pending; no further commits accepted", channelID) + } + if committed := findCommitted(current.CommittedDeliveries, deliveryID); committed != nil { + if committed.Cumulative == newCumulative && committed.VoucherSignature == signature { + receiptAmount, receiptCumulative = committed.Amount, committed.Cumulative + receiptStatus = intents.CommitStatusReplayed + return *current, nil + } + return ChannelState{}, fmt.Errorf("delivery %s was already committed with different voucher", deliveryID) + } + pendingIndex := -1 + for i, delivery := range current.PendingDeliveries { + if delivery.DeliveryID == deliveryID { + pendingIndex = i + break + } + } + if pendingIndex < 0 { + return ChannelState{}, fmt.Errorf("delivery %s not found", deliveryID) + } + reserved := current.PendingDeliveries[pendingIndex] + if reserved.ExpiresAt <= now { + return ChannelState{}, fmt.Errorf("delivery %s has expired", deliveryID) + } + if newCumulative <= current.Cumulative { + return ChannelState{}, fmt.Errorf("commit cumulative %d must exceed watermark %d", newCumulative, current.Cumulative) + } + actualAmount := newCumulative - current.Cumulative + if actualAmount > reserved.Amount { + return ChannelState{}, fmt.Errorf("commit amount %d exceeds reserved amount %d", actualAmount, reserved.Amount) + } + + next := *current + next.PendingDeliveries = append( + append([]PendingDelivery(nil), current.PendingDeliveries[:pendingIndex]...), + current.PendingDeliveries[pendingIndex+1:]..., + ) + next.Cumulative = newCumulative + next.HighestVoucherSignature = &signature + next.HighestVoucherExpiresAt = &voucherExpiresAt + next.CommittedDeliveries = append(append([]CommittedDelivery(nil), current.CommittedDeliveries...), CommittedDelivery{ + DeliveryID: deliveryID, + Amount: actualAmount, + Cumulative: newCumulative, + VoucherSignature: signature, + }) + receiptAmount, receiptCumulative = actualAmount, newCumulative + receiptStatus = intents.CommitStatusCommitted + return next, nil + }) + if err != nil { + return intents.CommitReceipt{}, err + } + return commitReceipt(deliveryID, channelID, receiptAmount, receiptCumulative, receiptStatus), nil +} + +// commitReceipt builds a CommitReceipt with stringified amounts. +func commitReceipt(deliveryID, sessionID string, amount, cumulative uint64, status intents.CommitStatus) intents.CommitReceipt { + return intents.CommitReceipt{ + DeliveryID: deliveryID, + SessionID: sessionID, + Amount: strconv.FormatUint(amount, 10), + Cumulative: strconv.FormatUint(cumulative, 10), + Status: status, + } +} + +// findPending returns the pending delivery with the given id, or nil. +func findPending(deliveries []PendingDelivery, deliveryID string) *PendingDelivery { + for i := range deliveries { + if deliveries[i].DeliveryID == deliveryID { + return &deliveries[i] + } + } + return nil +} + +// findCommitted returns the committed delivery with the given id, or nil. +func findCommitted(deliveries []CommittedDelivery, deliveryID string) *CommittedDelivery { + for i := range deliveries { + if deliveries[i].DeliveryID == deliveryID { + return &deliveries[i] + } + } + return nil +} + +// ProcessClose processes a close action: atomically set close-pending and +// accept a final voucher if provided. +// +// Once CloseRequestedAt is set, vouchers, deliveries, commits, and top-ups +// are all rejected, and a second close is rejected with "close already +// requested". A non-monotonic final voucher is a hard error (unless it is an +// idempotent replay of the current highest voucher) and leaves the state +// unchanged. On-chain settlement (settle_and_finalize + distribute) is driven +// by the host after this returns; see MarkFinalized for the post-settlement +// transition. +func (s *SessionServer) ProcessClose(ctx context.Context, payload *intents.ClosePayload) (ChannelState, error) { + now := uint64(time.Now().Unix()) + channelID := payload.ChannelID + voucher := payload.Voucher + + return s.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", channelID) + } + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", channelID) + } + if current.CloseRequestedAt != nil { + return ChannelState{}, fmt.Errorf("close already requested") + } + + next := *current + if voucher != nil { + cumulative, err := strconv.ParseUint(voucher.Data.Cumulative, 10, 64) + if err != nil { + return ChannelState{}, fmt.Errorf("invalid cumulative in final voucher: %s", voucher.Data.Cumulative) + } + if cumulative <= current.Cumulative { + // Idempotent replay of the current highest voucher is allowed; + // any other non-monotonic final voucher is a hard error. + replay := cumulative == current.Cumulative && + current.HighestVoucherSignature != nil && + *current.HighestVoucherSignature == voucher.Signature + if !replay { + return ChannelState{}, fmt.Errorf( + "final voucher cumulative %d must exceed watermark %d", cumulative, current.Cumulative) + } + if next.HighestVoucherExpiresAt == nil { + expiresAt := voucher.Data.ExpiresAt + next.HighestVoucherExpiresAt = &expiresAt + } + } else { + if cumulative > current.Deposit { + return ChannelState{}, fmt.Errorf("final voucher exceeds deposit") + } + if err := verifySessionVoucher(*voucher, current.AuthorizedSigner); err != nil { + return ChannelState{}, err + } + signature := voucher.Signature + expiresAt := voucher.Data.ExpiresAt + next.Cumulative = cumulative + next.HighestVoucherSignature = &signature + next.HighestVoucherExpiresAt = &expiresAt + } + } + closeRequestedAt := now + next.CloseRequestedAt = &closeRequestedAt + return next, nil + }) +} + +// MarkFinalized marks a channel as finalized. Call after the on-chain +// finalize transaction confirms. +func (s *SessionServer) MarkFinalized(ctx context.Context, channelID string) error { + _, err := s.store.MarkFinalized(ctx, channelID) + return err +} diff --git a/go/protocols/mpp/server/session_concurrency_test.go b/go/protocols/mpp/server/session_concurrency_test.go new file mode 100644 index 000000000..32c6c5998 --- /dev/null +++ b/go/protocols/mpp/server/session_concurrency_test.go @@ -0,0 +1,257 @@ +package server + +// Adversarial coverage of the re-check-inside-the-mutator paths: the +// preflight runs outside the store lock, so every state-dependent check must +// hold again inside the atomic mutator. These tests interleave a competing +// write between the preflight read and the mutator using a racing store +// wrapper. + +import ( + "context" + "math" + "strings" + "testing" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// racingChannelStore wraps a ChannelStore and runs interleave exactly once, +// immediately before the next UpdateChannel applies its mutator. This +// simulates a concurrent writer that slips in between a handler's preflight +// read and its atomic read-modify-write. +type racingChannelStore struct { + // ChannelStore is the wrapped real store the interleaved writes land in. + ChannelStore + + // interleave runs once immediately before the next UpdateChannel applies + // its mutator, then disarms itself. + interleave func(ctx context.Context, store ChannelStore) +} + +func (s *racingChannelStore) UpdateChannel(ctx context.Context, channelID string, mutator ChannelMutator) (ChannelState, error) { + if s.interleave != nil { + race := s.interleave + s.interleave = nil + race(ctx, s.ChannelStore) + } + return s.ChannelStore.UpdateChannel(ctx, channelID, mutator) +} + +func TestVerifyVoucherDetectsConcurrentWatermarkAdvance(t *testing.T) { + racing := &racingChannelStore{ChannelStore: NewMemoryChannelStore()} + server := NewSessionServer(sessionTestConfig(), racing) + signer, channelID := openTestChannel(t, server, 1_000_000) + + // Between the preflight and the mutator a competing voucher advances the + // watermark past this voucher's cumulative. + racing.interleave = func(ctx context.Context, store ChannelStore) { + if _, err := store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + next := *current + next.Cumulative = 500 + return next, nil + }); err != nil { + t.Fatalf("interleaved update: %v", err) + } + } + + _, err := submitVoucher(t, server, signer, channelID, 100) + if err == nil || !strings.Contains(err.Error(), "concurrent update") { + t.Fatalf("err = %v, want concurrent-update rejection", err) + } +} + +func TestVerifyVoucherDetectsConcurrentClose(t *testing.T) { + racing := &racingChannelStore{ChannelStore: NewMemoryChannelStore()} + server := NewSessionServer(sessionTestConfig(), racing) + signer, channelID := openTestChannel(t, server, 1_000_000) + + racing.interleave = func(ctx context.Context, store ChannelStore) { + if _, err := store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + next := *current + closeAt := uint64(1) + next.CloseRequestedAt = &closeAt + return next, nil + }); err != nil { + t.Fatalf("interleaved update: %v", err) + } + } + + _, err := submitVoucher(t, server, signer, channelID, 100) + if err == nil || !strings.Contains(err.Error(), "close is pending") { + t.Fatalf("err = %v, want close-pending rejection inside the mutator", err) + } +} + +func TestVerifyVoucherDetectsConcurrentFinalize(t *testing.T) { + racing := &racingChannelStore{ChannelStore: NewMemoryChannelStore()} + server := NewSessionServer(sessionTestConfig(), racing) + signer, channelID := openTestChannel(t, server, 1_000_000) + + racing.interleave = func(ctx context.Context, store ChannelStore) { + if _, err := store.MarkFinalized(ctx, channelID); err != nil { + t.Fatalf("interleaved finalize: %v", err) + } + } + + _, err := submitVoucher(t, server, signer, channelID, 100) + if err == nil || !strings.Contains(err.Error(), "finalized") { + t.Fatalf("err = %v, want finalized rejection inside the mutator", err) + } +} + +func TestVerifyVoucherConcurrentIdenticalReplayInsideMutator(t *testing.T) { + racing := &racingChannelStore{ChannelStore: NewMemoryChannelStore()} + server := NewSessionServer(sessionTestConfig(), racing) + signer, channelID := openTestChannel(t, server, 1_000_000) + + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + // The same voucher lands twice concurrently: the slower submission sees + // the watermark already advanced with its own signature and resolves as + // an idempotent replay instead of a concurrent-update error. + racing.interleave = func(ctx context.Context, store ChannelStore) { + if _, err := store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + next := *current + next.Cumulative = 100 + signature := voucher.Signature + next.HighestVoucherSignature = &signature + expiresAt := voucher.Data.ExpiresAt + next.HighestVoucherExpiresAt = &expiresAt + return next, nil + }); err != nil { + t.Fatalf("interleaved update: %v", err) + } + } + + cumulative, err := server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: voucher}) + if err != nil { + t.Fatalf("VerifyVoucher: %v", err) + } + if cumulative != 100 { + t.Fatalf("cumulative = %d, want 100", cumulative) + } +} + +func TestProcessCommitDetectsConcurrentReplayAndClose(t *testing.T) { + racing := &racingChannelStore{ChannelStore: NewMemoryChannelStore()} + server := NewSessionServer(sessionTestConfig(), racing) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + payload := &intents.CommitPayload{DeliveryID: directive.DeliveryID, Voucher: voucher} + + // A concurrent identical commit completes between preflight and mutator: + // the mutator resolves it as a replay using the committed-deliveries log. + racing.interleave = func(ctx context.Context, store ChannelStore) { + if _, err := store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + next := *current + next.PendingDeliveries = nil + next.Cumulative = 100 + signature := voucher.Signature + next.HighestVoucherSignature = &signature + next.CommittedDeliveries = []CommittedDelivery{{ + DeliveryID: directive.DeliveryID, + Amount: 100, + Cumulative: 100, + VoucherSignature: voucher.Signature, + }} + return next, nil + }); err != nil { + t.Fatalf("interleaved update: %v", err) + } + } + receipt, err := server.ProcessCommit(context.Background(), payload) + if err != nil { + t.Fatalf("ProcessCommit: %v", err) + } + if receipt.Status != intents.CommitStatusReplayed { + t.Fatalf("status = %s, want replayed", receipt.Status) + } + + // A concurrent close between preflight and mutator rejects the commit. + directive2, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + voucher2 := signer.SignVoucher(t, channelID, 200, farFuture()) + racing.interleave = func(ctx context.Context, store ChannelStore) { + if _, err := store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + next := *current + closeAt := uint64(1) + next.CloseRequestedAt = &closeAt + return next, nil + }); err != nil { + t.Fatalf("interleaved close: %v", err) + } + } + _, err = server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive2.DeliveryID, Voucher: voucher2, + }) + if err == nil || !strings.Contains(err.Error(), "close is pending") { + t.Fatalf("err = %v, want close-pending rejection inside the mutator", err) + } +} + +func TestFitsInDepositOverflowGuards(t *testing.T) { + cases := []struct { + name string + cumulative, pendingTotal, amount, cap uint64 + want bool + }{ + {"boundary holds", 400, 500, 100, 1_000, true}, + {"one over cap", 400, 500, 101, 1_000, false}, + {"cumulative plus pending overflows", math.MaxUint64, 1, 1, math.MaxUint64, false}, + {"reserved plus amount overflows", math.MaxUint64 - 1, 1, 1, math.MaxUint64, false}, + {"max values without overflow", 0, 0, math.MaxUint64, math.MaxUint64, true}, + } + for _, tc := range cases { + if got := fitsInDeposit(tc.cumulative, tc.pendingTotal, tc.amount, tc.cap); got != tc.want { + t.Errorf("%s: fitsInDeposit(%d, %d, %d, %d) = %v, want %v", + tc.name, tc.cumulative, tc.pendingTotal, tc.amount, tc.cap, got, tc.want) + } + } +} + +func TestVerifyVoucherForChannelMalformedSignatureEncodingRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + voucher := signer.SignVoucher(t, state.ChannelID, 100, farFuture()) + voucher.Signature = "not base58!!" + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectInvalidSignature { + t.Fatalf("result = %+v, want invalid-signature rejection", result) + } +} + +func TestVerifyVoucherForChannelMalformedAuthorizedSignerRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState("not-a-pubkey") + voucher := signer.SignVoucher(t, state.ChannelID, 100, farFuture()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectInvalidSignature { + t.Fatalf("result = %+v, want invalid-signature rejection", result) + } +} + +func TestProcessCommitExpiredVoucherRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + // The directive is live but the voucher itself is expired. + expired := signer.SignVoucher(t, channelID, 100, -10) + _, err = server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: expired, + }) + if err == nil || !strings.Contains(err.Error(), "voucher has expired") { + t.Fatalf("err = %v, want expired-voucher rejection", err) + } +} diff --git a/go/protocols/mpp/server/session_e2e_test.go b/go/protocols/mpp/server/session_e2e_test.go new file mode 100644 index 000000000..34fcaaf89 --- /dev/null +++ b/go/protocols/mpp/server/session_e2e_test.go @@ -0,0 +1,326 @@ +package server + +// Surfpool-gated end-to-end session lifecycle test. +// +// Exercises a real payment-channel open completed and broadcast by the +// server, metered vouchers, side-channel reserve/commit, and on-chain settle +// at close against the hosted Solana Payment Sandbox. The suite gates at +// runtime: it skips explicitly (never silently passes) when the sandbox is +// unreachable or the suite runs with -short. + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/client" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// surfpoolRPCURL resolves the sandbox RPC endpoint, honoring the harness +// override. +func surfpoolRPCURL() string { + if url := os.Getenv("MPP_HARNESS_RPC_URL"); url != "" { + return url + } + return "https://402.surfnet.dev:8899" +} + +// requireSurfpool skips the test explicitly when the sandbox is unreachable. +func requireSurfpool(t *testing.T) *rpc.Client { + t.Helper() + if testing.Short() { + t.Skip("skipping surfpool e2e in -short mode") + } + url := surfpoolRPCURL() + rpcClient := rpc.New(url) + probeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := rpcClient.GetLatestBlockhash(probeCtx, rpc.CommitmentConfirmed); err != nil { + t.Skipf("surfpool sandbox unreachable at %s: %v", url, err) + } + return rpcClient +} + +// surfnetSetAccount funds owner with lamports via the surfnet cheatcode. +func surfnetSetAccount(ctx context.Context, t *testing.T, rpcClient *rpc.Client, owner solana.PublicKey, lamports uint64) { + t.Helper() + params := []any{ + owner.String(), + map[string]any{ + "lamports": lamports, + "data": "", + "executable": false, + "owner": "11111111111111111111111111111111", + "rentEpoch": 0, + }, + } + var out json.RawMessage + if err := rpcClient.RPCCallForInto(ctx, &out, "surfnet_setAccount", params); err != nil { + t.Fatalf("surfnet_setAccount(%s): %v", owner, err) + } +} + +// surfnetSetTokenAccount provisions owner's token account via the surfnet +// cheatcode. +func surfnetSetTokenAccount(ctx context.Context, t *testing.T, rpcClient *rpc.Client, owner solana.PublicKey, mint string, amount uint64) { + t.Helper() + params := []any{ + owner.String(), + mint, + map[string]any{"amount": amount, "state": "initialized"}, + paycore.TokenProgram, + } + var out json.RawMessage + if err := rpcClient.RPCCallForInto(ctx, &out, "surfnet_setTokenAccount", params); err != nil { + t.Fatalf("surfnet_setTokenAccount(%s): %v", owner, err) + } +} + +// authedGet performs a GET with the given Authorization header and returns +// the response plus its body. +func authedGet(t *testing.T, url, authorization string) (*http.Response, string) { + t.Helper() + request, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + if authorization != "" { + request.Header.Set(core.AuthorizationHeader, authorization) + } + response, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("GET %s: %v", url, err) + } + body, err := io.ReadAll(response.Body) + response.Body.Close() + if err != nil { + t.Fatalf("read body: %v", err) + } + return response, string(body) +} + +func TestSessionServerE2ESurfpool(t *testing.T) { + rpcClient := requireSurfpool(t) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // The operator funds fees, completes the open signature server-side, and + // receives the proceeds. + operator := testutil.NewPrivateKey() + payer := testutil.NewPrivateKey() + mint := paycore.ResolveMint("USDC", "localnet") + + surfnetSetAccount(ctx, t, rpcClient, operator.PublicKey(), 10_000_000_000) + surfnetSetAccount(ctx, t, rpcClient, payer.PublicKey(), 10_000_000_000) + surfnetSetTokenAccount(ctx, t, rpcClient, payer.PublicKey(), mint, 100_000_000) + surfnetSetTokenAccount(ctx, t, rpcClient, operator.PublicKey(), mint, 0) + + strategy := intents.SessionPullVoucherStrategyClientVoucher + session, err := NewSession(SessionOptions{ + Operator: operator.PublicKey().String(), + Recipient: operator.PublicKey().String(), + Cap: 1_000_000, // 1.00 USDC + Currency: "USDC", + Decimals: 6, + Network: "localnet", + SecretKey: "session-e2e-secret", + Realm: "e2e.test", + Modes: []intents.SessionMode{intents.SessionModePull}, + PullVoucherStrategy: &strategy, + OpenTxSubmitter: OpenTxSubmitterServer, + PaymentChannelPayerSigner: operator, + Signer: operator, + RPC: rpcClient, + }) + if err != nil { + t.Fatalf("NewSession: %v", err) + } + t.Cleanup(session.Shutdown) + + mux := http.NewServeMux() + routes := session.Routes() + mux.HandleFunc("/__402/session/deliveries", routes.Deliveries) + mux.HandleFunc("/__402/session/commit", routes.Commit) + mux.Handle("/stream", SessionMiddleware(session, func(*http.Request) (SessionChallengeOptions, error) { + return SessionChallengeOptions{Description: "Metered token stream"}, nil + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }))) + httpServer := httptest.NewServer(mux) + defer httpServer.Close() + streamURL := httpServer.URL + "/stream" + + // 1. Unauthenticated request: 402 with a session challenge carrying a + // recent blockhash from the sandbox. + response, body := authedGet(t, streamURL, "") + if response.StatusCode != http.StatusPaymentRequired { + t.Fatalf("expected 402, got %d: %s", response.StatusCode, body) + } + challenge, request, err := client.ParseSessionChallenge(response.Header.Get(core.WWWAuthenticateHeader)) + if err != nil { + t.Fatalf("ParseSessionChallenge: %v", err) + } + if request.RecentBlockhash == nil { + t.Fatal("challenge missing recentBlockhash") + } + + // 2. Open: the client derives the channel, partial-signs as the payer + // against the challenge blockhash, and the server completes the fee-payer + // signature and broadcasts. + sessionSigner, err := client.NewEphemeralSessionSigner() + if err != nil { + t.Fatalf("NewEphemeralSessionSigner: %v", err) + } + opener, err := client.CreatePaymentChannelSessionOpener(request, payer, sessionSigner, "", client.PaymentChannelSessionOpenOptions{}) + if err != nil { + t.Fatalf("CreatePaymentChannelSessionOpener: %v", err) + } + openAuthorization, err := client.SerializeSessionCredential(challenge, opener.Action) + if err != nil { + t.Fatalf("serialize open credential: %v", err) + } + response, body = authedGet(t, streamURL, openAuthorization) + if response.StatusCode != http.StatusOK { + t.Fatalf("open failed: %d %s", response.StatusCode, body) + } + channelID := opener.Session.ChannelIDString() + state := mustGetChannel(t, session, channelID) + if state == nil || state.Deposit != 1_000_000 { + t.Fatalf("channel state after open = %+v", state) + } + + // The broadcast open transaction confirmed on-chain. + openReceipt, err := core.ParseReceipt(response.Header.Get(core.PaymentReceiptHeader)) + if err != nil { + t.Fatalf("parse open receipt: %v", err) + } + openSignature, err := solana.SignatureFromBase58(openReceipt.Reference) + if err != nil { + t.Fatalf("open receipt reference %q is not a signature: %v", openReceipt.Reference, err) + } + statuses, err := rpcClient.GetSignatureStatuses(ctx, true, openSignature) + if err != nil || len(statuses.Value) == 0 || statuses.Value[0] == nil || statuses.Value[0].Err != nil { + t.Fatalf("open signature %s not confirmed: %v %+v", openSignature, err, statuses) + } + + // 3. In-band voucher: advances the watermark. + voucherAction, err := opener.Session.VoucherAction(100) + if err != nil { + t.Fatalf("VoucherAction: %v", err) + } + voucherAuthorization, err := client.SerializeSessionCredential(challenge, voucherAction) + if err != nil { + t.Fatalf("serialize voucher credential: %v", err) + } + response, body = authedGet(t, streamURL, voucherAuthorization) + if response.StatusCode != http.StatusOK { + t.Fatalf("voucher failed: %d %s", response.StatusCode, body) + } + if mustGetChannel(t, session, channelID).Cumulative != 100 { + t.Fatal("voucher did not advance the watermark") + } + + // 4. Side-channel reserve + commit. + reserve := reserveDeliveryHTTP(t, httpServer.URL, map[string]any{"sessionId": channelID, "amount": "200"}) + voucher, err := opener.Session.PrepareIncrement(150) + if err != nil { + t.Fatalf("PrepareIncrement: %v", err) + } + receipt := commitDeliveryHTTP(t, httpServer.URL, reserve.DeliveryID, voucher) + if receipt.Status != intents.CommitStatusCommitted || receipt.Cumulative != "250" { + t.Fatalf("commit receipt = %+v", receipt) + } + if err := opener.Session.RecordVoucher(voucher); err != nil { + t.Fatalf("RecordVoucher: %v", err) + } + + // 5. Close: settles the highest voucher on-chain and finalizes. + closeAuthorization, err := client.SerializeSessionCredential(challenge, + intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})) + if err != nil { + t.Fatalf("serialize close credential: %v", err) + } + response, body = authedGet(t, streamURL, closeAuthorization) + if response.StatusCode != http.StatusOK { + t.Fatalf("close failed: %d %s", response.StatusCode, body) + } + state = mustGetChannel(t, session, channelID) + if !state.Finalized || state.SettledSignature == nil { + t.Fatalf("channel not settled: %+v", state) + } + settleSignature, err := solana.SignatureFromBase58(*state.SettledSignature) + if err != nil { + t.Fatalf("settled signature %q invalid: %v", *state.SettledSignature, err) + } + deadline := time.Now().Add(30 * time.Second) + for { + statuses, err := rpcClient.GetSignatureStatuses(ctx, true, settleSignature) + if err == nil && len(statuses.Value) > 0 && statuses.Value[0] != nil { + if statuses.Value[0].Err != nil { + t.Fatalf("settlement failed on-chain: %+v", statuses.Value[0].Err) + } + break + } + if time.Now().After(deadline) { + t.Fatalf("settlement %s never confirmed", settleSignature) + } + time.Sleep(time.Second) + } +} + +// reserveDeliveryHTTP reserves a delivery through the live side channel. +func reserveDeliveryHTTP(t *testing.T, baseURL string, body map[string]any) intents.MeteringDirective { + t.Helper() + directive := intents.MeteringDirective{} + postSessionJSON(t, baseURL+"/__402/session/deliveries", body, &directive) + return directive +} + +// commitDeliveryHTTP commits a delivery through the live side channel. +func commitDeliveryHTTP(t *testing.T, baseURL, deliveryID string, voucher intents.SignedVoucher) intents.CommitReceipt { + t.Helper() + receipt := intents.CommitReceipt{} + postSessionJSON(t, baseURL+"/__402/session/commit", map[string]any{ + "deliveryId": deliveryID, + "voucher": voucher, + }, &receipt) + return receipt +} + +// postSessionJSON POSTs a JSON body and decodes the 200 response into out. +func postSessionJSON(t *testing.T, url string, body map[string]any, out any) { + t.Helper() + encoded, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + response, err := http.Post(url, "application/json", bytes.NewReader(encoded)) + if err != nil { + t.Fatalf("POST %s: %v", url, err) + } + defer response.Body.Close() + raw, err := io.ReadAll(response.Body) + if err != nil { + t.Fatalf("read response: %v", err) + } + if response.StatusCode != http.StatusOK { + t.Fatalf("POST %s: %d %s", url, response.StatusCode, raw) + } + if err := json.Unmarshal(raw, out); err != nil { + t.Fatalf("decode response: %v", err) + } +} diff --git a/go/protocols/mpp/server/session_lifecycle.go b/go/protocols/mpp/server/session_lifecycle.go new file mode 100644 index 000000000..71bda9f32 --- /dev/null +++ b/go/protocols/mpp/server/session_lifecycle.go @@ -0,0 +1,103 @@ +package server + +// Per-channel idle-close lifecycle. +// +// When the server accepts an open, it arms a single-shot timer keyed on the +// channel id. Every voucher / commit / topUp Touch resets the timer. When the +// timer fires, the closeOnIdle handler is invoked with the channel id so the +// server can run its close-and-settle path without waiting for a client close +// action. +// +// The idle-close watchdog is an extension beyond the draft MPP spec; +// without it, hosts drive close explicitly. + +import ( + "sync" + "time" +) + +// SessionLifecycle is the idle-close watchdog. Touch resets the per-channel +// timer, RemoveChannel cancels it, and Shutdown cancels everything. +type SessionLifecycle struct { + // mu guards timers and shutdown. + mu sync.Mutex + + // timers holds the armed single-shot idle timer per channel id. + timers map[string]*time.Timer + + // closeDelay is the idle duration before a channel is auto-closed; + // <= 0 disables the watchdog entirely. + closeDelay time.Duration + + // closeOnIdle is invoked with the channel id when its idle timer fires. + closeOnIdle func(channelID string) + + // shutdown, once true, turns every later Touch into a no-op and stops + // already-fired timers from invoking closeOnIdle. + shutdown bool +} + +// NewSessionLifecycle creates an idle-close watchdog. closeDelay <= 0 +// disables the timer entirely (all operations become no-ops), the right +// default for tests and for callers that drive close explicitly. +// +// closeOnIdle is invoked with the channel id when a timer fires. Errors +// during idle close have no synchronous caller to report to; the handler is +// expected to log internally. +func NewSessionLifecycle(closeOnIdle func(channelID string), closeDelay time.Duration) *SessionLifecycle { + return &SessionLifecycle{ + timers: map[string]*time.Timer{}, + closeDelay: closeDelay, + closeOnIdle: closeOnIdle, + } +} + +// Touch resets the idle timer for channelID. No-op when the close delay is +// disabled or the lifecycle is shut down. +func (l *SessionLifecycle) Touch(channelID string) { + if l.closeDelay <= 0 { + return + } + l.mu.Lock() + defer l.mu.Unlock() + if l.shutdown { + return + } + l.cancelLocked(channelID) + l.timers[channelID] = time.AfterFunc(l.closeDelay, func() { + l.mu.Lock() + delete(l.timers, channelID) + stopped := l.shutdown + l.mu.Unlock() + if stopped { + return + } + l.closeOnIdle(channelID) + }) +} + +// RemoveChannel cancels the idle timer for channelID. +func (l *SessionLifecycle) RemoveChannel(channelID string) { + l.mu.Lock() + defer l.mu.Unlock() + l.cancelLocked(channelID) +} + +// Shutdown cancels every outstanding timer and disables future touches. +func (l *SessionLifecycle) Shutdown() { + l.mu.Lock() + defer l.mu.Unlock() + l.shutdown = true + for channelID, timer := range l.timers { + timer.Stop() + delete(l.timers, channelID) + } +} + +// cancelLocked stops and forgets the timer for channelID. Callers hold l.mu. +func (l *SessionLifecycle) cancelLocked(channelID string) { + if timer, ok := l.timers[channelID]; ok { + timer.Stop() + delete(l.timers, channelID) + } +} diff --git a/go/protocols/mpp/server/session_lifecycle_test.go b/go/protocols/mpp/server/session_lifecycle_test.go new file mode 100644 index 000000000..5ab3ae49b --- /dev/null +++ b/go/protocols/mpp/server/session_lifecycle_test.go @@ -0,0 +1,120 @@ +package server + +// Unit coverage of the SessionLifecycle idle-close watchdog: zero-delay +// disablement, idle firing, touch resets, channel removal, and shutdown. + +import ( + "sync" + "testing" + "time" +) + +// idleRecorder collects closeOnIdle invocations. +type idleRecorder struct { + // mu guards fired. + mu sync.Mutex + + // fired accumulates the channel ids passed to the handler, in order. + fired []string + + // ch receives each fired channel id so tests can block until the + // watchdog fires. + ch chan string +} + +func newIdleRecorder() *idleRecorder { + return &idleRecorder{ch: make(chan string, 16)} +} + +func (r *idleRecorder) handler(channelID string) { + r.mu.Lock() + r.fired = append(r.fired, channelID) + r.mu.Unlock() + r.ch <- channelID +} + +func (r *idleRecorder) count() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.fired) +} + +func TestSessionLifecycleZeroDelayDisablesTimers(t *testing.T) { + recorder := newIdleRecorder() + lifecycle := NewSessionLifecycle(recorder.handler, 0) + lifecycle.Touch("c1") + + time.Sleep(30 * time.Millisecond) + if recorder.count() != 0 { + t.Fatalf("closeOnIdle fired %d times with disabled delay", recorder.count()) + } +} + +func TestSessionLifecycleFiresAfterIdle(t *testing.T) { + recorder := newIdleRecorder() + lifecycle := NewSessionLifecycle(recorder.handler, 10*time.Millisecond) + defer lifecycle.Shutdown() + + lifecycle.Touch("c1") + select { + case channelID := <-recorder.ch: + if channelID != "c1" { + t.Fatalf("fired for %q, want c1", channelID) + } + case <-time.After(2 * time.Second): + t.Fatal("closeOnIdle never fired") + } +} + +func TestSessionLifecycleTouchResetsTimer(t *testing.T) { + recorder := newIdleRecorder() + lifecycle := NewSessionLifecycle(recorder.handler, 80*time.Millisecond) + defer lifecycle.Shutdown() + + lifecycle.Touch("c1") + // Keep touching before the delay elapses; the timer must keep resetting. + for range 3 { + time.Sleep(30 * time.Millisecond) + lifecycle.Touch("c1") + if recorder.count() != 0 { + t.Fatal("closeOnIdle fired while the channel was being touched") + } + } + select { + case <-recorder.ch: + case <-time.After(2 * time.Second): + t.Fatal("closeOnIdle never fired after touches stopped") + } + if recorder.count() != 1 { + t.Fatalf("closeOnIdle fired %d times, want 1", recorder.count()) + } +} + +func TestSessionLifecycleRemoveChannelCancelsTimer(t *testing.T) { + recorder := newIdleRecorder() + lifecycle := NewSessionLifecycle(recorder.handler, 20*time.Millisecond) + defer lifecycle.Shutdown() + + lifecycle.Touch("c1") + lifecycle.RemoveChannel("c1") + + time.Sleep(60 * time.Millisecond) + if recorder.count() != 0 { + t.Fatalf("closeOnIdle fired %d times after RemoveChannel", recorder.count()) + } +} + +func TestSessionLifecycleShutdownCancelsAllTimersAndDisablesTouch(t *testing.T) { + recorder := newIdleRecorder() + lifecycle := NewSessionLifecycle(recorder.handler, 20*time.Millisecond) + + lifecycle.Touch("c1") + lifecycle.Touch("c2") + lifecycle.Shutdown() + lifecycle.Touch("c3") + + time.Sleep(60 * time.Millisecond) + if recorder.count() != 0 { + t.Fatalf("closeOnIdle fired %d times after Shutdown", recorder.count()) + } +} diff --git a/go/protocols/mpp/server/session_method.go b/go/protocols/mpp/server/session_method.go new file mode 100644 index 000000000..d532b4969 --- /dev/null +++ b/go/protocols/mpp/server/session_method.go @@ -0,0 +1,803 @@ +package server + +// HTTP-facing session method. +// +// A Session issues HMAC-bound 402 challenges carrying a SessionRequest +// (Challenge), verifies Authorization credentials whose payload is one of the +// five session actions (VerifyCredential dispatching to open / voucher / +// commit / topUp / close), exposes the reserve/commit metering side channel +// (Routes), and drives on-chain settlement at close when both a merchant +// signer and an RPC client are configured. The lower-level building blocks +// (SessionServer, ChannelStore, the voucher verifier, and the on-chain +// helpers) are composed here. +// +// The close settlement path, the idle-close watchdog, the re-drivable close, +// and the side-channel routes are extensions beyond the draft MPP spec and +// are documented as such where they extend it. + +import ( + "context" + "fmt" + "log" + "strconv" + "time" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// OpenTxSubmitter selects who broadcasts a push-mode payment-channel open +// transaction. +type OpenTxSubmitter string + +const ( + // OpenTxSubmitterClient means the client broadcasts the open transaction + // itself and the server only verifies it. Default. + OpenTxSubmitterClient OpenTxSubmitter = "client" + + // OpenTxSubmitterServer means the server completes the fee-payer + // signature, broadcasts the client-built open transaction, and waits for + // confirmation before persisting channel state. + OpenTxSubmitterServer OpenTxSubmitter = "server" +) + +// SessionOptions configures NewSession. +type SessionOptions struct { + // Operator public key (base58), shown to clients in the challenge. + Operator string + + // Recipient is the primary payment recipient (base58). Required. + Recipient string + + // Cap is the maximum session cap the server will offer (base units). + // Required, must be positive. + Cap uint64 + + // Currency identifier (e.g. "USDC" or an SPL mint address). Default USDC. + Currency string + + // Decimals is the token decimals. Default 6. + Decimals uint8 + + // Network is the Solana network. Default "mainnet". + Network string + + // SecretKey is the challenge HMAC secret. Defaults to MPP_SECRET_KEY. + SecretKey string + + // Realm is the challenge realm. Defaults to DetectRealm(). + Realm string + + // ProgramID overrides the payment-channels program id. Nil defaults to + // the canonical program. + ProgramID *solana.PublicKey + + // MinVoucherDelta is the minimum voucher increment (base units). 0 = no + // minimum. + MinVoucherDelta uint64 + + // Modes are the funding modes advertised to clients. Empty means push + // only. + Modes []intents.SessionMode + + // PullVoucherStrategy is the voucher authority for pull-mode sessions. + // Required when Modes includes pull. + PullVoucherStrategy *intents.SessionPullVoucherStrategy + + // Splits are optional basis-point splits distributed at close. Max 8. + Splits []Split + + // CloseDelay arms the idle-close watchdog; zero disables it. + CloseDelay time.Duration + + // OpenTxSubmitter selects who broadcasts push-mode open transactions. + // Default OpenTxSubmitterClient. + OpenTxSubmitter OpenTxSubmitter + + // Signer is the merchant signer for the settle_and_finalize + distribute + // settlement transaction. Settlement at close (and on idle close) only + // runs when both Signer and RPC are configured. + Signer solanatx.Signer + + // PaymentChannelPayerSigner completes the fee-payer signature when the + // server broadcasts a client-built open (OpenTxSubmitterServer). + PaymentChannelPayerSigner solanatx.Signer + + // Store is the pluggable channel store. Defaults to in-memory. + Store ChannelStore + + // RPC is the optional RPC client used for on-chain checks, the + // recentBlockhash prefetch, and settlement broadcasts. Nil skips every + // on-chain check and trusts payload claims as provided. + RPC solanatx.RPCClient +} + +// Session is the server-side session method handler. Create with NewSession. +type Session struct { + // core is the lower-level SessionServer dispatching open / voucher / + // commit / topUp / close against the channel store. + core *SessionServer + + // lifecycle is the idle-close watchdog; nil when CloseDelay is zero. + lifecycle *SessionLifecycle + + // secretKey is the HMAC secret binding 402 challenges to this server. + secretKey string + + // realm is the challenge realm advertised in 402 responses. + realm string + + // cap is the maximum session cap offered in challenges (token base + // units); per-challenge requested caps are clamped to it. + cap uint64 + + // currency is the challenge currency (symbol such as "USDC" or an SPL + // mint address). + currency string + + // recipient is the primary payment recipient pubkey (base58). + recipient string + + // network is the Solana network ("mainnet", "devnet", "localnet"). + network string + + // openTxSubmitter selects whether the client or the server broadcasts + // push-mode open transactions. + openTxSubmitter OpenTxSubmitter + + // signer is the merchant signer for the close settlement transaction; + // settlement only runs when both signer and rpc are configured. + signer solanatx.Signer + + // payerSigner completes the fee-payer signature on server-broadcast + // opens (OpenTxSubmitterServer). + payerSigner solanatx.Signer + + // rpc is the optional RPC client for on-chain checks, the blockhash + // prefetch, and settlement broadcasts; nil skips every on-chain check + // and trusts payload claims as provided. + rpc solanatx.RPCClient +} + +// NewSession creates the server-side session method. +func NewSession(options SessionOptions) (*Session, error) { + if options.Cap == 0 { + return nil, core.NewError(core.ErrCodeInvalidConfig, "cap must be positive") + } + if options.Recipient == "" { + return nil, core.NewError(core.ErrCodeInvalidConfig, "recipient is required") + } + if _, err := solana.PublicKeyFromBase58(options.Recipient); err != nil { + return nil, core.WrapError(core.ErrCodeInvalidConfig, "invalid recipient pubkey", err) + } + if len(options.Splits) > maxSplits { + return nil, core.NewError(core.ErrCodeInvalidConfig, + fmt.Sprintf("splits cannot exceed %d entries", maxSplits)) + } + if options.SecretKey == "" { + options.SecretKey = DetectSecretKey() + } + if options.SecretKey == "" { + return nil, core.NewError(core.ErrCodeInvalidConfig, "missing secret key") + } + if options.Currency == "" { + options.Currency = "USDC" + } + if options.Decimals == 0 { + options.Decimals = 6 + } + if options.Network == "" { + options.Network = "mainnet" + } + if options.Realm == "" { + options.Realm = DetectRealm() + } + switch options.OpenTxSubmitter { + case "": + options.OpenTxSubmitter = OpenTxSubmitterClient + case OpenTxSubmitterClient, OpenTxSubmitterServer: + default: + return nil, core.NewError(core.ErrCodeInvalidConfig, + fmt.Sprintf("openTxSubmitter must be %q or %q, got %q", + OpenTxSubmitterClient, OpenTxSubmitterServer, options.OpenTxSubmitter)) + } + supportsPull := false + for _, mode := range options.Modes { + if mode == intents.SessionModePull { + supportsPull = true + } + } + if supportsPull && options.PullVoucherStrategy == nil { + return nil, core.NewError(core.ErrCodeInvalidConfig, + "pullVoucherStrategy is required when modes includes pull") + } + store := options.Store + if store == nil { + store = NewMemoryChannelStore() + } + + config := SessionConfig{ + Operator: options.Operator, + Recipient: options.Recipient, + Splits: options.Splits, + MaxCap: options.Cap, + Currency: options.Currency, + Decimals: options.Decimals, + Network: options.Network, + ProgramID: options.ProgramID, + MinVoucherDelta: options.MinVoucherDelta, + Modes: options.Modes, + PullVoucherStrategy: options.PullVoucherStrategy, + } + session := &Session{ + core: NewSessionServer(config, store), + secretKey: options.SecretKey, + realm: options.Realm, + cap: options.Cap, + currency: options.Currency, + recipient: options.Recipient, + network: options.Network, + openTxSubmitter: options.OpenTxSubmitter, + signer: options.Signer, + payerSigner: options.PaymentChannelPayerSigner, + rpc: options.RPC, + } + if options.CloseDelay > 0 { + session.lifecycle = NewSessionLifecycle(session.closeOnIdle, options.CloseDelay) + } + return session, nil +} + +// Core returns the underlying SessionServer so hosts can reach the channel +// store and the lower-level lifecycle methods. +func (s *Session) Core() *SessionServer { return s.core } + +// Shutdown cancels the idle-close watchdog timers. Hosts should call it when +// tearing the session method down. +func (s *Session) Shutdown() { + if s.lifecycle != nil { + s.lifecycle.Shutdown() + } +} + +// touch resets the idle-close timer for channelID when the watchdog is armed. +func (s *Session) touch(channelID string) { + if s.lifecycle != nil { + s.lifecycle.Touch(channelID) + } +} + +// closeOnIdle is the idle-close watchdog handler: settle the channel +// on-chain when both a merchant signer and an RPC client are configured. +// Errors have no synchronous caller to report to and are logged instead. +func (s *Session) closeOnIdle(channelID string) { + if s.signer == nil || s.rpc == nil { + return + } + if _, err := s.closeAndSettleChannel(context.Background(), channelID); err != nil { + log.Printf("[solana-mpp] idle-close settle failed for %s: %v", channelID, err) + } +} + +// SessionChallengeOptions customize a single 402 session challenge. +type SessionChallengeOptions struct { + // Cap is the requested session cap (base units, decimal string). Empty + // uses the server maximum; larger requests are clamped to it. + Cap string + + // Description is a human-readable challenge description. + Description string + + // ExternalID is a merchant reference id echoed on the receipt. + ExternalID string + + // Expires is the challenge expiry (RFC 3339). Default five minutes. + Expires string +} + +// Challenge builds the HMAC-bound 402 challenge embedding a SessionRequest. +// +// The requested cap is clamped to the server maximum, minVoucherDelta is +// included only when positive, modes are omitted when push-only, +// pullVoucherStrategy is included only when pull is offered, and a recent +// blockhash is prefetched (non-fatally) when an RPC client is configured. +// The blockhash source is the injected RPC client rather than a raw URL +// fetch so unit tests stay offline. +func (s *Session) Challenge(ctx context.Context, options SessionChallengeOptions) (core.PaymentChallenge, error) { + capValue := s.cap + if options.Cap != "" { + requested, err := parseSessionU64(options.Cap, "cap") + if err != nil { + return core.PaymentChallenge{}, core.WrapError(core.ErrCodeInvalidPayload, "invalid requested cap", err) + } + capValue = requested + } + request := s.core.BuildChallengeRequest(capValue) + if options.Description != "" { + description := options.Description + request.Description = &description + } + if options.ExternalID != "" { + externalID := options.ExternalID + request.ExternalID = &externalID + } + if s.rpc != nil { + // Non-fatal: the client fetches its own blockhash when absent. + if out, err := s.rpc.GetLatestBlockhash(ctx, rpc.CommitmentConfirmed); err == nil && out != nil && out.Value != nil { + blockhash := out.Value.Blockhash.String() + request.RecentBlockhash = &blockhash + } + } + requestValue, err := core.NewBase64URLJSONValue(request) + if err != nil { + return core.PaymentChallenge{}, err + } + expires := options.Expires + if expires == "" { + expires = core.Minutes(5) + } + return core.NewChallengeWithSecretFull( + s.secretKey, + s.realm, + core.NewMethodName("solana"), + core.NewIntentName("session"), + requestValue, + expires, + "", + options.Description, + nil, + ), nil +} + +// VerifyCredential verifies a session Authorization credential: Tier-1 HMAC +// and expiry, the Tier-2 pinned-field backstop, then dispatch on the payload +// action (open / voucher / commit / topUp / close). +func (s *Session) VerifyCredential(ctx context.Context, credential core.PaymentCredential) (core.Receipt, error) { + challenge := core.PaymentChallenge{ + ID: credential.Challenge.ID, + Realm: credential.Challenge.Realm, + Method: credential.Challenge.Method, + Intent: credential.Challenge.Intent, + Request: credential.Challenge.Request, + Expires: credential.Challenge.Expires, + Digest: credential.Challenge.Digest, + Opaque: credential.Challenge.Opaque, + } + if !challenge.Verify(s.secretKey) { + return core.Receipt{}, core.NewError(core.ErrCodeChallengeMismatch, "challenge ID mismatch") + } + if challenge.IsExpired(time.Now()) { + return core.Receipt{}, core.NewError(core.ErrCodeChallengeExpired, + fmt.Sprintf("challenge expired at %s", challenge.Expires)) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + return core.Receipt{}, err + } + if err := s.verifyPinnedSessionFields(credential, request); err != nil { + return core.Receipt{}, err + } + var action intents.SessionAction + if err := credential.PayloadAs(&action); err != nil { + return core.Receipt{}, core.WrapError(core.ErrCodeInvalidPayload, "decode session action", err) + } + + var reference string + var err error + switch { + case action.Open != nil: + reference, err = s.handleOpen(ctx, action.Open) + case action.Voucher != nil: + reference, err = s.handleVoucher(ctx, action.Voucher) + case action.Commit != nil: + reference, err = s.handleCommit(ctx, action.Commit) + case action.TopUp != nil: + reference, err = s.handleTopUp(ctx, action.TopUp) + case action.Close != nil: + reference, err = s.handleClose(ctx, action.Close) + default: + return core.Receipt{}, core.NewError(core.ErrCodeInvalidPayload, "unknown session action") + } + if err != nil { + return core.Receipt{}, err + } + externalID := "" + if request.ExternalID != nil { + externalID = *request.ExternalID + } + return successReceipt(reference, credential.Challenge.ID, externalID), nil +} + +// verifyPinnedSessionFields is the Tier-2 backstop for session credentials: +// after Tier-1 HMAC confirms the challenge was issued by this server, fields +// fixed at construction time are compared so a credential issued by a +// different method/intent/realm or for a different recipient/currency cannot +// reach the action handlers. Same rationale as the charge handler's +// verifyPinnedFields. +func (s *Session) verifyPinnedSessionFields(credential core.PaymentCredential, request intents.SessionRequest) error { + const methodName = "solana" + if string(credential.Challenge.Method) != methodName { + return core.NewError(core.ErrCodeChallengeRouteMismatch, + fmt.Sprintf("credential method %q does not match this server (expected %q)", + credential.Challenge.Method, methodName)) + } + if !credential.Challenge.Intent.IsSession() { + return core.NewError(core.ErrCodeChallengeRouteMismatch, + fmt.Sprintf("credential intent %q is not a session", credential.Challenge.Intent)) + } + if credential.Challenge.Realm != s.realm { + return core.NewError(core.ErrCodeChallengeRouteMismatch, + fmt.Sprintf("credential realm %q does not match this server (expected %q)", + credential.Challenge.Realm, s.realm)) + } + if request.Currency != s.currency { + return core.NewError(core.ErrCodeChallengeRouteMismatch, + fmt.Sprintf("credential currency %q does not match this server (expected %q)", + request.Currency, s.currency)) + } + if request.Recipient != s.recipient { + return core.NewError(core.ErrCodeRecipientMismatch, + "credential recipient does not match this server") + } + return nil +} + +// handleOpen processes an open action: resolve the channel facts from the +// payload (verifying or broadcasting the attached transaction when present), +// enforce the deposit invariants, and insert the channel state atomically and +// idempotently. +func (s *Session) handleOpen(ctx context.Context, payload *intents.OpenPayload) (string, error) { + mode := payload.Mode + if !s.core.supportsMode(mode) { + return "", fmt.Errorf("session mode %q is not supported by this challenge", mode) + } + if mode == intents.SessionModePull && s.core.config.PullVoucherStrategy == nil { + return "", fmt.Errorf("pull-mode open requires a pullVoucherStrategy on the server config") + } + // Empty strings count as missing. + hasTransaction := payload.Transaction != nil && *payload.Transaction != "" + hasChannelID := payload.ChannelID != nil && *payload.ChannelID != "" + if mode == intents.SessionModePush && !hasTransaction && !hasChannelID { + return "", fmt.Errorf("open payload missing transaction or channelId") + } + + var channelID string + var deposit uint64 + signature := payload.Signature + + switch { + case hasTransaction: + // Payment-channel-backed open: push sessions and clientVoucher pull + // sessions whose deposit lives in an on-chain payment channel both + // attach the pre-signed open transaction. + expected := VerifyOpenTxExpected{ + AuthorizedSigner: payload.AuthorizedSigner, + Currency: s.currency, + MaxCap: s.cap, + Network: s.network, + ProgramID: s.core.config.ProgramID, + Recipient: s.recipient, + } + if s.openTxSubmitter == OpenTxSubmitterServer { + if s.rpc == nil { + return "", fmt.Errorf("openTxSubmitter=server requires an rpc client") + } + // Decode-only first so an idempotent replay of an + // already-persisted open does not rebroadcast the transaction. + preVerified, err := VerifyOpenTx(ctx, expected, payload, nil) + if err != nil { + return "", err + } + existing, err := s.core.store.GetChannel(ctx, preVerified.ChannelID) + if err != nil { + return "", err + } + if existing != nil { + channelID = preVerified.ChannelID + deposit = preVerified.Deposit + } else { + submitted, err := SubmitOpenTx(ctx, expected, payload, s.payerSigner, s.rpc) + if err != nil { + return "", err + } + channelID = submitted.ChannelID + deposit = submitted.Deposit + signature = submitted.Signature + } + } else { + verified, err := VerifyOpenTx(ctx, expected, payload, s.rpc) + if err != nil { + return "", err + } + channelID = verified.ChannelID + deposit = verified.Deposit + } + case mode == intents.SessionModePush: + // No transaction in the payload: the client asserts a previously + // broadcast open. With an RPC client the open signature is confirmed + // on-chain before persisting; without one the channelId/deposit + // fields are trusted as-is. + channelID = *payload.ChannelID + var err error + deposit, err = payload.DepositAmount() + if err != nil { + return "", err + } + if s.rpc != nil { + if err := confirmTransactionSignature(ctx, s.rpc, signature, "open"); err != nil { + return "", err + } + } + default: + // Pull mode without a channel transaction: trust the + // channelId/tokenAccount + approvedAmount. Keying order is channelId + // first, then tokenAccount. + // + // The Go SDK has no multi-delegate program builders, so + // operated-voucher opens do not submit a multi-delegate init + // transaction here (the client cannot produce those transactions + // either; see go/README.md scope notes). + var err error + channelID, err = payload.SessionID() + if err != nil { + return "", err + } + deposit, err = payload.DepositAmount() + if err != nil { + return "", err + } + } + + if deposit == 0 { + return "", fmt.Errorf("deposit must be greater than zero") + } + if deposit > s.cap { + return "", fmt.Errorf("deposit %d exceeds cap %d", deposit, s.cap) + } + + operator := payload.Owner + if operator == nil { + operator = payload.Payer + } + fresh := ChannelState{ + ChannelID: channelID, + AuthorizedSigner: payload.AuthorizedSigner, + Deposit: deposit, + Operator: operator, + } + + // The existence check lives inside the atomic mutator so a concurrent + // open replay cannot race a fresh create. Replays must never reset the + // voucher watermark. + if _, err := s.core.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current != nil { + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", channelID) + } + if current.AuthorizedSigner != payload.AuthorizedSigner { + return ChannelState{}, fmt.Errorf( + "open replay: authorizedSigner %s does not match existing channel %s", + payload.AuthorizedSigner, channelID) + } + return *current, nil + } + return fresh, nil + }); err != nil { + return "", err + } + s.touch(channelID) + + if signature == "" { + return channelID, nil + } + return signature, nil +} + +// handleVoucher verifies a cumulative voucher and advances the watermark. +// The receipt reference is ":". +func (s *Session) handleVoucher(ctx context.Context, payload *intents.VoucherPayload) (string, error) { + channelID := payload.Voucher.Data.ChannelID + cumulative, err := s.core.VerifyVoucher(ctx, payload) + if err != nil { + return "", err + } + s.touch(channelID) + return fmt.Sprintf("%s:%d", channelID, cumulative), nil +} + +// handleCommit commits a reserved metered delivery. The receipt reference is +// "::". +func (s *Session) handleCommit(ctx context.Context, payload *intents.CommitPayload) (string, error) { + receipt, err := s.core.ProcessCommit(ctx, payload) + if err != nil { + return "", err + } + s.touch(receipt.SessionID) + return fmt.Sprintf("%s:%s:%s", receipt.SessionID, receipt.DeliveryID, receipt.Cumulative), nil +} + +// handleTopUp raises a channel's deposit after optional on-chain +// confirmation of the top-up signature. The receipt reference is the top-up +// transaction signature. +func (s *Session) handleTopUp(ctx context.Context, payload *intents.TopUpPayload) (string, error) { + newDeposit, err := parseSessionU64(payload.NewDeposit, "newDeposit") + if err != nil { + return "", err + } + if newDeposit > s.cap { + return "", fmt.Errorf("newDeposit %d exceeds cap %d", newDeposit, s.cap) + } + + // Cheap store pre-checks before touching the network. + existing, err := s.core.store.GetChannel(ctx, payload.ChannelID) + if err != nil { + return "", err + } + if existing == nil { + return "", fmt.Errorf("channel %s not found", payload.ChannelID) + } + if existing.Finalized { + return "", fmt.Errorf("channel %s is already finalized", payload.ChannelID) + } + if existing.CloseRequestedAt != nil { + return "", fmt.Errorf("channel %s close is pending; no further top-ups accepted", payload.ChannelID) + } + if s.rpc != nil { + if err := confirmTransactionSignature(ctx, s.rpc, payload.Signature, "topUp"); err != nil { + return "", err + } + } + if _, err := s.core.ProcessTopUp(ctx, payload); err != nil { + return "", err + } + s.touch(payload.ChannelID) + return payload.Signature, nil +} + +// handleClose accepts the optional final voucher, flips close-pending +// atomically, and settles on-chain when both a merchant signer and an RPC +// client are configured. The receipt reference is the on-chain settle +// signature when one exists, else the channel id. +// +// Unlike SessionServer.ProcessClose, where a second close is always +// rejected, the close here is re-drivable: when a prior close flipped the +// close-pending flag but settlement never recorded a signature, the retry +// proceeds so a transient settlement failure cannot strand the channel. +func (s *Session) handleClose(ctx context.Context, payload *intents.ClosePayload) (string, error) { + channelID := payload.ChannelID + now := uint64(time.Now().Unix()) + + if _, err := s.core.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", channelID) + } + if current.Finalized { + return ChannelState{}, fmt.Errorf("channel %s is already finalized", channelID) + } + if current.CloseRequestedAt != nil { + if current.SettledSignature == nil { + // Re-drivable close: leave state untouched and let the + // settlement retry proceed. + return *current, nil + } + return ChannelState{}, fmt.Errorf("close already requested") + } + + next := *current + closeRequestedAt := now + if payload.Voucher != nil { + voucher := *payload.Voucher + // Idempotent replay of the current highest voucher (same + // cumulative AND same signature) is accepted as-is. + replay := current.HighestVoucherSignature != nil && + *current.HighestVoucherSignature == voucher.Signature && + voucher.Data.Cumulative == strconv.FormatUint(current.Cumulative, 10) + if !replay { + verdict := VerifyVoucherForChannel(VerifyVoucherArgs{ + State: *current, + Signed: voucher, + Deposit: current.Deposit, + }) + switch verdict.Status { + case VoucherVerifyRejected: + // A non-replay final voucher at or below the watermark is + // a hard error: the close must abort rather than silently + // settle a stale amount. + return ChannelState{}, fmt.Errorf("%s: %s", verdict.Reason, verdict.Detail) + case VoucherVerifyAccepted: + next.Cumulative = verdict.NewCumulative + signature := verdict.NewSignature + expiresAt := verdict.NewExpiresAt + next.HighestVoucherSignature = &signature + next.HighestVoucherExpiresAt = &expiresAt + } + } + } + next.CloseRequestedAt = &closeRequestedAt + return next, nil + }); err != nil { + return "", err + } + + reference := channelID + if s.signer != nil && s.rpc != nil { + settleSignature, err := s.closeAndSettleChannel(ctx, channelID) + if err != nil { + return "", err + } + if settleSignature != "" { + reference = settleSignature + } + } + if s.lifecycle != nil { + s.lifecycle.RemoveChannel(channelID) + } + return reference, nil +} + +// closeAndSettleChannel builds settle_and_finalize (+ the Ed25519 precompile +// when a voucher was accepted) + distribute for a channel that has flipped +// to close-pending, submits them as one merchant-signed transaction, and +// marks the channel finalized with the settled signature. Returns "" when +// the channel does not exist. +func (s *Session) closeAndSettleChannel(ctx context.Context, channelID string) (string, error) { + state, err := s.core.store.GetChannel(ctx, channelID) + if err != nil { + return "", err + } + if state == nil { + return "", nil + } + merchant := s.signer.PublicKey() + // The recipient backstops the distribute payer for channels that never + // recorded an operator. + instructions, err := s.core.settlementInstructionsForState(*state, channelID, merchant, s.recipient) + if err != nil { + return "", err + } + blockhash, err := s.rpc.GetLatestBlockhash(ctx, rpc.CommitmentConfirmed) + if err != nil { + return "", core.WrapError(core.ErrCodeRPC, "fetch settlement blockhash", err) + } + if blockhash == nil || blockhash.Value == nil { + return "", core.NewError(core.ErrCodeRPC, "fetch settlement blockhash: empty response") + } + tx, err := solana.NewTransaction(instructions, blockhash.Value.Blockhash, solana.TransactionPayer(merchant)) + if err != nil { + return "", fmt.Errorf("build settlement transaction: %w", err) + } + if err := solanatx.SignTransaction(tx, s.signer); err != nil { + return "", fmt.Errorf("sign settlement transaction: %w", err) + } + signature, err := solanatx.SendTransaction(ctx, s.rpc, tx) + if err != nil { + return "", core.WrapError(core.ErrCodeRPC, "send settlement transaction", err) + } + settled := signature.String() + if _, err := s.core.store.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s disappeared during settle", channelID) + } + next := *current + next.Finalized = true + next.SettledSignature = &settled + return next, nil + }); err != nil { + return "", err + } + return settled, nil +} + +// parseSessionU64 parses a non-negative decimal string into a u64, naming +// the field in the error. +func parseSessionU64(value, name string) (uint64, error) { + parsed, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return 0, fmt.Errorf("%s is not an unsigned integer string: %s", name, value) + } + return parsed, nil +} diff --git a/go/protocols/mpp/server/session_method_branch_test.go b/go/protocols/mpp/server/session_method_branch_test.go new file mode 100644 index 000000000..d2b37e186 --- /dev/null +++ b/go/protocols/mpp/server/session_method_branch_test.go @@ -0,0 +1,792 @@ +package server + +// Adversarial branch coverage for the session method layer: store and RPC +// failure surfacing, malformed payload fields, settlement error paths, the +// SubmitOpenTx failure matrix, malformed open instructions, and the +// side-channel/middleware error responses. + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// failingGetStore wraps a ChannelStore and fails GetChannel. +type failingGetStore struct { + // ChannelStore is the wrapped store handling everything but GetChannel. + ChannelStore + + // getErr, when set, is returned by every GetChannel call. + getErr error +} + +func (f *failingGetStore) GetChannel(ctx context.Context, channelID string) (*ChannelState, error) { + if f.getErr != nil { + return nil, f.getErr + } + return f.ChannelStore.GetChannel(ctx, channelID) +} + +// failingSigner satisfies solanatx.Signer but always fails to sign. +type failingSigner struct { + key solana.PrivateKey // supplies the pubkey; Sign always fails regardless +} + +func (f failingSigner) PublicKey() solana.PublicKey { return f.key.PublicKey() } + +func (f failingSigner) Sign([]byte) (solana.Signature, error) { + return solana.Signature{}, errors.New("hardware signer unavailable") +} + +// failingBlockhashRPC fails GetLatestBlockhash on top of FakeRPC. +type failingBlockhashRPC struct { + // FakeRPC handles every RPC call other than GetLatestBlockhash. + *testutil.FakeRPC + + // err, when set, is returned by GetLatestBlockhash. + err error + + // empty makes GetLatestBlockhash return a nil result with no error. + empty bool +} + +func (f *failingBlockhashRPC) GetLatestBlockhash(ctx context.Context, commitment rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { + if f.err != nil { + return nil, f.err + } + if f.empty { + return nil, nil + } + return f.FakeRPC.GetLatestBlockhash(ctx, commitment) +} + +// failingStatusRPC fails GetSignatureStatuses on top of FakeRPC. +type failingStatusRPC struct { + // FakeRPC handles every RPC call other than GetSignatureStatuses, which + // this wrapper always fails. + *testutil.FakeRPC +} + +func (f *failingStatusRPC) GetSignatureStatuses(context.Context, bool, ...solana.Signature) (*rpc.GetSignatureStatusesResult, error) { + return nil, errors.New("rpc unavailable") +} + +func seedChannel(t *testing.T, store ChannelStore, state ChannelState) { + t.Helper() + if _, err := store.UpdateChannel(context.Background(), state.ChannelID, func(*ChannelState) (ChannelState, error) { + return state, nil + }); err != nil { + t.Fatalf("seed channel: %v", err) + } +} + +// ── VerifyCredential decode failures ── + +func TestVerifyCredentialRejectsUndecodableRequestAndMissingPayload(t *testing.T) { + session := newTestSession(t, nil) + + // A challenge whose HMAC verifies but whose request is not a session + // request JSON object. + raw := core.NewBase64URLJSONRaw(`"just-a-string"`) + challenge := core.NewChallengeWithSecretFull( + sessionMethodSecret, "api.test", core.NewMethodName("solana"), core.NewIntentName("session"), + raw, core.Minutes(5), "", "", nil) + credential, err := core.NewPaymentCredential(challenge.ToEcho(), map[string]string{"action": "close"}) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + if _, err := session.VerifyCredential(context.Background(), credential); err == nil { + t.Fatal("expected request decode error") + } + + // A credential with no payload reaches the unknown-action default. + good, err := session.Challenge(context.Background(), SessionChallengeOptions{}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + noPayload := core.PaymentCredential{Challenge: good.ToEcho()} + if _, err := session.VerifyCredential(context.Background(), noPayload); err == nil || + !strings.Contains(err.Error(), "unknown session action") { + t.Fatalf("missing payload error = %v", err) + } +} + +func TestVerifyCredentialRejectsWrongMethodAndRealm(t *testing.T) { + session := newTestSession(t, nil) + action := intents.NewCloseAction(intents.ClosePayload{ChannelID: solana.NewWallet().PublicKey().String()}) + + request, err := core.NewBase64URLJSONValue(session.core.BuildChallengeRequest(1_000)) + if err != nil { + t.Fatalf("encode request: %v", err) + } + + wrongMethod := core.NewChallengeWithSecretFull( + sessionMethodSecret, "api.test", core.NewMethodName("stripe"), core.NewIntentName("session"), + request, core.Minutes(5), "", "", nil) + credential, err := core.NewPaymentCredential(wrongMethod.ToEcho(), action) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + if _, err := session.VerifyCredential(context.Background(), credential); err == nil || + !strings.Contains(err.Error(), "method") { + t.Fatalf("wrong method error = %v", err) + } + + wrongRealm := core.NewChallengeWithSecretFull( + sessionMethodSecret, "other.realm", core.NewMethodName("solana"), core.NewIntentName("session"), + request, core.Minutes(5), "", "", nil) + credential, err = core.NewPaymentCredential(wrongRealm.ToEcho(), action) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + if _, err := session.VerifyCredential(context.Background(), credential); err == nil || + !strings.Contains(err.Error(), "realm") { + t.Fatalf("wrong realm error = %v", err) + } +} + +// ── open payload failures ── + +func TestSessionOpenMalformedAmountsRejected(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + session := newTestSession(t, func(o *SessionOptions) { + o.Modes = []intents.SessionMode{intents.SessionModePush, intents.SessionModePull} + o.PullVoucherStrategy = &strategy + }) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + + badDeposit := intents.OpenPayloadPush(channelID, "one-usdc", signer.Address(), "sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(badDeposit)); err == nil || + !strings.Contains(err.Error(), "invalid deposit amount") { + t.Fatalf("bad deposit error = %v", err) + } + + pullNoKey := intents.OpenPayload{Mode: intents.SessionModePull, AuthorizedSigner: signer.Address(), Signature: "sig"} + if _, err := verifySessionAction(t, session, intents.NewOpenAction(pullNoKey)); err == nil || + !strings.Contains(err.Error(), "missing channelId or tokenAccount") { + t.Fatalf("pull keying error = %v", err) + } + + tokenAccount := solana.NewWallet().PublicKey().String() + pullNoAmount := intents.OpenPayload{ + Mode: intents.SessionModePull, TokenAccount: &tokenAccount, + AuthorizedSigner: signer.Address(), Signature: "sig", + } + if _, err := verifySessionAction(t, session, intents.NewOpenAction(pullNoAmount)); err == nil || + !strings.Contains(err.Error(), "missing deposit or approvedAmount") { + t.Fatalf("pull amount error = %v", err) + } +} + +func TestSessionPullOpenWithoutSignatureReferencesChannel(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + session := newTestSession(t, func(o *SessionOptions) { + o.Modes = []intents.SessionMode{intents.SessionModePull} + o.PullVoucherStrategy = &strategy + }) + signer := newTestVoucherSigner(t) + tokenAccount := solana.NewWallet().PublicKey().String() + payload := intents.OpenPayloadPull(tokenAccount, "1000", solana.NewWallet().PublicKey().String(), signer.Address(), "") + + receipt, err := verifySessionAction(t, session, intents.NewOpenAction(payload)) + if err != nil { + t.Fatalf("pull open: %v", err) + } + if receipt.Reference != tokenAccount { + t.Fatalf("reference = %q, want token account", receipt.Reference) + } +} + +func TestSessionOpenSurfacesStoreFailures(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fake := testutil.NewFakeRPC() + store := &failingGetStore{ChannelStore: NewMemoryChannelStore(), getErr: errors.New("store offline")} + session := newTestSession(t, func(o *SessionOptions) { + o.Recipient = fixture.payee.String() + o.OpenTxSubmitter = OpenTxSubmitterServer + o.RPC = fake + o.Store = store + }) + if _, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)); err == nil || + !strings.Contains(err.Error(), "store offline") { + t.Fatalf("store failure error = %v", err) + } +} + +func TestSessionServerSubmitterSurfacesBroadcastFailure(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fake := testutil.NewFakeRPC() + fake.SendErr = errors.New("blockhash not found") + session := newTestSession(t, func(o *SessionOptions) { + o.Recipient = fixture.payee.String() + o.OpenTxSubmitter = OpenTxSubmitterServer + o.RPC = fake + }) + if _, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)); err == nil || + !strings.Contains(err.Error(), "broadcast open transaction") { + t.Fatalf("broadcast failure error = %v", err) + } +} + +// ── topUp / close failures ── + +func TestSessionTopUpMalformedDepositAndStoreFailure(t *testing.T) { + session := newTestSession(t, nil) + _, channelID := openTrustedChannel(t, session, 1_000) + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "ten", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "not an unsigned integer") { + t.Fatalf("malformed deposit error = %v", err) + } + + store := &failingGetStore{ChannelStore: NewMemoryChannelStore(), getErr: errors.New("store offline")} + failing := newTestSession(t, func(o *SessionOptions) { o.Store = store }) + if _, err := verifySessionAction(t, failing, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "5000", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "store offline") { + t.Fatalf("store failure error = %v", err) + } +} + +func TestSessionCloseUnknownChannelAndSettledDoubleClose(t *testing.T) { + session := newTestSession(t, nil) + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ + ChannelID: solana.NewWallet().PublicKey().String(), + })); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("unknown channel error = %v", err) + } + + // A close-pending channel that already recorded a settlement signature + // (but is not yet marked finalized) is not re-drivable. + channelID := solana.NewWallet().PublicKey().String() + closeRequestedAt := uint64(1) + settled := confirmedSignature(0xAB) + seedChannel(t, session.Core().Store(), ChannelState{ + ChannelID: channelID, + AuthorizedSigner: newTestVoucherSigner(t).Address(), + Deposit: 1_000, + CloseRequestedAt: &closeRequestedAt, + SettledSignature: &settled, + }) + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ + ChannelID: channelID, + })); err == nil || !strings.Contains(err.Error(), "close already requested") { + t.Fatalf("settled double-close error = %v", err) + } +} + +// ── closeAndSettleChannel failure matrix ── + +func TestCloseAndSettleChannelFailureMatrix(t *testing.T) { + merchant := testutil.NewPrivateKey() + ctx := context.Background() + + // Unknown channel settles to nothing. + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + }) + if signature, err := session.closeAndSettleChannel(ctx, solana.NewWallet().PublicKey().String()); err != nil || signature != "" { + t.Fatalf("unknown channel settle = %q, %v", signature, err) + } + + // Store read failure surfaces. + store := &failingGetStore{ChannelStore: NewMemoryChannelStore(), getErr: errors.New("store offline")} + failingStoreSession := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + o.Store = store + }) + if _, err := failingStoreSession.closeAndSettleChannel(ctx, "any"); err == nil || + !strings.Contains(err.Error(), "store offline") { + t.Fatalf("store failure = %v", err) + } + + // A non-base58 channel id fails instruction derivation. + badChannel := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + }) + seedChannel(t, badChannel.Core().Store(), ChannelState{ + ChannelID: "not-base58!", + AuthorizedSigner: newTestVoucherSigner(t).Address(), + Deposit: 1_000, + }) + if _, err := badChannel.closeAndSettleChannel(ctx, "not-base58!"); err == nil || + !strings.Contains(err.Error(), "invalid channel id") { + t.Fatalf("bad channel id error = %v", err) + } + + // Blockhash fetch failure and empty response both surface. + blockhashErr := &failingBlockhashRPC{FakeRPC: testutil.NewFakeRPC(), err: errors.New("rpc down")} + noBlockhash := newTestSession(t, func(o *SessionOptions) { + o.RPC = blockhashErr + o.Signer = merchant + }) + _, channelID := openTrustedChannel(t, noBlockhash, 1_000) + if _, err := noBlockhash.closeAndSettleChannel(ctx, channelID); err == nil || + !strings.Contains(err.Error(), "fetch settlement blockhash") { + t.Fatalf("blockhash failure = %v", err) + } + blockhashErr.err = nil + blockhashErr.empty = true + if _, err := noBlockhash.closeAndSettleChannel(ctx, channelID); err == nil || + !strings.Contains(err.Error(), "empty response") { + t.Fatalf("empty blockhash = %v", err) + } + + // Merchant signer failure surfaces. + badSigner := newTestSession(t, func(o *SessionOptions) { + o.RPC = testutil.NewFakeRPC() + o.Signer = failingSigner{key: merchant} + }) + _, signerChannel := openTrustedChannel(t, badSigner, 1_000) + if _, err := badSigner.closeAndSettleChannel(ctx, signerChannel); err == nil || + !strings.Contains(err.Error(), "sign settlement transaction") { + t.Fatalf("signer failure = %v", err) + } +} + +func TestSessionIdleCloseLogsSettlementFailure(t *testing.T) { + fake := &countingBlockhashRPC{FakeRPC: testutil.NewFakeRPC()} + fake.SendErr = errors.New("blockhash not found") + merchant := testutil.NewPrivateKey() + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + o.CloseDelay = 15 * time.Millisecond + }) + _, channelID := openTrustedChannel(t, session, 1_000) + baseline := fake.calls() + + // The watchdog fires, the settle fails (the broadcast is blocked), and + // the channel stays re-drivable rather than finalized. + deadline := time.Now().Add(3 * time.Second) + for fake.calls() == baseline { + if time.Now().After(deadline) { + t.Fatal("idle-close watchdog never attempted settlement") + } + time.Sleep(5 * time.Millisecond) + } + state := mustGetChannel(t, session, channelID) + if state.Finalized || state.SettledSignature != nil { + t.Fatalf("failed settle mutated state: %+v", state) + } +} + +// ── SettlementInstructions error paths ── + +func TestSettlementInstructionsStateErrorPaths(t *testing.T) { + ctx := context.Background() + merchant := testutil.NewPrivateKey().PublicKey() + channelID := solana.NewWallet().PublicKey().String() + operator := solana.NewWallet().PublicKey().String() + + // Store read failure. + failing := NewSessionServer(sessionTestConfig(), &failingGetStore{ + ChannelStore: NewMemoryChannelStore(), getErr: errors.New("store offline"), + }) + if _, err := failing.SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "store offline") { + t.Fatalf("store failure = %v", err) + } + + seed := func(t *testing.T, config SessionConfig, state ChannelState) *SessionServer { + server := NewSessionServer(config, NewMemoryChannelStore()) + seedChannel(t, server.Store(), state) + return server + } + base := func() ChannelState { + expiresAt := farFuture() + signature := confirmedSignature(0xCD) + return ChannelState{ + ChannelID: channelID, + AuthorizedSigner: newTestVoucherSigner(t).Address(), + Deposit: 1_000, + Cumulative: 500, + Operator: &operator, + HighestVoucherSignature: &signature, + HighestVoucherExpiresAt: &expiresAt, + } + } + + // Invalid stored voucher signature. + badSignature := base() + invalid := "not-base58!" + badSignature.HighestVoucherSignature = &invalid + if _, err := seed(t, sessionTestConfig(), badSignature).SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "invalid stored voucher signature") { + t.Fatalf("bad signature = %v", err) + } + + // Invalid stored authorized signer. + badSigner := base() + badSigner.AuthorizedSigner = "not-base58!" + if _, err := seed(t, sessionTestConfig(), badSigner).SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "invalid stored authorized signer") { + t.Fatalf("bad authorized signer = %v", err) + } + + // Voucher signature without an expiry. + noExpiry := base() + noExpiry.HighestVoucherExpiresAt = nil + if _, err := seed(t, sessionTestConfig(), noExpiry).SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "no voucher expiry") { + t.Fatalf("missing expiry = %v", err) + } + + // Native SOL currency cannot settle a token channel. + solConfig := sessionTestConfig() + solConfig.Currency = "SOL" + if _, err := seed(t, solConfig, base()).SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "requires an SPL token") { + t.Fatalf("SOL currency = %v", err) + } + + // Invalid stored channel payer. + badPayer := base() + badPayerValue := "not-base58!" + badPayer.Operator = &badPayerValue + if _, err := seed(t, sessionTestConfig(), badPayer).SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "invalid channel payer") { + t.Fatalf("bad payer = %v", err) + } + + // Invalid configured recipient. + badRecipient := sessionTestConfig() + badRecipient.Recipient = "not-base58!" + if _, err := seed(t, badRecipient, base()).SettlementInstructions(ctx, channelID, merchant); err == nil || + !strings.Contains(err.Error(), "invalid recipient") { + t.Fatalf("bad recipient = %v", err) + } +} + +// ── SubmitOpenTx failure matrix ── + +func TestSubmitOpenTxFailureMatrix(t *testing.T) { + ctx := context.Background() + fixture := buildOpenTxFixture(t, false) + + if _, err := SubmitOpenTx(ctx, fixture.expected, &fixture.payload, nil, nil); err == nil || + !strings.Contains(err.Error(), "requires an RPC client") { + t.Fatalf("nil rpc = %v", err) + } + + // Structural validation failures propagate before any broadcast. + fake := testutil.NewFakeRPC() + wrongRecipient := fixture.expected + wrongRecipient.Recipient = solana.NewWallet().PublicKey().String() + if _, err := SubmitOpenTx(ctx, wrongRecipient, &fixture.payload, nil, fake); err == nil || + !strings.Contains(err.Error(), "payee") { + t.Fatalf("verification failure = %v", err) + } + if len(fake.Sent) != 0 { + t.Fatal("broadcast happened despite verification failure") + } + + // Unsigned fee payer with no payer signer cannot broadcast. + operator := testutil.NewPrivateKey() + unsigned := buildServerCompletedOpenFixture(t, operator) + if _, err := SubmitOpenTx(ctx, unsigned.expected, &unsigned.payload, nil, fake); err == nil || + !strings.Contains(err.Error(), "missing the fee-payer signature") { + t.Fatalf("unsigned fee payer = %v", err) + } + + // A payer signer that is not required by the transaction does not help. + stranger := testutil.NewPrivateKey() + if _, err := SubmitOpenTx(ctx, unsigned.expected, &unsigned.payload, stranger, fake); err == nil || + !strings.Contains(err.Error(), "missing the fee-payer signature") { + t.Fatalf("stranger signer = %v", err) + } + + // A required signer that fails to sign surfaces the error. + if _, err := SubmitOpenTx(ctx, unsigned.expected, &unsigned.payload, failingSigner{key: operator}, fake); err == nil || + !strings.Contains(err.Error(), "co-sign open transaction") { + t.Fatalf("co-sign failure = %v", err) + } + + // Confirmation failure after broadcast surfaces. + confirmFail := testutil.NewFakeRPC() + confirmFail.Statuses[fixture.signature] = &rpc.SignatureStatusesResult{ + Err: map[string]any{"InstructionError": []any{0, "Custom"}}, + } + if _, err := SubmitOpenTx(ctx, fixture.expected, &fixture.payload, nil, confirmFail); err == nil || + !strings.Contains(err.Error(), "confirm open transaction") { + t.Fatalf("confirmation failure = %v", err) + } +} + +// ── VerifyOpenTx malformed instruction matrix ── + +// buildRawOpenPayload wraps a hand-built instruction targeting the +// payment-channels program into a signed transaction + open payload. +func buildRawOpenPayload(t *testing.T, accounts []*solana.AccountMeta, data []byte) (intents.OpenPayload, VerifyOpenTxExpected) { + t.Helper() + payer := testutil.NewPrivateKey() + ix := solana.NewInstruction(paymentchannels.ProgramPubkey(), accounts, data) + blockhash := solana.MustHashFromBase58("EkSnNWid2cvwEVnVx9aBqawnmiCNiDgp3gUdkDPTKN1N") + tx, err := solana.NewTransaction([]solana.Instruction{ix}, blockhash, solana.TransactionPayer(payer.PublicKey())) + if err != nil { + t.Fatalf("NewTransaction: %v", err) + } + if err := solanatx.SignTransaction(tx, payer); err != nil { + t.Fatalf("sign: %v", err) + } + encoded, err := solanatx.EncodeTransactionBase64(tx) + if err != nil { + t.Fatalf("encode: %v", err) + } + payload := intents.OpenPayloadPush("ignored", "1000", payer.PublicKey().String(), tx.Signatures[0].String()) + payload.ChannelID = nil + payload.Transaction = &encoded + expected := VerifyOpenTxExpected{ + AuthorizedSigner: payer.PublicKey().String(), + Currency: "USDC", + MaxCap: 5_000_000, + Network: "localnet", + Recipient: payer.PublicKey().String(), + } + return payload, expected +} + +func TestVerifyOpenTxMalformedInstructions(t *testing.T) { + ctx := context.Background() + wallet := func() *solana.AccountMeta { + return solana.Meta(solana.NewWallet().PublicKey()) + } + + // An empty currency with no explicit mint cannot resolve a mint. + fixture := buildOpenTxFixture(t, false) + unknownCurrency := fixture.expected + unknownCurrency.Currency = "" + unknownCurrency.Mint = "" + if _, err := VerifyOpenTx(ctx, unknownCurrency, &fixture.payload, nil); err == nil || + !strings.Contains(err.Error(), "could not resolve mint") { + t.Fatalf("empty currency = %v", err) + } + + // Too few accounts on the open instruction. + fewAccounts, expected := buildRawOpenPayload(t, + []*solana.AccountMeta{wallet(), wallet(), wallet()}, + append([]byte{openInstructionDiscriminator}, make([]byte, 20)...)) + if _, err := VerifyOpenTx(ctx, expected, &fewAccounts, nil); err == nil || + !strings.Contains(err.Error(), "too few accounts") { + t.Fatalf("few accounts = %v", err) + } + + // Short instruction data. + accounts := make([]*solana.AccountMeta, 0, 8) + for i := 0; i < 8; i++ { + accounts = append(accounts, wallet()) + } + shortData, shortExpected := buildRawOpenPayload(t, accounts, []byte{openInstructionDiscriminator, 1, 2, 3}) + // Point the expectations at the instruction's actual payee/mint/signer so + // the data-length check is what fails. + shortExpected.Recipient = accounts[1].PublicKey.String() + shortExpected.Mint = accounts[2].PublicKey.String() + shortExpected.AuthorizedSigner = accounts[3].PublicKey.String() + if _, err := VerifyOpenTx(ctx, shortExpected, &shortData, nil); err == nil || + !strings.Contains(err.Error(), "data too short") { + t.Fatalf("short data = %v", err) + } + + // No open instruction at all (wrong discriminator). + wrongDisc, wrongExpected := buildRawOpenPayload(t, accounts, []byte{9, 9, 9}) + if _, err := VerifyOpenTx(ctx, wrongExpected, &wrongDisc, nil); err == nil || + !strings.Contains(err.Error(), "no payment-channels open instruction") { + t.Fatalf("wrong discriminator = %v", err) + } +} + +func TestConfirmTransactionSignatureRPCErrorSurfaces(t *testing.T) { + failing := &failingStatusRPC{FakeRPC: testutil.NewFakeRPC()} + if err := confirmTransactionSignature(context.Background(), failing, confirmedSignature(0xEF), "open"); err == nil || + !strings.Contains(err.Error(), "RPC error") { + t.Fatalf("rpc error = %v", err) + } +} + +// ── routes + middleware failure responses ── + +func TestSessionRoutesCommitErrorBodies(t *testing.T) { + session := newTestSession(t, nil) + routes := session.Routes() + + invalid := httptest.NewRequest(http.MethodPost, "/__402/session/commit", strings.NewReader("not-json")) + recorder := httptest.NewRecorder() + routes.Commit(recorder, invalid) + if recorder.Code != http.StatusBadRequest || !strings.Contains(recorder.Body.String(), "invalid request body") { + t.Fatalf("invalid commit body: %d %s", recorder.Code, recorder.Body) + } + + signer, channelID := openTrustedChannel(t, session, 1_000) + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + unknown := commitDeliveryViaRoutes(t, routes, map[string]any{"deliveryId": "ghost", "voucher": voucher}) + if unknown.Code != http.StatusBadRequest || !strings.Contains(unknown.Body.String(), "not found") { + t.Fatalf("unknown delivery: %d %s", unknown.Code, unknown.Body) + } +} + +func TestSessionMiddlewareErrorResponses(t *testing.T) { + session := newTestSession(t, nil) + + // challengeFn failure becomes a 500. + failing := SessionMiddleware(session, func(*http.Request) (SessionChallengeOptions, error) { + return SessionChallengeOptions{}, errors.New("route metadata unavailable") + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) + recorder := httptest.NewRecorder() + failing.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("challengeFn failure status = %d", recorder.Code) + } + + // A challenge build failure (malformed cap) becomes a 500. + badCap := SessionMiddleware(session, func(*http.Request) (SessionChallengeOptions, error) { + return SessionChallengeOptions{Cap: "1.5"}, nil + })(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) + recorder = httptest.NewRecorder() + badCap.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + if recorder.Code != http.StatusInternalServerError { + t.Fatalf("bad cap status = %d", recorder.Code) + } + + // An empty Payment token falls through to the 402 challenge. + ok := SessionMiddleware(session, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + request := httptest.NewRequest(http.MethodGet, "/", nil) + request.Header.Set(core.AuthorizationHeader, "Payment ") + recorder = httptest.NewRecorder() + ok.ServeHTTP(recorder, request) + if recorder.Code != http.StatusPaymentRequired { + t.Fatalf("empty token status = %d", recorder.Code) + } +} + +// ── stream writer failures ── + +// failAfterWriter fails every write after the first n bytes budget runs out. +type failAfterWriter struct { + budget int // remaining bytes accepted before writes start failing +} + +func (f *failAfterWriter) Write(p []byte) (int, error) { + if f.budget <= 0 { + return 0, errors.New("client disconnected") + } + f.budget -= len(p) + return len(p), nil +} + +func TestMeteredStreamSurfacesWriteFailures(t *testing.T) { + stream := NewMeteredStreamWriter(&failAfterWriter{}) + if err := stream.WriteMetering(intents.MeteringDirective{DeliveryID: "d", SessionID: "s", Amount: "1", Currency: "USDC"}); err == nil { + t.Fatal("expected metering write failure") + } + if err := stream.WriteUsage(intents.MeteringUsage{DeliveryID: "d", Amount: "1"}); err == nil { + t.Fatal("expected usage write failure") + } + if err := stream.WriteEnvelope(map[string]string{"chunk": "x"}, intents.MeteringDirective{}); err == nil { + t.Fatal("expected envelope write failure") + } + if err := stream.WriteDone(); err == nil { + t.Fatal("expected done write failure") + } +} + +// ── core SessionServer gaps ── + +func TestBuildChallengeRequestIncludesProgramIDOverride(t *testing.T) { + programID := solana.NewWallet().PublicKey() + config := sessionTestConfig() + config.ProgramID = &programID + server := newSessionTestServer(config) + request := server.BuildChallengeRequest(1_000) + if request.ProgramID == nil || *request.ProgramID != programID.String() { + t.Fatalf("programId = %v", request.ProgramID) + } +} + +func TestVerifyVoucherSurfacesStoreFailure(t *testing.T) { + server := NewSessionServer(sessionTestConfig(), &failingGetStore{ + ChannelStore: NewMemoryChannelStore(), getErr: errors.New("store offline"), + }) + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, solana.NewWallet().PublicKey().String(), 100, farFuture()) + if _, err := server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: voucher}); err == nil || + !strings.Contains(err.Error(), "store offline") { + t.Fatalf("store failure = %v", err) + } +} + +func TestProcessOpenPayloadFieldErrors(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer := newTestVoucherSigner(t) + + missingChannel := intents.OpenPayload{Mode: intents.SessionModePush, AuthorizedSigner: signer.Address(), Signature: "sig"} + if _, err := server.ProcessOpen(context.Background(), &missingChannel); err == nil || + !strings.Contains(err.Error(), "missing channelId") { + t.Fatalf("missing channelId = %v", err) + } + + channelID := solana.NewWallet().PublicKey().String() + badDeposit := intents.OpenPayloadPush(channelID, strconv.Quote("x"), signer.Address(), "sig") + if _, err := server.ProcessOpen(context.Background(), &badDeposit); err == nil || + !strings.Contains(err.Error(), "invalid deposit amount") { + t.Fatalf("bad deposit = %v", err) + } +} + +func TestProcessTopUpMalformedDeposit(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + if _, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: "c", NewDeposit: "five", Signature: "sig", + }); err == nil || !strings.Contains(err.Error(), "invalid newDeposit") { + t.Fatalf("malformed deposit = %v", err) + } +} + +func TestProcessCommitMalformedCumulative(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + payload := intents.CommitPayload{ + DeliveryID: "d-1", + Voucher: intents.SignedVoucher{ + Data: intents.VoucherData{ChannelID: "c", Cumulative: "ten", ExpiresAt: farFuture()}, + Signature: confirmedSignature(0x01), + }, + } + if _, err := server.ProcessCommit(context.Background(), &payload); err == nil || + !strings.Contains(err.Error(), "invalid cumulative") { + t.Fatalf("malformed cumulative = %v", err) + } +} + +func TestProcessCloseMalformedFinalVoucher(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000) + voucher := intents.SignedVoucher{ + Data: intents.VoucherData{ChannelID: channelID, Cumulative: "ten", ExpiresAt: farFuture()}, + Signature: confirmedSignature(0x02), + } + if _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ + ChannelID: channelID, Voucher: &voucher, + }); err == nil || !strings.Contains(err.Error(), "invalid cumulative in final voucher") { + t.Fatalf("malformed final voucher = %v", err) + } +} diff --git a/go/protocols/mpp/server/session_method_gap_test.go b/go/protocols/mpp/server/session_method_gap_test.go new file mode 100644 index 000000000..933628d37 --- /dev/null +++ b/go/protocols/mpp/server/session_method_gap_test.go @@ -0,0 +1,151 @@ +package server + +// Remaining behavioral gaps on the session method layer: external id +// propagation onto receipts, the pull-strategy handler guard, server-submit +// pre-verification failures, lifecycle teardown on close, and settlement +// store-write failures. + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// failingUpdateStore wraps a ChannelStore and fails UpdateChannel once armed. +type failingUpdateStore struct { + // ChannelStore is the wrapped store used while fail is unset. + ChannelStore + + // fail, once armed, makes every UpdateChannel return a write error. + fail bool +} + +func (f *failingUpdateStore) UpdateChannel(ctx context.Context, channelID string, mutator ChannelMutator) (ChannelState, error) { + if f.fail { + return ChannelState{}, errors.New("store write rejected") + } + return f.ChannelStore.UpdateChannel(ctx, channelID, mutator) +} + +func TestSessionReceiptCarriesExternalID(t *testing.T) { + session := newTestSession(t, nil) + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{ExternalID: "order-42"}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + signer := newTestVoucherSigner(t) + credential, err := core.NewPaymentCredential(challenge.ToEcho(), intents.NewOpenAction( + intents.OpenPayloadPush(solana.NewWallet().PublicKey().String(), "1000", signer.Address(), "sig"))) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + receipt, err := session.VerifyCredential(context.Background(), credential) + if err != nil { + t.Fatalf("VerifyCredential: %v", err) + } + if receipt.ExternalID != "order-42" { + t.Fatalf("receipt externalId = %q", receipt.ExternalID) + } +} + +func TestSessionOpenPullRequiresStrategyAtHandler(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + session := newTestSession(t, func(o *SessionOptions) { + o.Modes = []intents.SessionMode{intents.SessionModePull} + o.PullVoucherStrategy = &strategy + }) + // Simulate a misconfigured lower-level core (the constructor enforces the + // invariant, but the handler re-checks it defensively). + session.core.config.PullVoucherStrategy = nil + signer := newTestVoucherSigner(t) + payload := intents.OpenPayloadPull( + solana.NewWallet().PublicKey().String(), "1000", + solana.NewWallet().PublicKey().String(), signer.Address(), "sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(payload)); err == nil || + !strings.Contains(err.Error(), "requires a pullVoucherStrategy") { + t.Fatalf("missing strategy error = %v", err) + } +} + +func TestSessionServerSubmitterPreVerificationFailure(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fake := testutil.NewFakeRPC() + // The session recipient differs from the fixture payee, so the decode-only + // pre-verification fails before any broadcast. + session := newTestSession(t, func(o *SessionOptions) { + o.OpenTxSubmitter = OpenTxSubmitterServer + o.RPC = fake + }) + if _, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)); err == nil || + !strings.Contains(err.Error(), "payee") { + t.Fatalf("pre-verification error = %v", err) + } + if len(fake.Sent) != 0 { + t.Fatal("broadcast happened despite pre-verification failure") + } +} + +func TestSessionCloseCancelsIdleTimer(t *testing.T) { + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.CloseDelay = 50 * time.Millisecond + }) + _, channelID := openTrustedChannel(t, session, 1_000) + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})); err != nil { + t.Fatalf("close: %v", err) + } + // With no merchant signer the close never settles, and the canceled + // watchdog must not fire afterward either. + time.Sleep(120 * time.Millisecond) + state := mustGetChannel(t, session, channelID) + if state.Finalized || len(fake.Sent) != 0 { + t.Fatalf("idle timer fired after close: %+v sends=%d", state, len(fake.Sent)) + } +} + +func TestCloseAndSettleSurfacesStoreWriteFailure(t *testing.T) { + fake := testutil.NewFakeRPC() + merchant := testutil.NewPrivateKey() + store := &failingUpdateStore{ChannelStore: NewMemoryChannelStore()} + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + o.Store = store + }) + _, channelID := openTrustedChannel(t, session, 1_000) + + store.fail = true + if _, err := session.closeAndSettleChannel(context.Background(), channelID); err == nil || + !strings.Contains(err.Error(), "store write rejected") { + t.Fatalf("store write failure = %v", err) + } +} + +func TestSettlementInstructionsInvalidMintCurrency(t *testing.T) { + config := sessionTestConfig() + // An unknown currency resolves to itself; a non-base58 value then fails + // mint parsing. + config.Currency = "not-a-mint!" + server := NewSessionServer(config, NewMemoryChannelStore()) + operator := solana.NewWallet().PublicKey().String() + channelID := solana.NewWallet().PublicKey().String() + seedChannel(t, server.Store(), ChannelState{ + ChannelID: channelID, + AuthorizedSigner: newTestVoucherSigner(t).Address(), + Deposit: 1_000, + Operator: &operator, + }) + if _, err := server.SettlementInstructions(context.Background(), channelID, solana.NewWallet().PublicKey()); err == nil || + !strings.Contains(err.Error(), "invalid mint") { + t.Fatalf("invalid mint = %v", err) + } +} diff --git a/go/protocols/mpp/server/session_method_test.go b/go/protocols/mpp/server/session_method_test.go new file mode 100644 index 000000000..711da8f4d --- /dev/null +++ b/go/protocols/mpp/server/session_method_test.go @@ -0,0 +1,1453 @@ +package server + +// Method-level coverage through the real credential layer: challenge +// issuance (canonical shape, cap clamping, pull advertisement, blockhash +// prefetch), the five verify() actions with their replay/hardening +// semantics, the side-channel routes, settlement retry, and the store +// sharing between the method and its routes. + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +const sessionMethodSecret = "session-method-secret" + +// confirmedSignature returns a base58 signature string registered as +// confirmed on the fake RPC. +func confirmedSignature(fill byte) string { + raw := make([]byte, 64) + for i := range raw { + raw[i] = fill + } + return solana.SignatureFromBytes(raw).String() +} + +func newTestSession(t *testing.T, mutate func(*SessionOptions)) *Session { + t.Helper() + options := SessionOptions{ + Operator: sessionTestRecipient, + Recipient: sessionTestRecipient, + Cap: 5_000_000, + Currency: "USDC", + Decimals: 6, + Network: "localnet", + SecretKey: sessionMethodSecret, + Realm: "api.test", + } + if mutate != nil { + mutate(&options) + } + session, err := NewSession(options) + if err != nil { + t.Fatalf("NewSession: %v", err) + } + t.Cleanup(session.Shutdown) + return session +} + +// sessionActionCredential issues a fresh challenge and wraps action into the +// credential a client would send. +func sessionActionCredential(t *testing.T, session *Session, action any) core.PaymentCredential { + t.Helper() + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + credential, err := core.NewPaymentCredential(challenge.ToEcho(), action) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + return credential +} + +func verifySessionAction(t *testing.T, session *Session, action any) (core.Receipt, error) { + t.Helper() + return session.VerifyCredential(context.Background(), sessionActionCredential(t, session, action)) +} + +// openTrustedChannel opens a transactionless push channel through the +// credential layer and returns the voucher signer plus channel id. The open +// signature is a valid base58 signature so the helper also works on sessions +// with an RPC client configured (the fake RPC confirms unknown signatures). +func openTrustedChannel(t *testing.T, session *Session, deposit uint64) (testVoucherSigner, string) { + t.Helper() + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + openSessionChannel(t, session, channelID, deposit, signer.Address(), confirmedSignature(0x99)) + return signer, channelID +} + +func openSessionChannel(t *testing.T, session *Session, channelID string, deposit uint64, authorizedSigner, signature string) core.Receipt { + t.Helper() + payload := intents.OpenPayloadPush(channelID, fmt.Sprintf("%d", deposit), authorizedSigner, signature) + receipt, err := verifySessionAction(t, session, intents.NewOpenAction(payload)) + if err != nil { + t.Fatalf("open: %v", err) + } + return receipt +} + +func mustGetChannel(t *testing.T, session *Session, channelID string) *ChannelState { + t.Helper() + state, err := session.Core().Store().GetChannel(context.Background(), channelID) + if err != nil { + t.Fatalf("GetChannel: %v", err) + } + return state +} + +// ── NewSession validation ── + +func TestNewSessionValidation(t *testing.T) { + base := func() SessionOptions { + return SessionOptions{ + Operator: sessionTestRecipient, + Recipient: sessionTestRecipient, + Cap: 1_000, + SecretKey: sessionMethodSecret, + } + } + + zeroCap := base() + zeroCap.Cap = 0 + if _, err := NewSession(zeroCap); err == nil || !strings.Contains(err.Error(), "cap must be positive") { + t.Fatalf("zero cap error = %v", err) + } + + noRecipient := base() + noRecipient.Recipient = "" + if _, err := NewSession(noRecipient); err == nil || !strings.Contains(err.Error(), "recipient is required") { + t.Fatalf("missing recipient error = %v", err) + } + + badRecipient := base() + badRecipient.Recipient = "not-base58!" + if _, err := NewSession(badRecipient); err == nil || !strings.Contains(err.Error(), "invalid recipient") { + t.Fatalf("invalid recipient error = %v", err) + } + + manySplits := base() + for i := 0; i < 9; i++ { + manySplits.Splits = append(manySplits.Splits, Split{Recipient: solana.NewWallet().PublicKey(), BPS: 1}) + } + if _, err := NewSession(manySplits); err == nil || !strings.Contains(err.Error(), "splits cannot exceed") { + t.Fatalf("splits error = %v", err) + } + + pullNoStrategy := base() + pullNoStrategy.Modes = []intents.SessionMode{intents.SessionModePull} + if _, err := NewSession(pullNoStrategy); err == nil || !strings.Contains(err.Error(), "pullVoucherStrategy is required") { + t.Fatalf("pull strategy error = %v", err) + } + + badSubmitter := base() + badSubmitter.OpenTxSubmitter = OpenTxSubmitter("relay") + if _, err := NewSession(badSubmitter); err == nil || !strings.Contains(err.Error(), "openTxSubmitter") { + t.Fatalf("openTxSubmitter error = %v", err) + } + + t.Setenv(secretKeyEnvVar, "") + noSecret := base() + noSecret.SecretKey = "" + if _, err := NewSession(noSecret); err == nil || !strings.Contains(err.Error(), "missing secret key") { + t.Fatalf("missing secret error = %v", err) + } +} + +func TestNewSessionDefaults(t *testing.T) { + session := newTestSession(t, func(o *SessionOptions) { + o.Currency = "" + o.Decimals = 0 + o.Network = "" + o.OpenTxSubmitter = "" + }) + if session.currency != "USDC" || session.network != "mainnet" { + t.Fatalf("defaults: currency=%q network=%q", session.currency, session.network) + } + if session.openTxSubmitter != OpenTxSubmitterClient { + t.Fatalf("openTxSubmitter default = %q", session.openTxSubmitter) + } + if session.core.config.Decimals != 6 { + t.Fatalf("decimals default = %d", session.core.config.Decimals) + } +} + +// ── Challenge ── + +func TestSessionChallengeCanonicalShape(t *testing.T) { + session := newTestSession(t, nil) + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{ + Cap: "1000000", + Description: "Metered token stream", + }) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + if !challenge.Verify(sessionMethodSecret) { + t.Fatal("challenge HMAC does not verify") + } + if !challenge.Intent.IsSession() { + t.Fatalf("intent = %q, want session", challenge.Intent) + } + if string(challenge.Method) != "solana" || challenge.Realm != "api.test" { + t.Fatalf("method=%q realm=%q", challenge.Method, challenge.Realm) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + t.Fatalf("decode request: %v", err) + } + if request.Cap != "1000000" || request.Currency != "USDC" { + t.Fatalf("cap=%q currency=%q", request.Cap, request.Currency) + } + if request.Operator != sessionTestRecipient || request.Recipient != sessionTestRecipient { + t.Fatalf("operator=%q recipient=%q", request.Operator, request.Recipient) + } + if request.Network == nil || *request.Network != "localnet" { + t.Fatalf("network = %v", request.Network) + } + if request.Decimals == nil || *request.Decimals != 6 { + t.Fatalf("decimals = %v", request.Decimals) + } + if request.Description == nil || *request.Description != "Metered token stream" { + t.Fatalf("description = %v", request.Description) + } + if request.Modes != nil { + t.Fatalf("modes should be omitted when push-only, got %v", request.Modes) + } + if request.RecentBlockhash != nil { + t.Fatalf("recentBlockhash should be absent without an RPC client, got %v", *request.RecentBlockhash) + } +} + +func TestSessionChallengeClampsRequestedCap(t *testing.T) { + session := newTestSession(t, func(o *SessionOptions) { o.Cap = 1_000_000 }) + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{Cap: "50000000"}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + t.Fatalf("decode request: %v", err) + } + if request.Cap != "1000000" { + t.Fatalf("cap = %q, want clamped 1000000", request.Cap) + } +} + +func TestSessionChallengeInvalidCapRejected(t *testing.T) { + session := newTestSession(t, nil) + if _, err := session.Challenge(context.Background(), SessionChallengeOptions{Cap: "1.5"}); err == nil { + t.Fatal("expected invalid cap error") + } +} + +func TestSessionChallengeIncludesBlockhashWithRPC(t *testing.T) { + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { o.RPC = fake }) + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + t.Fatalf("decode request: %v", err) + } + if request.RecentBlockhash == nil || *request.RecentBlockhash != fake.Blockhash.String() { + t.Fatalf("recentBlockhash = %v, want %s", request.RecentBlockhash, fake.Blockhash) + } +} + +func TestSessionChallengeAdvertisesPullStrategy(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + session := newTestSession(t, func(o *SessionOptions) { + o.Modes = []intents.SessionMode{intents.SessionModePull, intents.SessionModePush} + o.PullVoucherStrategy = &strategy + }) + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{ExternalID: "ref-7"}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + t.Fatalf("decode request: %v", err) + } + if len(request.Modes) != 2 { + t.Fatalf("modes = %v", request.Modes) + } + if request.PullVoucherStrategy == nil || *request.PullVoucherStrategy != strategy { + t.Fatalf("pullVoucherStrategy = %v", request.PullVoucherStrategy) + } + if request.ExternalID == nil || *request.ExternalID != "ref-7" { + t.Fatalf("externalId = %v", request.ExternalID) + } +} + +// ── VerifyCredential: tier-1 + tier-2 ── + +func TestVerifyCredentialRejectsTamperedAndExpiredChallenges(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + action := intents.NewOpenAction(intents.OpenPayloadPush(channelID, "1000", signer.Address(), "sig")) + + credential := sessionActionCredential(t, session, action) + credential.Challenge.Realm = "tampered.example" + if _, err := session.VerifyCredential(context.Background(), credential); err == nil || + !strings.Contains(err.Error(), "challenge ID mismatch") { + t.Fatalf("tampered realm error = %v", err) + } + + request, err := core.NewBase64URLJSONValue(session.core.BuildChallengeRequest(1_000)) + if err != nil { + t.Fatalf("encode request: %v", err) + } + expired := core.NewChallengeWithSecretFull( + sessionMethodSecret, "api.test", core.NewMethodName("solana"), core.NewIntentName("session"), + request, "2020-01-01T00:00:00Z", "", "", nil) + expiredCredential, err := core.NewPaymentCredential(expired.ToEcho(), action) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + if _, err := session.VerifyCredential(context.Background(), expiredCredential); err == nil || + !strings.Contains(err.Error(), "expired") { + t.Fatalf("expired challenge error = %v", err) + } +} + +func TestVerifyCredentialPinnedFieldBackstop(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + action := intents.NewOpenAction(intents.OpenPayloadPush( + solana.NewWallet().PublicKey().String(), "1000", signer.Address(), "sig")) + + issue := func(intent string, request intents.SessionRequest) core.PaymentCredential { + encoded, err := core.NewBase64URLJSONValue(request) + if err != nil { + t.Fatalf("encode request: %v", err) + } + challenge := core.NewChallengeWithSecretFull( + sessionMethodSecret, "api.test", core.NewMethodName("solana"), core.NewIntentName(intent), + encoded, core.Minutes(5), "", "", nil) + credential, err := core.NewPaymentCredential(challenge.ToEcho(), action) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + return credential + } + + chargeIntent := issue("charge", session.core.BuildChallengeRequest(1_000)) + if _, err := session.VerifyCredential(context.Background(), chargeIntent); err == nil || + !strings.Contains(err.Error(), "not a session") { + t.Fatalf("wrong intent error = %v", err) + } + + wrongCurrency := session.core.BuildChallengeRequest(1_000) + wrongCurrency.Currency = "USDT" + if _, err := session.VerifyCredential(context.Background(), issue("session", wrongCurrency)); err == nil || + !strings.Contains(err.Error(), "currency") { + t.Fatalf("wrong currency error = %v", err) + } + + wrongRecipient := session.core.BuildChallengeRequest(1_000) + wrongRecipient.Recipient = solana.NewWallet().PublicKey().String() + if _, err := session.VerifyCredential(context.Background(), issue("session", wrongRecipient)); err == nil || + !strings.Contains(err.Error(), "recipient") { + t.Fatalf("wrong recipient error = %v", err) + } + + unknownAction := sessionActionCredential(t, session, map[string]string{"action": "refund"}) + if _, err := session.VerifyCredential(context.Background(), unknownAction); err == nil || + !strings.Contains(err.Error(), "decode session action") { + t.Fatalf("unknown action error = %v", err) + } +} + +// ── open ── + +func TestSessionOpenTrustsChannelIDAndDeposit(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + + receipt := openSessionChannel(t, session, channelID, 1_000_000, signer.Address(), "sig-1") + if receipt.Status != core.ReceiptStatusSuccess { + t.Fatalf("status = %q", receipt.Status) + } + if receipt.Reference != "sig-1" { + t.Fatalf("reference = %q, want sig-1", receipt.Reference) + } + state := mustGetChannel(t, session, channelID) + if state == nil || state.Deposit != 1_000_000 || state.Cumulative != 0 || state.AuthorizedSigner != signer.Address() { + t.Fatalf("stored state = %+v", state) + } +} + +func TestSessionOpenRejectsUnadvertisedMode(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + payload := intents.OpenPayloadPull( + solana.NewWallet().PublicKey().String(), "1000", + solana.NewWallet().PublicKey().String(), signer.Address(), "sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(payload)); err == nil || + !strings.Contains(err.Error(), "not supported") { + t.Fatalf("unadvertised mode error = %v", err) + } +} + +func TestSessionOpenRejectsBadDeposits(t *testing.T) { + session := newTestSession(t, func(o *SessionOptions) { o.Cap = 1_000 }) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + + over := intents.OpenPayloadPush(channelID, "10000", signer.Address(), "sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(over)); err == nil || + !strings.Contains(err.Error(), "exceeds cap") { + t.Fatalf("over-cap error = %v", err) + } + + zero := intents.OpenPayloadPush(channelID, "0", signer.Address(), "sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(zero)); err == nil || + !strings.Contains(err.Error(), "greater than zero") { + t.Fatalf("zero deposit error = %v", err) + } + + missing := intents.OpenPayload{Mode: intents.SessionModePush, AuthorizedSigner: signer.Address(), Signature: "sig"} + if _, err := verifySessionAction(t, session, intents.NewOpenAction(missing)); err == nil || + !strings.Contains(err.Error(), "missing transaction or channelId") { + t.Fatalf("missing channel error = %v", err) + } +} + +// TestSessionOpenRejectsEmptyStringFields pins that empty strings count as +// missing on the push open path: transaction="" with no channelId (and the +// all-empty variant) must reject gracefully instead of dereferencing a nil +// ChannelID. +func TestSessionOpenRejectsEmptyStringFields(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + empty := "" + + emptyTx := intents.OpenPayload{ + Mode: intents.SessionModePush, Transaction: &empty, + AuthorizedSigner: signer.Address(), Signature: "sig", + } + if _, err := verifySessionAction(t, session, intents.NewOpenAction(emptyTx)); err == nil || + !strings.Contains(err.Error(), "missing transaction or channelId") { + t.Fatalf("empty transaction error = %v", err) + } + + emptyBoth := intents.OpenPayload{ + Mode: intents.SessionModePush, Transaction: &empty, ChannelID: &empty, + AuthorizedSigner: signer.Address(), Signature: "sig", + } + if _, err := verifySessionAction(t, session, intents.NewOpenAction(emptyBoth)); err == nil || + !strings.Contains(err.Error(), "missing transaction or channelId") { + t.Fatalf("empty transaction and channelId error = %v", err) + } +} + +func TestSessionOpenReplaySemantics(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + + if _, err := submitMethodVoucher(t, session, signer, channelID, 250); err != nil { + t.Fatalf("voucher: %v", err) + } + + // Idempotent replay preserves the watermark. + openSessionChannel(t, session, channelID, 1_000, signer.Address(), "open-sig") + state := mustGetChannel(t, session, channelID) + if state.Cumulative != 250 || state.HighestVoucherSignature == nil { + t.Fatalf("replay reset watermark: %+v", state) + } + + // Different authorizedSigner rejects without overwriting. + intruder := newTestVoucherSigner(t) + payload := intents.OpenPayloadPush(channelID, "1000", intruder.Address(), "open-sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(payload)); err == nil || + !strings.Contains(err.Error(), "authorizedSigner") { + t.Fatalf("intruder replay error = %v", err) + } + if mustGetChannel(t, session, channelID).AuthorizedSigner != signer.Address() { + t.Fatal("intruder replay overwrote the authorized signer") + } + + // Finalized channel rejects replays. + if _, err := session.Core().Store().MarkFinalized(context.Background(), channelID); err != nil { + t.Fatalf("MarkFinalized: %v", err) + } + replay := intents.OpenPayloadPush(channelID, "1000", signer.Address(), "open-sig") + if _, err := verifySessionAction(t, session, intents.NewOpenAction(replay)); err == nil || + !strings.Contains(err.Error(), "finalized") { + t.Fatalf("finalized replay error = %v", err) + } +} + +func TestSessionOpenVerifiesSignatureOnChain(t *testing.T) { + fake := testutil.NewFakeRPC() + okSig := confirmedSignature(0x11) + ghostSig := confirmedSignature(0x22) + failedSig := confirmedSignature(0x33) + fake.Statuses[ghostSig] = nil + fake.Statuses[failedSig] = &rpc.SignatureStatusesResult{Err: map[string]any{"InstructionError": []any{0, "Custom"}}} + + session := newTestSession(t, func(o *SessionOptions) { o.RPC = fake }) + signer := newTestVoucherSigner(t) + + channelID := solana.NewWallet().PublicKey().String() + receipt := openSessionChannel(t, session, channelID, 1_000, signer.Address(), okSig) + if receipt.Reference != okSig { + t.Fatalf("reference = %q", receipt.Reference) + } + + ghostChannel := solana.NewWallet().PublicKey().String() + ghost := intents.OpenPayloadPush(ghostChannel, "1000", signer.Address(), ghostSig) + if _, err := verifySessionAction(t, session, intents.NewOpenAction(ghost)); err == nil || + !strings.Contains(err.Error(), "not found") { + t.Fatalf("ghost signature error = %v", err) + } + if mustGetChannel(t, session, ghostChannel) != nil { + t.Fatal("channel persisted despite unknown signature") + } + + failed := intents.OpenPayloadPush(solana.NewWallet().PublicKey().String(), "1000", signer.Address(), failedSig) + if _, err := verifySessionAction(t, session, intents.NewOpenAction(failed)); err == nil || + !strings.Contains(err.Error(), "failed on-chain") { + t.Fatalf("failed signature error = %v", err) + } +} + +func TestSessionPullOpenPrefersChannelIDOverTokenAccount(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + session := newTestSession(t, func(o *SessionOptions) { + o.Modes = []intents.SessionMode{intents.SessionModePull} + o.PullVoucherStrategy = &strategy + }) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + tokenAccount := solana.NewWallet().PublicKey().String() + + payload := intents.OpenPayloadPull(tokenAccount, "1000", solana.NewWallet().PublicKey().String(), signer.Address(), "sig-1") + payload.ChannelID = &channelID + if _, err := verifySessionAction(t, session, intents.NewOpenAction(payload)); err != nil { + t.Fatalf("pull open: %v", err) + } + if mustGetChannel(t, session, channelID) == nil { + t.Fatal("channel not keyed by channelId") + } + if mustGetChannel(t, session, tokenAccount) != nil { + t.Fatal("channel unexpectedly keyed by tokenAccount") + } + // Pull opens record the owner as the channel operator. + if state := mustGetChannel(t, session, channelID); state.Operator == nil { + t.Fatal("pull open did not record the operator") + } +} + +// ── voucher ── + +func submitMethodVoucher(t *testing.T, session *Session, signer testVoucherSigner, channelID string, cumulative uint64) (core.Receipt, error) { + t.Helper() + voucher := signer.SignVoucher(t, channelID, cumulative, farFuture()) + return verifySessionAction(t, session, intents.NewVoucherAction(intents.VoucherPayload{Voucher: voucher})) +} + +func TestSessionVoucherAdvancesWatermark(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + + voucher := signer.SignVoucher(t, channelID, 250, farFuture()) + receipt, err := verifySessionAction(t, session, intents.NewVoucherAction(intents.VoucherPayload{Voucher: voucher})) + if err != nil { + t.Fatalf("voucher: %v", err) + } + if receipt.Reference != channelID+":250" { + t.Fatalf("reference = %q", receipt.Reference) + } + state := mustGetChannel(t, session, channelID) + if state.Cumulative != 250 || state.HighestVoucherSignature == nil || *state.HighestVoucherSignature != voucher.Signature { + t.Fatalf("state after voucher = %+v", state) + } +} + +func TestSessionVoucherUnknownChannelRejected(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + if _, err := submitMethodVoucher(t, session, signer, solana.NewWallet().PublicKey().String(), 100); err == nil || + !strings.Contains(err.Error(), "not found") { + t.Fatalf("unknown channel error = %v", err) + } +} + +func TestSessionVoucherNonMonotonicTaggedRejection(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + if _, err := submitMethodVoucher(t, session, signer, channelID, 100); err != nil { + t.Fatalf("first voucher: %v", err) + } + if _, err := submitMethodVoucher(t, session, signer, channelID, 50); err == nil || + !strings.Contains(err.Error(), "cumulative-not-monotonic") { + t.Fatalf("stale voucher error = %v", err) + } +} + +func TestSessionVoucherAcceptsCumulativeAliasOnTheWire(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + canonical := signer.SignVoucher(t, channelID, 250, farFuture()) + + aliased := map[string]any{ + "action": "voucher", + "voucher": map[string]any{ + "data": map[string]any{ + "channelId": channelID, + "cumulative": "250", + "expiresAt": canonical.Data.ExpiresAt, + }, + "signature": canonical.Signature, + }, + } + receipt, err := verifySessionAction(t, session, aliased) + if err != nil { + t.Fatalf("aliased voucher: %v", err) + } + if receipt.Reference != channelID+":250" { + t.Fatalf("reference = %q", receipt.Reference) + } + if mustGetChannel(t, session, channelID).Cumulative != 250 { + t.Fatal("alias voucher did not advance the watermark") + } + + neither := map[string]any{ + "action": "voucher", + "voucher": map[string]any{ + "data": map[string]any{"channelId": channelID, "expiresAt": canonical.Data.ExpiresAt}, + "signature": canonical.Signature, + }, + } + if _, err := verifySessionAction(t, session, neither); err == nil || + !strings.Contains(err.Error(), "cumulativeAmount") { + t.Fatalf("missing cumulative error = %v", err) + } +} + +// ── topUp ── + +func TestSessionTopUpUpdatesDeposit(t *testing.T) { + session := newTestSession(t, nil) + _, channelID := openTrustedChannel(t, session, 1_000) + + receipt, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "5000", Signature: "topup-sig", + })) + if err != nil { + t.Fatalf("topUp: %v", err) + } + if receipt.Reference != "topup-sig" { + t.Fatalf("reference = %q", receipt.Reference) + } + if mustGetChannel(t, session, channelID).Deposit != 5_000 { + t.Fatal("deposit not raised") + } +} + +func TestSessionTopUpHardening(t *testing.T) { + session := newTestSession(t, nil) + _, channelID := openTrustedChannel(t, session, 5_000) + + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "1000", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "must exceed current deposit") { + t.Fatalf("below-current error = %v", err) + } + + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "99000000", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "exceeds cap") { + t.Fatalf("over-cap error = %v", err) + } + + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: solana.NewWallet().PublicKey().String(), NewDeposit: "9000", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("unknown channel error = %v", err) + } + + // Close-pending blocks top-ups. + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})); err != nil { + t.Fatalf("close: %v", err) + } + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "9000", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "close is pending") { + t.Fatalf("close-pending error = %v", err) + } + + // Finalized blocks top-ups. + _, finalizedChannel := openTrustedChannel(t, session, 5_000) + if _, err := session.Core().Store().MarkFinalized(context.Background(), finalizedChannel); err != nil { + t.Fatalf("MarkFinalized: %v", err) + } + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: finalizedChannel, NewDeposit: "9000", Signature: "sig", + })); err == nil || !strings.Contains(err.Error(), "finalized") { + t.Fatalf("finalized error = %v", err) + } +} + +func TestSessionTopUpVerifiesSignatureOnChain(t *testing.T) { + fake := testutil.NewFakeRPC() + openSig := confirmedSignature(0x44) + topupSig := confirmedSignature(0x55) + ghostSig := confirmedSignature(0x66) + fake.Statuses[ghostSig] = nil + + session := newTestSession(t, func(o *SessionOptions) { o.RPC = fake }) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + openSessionChannel(t, session, channelID, 1_000, signer.Address(), openSig) + + receipt, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "5000", Signature: topupSig, + })) + if err != nil { + t.Fatalf("topUp: %v", err) + } + if receipt.Reference != topupSig { + t.Fatalf("reference = %q", receipt.Reference) + } + if mustGetChannel(t, session, channelID).Deposit != 5_000 { + t.Fatal("deposit not raised") + } + + if _, err := verifySessionAction(t, session, intents.NewTopUpAction(intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "9000", Signature: ghostSig, + })); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("ghost top-up error = %v", err) + } + if mustGetChannel(t, session, channelID).Deposit != 5_000 { + t.Fatal("deposit raised despite unknown signature") + } +} + +// ── close ── + +func TestSessionCloseFlipsClosePending(t *testing.T) { + session := newTestSession(t, nil) + _, channelID := openTrustedChannel(t, session, 1_000) + + receipt, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})) + if err != nil { + t.Fatalf("close: %v", err) + } + if receipt.Reference != channelID { + t.Fatalf("reference = %q, want channel id", receipt.Reference) + } + state := mustGetChannel(t, session, channelID) + if state.CloseRequestedAt == nil || state.Finalized { + t.Fatalf("state after close = %+v", state) + } +} + +func TestSessionCloseWithFinalVoucherAdvancesWatermark(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + + final := signer.SignVoucher(t, channelID, 750, farFuture()) + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ + ChannelID: channelID, Voucher: &final, + })); err != nil { + t.Fatalf("close: %v", err) + } + state := mustGetChannel(t, session, channelID) + if state.Cumulative != 750 || state.CloseRequestedAt == nil { + t.Fatalf("state after close = %+v", state) + } +} + +func TestSessionCloseNonMonotonicFinalVoucherHardError(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + if _, err := submitMethodVoucher(t, session, signer, channelID, 250); err != nil { + t.Fatalf("voucher: %v", err) + } + + stale := signer.SignVoucher(t, channelID, 100, farFuture()) + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ + ChannelID: channelID, Voucher: &stale, + })); err == nil || !strings.Contains(err.Error(), "cumulative-not-monotonic") { + t.Fatalf("stale final voucher error = %v", err) + } + state := mustGetChannel(t, session, channelID) + if state.CloseRequestedAt != nil || state.Cumulative != 250 { + t.Fatalf("close mutated state on hard error: %+v", state) + } +} + +func TestSessionCloseAcceptsReplayOfHighestVoucher(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + voucher := signer.SignVoucher(t, channelID, 250, farFuture()) + if _, err := verifySessionAction(t, session, intents.NewVoucherAction(intents.VoucherPayload{Voucher: voucher})); err != nil { + t.Fatalf("voucher: %v", err) + } + + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ + ChannelID: channelID, Voucher: &voucher, + })); err != nil { + t.Fatalf("close with replayed highest voucher: %v", err) + } + state := mustGetChannel(t, session, channelID) + if state.CloseRequestedAt == nil || state.Cumulative != 250 { + t.Fatalf("state after replay close = %+v", state) + } +} + +func TestSessionCloseRetryAfterFailedSettlement(t *testing.T) { + fake := testutil.NewFakeRPC() + merchant := testutil.NewPrivateKey() + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + }) + signer, channelID := openTrustedChannel(t, session, 1_000) + if _, err := submitMethodVoucher(t, session, signer, channelID, 400); err != nil { + t.Fatalf("voucher: %v", err) + } + + // First close: settlement broadcast fails; close stays pending and + // re-drivable. + fake.SendErr = fmt.Errorf("blockhash not found") + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})); err == nil || + !strings.Contains(err.Error(), "blockhash not found") { + t.Fatalf("settlement failure error = %v", err) + } + state := mustGetChannel(t, session, channelID) + if state.CloseRequestedAt == nil || state.Finalized || state.SettledSignature != nil { + t.Fatalf("state after failed settle = %+v", state) + } + + // Retry succeeds and finalizes the channel. + fake.SendErr = nil + receipt, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})) + if err != nil { + t.Fatalf("close retry: %v", err) + } + if len(fake.Sent) != 1 { + t.Fatalf("settlement broadcasts = %d, want 1", len(fake.Sent)) + } + state = mustGetChannel(t, session, channelID) + if !state.Finalized || state.SettledSignature == nil { + t.Fatalf("state after settle = %+v", state) + } + if receipt.Reference != *state.SettledSignature { + t.Fatalf("reference = %q, want settled signature %q", receipt.Reference, *state.SettledSignature) + } + + // A third close on the finalized channel rejects. + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})); err == nil || + !strings.Contains(err.Error(), "finalized") { + t.Fatalf("third close error = %v", err) + } +} + +func TestSessionCloseWithoutSignerDoesNotSettle(t *testing.T) { + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { o.RPC = fake }) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + openSessionChannel(t, session, channelID, 1_000, signer.Address(), confirmedSignature(0x77)) + + if _, err := verifySessionAction(t, session, intents.NewCloseAction(intents.ClosePayload{ChannelID: channelID})); err != nil { + t.Fatalf("close: %v", err) + } + if len(fake.Sent) != 0 { + t.Fatalf("settlement broadcast without a merchant signer: %d sends", len(fake.Sent)) + } +} + +// ── commit + routes ── + +func reserveDelivery(t *testing.T, routes SessionRoutes, body map[string]any) *httptest.ResponseRecorder { + t.Helper() + encoded, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/__402/session/deliveries", bytes.NewReader(encoded)) + recorder := httptest.NewRecorder() + routes.Deliveries(recorder, request) + return recorder +} + +func commitDeliveryViaRoutes(t *testing.T, routes SessionRoutes, body map[string]any) *httptest.ResponseRecorder { + t.Helper() + encoded, err := json.Marshal(body) + if err != nil { + t.Fatalf("marshal body: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/__402/session/commit", bytes.NewReader(encoded)) + recorder := httptest.NewRecorder() + routes.Commit(recorder, request) + return recorder +} + +func TestSessionCommitForReservedDelivery(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + routes := session.Routes() + + reserve := reserveDelivery(t, routes, map[string]any{"amount": "200", "sessionId": channelID}) + if reserve.Code != http.StatusOK { + t.Fatalf("reserve status = %d body=%s", reserve.Code, reserve.Body) + } + var directive intents.MeteringDirective + if err := json.Unmarshal(reserve.Body.Bytes(), &directive); err != nil { + t.Fatalf("decode directive: %v", err) + } + if directive.DeliveryID != channelID+":1" || directive.Sequence != 1 { + t.Fatalf("directive = %+v", directive) + } + if directive.Currency != "USDC" || directive.Amount != "200" { + t.Fatalf("directive fields = %+v", directive) + } + + voucher := signer.SignVoucher(t, channelID, 150, farFuture()) + receipt, err := verifySessionAction(t, session, intents.NewCommitAction(intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: voucher, + })) + if err != nil { + t.Fatalf("commit: %v", err) + } + wantReference := fmt.Sprintf("%s:%s:150", channelID, directive.DeliveryID) + if receipt.Reference != wantReference { + t.Fatalf("reference = %q, want %q", receipt.Reference, wantReference) + } + state := mustGetChannel(t, session, channelID) + if state.Cumulative != 150 || len(state.CommittedDeliveries) != 1 || len(state.PendingDeliveries) != 0 { + t.Fatalf("state after commit = %+v", state) + } +} + +func TestSessionRoutesValidation(t *testing.T) { + session := newTestSession(t, nil) + routes := session.Routes() + + if recorder := reserveDelivery(t, routes, map[string]any{"amount": "10", "sessionId": "ghost"}); recorder.Code != http.StatusBadRequest { + t.Fatalf("unknown channel status = %d", recorder.Code) + } + if recorder := reserveDelivery(t, routes, map[string]any{"amount": "10"}); recorder.Code != http.StatusBadRequest || + !strings.Contains(recorder.Body.String(), "sessionId required") { + t.Fatalf("missing sessionId: %d %s", recorder.Code, recorder.Body) + } + if recorder := reserveDelivery(t, routes, map[string]any{"amount": "0", "sessionId": "x"}); recorder.Code != http.StatusBadRequest || + !strings.Contains(recorder.Body.String(), "amount must be positive") { + t.Fatalf("zero amount: %d %s", recorder.Code, recorder.Body) + } + if recorder := reserveDelivery(t, routes, map[string]any{"amount": "ten", "sessionId": "x"}); recorder.Code != http.StatusBadRequest { + t.Fatalf("non-numeric amount status = %d", recorder.Code) + } + + invalid := httptest.NewRequest(http.MethodPost, "/__402/session/deliveries", strings.NewReader("not-json")) + recorder := httptest.NewRecorder() + routes.Deliveries(recorder, invalid) + if recorder.Code != http.StatusBadRequest || !strings.Contains(recorder.Body.String(), "invalid request body") { + t.Fatalf("invalid body: %d %s", recorder.Code, recorder.Body) + } + + get := httptest.NewRequest(http.MethodGet, "/__402/session/deliveries", nil) + recorder = httptest.NewRecorder() + routes.Deliveries(recorder, get) + if recorder.Code != http.StatusMethodNotAllowed { + t.Fatalf("GET deliveries status = %d", recorder.Code) + } + + if recorder := commitDeliveryViaRoutes(t, routes, map[string]any{"voucher": map[string]any{}}); recorder.Code != http.StatusBadRequest || + !strings.Contains(recorder.Body.String(), "deliveryId required") { + t.Fatalf("missing deliveryId: %d %s", recorder.Code, recorder.Body) + } + if recorder := commitDeliveryViaRoutes(t, routes, map[string]any{"deliveryId": "d-1"}); recorder.Code != http.StatusBadRequest || + !strings.Contains(recorder.Body.String(), "voucher required") { + t.Fatalf("missing voucher: %d %s", recorder.Code, recorder.Body) + } + getCommit := httptest.NewRequest(http.MethodGet, "/__402/session/commit", nil) + recorder = httptest.NewRecorder() + routes.Commit(recorder, getCommit) + if recorder.Code != http.StatusMethodNotAllowed { + t.Fatalf("GET commit status = %d", recorder.Code) + } +} + +func TestSessionRoutesCommitReplayStatus(t *testing.T) { + session := newTestSession(t, nil) + signer, channelID := openTrustedChannel(t, session, 1_000) + routes := session.Routes() + + reserve := reserveDelivery(t, routes, map[string]any{"amount": "50", "sessionId": channelID}) + var directive intents.MeteringDirective + if err := json.Unmarshal(reserve.Body.Bytes(), &directive); err != nil { + t.Fatalf("decode directive: %v", err) + } + voucher := signer.SignVoucher(t, channelID, 50, farFuture()) + + commitBody := map[string]any{"deliveryId": directive.DeliveryID, "voucher": voucher} + first := commitDeliveryViaRoutes(t, routes, commitBody) + if first.Code != http.StatusOK { + t.Fatalf("first commit: %d %s", first.Code, first.Body) + } + var firstReceipt intents.CommitReceipt + if err := json.Unmarshal(first.Body.Bytes(), &firstReceipt); err != nil { + t.Fatalf("decode receipt: %v", err) + } + if firstReceipt.Status != intents.CommitStatusCommitted || firstReceipt.Amount != "50" { + t.Fatalf("first receipt = %+v", firstReceipt) + } + + replay := commitDeliveryViaRoutes(t, routes, commitBody) + if replay.Code != http.StatusOK { + t.Fatalf("replay commit: %d %s", replay.Code, replay.Body) + } + var replayReceipt intents.CommitReceipt + if err := json.Unmarshal(replay.Body.Bytes(), &replayReceipt); err != nil { + t.Fatalf("decode replay receipt: %v", err) + } + if replayReceipt.Status != intents.CommitStatusReplayed { + t.Fatalf("replay status = %q", replayReceipt.Status) + } +} + +func TestSessionCommitReplayReVerifiesSignature(t *testing.T) { + session := newTestSession(t, nil) + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + forged := solana.SignatureFromBytes(bytes.Repeat([]byte{0xAA}, 64)).String() + + // Seed a channel whose committed delivery carries a forged signature; a + // replay must fail the signature re-verification. + if _, err := session.Core().Store().UpdateChannel(context.Background(), channelID, func(*ChannelState) (ChannelState, error) { + return ChannelState{ + ChannelID: channelID, + AuthorizedSigner: signer.Address(), + Deposit: 1_000, + Cumulative: 50, + NextDeliverySequence: 1, + CommittedDeliveries: []CommittedDelivery{ + {DeliveryID: "d-1", Amount: 50, Cumulative: 50, VoucherSignature: forged}, + }, + }, nil + }); err != nil { + t.Fatalf("seed store: %v", err) + } + + forgedVoucher := intents.SignedVoucher{ + Data: intents.VoucherData{ + ChannelID: channelID, + Cumulative: "50", + ExpiresAt: farFuture(), + }, + Signature: forged, + } + if _, err := verifySessionAction(t, session, intents.NewCommitAction(intents.CommitPayload{ + DeliveryID: "d-1", Voucher: forgedVoucher, + })); err == nil || !strings.Contains(err.Error(), "signature") { + t.Fatalf("forged replay error = %v", err) + } +} + +func TestSessionRoutesShareStoreWithMethod(t *testing.T) { + session := newTestSession(t, nil) + _, channelID := openTrustedChannel(t, session, 1_000) + + recorder := reserveDelivery(t, session.Routes(), map[string]any{"amount": "100", "sessionId": channelID}) + if recorder.Code != http.StatusOK { + t.Fatalf("reserve status = %d body=%s", recorder.Code, recorder.Body) + } + var directive intents.MeteringDirective + if err := json.Unmarshal(recorder.Body.Bytes(), &directive); err != nil { + t.Fatalf("decode directive: %v", err) + } + if directive.DeliveryID != channelID+":1" { + t.Fatalf("deliveryId = %q", directive.DeliveryID) + } +} + +// ── open with a transaction ── + +func TestSessionOpenVerifiesAttachedTransaction(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + session := newTestSession(t, func(o *SessionOptions) { + o.Recipient = fixture.payee.String() + o.Operator = fixture.payee.String() + o.Network = "localnet" + }) + + receipt, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)) + if err != nil { + t.Fatalf("open with transaction: %v", err) + } + if receipt.Reference != fixture.signature { + t.Fatalf("reference = %q, want tx signature", receipt.Reference) + } + state := mustGetChannel(t, session, fixture.channel.String()) + if state == nil || state.Deposit != openFixtureDeposit { + t.Fatalf("state = %+v", state) + } + // Push channel opens record the channel payer as the operator. + if state.Operator == nil || *state.Operator != fixture.payer.PublicKey().String() { + t.Fatalf("operator = %v", state.Operator) + } +} + +func TestSessionOpenRejectsTransactionForWrongRecipient(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + session := newTestSession(t, nil) // recipient differs from the fixture payee + if _, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)); err == nil || + !strings.Contains(err.Error(), "payee") { + t.Fatalf("wrong recipient error = %v", err) + } +} + +func TestSessionServerSubmitterBroadcastsOnceAndReplaysWithoutRebroadcast(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { + o.Recipient = fixture.payee.String() + o.OpenTxSubmitter = OpenTxSubmitterServer + o.RPC = fake + }) + + receipt, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)) + if err != nil { + t.Fatalf("server-submitted open: %v", err) + } + if len(fake.Sent) != 1 { + t.Fatalf("broadcasts = %d, want 1", len(fake.Sent)) + } + if receipt.Reference != fixture.signature { + t.Fatalf("reference = %q, want broadcast signature", receipt.Reference) + } + + // Idempotent replay of the persisted open must not rebroadcast. + if _, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)); err != nil { + t.Fatalf("open replay: %v", err) + } + if len(fake.Sent) != 1 { + t.Fatalf("replay rebroadcast the open: %d sends", len(fake.Sent)) + } +} + +func TestSessionServerSubmitterRequiresRPC(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + session := newTestSession(t, func(o *SessionOptions) { + o.Recipient = fixture.payee.String() + o.OpenTxSubmitter = OpenTxSubmitterServer + }) + if _, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)); err == nil || + !strings.Contains(err.Error(), "requires an rpc client") { + t.Fatalf("missing rpc error = %v", err) + } +} + +func TestSessionServerSubmitterCompletesFeePayerSignature(t *testing.T) { + // Client partial-signs as the channel payer and leaves the fee-payer + // (operator) slot for the server, with the pending placeholder as the + // payload signature: the createServerOpenedPaymentChannelSessionOpener + // flow. + operator := testutil.NewPrivateKey() + fixture := buildServerCompletedOpenFixture(t, operator) + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { + o.Recipient = fixture.payee.String() + o.Operator = operator.PublicKey().String() + o.OpenTxSubmitter = OpenTxSubmitterServer + o.PaymentChannelPayerSigner = operator + o.RPC = fake + }) + + receipt, err := verifySessionAction(t, session, intents.NewOpenAction(fixture.payload)) + if err != nil { + t.Fatalf("server-completed open: %v", err) + } + if len(fake.Sent) != 1 { + t.Fatalf("broadcasts = %d, want 1", len(fake.Sent)) + } + if fake.Sent[0].Signatures[0].IsZero() { + t.Fatal("fee-payer signature was not completed before broadcast") + } + if receipt.Reference != fake.Sent[0].Signatures[0].String() { + t.Fatalf("reference = %q, want broadcast signature", receipt.Reference) + } +} + +// buildServerCompletedOpenFixture builds an open transaction whose fee payer +// is the operator (unsigned) while the channel payer has partial-signed, +// paired with a placeholder payload signature. +func buildServerCompletedOpenFixture(t *testing.T, operator solana.PrivateKey) openTxFixture { + t.Helper() + fixture := buildOpenTxFixture(t, false) + // Rebuild the open transaction with the operator as fee payer; only the + // channel payer partial-signs, leaving the fee-payer slot zeroed. + ix, err := paymentchannels.BuildOpenInstruction(paymentchannels.OpenChannelParams{ + Payer: fixture.payer.PublicKey(), + Payee: fixture.payee, + Mint: fixture.mint, + AuthorizedSigner: fixture.authorized, + Salt: openFixtureSalt, + Deposit: openFixtureDeposit, + GracePeriod: openFixtureGrace, + TokenProgram: solana.TokenProgramID, + }) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + blockhash := solana.MustHashFromBase58("EkSnNWid2cvwEVnVx9aBqawnmiCNiDgp3gUdkDPTKN1N") + tx, err := solana.NewTransaction([]solana.Instruction{ix}, blockhash, solana.TransactionPayer(operator.PublicKey())) + if err != nil { + t.Fatalf("NewTransaction: %v", err) + } + if err := solanatx.SignTransaction(tx, fixture.payer); err != nil { + t.Fatalf("partial-sign open tx: %v", err) + } + encoded, err := solanatx.EncodeTransactionBase64(tx) + if err != nil { + t.Fatalf("EncodeTransactionBase64: %v", err) + } + payload := fixture.payload + payload.Signature = strings.Repeat("1", 64) + payload.Transaction = &encoded + fixture.payload = payload + fixture.expected.Recipient = fixture.payee.String() + return fixture +} + +// ── middleware ── + +func TestSessionMiddlewareChallengeAndVerifyFlow(t *testing.T) { + session := newTestSession(t, nil) + var receiptInContext *core.Receipt + handler := SessionMiddleware(session, func(*http.Request) (SessionChallengeOptions, error) { + return SessionChallengeOptions{Cap: "1000000", Description: "Stream"}, nil + })(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if receipt, ok := ReceiptFromContext(r.Context()); ok { + receiptInContext = &receipt + } + w.WriteHeader(http.StatusOK) + })) + server := httptest.NewServer(handler) + defer server.Close() + + // No credential: 402 with a session challenge. + response, err := http.Get(server.URL) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer response.Body.Close() + if response.StatusCode != http.StatusPaymentRequired { + t.Fatalf("status = %d, want 402", response.StatusCode) + } + header := response.Header.Get(core.WWWAuthenticateHeader) + challenge, err := core.ParseWWWAuthenticate(header) + if err != nil { + t.Fatalf("parse challenge: %v", err) + } + if !challenge.Intent.IsSession() { + t.Fatalf("intent = %q", challenge.Intent) + } + var request intents.SessionRequest + if err := challenge.Request.Decode(&request); err != nil { + t.Fatalf("decode request: %v", err) + } + if request.Cap != "1000000" { + t.Fatalf("cap = %q", request.Cap) + } + + // Open credential: passes through with a receipt. + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + credential, err := core.NewPaymentCredential(challenge.ToEcho(), intents.NewOpenAction( + intents.OpenPayloadPush(channelID, "1000", signer.Address(), "open-sig"))) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + authorization, err := core.FormatAuthorization(credential) + if err != nil { + t.Fatalf("FormatAuthorization: %v", err) + } + authedRequest, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + authedRequest.Header.Set(core.AuthorizationHeader, authorization) + authedResponse, err := http.DefaultClient.Do(authedRequest) + if err != nil { + t.Fatalf("authed GET: %v", err) + } + defer authedResponse.Body.Close() + if authedResponse.StatusCode != http.StatusOK { + t.Fatalf("authed status = %d", authedResponse.StatusCode) + } + receiptHeader := authedResponse.Header.Get(core.PaymentReceiptHeader) + if receiptHeader == "" { + t.Fatal("missing Payment-Receipt header") + } + receipt, err := core.ParseReceipt(receiptHeader) + if err != nil { + t.Fatalf("ParseReceipt: %v", err) + } + if receipt.Reference != "open-sig" { + t.Fatalf("receipt reference = %q", receipt.Reference) + } + if receiptInContext == nil || receiptInContext.Reference != "open-sig" { + t.Fatalf("receipt in context = %+v", receiptInContext) + } + if mustGetChannel(t, session, channelID) == nil { + t.Fatal("middleware did not persist the opened channel") + } + + // Garbage credential: 402 with a problem+json body. + badRequest, err := http.NewRequest(http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } + badRequest.Header.Set(core.AuthorizationHeader, "Payment not-base64url") + badResponse, err := http.DefaultClient.Do(badRequest) + if err != nil { + t.Fatalf("bad GET: %v", err) + } + defer badResponse.Body.Close() + if badResponse.StatusCode != http.StatusPaymentRequired { + t.Fatalf("bad credential status = %d", badResponse.StatusCode) + } + if contentType := badResponse.Header.Get("Content-Type"); contentType != "application/problem+json" { + t.Fatalf("bad credential content type = %q", contentType) + } +} + +func TestSessionMiddlewareSkipsBlockhashPrefetchOnVerifyPath(t *testing.T) { + fake := &countingBlockhashRPC{FakeRPC: testutil.NewFakeRPC()} + session := newTestSession(t, func(o *SessionOptions) { o.RPC = fake }) + handler := SessionMiddleware(session, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Verify path: a valid credential never triggers the prefetch. + challenge, err := session.Challenge(context.Background(), SessionChallengeOptions{}) + if err != nil { + t.Fatalf("Challenge: %v", err) + } + calls := fake.calls() + signer := newTestVoucherSigner(t) + credential, err := core.NewPaymentCredential(challenge.ToEcho(), intents.NewOpenAction( + intents.OpenPayloadPush(solana.NewWallet().PublicKey().String(), "1000", signer.Address(), confirmedSignature(0x88)))) + if err != nil { + t.Fatalf("NewPaymentCredential: %v", err) + } + authorization, err := core.FormatAuthorization(credential) + if err != nil { + t.Fatalf("FormatAuthorization: %v", err) + } + request := httptest.NewRequest(http.MethodGet, "/", nil) + request.Header.Set(core.AuthorizationHeader, authorization) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("verify path status = %d", recorder.Code) + } + if fake.calls() != calls { + t.Fatalf("verify path fetched a blockhash: %d -> %d", calls, fake.calls()) + } + + // Challenge path fetches exactly once. + recorder = httptest.NewRecorder() + handler.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil)) + if recorder.Code != http.StatusPaymentRequired { + t.Fatalf("challenge path status = %d", recorder.Code) + } + if fake.calls() != calls+1 { + t.Fatalf("challenge path blockhash calls = %d, want %d", fake.calls(), calls+1) + } +} + +// countingBlockhashRPC counts GetLatestBlockhash calls on top of FakeRPC. +// The counter is atomic because the idle-close watchdog reads blockhashes +// from its own goroutine. +type countingBlockhashRPC struct { + // FakeRPC handles every RPC call; GetLatestBlockhash is counted first. + *testutil.FakeRPC + + // blockhashCalls counts GetLatestBlockhash invocations; atomic because + // the idle-close watchdog fetches blockhashes from its own goroutine. + blockhashCalls atomic.Int64 +} + +// calls returns the GetLatestBlockhash call count. +func (c *countingBlockhashRPC) calls() int64 { return c.blockhashCalls.Load() } + +func (c *countingBlockhashRPC) GetLatestBlockhash(ctx context.Context, commitment rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { + c.blockhashCalls.Add(1) + return c.FakeRPC.GetLatestBlockhash(ctx, commitment) +} + +// ── idle-close lifecycle ── + +func TestSessionIdleCloseSettlesOnChain(t *testing.T) { + fake := testutil.NewFakeRPC() + merchant := testutil.NewPrivateKey() + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.Signer = merchant + o.CloseDelay = 25 * time.Millisecond + }) + signer, channelID := openTrustedChannel(t, session, 1_000) + if _, err := submitMethodVoucher(t, session, signer, channelID, 300); err != nil { + t.Fatalf("voucher: %v", err) + } + + deadline := time.Now().Add(3 * time.Second) + for { + state := mustGetChannel(t, session, channelID) + if state != nil && state.Finalized && state.SettledSignature != nil { + break + } + if time.Now().After(deadline) { + t.Fatalf("idle close never settled; state = %+v", state) + } + time.Sleep(10 * time.Millisecond) + } + if len(fake.Sent) != 1 { + t.Fatalf("settlement broadcasts = %d, want 1", len(fake.Sent)) + } +} + +func TestSessionIdleCloseWithoutSignerIsInert(t *testing.T) { + fake := testutil.NewFakeRPC() + session := newTestSession(t, func(o *SessionOptions) { + o.RPC = fake + o.CloseDelay = 10 * time.Millisecond + }) + _, channelID := openTrustedChannel(t, session, 1_000) + + time.Sleep(80 * time.Millisecond) + state := mustGetChannel(t, session, channelID) + if state.Finalized || len(fake.Sent) != 0 { + t.Fatalf("idle close ran without a signer: state=%+v sends=%d", state, len(fake.Sent)) + } +} diff --git a/go/protocols/mpp/server/session_onchain.go b/go/protocols/mpp/server/session_onchain.go new file mode 100644 index 000000000..2a79a4cc2 --- /dev/null +++ b/go/protocols/mpp/server/session_onchain.go @@ -0,0 +1,490 @@ +package server + +// On-chain verification and settlement for the session intent. +// +// Trust model: when no verifier is installed on SessionConfig (the seam is +// nil), transaction signatures and deposit amounts are trusted as +// provided. NewOpenTxVerifier always +// validates an attached open transaction structurally (decode, bind the +// payload signature, check the open instruction against the challenge, +// re-derive the channel PDA); confirming that the transaction actually landed +// additionally requires an RPC client. NewTopUpTxVerifier is purely RPC-backed +// (the top-up payload carries only a signature, no transaction), so without an +// RPC client the top-up seam stays nil and the new deposit is trusted as +// provided. + +import ( + "context" + "encoding/binary" + "fmt" + "strings" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// openInstructionDiscriminator is the payment-channel open instruction +// discriminator (single-byte Anchor-numeric form, not the 8-byte sha256 +// convention). Matches OPEN_DISCRIMINATOR in the vendored Codama clients. +const openInstructionDiscriminator = 1 + +// VerifyOpenTxExpected carries the challenge-side values a client-submitted +// open transaction is validated against. +type VerifyOpenTxExpected struct { + // AuthorizedSigner is the voucher signing key claimed by the open payload + // (base58); the transaction's authorizedSigner account must match it. + AuthorizedSigner string + + // Currency is the challenge currency (symbol or mint address). + Currency string + + // MaxCap is the maximum deposit the server accepts (base units). + MaxCap uint64 + + // Mint optionally overrides the SPL mint; empty resolves it from + // Currency/Network. + Mint string + + // Network is the Solana network used for mint resolution. + Network string + + // ProgramID optionally overrides the payment-channels program id; nil + // defaults to the canonical program. + ProgramID *solana.PublicKey + + // Recipient is the primary payment recipient (challenge recipient, + // base58); the transaction's payee account must match it. + Recipient string +} + +// VerifyOpenTxResult carries the channel facts extracted from a verified open +// transaction. +type VerifyOpenTxResult struct { + // ChannelID is the channel PDA derived from the open instruction (base58). + ChannelID string + + // Deposit locked by the open, in base units. + Deposit uint64 + + // GracePeriod is the close grace period in seconds. + GracePeriod uint32 + + // Salt is the channel-derivation salt. + Salt uint64 +} + +// VerifyOpenTx decodes and validates a client-submitted payment-channel open +// transaction against the session challenge. +// +// Both legacy and v0 transaction encodings are accepted (clients across the +// language SDKs emit either). The embedded open +// instruction must target the configured payment-channels program, the payee +// must equal the challenge recipient, the mint must match the challenge +// currency/network, the authorizedSigner must match the payload, the deposit +// must be positive and within the cap, and the channel account must equal +// the PDA re-derived from the instruction's own seeds. +// +// When the payload carries a non-placeholder signature, it must equal the +// transaction's own fee-payer signature: a client must not be able to pair +// an unrelated (but confirmed) signature with different transaction bytes. +// If rpcClient is non-nil, that bound signature is additionally confirmed +// on-chain; nil skips the liveness check (structural validation only). +func VerifyOpenTx(ctx context.Context, expected VerifyOpenTxExpected, payload *intents.OpenPayload, rpcClient solanatx.RPCClient) (VerifyOpenTxResult, error) { + if payload.Transaction == nil || *payload.Transaction == "" { + return VerifyOpenTxResult{}, fmt.Errorf("openPayload.transaction is required for push-mode open verification") + } + + tx, err := solanatx.DecodeTransactionBase64(*payload.Transaction) + if err != nil { + return VerifyOpenTxResult{}, fmt.Errorf("decode open transaction: %w", err) + } + + // Bind the claimed signature to this transaction before trusting it. + boundSignature := payload.Signature != "" && !isPlaceholderSignature(payload.Signature) + if boundSignature { + if len(tx.Signatures) == 0 || tx.Signatures[0].IsZero() { + return VerifyOpenTxResult{}, fmt.Errorf("openPayload.signature is set but the transaction carries no fee-payer signature") + } + if txSignature := tx.Signatures[0].String(); txSignature != payload.Signature { + return VerifyOpenTxResult{}, fmt.Errorf("openPayload.signature %s != transaction signature %s", payload.Signature, txSignature) + } + } + + programID := paymentchannels.ProgramPubkey() + if expected.ProgramID != nil { + programID = *expected.ProgramID + } + expectedMint := expected.Mint + if expectedMint == "" { + expectedMint = paycore.ResolveMint(expected.Currency, expected.Network) + } + if expectedMint == "" { + return VerifyOpenTxResult{}, fmt.Errorf("could not resolve mint from currency %q", expected.Currency) + } + + accountKeys := tx.Message.AccountKeys + accountAt := func(indices []uint16, slot int, label string) (solana.PublicKey, error) { + if slot >= len(indices) || int(indices[slot]) >= len(accountKeys) { + return solana.PublicKey{}, fmt.Errorf("open instruction is missing the %s account at slot %d", label, slot) + } + return accountKeys[indices[slot]], nil + } + + var openIx *solana.CompiledInstruction + for i := range tx.Message.Instructions { + ix := &tx.Message.Instructions[i] + if int(ix.ProgramIDIndex) >= len(accountKeys) || !accountKeys[ix.ProgramIDIndex].Equals(programID) { + continue + } + if len(ix.Data) < 1 || ix.Data[0] != openInstructionDiscriminator { + continue + } + openIx = ix + break + } + if openIx == nil { + return VerifyOpenTxResult{}, fmt.Errorf("no payment-channels open instruction found") + } + + // Open instruction account layout (matches the generated client): + // 0 payer, 1 payee, 2 mint, 3 authorizedSigner, 4 channel, + // 5 payerTokenAccount, 6 channelTokenAccount, 7 tokenProgram, ... + if len(openIx.Accounts) < 7 { + return VerifyOpenTxResult{}, fmt.Errorf("open instruction has too few accounts (%d)", len(openIx.Accounts)) + } + payer, err := accountAt(openIx.Accounts, 0, "payer") + if err != nil { + return VerifyOpenTxResult{}, err + } + payee, err := accountAt(openIx.Accounts, 1, "payee") + if err != nil { + return VerifyOpenTxResult{}, err + } + mint, err := accountAt(openIx.Accounts, 2, "mint") + if err != nil { + return VerifyOpenTxResult{}, err + } + authorizedSigner, err := accountAt(openIx.Accounts, 3, "authorizedSigner") + if err != nil { + return VerifyOpenTxResult{}, err + } + channel, err := accountAt(openIx.Accounts, 4, "channel") + if err != nil { + return VerifyOpenTxResult{}, err + } + + if payee.String() != expected.Recipient { + return VerifyOpenTxResult{}, fmt.Errorf("open payee %s != expected recipient %s", payee, expected.Recipient) + } + if mint.String() != expectedMint { + return VerifyOpenTxResult{}, fmt.Errorf("open mint %s != expected mint %s", mint, expectedMint) + } + if authorizedSigner.String() != expected.AuthorizedSigner { + return VerifyOpenTxResult{}, fmt.Errorf("open authorizedSigner %s != expected %s", authorizedSigner, expected.AuthorizedSigner) + } + + // Instruction data: [discriminator u8][salt u64][deposit u64][grace u32][recipients]. + if len(openIx.Data) < 1+8+8+4 { + return VerifyOpenTxResult{}, fmt.Errorf("open instruction data too short (%d bytes)", len(openIx.Data)) + } + salt := binary.LittleEndian.Uint64(openIx.Data[1:9]) + deposit := binary.LittleEndian.Uint64(openIx.Data[9:17]) + gracePeriod := binary.LittleEndian.Uint32(openIx.Data[17:21]) + + if deposit == 0 { + return VerifyOpenTxResult{}, fmt.Errorf("open deposit must be greater than zero") + } + if deposit > expected.MaxCap { + return VerifyOpenTxResult{}, fmt.Errorf("open deposit %d exceeds max cap %d", deposit, expected.MaxCap) + } + + // Re-derive the channel PDA from the instruction's own seeds. + derivedChannel, _, err := paymentchannels.FindChannelPDAForProgram(payer, payee, mint, authorizedSigner, salt, programID) + if err != nil { + return VerifyOpenTxResult{}, err + } + if !derivedChannel.Equals(channel) { + return VerifyOpenTxResult{}, fmt.Errorf("open channel PDA %s != derived %s", channel, derivedChannel) + } + if payload.ChannelID != nil && *payload.ChannelID != channel.String() { + return VerifyOpenTxResult{}, fmt.Errorf("openPayload.channelId %s != transaction channel %s", *payload.ChannelID, channel) + } + + // Optional liveness check: only when the caller provides an RPC client + // and the client already populated the transaction signature. + if rpcClient != nil && boundSignature { + if err := confirmTransactionSignature(ctx, rpcClient, payload.Signature, "open"); err != nil { + return VerifyOpenTxResult{}, err + } + } + + return VerifyOpenTxResult{ + ChannelID: channel.String(), + Deposit: deposit, + GracePeriod: gracePeriod, + Salt: salt, + }, nil +} + +// NewOpenTxVerifier returns the on-chain open verifier to install on +// SessionConfig.VerifyOpenTx. When the open payload carries a transaction, +// it is structurally validated against the challenge via VerifyOpenTx (with +// an on-chain liveness check when rpcClient is non-nil). When the payload +// carries only a confirmation signature, rpcClient is required and the +// signature is confirmed on-chain via getSignatureStatuses. +func NewOpenTxVerifier(config SessionConfig, rpcClient solanatx.RPCClient) SessionTxVerifier[intents.OpenPayload] { + return func(ctx context.Context, payload *intents.OpenPayload) error { + if payload.Transaction != nil && *payload.Transaction != "" { + expected := VerifyOpenTxExpected{ + AuthorizedSigner: payload.AuthorizedSigner, + Currency: config.Currency, + MaxCap: config.MaxCap, + Network: config.Network, + ProgramID: config.ProgramID, + Recipient: config.Recipient, + } + _, err := VerifyOpenTx(ctx, expected, payload, rpcClient) + return err + } + if rpcClient == nil { + return fmt.Errorf("open verification requires a transaction or an RPC client") + } + return confirmTransactionSignature(ctx, rpcClient, payload.Signature, "open") + } +} + +// NewTopUpTxVerifier returns the on-chain top-up verifier to install on +// SessionConfig.VerifyTopUpTx: it confirms the top-up transaction signature +// on-chain via getSignatureStatuses. +// A nil rpcClient returns nil so the seam stays unset, and the new deposit is +// trusted as provided; suitable only for unit tests or deployments that +// verify transactions out of band. +func NewTopUpTxVerifier(rpcClient solanatx.RPCClient) SessionTxVerifier[intents.TopUpPayload] { + if rpcClient == nil { + return nil + } + return func(ctx context.Context, payload *intents.TopUpPayload) error { + return confirmTransactionSignature(ctx, rpcClient, payload.Signature, "top-up") + } +} + +// SettlementInstructions builds the on-chain settlement sequence for a +// channel: settle_and_finalize over the stored watermark (preceded by the +// Ed25519 precompile instruction when a voucher was accepted) plus the +// distribute instruction, to be bundled into one merchant-signed +// transaction. Hosts drive this after ProcessClose records the close-pending +// state, then call MarkFinalized once the transaction confirms. +// +// The mint and token program are resolved from the configured currency and +// network (Token-2022 for PYUSD/USDG/CASH). +func (s *SessionServer) SettlementInstructions(ctx context.Context, channelID string, merchant solana.PublicKey) ([]solana.Instruction, error) { + state, err := s.store.GetChannel(ctx, channelID) + if err != nil { + return nil, err + } + if state == nil { + return nil, fmt.Errorf("channel %s not found", channelID) + } + return s.settlementInstructionsForState(*state, channelID, merchant, "") +} + +// settlementInstructionsForState derives the settlement instruction sequence +// for an already-read channel snapshot. payerFallback, when non-empty, is +// used as the distribute payer when the channel never recorded an operator; +// empty keeps the strict unknown-payer error. +func (s *SessionServer) settlementInstructionsForState(state ChannelState, channelID string, merchant solana.PublicKey, payerFallback string) ([]solana.Instruction, error) { + channel, err := solana.PublicKeyFromBase58(channelID) + if err != nil { + return nil, fmt.Errorf("invalid channel id %q: %w", channelID, err) + } + programID := paymentchannels.ProgramPubkey() + if s.config.ProgramID != nil { + programID = *s.config.ProgramID + } + + var voucherSignature *[64]byte + var authorizedSigner solana.PublicKey + expiresAt := int64(0) + if state.HighestVoucherSignature != nil { + signature, err := solana.SignatureFromBase58(*state.HighestVoucherSignature) + if err != nil { + return nil, fmt.Errorf("invalid stored voucher signature: %w", err) + } + signatureBytes := [64]byte(signature) + voucherSignature = &signatureBytes + authorizedSigner, err = solana.PublicKeyFromBase58(state.AuthorizedSigner) + if err != nil { + return nil, fmt.Errorf("invalid stored authorized signer %q: %w", state.AuthorizedSigner, err) + } + if state.HighestVoucherExpiresAt == nil { + return nil, fmt.Errorf("channel %s has a voucher signature but no voucher expiry", channelID) + } + expiresAt = *state.HighestVoucherExpiresAt + } + + instructions, err := paymentchannels.BuildSettleAndFinalizeInstructions(paymentchannels.SettleAndFinalizeParams{ + Merchant: merchant, + Channel: channel, + AuthorizedSigner: authorizedSigner, + Signature: voucherSignature, + CumulativeAmount: state.Cumulative, + ExpiresAt: expiresAt, + ProgramID: programID, + }) + if err != nil { + return nil, err + } + + mintAddress := paycore.ResolveMint(s.config.Currency, s.config.Network) + if mintAddress == "" { + return nil, fmt.Errorf("session settlement requires an SPL token, got currency %q", s.config.Currency) + } + mint, err := solana.PublicKeyFromBase58(mintAddress) + if err != nil { + return nil, fmt.Errorf("invalid mint %q: %w", mintAddress, err) + } + tokenProgram, err := solana.PublicKeyFromBase58(paycore.DefaultTokenProgramForCurrency(s.config.Currency, s.config.Network)) + if err != nil { + return nil, fmt.Errorf("invalid token program: %w", err) + } + payerAddress := payerFallback + if state.Operator != nil { + payerAddress = *state.Operator + } + if payerAddress == "" { + return nil, fmt.Errorf("channel %s payer is unknown; cannot derive the refund token account", channelID) + } + payer, err := solana.PublicKeyFromBase58(payerAddress) + if err != nil { + return nil, fmt.Errorf("invalid channel payer %q: %w", payerAddress, err) + } + payee, err := solana.PublicKeyFromBase58(s.config.Recipient) + if err != nil { + return nil, fmt.Errorf("invalid recipient %q: %w", s.config.Recipient, err) + } + + recipients := make([]paymentchannels.Distribution, 0, len(s.config.Splits)) + for _, split := range s.config.Splits { + recipients = append(recipients, paymentchannels.Distribution{ + Recipient: split.Recipient, + Bps: split.BPS, + }) + } + + distribute, err := paymentchannels.BuildDistributeInstruction(paymentchannels.DistributeParams{ + Channel: channel, + Payer: payer, + Payee: payee, + Treasury: paymentchannels.TreasuryOwner(), + Mint: mint, + Recipients: recipients, + TokenProgram: tokenProgram, + ProgramID: programID, + }) + if err != nil { + return nil, err + } + return append(instructions, distribute), nil +} + +// SubmitOpenTxResult carries the verified channel facts plus the broadcast +// signature of a server-submitted open. +type SubmitOpenTxResult struct { + // VerifyOpenTxResult carries the channel facts (channel PDA, deposit, + // grace period, salt) extracted during pre-broadcast validation. + VerifyOpenTxResult + + // Signature of the broadcast open transaction (base58). + Signature string +} + +// SubmitOpenTx validates a client-built payment-channel open transaction, +// completes the fee-payer signature when payerSigner is required by the +// transaction, broadcasts it, and waits for at least confirmed commitment. +// Callers must not persist channel state for a transaction that never +// landed. Used when the session is configured with the server open-tx +// submitter. +func SubmitOpenTx(ctx context.Context, expected VerifyOpenTxExpected, payload *intents.OpenPayload, payerSigner solanatx.Signer, rpcClient solanatx.RPCClient) (SubmitOpenTxResult, error) { + if rpcClient == nil { + return SubmitOpenTxResult{}, fmt.Errorf("SubmitOpenTx requires an RPC client") + } + // Structural validation only: the transaction has not been broadcast yet, + // so there is no on-chain liveness to check. + verified, err := VerifyOpenTx(ctx, expected, payload, nil) + if err != nil { + return SubmitOpenTxResult{}, err + } + tx, err := solanatx.DecodeTransactionBase64(*payload.Transaction) + if err != nil { + return SubmitOpenTxResult{}, fmt.Errorf("decode open transaction: %w", err) + } + // Complete the fee-payer signature when the client left the slot for the + // server (the createServerOpenedPaymentChannelSessionOpener flow builds + // the open with the operator as fee payer and only partial-signs as the + // channel payer). + if payerSigner != nil && signerIsRequired(tx, payerSigner.PublicKey()) { + if err := solanatx.SignTransaction(tx, payerSigner); err != nil { + return SubmitOpenTxResult{}, fmt.Errorf("co-sign open transaction: %w", err) + } + } + if len(tx.Signatures) == 0 || tx.Signatures[0].IsZero() { + return SubmitOpenTxResult{}, fmt.Errorf("open transaction is missing the fee-payer signature") + } + signature, err := solanatx.SendTransaction(ctx, rpcClient, tx) + if err != nil { + return SubmitOpenTxResult{}, fmt.Errorf("broadcast open transaction: %w", err) + } + if err := solanatx.WaitForConfirmation(ctx, rpcClient, signature); err != nil { + return SubmitOpenTxResult{}, fmt.Errorf("confirm open transaction: %w", err) + } + return SubmitOpenTxResult{VerifyOpenTxResult: verified, Signature: signature.String()}, nil +} + +// signerIsRequired reports whether key is one of the transaction's required +// signers. +func signerIsRequired(tx *solana.Transaction, key solana.PublicKey) bool { + for _, signer := range tx.Message.Signers() { + if signer.Equals(key) { + return true + } + } + return false +} + +// confirmTransactionSignature checks once via getSignatureStatuses that the +// base58 signature names a known, successful transaction. label names the +// transaction in error messages ("open", "top-up"). +func confirmTransactionSignature(ctx context.Context, rpcClient solanatx.RPCClient, signature, label string) error { + parsed, err := solana.SignatureFromBase58(signature) + if err != nil { + return fmt.Errorf("invalid %s tx signature %q: %w", label, signature, err) + } + out, err := rpcClient.GetSignatureStatuses(ctx, true, parsed) + if err != nil { + return fmt.Errorf("RPC error verifying %s tx: %w", label, err) + } + if out == nil || len(out.Value) == 0 || out.Value[0] == nil { + return fmt.Errorf("%s tx %q not found; not yet confirmed or does not exist", label, signature) + } + if out.Value[0].Err != nil { + return fmt.Errorf("%s tx %q failed on-chain: %v", label, signature, out.Value[0].Err) + } + return nil +} + +// isPlaceholderSignature reports whether the signature is the pending +// placeholder produced by the server-completed open flow (an empty string or +// a run of 40+ '1' characters, the base58 encoding of the all-ones marker). +func isPlaceholderSignature(signature string) bool { + if signature == "" { + return true + } + if len(signature) < 40 { + return false + } + return strings.Count(signature, "1") == len(signature) +} diff --git a/go/protocols/mpp/server/session_onchain_test.go b/go/protocols/mpp/server/session_onchain_test.go new file mode 100644 index 000000000..69859309b --- /dev/null +++ b/go/protocols/mpp/server/session_onchain_test.go @@ -0,0 +1,716 @@ +package server + +// Coverage of VerifyOpenTx and the settle-and-distribute composition: +// legacy and v0 transaction decoding, payload-signature binding, challenge +// validation failure modes, RPC-backed confirmation, and the settlement +// instruction sequence derived from stored channel state. + +import ( + "bytes" + "context" + "encoding/binary" + "strconv" + "strings" + "testing" + + solana "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/solana-foundation/pay-kit/go/internal/testutil" + "github.com/solana-foundation/pay-kit/go/paycore" + "github.com/solana-foundation/pay-kit/go/paycore/paymentchannels" + "github.com/solana-foundation/pay-kit/go/paycore/solanatx" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// openTxFixture bundles a freshly built and signed payment-channel open +// transaction with the payload and challenge expectations that accept it. +type openTxFixture struct { + payer solana.PrivateKey // channel payer keypair; fee payer and sole signer of the open tx + payee solana.PublicKey // channel recipient the challenge expects + authorized solana.PublicKey // voucher-signing pubkey baked into the channel + mint solana.PublicKey // SPL mint the channel settles in (mainnet USDC) + channel solana.PublicKey // channel PDA derived from payer/payee/mint/authorized + salt + signature string // fee-payer signature of the open tx (base58) + payload intents.OpenPayload // open payload carrying the base64-encoded wire tx + expected VerifyOpenTxExpected // challenge-side expectations that accept this fixture +} + +const ( + openFixtureSalt = uint64(7) + openFixtureDeposit = uint64(1_000_000) + openFixtureGrace = uint32(900) +) + +// buildOpenTxFixture builds a payer-signed open transaction in the requested +// encoding (clients across the language SDKs emit either). +func buildOpenTxFixture(t *testing.T, v0 bool) openTxFixture { + t.Helper() + + payer := testutil.NewPrivateKey() + payee := testutil.NewPrivateKey().PublicKey() + authorized := testutil.NewPrivateKey().PublicKey() + mint := solana.MustPublicKeyFromBase58(paycore.USDCMainnetMint) + + channel, _, err := paymentchannels.FindChannelPDA(payer.PublicKey(), payee, mint, authorized, openFixtureSalt) + if err != nil { + t.Fatalf("FindChannelPDA: %v", err) + } + ix, err := paymentchannels.BuildOpenInstruction(paymentchannels.OpenChannelParams{ + Payer: payer.PublicKey(), + Payee: payee, + Mint: mint, + AuthorizedSigner: authorized, + Salt: openFixtureSalt, + Deposit: openFixtureDeposit, + GracePeriod: openFixtureGrace, + TokenProgram: solana.TokenProgramID, + }) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + + fixture := openTxFixture{ + payer: payer, + payee: payee, + authorized: authorized, + mint: mint, + channel: channel, + } + fixture.signature, fixture.payload = signAndAttachOpenTx(t, &fixture, ix, v0) + fixture.expected = VerifyOpenTxExpected{ + AuthorizedSigner: authorized.String(), + Currency: "USDC", + MaxCap: 5_000_000, + Network: "localnet", + Recipient: payee.String(), + } + return fixture +} + +// signAndAttachOpenTx assembles, signs, and base64-encodes the open +// transaction for ix, returning the fee-payer signature and the open payload +// carrying the wire transaction. +func signAndAttachOpenTx(t *testing.T, fixture *openTxFixture, ix solana.Instruction, v0 bool) (string, intents.OpenPayload) { + t.Helper() + blockhash := solana.MustHashFromBase58("EkSnNWid2cvwEVnVx9aBqawnmiCNiDgp3gUdkDPTKN1N") + tx, err := solana.NewTransaction([]solana.Instruction{ix}, blockhash, solana.TransactionPayer(fixture.payer.PublicKey())) + if err != nil { + t.Fatalf("NewTransaction: %v", err) + } + if v0 { + tx.Message.SetVersion(solana.MessageVersionV0) + } + if _, err := tx.Sign(func(key solana.PublicKey) *solana.PrivateKey { + if key.Equals(fixture.payer.PublicKey()) { + payerKey := fixture.payer + return &payerKey + } + return nil + }); err != nil { + t.Fatalf("sign open transaction: %v", err) + } + encoded, err := solanatx.EncodeTransactionBase64(tx) + if err != nil { + t.Fatalf("EncodeTransactionBase64: %v", err) + } + signature := tx.Signatures[0].String() + payload := intents.OpenPayloadPaymentChannel( + fixture.channel.String(), + strconv.FormatUint(openFixtureDeposit, 10), + fixture.payer.PublicKey().String(), + fixture.payee.String(), + fixture.mint.String(), + openFixtureSalt, + openFixtureGrace, + fixture.authorized.String(), + signature, + ).WithTransaction(encoded) + return signature, payload +} + +// ── VerifyOpenTx: accepted encodings ── + +func TestVerifyOpenTxAcceptsLegacyEncoding(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + result, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil) + if err != nil { + t.Fatalf("VerifyOpenTx: %v", err) + } + if result.ChannelID != fixture.channel.String() { + t.Fatalf("channelId = %s, want %s", result.ChannelID, fixture.channel) + } + if result.Deposit != openFixtureDeposit || result.GracePeriod != openFixtureGrace || result.Salt != openFixtureSalt { + t.Fatalf("result = %+v, want deposit/grace/salt %d/%d/%d", result, openFixtureDeposit, openFixtureGrace, openFixtureSalt) + } +} + +func TestVerifyOpenTxAcceptsV0Encoding(t *testing.T) { + fixture := buildOpenTxFixture(t, true) + // Confirm the fixture really emits the v0 wire prefix before asserting it + // verifies: the message must round-trip through the versioned decoder. + decoded, err := solanatx.DecodeTransactionBase64(*fixture.payload.Transaction) + if err != nil { + t.Fatalf("decode v0 fixture: %v", err) + } + if decoded.Message.GetVersion() != solana.MessageVersionV0 { + t.Fatalf("fixture message version = %v, want v0", decoded.Message.GetVersion()) + } + + result, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil) + if err != nil { + t.Fatalf("VerifyOpenTx: %v", err) + } + if result.ChannelID != fixture.channel.String() { + t.Fatalf("channelId = %s, want %s", result.ChannelID, fixture.channel) + } +} + +func TestVerifyOpenTxHonorsExplicitMintAndProgramOverrides(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.expected.Currency = "not-a-currency" + fixture.expected.Mint = fixture.mint.String() + programID := paymentchannels.ProgramPubkey() + fixture.expected.ProgramID = &programID + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err != nil { + t.Fatalf("VerifyOpenTx with explicit mint/program overrides: %v", err) + } +} + +// ── VerifyOpenTx: failure modes ── + +func TestVerifyOpenTxRejectsUndecodableTransaction(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + garbage := "not-base64!" + fixture.payload.Transaction = &garbage + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "decode open transaction") { + t.Fatalf("err = %v, want decode rejection", err) + } +} + +func TestVerifyOpenTxRequiresTransaction(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.payload.Transaction = nil + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "transaction is required") { + t.Fatalf("err = %v, want transaction-required rejection", err) + } +} + +func TestVerifyOpenTxRejectsWrongPayee(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.expected.Recipient = fixture.payer.PublicKey().String() + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "payee") { + t.Fatalf("err = %v, want payee rejection", err) + } +} + +func TestVerifyOpenTxRejectsWrongMint(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.expected.Currency = "USDT" + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "mint") { + t.Fatalf("err = %v, want mint rejection", err) + } +} + +func TestVerifyOpenTxRejectsWrongAuthorizedSigner(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.expected.AuthorizedSigner = testutil.NewPrivateKey().PublicKey().String() + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "authorizedSigner") { + t.Fatalf("err = %v, want authorizedSigner rejection", err) + } +} + +func TestVerifyOpenTxRejectsOverCapDeposit(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.expected.MaxCap = openFixtureDeposit - 1 + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "exceeds max cap") { + t.Fatalf("err = %v, want over-cap rejection", err) + } +} + +func TestVerifyOpenTxRejectsZeroDeposit(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + // Rebuild the open instruction with a zero deposit; the channel PDA does + // not embed the deposit, so only the deposit check can reject it. + ix, err := paymentchannels.BuildOpenInstruction(paymentchannels.OpenChannelParams{ + Payer: fixture.payer.PublicKey(), + Payee: fixture.payee, + Mint: fixture.mint, + AuthorizedSigner: fixture.authorized, + Salt: openFixtureSalt, + Deposit: 0, + GracePeriod: openFixtureGrace, + TokenProgram: solana.TokenProgramID, + }) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + _, fixture.payload = signAndAttachOpenTx(t, &fixture, ix, false) + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "greater than zero") { + t.Fatalf("err = %v, want zero-deposit rejection", err) + } +} + +func TestVerifyOpenTxRejectsUnboundSignature(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + other := testutil.NewPrivateKey() + unrelated, err := other.Sign([]byte("unrelated transaction")) + if err != nil { + t.Fatalf("sign unrelated payload: %v", err) + } + fixture.payload.Signature = unrelated.String() + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "transaction signature") { + t.Fatalf("err = %v, want signature-binding rejection", err) + } +} + +func TestVerifyOpenTxRejectsSignatureWithoutFeePayerSignature(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + tx, err := solanatx.DecodeTransactionBase64(*fixture.payload.Transaction) + if err != nil { + t.Fatalf("decode fixture transaction: %v", err) + } + tx.Signatures = []solana.Signature{{}} + stripped, err := solanatx.EncodeTransactionBase64(tx) + if err != nil { + t.Fatalf("re-encode stripped transaction: %v", err) + } + fixture.payload.Transaction = &stripped + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "no fee-payer signature") { + t.Fatalf("err = %v, want missing fee-payer-signature rejection", err) + } +} + +func TestVerifyOpenTxAcceptsPlaceholderSignatureWithoutBinding(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fixture.payload.Signature = strings.Repeat("1", 64) + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err != nil { + t.Fatalf("VerifyOpenTx with placeholder signature: %v", err) + } +} + +func TestVerifyOpenTxRejectsMissingOpenInstruction(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + memo, err := solanatx.BuildMemoInstruction("not an open") + if err != nil { + t.Fatalf("BuildMemoInstruction: %v", err) + } + _, fixture.payload = signAndAttachOpenTx(t, &fixture, memo, false) + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "no payment-channels open instruction") { + t.Fatalf("err = %v, want missing-open-instruction rejection", err) + } +} + +func TestVerifyOpenTxRejectsChannelPDAMismatch(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + ix, err := paymentchannels.BuildOpenInstruction(paymentchannels.OpenChannelParams{ + Payer: fixture.payer.PublicKey(), + Payee: fixture.payee, + Mint: fixture.mint, + AuthorizedSigner: fixture.authorized, + Salt: openFixtureSalt, + Deposit: openFixtureDeposit, + GracePeriod: openFixtureGrace, + TokenProgram: solana.TokenProgramID, + }) + if err != nil { + t.Fatalf("BuildOpenInstruction: %v", err) + } + // Swap the channel account (slot 4) for an unrelated key while keeping + // the instruction data intact: the re-derived PDA must catch it. + data, err := ix.Data() + if err != nil { + t.Fatalf("ix.Data: %v", err) + } + accounts := make(solana.AccountMetaSlice, len(ix.Accounts())) + copy(accounts, ix.Accounts()) + tampered := *accounts[4] + tampered.PublicKey = testutil.NewPrivateKey().PublicKey() + accounts[4] = &tampered + forged := solana.NewInstruction(ix.ProgramID(), accounts, data) + + _, fixture.payload = signAndAttachOpenTx(t, &fixture, forged, false) + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "PDA") { + t.Fatalf("err = %v, want channel-PDA rejection", err) + } +} + +func TestVerifyOpenTxRejectsPayloadChannelIDMismatch(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + other := testutil.NewPrivateKey().PublicKey().String() + fixture.payload.ChannelID = &other + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, nil); err == nil || !strings.Contains(err.Error(), "channelId") { + t.Fatalf("err = %v, want payload-channelId rejection", err) + } +} + +// ── VerifyOpenTx: RPC liveness ── + +func TestVerifyOpenTxConfirmsBoundSignatureViaRPC(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fakeRPC := testutil.NewFakeRPC() + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, fakeRPC); err != nil { + t.Fatalf("VerifyOpenTx with confirmed signature: %v", err) + } +} + +func TestVerifyOpenTxSurfacesRPCFailure(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fakeRPC := testutil.NewFakeRPC() + fakeRPC.Statuses[fixture.signature] = &rpc.SignatureStatusesResult{Err: "InstructionError"} + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, fakeRPC); err == nil || !strings.Contains(err.Error(), "failed on-chain") { + t.Fatalf("err = %v, want on-chain failure rejection", err) + } +} + +func TestVerifyOpenTxSurfacesRPCNotFound(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + fakeRPC := testutil.NewFakeRPC() + fakeRPC.Statuses[fixture.signature] = nil + if _, err := VerifyOpenTx(context.Background(), fixture.expected, &fixture.payload, fakeRPC); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want not-found rejection", err) + } +} + +func TestIsPlaceholderSignature(t *testing.T) { + cases := []struct { + signature string + want bool + }{ + {"", true}, + {strings.Repeat("1", 64), true}, + {strings.Repeat("1", 40), true}, + {strings.Repeat("1", 39), false}, + {strings.Repeat("1", 63) + "2", false}, + {"5VERYrealLookingBase58SignatureValue11111111111111111111111111111", false}, + } + for _, tc := range cases { + if got := isPlaceholderSignature(tc.signature); got != tc.want { + t.Fatalf("isPlaceholderSignature(%q) = %v, want %v", tc.signature, got, tc.want) + } + } +} + +// ── NewOpenTxVerifier wiring ── + +// openSessionConfig returns a session config whose challenge values accept +// the fixture's open transaction. +func openSessionConfig(fixture openTxFixture) SessionConfig { + return SessionConfig{ + Operator: fixture.payee.String(), + Recipient: fixture.payee.String(), + MaxCap: 5_000_000, + Currency: "USDC", + Decimals: 6, + Network: "localnet", + Modes: []intents.SessionMode{intents.SessionModePush}, + } +} + +func TestNewOpenTxVerifierAcceptsValidOpenThroughProcessOpen(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + config := openSessionConfig(fixture) + config.VerifyOpenTx = NewOpenTxVerifier(config, nil) + server := newSessionTestServer(config) + + state, err := server.ProcessOpen(context.Background(), &fixture.payload) + if err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + if state.ChannelID != fixture.channel.String() { + t.Fatalf("channelId = %s, want %s", state.ChannelID, fixture.channel) + } +} + +func TestNewOpenTxVerifierRejectsForeignRecipientThroughProcessOpen(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + config := openSessionConfig(fixture) + config.Recipient = fixture.payer.PublicKey().String() // not the tx payee + config.VerifyOpenTx = NewOpenTxVerifier(config, nil) + server := newSessionTestServer(config) + + if _, err := server.ProcessOpen(context.Background(), &fixture.payload); err == nil || !strings.Contains(err.Error(), "payee") { + t.Fatalf("err = %v, want payee rejection through the verifier seam", err) + } +} + +func TestNewOpenTxVerifierWithoutTransactionRequiresRPC(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + config := openSessionConfig(fixture) + verifier := NewOpenTxVerifier(config, nil) + payload := fixture.payload + payload.Transaction = nil + if err := verifier(context.Background(), &payload); err == nil || !strings.Contains(err.Error(), "RPC client") { + t.Fatalf("err = %v, want rpc-required rejection", err) + } +} + +func TestNewOpenTxVerifierWithoutTransactionConfirmsSignature(t *testing.T) { + fixture := buildOpenTxFixture(t, false) + config := openSessionConfig(fixture) + verifier := NewOpenTxVerifier(config, testutil.NewFakeRPC()) + payload := fixture.payload + payload.Transaction = nil + if err := verifier(context.Background(), &payload); err != nil { + t.Fatalf("verifier with confirmed signature: %v", err) + } +} + +// ── NewTopUpTxVerifier ── + +func TestNewTopUpTxVerifierNilRPCDisablesTheSeam(t *testing.T) { + if verifier := NewTopUpTxVerifier(nil); verifier != nil { + t.Fatal("NewTopUpTxVerifier(nil) must return nil so the seam stays trust-as-provided") + } +} + +func TestNewTopUpTxVerifierConfirmsSignature(t *testing.T) { + signer := testutil.NewPrivateKey() + signature, err := signer.Sign([]byte("top-up")) + if err != nil { + t.Fatalf("sign: %v", err) + } + verifier := NewTopUpTxVerifier(testutil.NewFakeRPC()) + payload := &intents.TopUpPayload{ChannelID: "chan", NewDeposit: "2000000", Signature: signature.String()} + if err := verifier(context.Background(), payload); err != nil { + t.Fatalf("verifier with confirmed signature: %v", err) + } +} + +func TestNewTopUpTxVerifierSurfacesFailureAndNotFound(t *testing.T) { + signer := testutil.NewPrivateKey() + signature, err := signer.Sign([]byte("top-up")) + if err != nil { + t.Fatalf("sign: %v", err) + } + fakeRPC := testutil.NewFakeRPC() + fakeRPC.Statuses[signature.String()] = &rpc.SignatureStatusesResult{Err: "InstructionError"} + verifier := NewTopUpTxVerifier(fakeRPC) + payload := &intents.TopUpPayload{ChannelID: "chan", NewDeposit: "2000000", Signature: signature.String()} + if err := verifier(context.Background(), payload); err == nil || !strings.Contains(err.Error(), "top-up") { + t.Fatalf("err = %v, want top-up failure rejection", err) + } + + fakeRPC.Statuses[signature.String()] = nil + if err := verifier(context.Background(), payload); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want not-found rejection", err) + } + + if err := verifier(context.Background(), &intents.TopUpPayload{Signature: "not-base58!"}); err == nil || !strings.Contains(err.Error(), "invalid top-up tx signature") { + t.Fatalf("err = %v, want invalid-signature rejection", err) + } +} + +// ── SettlementInstructions ── + +// openSettlementChannel opens a payment-channel-shaped session (payer set, so +// the distribute refund account can be derived) and returns the voucher +// signer plus the channel id. +func openSettlementChannel(t *testing.T, server *SessionServer, payer solana.PublicKey) (testVoucherSigner, string) { + t.Helper() + signer := newTestVoucherSigner(t) + channelID := testutil.NewPrivateKey().PublicKey().String() + payload := intents.OpenPayloadPaymentChannel( + channelID, "1000000", + payer.String(), + sessionTestRecipient, + paycore.USDCMainnetMint, + openFixtureSalt, openFixtureGrace, + signer.Address(), "dummy_tx_sig", + ) + if _, err := server.ProcessOpen(context.Background(), &payload); err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + return signer, channelID +} + +func TestSettlementInstructionsWithVoucher(t *testing.T) { + config := sessionTestConfig() + config.Splits = []Split{{Recipient: testutil.NewPrivateKey().PublicKey(), BPS: 250}} + server := newSessionTestServer(config) + payer := testutil.NewPrivateKey().PublicKey() + merchant := testutil.NewPrivateKey().PublicKey() + signer, channelID := openSettlementChannel(t, server, payer) + + if _, err := submitVoucher(t, server, signer, channelID, 500); err != nil { + t.Fatalf("submitVoucher: %v", err) + } + if _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}); err != nil { + t.Fatalf("ProcessClose: %v", err) + } + + instructions, err := server.SettlementInstructions(context.Background(), channelID, merchant) + if err != nil { + t.Fatalf("SettlementInstructions: %v", err) + } + if len(instructions) != 3 { + t.Fatalf("instructions = %d, want 3 (ed25519 + settle_and_finalize + distribute)", len(instructions)) + } + + // Instruction 0: the Ed25519 precompile over the stored highest voucher. + if !instructions[0].ProgramID().Equals(paymentchannels.Ed25519ProgramPubkey()) { + t.Fatalf("instruction 0 program = %s, want Ed25519 precompile", instructions[0].ProgramID()) + } + state, err := server.Store().GetChannel(context.Background(), channelID) + if err != nil || state == nil { + t.Fatalf("GetChannel: %v", err) + } + if state.HighestVoucherExpiresAt == nil { + t.Fatal("expected a stored voucher expiry") + } + channel := solana.MustPublicKeyFromBase58(channelID) + wantMessage, err := paymentchannels.VoucherMessageBytes(channel, 500, *state.HighestVoucherExpiresAt) + if err != nil { + t.Fatalf("VoucherMessageBytes: %v", err) + } + precompileData, err := instructions[0].Data() + if err != nil { + t.Fatalf("precompile.Data: %v", err) + } + if !bytes.Equal(precompileData[112:160], wantMessage) { + t.Fatal("precompile message != stored voucher payload") + } + + // Instruction 1: settle_and_finalize committing the watermark. + settleData, err := instructions[1].Data() + if err != nil { + t.Fatalf("settle.Data: %v", err) + } + if settleData[0] != 4 || settleData[len(settleData)-1] != 1 { + t.Fatalf("settle data disc/hasVoucher = %d/%d, want 4/1", settleData[0], settleData[len(settleData)-1]) + } + if got := binary.LittleEndian.Uint64(settleData[33:41]); got != 500 { + t.Fatalf("settled cumulative = %d, want 500", got) + } + if !instructions[1].Accounts()[0].PublicKey.Equals(merchant) { + t.Fatalf("settle merchant = %s, want %s", instructions[1].Accounts()[0].PublicKey, merchant) + } + + // Instruction 2: distribute with the configured split appended. + distributeData, err := instructions[2].Data() + if err != nil { + t.Fatalf("distribute.Data: %v", err) + } + if distributeData[0] != 7 { + t.Fatalf("distribute discriminator = %d, want 7", distributeData[0]) + } + if got := len(instructions[2].Accounts()); got != 11 { + t.Fatalf("distribute accounts = %d, want 11 (10 fixed + 1 split ATA)", got) + } + payerATA, _, err := solana.FindAssociatedTokenAddressWithProgram( + payer, solana.MustPublicKeyFromBase58(paycore.USDCMainnetMint), solana.TokenProgramID) + if err != nil { + t.Fatalf("derive payer ATA: %v", err) + } + if !instructions[2].Accounts()[3].PublicKey.Equals(payerATA) { + t.Fatalf("distribute payer token account = %s, want %s", instructions[2].Accounts()[3].PublicKey, payerATA) + } +} + +func TestSettlementInstructionsVoucherlessClose(t *testing.T) { + config := sessionTestConfig() + programID := paymentchannels.ProgramPubkey() + config.ProgramID = &programID + server := newSessionTestServer(config) + payer := testutil.NewPrivateKey().PublicKey() + merchant := testutil.NewPrivateKey().PublicKey() + _, channelID := openSettlementChannel(t, server, payer) + + if _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}); err != nil { + t.Fatalf("ProcessClose: %v", err) + } + instructions, err := server.SettlementInstructions(context.Background(), channelID, merchant) + if err != nil { + t.Fatalf("SettlementInstructions: %v", err) + } + if len(instructions) != 2 { + t.Fatalf("instructions = %d, want 2 (no precompile without a voucher)", len(instructions)) + } + settleData, err := instructions[0].Data() + if err != nil { + t.Fatalf("settle.Data: %v", err) + } + if settleData[len(settleData)-1] != 0 { + t.Fatalf("hasVoucher = %d, want 0", settleData[len(settleData)-1]) + } + if got := binary.LittleEndian.Uint64(settleData[33:41]); got != 0 { + t.Fatalf("settled cumulative = %d, want 0", got) + } +} + +func TestSettlementInstructionsResolvesToken2022FromCurrency(t *testing.T) { + config := sessionTestConfig() + config.Currency = "PYUSD" + config.Network = "mainnet" + server := newSessionTestServer(config) + payer := testutil.NewPrivateKey().PublicKey() + merchant := testutil.NewPrivateKey().PublicKey() + + signer := newTestVoucherSigner(t) + channelID := testutil.NewPrivateKey().PublicKey().String() + payload := intents.OpenPayloadPaymentChannel( + channelID, "1000000", + payer.String(), sessionTestRecipient, paycore.PYUSDMainnetMint, + openFixtureSalt, openFixtureGrace, signer.Address(), "dummy_tx_sig", + ) + if _, err := server.ProcessOpen(context.Background(), &payload); err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + + instructions, err := server.SettlementInstructions(context.Background(), channelID, merchant) + if err != nil { + t.Fatalf("SettlementInstructions: %v", err) + } + distribute := instructions[len(instructions)-1] + accounts := distribute.Accounts() + if got := accounts[6].PublicKey.String(); got != paycore.PYUSDMainnetMint { + t.Fatalf("distribute mint = %s, want PYUSD mainnet mint", got) + } + if got := accounts[7].PublicKey.String(); got != paycore.Token2022Program { + t.Fatalf("distribute token program = %s, want Token-2022", got) + } +} + +func TestSettlementInstructionsErrorPaths(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + merchant := testutil.NewPrivateKey().PublicKey() + + if _, err := server.SettlementInstructions(context.Background(), "missing-channel", merchant); err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want channel-not-found rejection", err) + } + + // A channel opened without a payer/owner has no refund token account. + _, channelID := openTestChannel(t, server, 1_000_000) + if _, err := server.SettlementInstructions(context.Background(), channelID, merchant); err == nil || !strings.Contains(err.Error(), "payer is unknown") { + t.Fatalf("err = %v, want unknown-payer rejection", err) + } + + // SOL is not an SPL token, so settlement cannot derive token accounts. + solConfig := sessionTestConfig() + solConfig.Currency = "SOL" + solServer := newSessionTestServer(solConfig) + payer := testutil.NewPrivateKey().PublicKey() + _, solChannel := openSettlementChannel(t, solServer, payer) + if _, err := solServer.SettlementInstructions(context.Background(), solChannel, merchant); err == nil || !strings.Contains(err.Error(), "SPL token") { + t.Fatalf("err = %v, want SPL-token rejection", err) + } + + // A pull-style session id that is not a base58 pubkey cannot be settled + // through the payment-channels program. + if _, err := server.ProcessOpen(context.Background(), sessionOpenPayload("not-a-pubkey!", 1_000_000, "signer1")); err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + if _, err := server.SettlementInstructions(context.Background(), "not-a-pubkey!", merchant); err == nil || !strings.Contains(err.Error(), "invalid channel id") { + t.Fatalf("err = %v, want invalid-channel-id rejection", err) + } + + // A challenge recipient that is not a valid pubkey fails distribute + // derivation. + badRecipientConfig := sessionTestConfig() + badRecipientConfig.Recipient = "not-a-recipient!" + badRecipientServer := newSessionTestServer(badRecipientConfig) + _, badChannel := openSettlementChannel(t, badRecipientServer, payer) + if _, err := badRecipientServer.SettlementInstructions(context.Background(), badChannel, merchant); err == nil || !strings.Contains(err.Error(), "invalid recipient") { + t.Fatalf("err = %v, want invalid-recipient rejection", err) + } +} diff --git a/go/protocols/mpp/server/session_routes.go b/go/protocols/mpp/server/session_routes.go new file mode 100644 index 000000000..6fbc3c1ad --- /dev/null +++ b/go/protocols/mpp/server/session_routes.go @@ -0,0 +1,230 @@ +package server + +// Metering side channel and HTTP middleware for the session method. +// +// The reserve/commit side channel is an extension beyond the draft MPP +// spec: SessionFetch-style clients POST to /__402/session/deliveries to +// reserve capacity for a metered delivery and to /__402/session/commit to +// commit it with a signed voucher. Hosts mount the two handlers on those +// paths themselves. + +import ( + "context" + "encoding/json" + "net/http" + + core "github.com/solana-foundation/pay-kit/go/protocols/mpp/core" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/errorcodes" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// SessionRoutes carries the metering side-channel handlers built by +// Session.Routes. Both share the session's channel store, so deliveries see +// channels opened through VerifyCredential. +type SessionRoutes struct { + // Deliveries reserves capacity for a metered delivery. Mount at + // POST /__402/session/deliveries. + Deliveries http.HandlerFunc + + // Commit commits a reserved delivery with a signed voucher. Mount at + // POST /__402/session/commit. + Commit http.HandlerFunc +} + +// sessionDeliveryRequestBody is the JSON body of a delivery reservation. +type sessionDeliveryRequestBody struct { + // SessionID is the channel/session id that will pay for the delivery. + SessionID string `json:"sessionId"` + + // Amount owed for the delivery: a decimal u64 string in token base units. + Amount string `json:"amount"` + + // DeliveryID is an optional idempotency key; when empty the server + // derives ":". + DeliveryID string `json:"deliveryId,omitempty"` + + // CommitURL is an optional commit endpoint hint echoed back to the + // client in the metering directive. + CommitURL string `json:"commitUrl,omitempty"` + + // ExpiresAt is an optional delivery expiry (Unix seconds); zero defaults + // to intents.DefaultSessionExpiresAt. + ExpiresAt int64 `json:"expiresAt,omitempty"` + + // Proof is an optional opaque proof echoed back to the client in the + // metering directive. + Proof string `json:"proof,omitempty"` +} + +// sessionCommitRequestBody is the JSON body of a side-channel commit. +type sessionCommitRequestBody struct { + // DeliveryID names the reserved delivery being committed. Required. + DeliveryID string `json:"deliveryId"` + + // Voucher is the signed voucher whose cumulative (a lifetime total, not + // a per-request delta) settles the delivery. Required; nil is rejected. + Voucher *intents.SignedVoucher `json:"voucher"` +} + +// Routes builds the metering side-channel handlers for this session. +func (s *Session) Routes() SessionRoutes { + return SessionRoutes{ + Deliveries: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeSessionRouteError(w, http.StatusMethodNotAllowed, "POST required") + return + } + var body sessionDeliveryRequestBody + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeSessionRouteError(w, http.StatusBadRequest, "invalid request body") + return + } + if body.SessionID == "" { + writeSessionRouteError(w, http.StatusBadRequest, "sessionId required") + return + } + amount, err := parseSessionU64(body.Amount, "amount") + if err != nil { + writeSessionRouteError(w, http.StatusBadRequest, err.Error()) + return + } + if amount == 0 { + writeSessionRouteError(w, http.StatusBadRequest, "amount must be positive") + return + } + directive, err := s.core.BeginDelivery(r.Context(), DeliveryRequest{ + SessionID: body.SessionID, + Amount: amount, + DeliveryID: body.DeliveryID, + CommitURL: body.CommitURL, + Proof: body.Proof, + ExpiresAt: body.ExpiresAt, + }) + if err != nil { + writeSessionRouteError(w, http.StatusBadRequest, err.Error()) + return + } + s.touch(body.SessionID) + writeSessionRouteJSON(w, http.StatusOK, directive) + }, + Commit: func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeSessionRouteError(w, http.StatusMethodNotAllowed, "POST required") + return + } + var body sessionCommitRequestBody + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeSessionRouteError(w, http.StatusBadRequest, "invalid request body") + return + } + if body.DeliveryID == "" { + writeSessionRouteError(w, http.StatusBadRequest, "deliveryId required") + return + } + if body.Voucher == nil { + writeSessionRouteError(w, http.StatusBadRequest, "voucher required") + return + } + receipt, err := s.core.ProcessCommit(r.Context(), &intents.CommitPayload{ + DeliveryID: body.DeliveryID, + Voucher: *body.Voucher, + }) + if err != nil { + writeSessionRouteError(w, http.StatusBadRequest, err.Error()) + return + } + s.touch(receipt.SessionID) + writeSessionRouteJSON(w, http.StatusOK, receipt) + }, + } +} + +// writeSessionRouteJSON writes a JSON response body with the given status. +func writeSessionRouteJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +// writeSessionRouteError writes the {"error": message} failure body the +// side-channel clients expect. +func writeSessionRouteError(w http.ResponseWriter, status int, message string) { + writeSessionRouteJSON(w, status, map[string]string{"error": message}) +} + +// SessionChallengeFunc returns the per-request challenge options for a route +// gated by SessionMiddleware. A nil function uses zero options (the server +// cap, no description). +type SessionChallengeFunc func(r *http.Request) (SessionChallengeOptions, error) + +// SessionMiddleware wraps an http.Handler to enforce MPP session payments. +// +// Requests without a valid credential receive a 402 with a session challenge +// in WWW-Authenticate. Requests with a valid credential have the action +// applied (open / voucher / commit / topUp / close), the receipt exposed in +// Payment-Receipt and the request context, and are passed through. The +// challenge (and its recentBlockhash prefetch) is only built when a 402 is +// actually issued, so the verify path never fetches a blockhash. +func SessionMiddleware(s *Session, challengeFn SessionChallengeFunc) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var verificationErr error + authHeader := r.Header.Get(core.AuthorizationHeader) + if paymentToken, ok := core.ExtractPaymentScheme(authHeader); ok && paymentToken != "" { + credential, err := core.ParseAuthorization(authHeader) + if err != nil { + verificationErr = core.WrapError(core.ErrCodeInvalidPayload, "parse authorization", err) + } else { + receipt, verifyErr := s.VerifyCredential(r.Context(), credential) + if verifyErr == nil { + if receiptHeader, fmtErr := core.FormatReceipt(receipt); fmtErr == nil { + w.Header().Set(core.PaymentReceiptHeader, receiptHeader) + } + markAuthorizationBoundResponse(w.Header()) + ctx := context.WithValue(r.Context(), receiptContextKey, receipt) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + verificationErr = verifyErr + } + } + + options := SessionChallengeOptions{} + if challengeFn != nil { + var err error + options, err = challengeFn(r) + if err != nil { + http.Error(w, "challenge function error", http.StatusInternalServerError) + return + } + } + challenge, err := s.Challenge(r.Context(), options) + if err != nil { + http.Error(w, "failed to create challenge", http.StatusInternalServerError) + return + } + wwwAuth, err := core.FormatWWWAuthenticate(challenge) + if err != nil { + http.Error(w, "failed to format challenge", http.StatusInternalServerError) + return + } + w.Header().Set(core.WWWAuthenticateHeader, wwwAuth) + markAuthorizationBoundResponse(w.Header()) + + code := errorcodes.PaymentInvalid + message := "Payment required" + if verificationErr != nil { + code = errorcodes.CanonicalFromError(verificationErr) + message = verificationErr.Error() + } + body, err := json.Marshal(errorcodes.NewPaymentRequiredBody(code, message)) + if err != nil { + http.Error(w, "failed to marshal challenge body", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/problem+json") + w.WriteHeader(http.StatusPaymentRequired) + _, _ = w.Write(body) + }) + } +} diff --git a/go/protocols/mpp/server/session_server_test.go b/go/protocols/mpp/server/session_server_test.go new file mode 100644 index 000000000..eaaea4961 --- /dev/null +++ b/go/protocols/mpp/server/session_server_test.go @@ -0,0 +1,939 @@ +package server + +// Off-chain session handler coverage: open, voucher verification, top-up, +// delivery begin/commit, close, and challenge-request building. + +import ( + "context" + "encoding/json" + "errors" + "strconv" + "strings" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +const sessionTestRecipient = "CXhrFZJLKqjzmP3sjYLcF4dTeXWKCy9e2SXXZ2Yo6MPY" + +func sessionTestConfig() SessionConfig { + return SessionConfig{ + Operator: sessionTestRecipient, + Recipient: sessionTestRecipient, + MaxCap: 10_000_000, + Currency: "USDC", + Decimals: 6, + Network: "localnet", + Modes: []intents.SessionMode{intents.SessionModePush}, + } +} + +func newSessionTestServer(config SessionConfig) *SessionServer { + return NewSessionServer(config, NewMemoryChannelStore()) +} + +func sessionOpenPayload(channelID string, deposit uint64, signer string) *intents.OpenPayload { + payload := intents.OpenPayloadPush(channelID, strconv.FormatUint(deposit, 10), signer, "dummy_tx_sig") + return &payload +} + +// openTestChannel opens a channel signed by a fresh keypair and returns the +// signer plus the channel id (a valid base58 32-byte key so vouchers can be +// signed against it). +func openTestChannel(t *testing.T, server *SessionServer, deposit uint64) (testVoucherSigner, string) { + t.Helper() + signer := newTestVoucherSigner(t) + channelID := solana.NewWallet().PublicKey().String() + if _, err := server.ProcessOpen(context.Background(), sessionOpenPayload(channelID, deposit, signer.Address())); err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + return signer, channelID +} + +// submitVoucher signs and submits a voucher for cumulative, far in the future. +func submitVoucher(t *testing.T, server *SessionServer, signer testVoucherSigner, channelID string, cumulative uint64) (uint64, error) { + t.Helper() + voucher := signer.SignVoucher(t, channelID, cumulative, farFuture()) + return server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: voucher}) +} + +// ── BuildChallengeRequest ── + +func TestBuildChallengeRequestCanonicalShape(t *testing.T) { + config := sessionTestConfig() + config.MinVoucherDelta = 0 + server := newSessionTestServer(config) + + request := server.BuildChallengeRequest(1_000_000) + if request.Cap != "1000000" { + t.Fatalf("cap = %q, want 1000000", request.Cap) + } + if request.Currency != "USDC" || request.Operator != sessionTestRecipient || request.Recipient != sessionTestRecipient { + t.Fatalf("unexpected request fields: %+v", request) + } + if request.Decimals == nil || *request.Decimals != 6 { + t.Fatalf("decimals = %v, want 6", request.Decimals) + } + if request.Network == nil || *request.Network != "localnet" { + t.Fatalf("network = %v, want localnet", request.Network) + } + // minVoucherDelta omitted when zero, modes omitted when push-only, + // pullVoucherStrategy omitted when pull is not offered. + raw, err := json.Marshal(request) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + for _, absent := range []string{"minVoucherDelta", "modes", "pullVoucherStrategy", "recentBlockhash"} { + if strings.Contains(string(raw), absent) { + t.Fatalf("challenge JSON unexpectedly contains %q: %s", absent, raw) + } + } +} + +func TestBuildChallengeRequestClampsCapToMax(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + request := server.BuildChallengeRequest(99_000_000) + if request.Cap != "10000000" { + t.Fatalf("cap = %q, want clamped 10000000", request.Cap) + } +} + +func TestBuildChallengeRequestIncludesMinVoucherDeltaWhenPositive(t *testing.T) { + config := sessionTestConfig() + config.MinVoucherDelta = 250 + server := newSessionTestServer(config) + request := server.BuildChallengeRequest(1_000) + if request.MinVoucherDelta == nil || *request.MinVoucherDelta != "250" { + t.Fatalf("minVoucherDelta = %v, want 250", request.MinVoucherDelta) + } +} + +func TestBuildChallengeRequestAdvertisesPullModeAndStrategy(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + config := sessionTestConfig() + config.Modes = []intents.SessionMode{intents.SessionModePush, intents.SessionModePull} + config.PullVoucherStrategy = &strategy + config.Splits = []Split{{Recipient: solana.MustPublicKeyFromBase58(sessionTestRecipient), BPS: 10}} + server := newSessionTestServer(config) + + request := server.BuildChallengeRequest(1_000) + if len(request.Modes) != 2 { + t.Fatalf("modes = %v, want push+pull", request.Modes) + } + if request.PullVoucherStrategy == nil || *request.PullVoucherStrategy != strategy { + t.Fatalf("pullVoucherStrategy = %v, want clientVoucher", request.PullVoucherStrategy) + } + if len(request.Splits) != 1 || request.Splits[0].Recipient != sessionTestRecipient || request.Splits[0].BPS != 10 { + t.Fatalf("splits = %+v", request.Splits) + } +} + +// ── ProcessOpen ── + +func TestProcessOpenStoresState(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + state, err := server.ProcessOpen(context.Background(), sessionOpenPayload("chan1", 1_000_000, "signer1")) + if err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + if state.Deposit != 1_000_000 || state.Cumulative != 0 || state.Finalized { + t.Fatalf("state = %+v", state) + } + if state.AuthorizedSigner != "signer1" { + t.Fatalf("authorizedSigner = %q, want signer1", state.AuthorizedSigner) + } +} + +func TestProcessOpenZeroDepositRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + if _, err := server.ProcessOpen(context.Background(), sessionOpenPayload("chan1", 0, "signer1")); err == nil { + t.Fatal("expected zero-deposit rejection") + } +} + +func TestProcessOpenExceedsCapRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + if _, err := server.ProcessOpen(context.Background(), sessionOpenPayload("chan1", 20_000_000, "signer1")); err == nil { + t.Fatal("expected over-cap rejection") + } +} + +func TestProcessOpenRejectsUnadvertisedPullMode(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + payload := intents.OpenPayloadPaymentChannelWithMode( + intents.SessionModePull, + "chan1", "1000000", "payer", sessionTestRecipient, "mint", + 1, 900, "signer1", "pending", + ) + _, err := server.ProcessOpen(context.Background(), &payload) + if err == nil || !strings.Contains(err.Error(), "not supported") { + t.Fatalf("err = %v, want mode-not-supported", err) + } +} + +func TestProcessOpenAcceptsAdvertisedPullClientVoucherChannel(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + config := sessionTestConfig() + config.Modes = []intents.SessionMode{intents.SessionModePull} + config.PullVoucherStrategy = &strategy + server := newSessionTestServer(config) + payload := intents.OpenPayloadPaymentChannelWithMode( + intents.SessionModePull, + "chan1", "1000000", "payer", sessionTestRecipient, "mint", + 1, 900, "signer1", "pending", + ) + state, err := server.ProcessOpen(context.Background(), &payload) + if err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + if state.ChannelID != "chan1" || state.Deposit != 1_000_000 { + t.Fatalf("state = %+v", state) + } + if state.Operator == nil || *state.Operator != "payer" { + t.Fatalf("operator = %v, want payer fallback", state.Operator) + } +} + +func TestProcessOpenPrefersChannelIDOverTokenAccount(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + config := sessionTestConfig() + config.Modes = []intents.SessionMode{intents.SessionModePull} + config.PullVoucherStrategy = &strategy + server := newSessionTestServer(config) + + payload := intents.OpenPayloadPull("token-acct", "1000", "owner", "signer1", "sig") + channelID := "delegation-pda" + payload.ChannelID = &channelID + + state, err := server.ProcessOpen(context.Background(), &payload) + if err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + if state.ChannelID != "delegation-pda" { + t.Fatalf("session key = %q, want channelId to win over tokenAccount", state.ChannelID) + } + if state.Operator == nil || *state.Operator != "owner" { + t.Fatalf("operator = %v, want owner", state.Operator) + } +} + +func TestProcessOpenReplayPreservesWatermark(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := submitVoucher(t, server, signer, channelID, 250); err != nil { + t.Fatalf("voucher: %v", err) + } + + replayed, err := server.ProcessOpen(context.Background(), sessionOpenPayload(channelID, 1_000_000, signer.Address())) + if err != nil { + t.Fatalf("replayed open: %v", err) + } + if replayed.Cumulative != 250 { + t.Fatalf("replayed open reset the watermark: cumulative = %d, want 250", replayed.Cumulative) + } + if replayed.HighestVoucherSignature == nil { + t.Fatal("replayed open erased the highest voucher signature") + } +} + +func TestProcessOpenReplayWithDifferentSignerRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + other := newTestVoucherSigner(t) + _, err := server.ProcessOpen(context.Background(), sessionOpenPayload(channelID, 1_000_000, other.Address())) + if err == nil || !strings.Contains(err.Error(), "different authorized signer") { + t.Fatalf("err = %v, want different-authorized-signer rejection", err) + } +} + +func TestProcessOpenReplayOnFinalizedChannelRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + if err := server.MarkFinalized(context.Background(), channelID); err != nil { + t.Fatalf("MarkFinalized: %v", err) + } + _, err := server.ProcessOpen(context.Background(), sessionOpenPayload(channelID, 1_000_000, signer.Address())) + if err == nil || !strings.Contains(err.Error(), "finalized") { + t.Fatalf("err = %v, want finalized rejection", err) + } +} + +func TestProcessOpenInvokesVerifyOpenTxSeamForPush(t *testing.T) { + verified := 0 + config := sessionTestConfig() + config.VerifyOpenTx = func(_ context.Context, payload *intents.OpenPayload) error { + verified++ + if payload.Signature != "dummy_tx_sig" { + t.Fatalf("verifier got signature %q", payload.Signature) + } + return nil + } + server := newSessionTestServer(config) + if _, err := server.ProcessOpen(context.Background(), sessionOpenPayload("chan1", 1_000, "signer1")); err != nil { + t.Fatalf("ProcessOpen: %v", err) + } + if verified != 1 { + t.Fatalf("VerifyOpenTx invoked %d times, want 1", verified) + } +} + +func TestProcessOpenVerifyOpenTxErrorRejectsWithoutPersisting(t *testing.T) { + wantErr := errors.New("tx not found") + config := sessionTestConfig() + config.VerifyOpenTx = func(context.Context, *intents.OpenPayload) error { return wantErr } + server := newSessionTestServer(config) + + _, err := server.ProcessOpen(context.Background(), sessionOpenPayload("chan1", 1_000, "signer1")) + if !errors.Is(err, wantErr) { + t.Fatalf("err = %v, want %v", err, wantErr) + } + state, err := server.Store().GetChannel(context.Background(), "chan1") + if err != nil { + t.Fatalf("GetChannel: %v", err) + } + if state != nil { + t.Fatalf("channel persisted despite failed verification: %+v", state) + } +} + +func TestProcessOpenSkipsVerifyOpenTxForPull(t *testing.T) { + strategy := intents.SessionPullVoucherStrategyClientVoucher + config := sessionTestConfig() + config.Modes = []intents.SessionMode{intents.SessionModePull} + config.PullVoucherStrategy = &strategy + config.VerifyOpenTx = func(context.Context, *intents.OpenPayload) error { + t.Fatal("VerifyOpenTx must not run for pull opens") + return nil + } + server := newSessionTestServer(config) + + payload := intents.OpenPayloadPull("token-acct", "1000", "owner", "signer1", "sig") + if _, err := server.ProcessOpen(context.Background(), &payload); err != nil { + t.Fatalf("ProcessOpen: %v", err) + } +} + +// ── VerifyVoucher ── + +func TestVerifyVoucherAdvancesWatermark(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + cumulative, err := submitVoucher(t, server, signer, channelID, 100) + if err != nil { + t.Fatalf("VerifyVoucher: %v", err) + } + if cumulative != 100 { + t.Fatalf("cumulative = %d, want 100", cumulative) + } + + cumulative, err = submitVoucher(t, server, signer, channelID, 300) + if err != nil { + t.Fatalf("VerifyVoucher: %v", err) + } + if cumulative != 300 { + t.Fatalf("cumulative = %d, want 300", cumulative) + } + + state, err := server.Store().GetChannel(context.Background(), channelID) + if err != nil || state == nil { + t.Fatalf("GetChannel: state=%v err=%v", state, err) + } + if state.Cumulative != 300 || state.HighestVoucherSignature == nil || state.HighestVoucherExpiresAt == nil { + t.Fatalf("state = %+v", state) + } +} + +func TestVerifyVoucherUnknownChannelRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + _, err := server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: voucher}) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want channel-not-found", err) + } +} + +func TestVerifyVoucherNonMonotonicRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := submitVoucher(t, server, signer, channelID, 200); err != nil { + t.Fatalf("voucher: %v", err) + } + // Decreasing cumulative. + _, err := submitVoucher(t, server, signer, channelID, 150) + if err == nil || !strings.Contains(err.Error(), "must exceed watermark") { + t.Fatalf("err = %v, want non-monotonic rejection", err) + } + // Equal cumulative with a different signature (different expiry) is not a + // replay and must also be rejected as non-monotonic. + different := signer.SignVoucher(t, channelID, 200, farFuture()+60) + _, err = server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: different}) + if err == nil || !strings.Contains(err.Error(), "must exceed watermark") { + t.Fatalf("err = %v, want non-monotonic rejection for equal cumulative", err) + } +} + +func TestVerifyVoucherIdempotentReplayReturnsSameCumulative(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + voucher := signer.SignVoucher(t, channelID, 150, farFuture()) + if _, err := server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: voucher}); err != nil { + t.Fatalf("first submit: %v", err) + } + cumulative, err := server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: voucher}) + if err != nil { + t.Fatalf("replay: %v", err) + } + if cumulative != 150 { + t.Fatalf("replay cumulative = %d, want 150", cumulative) + } +} + +func TestVerifyVoucherRespectsMinVoucherDelta(t *testing.T) { + config := sessionTestConfig() + config.MinVoucherDelta = 100 + server := newSessionTestServer(config) + signer, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := submitVoucher(t, server, signer, channelID, 50); err == nil { + t.Fatal("expected below-min-delta rejection") + } + if _, err := submitVoucher(t, server, signer, channelID, 100); err != nil { + t.Fatalf("delta == min must pass: %v", err) + } +} + +func TestVerifyVoucherAcceptsLegacyCumulativeAlias(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + signed := signer.SignVoucher(t, channelID, 400, farFuture()) + // Re-encode the voucher payload with the legacy "cumulative" wire alias. + wire := []byte(`{"voucher":{"data":{"channelId":"` + channelID + + `","cumulative":"400","expiresAt":` + strconv.FormatInt(signed.Data.ExpiresAt, 10) + + `},"signature":"` + signed.Signature + `"}}`) + var payload intents.VoucherPayload + if err := json.Unmarshal(wire, &payload); err != nil { + t.Fatalf("decode aliased payload: %v", err) + } + + cumulative, err := server.VerifyVoucher(context.Background(), &payload) + if err != nil { + t.Fatalf("VerifyVoucher: %v", err) + } + if cumulative != 400 { + t.Fatalf("cumulative = %d, want 400", cumulative) + } +} + +// ── ProcessTopUp ── + +func TestProcessTopUpRaisesDeposit(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + state, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID, + NewDeposit: "2000000", + Signature: "topup_sig", + }) + if err != nil { + t.Fatalf("ProcessTopUp: %v", err) + } + if state.Deposit != 2_000_000 { + t.Fatalf("deposit = %d, want 2000000", state.Deposit) + } +} + +func TestProcessTopUpRejectsNonIncreasingDeposit(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + _, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "1000000", Signature: "sig", + }) + if err == nil || !strings.Contains(err.Error(), "must exceed current deposit") { + t.Fatalf("err = %v, want non-increasing rejection", err) + } +} + +func TestProcessTopUpRejectsOverMaxCap(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + _, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "20000000", Signature: "sig", + }) + if err == nil || !strings.Contains(err.Error(), "exceeds max cap") { + t.Fatalf("err = %v, want over-cap rejection", err) + } +} + +func TestProcessTopUpRejectsWhenFinalizedOrClosePending(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + if _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}); err != nil { + t.Fatalf("ProcessClose: %v", err) + } + _, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "2000000", Signature: "sig", + }) + if err == nil || !strings.Contains(err.Error(), "close is pending") { + t.Fatalf("err = %v, want close-pending rejection", err) + } + + server2 := newSessionTestServer(sessionTestConfig()) + _, channelID2 := openTestChannel(t, server2, 1_000_000) + if err := server2.MarkFinalized(context.Background(), channelID2); err != nil { + t.Fatalf("MarkFinalized: %v", err) + } + _, err = server2.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID2, NewDeposit: "2000000", Signature: "sig", + }) + if err == nil || !strings.Contains(err.Error(), "finalized") { + t.Fatalf("err = %v, want finalized rejection", err) + } +} + +func TestProcessTopUpInvokesVerifyTopUpTxSeam(t *testing.T) { + wantErr := errors.New("topup tx unknown") + config := sessionTestConfig() + config.VerifyTopUpTx = func(_ context.Context, payload *intents.TopUpPayload) error { + if payload.Signature != "topup_sig" { + t.Fatalf("verifier got signature %q", payload.Signature) + } + return wantErr + } + server := newSessionTestServer(config) + _, channelID := openTestChannel(t, server, 1_000_000) + + _, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "2000000", Signature: "topup_sig", + }) + if !errors.Is(err, wantErr) { + t.Fatalf("err = %v, want %v", err, wantErr) + } + state, getErr := server.Store().GetChannel(context.Background(), channelID) + if getErr != nil || state == nil { + t.Fatalf("GetChannel: state=%v err=%v", state, getErr) + } + if state.Deposit != 1_000_000 { + t.Fatalf("deposit raised despite failed verification: %d", state.Deposit) + } +} + +func TestVoucherAcceptedAfterTopUpRaisesDeposit(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000) + + if _, err := submitVoucher(t, server, signer, channelID, 2_000); err == nil { + t.Fatal("expected exceeds-deposit rejection before top-up") + } + if _, err := server.ProcessTopUp(context.Background(), &intents.TopUpPayload{ + ChannelID: channelID, NewDeposit: "5000", Signature: "sig", + }); err != nil { + t.Fatalf("ProcessTopUp: %v", err) + } + if _, err := submitVoucher(t, server, signer, channelID, 2_000); err != nil { + t.Fatalf("voucher after top-up: %v", err) + } +} + +// ── BeginDelivery ── + +func TestBeginDeliveryAssignsSequenceAndDefaultDeliveryID(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + first, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + if first.DeliveryID != channelID+":1" || first.Sequence != 1 { + t.Fatalf("directive = %+v, want sequence 1 and default id", first) + } + if first.Amount != "100" || first.Currency != "USDC" || first.SessionID != channelID { + t.Fatalf("directive = %+v", first) + } + if first.ExpiresAt != intents.DefaultSessionExpiresAt { + t.Fatalf("expiresAt = %d, want default %d", first.ExpiresAt, intents.DefaultSessionExpiresAt) + } + + second, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 50}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + if second.DeliveryID != channelID+":2" || second.Sequence != 2 { + t.Fatalf("directive = %+v, want sequence 2", second) + } +} + +func TestBeginDeliveryHonorsExplicitFields(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + expiresAt := time.Now().Unix() + 60 + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{ + SessionID: channelID, + Amount: 100, + DeliveryID: "custom-id", + CommitURL: "https://example.test/commit", + Proof: "proof-blob", + ExpiresAt: expiresAt, + }) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + if directive.DeliveryID != "custom-id" || directive.ExpiresAt != expiresAt { + t.Fatalf("directive = %+v", directive) + } + if directive.CommitURL == nil || *directive.CommitURL != "https://example.test/commit" { + t.Fatalf("commitUrl = %v", directive.CommitURL) + } + if directive.Proof == nil || *directive.Proof != "proof-blob" { + t.Fatalf("proof = %v", directive.Proof) + } +} + +func TestBeginDeliveryRejectsZeroAmountAndUnknownChannel(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + if _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: "ghost", Amount: 0}); err == nil { + t.Fatal("expected zero-amount rejection") + } + if _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: "ghost", Amount: 5}); err == nil { + t.Fatal("expected unknown-channel rejection") + } +} + +func TestBeginDeliveryRejectsDuplicateDeliveryID(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := server.BeginDelivery(context.Background(), DeliveryRequest{ + SessionID: channelID, Amount: 10, DeliveryID: "dup", + }); err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + _, err := server.BeginDelivery(context.Background(), DeliveryRequest{ + SessionID: channelID, Amount: 10, DeliveryID: "dup", + }) + if err == nil || !strings.Contains(err.Error(), "already exists") { + t.Fatalf("err = %v, want duplicate rejection", err) + } +} + +func TestBeginDeliveryReservationMath(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000) + + // Advance the watermark to 400 so the reservation has to account for it. + if _, err := submitVoucher(t, server, signer, channelID, 400); err != nil { + t.Fatalf("voucher: %v", err) + } + // Reserve 500: cumulative 400 + pending 500 = 900 <= 1000. + if _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 500}); err != nil { + t.Fatalf("first reservation: %v", err) + } + // Reserve 100 more: 400 + 500 + 100 = 1000 <= 1000 (boundary holds). + if _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}); err != nil { + t.Fatalf("boundary reservation: %v", err) + } + // One more unit must fail: 400 + 600 + 1 > 1000. + _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 1}) + if err == nil || !strings.Contains(err.Error(), "exceeds available deposit") { + t.Fatalf("err = %v, want reservation overflow rejection", err) + } +} + +func TestBeginDeliveryRejectedWhenClosePending(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + if _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}); err != nil { + t.Fatalf("ProcessClose: %v", err) + } + _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 5}) + if err == nil || !strings.Contains(err.Error(), "close is pending") { + t.Fatalf("err = %v, want close-pending rejection", err) + } +} + +// ── ProcessCommit ── + +func TestProcessCommitCommitsReservedDelivery(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + receipt, err := server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, + Voucher: voucher, + }) + if err != nil { + t.Fatalf("ProcessCommit: %v", err) + } + if receipt.Status != intents.CommitStatusCommitted || receipt.Amount != "100" || receipt.Cumulative != "100" { + t.Fatalf("receipt = %+v", receipt) + } + + state, err := server.Store().GetChannel(context.Background(), channelID) + if err != nil || state == nil { + t.Fatalf("GetChannel: state=%v err=%v", state, err) + } + if state.Cumulative != 100 || len(state.PendingDeliveries) != 0 || len(state.CommittedDeliveries) != 1 { + t.Fatalf("state = %+v", state) + } +} + +func TestProcessCommitReplayReturnsCachedReceipt(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + payload := &intents.CommitPayload{DeliveryID: directive.DeliveryID, Voucher: voucher} + + if _, err := server.ProcessCommit(context.Background(), payload); err != nil { + t.Fatalf("first commit: %v", err) + } + replay, err := server.ProcessCommit(context.Background(), payload) + if err != nil { + t.Fatalf("replayed commit: %v", err) + } + if replay.Status != intents.CommitStatusReplayed || replay.Amount != "100" || replay.Cumulative != "100" { + t.Fatalf("replay receipt = %+v", replay) + } +} + +func TestProcessCommitReplayWithDifferentVoucherRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 200}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + first := signer.SignVoucher(t, channelID, 100, farFuture()) + if _, err := server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: first, + }); err != nil { + t.Fatalf("first commit: %v", err) + } + + different := signer.SignVoucher(t, channelID, 150, farFuture()) + _, err = server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: different, + }) + if err == nil || !strings.Contains(err.Error(), "already committed with different voucher") { + t.Fatalf("err = %v, want different-voucher rejection", err) + } +} + +func TestProcessCommitReplayReVerifiesSignature(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + if _, err := server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: voucher, + }); err != nil { + t.Fatalf("first commit: %v", err) + } + + // Same signature and cumulative, but tampered expiry: the replayed + // voucher no longer verifies and must be rejected. + tampered := voucher + tampered.Data.ExpiresAt = voucher.Data.ExpiresAt + 1 + _, err = server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: tampered, + }) + if err == nil { + t.Fatal("expected replayed-commit signature re-verification failure") + } +} + +func TestProcessCommitUnknownDeliveryRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + _, err := server.ProcessCommit(context.Background(), &intents.CommitPayload{DeliveryID: "ghost", Voucher: voucher}) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want delivery-not-found", err) + } +} + +func TestProcessCommitExpiredDirectiveRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{ + SessionID: channelID, + Amount: 100, + ExpiresAt: time.Now().Unix() - 10, + }) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + voucher := signer.SignVoucher(t, channelID, 100, farFuture()) + _, err = server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: voucher, + }) + if err == nil || !strings.Contains(err.Error(), "has expired") { + t.Fatalf("err = %v, want expired-directive rejection", err) + } +} + +func TestProcessCommitOverReservedAmountRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + directive, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 100}) + if err != nil { + t.Fatalf("BeginDelivery: %v", err) + } + // The voucher claims 150 against a 100 reservation. + voucher := signer.SignVoucher(t, channelID, 150, farFuture()) + _, err = server.ProcessCommit(context.Background(), &intents.CommitPayload{ + DeliveryID: directive.DeliveryID, Voucher: voucher, + }) + if err == nil || !strings.Contains(err.Error(), "exceeds reserved amount") { + t.Fatalf("err = %v, want over-reservation rejection", err) + } +} + +// ── ProcessClose ── + +func TestProcessCloseFlipsClosePendingAndBlocksFurtherActivity(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + state, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}) + if err != nil { + t.Fatalf("ProcessClose: %v", err) + } + if state.CloseRequestedAt == nil { + t.Fatal("closeRequestedAt not set") + } + + if _, err := submitVoucher(t, server, signer, channelID, 100); err == nil { + t.Fatal("expected voucher rejection after close") + } + if _, err := server.BeginDelivery(context.Background(), DeliveryRequest{SessionID: channelID, Amount: 1}); err == nil { + t.Fatal("expected delivery rejection after close") + } +} + +func TestProcessCloseDoubleCloseRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}); err != nil { + t.Fatalf("ProcessClose: %v", err) + } + _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID}) + if err == nil || !strings.Contains(err.Error(), "close already requested") { + t.Fatalf("err = %v, want double-close rejection", err) + } +} + +func TestProcessCloseFinalVoucherAdvancesWatermark(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := submitVoucher(t, server, signer, channelID, 100); err != nil { + t.Fatalf("voucher: %v", err) + } + final := signer.SignVoucher(t, channelID, 500, farFuture()) + state, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID, Voucher: &final}) + if err != nil { + t.Fatalf("ProcessClose: %v", err) + } + if state.Cumulative != 500 { + t.Fatalf("cumulative = %d, want 500", state.Cumulative) + } + if state.HighestVoucherSignature == nil || *state.HighestVoucherSignature != final.Signature { + t.Fatalf("highest signature not updated: %+v", state) + } +} + +func TestProcessCloseNonMonotonicFinalVoucherIsHardError(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + if _, err := submitVoucher(t, server, signer, channelID, 300); err != nil { + t.Fatalf("voucher: %v", err) + } + stale := signer.SignVoucher(t, channelID, 200, farFuture()) + _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID, Voucher: &stale}) + if err == nil || !strings.Contains(err.Error(), "must exceed watermark") { + t.Fatalf("err = %v, want non-monotonic hard error", err) + } + + // The failed close must not flip close-pending. + state, getErr := server.Store().GetChannel(context.Background(), channelID) + if getErr != nil || state == nil { + t.Fatalf("GetChannel: state=%v err=%v", state, getErr) + } + if state.CloseRequestedAt != nil { + t.Fatal("failed close flipped close-pending") + } + if state.Cumulative != 300 { + t.Fatalf("cumulative = %d, want unchanged 300", state.Cumulative) + } +} + +func TestProcessCloseAcceptsReplayOfCurrentHighestVoucher(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000_000) + + highest := signer.SignVoucher(t, channelID, 300, farFuture()) + if _, err := server.VerifyVoucher(context.Background(), &intents.VoucherPayload{Voucher: highest}); err != nil { + t.Fatalf("voucher: %v", err) + } + state, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID, Voucher: &highest}) + if err != nil { + t.Fatalf("ProcessClose with replayed highest voucher: %v", err) + } + if state.CloseRequestedAt == nil || state.Cumulative != 300 { + t.Fatalf("state = %+v", state) + } +} + +func TestProcessCloseFinalVoucherExceedingDepositRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + signer, channelID := openTestChannel(t, server, 1_000) + + final := signer.SignVoucher(t, channelID, 2_000, farFuture()) + _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: channelID, Voucher: &final}) + if err == nil || !strings.Contains(err.Error(), "exceeds deposit") { + t.Fatalf("err = %v, want exceeds-deposit rejection", err) + } +} + +func TestProcessCloseUnknownChannelRejected(t *testing.T) { + server := newSessionTestServer(sessionTestConfig()) + _, err := server.ProcessClose(context.Background(), &intents.ClosePayload{ChannelID: "ghost"}) + if err == nil || !strings.Contains(err.Error(), "not found") { + t.Fatalf("err = %v, want channel-not-found", err) + } +} diff --git a/go/protocols/mpp/server/session_store.go b/go/protocols/mpp/server/session_store.go new file mode 100644 index 000000000..3eee93843 --- /dev/null +++ b/go/protocols/mpp/server/session_store.go @@ -0,0 +1,309 @@ +package server + +// Per-channel state store for the MPP session server. +// +// The in-memory implementation serializes UpdateChannel calls per channel id +// with a per-channel mutex, so the read-modify-write sequence inside the +// mutator is atomic from the perspective of any other caller targeting the +// same channel while updates to different channels run concurrently. +// +// The voucher verifier (see session_voucher.go) is intentionally +// side-effect-free: it computes a verdict, and the caller persists any +// accepted delta through ChannelStore.UpdateChannel. + +import ( + "context" + "fmt" + "sync" +) + +// PendingDelivery is one delivery the server has reserved against a channel +// but not yet received a signed voucher for. +type PendingDelivery struct { + // DeliveryID is the idempotency key for this delivery. + DeliveryID string `json:"deliveryId"` + + // Amount reserved for this delivery in base units. + Amount uint64 `json:"amount"` + + // Sequence is the monotonic per-channel delivery sequence. + Sequence uint64 `json:"sequence"` + + // ExpiresAt is the Unix timestamp after which the delivery should not be + // committed. + ExpiresAt int64 `json:"expiresAt"` +} + +// CommittedDelivery is a delivery that has been committed by a signed +// voucher. Kept for idempotent commit replay. +type CommittedDelivery struct { + // DeliveryID is the idempotency key for this delivery. + DeliveryID string `json:"deliveryId"` + + // Amount committed for this delivery in base units. + Amount uint64 `json:"amount"` + + // Cumulative is the channel watermark after this commit. + Cumulative uint64 `json:"cumulative"` + + // VoucherSignature is the signature of the committing voucher (base58). + VoucherSignature string `json:"voucherSignature"` +} + +// ChannelState is the persisted state of a single payment channel from the +// server's point of view. The JSON tags are the shared snake_case wire +// names, so durable stores can interoperate across the language SDKs. +type ChannelState struct { + // ChannelID is the on-chain channel address (base58). + // + // Push sessions: the payment-channel address. + // Pull sessions: the FixedDelegation PDA address. + ChannelID string `json:"channel_id"` + + // AuthorizedSigner is the public key authorized to sign vouchers for this + // session (base58). + AuthorizedSigner string `json:"authorized_signer"` + + // Deposit is the total deposit / approved amount locked for this session + // (base units). + Deposit uint64 `json:"deposit"` + + // Cumulative is the highest cumulative amount accepted by the server (the + // settled watermark). + Cumulative uint64 `json:"cumulative"` + + // Finalized is true once the channel has been finalized on-chain. + Finalized bool `json:"finalized"` + + // HighestVoucherSignature is the signature of the highest accepted voucher + // (base58). Stored for idempotent replay detection. + HighestVoucherSignature *string `json:"highest_voucher_signature"` + + // HighestVoucherExpiresAt is the expiry timestamp from the highest + // accepted voucher. Needed when the server later settles that voucher + // on-chain. + HighestVoucherExpiresAt *int64 `json:"highest_voucher_expires_at"` + + // CloseRequestedAt is the Unix timestamp (seconds) when cooperative close + // was requested. Once set, no further vouchers are accepted. + CloseRequestedAt *uint64 `json:"close_requested_at"` + + // SettledSignature is the signature (base58) of the broadcast + // settle-and-distribute transaction. A close-pending channel with no + // settled signature is re-drivable: a close retry may attempt settlement + // again. + // + // An extension beyond the core channel-state shape, recorded only when + // this server drives on-chain settlement. Serialized with omitempty so a + // channel state without a settlement round-trips cleanly. + SettledSignature *string `json:"settled_signature,omitempty"` + + // Operator is the client wallet pubkey (base58) for pull-mode sessions; + // nil for push sessions. + Operator *string `json:"operator"` + + // NextDeliverySequence is the next server-side metered delivery sequence. + NextDeliverySequence uint64 `json:"next_delivery_sequence"` + + // PendingDeliveries are reserved by the server but not yet committed. + PendingDeliveries []PendingDelivery `json:"pending_deliveries"` + + // CommittedDeliveries are recently committed deliveries, kept for + // idempotent commit replay. + CommittedDeliveries []CommittedDelivery `json:"committed_deliveries"` +} + +// clone returns a deep copy so callers can never alias store-internal state. +func (s ChannelState) clone() ChannelState { + out := s + if s.HighestVoucherSignature != nil { + v := *s.HighestVoucherSignature + out.HighestVoucherSignature = &v + } + if s.HighestVoucherExpiresAt != nil { + v := *s.HighestVoucherExpiresAt + out.HighestVoucherExpiresAt = &v + } + if s.CloseRequestedAt != nil { + v := *s.CloseRequestedAt + out.CloseRequestedAt = &v + } + if s.SettledSignature != nil { + v := *s.SettledSignature + out.SettledSignature = &v + } + if s.Operator != nil { + v := *s.Operator + out.Operator = &v + } + if s.PendingDeliveries != nil { + out.PendingDeliveries = append([]PendingDelivery(nil), s.PendingDeliveries...) + } + if s.CommittedDeliveries != nil { + out.CommittedDeliveries = append([]CommittedDelivery(nil), s.CommittedDeliveries...) + } + return out +} + +// ListChannelsFilter is an optional filter for ChannelStore.ListChannels. +type ListChannelsFilter struct { + // Finalized, when non-nil, only includes channels matching this finalized + // state. + Finalized *bool + + // ClosePending, when non-nil, only includes channels whose + // CloseRequestedAt presence matches. + ClosePending *bool +} + +// ChannelMutator is handed to UpdateChannel. It receives the current state +// (nil if no channel exists) and returns the next state or an error, in which +// case the stored state is left unchanged. +// +// Implementations MUST guarantee the mutator runs without interleaving with +// other UpdateChannel calls for the same channel id. +type ChannelMutator func(current *ChannelState) (ChannelState, error) + +// ChannelStore is the pluggable store for per-channel session state. +// +// UpdateChannel is the only way to mutate a channel: the voucher verifier +// always needs an atomic read-modify-write to avoid double-spend under +// concurrent vouchers, so no direct put is exposed. +type ChannelStore interface { + // GetChannel reads a channel. Returns nil when it does not exist. + GetChannel(ctx context.Context, channelID string) (*ChannelState, error) + + // UpdateChannel atomically read-modify-writes a channel's state and + // returns the stored result. + UpdateChannel(ctx context.Context, channelID string, mutator ChannelMutator) (ChannelState, error) + + // DeleteChannel removes a channel from the store. Deleting a missing + // channel is a no-op. + DeleteChannel(ctx context.Context, channelID string) error + + // ListChannels returns a snapshot list. The filter is applied after read; + // nil means no filter. + ListChannels(ctx context.Context, filter *ListChannelsFilter) ([]ChannelState, error) + + // MarkFinalized flips Finalized to true. Errors when the channel is not + // found. + MarkFinalized(ctx context.Context, channelID string) (ChannelState, error) +} + +// MemoryChannelStore is an in-memory ChannelStore with per-channel locking: +// UpdateChannel calls for the same channel id run strictly sequentially while +// calls for different ids run concurrently. +type MemoryChannelStore struct { + // mu guards data and locks. + mu sync.Mutex + + // data maps channel id to stored state; values are cloned on the way in + // and out so callers never share memory with the store. + data map[string]ChannelState + + // locks holds the per-channel mutex serializing UpdateChannel calls for + // the same channel id. + locks map[string]*sync.Mutex +} + +// NewMemoryChannelStore creates an empty MemoryChannelStore. +func NewMemoryChannelStore() *MemoryChannelStore { + return &MemoryChannelStore{ + data: map[string]ChannelState{}, + locks: map[string]*sync.Mutex{}, + } +} + +// channelLock returns the mutex serializing updates for channelID. +func (s *MemoryChannelStore) channelLock(channelID string) *sync.Mutex { + s.mu.Lock() + defer s.mu.Unlock() + lock, ok := s.locks[channelID] + if !ok { + lock = &sync.Mutex{} + s.locks[channelID] = lock + } + return lock +} + +// GetChannel reads a channel. Returns nil when it does not exist. +func (s *MemoryChannelStore) GetChannel(_ context.Context, channelID string) (*ChannelState, error) { + s.mu.Lock() + defer s.mu.Unlock() + state, ok := s.data[channelID] + if !ok { + return nil, nil + } + out := state.clone() + return &out, nil +} + +// UpdateChannel atomically read-modify-writes a channel's state. A mutator +// error leaves the stored state unchanged and does not poison later updates. +func (s *MemoryChannelStore) UpdateChannel(_ context.Context, channelID string, mutator ChannelMutator) (ChannelState, error) { + lock := s.channelLock(channelID) + lock.Lock() + defer lock.Unlock() + + s.mu.Lock() + current, ok := s.data[channelID] + s.mu.Unlock() + + var currentPtr *ChannelState + if ok { + snapshot := current.clone() + currentPtr = &snapshot + } + next, err := mutator(currentPtr) + if err != nil { + return ChannelState{}, err + } + + s.mu.Lock() + s.data[channelID] = next.clone() + s.mu.Unlock() + return next, nil +} + +// DeleteChannel removes a channel from the store. +func (s *MemoryChannelStore) DeleteChannel(_ context.Context, channelID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.data, channelID) + return nil +} + +// ListChannels returns a snapshot of all channels matching the filter. +func (s *MemoryChannelStore) ListChannels(_ context.Context, filter *ListChannelsFilter) ([]ChannelState, error) { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]ChannelState, 0, len(s.data)) + for _, state := range s.data { + if filter != nil { + if filter.Finalized != nil && state.Finalized != *filter.Finalized { + continue + } + if filter.ClosePending != nil { + closePending := state.CloseRequestedAt != nil + if closePending != *filter.ClosePending { + continue + } + } + } + out = append(out, state.clone()) + } + return out, nil +} + +// MarkFinalized flips Finalized to true, erroring when the channel is +// missing. +func (s *MemoryChannelStore) MarkFinalized(ctx context.Context, channelID string) (ChannelState, error) { + return s.UpdateChannel(ctx, channelID, func(current *ChannelState) (ChannelState, error) { + if current == nil { + return ChannelState{}, fmt.Errorf("channel %s not found", channelID) + } + next := *current + next.Finalized = true + return next, nil + }) +} diff --git a/go/protocols/mpp/server/session_store_test.go b/go/protocols/mpp/server/session_store_test.go new file mode 100644 index 000000000..18cde8b27 --- /dev/null +++ b/go/protocols/mpp/server/session_store_test.go @@ -0,0 +1,269 @@ +package server + +// MemoryChannelStore coverage: insert-on-missing updates, mutator error +// handling, concurrent update serialization, list filtering, delete, +// finalization, and clone isolation. + +import ( + "context" + "errors" + "sync" + "testing" +) + +func testChannelState(channelID string, deposit uint64) ChannelState { + return ChannelState{ + ChannelID: channelID, + AuthorizedSigner: "11111111111111111111111111111111", + Deposit: deposit, + } +} + +func TestMemoryChannelStoreUpdateChannelInsertsWhenMissing(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + result, err := store.UpdateChannel(ctx, "c1", func(current *ChannelState) (ChannelState, error) { + if current != nil { + t.Fatalf("expected nil current state, got %+v", current) + } + return testChannelState("c1", 5), nil + }) + if err != nil { + t.Fatalf("UpdateChannel: %v", err) + } + if result.Deposit != 5 { + t.Fatalf("deposit = %d, want 5", result.Deposit) + } + + stored, err := store.GetChannel(ctx, "c1") + if err != nil || stored == nil { + t.Fatalf("GetChannel: state=%v err=%v", stored, err) + } + if stored.Deposit != 5 { + t.Fatalf("stored deposit = %d, want 5", stored.Deposit) + } +} + +func TestMemoryChannelStoreUpdateChannelSeesPriorWrites(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + if _, err := store.UpdateChannel(ctx, "c1", func(*ChannelState) (ChannelState, error) { + return testChannelState("c1", 1), nil + }); err != nil { + t.Fatalf("insert: %v", err) + } + + next, err := store.UpdateChannel(ctx, "c1", func(current *ChannelState) (ChannelState, error) { + if current == nil || current.Deposit != 1 { + t.Fatalf("current = %+v, want deposit 1", current) + } + out := *current + out.Deposit = 2 + return out, nil + }) + if err != nil { + t.Fatalf("update: %v", err) + } + if next.Deposit != 2 { + t.Fatalf("deposit = %d, want 2", next.Deposit) + } +} + +func TestMemoryChannelStoreSerializesConcurrentUpdates(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + if _, err := store.UpdateChannel(ctx, "c1", func(*ChannelState) (ChannelState, error) { + return testChannelState("c1", 1_000_000), nil + }); err != nil { + t.Fatalf("insert: %v", err) + } + + // Fire 50 concurrent increments; each must see the previous value. + const workers = 50 + var wg sync.WaitGroup + wg.Add(workers) + for range workers { + go func() { + defer wg.Done() + _, err := store.UpdateChannel(ctx, "c1", func(current *ChannelState) (ChannelState, error) { + out := *current + out.Cumulative++ + return out, nil + }) + if err != nil { + t.Errorf("concurrent update: %v", err) + } + }() + } + wg.Wait() + + stored, err := store.GetChannel(ctx, "c1") + if err != nil || stored == nil { + t.Fatalf("GetChannel: state=%v err=%v", stored, err) + } + if stored.Cumulative != workers { + t.Fatalf("cumulative = %d, want %d", stored.Cumulative, workers) + } +} + +func TestMemoryChannelStoreMutatorErrorLeavesStateUnchanged(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + if _, err := store.UpdateChannel(ctx, "c1", func(*ChannelState) (ChannelState, error) { + state := testChannelState("c1", 1_000_000) + state.Cumulative = 7 + return state, nil + }); err != nil { + t.Fatalf("insert: %v", err) + } + + wantErr := errors.New("nope") + if _, err := store.UpdateChannel(ctx, "c1", func(*ChannelState) (ChannelState, error) { + return ChannelState{}, wantErr + }); !errors.Is(err, wantErr) { + t.Fatalf("err = %v, want %v", err, wantErr) + } + + stored, err := store.GetChannel(ctx, "c1") + if err != nil || stored == nil { + t.Fatalf("GetChannel: state=%v err=%v", stored, err) + } + if stored.Cumulative != 7 || stored.Deposit != 1_000_000 { + t.Fatalf("state mutated by failed update: %+v", stored) + } + + // A failed update must not poison subsequent updates on the same channel. + next, err := store.UpdateChannel(ctx, "c1", func(current *ChannelState) (ChannelState, error) { + out := *current + out.Cumulative++ + return out, nil + }) + if err != nil { + t.Fatalf("follow-up update: %v", err) + } + if next.Cumulative != 8 { + t.Fatalf("cumulative = %d, want 8", next.Cumulative) + } +} + +func TestMemoryChannelStoreListChannelsAppliesFilters(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + mustInsert := func(state ChannelState) { + t.Helper() + if _, err := store.UpdateChannel(ctx, state.ChannelID, func(*ChannelState) (ChannelState, error) { + return state, nil + }); err != nil { + t.Fatalf("insert %s: %v", state.ChannelID, err) + } + } + mustInsert(testChannelState("a", 1)) + finalized := testChannelState("b", 1) + finalized.Finalized = true + mustInsert(finalized) + closing := testChannelState("c", 1) + closeAt := uint64(123) + closing.CloseRequestedAt = &closeAt + mustInsert(closing) + + all, err := store.ListChannels(ctx, nil) + if err != nil { + t.Fatalf("ListChannels: %v", err) + } + if len(all) != 3 { + t.Fatalf("len(all) = %d, want 3", len(all)) + } + + wantTrue, wantFalse := true, false + onlyFinalized, err := store.ListChannels(ctx, &ListChannelsFilter{Finalized: &wantTrue}) + if err != nil { + t.Fatalf("ListChannels finalized: %v", err) + } + if len(onlyFinalized) != 1 || onlyFinalized[0].ChannelID != "b" { + t.Fatalf("finalized filter = %+v, want only b", onlyFinalized) + } + + closePending, err := store.ListChannels(ctx, &ListChannelsFilter{Finalized: &wantFalse, ClosePending: &wantTrue}) + if err != nil { + t.Fatalf("ListChannels closePending: %v", err) + } + if len(closePending) != 1 || closePending[0].ChannelID != "c" { + t.Fatalf("closePending filter = %+v, want only c", closePending) + } +} + +func TestMemoryChannelStoreDeleteAndMarkFinalized(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + if _, err := store.UpdateChannel(ctx, "c1", func(*ChannelState) (ChannelState, error) { + return testChannelState("c1", 1), nil + }); err != nil { + t.Fatalf("insert: %v", err) + } + + state, err := store.MarkFinalized(ctx, "c1") + if err != nil { + t.Fatalf("MarkFinalized: %v", err) + } + if !state.Finalized { + t.Fatal("expected finalized state") + } + stored, err := store.GetChannel(ctx, "c1") + if err != nil || stored == nil || !stored.Finalized { + t.Fatalf("stored state = %+v err=%v, want finalized", stored, err) + } + + if err := store.DeleteChannel(ctx, "c1"); err != nil { + t.Fatalf("DeleteChannel: %v", err) + } + missing, err := store.GetChannel(ctx, "c1") + if err != nil { + t.Fatalf("GetChannel after delete: %v", err) + } + if missing != nil { + t.Fatalf("expected nil after delete, got %+v", missing) + } + + if _, err := store.MarkFinalized(ctx, "ghost"); err == nil { + t.Fatal("expected error marking missing channel finalized") + } +} + +func TestMemoryChannelStoreReturnsClones(t *testing.T) { + store := NewMemoryChannelStore() + ctx := context.Background() + + signature := "sig" + if _, err := store.UpdateChannel(ctx, "c1", func(*ChannelState) (ChannelState, error) { + state := testChannelState("c1", 1) + state.HighestVoucherSignature = &signature + state.PendingDeliveries = []PendingDelivery{{DeliveryID: "c1:1", Amount: 1, Sequence: 1, ExpiresAt: 9}} + return state, nil + }); err != nil { + t.Fatalf("insert: %v", err) + } + + got, err := store.GetChannel(ctx, "c1") + if err != nil || got == nil { + t.Fatalf("GetChannel: state=%v err=%v", got, err) + } + *got.HighestVoucherSignature = "tampered" + got.PendingDeliveries[0].Amount = 99 + + fresh, err := store.GetChannel(ctx, "c1") + if err != nil || fresh == nil { + t.Fatalf("GetChannel: state=%v err=%v", fresh, err) + } + if *fresh.HighestVoucherSignature != "sig" { + t.Fatalf("stored signature mutated through returned pointer: %q", *fresh.HighestVoucherSignature) + } + if fresh.PendingDeliveries[0].Amount != 1 { + t.Fatalf("stored pending delivery mutated through returned slice: %+v", fresh.PendingDeliveries) + } +} diff --git a/go/protocols/mpp/server/session_stream.go b/go/protocols/mpp/server/session_stream.go new file mode 100644 index 000000000..99613c42e --- /dev/null +++ b/go/protocols/mpp/server/session_stream.go @@ -0,0 +1,133 @@ +package server + +// Server-side metered SSE stream writer. +// +// Emits the Server-Sent Event frames the metered session clients decode: +// "mpp.metering" directives, "mpp.usage" final-usage events, plain data +// payload messages, and the terminal "[DONE]" sentinel. The event names are +// canonical: they are the ones the SDK session clients parse (the Go +// client's SseDecoder/ParseMeteredSseEvent among them). + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// doneSentinel is the terminal data-only message recognized by the metered +// SSE decoders alongside the "done" event name. +const doneSentinel = "[DONE]" + +// MeteredStream writes metered Server-Sent Events to an HTTP response. Build +// with NewMeteredStream; every write flushes so chunks reach the client as +// they are produced. +type MeteredStream struct { + // writer receives the encoded SSE frames. + writer io.Writer + + // flusher, when non-nil, is flushed after every frame so chunks reach + // the client incrementally; nil writers buffer as usual. + flusher http.Flusher +} + +// NewMeteredStream prepares w for Server-Sent Events (Content-Type +// text/event-stream, no caching) and returns the stream writer. The +// ResponseWriter does not need to implement http.Flusher, but streaming is +// only incremental when it does. +func NewMeteredStream(w http.ResponseWriter) *MeteredStream { + header := w.Header() + header.Set("Content-Type", "text/event-stream") + header.Set("Cache-Control", "no-cache") + header.Set("Connection", "keep-alive") + flusher, _ := w.(http.Flusher) + return &MeteredStream{writer: w, flusher: flusher} +} + +// NewMeteredStreamWriter wraps a raw writer (no header handling) for +// transports other than net/http. +func NewMeteredStreamWriter(w io.Writer) *MeteredStream { + return &MeteredStream{writer: w} +} + +// WriteEvent writes one SSE frame with an explicit event name. Empty event +// names emit a default (message) frame. The data must not be empty; +// multi-line data is split into one data: line per line per the SSE format. +func (m *MeteredStream) WriteEvent(event string, data []byte) error { + if len(data) == 0 { + return fmt.Errorf("SSE event data must not be empty") + } + frame := "" + if event != "" { + frame = "event: " + event + "\n" + } + start := 0 + for i := 0; i <= len(data); i++ { + if i == len(data) || data[i] == '\n' { + frame += "data: " + string(data[start:i]) + "\n" + start = i + 1 + } + } + frame += "\n" + if _, err := io.WriteString(m.writer, frame); err != nil { + return err + } + if m.flusher != nil { + m.flusher.Flush() + } + return nil +} + +// WriteJSON writes a default (message) frame whose data is the JSON encoding +// of v. Use for application payload chunks. +func (m *MeteredStream) WriteJSON(v any) error { + data, err := json.Marshal(v) + if err != nil { + return err + } + return m.WriteEvent("", data) +} + +// WriteMetering emits an "mpp.metering" event carrying the metering +// directive the client must commit after processing the paired payload. +func (m *MeteredStream) WriteMetering(directive intents.MeteringDirective) error { + data, err := json.Marshal(directive) + if err != nil { + return err + } + return m.WriteEvent("mpp.metering", data) +} + +// WriteUsage emits an "mpp.usage" event reporting the final amount owed for +// a streamed delivery. The amount must not exceed the amount reserved by the +// original directive. +func (m *MeteredStream) WriteUsage(usage intents.MeteringUsage) error { + data, err := json.Marshal(usage) + if err != nil { + return err + } + return m.WriteEvent("mpp.usage", data) +} + +// WriteEnvelope emits the payload as a default data frame followed by its +// "mpp.metering" directive, the pairing the metered session consumers +// expect. +func (m *MeteredStream) WriteEnvelope(payload any, directive intents.MeteringDirective) error { + if err := m.WriteJSON(payload); err != nil { + return err + } + return m.WriteMetering(directive) +} + +// WriteDone emits the terminal "[DONE]" sentinel message. +func (m *MeteredStream) WriteDone() error { + return m.WriteEvent("", []byte(doneSentinel)) +} + +// WriteDoneEvent emits an explicit "done" event, the alternative terminal +// frame the decoders accept. +func (m *MeteredStream) WriteDoneEvent() error { + return m.WriteEvent("done", []byte(doneSentinel)) +} diff --git a/go/protocols/mpp/server/session_stream_test.go b/go/protocols/mpp/server/session_stream_test.go new file mode 100644 index 000000000..e24c24fcd --- /dev/null +++ b/go/protocols/mpp/server/session_stream_test.go @@ -0,0 +1,139 @@ +package server + +// Round-trips the server-side metered SSE writer through the client metered +// SSE decoder (SseDecoder + ParseMeteredSseEvent), proving the emitted frames +// carry the event names and payloads the metered session clients consume. + +import ( + "net/http/httptest" + "strings" + "testing" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/client" + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +func decodeMeteredEvents(t *testing.T, raw string) []client.MeteredSseEvent { + t.Helper() + decoder := &client.SseDecoder{} + events, err := decoder.PushChunk([]byte(raw)) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + tail, err := decoder.Finish() + if err != nil { + t.Fatalf("Finish: %v", err) + } + events = append(events, tail...) + parsed := make([]client.MeteredSseEvent, 0, len(events)) + for _, event := range events { + metered, err := client.ParseMeteredSseEvent(event) + if err != nil { + t.Fatalf("ParseMeteredSseEvent(%+v): %v", event, err) + } + parsed = append(parsed, metered) + } + return parsed +} + +func TestMeteredStreamRoundTripsThroughClientDecoder(t *testing.T) { + recorder := httptest.NewRecorder() + stream := NewMeteredStream(recorder) + + directive := intents.MeteringDirective{ + DeliveryID: "session-1:1", + SessionID: "session-1", + Amount: "100", + Currency: "USDC", + Sequence: 1, + ExpiresAt: intents.DefaultSessionExpiresAt, + } + usage := intents.MeteringUsage{DeliveryID: "session-1:1", Amount: "80"} + + if err := stream.WriteEnvelope(map[string]string{"chunk": "A payment channel "}, directive); err != nil { + t.Fatalf("WriteEnvelope: %v", err) + } + if err := stream.WriteUsage(usage); err != nil { + t.Fatalf("WriteUsage: %v", err) + } + if err := stream.WriteDone(); err != nil { + t.Fatalf("WriteDone: %v", err) + } + + if contentType := recorder.Header().Get("Content-Type"); contentType != "text/event-stream" { + t.Fatalf("Content-Type = %q", contentType) + } + if cacheControl := recorder.Header().Get("Cache-Control"); cacheControl != "no-cache" { + t.Fatalf("Cache-Control = %q", cacheControl) + } + if !recorder.Flushed { + t.Fatal("stream writes did not flush the response") + } + + events := decodeMeteredEvents(t, recorder.Body.String()) + if len(events) != 4 { + t.Fatalf("decoded %d events, want 4: %+v", len(events), events) + } + if events[0].Kind != client.MeteredSseEventMessage || !strings.Contains(string(events[0].Message), "A payment channel") { + t.Fatalf("event 0 = %+v", events[0]) + } + if events[1].Kind != client.MeteredSseEventMetering || events[1].Metering.DeliveryID != directive.DeliveryID { + t.Fatalf("event 1 = %+v", events[1]) + } + if events[1].Metering.Amount != "100" || events[1].Metering.Sequence != 1 { + t.Fatalf("metering payload = %+v", events[1].Metering) + } + if events[2].Kind != client.MeteredSseEventUsage || events[2].Usage.Amount != "80" { + t.Fatalf("event 2 = %+v", events[2]) + } + if events[3].Kind != client.MeteredSseEventDone { + t.Fatalf("event 3 = %+v", events[3]) + } +} + +func TestMeteredStreamDoneEventVariant(t *testing.T) { + recorder := httptest.NewRecorder() + stream := NewMeteredStream(recorder) + if err := stream.WriteDoneEvent(); err != nil { + t.Fatalf("WriteDoneEvent: %v", err) + } + events := decodeMeteredEvents(t, recorder.Body.String()) + if len(events) != 1 || events[0].Kind != client.MeteredSseEventDone { + t.Fatalf("events = %+v", events) + } +} + +func TestMeteredStreamSplitsMultiLineData(t *testing.T) { + var buffer strings.Builder + stream := NewMeteredStreamWriter(&buffer) + if err := stream.WriteEvent("note", []byte("line-1\nline-2")); err != nil { + t.Fatalf("WriteEvent: %v", err) + } + raw := buffer.String() + if raw != "event: note\ndata: line-1\ndata: line-2\n\n" { + t.Fatalf("frame = %q", raw) + } + + decoder := &client.SseDecoder{} + events, err := decoder.PushChunk([]byte(raw)) + if err != nil { + t.Fatalf("PushChunk: %v", err) + } + if len(events) != 1 || events[0].Data != "line-1\nline-2" { + t.Fatalf("events = %+v", events) + } +} + +func TestMeteredStreamRejectsEmptyData(t *testing.T) { + stream := NewMeteredStreamWriter(&strings.Builder{}) + if err := stream.WriteEvent("note", nil); err == nil { + t.Fatal("expected empty-data error") + } +} + +func TestMeteredStreamWriteJSONMarshalError(t *testing.T) { + stream := NewMeteredStreamWriter(&strings.Builder{}) + if err := stream.WriteJSON(func() {}); err == nil { + t.Fatal("expected marshal error") + } +} diff --git a/go/protocols/mpp/server/session_voucher.go b/go/protocols/mpp/server/session_voucher.go new file mode 100644 index 000000000..c417a6297 --- /dev/null +++ b/go/protocols/mpp/server/session_voucher.go @@ -0,0 +1,248 @@ +package server + +// Voucher verifier for the MPP session server. +// +// Pure function: given a current channel snapshot and a signed voucher, +// decide whether to accept (and what the new watermark would be), reject, +// or treat as an idempotent replay. The caller persists any accepted delta +// through ChannelStore.UpdateChannel, re-checking inside the atomic mutator. +// +// The check sequence (order and operators) is pinned across the language +// SDKs and harness-tested: +// parse u64 -> finalized -> close pending -> idempotent replay (same +// cumulative AND same signature, signature re-verified) -> cumulative > +// watermark strictly -> cumulative <= deposit -> delta >= minVoucherDelta -> +// Ed25519 verify against the stored authorizedSigner -> expiresAt > now. + +import ( + "crypto/ed25519" + "fmt" + "strconv" + "time" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +// VoucherVerifyStatus is the outcome class of a voucher verification. +type VoucherVerifyStatus string + +const ( + // VoucherVerifyAccepted means the voucher advanced the channel watermark. + VoucherVerifyAccepted VoucherVerifyStatus = "accepted" + + // VoucherVerifyReplayed means an already-accepted voucher was re-submitted + // (idempotent). + VoucherVerifyReplayed VoucherVerifyStatus = "replayed" + + // VoucherVerifyRejected means the voucher was rejected; see + // VoucherRejectReason. + VoucherVerifyRejected VoucherVerifyStatus = "rejected" +) + +// VoucherRejectReason is a stable string tag for voucher rejections so the +// caller can map to HTTP statuses / log levels without parsing free text. +// The tag values are stable across the language SDKs. +type VoucherRejectReason string + +const ( + // VoucherRejectBelowMinDelta: the delta is below the configured minimum. + VoucherRejectBelowMinDelta VoucherRejectReason = "below-min-delta" + + // VoucherRejectChannelClosePending: a close was already requested. + VoucherRejectChannelClosePending VoucherRejectReason = "channel-close-pending" + + // VoucherRejectChannelFinalized: the channel is already finalized. + VoucherRejectChannelFinalized VoucherRejectReason = "channel-finalized" + + // VoucherRejectCumulativeNotMonotonic: the cumulative does not strictly + // exceed the watermark. + VoucherRejectCumulativeNotMonotonic VoucherRejectReason = "cumulative-not-monotonic" + + // VoucherRejectExceedsDeposit: the cumulative exceeds the deposit cap. + VoucherRejectExceedsDeposit VoucherRejectReason = "exceeds-deposit" + + // VoucherRejectExpired: the voucher expiry is not in the future. + VoucherRejectExpired VoucherRejectReason = "expired" + + // VoucherRejectInvalidCumulative: the cumulative does not parse as a u64. + VoucherRejectInvalidCumulative VoucherRejectReason = "invalid-cumulative" + + // VoucherRejectInvalidSignature: the Ed25519 signature check failed. + VoucherRejectInvalidSignature VoucherRejectReason = "invalid-signature" +) + +// VoucherVerifyResult is the verdict of VerifyVoucherForChannel. +// +// Status selects which fields are meaningful: NewCumulative for accepted and +// replayed; NewExpiresAt and NewSignature for accepted only; Reason and +// Detail for rejected only. +type VoucherVerifyResult struct { + // Status is the outcome class. + Status VoucherVerifyStatus + + // NewCumulative is the watermark to persist (accepted) or the existing + // watermark (replayed). + NewCumulative uint64 + + // NewExpiresAt is the expiry of the now-highest voucher (accepted only). + NewExpiresAt int64 + + // NewSignature is the signature to persist as HighestVoucherSignature + // (accepted only, base58). + NewSignature string + + // Reason is the stable rejection tag (rejected only). + Reason VoucherRejectReason + + // Detail is a human-readable rejection detail. Safe to log; not stable. + Detail string +} + +// VerifyVoucherArgs are the inputs to VerifyVoucherForChannel. +type VerifyVoucherArgs struct { + // State is the channel snapshot, typically read just before calling. + State ChannelState + + // Signed is the voucher being submitted. + Signed intents.SignedVoucher + + // Deposit is the authoritative deposit cap. Passed in (rather than read + // off State) because some callers carry an updated cap after a recent + // top-up that has not yet been written back into the store. + Deposit uint64 + + // MinVoucherDelta is the optional minimum delta from the previous + // cumulative. Zero disables the check. + MinVoucherDelta uint64 + + // NowSeconds overrides the clock (Unix seconds) for deterministic tests. + // Nil defaults to time.Now(). + NowSeconds *int64 +} + +// VerifyVoucherForChannel verifies a voucher against a channel snapshot. +// +// Returns a verdict; the caller is responsible for persisting any accepted +// delta via ChannelStore.UpdateChannel. The verifier is pure: no store, +// network, or clock side effects (the clock is injectable). +func VerifyVoucherForChannel(args VerifyVoucherArgs) VoucherVerifyResult { + state := args.State + signed := args.Signed + + // 1. Parse new cumulative from the payload. + newCumulative, err := strconv.ParseUint(signed.Data.Cumulative, 10, 64) + if err != nil { + return voucherReject(VoucherRejectInvalidCumulative, + fmt.Sprintf("invalid cumulative in voucher: %s", signed.Data.Cumulative)) + } + + // 2. Channel must not be finalized. + if state.Finalized { + return voucherReject(VoucherRejectChannelFinalized, + fmt.Sprintf("channel %s is already finalized", state.ChannelID)) + } + + // 3. Channel must not be in close-pending. + if state.CloseRequestedAt != nil { + return voucherReject(VoucherRejectChannelClosePending, + fmt.Sprintf("channel %s close is pending; no further vouchers accepted", state.ChannelID)) + } + + // 4. Idempotent replay: same cumulative AND same signature. The signature + // is re-verified so a replay of a forged voucher cannot slip through. + if newCumulative == state.Cumulative && + state.HighestVoucherSignature != nil && + *state.HighestVoucherSignature == signed.Signature { + if err := verifyVoucherSignatureBytes(signed, state.AuthorizedSigner); err != nil { + return voucherReject(VoucherRejectInvalidSignature, err.Error()) + } + if signed.Data.ExpiresAt <= voucherNow(args.NowSeconds) { + return voucherReject(VoucherRejectExpired, "voucher has expired") + } + return VoucherVerifyResult{Status: VoucherVerifyReplayed, NewCumulative: newCumulative} + } + + // 5. Must strictly exceed the watermark (non-replay case). + if newCumulative <= state.Cumulative { + return voucherReject(VoucherRejectCumulativeNotMonotonic, + fmt.Sprintf("voucher cumulative %d must exceed watermark %d", newCumulative, state.Cumulative)) + } + + // 6. Must not exceed the deposit. + if newCumulative > args.Deposit { + return voucherReject(VoucherRejectExceedsDeposit, + fmt.Sprintf("voucher cumulative %d exceeds deposit %d", newCumulative, args.Deposit)) + } + + // 7. Min delta check. + delta := newCumulative - state.Cumulative + if args.MinVoucherDelta > 0 && delta < args.MinVoucherDelta { + return voucherReject(VoucherRejectBelowMinDelta, + fmt.Sprintf("voucher delta %d is below minimum %d", delta, args.MinVoucherDelta)) + } + + // 8. Verify the Ed25519 signature over the 48-byte canonical payload. + if err := verifyVoucherSignatureBytes(signed, state.AuthorizedSigner); err != nil { + return voucherReject(VoucherRejectInvalidSignature, err.Error()) + } + + // 9. Expiry. The caller may override NowSeconds for deterministic tests. + if signed.Data.ExpiresAt <= voucherNow(args.NowSeconds) { + return voucherReject(VoucherRejectExpired, "voucher has expired") + } + + return VoucherVerifyResult{ + Status: VoucherVerifyAccepted, + NewCumulative: newCumulative, + NewExpiresAt: signed.Data.ExpiresAt, + NewSignature: signed.Signature, + } +} + +// voucherReject builds a rejected verdict. +func voucherReject(reason VoucherRejectReason, detail string) VoucherVerifyResult { + return VoucherVerifyResult{Status: VoucherVerifyRejected, Reason: reason, Detail: detail} +} + +// voucherNow returns the override when set, otherwise the wall clock in Unix +// seconds. +func voucherNow(override *int64) int64 { + if override != nil { + return *override + } + return time.Now().Unix() +} + +// verifyVoucherSignatureBytes checks the voucher's Ed25519 signature over the +// canonical 48-byte voucher payload against the authorized signer (both +// base58). The expiry check is not included; callers order it explicitly. +func verifyVoucherSignatureBytes(signed intents.SignedVoucher, authorizedSigner string) error { + message, err := signed.Data.MessageBytes() + if err != nil { + return err + } + signature, err := solana.SignatureFromBase58(signed.Signature) + if err != nil { + return fmt.Errorf("invalid signature encoding: %w", err) + } + pubkey, err := solana.PublicKeyFromBase58(authorizedSigner) + if err != nil { + return fmt.Errorf("invalid authorized signer: %w", err) + } + if !ed25519.Verify(ed25519.PublicKey(pubkey.Bytes()), message, signature[:]) { + return fmt.Errorf("voucher signature verification failed") + } + return nil +} + +// verifySessionVoucher checks expiry first (against the wall clock), then +// the Ed25519 signature. Used by the commit and close paths; the voucher +// handler orders the two checks itself. +func verifySessionVoucher(signed intents.SignedVoucher, authorizedSigner string) error { + if signed.Data.ExpiresAt <= time.Now().Unix() { + return fmt.Errorf("voucher has expired") + } + return verifyVoucherSignatureBytes(signed, authorizedSigner) +} diff --git a/go/protocols/mpp/server/session_voucher_test.go b/go/protocols/mpp/server/session_voucher_test.go new file mode 100644 index 000000000..f1cd469fc --- /dev/null +++ b/go/protocols/mpp/server/session_voucher_test.go @@ -0,0 +1,359 @@ +package server + +// Voucher verifier coverage plus adversarial ordering checks: the check +// sequence (order and operators) is part of the wire contract. + +import ( + "crypto/ed25519" + "crypto/rand" + "strconv" + "testing" + "time" + + solana "github.com/gagliardetto/solana-go" + + "github.com/solana-foundation/pay-kit/go/protocols/mpp/intents" +) + +const testVoucherChannelID = "11111111111111111111111111111111" + +// testVoucherSigner is an in-memory Ed25519 keypair for voucher tests. +type testVoucherSigner struct { + pub ed25519.PublicKey // verify key; its base58 form is the channel's authorized signer + priv ed25519.PrivateKey // signing key for the canonical 48-byte voucher preimage +} + +func newTestVoucherSigner(t *testing.T) testVoucherSigner { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("generate keypair: %v", err) + } + return testVoucherSigner{pub: pub, priv: priv} +} + +// Address returns the signer pubkey as base58. +func (s testVoucherSigner) Address() string { + return solana.PublicKeyFromBytes(s.pub).String() +} + +// SignVoucher signs the canonical 48-byte voucher payload. +func (s testVoucherSigner) SignVoucher(t *testing.T, channelID string, cumulative uint64, expiresAt int64) intents.SignedVoucher { + t.Helper() + data := intents.VoucherData{ + ChannelID: channelID, + Cumulative: strconv.FormatUint(cumulative, 10), + ExpiresAt: expiresAt, + } + message, err := data.MessageBytes() + if err != nil { + t.Fatalf("voucher message bytes: %v", err) + } + signature := ed25519.Sign(s.priv, message) + return intents.SignedVoucher{Data: data, Signature: solana.SignatureFromBytes(signature).String()} +} + +func farFuture() int64 { + return time.Now().Unix() + 3600 +} + +func voucherTestState(authorizedSigner string) ChannelState { + return ChannelState{ + ChannelID: testVoucherChannelID, + AuthorizedSigner: authorizedSigner, + Deposit: 1_000, + } +} + +func TestVerifyVoucherForChannelHappyPath(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + expiresAt := farFuture() + voucher := signer.SignVoucher(t, state.ChannelID, 100, expiresAt) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: state.Deposit}) + if result.Status != VoucherVerifyAccepted { + t.Fatalf("status = %s (%s: %s), want accepted", result.Status, result.Reason, result.Detail) + } + if result.NewCumulative != 100 { + t.Fatalf("newCumulative = %d, want 100", result.NewCumulative) + } + if result.NewSignature != voucher.Signature { + t.Fatalf("newSignature = %q, want voucher signature", result.NewSignature) + } + if result.NewExpiresAt != expiresAt { + t.Fatalf("newExpiresAt = %d, want %d", result.NewExpiresAt, expiresAt) + } +} + +func TestVerifyVoucherForChannelIdempotentReplay(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + state := voucherTestState(signer.Address()) + state.Cumulative = 100 + state.HighestVoucherSignature = &voucher.Signature + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyReplayed { + t.Fatalf("status = %s, want replayed", result.Status) + } + if result.NewCumulative != 100 { + t.Fatalf("newCumulative = %d, want 100", result.NewCumulative) + } +} + +func TestVerifyVoucherForChannelReplayReVerifiesSignature(t *testing.T) { + signer := newTestVoucherSigner(t) + forger := newTestVoucherSigner(t) + // A forged voucher whose signature somehow got persisted as the highest: + // the replay path must still reject it on signature re-verification. + forged := forger.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + state := voucherTestState(signer.Address()) + state.Cumulative = 100 + state.HighestVoucherSignature = &forged.Signature + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: forged, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectInvalidSignature { + t.Fatalf("result = %+v, want invalid-signature rejection", result) + } +} + +func TestVerifyVoucherForChannelReplayOfExpiredVoucherRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + past := time.Now().Unix() - 10 + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, past) + state := voucherTestState(signer.Address()) + state.Cumulative = 100 + state.HighestVoucherSignature = &voucher.Signature + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectExpired { + t.Fatalf("result = %+v, want expired rejection", result) + } +} + +func TestVerifyVoucherForChannelDecreasingCumulativeRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 50, farFuture()) + state := voucherTestState(signer.Address()) + state.Cumulative = 100 + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectCumulativeNotMonotonic { + t.Fatalf("result = %+v, want cumulative-not-monotonic rejection", result) + } +} + +func TestVerifyVoucherForChannelEqualCumulativeWithoutMatchingSignatureRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + otherSignature := "5J6vbXSpEpGv4VLLqDhuRG6Tbj5n6dgEgvtTwTKpoSjvSwLTW9PSqQc6dpMUDPCvD3KZ5dGsmiTk5jzwYZyD8Xkz" + state := voucherTestState(signer.Address()) + state.Cumulative = 100 + state.HighestVoucherSignature = &otherSignature + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectCumulativeNotMonotonic { + t.Fatalf("result = %+v, want cumulative-not-monotonic rejection", result) + } +} + +func TestVerifyVoucherForChannelExceedsDepositRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 2_000, farFuture()) + state := voucherTestState(signer.Address()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectExceedsDeposit { + t.Fatalf("result = %+v, want exceeds-deposit rejection", result) + } +} + +func TestVerifyVoucherForChannelBelowMinDeltaRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 5, farFuture()) + state := voucherTestState(signer.Address()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{ + State: state, + Signed: voucher, + Deposit: 1_000, + MinVoucherDelta: 100, + }) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectBelowMinDelta { + t.Fatalf("result = %+v, want below-min-delta rejection", result) + } +} + +func TestVerifyVoucherForChannelBadSignatureRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + other := newTestVoucherSigner(t) + // Sign with other, but the channel authorizes signer; sig must fail. + voucher := other.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + state := voucherTestState(signer.Address()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectInvalidSignature { + t.Fatalf("result = %+v, want invalid-signature rejection", result) + } +} + +func TestVerifyVoucherForChannelExpiredRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, time.Now().Unix()-10) + state := voucherTestState(signer.Address()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectExpired { + t.Fatalf("result = %+v, want expired rejection", result) + } +} + +func TestVerifyVoucherForChannelFinalizedRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + state := voucherTestState(signer.Address()) + state.Finalized = true + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectChannelFinalized { + t.Fatalf("result = %+v, want channel-finalized rejection", result) + } +} + +func TestVerifyVoucherForChannelClosePendingRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + state := voucherTestState(signer.Address()) + closeAt := uint64(1) + state.CloseRequestedAt = &closeAt + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectChannelClosePending { + t.Fatalf("result = %+v, want channel-close-pending rejection", result) + } +} + +func TestVerifyVoucherForChannelNowSecondsOverrideIsDeterministic(t *testing.T) { + signer := newTestVoucherSigner(t) + voucher := signer.SignVoucher(t, testVoucherChannelID, 100, 1_000) + state := voucherTestState(signer.Address()) + + late := int64(2_000) + expired := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000, NowSeconds: &late}) + if expired.Status != VoucherVerifyRejected || expired.Reason != VoucherRejectExpired { + t.Fatalf("result = %+v, want expired rejection at now=2000", expired) + } + + early := int64(500) + fresh := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000, NowSeconds: &early}) + if fresh.Status != VoucherVerifyAccepted { + t.Fatalf("result = %+v, want accepted at now=500", fresh) + } +} + +func TestVerifyVoucherForChannelInvalidCumulativeRejected(t *testing.T) { + signer := newTestVoucherSigner(t) + real := signer.SignVoucher(t, testVoucherChannelID, 100, farFuture()) + // Tamper the data field after signing; the verifier should reject on + // parse before the signature check. + tampered := intents.SignedVoucher{ + Data: intents.VoucherData{ + ChannelID: real.Data.ChannelID, + Cumulative: "not-a-number", + ExpiresAt: real.Data.ExpiresAt, + }, + Signature: real.Signature, + } + state := voucherTestState(signer.Address()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: tampered, Deposit: 1_000}) + if result.Status != VoucherVerifyRejected || result.Reason != VoucherRejectInvalidCumulative { + t.Fatalf("result = %+v, want invalid-cumulative rejection", result) + } +} + +// Ordering checks: each earlier step must win over every later failure +// present in the same voucher. + +func TestVerifyVoucherForChannelOrderingParseBeatsFinalized(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + state.Finalized = true + voucher := intents.SignedVoucher{ + Data: intents.VoucherData{ChannelID: state.ChannelID, Cumulative: "bogus", ExpiresAt: farFuture()}, + Signature: "sig", + } + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Reason != VoucherRejectInvalidCumulative { + t.Fatalf("reason = %s, want invalid-cumulative before channel-finalized", result.Reason) + } +} + +func TestVerifyVoucherForChannelOrderingFinalizedBeatsClosePending(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + state.Finalized = true + closeAt := uint64(1) + state.CloseRequestedAt = &closeAt + voucher := signer.SignVoucher(t, state.ChannelID, 100, farFuture()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Reason != VoucherRejectChannelFinalized { + t.Fatalf("reason = %s, want channel-finalized before channel-close-pending", result.Reason) + } +} + +func TestVerifyVoucherForChannelOrderingMonotonicBeatsDeposit(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + state.Deposit = 10 + state.Cumulative = 100 + // Non-monotonic AND over deposit: monotonicity is checked first. + voucher := signer.SignVoucher(t, state.ChannelID, 50, farFuture()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 10}) + if result.Reason != VoucherRejectCumulativeNotMonotonic { + t.Fatalf("reason = %s, want cumulative-not-monotonic before exceeds-deposit", result.Reason) + } +} + +func TestVerifyVoucherForChannelOrderingDepositBeatsMinDelta(t *testing.T) { + signer := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + state.Deposit = 10 + // Over deposit AND below min delta relative to a large min: deposit wins. + voucher := signer.SignVoucher(t, state.ChannelID, 20, farFuture()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 10, MinVoucherDelta: 100}) + if result.Reason != VoucherRejectExceedsDeposit { + t.Fatalf("reason = %s, want exceeds-deposit before below-min-delta", result.Reason) + } +} + +func TestVerifyVoucherForChannelOrderingMinDeltaBeatsSignature(t *testing.T) { + signer := newTestVoucherSigner(t) + other := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + // Below min delta AND wrongly signed: min delta is checked first. + voucher := other.SignVoucher(t, state.ChannelID, 5, farFuture()) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000, MinVoucherDelta: 100}) + if result.Reason != VoucherRejectBelowMinDelta { + t.Fatalf("reason = %s, want below-min-delta before invalid-signature", result.Reason) + } +} + +func TestVerifyVoucherForChannelOrderingSignatureBeatsExpiry(t *testing.T) { + signer := newTestVoucherSigner(t) + other := newTestVoucherSigner(t) + state := voucherTestState(signer.Address()) + // Wrongly signed AND expired: the signature is verified before expiry. + voucher := other.SignVoucher(t, state.ChannelID, 100, time.Now().Unix()-10) + + result := VerifyVoucherForChannel(VerifyVoucherArgs{State: state, Signed: voucher, Deposit: 1_000}) + if result.Reason != VoucherRejectInvalidSignature { + t.Fatalf("reason = %s, want invalid-signature before expired", result.Reason) + } +} diff --git a/go/protocols/mpp/wire/types.go b/go/protocols/mpp/wire/types.go index c81f3fc08..d3ff9d056 100644 --- a/go/protocols/mpp/wire/types.go +++ b/go/protocols/mpp/wire/types.go @@ -36,8 +36,14 @@ func NewIntentName(name string) IntentName { return IntentName(strings.ToLower(n // IsCharge returns whether the intent is the charge intent. func (i IntentName) IsCharge() bool { return strings.EqualFold(string(i), "charge") } +// IsSession returns whether the intent is the session intent. +func (i IntentName) IsSession() bool { return strings.EqualFold(string(i), "session") } + // Base64URLJSON preserves a base64url-encoded JSON blob. type Base64URLJSON struct { + // raw is the base64url-encoded JSON kept verbatim as it appeared on + // the wire (never re-encoded), so the HMAC challenge ID computed over + // it stays byte-stable; "" means the value is absent. raw string } diff --git a/go/protocols/mpp/wire/types_test.go b/go/protocols/mpp/wire/types_test.go index 3cc512c8d..7e42d073a 100644 --- a/go/protocols/mpp/wire/types_test.go +++ b/go/protocols/mpp/wire/types_test.go @@ -119,6 +119,15 @@ func TestIntentNameIsCharge(t *testing.T) { } } +func TestIntentNameIsSession(t *testing.T) { + if !NewIntentName("Session").IsSession() { + t.Fatal("expected session intent") + } + if NewIntentName("charge").IsSession() { + t.Fatal("charge must not be a session intent") + } +} + func TestMethodNameInvalid(t *testing.T) { tests := []struct { name string diff --git a/go/protocols/programs/paymentchannels/account_channel.go b/go/protocols/programs/paymentchannels/account_channel.go new file mode 100644 index 000000000..b3f63feaf --- /dev/null +++ b/go/protocols/programs/paymentchannels/account_channel.go @@ -0,0 +1,251 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +type Channel struct { + // Discriminator is the single-byte account tag stored at offset 0; + // AccountDiscriminator_Channel (0) marks a live channel account. + Discriminator uint8 + // Version is the channel account layout version byte, allowing the + // program to migrate the on-chain struct shape. + Version uint8 + // Bump is the canonical PDA bump seed for the channel address derived + // from ["channel", payer, payee, mint, authorizedSigner, salt as + // little-endian u64]. + Bump uint8 + // Status is the channel lifecycle state as a ChannelStatus byte: + // 0 Open, 1 Finalized, 2 Closing. + Status uint8 + // Salt is the caller-chosen u64 mixed little-endian into the channel PDA + // seeds, letting one payer/payee/mint/signer tuple open many channels. + Salt uint64 + // Deposit is the total locked in the channel (open deposit plus topUps), + // in base units of Mint (e.g. 6 decimals for USDC). + Deposit uint64 + // Settlement holds the cumulative settled and paid-out watermarks + // accrued over the channel lifetime. + Settlement SettlementWatermarks + // ClosureStartedAt is the Unix timestamp (seconds) when channel closure + // was requested, starting the grace period; zero while no closure is + // pending. + ClosureStartedAt int64 + // PayerWithdrawnAt is the Unix timestamp (seconds) when the payer + // withdrew the unsettled remainder; zero until that withdrawal happens. + PayerWithdrawnAt int64 + // GracePeriod is the close grace period in seconds: the window after + // closure starts during which outstanding vouchers can still be settled. + GracePeriod uint32 + // DistributionHash is the 32-byte blake3 commitment, fixed at open, over + // the recipient split list (count as little-endian u32, then each + // recipient pubkey followed by its bps as little-endian u16); the + // distribute instruction must supply a list hashing to this value. + DistributionHash [32]uint8 + // Payer is the raw 32-byte public key of the wallet that funded the + // channel and is refunded the unsettled remainder. + Payer ag_solanago.PublicKey + // Payee is the raw 32-byte public key of the primary payment recipient. + Payee ag_solanago.PublicKey + // AuthorizedSigner is the public key whose Ed25519 signature over the + // 48-byte voucher preimage authorizes settlements against this channel. + AuthorizedSigner ag_solanago.PublicKey + // Mint is the SPL token mint locked in the channel; all channel amounts + // are denominated in this mint's base units. + Mint ag_solanago.PublicKey +} + +func (obj *Channel) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(obj.Discriminator) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Version) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Bump) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Status) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Salt) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Deposit) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Settlement) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.ClosureStartedAt) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.PayerWithdrawnAt) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.GracePeriod) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.DistributionHash) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Payer) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Payee) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.AuthorizedSigner) + if err != nil { + return err + } + } + { + err := encoder.Encode(obj.Mint) + if err != nil { + return err + } + } + return nil +} + +func (obj *Channel) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + err := decoder.Decode(&obj.Discriminator) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Version) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Bump) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Status) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Salt) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Deposit) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Settlement) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.ClosureStartedAt) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.PayerWithdrawnAt) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.GracePeriod) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.DistributionHash) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Payer) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Payee) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.AuthorizedSigner) + if err != nil { + return err + } + } + { + err := decoder.Decode(&obj.Mint) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/account_closed_channel.go b/go/protocols/programs/paymentchannels/account_closed_channel.go new file mode 100644 index 000000000..7defb294b --- /dev/null +++ b/go/protocols/programs/paymentchannels/account_closed_channel.go @@ -0,0 +1,38 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + ag_binary "github.com/gagliardetto/binary" +) + +type ClosedChannel struct { + // Discriminator is the single-byte account tag stored at offset 0; + // AccountDiscriminator_ClosedChannel (1) marks a channel account that + // has been closed, distinguishing it from a live Channel. + Discriminator uint8 +} + +func (obj *ClosedChannel) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(obj.Discriminator) + if err != nil { + return err + } + } + return nil +} + +func (obj *ClosedChannel) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + err := decoder.Decode(&obj.Discriminator) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/errors.go b/go/protocols/programs/paymentchannels/errors.go new file mode 100644 index 000000000..993815ffe --- /dev/null +++ b/go/protocols/programs/paymentchannels/errors.go @@ -0,0 +1,268 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" +) + +// PaymentChannelsError represents a program-specific error. +type PaymentChannelsError uint32 + +const ( + // NotImplementedError - Not implemented + NotImplementedError PaymentChannelsError = 0x0 + // MissingRequiredSignatureError - A signature was required but not found + MissingRequiredSignatureError PaymentChannelsError = 0x1 + // InvalidChannelStatusError - Invalid channel status + InvalidChannelStatusError PaymentChannelsError = 0x2 + // InvalidAccountDiscriminatorError - Invalid account discriminator + InvalidAccountDiscriminatorError PaymentChannelsError = 0x3 + // UnsupportedChannelVersionError - Unsupported channel version + UnsupportedChannelVersionError PaymentChannelsError = 0x4 + // InvalidChannelPayerError - Account does not match channel payer + InvalidChannelPayerError PaymentChannelsError = 0x5 + // InvalidChannelPayeeError - Account does not match channel payee + InvalidChannelPayeeError PaymentChannelsError = 0x6 + // InvalidChannelMintError - Account does not match channel mint + InvalidChannelMintError PaymentChannelsError = 0x7 + // InvalidEventAuthorityError - Invalid event authority + InvalidEventAuthorityError PaymentChannelsError = 0x8 + // NotEnoughAccountKeysError - Not enough accounts were provided + NotEnoughAccountKeysError PaymentChannelsError = 0x9 + // ChannelAccountMismatchError - Channel account does not match derived PDA + ChannelAccountMismatchError PaymentChannelsError = 0x32 + // InvalidChannelTokenAccountError - Channel token account is not ATA(channel, mint, token_program) + InvalidChannelTokenAccountError PaymentChannelsError = 0x33 + // InvalidChannelTokenExtensionsError - Channel token account has invalid extensions + InvalidChannelTokenExtensionsError PaymentChannelsError = 0x34 + // MintAccountMismatchError - Mint account does not match channel.mint + MintAccountMismatchError PaymentChannelsError = 0x35 + // InvalidMintTokenProgramError - Token program must be SPL Token or Token-2022 + InvalidMintTokenProgramError PaymentChannelsError = 0x36 + // MalformedMintTokenAccountDataError - Token account or mint TLV trailer is malformed + MalformedMintTokenAccountDataError PaymentChannelsError = 0x37 + // MalformedMintTokenExtensionsError - Token account or mint TLV trailer is malformed + MalformedMintTokenExtensionsError PaymentChannelsError = 0x38 + // PayerAccountMismatchError - Payer token account is not ATA(payer, token_program, mint) + PayerAccountMismatchError PaymentChannelsError = 0x39 + // InvalidPayerTokenAccountError - Payer token account is invalid + InvalidPayerTokenAccountError PaymentChannelsError = 0x3A + // InvalidPayerTokenExtensionsError - Payer token account has invalid extensions + InvalidPayerTokenExtensionsError PaymentChannelsError = 0x3B + // PayeeAccountMismatchError - Payee token account is not ATA(payee, token_program, mint) + PayeeAccountMismatchError PaymentChannelsError = 0x3C + // InvalidPayeeTokenAccountError - Payee token account is invalid + InvalidPayeeTokenAccountError PaymentChannelsError = 0x3D + // InvalidPayeeTokenExtensionsError - Payee token account has invalid extensions + InvalidPayeeTokenExtensionsError PaymentChannelsError = 0x3E + // DepositMustBeNonZeroError - Deposit must be non-zero + DepositMustBeNonZeroError PaymentChannelsError = 0xC8 + // GracePeriodMustBeNonZeroError - Grace period must be non-zero + GracePeriodMustBeNonZeroError PaymentChannelsError = 0xC9 + // MissingEd25519VerificationError - Missing Ed25519 precompile ix at current-1 + MissingEd25519VerificationError PaymentChannelsError = 0xE6 + // MalformedEd25519InstructionError - Malformed Ed25519 precompile instruction + MalformedEd25519InstructionError PaymentChannelsError = 0xE7 + // VoucherChannelMismatchError - Voucher channel_id does not match channel PDA + VoucherChannelMismatchError PaymentChannelsError = 0xE8 + // VoucherExpiredError - Voucher expired + VoucherExpiredError PaymentChannelsError = 0xE9 + // VoucherWatermarkNotMonotonicError - Voucher watermark not strictly monotonic + VoucherWatermarkNotMonotonicError PaymentChannelsError = 0xEA + // VoucherOverDepositError - Voucher cumulative_amount exceeds channel deposit + VoucherOverDepositError PaymentChannelsError = 0xEB + // VoucherMessageMismatchError - Ed25519 message does not match Borsh voucher payload + VoucherMessageMismatchError PaymentChannelsError = 0xEC + // VoucherSignerMismatchError - Voucher signer does not match channel authorized_signer + VoucherSignerMismatchError PaymentChannelsError = 0xED + // InvalidRecipientCountError - num_recipients outside [0, 32] + InvalidRecipientCountError PaymentChannelsError = 0x104 + // InvalidSplitConfigError - Each shareBps must be non-zero and Σbps must be at most 10_000 + InvalidSplitConfigError PaymentChannelsError = 0x105 + // DistributionPartsOverflowError - num_recipients outside [0, 32] + DistributionPartsOverflowError PaymentChannelsError = 0x106 + // DuplicateRecipientError - Distribution plan contains a duplicate recipient address + DuplicateRecipientError PaymentChannelsError = 0x107 + // DistributionAmountOverflowError - num_recipients outside [0, 32] + DistributionAmountOverflowError PaymentChannelsError = 0x108 + // DistributionPreimageLengthOverflowError - Distribution preimage length calculation overflow + DistributionPreimageLengthOverflowError PaymentChannelsError = 0x109 + // ChannelAddressMismatchError - Derived channel account address does not match the user provided address + ChannelAddressMismatchError PaymentChannelsError = 0x7D0 + // PayerPayeeMustDifferError - Payer and payee must be different accounts + PayerPayeeMustDifferError PaymentChannelsError = 0x7D1 + // InvalidAuthorizedSignerError - authorized_signer must be a valid Ed25519 public key + InvalidAuthorizedSignerError PaymentChannelsError = 0x7D2 + // TopUpDepositOverflowError - Deposit must be non-zero + TopUpDepositOverflowError PaymentChannelsError = 0x834 + // FinalizeDeadlineOverflowError - Deadline overflow on grace period + FinalizeDeadlineOverflowError PaymentChannelsError = 0x898 + // PayerAlreadyWithdrawnError - Payer refund has already been claimed + PayerAlreadyWithdrawnError PaymentChannelsError = 0x8FC + // RefundCalculationOverflowError - Payer refund amount calculation underflow + RefundCalculationOverflowError PaymentChannelsError = 0x8FD + // ChannelNotDistributableError - Channel is not in OPEN or FINALIZED + ChannelNotDistributableError PaymentChannelsError = 0x960 + // TreasuryAccountMismatchError - Treasury token account is not ATA(TREASURY_OWNER, mint, token_program) + TreasuryAccountMismatchError PaymentChannelsError = 0x961 + // InvalidTreasuryTokenAccountError - Treasury token account is invalid + InvalidTreasuryTokenAccountError PaymentChannelsError = 0x962 + // InvalidTreasuryTokenExtensionsError - Treasury token account has invalid extensions + InvalidTreasuryTokenExtensionsError PaymentChannelsError = 0x963 + // RecipientAccountMismatchError - Recipient token account is not ATA(recipient, token_program, mint) + RecipientAccountMismatchError PaymentChannelsError = 0x964 + // InvalidRecipientTokenAccountError - Recipient token account is invalid + InvalidRecipientTokenAccountError PaymentChannelsError = 0x965 + // InvalidRecipientTokenExtensionsError - Recipient token account has invalid extensions + InvalidRecipientTokenExtensionsError PaymentChannelsError = 0x966 + // InvalidDistributionHashError - Distribution hash mismatch + InvalidDistributionHashError PaymentChannelsError = 0x967 + // NothingToDistributeError - No newly settled funds to distribute + NothingToDistributeError PaymentChannelsError = 0x968 + // RecipientAccountCountMismatchError - Recipient ATA tail length does not match the committed plan's entry count + RecipientAccountCountMismatchError PaymentChannelsError = 0x969 + // DistributePoolOverflowError - Distribution pool calculation underflow + DistributePoolOverflowError PaymentChannelsError = 0x96A + // DistributeBalanceCalculationOverflowError - Channel rent rebalance calculation underflow + DistributeBalanceCalculationOverflowError PaymentChannelsError = 0x96B + // DistributePayerBalanceOverflowError - Payer lamports overflow on rent refund + DistributePayerBalanceOverflowError PaymentChannelsError = 0x96C + // DistributeTransferQueueOverflowError - Transfer queue capacity exceeded + DistributeTransferQueueOverflowError PaymentChannelsError = 0x96D +) + +func (e PaymentChannelsError) Error() string { + switch e { + case NotImplementedError: + return "Not implemented" + case MissingRequiredSignatureError: + return "A signature was required but not found" + case InvalidChannelStatusError: + return "Invalid channel status" + case InvalidAccountDiscriminatorError: + return "Invalid account discriminator" + case UnsupportedChannelVersionError: + return "Unsupported channel version" + case InvalidChannelPayerError: + return "Account does not match channel payer" + case InvalidChannelPayeeError: + return "Account does not match channel payee" + case InvalidChannelMintError: + return "Account does not match channel mint" + case InvalidEventAuthorityError: + return "Invalid event authority" + case NotEnoughAccountKeysError: + return "Not enough accounts were provided" + case ChannelAccountMismatchError: + return "Channel account does not match derived PDA" + case InvalidChannelTokenAccountError: + return "Channel token account is not ATA(channel, mint, token_program)" + case InvalidChannelTokenExtensionsError: + return "Channel token account has invalid extensions" + case MintAccountMismatchError: + return "Mint account does not match channel.mint" + case InvalidMintTokenProgramError: + return "Token program must be SPL Token or Token-2022" + case MalformedMintTokenAccountDataError: + return "Token account or mint TLV trailer is malformed" + case MalformedMintTokenExtensionsError: + return "Token account or mint TLV trailer is malformed" + case PayerAccountMismatchError: + return "Payer token account is not ATA(payer, token_program, mint)" + case InvalidPayerTokenAccountError: + return "Payer token account is invalid" + case InvalidPayerTokenExtensionsError: + return "Payer token account has invalid extensions" + case PayeeAccountMismatchError: + return "Payee token account is not ATA(payee, token_program, mint)" + case InvalidPayeeTokenAccountError: + return "Payee token account is invalid" + case InvalidPayeeTokenExtensionsError: + return "Payee token account has invalid extensions" + case DepositMustBeNonZeroError: + return "Deposit must be non-zero" + case GracePeriodMustBeNonZeroError: + return "Grace period must be non-zero" + case MissingEd25519VerificationError: + return "Missing Ed25519 precompile ix at current-1" + case MalformedEd25519InstructionError: + return "Malformed Ed25519 precompile instruction" + case VoucherChannelMismatchError: + return "Voucher channel_id does not match channel PDA" + case VoucherExpiredError: + return "Voucher expired" + case VoucherWatermarkNotMonotonicError: + return "Voucher watermark not strictly monotonic" + case VoucherOverDepositError: + return "Voucher cumulative_amount exceeds channel deposit" + case VoucherMessageMismatchError: + return "Ed25519 message does not match Borsh voucher payload" + case VoucherSignerMismatchError: + return "Voucher signer does not match channel authorized_signer" + case InvalidRecipientCountError: + return "num_recipients outside [0, 32]" + case InvalidSplitConfigError: + return "Each shareBps must be non-zero and Σbps must be at most 10_000" + case DistributionPartsOverflowError: + return "num_recipients outside [0, 32]" + case DuplicateRecipientError: + return "Distribution plan contains a duplicate recipient address" + case DistributionAmountOverflowError: + return "num_recipients outside [0, 32]" + case DistributionPreimageLengthOverflowError: + return "Distribution preimage length calculation overflow" + case ChannelAddressMismatchError: + return "Derived channel account address does not match the user provided address" + case PayerPayeeMustDifferError: + return "Payer and payee must be different accounts" + case InvalidAuthorizedSignerError: + return "authorized_signer must be a valid Ed25519 public key" + case TopUpDepositOverflowError: + return "Deposit must be non-zero" + case FinalizeDeadlineOverflowError: + return "Deadline overflow on grace period" + case PayerAlreadyWithdrawnError: + return "Payer refund has already been claimed" + case RefundCalculationOverflowError: + return "Payer refund amount calculation underflow" + case ChannelNotDistributableError: + return "Channel is not in OPEN or FINALIZED" + case TreasuryAccountMismatchError: + return "Treasury token account is not ATA(TREASURY_OWNER, mint, token_program)" + case InvalidTreasuryTokenAccountError: + return "Treasury token account is invalid" + case InvalidTreasuryTokenExtensionsError: + return "Treasury token account has invalid extensions" + case RecipientAccountMismatchError: + return "Recipient token account is not ATA(recipient, token_program, mint)" + case InvalidRecipientTokenAccountError: + return "Recipient token account is invalid" + case InvalidRecipientTokenExtensionsError: + return "Recipient token account has invalid extensions" + case InvalidDistributionHashError: + return "Distribution hash mismatch" + case NothingToDistributeError: + return "No newly settled funds to distribute" + case RecipientAccountCountMismatchError: + return "Recipient ATA tail length does not match the committed plan's entry count" + case DistributePoolOverflowError: + return "Distribution pool calculation underflow" + case DistributeBalanceCalculationOverflowError: + return "Channel rent rebalance calculation underflow" + case DistributePayerBalanceOverflowError: + return "Payer lamports overflow on rent refund" + case DistributeTransferQueueOverflowError: + return "Transfer queue capacity exceeded" + default: + return fmt.Sprintf("unknown error: %d", e) + } +} + +func (e PaymentChannelsError) Code() uint32 { + return uint32(e) +} diff --git a/go/protocols/programs/paymentchannels/instruction_distribute.go b/go/protocols/programs/paymentchannels/instruction_distribute.go new file mode 100644 index 000000000..2042f4765 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_distribute.go @@ -0,0 +1,254 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var DistributeDiscriminator = 7 + +// Distribute is the `Distribute` instruction. +type Distribute struct { + // [0] = [WRITE] Channel + // [1] = [WRITE] Payer + // [2] = [WRITE] ChannelTokenAccount + // [3] = [WRITE] PayerTokenAccount + // [4] = [WRITE] PayeeTokenAccount + // [5] = [WRITE] TreasuryTokenAccount + // [6] = [] Mint + // [7] = [] TokenProgram + // [8] = [] EventAuthority + // [9] = [] SelfProgram + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` + // DistributeArgs holds the Borsh-encoded instruction arguments written + // after the distribute discriminator byte (7): the recipient split list, + // which must hash to the channel's stored distribution hash. + DistributeArgs DistributeArgs +} + +// NewDistributeInstructionBuilder creates a new `Distribute` instruction builder. +func NewDistributeInstructionBuilder() *Distribute { + nd := &Distribute{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 10) + return nd +} + +// SetDistributeArgs sets the "distribute_args" parameter. +func (inst *Distribute) SetDistributeArgs(distributeArgs DistributeArgs) *Distribute { + inst.DistributeArgs = distributeArgs + return inst +} + +// SetChannelAccount sets the "channel" account. +func (inst *Distribute) SetChannelAccount(channel ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[0] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *Distribute) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetPayerAccount sets the "payer" account. +func (inst *Distribute) SetPayerAccount(payer ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[1] = ag_solanago.Meta(payer).WRITE() + return inst +} + +// GetPayerAccount gets the "payer" account. +func (inst *Distribute) GetPayerAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +// SetChannelTokenAccountAccount sets the "channel_token_account" account. +func (inst *Distribute) SetChannelTokenAccountAccount(channelTokenAccount ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[2] = ag_solanago.Meta(channelTokenAccount).WRITE() + return inst +} + +// GetChannelTokenAccountAccount gets the "channel_token_account" account. +func (inst *Distribute) GetChannelTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[2] +} + +// SetPayerTokenAccountAccount sets the "payer_token_account" account. +func (inst *Distribute) SetPayerTokenAccountAccount(payerTokenAccount ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[3] = ag_solanago.Meta(payerTokenAccount).WRITE() + return inst +} + +// GetPayerTokenAccountAccount gets the "payer_token_account" account. +func (inst *Distribute) GetPayerTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[3] +} + +// SetPayeeTokenAccountAccount sets the "payee_token_account" account. +func (inst *Distribute) SetPayeeTokenAccountAccount(payeeTokenAccount ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[4] = ag_solanago.Meta(payeeTokenAccount).WRITE() + return inst +} + +// GetPayeeTokenAccountAccount gets the "payee_token_account" account. +func (inst *Distribute) GetPayeeTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[4] +} + +// SetTreasuryTokenAccountAccount sets the "treasury_token_account" account. +func (inst *Distribute) SetTreasuryTokenAccountAccount(treasuryTokenAccount ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[5] = ag_solanago.Meta(treasuryTokenAccount).WRITE() + return inst +} + +// GetTreasuryTokenAccountAccount gets the "treasury_token_account" account. +func (inst *Distribute) GetTreasuryTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[5] +} + +// SetMintAccount sets the "mint" account. +func (inst *Distribute) SetMintAccount(mint ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[6] = ag_solanago.Meta(mint) + return inst +} + +// GetMintAccount gets the "mint" account. +func (inst *Distribute) GetMintAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[6] +} + +// SetTokenProgramAccount sets the "token_program" account. +func (inst *Distribute) SetTokenProgramAccount(tokenProgram ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[7] = ag_solanago.Meta(tokenProgram) + return inst +} + +// GetTokenProgramAccount gets the "token_program" account. +func (inst *Distribute) GetTokenProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[7] +} + +// SetEventAuthorityAccount sets the "event_authority" account. +func (inst *Distribute) SetEventAuthorityAccount(eventAuthority ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[8] = ag_solanago.Meta(eventAuthority) + return inst +} + +// GetEventAuthorityAccount gets the "event_authority" account. +func (inst *Distribute) GetEventAuthorityAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[8] +} + +// SetSelfProgramAccount sets the "self_program" account. +func (inst *Distribute) SetSelfProgramAccount(selfProgram ag_solanago.PublicKey) *Distribute { + inst.AccountMetaSlice[9] = ag_solanago.Meta(selfProgram) + return inst +} + +// GetSelfProgramAccount gets the "self_program" account. +func (inst *Distribute) GetSelfProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[9] +} + +func (inst Distribute) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst Distribute) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *Distribute) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.Payer is not set") + } + if inst.AccountMetaSlice[2] == nil { + return fmt.Errorf("accounts.ChannelTokenAccount is not set") + } + if inst.AccountMetaSlice[3] == nil { + return fmt.Errorf("accounts.PayerTokenAccount is not set") + } + if inst.AccountMetaSlice[4] == nil { + return fmt.Errorf("accounts.PayeeTokenAccount is not set") + } + if inst.AccountMetaSlice[5] == nil { + return fmt.Errorf("accounts.TreasuryTokenAccount is not set") + } + if inst.AccountMetaSlice[6] == nil { + return fmt.Errorf("accounts.Mint is not set") + } + if inst.AccountMetaSlice[7] == nil { + return fmt.Errorf("accounts.TokenProgram is not set") + } + if inst.AccountMetaSlice[8] == nil { + return fmt.Errorf("accounts.EventAuthority is not set") + } + if inst.AccountMetaSlice[9] == nil { + return fmt.Errorf("accounts.SelfProgram is not set") + } + return nil +} + +func (inst *Distribute) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *Distribute) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 10 { + return fmt.Errorf("not enough accounts: expected at least 10, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:10] + return nil +} + +func (inst *Distribute) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(7)) + if err != nil { + return err + } + } + { + err := encoder.Encode(inst.DistributeArgs) + if err != nil { + return err + } + } + return nil +} + +func (inst *Distribute) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + { + err := decoder.Decode(&inst.DistributeArgs) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_emit_event.go b/go/protocols/programs/paymentchannels/instruction_emit_event.go new file mode 100644 index 000000000..0a4e70952 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_emit_event.go @@ -0,0 +1,97 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var EmitEventDiscriminator = 228 + +// EmitEvent is the `EmitEvent` instruction. +type EmitEvent struct { + // [0] = [SIGNER] EventAuthority + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` +} + +// NewEmitEventInstructionBuilder creates a new `EmitEvent` instruction builder. +func NewEmitEventInstructionBuilder() *EmitEvent { + nd := &EmitEvent{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 1) + return nd +} + +// SetEventAuthorityAccount sets the "event_authority" account. +func (inst *EmitEvent) SetEventAuthorityAccount(eventAuthority ag_solanago.PublicKey) *EmitEvent { + inst.AccountMetaSlice[0] = ag_solanago.Meta(eventAuthority).SIGNER() + return inst +} + +// GetEventAuthorityAccount gets the "event_authority" account. +func (inst *EmitEvent) GetEventAuthorityAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +func (inst EmitEvent) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst EmitEvent) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *EmitEvent) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.EventAuthority is not set") + } + return nil +} + +func (inst *EmitEvent) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *EmitEvent) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 1 { + return fmt.Errorf("not enough accounts: expected at least 1, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:1] + return nil +} + +func (inst *EmitEvent) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(228)) + if err != nil { + return err + } + } + return nil +} + +func (inst *EmitEvent) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_finalize.go b/go/protocols/programs/paymentchannels/instruction_finalize.go new file mode 100644 index 000000000..e43453295 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_finalize.go @@ -0,0 +1,97 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var FinalizeDiscriminator = 6 + +// Finalize is the `Finalize` instruction. +type Finalize struct { + // [0] = [WRITE] Channel + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` +} + +// NewFinalizeInstructionBuilder creates a new `Finalize` instruction builder. +func NewFinalizeInstructionBuilder() *Finalize { + nd := &Finalize{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 1) + return nd +} + +// SetChannelAccount sets the "channel" account. +func (inst *Finalize) SetChannelAccount(channel ag_solanago.PublicKey) *Finalize { + inst.AccountMetaSlice[0] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *Finalize) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +func (inst Finalize) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst Finalize) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *Finalize) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + return nil +} + +func (inst *Finalize) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *Finalize) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 1 { + return fmt.Errorf("not enough accounts: expected at least 1, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:1] + return nil +} + +func (inst *Finalize) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(6)) + if err != nil { + return err + } + } + return nil +} + +func (inst *Finalize) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_open.go b/go/protocols/programs/paymentchannels/instruction_open.go new file mode 100644 index 000000000..38406c3e7 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_open.go @@ -0,0 +1,299 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var OpenDiscriminator = 1 + +// Open is the `Open` instruction. +type Open struct { + // [0] = [WRITE, SIGNER] Payer + // [1] = [] Payee + // [2] = [] Mint + // [3] = [] AuthorizedSigner + // [4] = [WRITE] Channel + // [5] = [WRITE] PayerTokenAccount + // [6] = [WRITE] ChannelTokenAccount + // [7] = [] TokenProgram + // [8] = [] SystemProgram + // [9] = [] Rent + // [10] = [] AssociatedTokenProgram + // [11] = [] EventAuthority + // [12] = [] SelfProgram + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` + // OpenArgs holds the Borsh-encoded instruction arguments written after + // the open discriminator byte (1): PDA salt, initial deposit, close + // grace period, and the recipient split list committed at open. + OpenArgs OpenArgs +} + +// NewOpenInstructionBuilder creates a new `Open` instruction builder. +func NewOpenInstructionBuilder() *Open { + nd := &Open{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 13) + return nd +} + +// SetOpenArgs sets the "open_args" parameter. +func (inst *Open) SetOpenArgs(openArgs OpenArgs) *Open { + inst.OpenArgs = openArgs + return inst +} + +// SetPayerAccount sets the "payer" account. +func (inst *Open) SetPayerAccount(payer ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[0] = ag_solanago.Meta(payer).SIGNER().WRITE() + return inst +} + +// GetPayerAccount gets the "payer" account. +func (inst *Open) GetPayerAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetPayeeAccount sets the "payee" account. +func (inst *Open) SetPayeeAccount(payee ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[1] = ag_solanago.Meta(payee) + return inst +} + +// GetPayeeAccount gets the "payee" account. +func (inst *Open) GetPayeeAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +// SetMintAccount sets the "mint" account. +func (inst *Open) SetMintAccount(mint ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[2] = ag_solanago.Meta(mint) + return inst +} + +// GetMintAccount gets the "mint" account. +func (inst *Open) GetMintAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[2] +} + +// SetAuthorizedSignerAccount sets the "authorized_signer" account. +func (inst *Open) SetAuthorizedSignerAccount(authorizedSigner ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[3] = ag_solanago.Meta(authorizedSigner) + return inst +} + +// GetAuthorizedSignerAccount gets the "authorized_signer" account. +func (inst *Open) GetAuthorizedSignerAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[3] +} + +// SetChannelAccount sets the "channel" account. +func (inst *Open) SetChannelAccount(channel ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[4] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *Open) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[4] +} + +// SetPayerTokenAccountAccount sets the "payer_token_account" account. +func (inst *Open) SetPayerTokenAccountAccount(payerTokenAccount ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[5] = ag_solanago.Meta(payerTokenAccount).WRITE() + return inst +} + +// GetPayerTokenAccountAccount gets the "payer_token_account" account. +func (inst *Open) GetPayerTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[5] +} + +// SetChannelTokenAccountAccount sets the "channel_token_account" account. +func (inst *Open) SetChannelTokenAccountAccount(channelTokenAccount ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[6] = ag_solanago.Meta(channelTokenAccount).WRITE() + return inst +} + +// GetChannelTokenAccountAccount gets the "channel_token_account" account. +func (inst *Open) GetChannelTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[6] +} + +// SetTokenProgramAccount sets the "token_program" account. +func (inst *Open) SetTokenProgramAccount(tokenProgram ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[7] = ag_solanago.Meta(tokenProgram) + return inst +} + +// GetTokenProgramAccount gets the "token_program" account. +func (inst *Open) GetTokenProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[7] +} + +// SetSystemProgramAccount sets the "system_program" account. +func (inst *Open) SetSystemProgramAccount(systemProgram ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[8] = ag_solanago.Meta(systemProgram) + return inst +} + +// GetSystemProgramAccount gets the "system_program" account. +func (inst *Open) GetSystemProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[8] +} + +// SetRentAccount sets the "rent" account. +func (inst *Open) SetRentAccount(rent ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[9] = ag_solanago.Meta(rent) + return inst +} + +// GetRentAccount gets the "rent" account. +func (inst *Open) GetRentAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[9] +} + +// SetAssociatedTokenProgramAccount sets the "associated_token_program" account. +func (inst *Open) SetAssociatedTokenProgramAccount(associatedTokenProgram ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[10] = ag_solanago.Meta(associatedTokenProgram) + return inst +} + +// GetAssociatedTokenProgramAccount gets the "associated_token_program" account. +func (inst *Open) GetAssociatedTokenProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[10] +} + +// SetEventAuthorityAccount sets the "event_authority" account. +func (inst *Open) SetEventAuthorityAccount(eventAuthority ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[11] = ag_solanago.Meta(eventAuthority) + return inst +} + +// GetEventAuthorityAccount gets the "event_authority" account. +func (inst *Open) GetEventAuthorityAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[11] +} + +// SetSelfProgramAccount sets the "self_program" account. +func (inst *Open) SetSelfProgramAccount(selfProgram ag_solanago.PublicKey) *Open { + inst.AccountMetaSlice[12] = ag_solanago.Meta(selfProgram) + return inst +} + +// GetSelfProgramAccount gets the "self_program" account. +func (inst *Open) GetSelfProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[12] +} + +func (inst Open) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst Open) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *Open) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Payer is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.Payee is not set") + } + if inst.AccountMetaSlice[2] == nil { + return fmt.Errorf("accounts.Mint is not set") + } + if inst.AccountMetaSlice[3] == nil { + return fmt.Errorf("accounts.AuthorizedSigner is not set") + } + if inst.AccountMetaSlice[4] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + if inst.AccountMetaSlice[5] == nil { + return fmt.Errorf("accounts.PayerTokenAccount is not set") + } + if inst.AccountMetaSlice[6] == nil { + return fmt.Errorf("accounts.ChannelTokenAccount is not set") + } + if inst.AccountMetaSlice[7] == nil { + return fmt.Errorf("accounts.TokenProgram is not set") + } + if inst.AccountMetaSlice[8] == nil { + return fmt.Errorf("accounts.SystemProgram is not set") + } + if inst.AccountMetaSlice[9] == nil { + return fmt.Errorf("accounts.Rent is not set") + } + if inst.AccountMetaSlice[10] == nil { + return fmt.Errorf("accounts.AssociatedTokenProgram is not set") + } + if inst.AccountMetaSlice[11] == nil { + return fmt.Errorf("accounts.EventAuthority is not set") + } + if inst.AccountMetaSlice[12] == nil { + return fmt.Errorf("accounts.SelfProgram is not set") + } + return nil +} + +func (inst *Open) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *Open) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 13 { + return fmt.Errorf("not enough accounts: expected at least 13, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:13] + return nil +} + +func (inst *Open) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(1)) + if err != nil { + return err + } + } + { + err := encoder.Encode(inst.OpenArgs) + if err != nil { + return err + } + } + return nil +} + +func (inst *Open) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + { + err := decoder.Decode(&inst.OpenArgs) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_request_close.go b/go/protocols/programs/paymentchannels/instruction_request_close.go new file mode 100644 index 000000000..9d4960483 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_request_close.go @@ -0,0 +1,112 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var RequestCloseDiscriminator = 5 + +// RequestClose is the `RequestClose` instruction. +type RequestClose struct { + // [0] = [SIGNER] Payer + // [1] = [WRITE] Channel + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` +} + +// NewRequestCloseInstructionBuilder creates a new `RequestClose` instruction builder. +func NewRequestCloseInstructionBuilder() *RequestClose { + nd := &RequestClose{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 2) + return nd +} + +// SetPayerAccount sets the "payer" account. +func (inst *RequestClose) SetPayerAccount(payer ag_solanago.PublicKey) *RequestClose { + inst.AccountMetaSlice[0] = ag_solanago.Meta(payer).SIGNER() + return inst +} + +// GetPayerAccount gets the "payer" account. +func (inst *RequestClose) GetPayerAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetChannelAccount sets the "channel" account. +func (inst *RequestClose) SetChannelAccount(channel ag_solanago.PublicKey) *RequestClose { + inst.AccountMetaSlice[1] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *RequestClose) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +func (inst RequestClose) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst RequestClose) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *RequestClose) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Payer is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + return nil +} + +func (inst *RequestClose) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *RequestClose) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 2 { + return fmt.Errorf("not enough accounts: expected at least 2, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:2] + return nil +} + +func (inst *RequestClose) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(5)) + if err != nil { + return err + } + } + return nil +} + +func (inst *RequestClose) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_settle.go b/go/protocols/programs/paymentchannels/instruction_settle.go new file mode 100644 index 000000000..76a70c0fe --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_settle.go @@ -0,0 +1,135 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var SettleDiscriminator = 2 + +// Settle is the `Settle` instruction. +type Settle struct { + // [0] = [WRITE] Channel + // [1] = [] InstructionsSysvar + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` + // SettleArgs holds the Borsh-encoded instruction arguments written after + // the settle discriminator byte (2): the voucher whose Ed25519 signature + // is checked via a preceding precompile instruction referenced through + // the instructions sysvar. + SettleArgs SettleArgs +} + +// NewSettleInstructionBuilder creates a new `Settle` instruction builder. +func NewSettleInstructionBuilder() *Settle { + nd := &Settle{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 2) + return nd +} + +// SetSettleArgs sets the "settle_args" parameter. +func (inst *Settle) SetSettleArgs(settleArgs SettleArgs) *Settle { + inst.SettleArgs = settleArgs + return inst +} + +// SetChannelAccount sets the "channel" account. +func (inst *Settle) SetChannelAccount(channel ag_solanago.PublicKey) *Settle { + inst.AccountMetaSlice[0] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *Settle) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetInstructionsSysvarAccount sets the "instructions_sysvar" account. +func (inst *Settle) SetInstructionsSysvarAccount(instructionsSysvar ag_solanago.PublicKey) *Settle { + inst.AccountMetaSlice[1] = ag_solanago.Meta(instructionsSysvar) + return inst +} + +// GetInstructionsSysvarAccount gets the "instructions_sysvar" account. +func (inst *Settle) GetInstructionsSysvarAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +func (inst Settle) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst Settle) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *Settle) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.InstructionsSysvar is not set") + } + return nil +} + +func (inst *Settle) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *Settle) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 2 { + return fmt.Errorf("not enough accounts: expected at least 2, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:2] + return nil +} + +func (inst *Settle) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(2)) + if err != nil { + return err + } + } + { + err := encoder.Encode(inst.SettleArgs) + if err != nil { + return err + } + } + return nil +} + +func (inst *Settle) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + { + err := decoder.Decode(&inst.SettleArgs) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_settle_and_finalize.go b/go/protocols/programs/paymentchannels/instruction_settle_and_finalize.go new file mode 100644 index 000000000..3766e995a --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_settle_and_finalize.go @@ -0,0 +1,150 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var SettleAndFinalizeDiscriminator = 4 + +// SettleAndFinalize is the `SettleAndFinalize` instruction. +type SettleAndFinalize struct { + // [0] = [SIGNER] Merchant + // [1] = [WRITE] Channel + // [2] = [] InstructionsSysvar + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` + // SettleAndFinalizeArgs holds the Borsh-encoded instruction arguments + // written after the settleAndFinalize discriminator byte (4): the + // optional final voucher plus the hasVoucher flag selecting whether the + // Ed25519 precompile check applies. + SettleAndFinalizeArgs SettleAndFinalizeArgs +} + +// NewSettleAndFinalizeInstructionBuilder creates a new `SettleAndFinalize` instruction builder. +func NewSettleAndFinalizeInstructionBuilder() *SettleAndFinalize { + nd := &SettleAndFinalize{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 3) + return nd +} + +// SetSettleAndFinalizeArgs sets the "settle_and_finalize_args" parameter. +func (inst *SettleAndFinalize) SetSettleAndFinalizeArgs(settleAndFinalizeArgs SettleAndFinalizeArgs) *SettleAndFinalize { + inst.SettleAndFinalizeArgs = settleAndFinalizeArgs + return inst +} + +// SetMerchantAccount sets the "merchant" account. +func (inst *SettleAndFinalize) SetMerchantAccount(merchant ag_solanago.PublicKey) *SettleAndFinalize { + inst.AccountMetaSlice[0] = ag_solanago.Meta(merchant).SIGNER() + return inst +} + +// GetMerchantAccount gets the "merchant" account. +func (inst *SettleAndFinalize) GetMerchantAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetChannelAccount sets the "channel" account. +func (inst *SettleAndFinalize) SetChannelAccount(channel ag_solanago.PublicKey) *SettleAndFinalize { + inst.AccountMetaSlice[1] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *SettleAndFinalize) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +// SetInstructionsSysvarAccount sets the "instructions_sysvar" account. +func (inst *SettleAndFinalize) SetInstructionsSysvarAccount(instructionsSysvar ag_solanago.PublicKey) *SettleAndFinalize { + inst.AccountMetaSlice[2] = ag_solanago.Meta(instructionsSysvar) + return inst +} + +// GetInstructionsSysvarAccount gets the "instructions_sysvar" account. +func (inst *SettleAndFinalize) GetInstructionsSysvarAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[2] +} + +func (inst SettleAndFinalize) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst SettleAndFinalize) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *SettleAndFinalize) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Merchant is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + if inst.AccountMetaSlice[2] == nil { + return fmt.Errorf("accounts.InstructionsSysvar is not set") + } + return nil +} + +func (inst *SettleAndFinalize) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *SettleAndFinalize) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 3 { + return fmt.Errorf("not enough accounts: expected at least 3, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:3] + return nil +} + +func (inst *SettleAndFinalize) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(4)) + if err != nil { + return err + } + } + { + err := encoder.Encode(inst.SettleAndFinalizeArgs) + if err != nil { + return err + } + } + return nil +} + +func (inst *SettleAndFinalize) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + { + err := decoder.Decode(&inst.SettleAndFinalizeArgs) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_top_up.go b/go/protocols/programs/paymentchannels/instruction_top_up.go new file mode 100644 index 000000000..89249f633 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_top_up.go @@ -0,0 +1,194 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var TopUpDiscriminator = 3 + +// TopUp is the `TopUp` instruction. +type TopUp struct { + // [0] = [WRITE, SIGNER] Payer + // [1] = [WRITE] Channel + // [2] = [WRITE] PayerTokenAccount + // [3] = [WRITE] ChannelTokenAccount + // [4] = [] Mint + // [5] = [] TokenProgram + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` + // TopUpArgs holds the Borsh-encoded instruction arguments written after + // the topUp discriminator byte (3): the additional deposit amount in + // mint base units. + TopUpArgs TopUpArgs +} + +// NewTopUpInstructionBuilder creates a new `TopUp` instruction builder. +func NewTopUpInstructionBuilder() *TopUp { + nd := &TopUp{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 6) + return nd +} + +// SetTopUpArgs sets the "top_up_args" parameter. +func (inst *TopUp) SetTopUpArgs(topUpArgs TopUpArgs) *TopUp { + inst.TopUpArgs = topUpArgs + return inst +} + +// SetPayerAccount sets the "payer" account. +func (inst *TopUp) SetPayerAccount(payer ag_solanago.PublicKey) *TopUp { + inst.AccountMetaSlice[0] = ag_solanago.Meta(payer).SIGNER().WRITE() + return inst +} + +// GetPayerAccount gets the "payer" account. +func (inst *TopUp) GetPayerAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetChannelAccount sets the "channel" account. +func (inst *TopUp) SetChannelAccount(channel ag_solanago.PublicKey) *TopUp { + inst.AccountMetaSlice[1] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *TopUp) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +// SetPayerTokenAccountAccount sets the "payer_token_account" account. +func (inst *TopUp) SetPayerTokenAccountAccount(payerTokenAccount ag_solanago.PublicKey) *TopUp { + inst.AccountMetaSlice[2] = ag_solanago.Meta(payerTokenAccount).WRITE() + return inst +} + +// GetPayerTokenAccountAccount gets the "payer_token_account" account. +func (inst *TopUp) GetPayerTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[2] +} + +// SetChannelTokenAccountAccount sets the "channel_token_account" account. +func (inst *TopUp) SetChannelTokenAccountAccount(channelTokenAccount ag_solanago.PublicKey) *TopUp { + inst.AccountMetaSlice[3] = ag_solanago.Meta(channelTokenAccount).WRITE() + return inst +} + +// GetChannelTokenAccountAccount gets the "channel_token_account" account. +func (inst *TopUp) GetChannelTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[3] +} + +// SetMintAccount sets the "mint" account. +func (inst *TopUp) SetMintAccount(mint ag_solanago.PublicKey) *TopUp { + inst.AccountMetaSlice[4] = ag_solanago.Meta(mint) + return inst +} + +// GetMintAccount gets the "mint" account. +func (inst *TopUp) GetMintAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[4] +} + +// SetTokenProgramAccount sets the "token_program" account. +func (inst *TopUp) SetTokenProgramAccount(tokenProgram ag_solanago.PublicKey) *TopUp { + inst.AccountMetaSlice[5] = ag_solanago.Meta(tokenProgram) + return inst +} + +// GetTokenProgramAccount gets the "token_program" account. +func (inst *TopUp) GetTokenProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[5] +} + +func (inst TopUp) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst TopUp) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *TopUp) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Payer is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + if inst.AccountMetaSlice[2] == nil { + return fmt.Errorf("accounts.PayerTokenAccount is not set") + } + if inst.AccountMetaSlice[3] == nil { + return fmt.Errorf("accounts.ChannelTokenAccount is not set") + } + if inst.AccountMetaSlice[4] == nil { + return fmt.Errorf("accounts.Mint is not set") + } + if inst.AccountMetaSlice[5] == nil { + return fmt.Errorf("accounts.TokenProgram is not set") + } + return nil +} + +func (inst *TopUp) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *TopUp) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 6 { + return fmt.Errorf("not enough accounts: expected at least 6, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:6] + return nil +} + +func (inst *TopUp) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(3)) + if err != nil { + return err + } + } + { + err := encoder.Encode(inst.TopUpArgs) + if err != nil { + return err + } + } + return nil +} + +func (inst *TopUp) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + { + err := decoder.Decode(&inst.TopUpArgs) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instruction_withdraw_payer.go b/go/protocols/programs/paymentchannels/instruction_withdraw_payer.go new file mode 100644 index 000000000..2f0a2c4a7 --- /dev/null +++ b/go/protocols/programs/paymentchannels/instruction_withdraw_payer.go @@ -0,0 +1,172 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +var WithdrawPayerDiscriminator = 8 + +// WithdrawPayer is the `WithdrawPayer` instruction. +type WithdrawPayer struct { + // [0] = [SIGNER] Payer + // [1] = [WRITE] Channel + // [2] = [WRITE] ChannelTokenAccount + // [3] = [WRITE] PayerTokenAccount + // [4] = [] Mint + // [5] = [] TokenProgram + ag_solanago.AccountMetaSlice `bin:"-" borsh_skip:"true"` +} + +// NewWithdrawPayerInstructionBuilder creates a new `WithdrawPayer` instruction builder. +func NewWithdrawPayerInstructionBuilder() *WithdrawPayer { + nd := &WithdrawPayer{} + nd.AccountMetaSlice = make(ag_solanago.AccountMetaSlice, 6) + return nd +} + +// SetPayerAccount sets the "payer" account. +func (inst *WithdrawPayer) SetPayerAccount(payer ag_solanago.PublicKey) *WithdrawPayer { + inst.AccountMetaSlice[0] = ag_solanago.Meta(payer).SIGNER() + return inst +} + +// GetPayerAccount gets the "payer" account. +func (inst *WithdrawPayer) GetPayerAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[0] +} + +// SetChannelAccount sets the "channel" account. +func (inst *WithdrawPayer) SetChannelAccount(channel ag_solanago.PublicKey) *WithdrawPayer { + inst.AccountMetaSlice[1] = ag_solanago.Meta(channel).WRITE() + return inst +} + +// GetChannelAccount gets the "channel" account. +func (inst *WithdrawPayer) GetChannelAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[1] +} + +// SetChannelTokenAccountAccount sets the "channel_token_account" account. +func (inst *WithdrawPayer) SetChannelTokenAccountAccount(channelTokenAccount ag_solanago.PublicKey) *WithdrawPayer { + inst.AccountMetaSlice[2] = ag_solanago.Meta(channelTokenAccount).WRITE() + return inst +} + +// GetChannelTokenAccountAccount gets the "channel_token_account" account. +func (inst *WithdrawPayer) GetChannelTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[2] +} + +// SetPayerTokenAccountAccount sets the "payer_token_account" account. +func (inst *WithdrawPayer) SetPayerTokenAccountAccount(payerTokenAccount ag_solanago.PublicKey) *WithdrawPayer { + inst.AccountMetaSlice[3] = ag_solanago.Meta(payerTokenAccount).WRITE() + return inst +} + +// GetPayerTokenAccountAccount gets the "payer_token_account" account. +func (inst *WithdrawPayer) GetPayerTokenAccountAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[3] +} + +// SetMintAccount sets the "mint" account. +func (inst *WithdrawPayer) SetMintAccount(mint ag_solanago.PublicKey) *WithdrawPayer { + inst.AccountMetaSlice[4] = ag_solanago.Meta(mint) + return inst +} + +// GetMintAccount gets the "mint" account. +func (inst *WithdrawPayer) GetMintAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[4] +} + +// SetTokenProgramAccount sets the "token_program" account. +func (inst *WithdrawPayer) SetTokenProgramAccount(tokenProgram ag_solanago.PublicKey) *WithdrawPayer { + inst.AccountMetaSlice[5] = ag_solanago.Meta(tokenProgram) + return inst +} + +// GetTokenProgramAccount gets the "token_program" account. +func (inst *WithdrawPayer) GetTokenProgramAccount() *ag_solanago.AccountMeta { + return inst.AccountMetaSlice[5] +} + +func (inst WithdrawPayer) Build() *Instruction { + return &Instruction{BaseVariant: ag_binary.BaseVariant{ + Impl: inst, + TypeID: ag_binary.NoTypeIDDefaultID, + }} +} + +// ValidateAndBuild validates the instruction parameters and accounts; +// if there is a validation error, it returns the error. +// Otherwise, it builds and returns the instruction. +func (inst WithdrawPayer) ValidateAndBuild() (*Instruction, error) { + if err := inst.Validate(); err != nil { + return nil, err + } + return inst.Build(), nil +} + +func (inst *WithdrawPayer) Validate() error { + if inst.AccountMetaSlice[0] == nil { + return fmt.Errorf("accounts.Payer is not set") + } + if inst.AccountMetaSlice[1] == nil { + return fmt.Errorf("accounts.Channel is not set") + } + if inst.AccountMetaSlice[2] == nil { + return fmt.Errorf("accounts.ChannelTokenAccount is not set") + } + if inst.AccountMetaSlice[3] == nil { + return fmt.Errorf("accounts.PayerTokenAccount is not set") + } + if inst.AccountMetaSlice[4] == nil { + return fmt.Errorf("accounts.Mint is not set") + } + if inst.AccountMetaSlice[5] == nil { + return fmt.Errorf("accounts.TokenProgram is not set") + } + return nil +} + +func (inst *WithdrawPayer) GetAccounts() (out []*ag_solanago.AccountMeta) { + return inst.AccountMetaSlice +} + +func (inst *WithdrawPayer) SetAccounts(accounts []*ag_solanago.AccountMeta) error { + if len(accounts) < 6 { + return fmt.Errorf("not enough accounts: expected at least 6, got %d", len(accounts)) + } + inst.AccountMetaSlice = accounts[:6] + return nil +} + +func (inst *WithdrawPayer) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + { + err := encoder.Encode(uint8(8)) + if err != nil { + return err + } + } + return nil +} + +func (inst *WithdrawPayer) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + { + var tmp uint8 + err := decoder.Decode(&tmp) + if err != nil { + return err + } + } + return nil +} diff --git a/go/protocols/programs/paymentchannels/instructions.go b/go/protocols/programs/paymentchannels/instructions.go new file mode 100644 index 000000000..621cba31f --- /dev/null +++ b/go/protocols/programs/paymentchannels/instructions.go @@ -0,0 +1,127 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + "bytes" + "fmt" + + ag_binary "github.com/gagliardetto/binary" + ag_solanago "github.com/gagliardetto/solana-go" +) + +const ProgramName = "PaymentChannels" + +var ProgramID ag_solanago.PublicKey = ag_solanago.MustPublicKeyFromBase58("CQAyft83tN1w2bRofB5PZ79eVDU2xZUVo43LU1qL4zRg") + +func SetProgramID(pubkey ag_solanago.PublicKey) { + ProgramID = pubkey + ag_solanago.RegisterInstructionDecoder(ProgramID, registryDecodeInstruction) +} + +var ( + Instruction_9_Count uint32 = 9 + + InstructionImplDef = ag_binary.NewVariantDefinition( + ag_binary.Uint8TypeIDEncoding, + []ag_binary.VariantType{ + { + "Distribute", + (*Distribute)(nil), + }, + { + "EmitEvent", + (*EmitEvent)(nil), + }, + { + "Finalize", + (*Finalize)(nil), + }, + { + "Open", + (*Open)(nil), + }, + { + "RequestClose", + (*RequestClose)(nil), + }, + { + "Settle", + (*Settle)(nil), + }, + { + "SettleAndFinalize", + (*SettleAndFinalize)(nil), + }, + { + "TopUp", + (*TopUp)(nil), + }, + { + "WithdrawPayer", + (*WithdrawPayer)(nil), + }, + }, + ) +) + +type Instruction struct { + // BaseVariant carries the decoded instruction variant: TypeID holds the + // single-byte instruction discriminator and Impl the concrete + // per-instruction struct (Open, TopUp, Settle, ...). + ag_binary.BaseVariant +} + +func (inst *Instruction) ProgramID() ag_solanago.PublicKey { + return ProgramID +} + +func (inst *Instruction) Accounts() (out []*ag_solanago.AccountMeta) { + return inst.Impl.(ag_solanago.AccountsGettable).GetAccounts() +} + +func (inst *Instruction) Data() ([]byte, error) { + buf := new(bytes.Buffer) + if err := ag_binary.NewBorshEncoder(buf).Encode(inst); err != nil { + return nil, fmt.Errorf("unable to encode instruction: %w", err) + } + return buf.Bytes(), nil +} + +func (inst *Instruction) UnmarshalWithDecoder(decoder *ag_binary.Decoder) error { + return inst.BaseVariant.UnmarshalBinaryVariant(decoder, InstructionImplDef) +} + +func (inst *Instruction) MarshalWithEncoder(encoder *ag_binary.Encoder) error { + err := encoder.WriteUint8(inst.TypeID.Uint8()) + if err != nil { + return fmt.Errorf("unable to write variant type: %w", err) + } + return encoder.Encode(inst.Impl) +} + +func registryDecodeInstruction(accounts []*ag_solanago.AccountMeta, data []byte) (interface{}, error) { + inst, err := DecodeInstruction(accounts, data) + if err != nil { + return nil, err + } + return inst, nil +} + +func DecodeInstruction(accounts []*ag_solanago.AccountMeta, data []byte) (*Instruction, error) { + inst := new(Instruction) + if err := ag_binary.NewBorshDecoder(data).Decode(inst); err != nil { + return nil, fmt.Errorf("unable to decode instruction: %w", err) + } + if v, ok := inst.Impl.(ag_solanago.AccountsSettable); ok { + err := v.SetAccounts(accounts) + if err != nil { + return nil, fmt.Errorf("unable to set accounts for instruction: %w", err) + } + } + return inst, nil +} diff --git a/go/protocols/programs/paymentchannels/type_account_discriminator.go b/go/protocols/programs/paymentchannels/type_account_discriminator.go new file mode 100644 index 000000000..8db4ecb43 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_account_discriminator.go @@ -0,0 +1,14 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type AccountDiscriminator uint8 + +const ( + AccountDiscriminator_Channel AccountDiscriminator = iota + AccountDiscriminator_ClosedChannel +) diff --git a/go/protocols/programs/paymentchannels/type_channel_status.go b/go/protocols/programs/paymentchannels/type_channel_status.go new file mode 100644 index 000000000..b9843701b --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_channel_status.go @@ -0,0 +1,15 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type ChannelStatus uint8 + +const ( + ChannelStatus_Open ChannelStatus = iota + ChannelStatus_Finalized + ChannelStatus_Closing +) diff --git a/go/protocols/programs/paymentchannels/type_distribute_args.go b/go/protocols/programs/paymentchannels/type_distribute_args.go new file mode 100644 index 000000000..0068f38ef --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_distribute_args.go @@ -0,0 +1,14 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type DistributeArgs struct { + // Recipients is the basis-point split list paid out at distribution, + // Borsh-encoded with a little-endian u32 length prefix; it must hash to + // the distribution hash committed in the channel account at open. + Recipients []DistributionEntry +} diff --git a/go/protocols/programs/paymentchannels/type_distribution_entry.go b/go/protocols/programs/paymentchannels/type_distribution_entry.go new file mode 100644 index 000000000..6e024a4e4 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_distribution_entry.go @@ -0,0 +1,21 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + ag_solanago "github.com/gagliardetto/solana-go" +) + +type DistributionEntry struct { + // Recipient is the raw 32-byte public key of the split recipient + // wallet; its associated token account for the channel mint receives + // this entry's share at distribution. + Recipient ag_solanago.PublicKey + // Bps is the recipient's share of the settled amount in basis points + // (1 bps = 0.01%). + Bps uint16 +} diff --git a/go/protocols/programs/paymentchannels/type_open_args.go b/go/protocols/programs/paymentchannels/type_open_args.go new file mode 100644 index 000000000..ff152a7a1 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_open_args.go @@ -0,0 +1,23 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type OpenArgs struct { + // Salt is the caller-chosen u64 mixed little-endian into the channel + // PDA seeds, distinguishing channels that share payer, payee, mint, and + // authorized signer. + Salt uint64 + // Deposit is the initial amount transferred from the payer into the + // channel token account, in mint base units. + Deposit uint64 + // GracePeriod is the close grace period in seconds: the window after a + // close request during which outstanding vouchers can still be settled. + GracePeriod uint32 + // Recipients is the basis-point split list committed at open via the + // channel's distribution hash and enforced again at distribution. + Recipients []DistributionEntry +} diff --git a/go/protocols/programs/paymentchannels/type_opened.go b/go/protocols/programs/paymentchannels/type_opened.go new file mode 100644 index 000000000..883130a0a --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_opened.go @@ -0,0 +1,18 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + ag_solanago "github.com/gagliardetto/solana-go" +) + +type Opened struct { + // Channel is the PDA of the channel account created by the open + // instruction, derived from ["channel", payer, payee, mint, + // authorizedSigner, salt as little-endian u64]. + Channel ag_solanago.PublicKey +} diff --git a/go/protocols/programs/paymentchannels/type_payout_beneficiary.go b/go/protocols/programs/paymentchannels/type_payout_beneficiary.go new file mode 100644 index 000000000..18cfceb14 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_payout_beneficiary.go @@ -0,0 +1,15 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type PayoutBeneficiary uint8 + +const ( + PayoutBeneficiary_Recipient PayoutBeneficiary = iota + PayoutBeneficiary_Payee + PayoutBeneficiary_Payer +) diff --git a/go/protocols/programs/paymentchannels/type_payout_redirected.go b/go/protocols/programs/paymentchannels/type_payout_redirected.go new file mode 100644 index 000000000..15b4b055f --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_payout_redirected.go @@ -0,0 +1,29 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + ag_solanago "github.com/gagliardetto/solana-go" +) + +type PayoutRedirected struct { + // Channel is the PDA of the channel whose payout was redirected. + Channel ag_solanago.PublicKey + // Owner is the raw 32-byte public key of the wallet whose token account + // was the intended payout target but could not be paid. + Owner ag_solanago.PublicKey + // Amount is the redirected portion of the payout in mint base units + // (the amount of this transfer, not a cumulative total). + Amount uint64 + // Beneficiary identifies who actually received the redirected funds + // instead: a split recipient, the payee, or the payer. + Beneficiary PayoutBeneficiary + // Reason is the RedirectReason explaining why the intended token + // account was skipped (unsupported extension, closed or malformed, not + // initialized, or reassigned authority). + Reason RedirectReason +} diff --git a/go/protocols/programs/paymentchannels/type_redirect_reason.go b/go/protocols/programs/paymentchannels/type_redirect_reason.go new file mode 100644 index 000000000..86f70d4ce --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_redirect_reason.go @@ -0,0 +1,16 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type RedirectReason uint8 + +const ( + RedirectReason_UnsupportedExtension RedirectReason = iota + RedirectReason_ClosedOrMalformed + RedirectReason_NotInitialized + RedirectReason_ReassignedAuthority +) diff --git a/go/protocols/programs/paymentchannels/type_settle_and_finalize_args.go b/go/protocols/programs/paymentchannels/type_settle_and_finalize_args.go new file mode 100644 index 000000000..c26e8b8b9 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_settle_and_finalize_args.go @@ -0,0 +1,17 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type SettleAndFinalizeArgs struct { + // Voucher is the final voucher committed before the channel is + // finalized; its contents are only meaningful when HasVoucher is 1. + Voucher VoucherArgs + // HasVoucher is a Borsh bool flag: 1 settles Voucher (verified through + // a preceding Ed25519 precompile instruction referenced via the + // instructions sysvar) before finalizing, 0 finalizes without one. + HasVoucher uint8 +} diff --git a/go/protocols/programs/paymentchannels/type_settle_args.go b/go/protocols/programs/paymentchannels/type_settle_args.go new file mode 100644 index 000000000..7adc98608 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_settle_args.go @@ -0,0 +1,15 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type SettleArgs struct { + // Voucher is the signed voucher to commit, raising the channel's + // settled watermark to its cumulative amount; the Ed25519 signature is + // carried in a preceding precompile instruction referenced via the + // instructions sysvar, not in these args. + Voucher VoucherArgs +} diff --git a/go/protocols/programs/paymentchannels/type_settlement_watermarks.go b/go/protocols/programs/paymentchannels/type_settlement_watermarks.go new file mode 100644 index 000000000..3ab797dc0 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_settlement_watermarks.go @@ -0,0 +1,17 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type SettlementWatermarks struct { + // Settled is the cumulative amount committed by settled vouchers over + // the channel lifetime, in mint base units; it only ever increases. + Settled uint64 + // PayoutWatermark is the cumulative amount already paid out to + // recipients over the channel lifetime, in mint base units; the unpaid + // remainder is Settled minus PayoutWatermark. + PayoutWatermark uint64 +} diff --git a/go/protocols/programs/paymentchannels/type_top_up_args.go b/go/protocols/programs/paymentchannels/type_top_up_args.go new file mode 100644 index 000000000..05cfdf0d6 --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_top_up_args.go @@ -0,0 +1,14 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +type TopUpArgs struct { + // Amount is the additional deposit transferred from the payer into the + // channel token account, in mint base units; a delta added to the + // channel's deposit total, not a cumulative figure. + Amount uint64 +} diff --git a/go/protocols/programs/paymentchannels/type_voucher_args.go b/go/protocols/programs/paymentchannels/type_voucher_args.go new file mode 100644 index 000000000..7e7a7366f --- /dev/null +++ b/go/protocols/programs/paymentchannels/type_voucher_args.go @@ -0,0 +1,25 @@ +// This code was AUTOGENERATED using the codama library. +// Please DO NOT EDIT THIS FILE, instead use visitors +// to add features, then rerun codama to update it. +// +// https://github.com/codama-idl/codama + +package payment_channels + +import ( + ag_solanago "github.com/gagliardetto/solana-go" +) + +type VoucherArgs struct { + // ChannelId is the raw 32-byte channel PDA the voucher settles against; + // it forms bytes 0..32 of the 48-byte Ed25519 signing preimage. + ChannelId ag_solanago.PublicKey + // CumulativeAmount is the total authorized over the channel lifetime in + // mint base units (a running total, not a per-voucher delta), encoded + // as a little-endian u64 at bytes 32..40 of the signing preimage. + CumulativeAmount uint64 + // ExpiresAt is the voucher expiry as a Unix timestamp in seconds, + // encoded as a little-endian i64 at bytes 40..48 of the signing + // preimage. + ExpiresAt int64 +} diff --git a/go/protocols/programs/paymentchannels_parity_test/parity_test.go b/go/protocols/programs/paymentchannels_parity_test/parity_test.go new file mode 100644 index 000000000..760f3abd1 --- /dev/null +++ b/go/protocols/programs/paymentchannels_parity_test/parity_test.go @@ -0,0 +1,140 @@ +// Package paymentchannels_parity guards the Codama-generated payment-channels +// Go client against the Rust spine byte-for-byte. +// +// It lives in a separate directory from the generated package because the Go +// codegen (`pnpm run payment-channels:go`) renders with +// deleteFolderBeforeRendering, which wipes everything under +// protocols/programs/paymentchannels/. Keeping the guard out-of-tree means +// regeneration never clobbers it. +// +// The frozen hex vectors are produced by `borsh::to_vec` over the identical +// OpenArgs, DistributionEntry, and VoucherArgs struct layouts, plus the u8=1 +// open discriminator the on-chain program declares +// (OPEN_DISCRIMINATOR: u8 = 1). If the upstream IDL changes the layout, both +// the regenerated client and these vectors must move together, and this test +// makes that break loud. +package paymentchannels_parity + +import ( + "bytes" + "encoding/hex" + "testing" + + bin "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + + pc "github.com/solana-foundation/pay-kit/go/protocols/programs/paymentchannels" +) + +// borshEncode serializes v with the gagliardetto Borsh encoder, matching the +// little-endian, length-prefixed-Vec layout that borsh::to_vec emits in Rust. +func borshEncode(t *testing.T, v any) []byte { + t.Helper() + var buf bytes.Buffer + if err := bin.NewBorshEncoder(&buf).Encode(v); err != nil { + t.Fatalf("borsh encode: %v", err) + } + return buf.Bytes() +} + +func mustHex(t *testing.T, s string) []byte { + t.Helper() + b, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("decode frozen vector %q: %v", s, err) + } + return b +} + +// TestOpenDiscriminator pins the single-byte Anchor-numeric discriminator. This +// program does NOT use the 8-byte sha256("global:open")[:8] convention; the +// on-chain program declares OPEN_DISCRIMINATOR: u8 = 1 and the IDL encodes it +// as a fieldDiscriminatorNode u8 at offset 0. Guard against a silent switch +// to the wide form. +func TestOpenDiscriminator(t *testing.T) { + if pc.OpenDiscriminator != 1 { + t.Fatalf("OpenDiscriminator = %d, want 1 (rust OPEN_DISCRIMINATOR: u8 = 1)", pc.OpenDiscriminator) + } +} + +// TestOpenArgsBorshParity asserts the OpenArgs Borsh layout +// {salt u64, deposit u64, grace_period u32, recipients Vec<{recipient pubkey, bps u16}>} +// matches the Rust spine for a frozen input. +func TestOpenArgsBorshParity(t *testing.T) { + // salt=1, deposit=1_000_000, grace_period=900, + // recipients=[{recipient=, bps=10000}] + args := pc.OpenArgs{ + Salt: 1, + Deposit: 1_000_000, + GracePeriod: 900, + Recipients: []pc.DistributionEntry{ + {Recipient: solana.PublicKey{}, Bps: 10000}, + }, + } + + // Frozen from `borsh::to_vec(&OpenArgs{...})` against the Rust spine layout. + const wantOpenArgs = "010000000000000040420f0000000000840300000100000000000000000000000000000000000000000000000000000000000000000000001027" + got := borshEncode(t, args) + if want := mustHex(t, wantOpenArgs); !bytes.Equal(got, want) { + t.Fatalf("OpenArgs borsh mismatch\n got: %s\nwant: %s", hex.EncodeToString(got), wantOpenArgs) + } +} + +// TestOpenInstructionDataParity asserts the full Open instruction data wire +// bytes (1-byte discriminator || Borsh OpenArgs) match the Rust spine, going +// through the generated Open.MarshalWithEncoder path the client actually uses. +func TestOpenInstructionDataParity(t *testing.T) { + open := pc.NewOpenInstructionBuilder().SetOpenArgs(pc.OpenArgs{ + Salt: 1, + Deposit: 1_000_000, + GracePeriod: 900, + Recipients: []pc.DistributionEntry{ + {Recipient: solana.PublicKey{}, Bps: 10000}, + }, + }) + + var buf bytes.Buffer + if err := open.MarshalWithEncoder(bin.NewBorshEncoder(&buf)); err != nil { + t.Fatalf("Open.MarshalWithEncoder: %v", err) + } + + // Frozen from `[1u8] ++ borsh::to_vec(&OpenArgs{...})` against the Rust spine. + const wantOpenIx = "01010000000000000040420f0000000000840300000100000000000000000000000000000000000000000000000000000000000000000000001027" + if want := mustHex(t, wantOpenIx); !bytes.Equal(buf.Bytes(), want) { + t.Fatalf("Open instruction data mismatch\n got: %s\nwant: %s", hex.EncodeToString(buf.Bytes()), wantOpenIx) + } + if buf.Bytes()[0] != 0x01 { + t.Fatalf("first byte = %#02x, want 0x01 discriminator", buf.Bytes()[0]) + } +} + +// TestVoucherPreimageParity asserts the 48-byte voucher preimage exposed by the +// generated VoucherArgs type matches the Rust spine layout +// channel_id(32) || cumulative_amount_le(8) || expires_at_le(8). This is the +// load-bearing off-chain Ed25519 signing preimage for the session phase. +func TestVoucherPreimageParity(t *testing.T) { + voucher := pc.VoucherArgs{ + ChannelId: solana.PublicKey{}, // all-zero channel id + CumulativeAmount: 1234567, + ExpiresAt: 4102444800, // DEFAULT_SESSION_EXPIRES_AT (2100-01-01) + } + + // Frozen from `borsh::to_vec(&VoucherArgs{...})` against the Rust spine. + const wantVoucher = "000000000000000000000000000000000000000000000000000000000000000087d6120000000000005786f400000000" + got := borshEncode(t, voucher) + if len(got) != 48 { + t.Fatalf("voucher preimage = %d bytes, want 48", len(got)) + } + if want := mustHex(t, wantVoucher); !bytes.Equal(got, want) { + t.Fatalf("voucher preimage mismatch\n got: %s\nwant: %s", hex.EncodeToString(got), wantVoucher) + } + + // Pin the field offsets: cumulative_amount little-endian at byte 32, + // expires_at little-endian at byte 40. + if v := bin.LE.Uint64(got[32:40]); v != 1234567 { + t.Fatalf("cumulative_amount@32 = %d, want 1234567", v) + } + if v := int64(bin.LE.Uint64(got[40:48])); v != 4102444800 { + t.Fatalf("expires_at@40 = %d, want 4102444800", v) + } +} diff --git a/harness/go-client/main.go b/harness/go-client/main.go index 301931cd4..be2dee4e0 100644 --- a/harness/go-client/main.go +++ b/harness/go-client/main.go @@ -33,36 +33,65 @@ import ( const fixtureSettlementHeader = "x-fixture-settlement" type adapterResult struct { - Type string `json:"type"` - Implementation string `json:"implementation"` - Role string `json:"role"` - OK bool `json:"ok"` - Status int `json:"status"` + // Type is the harness message discriminator; always "result" here. + Type string `json:"type"` + // Implementation identifies the SDK under test; always "go" here. + Implementation string `json:"implementation"` + // Role is the side this adapter exercises; always "client" here. + Role string `json:"role"` + // OK reports whether the paid request ended with a 2xx status. + OK bool `json:"ok"` + // Status is the final HTTP status code of the paid request. + Status int `json:"status"` + // ResponseHeaders holds the final response headers, names lower-cased + // and multi-value headers joined with ", ". ResponseHeaders map[string]string `json:"responseHeaders"` - ResponseBody any `json:"responseBody"` - Settlement string `json:"settlement,omitempty"` + // ResponseBody is the final response body, JSON-decoded when it parses, + // otherwise the raw string. + ResponseBody any `json:"responseBody"` + // Settlement echoes the x-fixture-settlement header the fixture server + // sets with its settlement outcome; omitted when absent. + Settlement string `json:"settlement,omitempty"` } func main() { - if os.Getenv("X402_HARNESS_TARGET_URL") != "" { + switch resolveProtocolMode(os.Getenv) { + case "x402": if err := runX402Adapter(os.Stdout); err != nil { fmt.Fprintf(os.Stderr, "FAIL: %v\n", err) os.Exit(1) } - return - } - if os.Getenv("MPP_HARNESS_TARGET_URL") != "" { + case "mpp": if err := runProcessAdapter(os.Stdout); err != nil { fmt.Fprintf(os.Stderr, "FAIL: %v\n", err) os.Exit(1) } - return + default: + runLegacyHarness() + } +} + +// resolveProtocolMode picks the adapter protocol. The harness matrix injects +// BOTH MPP_HARNESS_TARGET_URL and X402_HARNESS_TARGET_URL on every client +// run, so the namespace probe alone is ambiguous: the explicit +// PAY_KIT_HARNESS_PROTOCOL hint set per scenario wins first. The probe order +// is only reached on manual runs that export a single TARGET_URL. +func resolveProtocolMode(getenv func(string) string) string { + if mode := strings.ToLower(strings.TrimSpace(getenv("PAY_KIT_HARNESS_PROTOCOL"))); mode != "" { + return mode + } + switch { + case getenv("X402_HARNESS_TARGET_URL") != "": + return "x402" + case getenv("MPP_HARNESS_TARGET_URL") != "": + return "mpp" + default: + return "" } - runLegacyHarness() } // runX402Adapter drives the x402 (exact) client against the target. It -// mirrors the Rust x402 harness_client contract: read the offer from the +// follows the x402 harness client contract: read the offer from the // 402 challenge, select by preferred network + currency order, build and // submit the Payment-Signature credential, then report the JSON result. func runX402Adapter(stdout io.Writer) error { diff --git a/harness/go-client/main_test.go b/harness/go-client/main_test.go index 07bd51a19..29a6dd858 100644 --- a/harness/go-client/main_test.go +++ b/harness/go-client/main_test.go @@ -83,3 +83,62 @@ func TestRunProcessAdapterRequiresRPCURL(t *testing.T) { t.Fatal("expected missing RPC URL to fail") } } + +// TestResolveProtocolMode pins the adapter dispatch: the harness matrix sets +// both TARGET_URL namespaces on every client run, so the explicit +// PAY_KIT_HARNESS_PROTOCOL hint must win over the namespace probe. Without +// the hint taking precedence, MPP cells run the x402 adapter and every +// positive charge scenario dies on the unanswered MPP challenge. +func TestResolveProtocolMode(t *testing.T) { + cases := []struct { + name string + env map[string]string + want string + }{ + { + name: "hint mpp wins over both target urls", + env: map[string]string{ + "PAY_KIT_HARNESS_PROTOCOL": "mpp", + "MPP_HARNESS_TARGET_URL": "http://127.0.0.1/protected", + "X402_HARNESS_TARGET_URL": "http://127.0.0.1/protected", + }, + want: "mpp", + }, + { + name: "hint x402 wins over both target urls", + env: map[string]string{ + "PAY_KIT_HARNESS_PROTOCOL": "x402", + "MPP_HARNESS_TARGET_URL": "http://127.0.0.1/protected", + "X402_HARNESS_TARGET_URL": "http://127.0.0.1/protected", + }, + want: "x402", + }, + { + name: "no hint probes x402 namespace first", + env: map[string]string{ + "X402_HARNESS_TARGET_URL": "http://127.0.0.1/protected", + }, + want: "x402", + }, + { + name: "no hint falls back to mpp namespace", + env: map[string]string{ + "MPP_HARNESS_TARGET_URL": "http://127.0.0.1/protected", + }, + want: "mpp", + }, + { + name: "no env selects the legacy harness", + env: map[string]string{}, + want: "", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := resolveProtocolMode(func(key string) string { return tc.env[key] }) + if got != tc.want { + t.Fatalf("resolveProtocolMode = %q, want %q", got, tc.want) + } + }) + } +} diff --git a/harness/runners/go.json b/harness/runners/go.json index 822e306a2..6fb39fa98 100644 --- a/harness/runners/go.json +++ b/harness/runners/go.json @@ -1,5 +1,6 @@ { "language": "go", "command": ["go", "run", "./cmd/conformance"], - "cwd": "go" + "cwd": "go", + "intents": ["charge", "x402-exact", "session"] } diff --git a/harness/src/conformance/contract-schema.ts b/harness/src/conformance/contract-schema.ts index 9637a5e67..799c4da26 100644 --- a/harness/src/conformance/contract-schema.ts +++ b/harness/src/conformance/contract-schema.ts @@ -131,7 +131,7 @@ export const conformanceVectorSchema = { required: ["id", "intent", "mode", "input", "expect"], properties: { id: { type: "string" }, - intent: { enum: ["charge", "x402-exact"] }, + intent: { enum: ["charge", "x402-exact", "session"] }, mode: { enum: ["build-transaction", "verify-transaction", "canonical-bytes"], }, diff --git a/harness/src/conformance/runners.ts b/harness/src/conformance/runners.ts index ae73bbd10..27a7ce3ba 100644 --- a/harness/src/conformance/runners.ts +++ b/harness/src/conformance/runners.ts @@ -22,8 +22,18 @@ export type RunnerManifest = { command: string[]; // Working directory relative to the repo root. Defaults to the repo root. cwd?: string; + // Intents this runner can exercise. When omitted, the runner is assumed to + // support the original cross-SDK intents ("charge", "x402-exact"); a vector + // whose intent is not listed is skipped for this runner rather than failed. + // This lets a new intent (e.g. "session") land with only the SDKs that + // implement it, without editing every other language's runner. + intents?: string[]; }; +// The intents every runner is assumed to support when its manifest does not +// declare an explicit `intents` list. +const DEFAULT_INTENTS = ["charge", "x402-exact"]; + const here = dirname(fileURLToPath(import.meta.url)); const repoRoot = join(here, "..", "..", ".."); const manifestsDir = join(here, "..", "..", "runners"); @@ -35,6 +45,12 @@ function isRunnerManifest(value: unknown): value is RunnerManifest { if (!Array.isArray(m.command) || m.command.length === 0) return false; if (!m.command.every((c) => typeof c === "string")) return false; if (m.cwd !== undefined && typeof m.cwd !== "string") return false; + if ( + m.intents !== undefined && + (!Array.isArray(m.intents) || !m.intents.every((i) => typeof i === "string")) + ) { + return false; + } return true; } @@ -43,6 +59,8 @@ export type DiscoveredRunner = { command: string[]; // Absolute working directory the driver spawns the runner in. cwd: string; + // Resolved intent capabilities (manifest `intents` or the default set). + intents: string[]; }; // Discover every runner manifest under harness/runners/, validate it, and @@ -64,6 +82,7 @@ export function discoverRunners(): DiscoveredRunner[] { language: parsed.language, command: parsed.command, cwd: parsed.cwd ? join(repoRoot, parsed.cwd) : repoRoot, + intents: parsed.intents ?? DEFAULT_INTENTS, }); } return runners; diff --git a/harness/src/conformance/schema.ts b/harness/src/conformance/schema.ts index bb4544ab4..20f9e6d0a 100644 --- a/harness/src/conformance/schema.ts +++ b/harness/src/conformance/schema.ts @@ -236,6 +236,17 @@ export type VectorInput = { opaque?: string; }; + // canonical-bytes (session): the 48-byte Ed25519 voucher preimage + // `channelId(32, base58) || cumulativeAmount LE u64 || expiresAt LE i64`. + // The runner emits it as exactBytes (hex/bytes/base64Url). This pins the + // single most load-bearing session invariant byte-for-byte across SDKs. + // Mirrors the program voucher_message_bytes layout. + voucherPreimage?: { + channelId: string; + cumulativeAmount: string; + expiresAt: number; + }; + // ── x402-exact inputs ──────────────────────────────────────────────── // build-transaction (x402): the offer the client selects + wraps into a // payment header. The runner emits the decoded X402EnvelopeShape. @@ -280,7 +291,7 @@ export type VectorInput = { export type ConformanceVector = { id: string; - intent: "charge" | "x402-exact"; + intent: "charge" | "x402-exact" | "session"; mode: VectorMode; description?: string; input: VectorInput; diff --git a/harness/test/conformance.test.ts b/harness/test/conformance.test.ts index 22d01b435..aa5e86f09 100644 --- a/harness/test/conformance.test.ts +++ b/harness/test/conformance.test.ts @@ -272,10 +272,18 @@ describe("cross-SDK conformance vectors", () => { expect(modes.has("canonical-bytes")).toBe(true); }); - for (const { language, command, cwd: runnerCwd } of RUNNERS) { + for (const { language, command, cwd: runnerCwd, intents } of RUNNERS) { describe(`${language} reference runner`, () => { for (const vector of vectors) { it(`${vector.id} (${vector.mode}) -> ${vector.expect.outcome}`, async (ctx) => { + // Skip vectors for an intent this runner does not declare. Lets a new + // intent (e.g. "session") land with only the SDKs that implement it; + // runners without an explicit `intents` list default to the original + // cross-SDK set ("charge", "x402-exact"). + if (!intents.includes(vector.intent)) { + ctx.skip(); + return; + } const result = await runVector(command, vector, runnerCwd); expect(result.id).toBe(vector.id); diff --git a/harness/vectors/session-voucher.json b/harness/vectors/session-voucher.json new file mode 100644 index 000000000..ee97e9685 --- /dev/null +++ b/harness/vectors/session-voucher.json @@ -0,0 +1,41 @@ +[ + { + "id": "session-voucher-preimage-frozen", + "intent": "session", + "mode": "canonical-bytes", + "description": "48-byte Ed25519 voucher preimage channelId(32)||cumulative LE u64||expiresAt LE i64. Frozen cross-SDK vector (matches the rust/Go/Python unit vectors). Pins the single most load-bearing session invariant byte-for-byte.", + "input": { + "voucherPreimage": { + "channelId": "cGfHiC6Kgg3FpFZvgwGcswsCRtp4aBP2fzuXRQPizuN", + "cumulativeAmount": "42", + "expiresAt": 1234 + } + }, + "expect": { + "outcome": "accept", + "exactBytes": { + "bytes": [9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,42,0,0,0,0,0,0,0,210,4,0,0,0,0,0,0], + "base64Url": "CQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkJCQkqAAAAAAAAANIEAAAAAAAA" + } + } + }, + { + "id": "session-voucher-preimage-large-cumulative", + "intent": "session", + "mode": "canonical-bytes", + "description": "Voucher preimage with a near-u64-max cumulative, asserting little-endian u64 packing has no precision loss.", + "input": { + "voucherPreimage": { + "channelId": "cGfHiC6Kgg3FpFZvgwGcswsCRtp4aBP2fzuXRQPizuN", + "cumulativeAmount": "18446744073709551607", + "expiresAt": 4102444800 + } + }, + "expect": { + "outcome": "accept", + "exactBytes": { + "bytes": [9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,247,255,255,255,255,255,255,255,0,87,134,244,0,0,0,0] + } + } + } +] diff --git a/skills/pay-sdk-implementation/codegen/generate-payment-channels-client-go.ts b/skills/pay-sdk-implementation/codegen/generate-payment-channels-client-go.ts new file mode 100644 index 000000000..3e42f89d4 --- /dev/null +++ b/skills/pay-sdk-implementation/codegen/generate-payment-channels-client-go.ts @@ -0,0 +1,57 @@ +/** + * Generate the pay-kit payment-channels Go client from the upstream + * `Moonsong-Labs/solana-payment-channels` Codama IDL. + * + * Mirrors generate-payment-channels-client.ts (the Rust path) — both scripts + * vendor the IDL at `/idl/payment-channels.json` and render a client + * into the matching SDK tree. This one targets the Go SDK via + * `@codama/renderers-go`, which emits a flat Go package using + * github.com/gagliardetto/{solana-go,binary} (already pay-kit Go deps). + * + * The renderer derives the Go package name from the IDL program name + * (`paymentChannels` → `payment_channels`). We render into a directory named + * `paymentchannels/` to keep a clean, import-friendly path that mirrors the + * rust `crates/programs/payment-channels/generated` layout. + * + * Output: + * go/protocols/programs/paymentchannels/ (rendered by Codama) + */ +import type { AnchorIdl } from '@codama/nodes-from-anchor'; +import { renderVisitor as renderGoVisitor } from '@codama/renderers-go'; +import { createFromJson } from 'codama'; +import fs from 'node:fs'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); +// Script lives at skills/pay-sdk-implementation/codegen/ — climb three +// levels to land at the repository root. +const repoRoot = path.resolve(__dirname, '..', '..', '..'); + +const idlPath = path.join(repoRoot, 'idl', 'payment-channels.json'); +const goClientDir = path.join(repoRoot, 'go', 'protocols', 'programs', 'paymentchannels'); + +if (!fs.existsSync(idlPath)) { + console.error(`[codegen] IDL not found at ${idlPath}`); + console.error(`[codegen] Run \`just payment-channels-pull-idl\` first to fetch it from upstream.`); + process.exit(1); +} + +const idl = JSON.parse(fs.readFileSync(idlPath, 'utf-8')) as AnchorIdl; +const codama = createFromJson(JSON.stringify(idl)); + +console.log(`[codegen] Rendering Go client from ${path.relative(repoRoot, idlPath)}`); +console.log(`[codegen] → ${path.relative(repoRoot, goClientDir)}/`); + +void codama.accept( + renderGoVisitor(goClientDir, { + // Codama re-renders into the target folder on every run; pre-clearing + // means a removed instruction in the upstream IDL also disappears here + // on regeneration. + deleteFolderBeforeRendering: true, + // gofmt the emitted Go so `gofmt -l` stays clean on the generated tree. + formatCode: true, + }), +); + +console.log(`[codegen] Done.`); diff --git a/skills/pay-sdk-implementation/codegen/package.json b/skills/pay-sdk-implementation/codegen/package.json index d80787217..12ba96ecc 100644 --- a/skills/pay-sdk-implementation/codegen/package.json +++ b/skills/pay-sdk-implementation/codegen/package.json @@ -5,10 +5,12 @@ "description": "Codama codegen tooling for pay-kit. Pulls IDL files from upstream Solana programs and generates per-language clients into the appropriate SDK trees.", "scripts": { "subscriptions:rust": "tsx ./generate-subscriptions-client.ts", - "payment-channels:rust": "tsx ./generate-payment-channels-client.ts" + "payment-channels:rust": "tsx ./generate-payment-channels-client.ts", + "payment-channels:go": "tsx ./generate-payment-channels-client-go.ts" }, "dependencies": { "@codama/nodes-from-anchor": "^1.4.1", + "@codama/renderers-go": "^2.0.0", "@codama/renderers-rust": "^3.1.0", "codama": "^1.6.0" }, diff --git a/skills/pay-sdk-implementation/codegen/pnpm-lock.yaml b/skills/pay-sdk-implementation/codegen/pnpm-lock.yaml index 614854c0d..3de1d8dbd 100644 --- a/skills/pay-sdk-implementation/codegen/pnpm-lock.yaml +++ b/skills/pay-sdk-implementation/codegen/pnpm-lock.yaml @@ -11,6 +11,9 @@ importers: '@codama/nodes-from-anchor': specifier: ^1.4.1 version: 1.5.0(typescript@5.9.3) + '@codama/renderers-go': + specifier: ^2.0.0 + version: 2.0.0(typescript@5.9.3) '@codama/renderers-rust': specifier: ^3.1.0 version: 3.1.0(typescript@5.9.3) @@ -50,6 +53,10 @@ packages: '@codama/renderers-core@1.3.8': resolution: {integrity: sha512-xy9Qb5BLYTi1OyvlRhRD7n0HUevOQ3QcHSPq9N3kqoUOgL2ziXPXvoejzzLC0OkvA16M7WvK3ihNx/nf4UEClQ==} + '@codama/renderers-go@2.0.0': + resolution: {integrity: sha512-RL/S2uLogQoa8uceassQhQseapOU+Cv46rI/OqBySV+6ncORboBbViYSA4/YXHKvFnslVwMv94DmDbTA/6dzRg==} + engines: {node: '>=20.18.0'} + '@codama/renderers-rust@3.1.0': resolution: {integrity: sha512-E/GSUCuiIpFj+ij3NbduH/h3sNDo39Bq14vj2atxdbbrPmu4clWvIEjXtbmP03qhudH73TxbYO8dWg/NwRi18A==} engines: {node: '>=20.18.0'} @@ -534,6 +541,19 @@ snapshots: '@codama/nodes': 1.7.0 '@codama/visitors-core': 1.7.0 + '@codama/renderers-go@2.0.0(typescript@5.9.3)': + dependencies: + '@codama/errors': 1.7.0 + '@codama/nodes': 1.7.0 + '@codama/renderers-core': 1.3.8 + '@codama/visitors-core': 1.7.0 + '@solana/codecs-strings': 6.9.0(typescript@5.9.3) + nunjucks: 3.2.4 + transitivePeerDependencies: + - chokidar + - fastestsmallesttextencoderdecoder + - typescript + '@codama/renderers-rust@3.1.0(typescript@5.9.3)': dependencies: '@codama/errors': 1.7.0