diff --git a/routing/pathfind.go b/routing/pathfind.go index e2ae5d4caa3..9d90e19bed1 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -50,11 +50,45 @@ const ( fakeHopHintCapacity = btcutil.Amount(10 * btcutil.SatoshiPerBitcoin) ) -// pathFinder defines the interface of a path finding algorithm. +// RouteOrigin determines where routes can originate from. The backward +// Dijkstra terminates when it reaches any origin vertex. This is the +// source-end counterpart to AdditionalEdge, which extends the graph at the +// destination end. Standard lnd uses singleOrigin (one source node). A +// multi-backend payment service can provide a multi-source implementation +// that terminates at any of its gateway nodes. +// +// NOTE: Only include vertices the caller can actually dispatch payments from. +// Circular self-payments (route-to-self) are only supported with the built-in +// singleOrigin. +type RouteOrigin interface { + // IsOrigin reports whether the given vertex is a valid route starting + // point. + // + // NOTE: Implementations should be O(1). findPath calls IsOrigin once + // per heap pop and once per edge relaxation, so any per-call cost + // directly contributes to path-finding latency. + IsOrigin(v route.Vertex) bool +} + +// singleOrigin is the default RouteOrigin: a single source vertex. +type singleOrigin struct { + source route.Vertex +} + +// IsOrigin reports whether v is the source vertex. +func (s *singleOrigin) IsOrigin(v route.Vertex) bool { + return v == s.source +} + +// pathFinder defines the interface of a path finding algorithm. The first +// return value is the source vertex of the computed path. This is typically +// the node's own key, but it may be an arbitrary source or, for multi-origin +// callers, whichever origin provides the cheapest path. type pathFinder = func(g *graphParams, r *RestrictParams, - cfg *PathFindingConfig, self, source, target route.Vertex, - amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) ( - []*unifiedEdge, float64, error) + cfg *PathFindingConfig, self route.Vertex, origin RouteOrigin, + target route.Vertex, amt lnwire.MilliSatoshi, timePref float64, + finalHtlcExpiry int32) ( + route.Vertex, []*unifiedEdge, float64, error) var ( // DefaultEstimator is the default estimator used for computing @@ -601,9 +635,9 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, // path and accurately check the amount to forward at every node against the // available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, - self, source, target route.Vertex, amt lnwire.MilliSatoshi, - timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64, - error) { + self route.Vertex, origin RouteOrigin, target route.Vertex, + amt lnwire.MilliSatoshi, timePref float64, + finalHtlcExpiry int32) (route.Vertex, []*unifiedEdge, float64, error) { // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to @@ -626,7 +660,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, context.TODO(), target, ) if err != nil { - return nil, 0, err + return route.Vertex{}, nil, 0, err } } @@ -635,14 +669,14 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, err := feature.ValidateRequired(features) if err != nil { log.Warnf("Pathfinding destination node features: %v", err) - return nil, 0, errUnknownRequiredFeature + return route.Vertex{}, nil, 0, errUnknownRequiredFeature } // Ensure that all transitive dependencies are set. err = feature.ValidateDeps(features) if err != nil { log.Warnf("Pathfinding destination node features: %v", err) - return nil, 0, errMissingDependentFeature + return route.Vertex{}, nil, 0, errMissingDependentFeature } // Now that we know the feature vector is well-formed, we'll proceed in @@ -652,7 +686,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, if r.PaymentAddr.IsSome() && !features.HasFeature(lnwire.PaymentAddrOptional) { - return nil, 0, errNoPaymentAddr + return route.Vertex{}, nil, 0, errNoPaymentAddr } // Set up outgoing channel map for quicker access. @@ -665,13 +699,15 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } // If we are routing from ourselves, check that we have enough local - // balance available. - if source == self { + // balance available. This check is skipped when self is not in the + // origin set (e.g. multi-origin), since local balance information is + // not available for remote origin nodes. + if origin.IsOrigin(self) { max, total, err := getOutgoingBalance( self, outgoingChanMap, g.bandwidthHints, g.graph, ) if err != nil { - return nil, 0, err + return route.Vertex{}, nil, 0, err } // If the total outgoing balance isn't sufficient, it will be @@ -681,13 +717,13 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, "htlc of amount: %v, only have local "+ "balance: %v", amt, total) - return nil, 0, errInsufficientBalance + return route.Vertex{}, nil, 0, errInsufficientBalance } // If there is only not enough capacity on a single route, it // may still be possible to complete the payment by splitting. if max < amt { - return nil, 0, errNoPathFound + return route.Vertex{}, nil, 0, errNoPathFound } } @@ -729,7 +765,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // and depends on whether the destination is blinded or not. lastHopPayloadSize, err := lastHopPayloadSize(r, finalHtlcExpiry, amt) if err != nil { - return nil, 0, err + return route.Vertex{}, nil, 0, err } // We can't always assume that the end destination is publicly @@ -763,8 +799,9 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Validate time preference value. if math.Abs(timePref) > 1 { - return nil, 0, fmt.Errorf("time preference %v out of range "+ - "[-1, 1]", timePref) + return route.Vertex{}, nil, 0, fmt.Errorf( + "time preference %v out of range [-1, 1]", timePref, + ) } // Scale to avoid the extremes -1 and 1 which run into infinity issues. @@ -857,7 +894,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, outboundFee int64 ) - if fromVertex != source { + if !origin.IsOrigin(fromVertex) { outboundFee = int64( edge.policy.ComputeFee(amountToSend), ) @@ -956,7 +993,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // little inaccuracy here because we are over estimating by // 1 hop. var payloadSize uint64 - if fromVertex != source { + if !origin.IsOrigin(fromVertex) { // In case the unifiedEdge does not have a payload size // function supplied we request a graceful shutdown // because this should never happen. @@ -1051,7 +1088,13 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, return fromFeatures, nil } - routeToSelf := source == target + // Allow circular routes only for single-origin self-payments + // (e.g., rebalancing). This lets Dijkstra explore past the target + // on first visit rather than terminating immediately. For + // multi-origin, the target may happen to be in the origin set + // but we still want a direct route from another origin. + _, isSingle := origin.(*singleOrigin) + routeToSelf := isSingle && origin.IsOrigin(target) for { nodesVisited++ @@ -1066,7 +1109,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, err := u.addGraphPolicies(g.graph) if err != nil { - return nil, 0, err + return route.Vertex{}, nil, 0, err } // We add hop hints that were supplied externally. @@ -1127,7 +1170,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Get feature vector for fromNode. fromFeatures, err := getGraphFeatures(fromNode) if err != nil { - return nil, 0, err + return route.Vertex{}, nil, 0, err } // If there are no valid features, skip this node. @@ -1148,14 +1191,21 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // from the heap. partialPath = heap.Pop(&nodeHeap).(*nodeWithDist) - // If we've reached our source (or we don't have any incoming - // edges), then we're done here and can exit the graph - // traversal early. - if partialPath.node == source { + // If we've reached a valid origin (or we don't have any + // incoming edges), then we're done here and can exit the + // graph traversal early. + if origin.IsOrigin(partialPath.node) { break } } + // The path finding loop exits either when it reaches a valid origin or + // when the heap empties. In the latter case, no path exists. + source := partialPath.node + if !origin.IsOrigin(source) { + return route.Vertex{}, nil, 0, errNoPathFound + } + // Use the distance map to unravel the forward path from source to // target. var pathEdges []*unifiedEdge @@ -1166,7 +1216,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, if !ok { // If the node doesn't have a next hop it means we // didn't find a path. - return nil, 0, errNoPathFound + return route.Vertex{}, nil, 0, errNoPathFound } // Add the next hop to the list of path edges. @@ -1200,7 +1250,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, distance[source].probability, len(pathEdges), distance[source].netAmountReceived-amt) - return pathEdges, distance[source].probability, nil + return source, pathEdges, distance[source].probability, nil } // blindedPathRestrictions are a set of constraints to adhere to when diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 85689ef9ed1..ae244f8f72a 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -889,6 +889,12 @@ func TestPathFinding(t *testing.T) { }, { name: "route to self", fn: runRouteToSelf, + }, { + name: "multi origin", + fn: runMultiOrigin, + }, { + name: "multi origin cheapest path", + fn: runMultiOriginCheapestPath, }, { name: "with metadata", fn: runFindPathWithMetadata, @@ -3073,6 +3079,352 @@ func runRouteToSelf(t *testing.T, useCache bool) { ctx.assertPath(path, []uint64{1, 3, 2}) } +// multiOrigin is a RouteOrigin that terminates at any vertex in the set. This +// is the multi-source variant for external payment controllers that dispatch +// from multiple gateway nodes. +type multiOrigin struct { + sources map[route.Vertex]struct{} +} + +func (m *multiOrigin) IsOrigin(v route.Vertex) bool { + _, ok := m.sources[v] + return ok +} + +// findPathWithOrigin is a test helper that runs findPath with a given +// RouteOrigin and returns the settled source vertex alongside the path. +func findPathWithOrigin(t *testing.T, ctx *pathFindingTestContext, + origin RouteOrigin, target route.Vertex, + amt lnwire.MilliSatoshi) (route.Vertex, []*unifiedEdge, error) { + + t.Helper() + + sourceNode, err := ctx.v1Graph.SourceNode(t.Context()) + require.NoError(t, err) + + var ( + source route.Vertex + path []*unifiedEdge + ) + err = ctx.v1Graph.GraphSession( + t.Context(), + func(graph graphdb.NodeTraverser) error { + source, path, _, err = findPath( + &graphParams{ + bandwidthHints: ctx.bandwidthHints, + graph: graph, + }, + &ctx.restrictParams, + &ctx.pathFindingConfig, + sourceNode.PubKeyBytes, + origin, target, + amt, 0, 0, + ) + + return err + }, func() { + path = nil + }, + ) + + return source, path, err +} + +// runMultiOrigin tests that the pathfinder correctly terminates at the nearest +// origin when given a RouteOrigin containing multiple valid source vertices. +// This exercises the multi-source Dijkstra behavior needed by an external +// payment controller that dispatches from multiple gateway nodes. +func runMultiOrigin(t *testing.T, useCache bool) { + // Build a diamond-shaped network with two possible origins: + // + // gw1 ---- alice ---- dest + // gw2 ---- bob ------/ + // + // Both gw1 and gw2 are valid origins. Since origins are fee-exempt + // (the sender doesn't pay its own forwarding fee), we differentiate + // the paths by the intermediate hop's fee: alice charges 500 msat + // while bob charges 2000 msat. + testChannels := []*testChannel{ + symmetricTestChannel("gw1", "alice", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 500, + }, 1, + ), + symmetricTestChannel("gw2", "bob", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 500, + }, 2, + ), + // alice->dest is cheap (500 msat). + symmetricTestChannel("alice", "dest", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 500, + }, 3, + ), + // bob->dest is expensive (2000 msat). + symmetricTestChannel("bob", "dest", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 2000, + }, 4, + ), + } + + ctx := newPathFindingTestContext(t, useCache, testChannels, "gw1") + + gw1 := ctx.keyFromAlias("gw1") + gw2 := ctx.keyFromAlias("gw2") + target := ctx.keyFromAlias("dest") + paymentAmt := lnwire.NewMSatFromSatoshis(100) + + // With both gateways available, the pathfinder should select + // gw1->alice->dest since alice charges less than bob. + bothOrigins := &multiOrigin{sources: map[route.Vertex]struct{}{ + gw1: {}, + gw2: {}, + }} + source, path, err := findPathWithOrigin( + t, ctx, bothOrigins, target, paymentAmt, + ) + require.NoError(t, err, "unable to find multi-origin path") + require.Equal(t, gw1, source, "expected gw1 as selected route source") + assertExpectedPath( + t, ctx.testGraphInstance.aliasMap, path, "alice", "dest", + ) + + // Simulate gw1 going offline by removing it from the origin set. + // The pathfinder should fall back to gw2->bob->dest. + gw2Only := &multiOrigin{sources: map[route.Vertex]struct{}{ + gw2: {}, + }} + source, path, err = findPathWithOrigin( + t, ctx, gw2Only, target, paymentAmt, + ) + require.NoError(t, err, "unable to find path via gw2") + require.Equal(t, gw2, source, "expected gw2 as selected route source") + assertExpectedPath( + t, ctx.testGraphInstance.aliasMap, path, "bob", "dest", + ) + + // An empty origin set should return errNoPathFound, since the + // path-finding loop will exhaust the heap without reaching any origin. + emptyOrigin := &multiOrigin{sources: map[route.Vertex]struct{}{}} + _, _, err = findPathWithOrigin( + t, ctx, emptyOrigin, target, paymentAmt, + ) + require.ErrorIs(t, err, errNoPathFound) + + // When the target is also an origin (circular payment scenario), the + // pathfinder should still find a valid route. + targetIsOrigin := &multiOrigin{sources: map[route.Vertex]struct{}{ + gw1: {}, + target: {}, + }} + source, path, err = findPathWithOrigin( + t, ctx, targetIsOrigin, target, paymentAmt, + ) + require.NoError(t, err, "unable to find path when target is origin") + require.Equal(t, gw1, source, "expected gw1 as selected route source") + assertExpectedPath( + t, ctx.testGraphInstance.aliasMap, path, "alice", "dest", + ) +} + +// runMultiOriginCheapestPath proves that multi-origin termination finds the +// globally cheapest path by cross-validating against exhaustive single-origin +// searches. This addresses the question: "is it enough to halt when you've +// found one of the origin nodes?" The answer is yes, because Dijkstra's +// min-heap ordering guarantees the first origin popped has the minimum cost. +// +// The test builds a network with three gateways at varying distances and fee +// levels, runs findPath once with all three as a multi-origin set, then runs +// findPath separately for each gateway as a single origin. The multi-origin +// result must match the cheapest individual result. +func runMultiOriginCheapestPath(t *testing.T, useCache bool) { + // Build a network where the cheapest origin is NOT the one with the + // fewest intermediate hops: + // + // gw1 ---- cheap1 ---- cheap2 ---- dest (2 intermediaries) + // gw2 ---- expensive ---- dest (1 intermediary) + // gw3 ---- medium1 ---- medium2 ---- dest (2 intermediaries) + // + // gw2's path has the fewest hops, but gw1's path is cheapest in fees. + // Origins are fee-exempt (the sender doesn't pay its own forwarding + // fee), so cost differentiation comes entirely from the intermediate + // hops. The cheap1/cheap2 edges each charge 100 msat base fee, the + // expensive edge charges 5000, and the medium edges charge 800 each. + testChannels := []*testChannel{ + // gw1's path: cheap intermediaries. + symmetricTestChannel("gw1", "cheap1", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 100, + }, 1, + ), + symmetricTestChannel("cheap1", "cheap2", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 100, + }, 2, + ), + symmetricTestChannel("cheap2", "dest", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 100, + }, 3, + ), + // gw2's path: fewest hops but expensive intermediary. + symmetricTestChannel("gw2", "expensive", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 5000, + }, 4, + ), + symmetricTestChannel("expensive", "dest", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 5000, + }, 5, + ), + // gw3's path: medium-cost intermediaries. + symmetricTestChannel("gw3", "medium1", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 800, + }, 6, + ), + symmetricTestChannel("medium1", "medium2", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 800, + }, 7, + ), + symmetricTestChannel("medium2", "dest", 100000, + &testChannelPolicy{ + Expiry: 144, + FeeBaseMsat: 800, + }, 8, + ), + } + + ctx := newPathFindingTestContext(t, useCache, testChannels, "gw1") + + gw1 := ctx.keyFromAlias("gw1") + gw2 := ctx.keyFromAlias("gw2") + gw3 := ctx.keyFromAlias("gw3") + target := ctx.keyFromAlias("dest") + paymentAmt := lnwire.NewMSatFromSatoshis(1000) + + const ( + startingHeight = 100 + finalHopCLTV = 1 + ) + + // buildRoute converts a findPath result into a route so we can + // compare TotalFees across origins. + buildRoute := func(source route.Vertex, + path []*unifiedEdge) *route.Route { + + r, err := newRoute( + source, path, startingHeight, + finalHopParams{ + amt: paymentAmt, + cltvDelta: finalHopCLTV, + records: nil, + }, nil, + ) + require.NoError(t, err) + + return r + } + + // Run multi-origin search with all three gateways. + allOrigins := &multiOrigin{sources: map[route.Vertex]struct{}{ + gw1: {}, + gw2: {}, + gw3: {}, + }} + multiSource, multiPath, err := findPathWithOrigin( + t, ctx, allOrigins, target, paymentAmt, + ) + require.NoError(t, err, "multi-origin findPath failed") + multiRoute := buildRoute(multiSource, multiPath) + + // Run single-origin search for each gateway independently + // and record the total fees for each. + type singleResult struct { + name string + vertex route.Vertex + fees lnwire.MilliSatoshi + } + gateways := []struct { + name string + vertex route.Vertex + }{ + {"gw1", gw1}, + {"gw2", gw2}, + {"gw3", gw3}, + } + + var results []singleResult + for _, gw := range gateways { + origin := &singleOrigin{source: gw.vertex} + src, path, err := findPathWithOrigin( + t, ctx, origin, target, paymentAmt, + ) + require.NoError(t, err, "single-origin findPath failed for %s", + gw.name) + + r := buildRoute(src, path) + results = append(results, singleResult{ + name: gw.name, + vertex: gw.vertex, + fees: r.TotalFees(), + }) + } + + // Find the cheapest single-origin result. + cheapest := results[0] + for _, r := range results[1:] { + if r.fees < cheapest.fees { + cheapest = r + } + } + + // Assert multi-origin picked the same gateway as the cheapest + // individual search. This is the core property: early termination in + // multi-origin Dijkstra finds the globally optimal origin. + require.Equal(t, cheapest.vertex, multiSource, + "multi-origin selected %s but cheapest individual origin is "+ + "%s (fees: multi=%v, gw1=%v, gw2=%v, gw3=%v)", + multiSource, cheapest.name, multiRoute.TotalFees(), + results[0].fees, results[1].fees, results[2].fees, + ) + + // Assert the total fees match exactly. + require.Equal(t, cheapest.fees, multiRoute.TotalFees(), + "multi-origin fees should equal cheapest single-origin fees", + ) + + // Sanity check the expected winner and path. + require.Equal(t, gw1, multiSource, + "expected gw1 as cheapest origin") + assertExpectedPath( + t, ctx.testGraphInstance.aliasMap, multiPath, + "cheap1", "cheap2", "dest", + ) + + // Log the fee comparison for visibility. + t.Logf("Multi-origin fees: %v (via %s)", multiRoute.TotalFees(), + cheapest.name) + for _, r := range results { + t.Logf(" %s single-origin fees: %v", r.name, r.fees) + } +} + // runInboundFees tests whether correct routes are built when inbound fees // apply. func runInboundFees(t *testing.T, useCache bool) { @@ -3327,13 +3679,14 @@ func dbFindPath(graph *graphdb.VersionedGraph, var route []*unifiedEdge err = graph.GraphSession(ctx, func(graph graphdb.NodeTraverser) error { - route, _, err = findPath( + _, route, _, err = findPath( &graphParams{ additionalEdges: additionalEdges, bandwidthHints: bandwidthHints, graph: graph, }, - r, cfg, sourceNode.PubKeyBytes, source, target, amt, + r, cfg, sourceNode.PubKeyBytes, + &singleOrigin{source}, target, amt, timePref, finalHtlcExpiry, ) diff --git a/routing/payment_session.go b/routing/payment_session.go index 4cddfa2eaca..d07c0bd3e76 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -6,6 +6,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btclog/v2" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" @@ -195,6 +196,35 @@ type paymentSession struct { // log is a payment session-specific logger. log btclog.Logger + + // opts holds optional configuration for the payment session. + opts sessionOptions +} + +// sessionOptions holds optional configuration for a payment session. +type sessionOptions struct { + // origin is an optional RouteOrigin that determines where routes can + // start. When set, the pathfinder terminates at any vertex for which + // IsOrigin returns true. + origin fn.Option[RouteOrigin] +} + +// defaultSessionOptions returns sessionOptions with default values. +func defaultSessionOptions() sessionOptions { + return sessionOptions{} +} + +// sessionOption is a functional option for configuring a payment session. +type sessionOption func(*sessionOptions) + +// withOrigin sets the RouteOrigin for this payment session. +func withOrigin(o RouteOrigin) sessionOption { + return func(opts *sessionOptions) { + if o == nil { + return + } + opts.origin = fn.Some(o) + } } // newPaymentSession instantiates a new payment session. @@ -202,7 +232,8 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, getBandwidthHints func(Graph) (bandwidthHints, error), graphSessFactory GraphSessionFactory, missionControl MissionControlQuerier, - pathFindingConfig PathFindingConfig) (*paymentSession, error) { + pathFindingConfig PathFindingConfig, + options ...sessionOption) (*paymentSession, error) { edges, err := RouteHintsToEdges(p.RouteHints, p.Target) if err != nil { @@ -223,6 +254,11 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, logPrefix := fmt.Sprintf("PaymentSession(%x):", p.Identifier()) + opts := defaultSessionOptions() + for _, o := range options { + o(&opts) + } + return &paymentSession{ selfNode: selfNode, additionalEdges: edges, @@ -234,6 +270,7 @@ func newPaymentSession(p *LightningPayment, selfNode route.Vertex, missionControl: missionControl, minShardAmt: DefaultShardMinAmt, log: log.WithPrefix(logPrefix), + opts: opts, }, nil } @@ -308,7 +345,10 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, maxAmt = *p.payment.MaxShardAmt } - var path []*unifiedEdge + var ( + sourceVertex route.Vertex + path []*unifiedEdge + ) findPath := func(graph graphdb.NodeTraverser) error { // We'll also obtain a set of bandwidthHints from the lower // layer for each of our outbound channels. This will allow the @@ -323,15 +363,21 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, p.log.Debugf("pathfinding for amt=%v", maxAmt) + // Use the configured origin if one was provided, + // otherwise default to the session's own node. + origin := p.opts.origin.UnwrapOr( + &singleOrigin{p.selfNode}, + ) + // Find a route for the current amount. - path, _, err = p.pathFinder( + sourceVertex, path, _, err = p.pathFinder( &graphParams{ additionalEdges: p.additionalEdges, bandwidthHints: bandwidthHints, graph: graph, }, restrictions, &p.pathFindingConfig, - p.selfNode, p.selfNode, p.payment.Target, + p.selfNode, origin, p.payment.Target, maxAmt, p.payment.TimePref, finalHtlcExpiry, ) if err != nil { @@ -347,6 +393,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, err := p.graphSessFactory.GraphSession( context.TODO(), findPath, func() { + sourceVertex = route.Vertex{} path = nil }, ) @@ -440,7 +487,7 @@ func (p *paymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi, // this into a route by applying the time-lock and fee // requirements. route, err := newRoute( - p.selfNode, path, height, + sourceVertex, path, height, finalHopParams{ amt: maxAmt, totalAmt: p.payment.Amount, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index 15820059d1b..c7eec543b49 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -44,6 +44,11 @@ type SessionSource struct { // PathFindingConfig defines global parameters that control the // trade-off in path finding between fees and probability. PathFindingConfig PathFindingConfig + + // Origin is an optional RouteOrigin that determines where routes can + // start. When set, the pathfinder terminates at any vertex for which + // IsOrigin returns true. When unset, routes originate from SourceNode. + Origin fn.Option[RouteOrigin] } // NewPaymentSession creates a new payment session backed by the latest prune @@ -62,9 +67,15 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment, ) } + var options []sessionOption + m.Origin.WhenSome(func(o RouteOrigin) { + options = append(options, withOrigin(o)) + }) + session, err := newPaymentSession( p, m.SourceNode.PubKeyBytes, getBandwidthHints, m.GraphSessionFactory, m.MissionControl, m.PathFindingConfig, + options..., ) if err != nil { return nil, err diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 0bc0b6dcbd6..ead48a7dff1 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -209,9 +209,9 @@ func TestRequestRoute(t *testing.T) { // Override pathfinder with a mock. session.pathFinder = func(_ *graphParams, r *RestrictParams, - _ *PathFindingConfig, _, _, _ route.Vertex, - _ lnwire.MilliSatoshi, _ float64, _ int32) ([]*unifiedEdge, - float64, error) { + _ *PathFindingConfig, self route.Vertex, _ RouteOrigin, + _ route.Vertex, _ lnwire.MilliSatoshi, _ float64, + _ int32) (route.Vertex, []*unifiedEdge, float64, error) { // We expect find path to receive a cltv limit excluding the // final cltv delta (including the block padding). @@ -232,7 +232,7 @@ func TestRequestRoute(t *testing.T) { }, } - return path, 1.0, nil + return self, path, 1.0, nil } route, err := session.RequestRoute( @@ -253,6 +253,97 @@ func TestRequestRoute(t *testing.T) { } } +// TestRequestRouteWithOrigin verifies that a custom RouteOrigin provided via +// the withOrigin option is forwarded to the pathfinder and that the source +// vertex it returns is used as the route's SourcePubKey. +func TestRequestRouteWithOrigin(t *testing.T) { + t.Parallel() + + // selfNode is the local coordinator node. It is deliberately different + // from the gateway vertex that the custom origin will select, so we + // can verify the returned route uses the gateway, not self. + selfNode := route.Vertex{0xaa} + gatewayNode := route.Vertex{0xbb} + + payment := &LightningPayment{ + Amount: 1000, + FeeLimit: 1000, + } + + var paymentHash [32]byte + err := payment.SetPaymentHash(paymentHash) + require.NoError(t, err, "unable to set payment hash") + + // Build a trivial one-hop path for the mock to return. + mockPath := []*unifiedEdge{ + { + policy: &models.CachedEdgePolicy{ + ToNodePubKey: func() route.Vertex { + return route.Vertex{0xcc} + }, + ToNodeFeatures: lnwire.NewFeatureVector( + nil, nil, + ), + }, + }, + } + + // Create a custom origin that only accepts gatewayNode. + customOrigin := &singleOrigin{source: gatewayNode} + + // Create a session with the withOrigin option. + session, err := newPaymentSession( + payment, selfNode, + func(Graph) (bandwidthHints, error) { + return &mockBandwidthHints{}, nil + }, + &sessionGraph{}, + &MissionControl{}, + PathFindingConfig{}, + withOrigin(customOrigin), + ) + require.NoError(t, err, "unable to create payment session") + + // Override the pathfinder with a mock that asserts the origin it + // receives is the custom one (not the default singleOrigin{selfNode}) + // and returns gatewayNode as the settled source. + session.pathFinder = func(_ *graphParams, _ *RestrictParams, + _ *PathFindingConfig, self route.Vertex, origin RouteOrigin, + _ route.Vertex, _ lnwire.MilliSatoshi, _ float64, + _ int32) (route.Vertex, []*unifiedEdge, float64, error) { + + // The self parameter should still be the session's own node. + require.Equal(t, selfNode, self, "self should be selfNode") + + // The origin must accept gatewayNode and reject selfNode. + require.True( + t, origin.IsOrigin(gatewayNode), + "origin should accept gatewayNode", + ) + require.False( + t, origin.IsOrigin(selfNode), + "origin should reject selfNode", + ) + + return gatewayNode, mockPath, 1.0, nil + } + + rt, err := session.RequestRoute( + payment.Amount, payment.FeeLimit, 0, 10, + lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType + 123: []byte{1, 2, 3}, + }, + ) + require.NoError(t, err) + + // The route's SourcePubKey must be the gateway returned by the + // pathfinder, not the session's selfNode. + require.Equal( + t, gatewayNode, rt.SourcePubKey, + "SourcePubKey should be the gateway, not selfNode", + ) +} + type sessionGraph struct { Graph } diff --git a/routing/router.go b/routing/router.go index 0f9288fde12..5de3759d165 100644 --- a/routing/router.go +++ b/routing/router.go @@ -477,6 +477,11 @@ type RouteRequest struct { // parameters used to reach a target node blinded paths. This field is // mutually exclusive with the Target field. BlindedPathSet *BlindedPaymentPathSet + + // Origin is an optional RouteOrigin that determines where the route + // can start. When set, it overrides Source for path-finding + // termination. When nil, a singleOrigin wrapping Source is used. + Origin RouteOrigin } // RouteHints is an alias type for a set of route hints, with the source node @@ -600,15 +605,20 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, return nil, 0, errors.New("time preference out of range") } - path, probability, err := findPath( + origin := RouteOrigin(&singleOrigin{req.Source}) + if req.Origin != nil { + origin = req.Origin + } + + source, path, probability, err := findPath( &graphParams{ additionalEdges: req.RouteHints, bandwidthHints: bandwidthHints, graph: r.cfg.RoutingGraph, }, req.Restrictions, &r.cfg.PathFindingConfig, - r.cfg.SelfNode, req.Source, req.Target, req.Amount, - req.TimePreference, finalHtlcExpiry, + r.cfg.SelfNode, origin, req.Target, + req.Amount, req.TimePreference, finalHtlcExpiry, ) if err != nil { return nil, 0, err @@ -616,7 +626,7 @@ func (r *ChannelRouter) FindRoute(req *RouteRequest) (*route.Route, float64, // Create the route with absolute time lock values. route, err := newRoute( - req.Source, path, uint32(currentHeight), + source, path, uint32(currentHeight), finalHopParams{ amt: req.Amount, totalAmt: req.Amount,