diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 4a9a51914e1..8fa460a852c 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -6,6 +6,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/netann" @@ -115,7 +116,10 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // First, we'll query for all the set of channels that have an // update that falls within the specified horizon. chansInHorizon := c.graph.ChanUpdatesInHorizon( - context.TODO(), startTime, endTime, + context.TODO(), graphdb.ChanUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) for channel, err := range chansInHorizon { @@ -181,7 +185,10 @@ func (c *ChanSeries) UpdatesInHorizon(chain chainhash.Hash, // update within the horizon as well. We send these second to // ensure that they follow any active channels they have. nodeAnnsInHorizon := c.graph.NodeUpdatesInHorizon( - context.TODO(), startTime, endTime, + context.TODO(), graphdb.NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, graphdb.WithIterPublicNodesOnly(), ) for nodeAnn, err := range nodeAnnsInHorizon { diff --git a/discovery/gossiper.go b/discovery/gossiper.go index ac3a6d861a2..1d17dc3213e 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -2187,9 +2187,12 @@ func (d *AuthenticatedGossiper) addNode(ctx context.Context, err) } - return d.cfg.Graph.AddNode( - ctx, models.NodeFromWireAnnouncement(msg), op..., - ) + node, err := models.NodeFromWireAnnouncement(msg) + if err != nil { + return err + } + + return d.cfg.Graph.AddNode(ctx, node, op...) } // isPremature decides whether a given network message has a block height+delta diff --git a/docs/release-notes/release-notes-0.21.0.md b/docs/release-notes/release-notes-0.21.0.md index c9d458d1857..1449db1454c 100644 --- a/docs/release-notes/release-notes-0.21.0.md +++ b/docs/release-notes/release-notes-0.21.0.md @@ -242,6 +242,17 @@ [4](https://github.com/lightningnetwork/lnd/pull/10542), [5](https://github.com/lightningnetwork/lnd/pull/10572), [6](https://github.com/lightningnetwork/lnd/pull/10582). +* Make the [graph `Store` interface + cross-version](https://github.com/lightningnetwork/lnd/pull/10656) so that + query methods (`ForEachNode`, `ForEachChannel`, `NodeUpdatesInHorizon`, + `ChanUpdatesInHorizon`, `FilterKnownChanIDs`) work across gossip v1 and v2. + Add `PreferHighest` fetch helpers and `GetVersions` queries so callers can + retrieve channels without knowing which gossip version announced them. +* Add [v2 model and store + support](https://github.com/lightningnetwork/lnd/pull/10657) to the graph + database: wire conversion helpers for node announcements, channel auth + proofs, edge info, and edge policies; `VersionedGraph` zombie wrappers; + SQL queries for v2 node traversal; and v3-only onion address filtering. * Updated waiting proof persistence for gossip upgrades by introducing typed waiting proof keys and payloads, with a DB migration to rewrite legacy waiting proof records to the new key/value format diff --git a/graph/builder.go b/graph/builder.go index 614c1102fc8..31d2199e6a1 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn/v2" graphdb "github.com/lightningnetwork/lnd/graph/db" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/lnutils" @@ -648,7 +649,11 @@ func (b *Builder) pruneZombieChans() error { startTime := time.Unix(0, 0) endTime := time.Now().Add(-1 * chanExpiry) oldEdgesIter := b.cfg.Graph.ChanUpdatesInHorizon( - context.TODO(), startTime, endTime, + context.TODO(), lnwire.GossipVersion1, + graphdb.ChanUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) for u, err := range oldEdgesIter { diff --git a/graph/db/benchmark_test.go b/graph/db/benchmark_test.go index 19d6a134a1f..062d4f52391 100644 --- a/graph/db/benchmark_test.go +++ b/graph/db/benchmark_test.go @@ -372,7 +372,7 @@ func TestPopulateDBs(t *testing.T) { numPolicies = 0 ) err := graph.ForEachChannel( - ctx, lnwire.GossipVersion1, + ctx, func(info *models.ChannelEdgeInfo, policy, policy2 *models.ChannelEdgePolicy) error { @@ -500,7 +500,7 @@ func syncGraph(t *testing.T, src, dest *ChannelGraph) { } var wgChans sync.WaitGroup - err = src.ForEachChannel(ctx, lnwire.GossipVersion1, + err = src.ForEachChannel(ctx, func(info *models.ChannelEdgeInfo, policy1, policy2 *models.ChannelEdgePolicy) error { @@ -624,7 +624,7 @@ func BenchmarkGraphReadMethods(b *testing.B) { name: "ForEachNode", fn: func(b testing.TB, store Store) { err := store.ForEachNode( - ctx, lnwire.GossipVersion1, + ctx, func(_ *models.Node) error { // Increment the counter to // ensure the callback is doing @@ -640,12 +640,11 @@ func BenchmarkGraphReadMethods(b *testing.B) { { name: "ForEachChannel", fn: func(b testing.TB, store Store) { - //nolint:ll - err := store.ForEachChannel( - ctx, lnwire.GossipVersion1, + err := store.ForEachChannel(ctx, func(_ *models.ChannelEdgeInfo, + _, _ *models.ChannelEdgePolicy, - _ *models.ChannelEdgePolicy) error { + ) error { // Increment the counter to // ensure the callback is doing @@ -662,7 +661,13 @@ func BenchmarkGraphReadMethods(b *testing.B) { name: "NodeUpdatesInHorizon", fn: func(b testing.TB, store Store) { iter := store.NodeUpdatesInHorizon( - ctx, time.Unix(0, 0), time.Now(), + ctx, lnwire.GossipVersion1, + NodeUpdateRange{ + StartTime: fn.Some( + time.Unix(0, 0), + ), + EndTime: fn.Some(time.Now()), + }, ) _, err := fn.CollectErr(iter) require.NoError(b, err) @@ -713,7 +718,13 @@ func BenchmarkGraphReadMethods(b *testing.B) { name: "ChanUpdatesInHorizon", fn: func(b testing.TB, store Store) { iter := store.ChanUpdatesInHorizon( - ctx, time.Unix(0, 0), time.Now(), + ctx, lnwire.GossipVersion1, + ChanUpdateRange{ + StartTime: fn.Some( + time.Unix(0, 0), + ), + EndTime: fn.Some(time.Now()), + }, ) _, err := fn.CollectErr(iter) require.NoError(b, err) @@ -817,7 +828,7 @@ func BenchmarkFindOptimalSQLQueryConfig(b *testing.B) { ) err := store.ForEachNode( - ctx, lnwire.GossipVersion1, + ctx, func(_ *models.Node) error { numNodes++ @@ -828,7 +839,7 @@ func BenchmarkFindOptimalSQLQueryConfig(b *testing.B) { //nolint:ll err = store.ForEachChannel( - ctx, lnwire.GossipVersion1, + ctx, func(_ *models.ChannelEdgeInfo, _, _ *models.ChannelEdgePolicy) error { diff --git a/graph/db/graph.go b/graph/db/graph.go index 5e74bfcb710..e0abaaeb3fc 100644 --- a/graph/db/graph.go +++ b/graph/db/graph.go @@ -171,8 +171,9 @@ func (c *ChannelGraph) populateCache(ctx context.Context) error { for _, v := range []lnwire.GossipVersion{ gossipV1, gossipV2, } { - // TODO(elle): If we have both v1 and v2 entries for the same - // node/channel, prefer v2 when merging. + // We iterate v1 first, then v2. Since AddNodeFeatures and + // AddChannel overwrite on key collision, v2 data naturally + // takes precedence when both versions exist. err := c.db.ForEachNodeCacheable(ctx, v, func(node route.Vertex, features *lnwire.FeatureVector) error { @@ -230,12 +231,37 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(ctx context.Context, return c.graphCache.ForEachChannel(node, cb) } - // TODO(elle): once the no-cache path needs to support - // pathfinding across gossip versions, this should iterate - // across all versions rather than defaulting to v1. - return c.db.ForEachNodeDirectedChannel( - ctx, gossipV1, node, cb, reset, - ) + // Iterate across all gossip versions (highest first) so that + // channels announced via v2 are preferred over v1. Each + // version runs in its own Store transaction. We use a + // per-version reset that clears the dedup map rather than + // passing the caller's reset, because ExecTx fires reset on + // every attempt (including the first) which would clear + // results accumulated from earlier versions. + seen := make(map[uint64]struct{}) + for _, v := range []lnwire.GossipVersion{gossipV2, gossipV1} { + err := c.db.ForEachNodeDirectedChannel( + ctx, v, node, + func(channel *DirectedChannel) error { + if _, ok := seen[channel.ChannelID]; ok { + return nil + } + seen[channel.ChannelID] = struct{}{} + + return cb(channel) + }, + func() { + seen = make(map[uint64]struct{}) + }, + ) + if err != nil && + !errors.Is(err, ErrVersionNotSupportedForKVDB) { + + return err + } + } + + return nil } // FetchNodeFeatures returns the features of the given node. If no features are @@ -251,7 +277,22 @@ func (c *ChannelGraph) FetchNodeFeatures(ctx context.Context, return c.graphCache.GetFeatures(node), nil } - return c.db.FetchNodeFeatures(ctx, lnwire.GossipVersion1, node) + // Try v2 first, fall back to v1 if the v2 features are empty. + for _, v := range []lnwire.GossipVersion{gossipV2, gossipV1} { + features, err := c.db.FetchNodeFeatures(ctx, v, node) + if errors.Is(err, ErrVersionNotSupportedForKVDB) { + continue + } + if err != nil { + return nil, err + } + + if !features.IsEmpty() { + return features, nil + } + } + + return lnwire.EmptyFeatureVector(), nil } // GraphSession will provide the call-back with access to a NodeTraverser @@ -516,56 +557,6 @@ func (c *ChannelGraph) PruneGraphNodes(ctx context.Context) error { return nil } -// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan -// ID's that we don't know and are not known zombies of the passed set. In other -// words, we perform a set difference of our set of chan ID's and the ones -// passed in. This method can be used by callers to determine the set of -// channels another peer knows of that we don't. -func (c *ChannelGraph) FilterKnownChanIDs(ctx context.Context, - chansInfo []ChannelUpdateInfo, - isZombieChan func(ChannelUpdateInfo) bool) ([]uint64, error) { - - unknown, knownZombies, err := c.db.FilterKnownChanIDs(ctx, chansInfo) - if err != nil { - return nil, err - } - - for _, info := range knownZombies { - // TODO(ziggie): Make sure that for the strict pruning case we - // compare the pubkeys and whether the right timestamp is not - // older than the `ChannelPruneExpiry`. - // - // NOTE: The timestamp data has no verification attached to it - // in the `ReplyChannelRange` msg so we are trusting this data - // at this point. However it is not critical because we are just - // removing the channel from the db when the timestamps are more - // recent. During the querying of the gossip msg verification - // happens as usual. However we should start punishing peers - // when they don't provide us honest data ? - if isZombieChan(info) { - continue - } - - // If we have marked it as a zombie but the latest update - // info could bring it back from the dead, then we mark it - // alive, and we let it be added to the set of IDs to query our - // peer for. - err := c.db.MarkEdgeLive( - ctx, info.Version, - info.ShortChannelID.ToUint64(), - ) - // Since there is a chance that the edge could have been marked - // as "live" between the FilterKnownChanIDs call and the - // MarkEdgeLive call, we ignore the error if the edge is already - // marked as live. - if err != nil && !errors.Is(err, ErrZombieEdgeNotFound) { - return nil, err - } - } - - return unknown, nil -} - // MarkEdgeZombie attempts to mark a channel identified by its channel ID as a // zombie for the given gossip version. This method is used on an ad-hoc basis, // when channels need to be marked as zombies outside the normal pruning cycle. @@ -634,12 +625,12 @@ func (c *ChannelGraph) ForEachNodeCacheable(ctx context.Context, } // NodeUpdatesInHorizon returns all known lightning nodes with updates in the -// range. +// range for the given gossip version. func (c *ChannelGraph) NodeUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { - return c.db.NodeUpdatesInHorizon(ctx, startTime, endTime, opts...) + return c.db.NodeUpdatesInHorizon(ctx, v, r, opts...) } // HasV1Node determines if the graph has a vertex identified by the target node @@ -650,13 +641,22 @@ func (c *ChannelGraph) HasV1Node(ctx context.Context, return c.db.HasV1Node(ctx, nodePub) } -// ForEachChannel iterates through all channel edges stored within the graph. +// ForEachNode iterates through all nodes in the graph across all gossip +// versions, yielding each unique node exactly once. +func (c *ChannelGraph) ForEachNode(ctx context.Context, + cb func(*models.Node) error, reset func()) error { + + return c.db.ForEachNode(ctx, cb, reset) +} + +// ForEachChannel iterates through all channel edges stored within the graph +// across all gossip versions. func (c *ChannelGraph) ForEachChannel(ctx context.Context, - v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error { - return c.db.ForEachChannel(ctx, v, cb, reset) + return c.db.ForEachChannel(ctx, cb, reset) } // DisabledChannelIDs returns the channel ids of disabled channels. @@ -696,12 +696,12 @@ func (c *ChannelGraph) HighestChanID(ctx context.Context, } // ChanUpdatesInHorizon returns all known channel edges with updates in the -// horizon. +// range for the given gossip version. func (c *ChannelGraph) ChanUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { - return c.db.ChanUpdatesInHorizon(ctx, startTime, endTime, opts...) + return c.db.ChanUpdatesInHorizon(ctx, v, r, opts...) } // FilterChannelRange returns channel IDs within the passed block height range @@ -734,26 +734,39 @@ func (c *ChannelGraph) FetchChanInfos(ctx context.Context, } // FetchChannelEdgesByOutpoint attempts to lookup directed edges by funding -// outpoint. +// outpoint, returning the highest available gossip version. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(ctx context.Context, op *wire.OutPoint) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { - return c.db.FetchChannelEdgesByOutpoint( - ctx, lnwire.GossipVersion1, op, - ) + return c.db.FetchChannelEdgesByOutpointPreferHighest(ctx, op) } -// FetchChannelEdgesByID attempts to lookup directed edges by channel ID. +// FetchChannelEdgesByID attempts to lookup directed edges by channel ID, +// returning the highest available gossip version. func (c *ChannelGraph) FetchChannelEdgesByID(ctx context.Context, chanID uint64) ( *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { - return c.db.FetchChannelEdgesByID( - ctx, lnwire.GossipVersion1, chanID, - ) + return c.db.FetchChannelEdgesByIDPreferHighest(ctx, chanID) +} + +// GetVersionsBySCID returns the list of gossip versions for which a channel +// with the given SCID exists in the database. +func (c *ChannelGraph) GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) { + + return c.db.GetVersionsBySCID(ctx, chanID) +} + +// GetVersionsByOutpoint returns the list of gossip versions for which a channel +// with the given funding outpoint exists in the database. +func (c *ChannelGraph) GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) { + + return c.db.GetVersionsByOutpoint(ctx, op) } // PutClosedScid stores a SCID for a closed channel in the database. @@ -848,7 +861,7 @@ func (c *VersionedGraph) ForEachNodeCached(ctx context.Context, func (c *VersionedGraph) ForEachNode(ctx context.Context, cb func(*models.Node) error, reset func()) error { - return c.db.ForEachNode(ctx, c.v, cb, reset) + return c.db.ForEachNode(ctx, cb, reset) } // NumZombies returns the current number of zombie channels in the graph. @@ -856,13 +869,74 @@ func (c *VersionedGraph) NumZombies(ctx context.Context) (uint64, error) { return c.db.NumZombies(ctx, c.v) } -// NodeUpdatesInHorizon returns all known lightning nodes which have an update -// timestamp within the passed range. +// NodeUpdatesInHorizon returns all known lightning nodes which have updates +// within the passed range. func (c *VersionedGraph) NodeUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { - return c.db.NodeUpdatesInHorizon(ctx, startTime, endTime, opts...) + return c.db.NodeUpdatesInHorizon(ctx, c.v, r, opts...) +} + +// FilterKnownChanIDs takes a set of channel IDs and returns the subset of chan +// ID's that we don't know and are not known zombies of the passed set. In +// other words, we perform a set difference of our set of chan ID's and the ones +// passed in. This method can be used by callers to determine the set of +// channels another peer knows of that we don't. +func (c *VersionedGraph) FilterKnownChanIDs(ctx context.Context, + chansInfo []ChannelUpdateInfo, + isZombieChan func(ChannelUpdateInfo) bool) ([]uint64, error) { + + unknown, knownZombies, err := c.db.FilterKnownChanIDs( + ctx, c.v, chansInfo, + ) + if err != nil { + return nil, err + } + + for _, info := range knownZombies { + // TODO(ziggie): Make sure that for the strict pruning case we + // compare the pubkeys and whether the right timestamp is not + // older than the `ChannelPruneExpiry`. + // + // NOTE: The timestamp data has no verification attached to it + // in the `ReplyChannelRange` msg so we are trusting this data + // at this point. However it is not critical because we are just + // removing the channel from the db when the timestamps are more + // recent. During the querying of the gossip msg verification + // happens as usual. However we should start punishing peers + // when they don't provide us honest data ? + if isZombieChan(info) { + continue + } + + // If we have marked it as a zombie but the latest update + // info could bring it back from the dead, then we mark it + // alive, and we let it be added to the set of IDs to query + // our peer for. + err := c.db.MarkEdgeLive( + ctx, c.v, + info.ShortChannelID.ToUint64(), + ) + // Since there is a chance that the edge could have been + // marked as "live" between the FilterKnownChanIDs call and + // the MarkEdgeLive call, we ignore the error if the edge is + // already marked as live. + if err != nil && !errors.Is(err, ErrZombieEdgeNotFound) { + return nil, err + } + } + + return unknown, nil +} + +// ChanUpdatesInHorizon returns all known channel edges with updates in the +// range. +func (c *VersionedGraph) ChanUpdatesInHorizon(ctx context.Context, + r ChanUpdateRange, + opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { + + return c.db.ChanUpdatesInHorizon(ctx, c.v, r, opts...) } // ChannelView returns the verifiable edge information for each active channel. @@ -985,6 +1059,23 @@ func (c *VersionedGraph) DeleteChannelEdges(ctx context.Context, ) } +// MarkEdgeZombie marks a channel as a zombie for this version. +func (c *VersionedGraph) MarkEdgeZombie(ctx context.Context, chanID uint64, + pubKey1, pubKey2 [33]byte) error { + + return c.ChannelGraph.MarkEdgeZombie( + ctx, c.v, chanID, pubKey1, pubKey2, + ) +} + +// MarkEdgeLive clears an edge from our zombie index for this version, deeming +// it as live. +func (c *VersionedGraph) MarkEdgeLive(ctx context.Context, + chanID uint64) error { + + return c.ChannelGraph.MarkEdgeLive(ctx, c.v, chanID) +} + // HasChannelEdge returns true if the database knows of a channel edge with the // passed channel ID and this graph's gossip version, and false otherwise. If it // is not found, then the zombie index is checked and its result is returned as @@ -1017,7 +1108,7 @@ func (c *VersionedGraph) ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { - return c.db.ForEachChannel(ctx, c.v, cb, reset) + return c.db.ForEachChannel(ctx, cb, reset) } // ForEachNodeCacheable iterates through all stored vertices/nodes in the graph. diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 501ff10be41..cb13f215860 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1789,7 +1789,7 @@ func TestGraphTraversal(t *testing.T) { // Iterate through all the known channels within the graph DB, once // again if the map is empty that indicates that all edges have // properly been reached. - err = graph.ForEachChannel(ctx, lnwire.GossipVersion1, + err = graph.ForEachChannel(ctx, func(ei *models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy, _ *models.ChannelEdgePolicy) error { @@ -2138,7 +2138,7 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 err := graph.ForEachChannel( - t.Context(), lnwire.GossipVersion1, + t.Context(), func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error { @@ -2419,7 +2419,10 @@ func TestChanUpdatesInHorizon(t *testing.T) { // If we issue an arbitrary query before any channel updates are // inserted in the database, we should get zero results. chanIter := graph.ChanUpdatesInHorizon( - ctx, time.Unix(999, 0), time.Unix(9999, 0), + ctx, lnwire.GossipVersion1, ChanUpdateRange{ + StartTime: fn.Some(time.Unix(999, 0)), + EndTime: fn.Some(time.Unix(9999, 0)), + }, ) chanUpdates, err := fn.CollectErr(chanIter) @@ -2526,7 +2529,10 @@ func TestChanUpdatesInHorizon(t *testing.T) { } for _, queryCase := range queryCases { respIter := graph.ChanUpdatesInHorizon( - ctx, queryCase.start, queryCase.end, + ctx, lnwire.GossipVersion1, ChanUpdateRange{ + StartTime: fn.Some(queryCase.start), + EndTime: fn.Some(queryCase.end), + }, ) resp, err := fn.CollectErr(respIter) @@ -2563,7 +2569,10 @@ func TestNodeUpdatesInHorizon(t *testing.T) { // If we issue an arbitrary query before we insert any nodes into the // database, then we shouldn't get any results back. nodeUpdatesIter := graph.NodeUpdatesInHorizon( - ctx, time.Unix(999, 0), time.Unix(9999, 0), + ctx, lnwire.GossipVersion1, NodeUpdateRange{ + StartTime: fn.Some(time.Unix(999, 0)), + EndTime: fn.Some(time.Unix(9999, 0)), + }, ) nodeUpdates, err := fn.CollectErr(nodeUpdatesIter) require.NoError(t, err, "unable to query for node updates") @@ -2638,7 +2647,10 @@ func TestNodeUpdatesInHorizon(t *testing.T) { } for _, queryCase := range queryCases { iter := graph.NodeUpdatesInHorizon( - ctx, queryCase.start, queryCase.end, + ctx, lnwire.GossipVersion1, NodeUpdateRange{ + StartTime: fn.Some(queryCase.start), + EndTime: fn.Some(queryCase.end), + }, ) resp, err := fn.CollectErr(iter) @@ -2766,7 +2778,11 @@ func testNodeUpdatesWithBatchSize(t *testing.T, ctx context.Context, for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { iter := testGraph.NodeUpdatesInHorizon( - ctx, tc.start, tc.end, + ctx, lnwire.GossipVersion1, + NodeUpdateRange{ + StartTime: fn.Some(tc.start), + EndTime: fn.Some(tc.end), + }, WithNodeUpdateIterBatchSize( batchSize, ), @@ -2839,7 +2855,13 @@ func TestNodeUpdatesInHorizonEarlyTermination(t *testing.T) { for _, stopAt := range terminationPoints { t.Run(fmt.Sprintf("StopAt%d", stopAt), func(t *testing.T) { iter := graph.NodeUpdatesInHorizon( - ctx, startTime, startTime.Add(200*time.Hour), + ctx, lnwire.GossipVersion1, + NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some( + startTime.Add(200 * time.Hour), + ), + }, WithNodeUpdateIterBatchSize(10), ) @@ -2928,7 +2950,12 @@ func TestChanUpdatesInHorizonBoundaryConditions(t *testing.T) { // Now we'll run the main query, and verify that we get // back the expected number of channels. iter := graph.ChanUpdatesInHorizon( - ctx, startTime, startTime.Add(26*time.Hour), + ctx, lnwire.GossipVersion1, ChanUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some( + startTime.Add(26 * time.Hour), + ), + }, WithChanUpdateIterBatchSize(batchSize), ) @@ -2954,7 +2981,9 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { t.Parallel() ctx := t.Context() - graph := MakeTestGraph(t) + graph := NewVersionedGraph( + MakeTestGraph(t), lnwire.GossipVersion1, + ) var ( scid1 = lnwire.ShortChannelID{BlockHeight: 1} @@ -2962,9 +2991,8 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { scid3 = lnwire.ShortChannelID{BlockHeight: 3} ) - v1Graph := NewVersionedGraph(graph, lnwire.GossipVersion1) isZombie := func(scid lnwire.ShortChannelID) bool { - zombie, _, _, err := v1Graph.IsZombieEdge(ctx, scid.ToUint64()) + zombie, _, _, err := graph.IsZombieEdge(ctx, scid.ToUint64()) require.NoError(t, err) return zombie @@ -2972,13 +3000,11 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { // Mark channel 1 and 2 as zombies. err := graph.MarkEdgeZombie( - ctx, lnwire.GossipVersion1, scid1.ToUint64(), - [33]byte{}, [33]byte{}, + ctx, scid1.ToUint64(), [33]byte{}, [33]byte{}, ) require.NoError(t, err) err = graph.MarkEdgeZombie( - ctx, lnwire.GossipVersion1, scid2.ToUint64(), - [33]byte{}, [33]byte{}, + ctx, scid2.ToUint64(), [33]byte{}, [33]byte{}, ) require.NoError(t, err) @@ -2988,13 +3014,15 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { // Call FilterKnownChanIDs with an isStillZombie call-back that would // result in the current zombies still be considered as zombies. - _, err = graph.FilterKnownChanIDs(ctx, []ChannelUpdateInfo{ - {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, - {ShortChannelID: scid2, Version: lnwire.GossipVersion1}, - {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, - }, func(_ ChannelUpdateInfo) bool { - return true - }) + _, err = graph.FilterKnownChanIDs( + ctx, []ChannelUpdateInfo{ + {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, + {ShortChannelID: scid2, Version: lnwire.GossipVersion1}, + {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, + }, func(_ ChannelUpdateInfo) bool { + return true + }, + ) require.NoError(t, err) require.True(t, isZombie(scid1)) @@ -3004,17 +3032,19 @@ func TestFilterKnownChanIDsZombieRevival(t *testing.T) { // Now call it again but this time with a isStillZombie call-back that // would result in channel with SCID 2 no longer being considered a // zombie. - _, err = graph.FilterKnownChanIDs(ctx, []ChannelUpdateInfo{ - {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, - { - ShortChannelID: scid2, - Version: lnwire.GossipVersion1, - Node1Freshness: lnwire.UnixTimestamp(1000), + _, err = graph.FilterKnownChanIDs( + ctx, []ChannelUpdateInfo{ + {ShortChannelID: scid1, Version: lnwire.GossipVersion1}, + { + ShortChannelID: scid2, + Version: lnwire.GossipVersion1, + Node1Freshness: lnwire.UnixTimestamp(1000), + }, + {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, + }, func(info ChannelUpdateInfo) bool { + return info.Node1Freshness != lnwire.UnixTimestamp(1000) }, - {ShortChannelID: scid3, Version: lnwire.GossipVersion1}, - }, func(info ChannelUpdateInfo) bool { - return info.Node1Freshness != lnwire.UnixTimestamp(1000) - }) + ) require.NoError(t, err) // Show that SCID 2 has been marked as live. @@ -3030,7 +3060,9 @@ func TestFilterKnownChanIDs(t *testing.T) { t.Parallel() ctx := t.Context() - graph := MakeTestGraph(t) + graph := NewVersionedGraph( + MakeTestGraph(t), lnwire.GossipVersion1, + ) isZombieUpdate := func(_ ChannelUpdateInfo) bool { return true @@ -3090,8 +3122,7 @@ func TestFilterKnownChanIDs(t *testing.T) { ) require.NoError(t, graph.AddChannelEdge(ctx, channel)) err := graph.DeleteChannelEdges( - ctx, lnwire.GossipVersion1, false, true, - channel.ChannelID, + ctx, false, true, channel.ChannelID, ) require.NoError(t, err) @@ -3356,7 +3387,10 @@ func TestStressTestChannelGraphAPI(t *testing.T) { chanIDs = append(chanIDs, info) } - _, err := graph.FilterKnownChanIDs( + vg := NewVersionedGraph( + graph, lnwire.GossipVersion1, + ) + _, err := vg.FilterKnownChanIDs( ctx, chanIDs, func(_ ChannelUpdateInfo) bool { return rand.Intn(2) == 0 @@ -3406,8 +3440,17 @@ func TestStressTestChannelGraphAPI(t *testing.T) { name: "ChanUpdateInHorizon", fn: func() error { iter := graph.ChanUpdatesInHorizon( - ctx, time.Now().Add(-time.Hour), - time.Now(), + ctx, lnwire.GossipVersion1, + ChanUpdateRange{ + StartTime: fn.Some( + time.Now().Add( + -time.Hour, + ), + ), + EndTime: fn.Some( + time.Now(), + ), + }, ) _, err := fn.CollectErr(iter) @@ -4203,7 +4246,10 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { startTime := time.Unix(9, 0) endTime := node1.LastUpdate.Add(time.Minute) nodesInHorizonIter := graph.NodeUpdatesInHorizon( - ctx, startTime, endTime, + ctx, NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) // We should only have a single node, and that node should exactly @@ -4221,7 +4267,10 @@ func TestNodePruningUpdateIndexDeletion(t *testing.T) { // Now that the node has been deleted, we'll again query the nodes in // the horizon. This time we should have no nodes at all. nodesInHorizonIter = graph.NodeUpdatesInHorizon( - ctx, startTime, endTime, + ctx, NodeUpdateRange{ + StartTime: fn.Some(startTime), + EndTime: fn.Some(endTime), + }, ) nodesInHorizon, err = fn.CollectErr(nodesInHorizonIter) require.NoError(t, err, "unable to fetch nodes in horizon") @@ -5222,7 +5271,8 @@ func TestLightningNodePersistence(t *testing.T) { require.True(t, ok) // Convert the wire message to our internal node representation. - node := models.NodeFromWireAnnouncement(na) + node, err := models.NodeFromWireAnnouncement(na) + require.NoError(t, err) // Persist the node to disk. err = graph.AddNode(ctx, node) @@ -5243,3 +5293,484 @@ func TestLightningNodePersistence(t *testing.T) { require.Equal(t, nodeAnnBytes, b.Bytes()) } + +// TestValidateForVersion verifies that ChanUpdateRange and NodeUpdateRange +// reject invalid field combinations for each gossip version. +func TestValidateForVersion(t *testing.T) { + t.Parallel() + + now := time.Now() + + tests := []struct { + name string + fn func() error + wantErr string + }{ + { + name: "v1 chan range with time - ok", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + }, + { + name: "v1 chan range with height - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartHeight: fn.Some(uint32(1)), + EndHeight: fn.Some(uint32(100)), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "v1 chan update range must use time", + }, + { + name: "v2 chan range with height - ok", + fn: func() error { + r := ChanUpdateRange{ + StartHeight: fn.Some(uint32(1)), + EndHeight: fn.Some(uint32(100)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + }, + { + name: "v2 chan range with time - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "v2 chan update range must use blocks", + }, + { + name: "mixed chan range - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + StartHeight: fn.Some(uint32(1)), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "both time and block", + }, + { + name: "v1 node range with time - ok", + fn: func() error { + r := NodeUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + }, + { + name: "v2 node range with height - ok", + fn: func() error { + r := NodeUpdateRange{ + StartHeight: fn.Some(uint32(1)), + EndHeight: fn.Some(uint32(100)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + }, + { + name: "v2 node range with time - rejected", + fn: func() error { + r := NodeUpdateRange{ + StartTime: fn.Some(now), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "v2 node update range must use height", + }, + { + name: "v1 chan range missing bounds - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "missing time bounds", + }, + { + name: "v1 chan range inverted - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartTime: fn.Some(now.Add(time.Hour)), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "start time after end time", + }, + { + name: "v2 chan range inverted - rejected", + fn: func() error { + r := ChanUpdateRange{ + StartHeight: fn.Some(uint32(100)), + EndHeight: fn.Some(uint32(50)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "start height after end height", + }, + { + name: "v1 node range inverted - rejected", + fn: func() error { + r := NodeUpdateRange{ + StartTime: fn.Some(now.Add(time.Hour)), + EndTime: fn.Some(now), + } + + return r.validateForVersion( + lnwire.GossipVersion1, + ) + }, + wantErr: "start time after end time", + }, + { + name: "v2 node range inverted - rejected", + fn: func() error { + r := NodeUpdateRange{ + StartHeight: fn.Some(uint32(100)), + EndHeight: fn.Some(uint32(50)), + } + + return r.validateForVersion( + lnwire.GossipVersion2, + ) + }, + wantErr: "start height after end height", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.fn() + if tc.wantErr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, + tc.wantErr) + } + }) + } +} + +// TestV2HorizonQueries tests that NodeUpdatesInHorizon and +// ChanUpdatesInHorizon work with v2 gossip (block-height ranges). This test +// only runs on SQL backends since KV does not support v2. +func TestV2HorizonQueries(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("v2 horizon queries only supported on SQL backends") + } + + ctx := t.Context() + graph := MakeTestGraph(t) + + // Create two v2 nodes with specific block heights. + node1 := createTestVertex(t, lnwire.GossipVersion2) + node1.LastBlockHeight = 100 + require.NoError(t, graph.AddNode(ctx, node1)) + + node2 := createTestVertex(t, lnwire.GossipVersion2) + node2.LastBlockHeight = 200 + require.NoError(t, graph.AddNode(ctx, node2)) + + // Create a third node outside the query range. + node3 := createTestVertex(t, lnwire.GossipVersion2) + node3.LastBlockHeight = 500 + require.NoError(t, graph.AddNode(ctx, node3)) + + // --- NodeUpdatesInHorizon v2 --- + + // Query for nodes in block range [50, 250]. + nodeIter := graph.NodeUpdatesInHorizon( + ctx, lnwire.GossipVersion2, NodeUpdateRange{ + StartHeight: fn.Some(uint32(50)), + EndHeight: fn.Some(uint32(250)), + }, + ) + nodes, err := fn.CollectErr(nodeIter) + require.NoError(t, err) + require.Len(t, nodes, 2) + + // Query for nodes in block range [150, 600] should return node2 and + // node3. + nodeIter = graph.NodeUpdatesInHorizon( + ctx, lnwire.GossipVersion2, NodeUpdateRange{ + StartHeight: fn.Some(uint32(150)), + EndHeight: fn.Some(uint32(600)), + }, + ) + nodes, err = fn.CollectErr(nodeIter) + require.NoError(t, err) + require.Len(t, nodes, 2) + + // Query for nodes in block range [300, 400] should return nothing. + nodeIter = graph.NodeUpdatesInHorizon( + ctx, lnwire.GossipVersion2, NodeUpdateRange{ + StartHeight: fn.Some(uint32(300)), + EndHeight: fn.Some(uint32(400)), + }, + ) + nodes, err = fn.CollectErr(nodeIter) + require.NoError(t, err) + require.Empty(t, nodes) + + // --- ChanUpdatesInHorizon v2 --- + + // Create a v2 channel between node1 and node2. + edgeInfo, _ := createEdge( + lnwire.GossipVersion2, 100, 1, 0, 10, node1, node2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo)) + + // Add v2 policies with specific block heights. + edge1 := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion2, + ChannelID: edgeInfo.ChannelID, + LastBlockHeight: 150, + TimeLockDelta: 14, + MinHTLC: 1000, + MaxHTLC: 1000000, + FeeBaseMSat: 1000, + FeeProportionalMillionths: 200, + } + edge2 := &models.ChannelEdgePolicy{ + Version: lnwire.GossipVersion2, + SecondPeer: true, + ChannelID: edgeInfo.ChannelID, + LastBlockHeight: 160, + TimeLockDelta: 14, + MinHTLC: 1000, + MaxHTLC: 1000000, + FeeBaseMSat: 1000, + FeeProportionalMillionths: 200, + } + require.NoError(t, graph.UpdateEdgePolicy(ctx, edge1)) + require.NoError(t, graph.UpdateEdgePolicy(ctx, edge2)) + + // Query for channel updates in block range [100, 200]. + chanIter := graph.ChanUpdatesInHorizon( + ctx, lnwire.GossipVersion2, ChanUpdateRange{ + StartHeight: fn.Some(uint32(100)), + EndHeight: fn.Some(uint32(200)), + }, + ) + channels, err := fn.CollectErr(chanIter) + require.NoError(t, err) + require.Len(t, channels, 1) + require.Equal(t, edgeInfo.ChannelID, channels[0].Info.ChannelID) + + // Query for channel updates in block range [200, 300] should return + // nothing since policies are at heights 150 and 160. + chanIter = graph.ChanUpdatesInHorizon( + ctx, lnwire.GossipVersion2, ChanUpdateRange{ + StartHeight: fn.Some(uint32(200)), + EndHeight: fn.Some(uint32(300)), + }, + ) + channels, err = fn.CollectErr(chanIter) + require.NoError(t, err) + require.Empty(t, channels) +} + +// TestPreferHighestAndGetVersions tests the four new Store methods: +// FetchChannelEdgesByIDPreferHighest, FetchChannelEdgesByOutpointPreferHighest, +// GetVersionsBySCID, and GetVersionsByOutpoint. +func TestPreferHighestAndGetVersions(t *testing.T) { + t.Parallel() + ctx := t.Context() + + graph := MakeTestGraph(t) + store := graph.db + + // Create two nodes that will anchor our test channel. + node1 := createTestVertex(t, lnwire.GossipVersion1) + require.NoError(t, graph.AddNode(ctx, node1)) + + node2 := createTestVertex(t, lnwire.GossipVersion1) + require.NoError(t, graph.AddNode(ctx, node2)) + + // Create and add a v1 channel edge. + edgeInfo, scid := createEdge( + lnwire.GossipVersion1, 100, 1, 0, 1, node1, node2, + ) + require.NoError(t, graph.AddChannelEdge(ctx, edgeInfo)) + + chanID := scid.ToUint64() + op := edgeInfo.ChannelPoint + + // FetchChannelEdgesByIDPreferHighest should return the v1 channel. + info, _, _, err := store.FetchChannelEdgesByIDPreferHighest( + ctx, chanID, + ) + require.NoError(t, err) + require.Equal(t, chanID, info.ChannelID) + + // FetchChannelEdgesByOutpointPreferHighest should also return it. + info, _, _, err = store.FetchChannelEdgesByOutpointPreferHighest( + ctx, &op, + ) + require.NoError(t, err) + require.Equal(t, chanID, info.ChannelID) + + // Querying a non-existent channel should return an error. + _, _, _, err = store.FetchChannelEdgesByIDPreferHighest(ctx, 999999) + require.Error(t, err) + + // GetVersionsBySCID should report v1. + versions, err := store.GetVersionsBySCID(ctx, chanID) + require.NoError(t, err) + require.Equal(t, []lnwire.GossipVersion{ + lnwire.GossipVersion1, + }, versions) + + // GetVersionsByOutpoint should also report v1. + versions, err = store.GetVersionsByOutpoint(ctx, &op) + require.NoError(t, err) + require.Equal(t, []lnwire.GossipVersion{ + lnwire.GossipVersion1, + }, versions) + + // GetVersions for a non-existent SCID should return empty. + versions, err = store.GetVersionsBySCID(ctx, 999999) + require.NoError(t, err) + require.Empty(t, versions) +} + +// TestPreferHighestNodeTraversal verifies that ChannelGraph's +// ForEachNodeDirectedChannel and FetchNodeFeatures correctly prefer v2 over v1 +// when the graph cache is disabled (exercising the no-cache code paths). +func TestPreferHighestNodeTraversal(t *testing.T) { + t.Parallel() + + if !isSQLDB { + t.Skip("prefer-highest requires SQL backend") + } + + ctx := t.Context() + + // Disable the cache so we exercise the no-cache code paths in + // ChannelGraph.ForEachNodeDirectedChannel and FetchNodeFeatures. + graph := MakeTestGraph(t, WithUseGraphCache(false)) + + // --- FetchNodeFeatures --- + + // Create a v1-only node and verify its features are returned. + privV1, err := btcec.NewPrivateKey() + require.NoError(t, err) + + nodeV1 := createNode(t, lnwire.GossipVersion1, privV1) + require.NoError(t, graph.AddNode(ctx, nodeV1)) + + features, err := graph.FetchNodeFeatures(ctx, nodeV1.PubKeyBytes) + require.NoError(t, err) + require.False(t, features.IsEmpty(), + "v1-only node should have features") + + // Create a v2-only node and verify its features are returned + // (exercises the v2 fallback). + privV2, err := btcec.NewPrivateKey() + require.NoError(t, err) + + nodeV2 := createNode(t, lnwire.GossipVersion2, privV2) + require.NoError(t, graph.AddNode(ctx, nodeV2)) + + features, err = graph.FetchNodeFeatures(ctx, nodeV2.PubKeyBytes) + require.NoError(t, err) + require.False(t, features.IsEmpty(), + "v2-only node should have features") + + // Create a node with both v1 and v2 announcements. + privBoth, err := btcec.NewPrivateKey() + require.NoError(t, err) + + nodeBothV1 := createNode(t, lnwire.GossipVersion1, privBoth) + require.NoError(t, graph.AddNode(ctx, nodeBothV1)) + + nodeBothV2 := createNode(t, lnwire.GossipVersion2, privBoth) + require.NoError(t, graph.AddNode(ctx, nodeBothV2)) + + features, err = graph.FetchNodeFeatures( + ctx, nodeBothV1.PubKeyBytes, + ) + require.NoError(t, err) + require.False(t, features.IsEmpty(), + "both-version node should have features") + + // --- ForEachNodeDirectedChannel --- + + // Add a v1 channel between nodeV1 and nodeBothV1. + edge, _ := createEdge( + lnwire.GossipVersion1, 100, 0, 0, 0, + nodeV1, nodeBothV1, + ) + require.NoError(t, graph.AddChannelEdge(ctx, edge)) + + pol := newEdgePolicy( + lnwire.GossipVersion1, edge.ChannelID, 1000, true, + ) + pol.ToNode = nodeBothV1.PubKeyBytes + pol.SigBytes = testSig.Serialize() + require.NoError(t, graph.UpdateEdgePolicy(ctx, pol)) + + // ForEachNodeDirectedChannel should find the channel. + var foundChannels int + err = graph.ForEachNodeDirectedChannel( + ctx, nodeV1.PubKeyBytes, + func(_ *DirectedChannel) error { + foundChannels++ + return nil + }, func() { + foundChannels = 0 + }, + ) + require.NoError(t, err) + require.Equal(t, 1, foundChannels, + "expected 1 channel for v1 node") +} diff --git a/graph/db/interfaces.go b/graph/db/interfaces.go index a725cb8ee46..33af4c1e006 100644 --- a/graph/db/interfaces.go +++ b/graph/db/interfaces.go @@ -29,8 +29,8 @@ type NodeTraverser interface { nodePub route.Vertex) (*lnwire.FeatureVector, error) } -// Store represents the main interface for the channel graph database for all -// channels and nodes gossiped via the V1 gossip protocol as defined in BOLT 7. +// Store represents the main interface for the channel graph database. It +// supports channels and nodes from multiple gossip protocol versions. type Store interface { //nolint:interfacebloat // ForEachNodeDirectedChannel calls the callback for every channel of // the given node. @@ -95,11 +95,11 @@ type Store interface { //nolint:interfacebloat chans map[uint64]*DirectedChannel) error, reset func()) error - // ForEachNode iterates through all the stored vertices/nodes in the - // graph, executing the passed callback with each node encountered. If - // the callback returns an error, then the transaction is aborted and - // the iteration stops early. - ForEachNode(ctx context.Context, v lnwire.GossipVersion, + // ForEachNode iterates through all nodes in the graph across all + // gossip versions, yielding each unique node exactly once. The + // callback receives the best available Node (highest advertised + // version preferred, falling back to shell nodes). + ForEachNode(ctx context.Context, cb func(*models.Node) error, reset func()) error // ForEachNodeCacheable iterates through all the stored vertices/nodes @@ -120,11 +120,12 @@ type Store interface { //nolint:interfacebloat DeleteNode(ctx context.Context, v lnwire.GossipVersion, nodePub route.Vertex) error - // NodeUpdatesInHorizon returns all the known lightning node which have - // an update timestamp within the passed range. This method can be used - // by two nodes to quickly determine if they have the same set of up to - // date node announcements. - NodeUpdatesInHorizon(ctx context.Context, startTime, endTime time.Time, + // NodeUpdatesInHorizon returns all the known lightning nodes which have + // updates within the passed range for the given gossip version. This + // method can be used by two nodes to quickly determine if they have the + // same set of up to date node announcements. + NodeUpdatesInHorizon(ctx context.Context, v lnwire.GossipVersion, + r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] // FetchNode attempts to look up a target node by its identity @@ -160,21 +161,16 @@ type Store interface { //nolint:interfacebloat GraphSession(ctx context.Context, cb func(graph NodeTraverser) error, reset func()) error - // ForEachChannel iterates through all the channel edges stored within - // the graph and invokes the passed callback for each edge. The callback - // takes two edges as since this is a directed graph, both the in/out - // edges are visited. If the callback returns an error, then the - // transaction is aborted and the iteration stops early. - // - // NOTE: If an edge can't be found, or wasn't advertised, then a nil - // pointer for that particular channel edge routing policy will be - // passed into the callback. - // - // TODO(elle): add a cross-version iteration API and make this iterate - // over all versions. - ForEachChannel(ctx context.Context, v lnwire.GossipVersion, + // ForEachChannel iterates through all channel edges stored within the + // graph across all gossip versions, yielding each unique channel + // exactly once. The callback receives the edge info and both + // directional policies. When both versions are present, v2 is + // preferred. Nil pointers are passed for policies that haven't been + // advertised. + ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error, reset func()) error + *models.ChannelEdgePolicy) error, + reset func()) error // ForEachChannelCacheable iterates through all the channel edges stored // within the graph and invokes the passed callback for each edge. The @@ -256,10 +252,10 @@ type Store interface { //nolint:interfacebloat uint64, error) // ChanUpdatesInHorizon returns all the known channel edges which have - // at least one edge that has an update timestamp within the specified - // horizon. - ChanUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + // at least one edge update within the specified range for the given + // gossip version. + ChanUpdatesInHorizon(ctx context.Context, v lnwire.GossipVersion, + r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] // FilterKnownChanIDs takes a set of channel IDs and return the subset @@ -269,9 +265,9 @@ type Store interface { //nolint:interfacebloat // callers to determine the set of channels another peer knows of that // we don't. The ChannelUpdateInfos for the known zombies is also // returned. - FilterKnownChanIDs(ctx context.Context, - chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, - error) + FilterKnownChanIDs(ctx context.Context, v lnwire.GossipVersion, + chansInfo []ChannelUpdateInfo) ([]uint64, + []ChannelUpdateInfo, error) // FilterChannelRange returns the channel ID's of all known channels // which were mined in a block height within the passed range for the @@ -321,6 +317,35 @@ type Store interface { //nolint:interfacebloat *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) + // FetchChannelEdgesByIDPreferHighest behaves like FetchChannelEdgesByID + // but is version-agnostic: if the channel exists under multiple gossip + // versions it returns the record with the highest version number. + FetchChannelEdgesByIDPreferHighest(ctx context.Context, + chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // FetchChannelEdgesByOutpointPreferHighest behaves like + // FetchChannelEdgesByOutpoint but is version-agnostic: if the channel + // exists under multiple gossip versions it returns the record with the + // highest version number. + FetchChannelEdgesByOutpointPreferHighest(ctx context.Context, + op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) + + // GetVersionsBySCID returns the list of gossip versions for which a + // channel with the given SCID exists in the database, ordered + // ascending. + GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) + + // GetVersionsByOutpoint returns the list of gossip versions for which + // a channel with the given funding outpoint exists in the database, + // ordered ascending. + GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) + // ChannelView returns the verifiable edge information for each active // channel within the known channel graph for the given gossip version. // The set of UTXO's (along with their scripts) returned are the ones diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 2a993d73bea..1e2da601341 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -408,13 +408,10 @@ func (c *KVStore) AddrsForNode(ctx context.Context, v lnwire.GossipVersion, // NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer // for that particular channel edge routing policy will be passed into the // callback. -func (c *KVStore) ForEachChannel(_ context.Context, v lnwire.GossipVersion, +func (c *KVStore) ForEachChannel(_ context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, - *models.ChannelEdgePolicy) error, reset func()) error { - - if v != lnwire.GossipVersion1 { - return ErrVersionNotSupportedForKVDB - } + *models.ChannelEdgePolicy) error, + reset func()) error { return forEachChannel(c.db, cb, reset) } @@ -836,13 +833,9 @@ func (c *KVStore) DisabledChannelIDs( // early. // // NOTE: this is part of the Store interface. -func (c *KVStore) ForEachNode(_ context.Context, v lnwire.GossipVersion, +func (c *KVStore) ForEachNode(_ context.Context, cb func(*models.Node) error, reset func()) error { - if v != lnwire.GossipVersion1 { - return ErrVersionNotSupportedForKVDB - } - return forEachNode(c.db, func(tx kvdb.RTx, node *models.Node) error { @@ -2384,11 +2377,19 @@ func (c *KVStore) fetchNextChanUpdateBatch( } // ChanUpdatesInHorizon returns all the known channel edges which have at least -// one edge that has an update timestamp within the specified horizon. +// one edge that has an update within the specified range for the given gossip +// version. func (c *KVStore) ChanUpdatesInHorizon(_ context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { + if v != lnwire.GossipVersion1 { + return chanUpdateRangeErrIter(ErrVersionNotSupportedForKVDB) + } + if err := r.validateForVersion(v); err != nil { + return chanUpdateRangeErrIter(err) + } + cfg := defaultIteratorConfig() for _, opt := range opts { opt(cfg) @@ -2396,7 +2397,9 @@ func (c *KVStore) ChanUpdatesInHorizon(_ context.Context, return func(yield func(ChannelEdge, error) bool) { iterState := newChanUpdatesIterator( - cfg.chanUpdateIterBatchSize, startTime, endTime, + cfg.chanUpdateIterBatchSize, + r.StartTime.UnwrapOr(time.Time{}), + r.EndTime.UnwrapOr(time.Time{}), ) for { @@ -2445,8 +2448,8 @@ func (c *KVStore) ChanUpdatesInHorizon(_ context.Context, float64(iterState.total), iterState.hits, iterState.total) } else { - log.Tracef("ChanUpdatesInHorizon returned no edges "+ - "in horizon (%s, %s)", startTime, endTime) + log.Tracef("ChanUpdatesInHorizon(v%d) returned "+ + "no edges in horizon", v) } } } @@ -2635,9 +2638,9 @@ func (c *KVStore) fetchNextNodeBatch( } // NodeUpdatesInHorizon returns all the known lightning node which have an -// update timestamp within the passed range. -func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, startTime, - endTime time.Time, +// update timestamp within the passed range for the given gossip version. +func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, + v lnwire.GossipVersion, r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { cfg := defaultIteratorConfig() @@ -2646,10 +2649,20 @@ func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, startTime, } return func(yield func(*models.Node, error) bool) { + if v != lnwire.GossipVersion1 { + yield(nil, ErrVersionNotSupportedForKVDB) + return + } + if err := r.validateForVersion(v); err != nil { + yield(nil, err) + return + } + // Initialize iterator state. state := newNodeUpdatesIterator( cfg.nodeUpdateIterBatchSize, - startTime, endTime, + r.StartTime.UnwrapOr(time.Time{}), + r.EndTime.UnwrapOr(time.Time{}), cfg.iterPublicNodes, ) @@ -2686,8 +2699,13 @@ func (c *KVStore) NodeUpdatesInHorizon(_ context.Context, startTime, // channels another peer knows of that we don't. The ChannelUpdateInfos for the // known zombies is also returned. func (c *KVStore) FilterKnownChanIDs(_ context.Context, + v lnwire.GossipVersion, chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, error) { + if v != lnwire.GossipVersion1 { + return nil, nil, ErrVersionNotSupportedForKVDB + } + var ( newChanIDs []uint64 knownZombies []ChannelUpdateInfo @@ -3453,7 +3471,7 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy) ( // or second edge policy is being updated. var fromNode, toNode []byte var isUpdate1 bool - if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + if edge.IsNode1() { fromNode = nodeInfo[:33] toNode = nodeInfo[33:66] isUpdate1 = true @@ -4156,7 +4174,86 @@ func (c *KVStore) FetchChannelEdgesByID(_ context.Context, return edgeInfo, policy1, policy2, nil } -// IsPublicNode is a helper method that determines whether the node with the +// FetchChannelEdgesByIDPreferHighest looks up the channel by SCID. The KV +// store only supports gossip v1, so this simply delegates to the versioned +// fetch. +// +// NOTE: part of the Store interface. +func (c *KVStore) FetchChannelEdgesByIDPreferHighest(ctx context.Context, + chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + return c.FetchChannelEdgesByID(ctx, lnwire.GossipVersion1, chanID) +} + +// FetchChannelEdgesByOutpointPreferHighest looks up the channel by funding +// outpoint. The KV store only supports gossip v1, so this simply delegates to +// the versioned fetch. +// +// NOTE: part of the Store interface. +func (c *KVStore) FetchChannelEdgesByOutpointPreferHighest( + ctx context.Context, op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + return c.FetchChannelEdgesByOutpoint( + ctx, lnwire.GossipVersion1, op, + ) +} + +// GetVersionsBySCID returns the gossip versions for which a channel with the +// given SCID exists. The KV store only supports gossip v1, so at most one +// version is returned. +// +// NOTE: part of the Store interface. +func (c *KVStore) GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) { + + _, _, _, err := c.FetchChannelEdgesByID( + ctx, lnwire.GossipVersion1, chanID, + ) + switch { + case errors.Is(err, ErrEdgeNotFound): + return nil, nil + + case errors.Is(err, ErrZombieEdge): + return nil, nil + + case err != nil: + return nil, err + + default: + return []lnwire.GossipVersion{lnwire.GossipVersion1}, nil + } +} + +// GetVersionsByOutpoint returns the gossip versions for which a channel with +// the given funding outpoint exists. The KV store only supports gossip v1, so +// at most one version is returned. +// +// NOTE: part of the Store interface. +func (c *KVStore) GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) { + + _, _, _, err := c.FetchChannelEdgesByOutpoint( + ctx, lnwire.GossipVersion1, op, + ) + switch { + case errors.Is(err, ErrEdgeNotFound): + return nil, nil + + case errors.Is(err, ErrZombieEdge): + return nil, nil + + case err != nil: + return nil, err + + default: + return []lnwire.GossipVersion{lnwire.GossipVersion1}, nil + } +} + // given public key is seen as a public node in the graph from the graph's // source node's point of view. func (c *KVStore) IsPublicNode(_ context.Context, v lnwire.GossipVersion, @@ -5245,7 +5342,7 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy, err = updateEdgePolicyDisabledIndex( edges, edge.ChannelID, - edge.ChannelFlags&lnwire.ChanUpdateDirection > 0, + !edge.IsNode1(), edge.IsDisabled(), ) if err != nil { diff --git a/graph/db/models/channel_auth_proof.go b/graph/db/models/channel_auth_proof.go index e4acc2f8b3b..0e8964495ec 100644 --- a/graph/db/models/channel_auth_proof.go +++ b/graph/db/models/channel_auth_proof.go @@ -1,10 +1,22 @@ package models import ( + "errors" + "fmt" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" ) +var ( + // ErrV2AnnSigProofAssemblyPending is returned when trying to derive a + // channel auth proof from v2 announce signatures. This will be + // supported once v2 proof assembly from announce_signatures_2 halves is + // implemented. + ErrV2AnnSigProofAssemblyPending = errors.New("v2 announce signatures " + + "proof assembly not yet implemented") +) + // ChannelAuthProof is the authentication proof (the signature portion) for a // channel. // @@ -115,3 +127,74 @@ func (c *ChannelAuthProof) BitcoinSig2() []byte { func (c *ChannelAuthProof) Sig() []byte { return c.Signature.UnwrapOr(nil) } + +// ChannelAuthProofFromWireAnnouncement constructs a channel auth proof from a +// wire channel announcement message. +func ChannelAuthProofFromWireAnnouncement( + ann lnwire.ChannelAnnouncement) (*ChannelAuthProof, error) { + + switch ann := ann.(type) { + case *lnwire.ChannelAnnouncement1: + return NewV1ChannelAuthProof( + ann.NodeSig1.ToSignatureBytes(), + ann.NodeSig2.ToSignatureBytes(), + ann.BitcoinSig1.ToSignatureBytes(), + ann.BitcoinSig2.ToSignatureBytes(), + ), nil + + case *lnwire.ChannelAnnouncement2: + return NewV2ChannelAuthProof( + ann.Signature.Val.ToSignatureBytes(), + ), nil + + default: + return nil, fmt.Errorf("unsupported channel announcement: %T", + ann) + } +} + +// ChannelAuthProofFromAnnounceSignatures derives a channel auth proof from two +// opposing announce signatures messages. +func ChannelAuthProofFromAnnounceSignatures(ann, oppositeAnn lnwire.AnnounceSignatures, + isFirstNode bool) (*ChannelAuthProof, error) { + + if ann == nil || oppositeAnn == nil { + return nil, fmt.Errorf("announce signatures cannot be nil") + } + + if ann.GossipVersion() != oppositeAnn.GossipVersion() { + return nil, fmt.Errorf("announce signatures version mismatch: %v "+ + "!= %v", ann.GossipVersion(), oppositeAnn.GossipVersion()) + } + + switch annSig := ann.(type) { + case *lnwire.AnnounceSignatures1: + oppSig, ok := oppositeAnn.(*lnwire.AnnounceSignatures1) + if !ok { + return nil, fmt.Errorf("unexpected opposite announce "+ + "signatures type: %T", oppositeAnn) + } + + if isFirstNode { + return NewV1ChannelAuthProof( + annSig.NodeSignature.ToSignatureBytes(), + oppSig.NodeSignature.ToSignatureBytes(), + annSig.BitcoinSignature.ToSignatureBytes(), + oppSig.BitcoinSignature.ToSignatureBytes(), + ), nil + } + + return NewV1ChannelAuthProof( + oppSig.NodeSignature.ToSignatureBytes(), + annSig.NodeSignature.ToSignatureBytes(), + oppSig.BitcoinSignature.ToSignatureBytes(), + annSig.BitcoinSignature.ToSignatureBytes(), + ), nil + + case *lnwire.AnnounceSignatures2: + return nil, ErrV2AnnSigProofAssemblyPending + + default: + return nil, fmt.Errorf("unsupported announce signatures: %T", ann) + } +} diff --git a/graph/db/models/channel_auth_proof_test.go b/graph/db/models/channel_auth_proof_test.go new file mode 100644 index 00000000000..dd75383b70b --- /dev/null +++ b/graph/db/models/channel_auth_proof_test.go @@ -0,0 +1,64 @@ +package models + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +func TestChannelAuthProofFromAnnounceSignaturesV1(t *testing.T) { + t.Parallel() + + sig := func(b byte) lnwire.Sig { + raw := bytes.Repeat([]byte{b}, 64) + s, err := lnwire.NewSigFromWireECDSA(raw) + require.NoError(t, err) + return s + } + + ann1 := &lnwire.AnnounceSignatures1{ + NodeSignature: sig(1), + BitcoinSignature: sig(2), + } + ann2 := &lnwire.AnnounceSignatures1{ + NodeSignature: sig(3), + BitcoinSignature: sig(4), + } + + proof, err := ChannelAuthProofFromAnnounceSignatures(ann1, ann2, true) + require.NoError(t, err) + require.Equal(t, ann1.NodeSignature.ToSignatureBytes(), proof.NodeSig1()) + require.Equal(t, ann2.NodeSignature.ToSignatureBytes(), proof.NodeSig2()) + require.Equal(t, ann1.BitcoinSignature.ToSignatureBytes(), + proof.BitcoinSig1()) + require.Equal(t, ann2.BitcoinSignature.ToSignatureBytes(), + proof.BitcoinSig2()) +} + +func TestChannelAuthProofFromAnnounceSignaturesV2Pending(t *testing.T) { + t.Parallel() + + ann := lnwire.NewAnnSigs2( + lnwire.ChannelID{}, lnwire.ShortChannelID{}, lnwire.PartialSig{}, + ) + + proof, err := ChannelAuthProofFromAnnounceSignatures(ann, ann, true) + require.ErrorIs(t, err, ErrV2AnnSigProofAssemblyPending) + require.Nil(t, proof) +} + +func TestChannelAuthProofFromAnnounceSignaturesVersionMismatch(t *testing.T) { + t.Parallel() + + ann1 := &lnwire.AnnounceSignatures1{} + ann2 := lnwire.NewAnnSigs2( + lnwire.ChannelID{}, lnwire.ShortChannelID{}, lnwire.PartialSig{}, + ) + + proof, err := ChannelAuthProofFromAnnounceSignatures(ann1, ann2, true) + require.Error(t, err) + require.Contains(t, err.Error(), "version mismatch") + require.Nil(t, proof) +} diff --git a/graph/db/models/channel_edge_info.go b/graph/db/models/channel_edge_info.go index aae7b91f0b7..91c9654dab4 100644 --- a/graph/db/models/channel_edge_info.go +++ b/graph/db/models/channel_edge_info.go @@ -13,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" ) // ChannelEdgeInfo represents a fully authenticated channel along with all its @@ -256,6 +257,59 @@ func NewV2Channel(chanID uint64, chainHash chainhash.Hash, node1, return edge, nil } +// ChannelEdgeInfoFromWireAnnouncement constructs a ChannelEdgeInfo from a wire +// channel announcement message. +func ChannelEdgeInfoFromWireAnnouncement(msg lnwire.ChannelAnnouncement, + proof *ChannelAuthProof) (*ChannelEdgeInfo, error) { + + switch msg := msg.(type) { + case *lnwire.ChannelAnnouncement1: + return NewV1Channel( + msg.ShortChannelID.ToUint64(), msg.ChainHash, + msg.NodeID1, msg.NodeID2, &ChannelV1Fields{ + BitcoinKey1Bytes: msg.BitcoinKey1, + BitcoinKey2Bytes: msg.BitcoinKey2, + ExtraOpaqueData: msg.ExtraOpaqueData, + }, + WithChanProof(proof), WithFeatures(msg.Features), + ) + + case *lnwire.ChannelAnnouncement2: + bitcoinKey1 := fn.MapOption( + func(key [33]byte) route.Vertex { + return key + }, + )(msg.BitcoinKey1.ValOpt()) + bitcoinKey2 := fn.MapOption( + func(key [33]byte) route.Vertex { + return key + }, + )(msg.BitcoinKey2.ValOpt()) + merkleRootHash := fn.MapOption( + func(hash [32]byte) chainhash.Hash { + return hash + }, + )(msg.MerkleRootHash.ValOpt()) + + return NewV2Channel( + msg.ShortChannelID.Val.ToUint64(), msg.ChainHash.Val, + msg.NodeID1.Val, msg.NodeID2.Val, &ChannelV2Fields{ + BitcoinKey1Bytes: bitcoinKey1, + BitcoinKey2Bytes: bitcoinKey2, + MerkleRootHash: merkleRootHash, + ExtraSignedFields: msg.ExtraSignedFields, + }, + WithChanProof(proof), WithFeatures(&msg.Features.Val), + WithCapacity(btcutil.Amount(msg.Capacity.Val)), + WithChannelPoint(wire.OutPoint(msg.Outpoint.Val)), + ) + + default: + return nil, fmt.Errorf("unsupported channel announcement: %T", + msg) + } +} + // NodeKey1 is the identity public key of the "first" node that was involved in // the creation of this channel. A node is considered "first" if the // lexicographical ordering the its serialized public key is "smaller" than @@ -452,3 +506,84 @@ func (c *ChannelEdgeInfo) ToChannelAnnouncement() ( return chanAnn, nil } + +// ToWireAnnouncement converts the ChannelEdgeInfo to a version-aware wire +// channel announcement. +func (c *ChannelEdgeInfo) ToWireAnnouncement() ( + lnwire.ChannelAnnouncement, error) { + + switch c.Version { + case lnwire.GossipVersion1: + return c.ToChannelAnnouncement() + + case lnwire.GossipVersion2: + return c.toChannelAnnouncement2() + + default: + return nil, fmt.Errorf("unsupported channel version: %d", + c.Version) + } +} + +// toChannelAnnouncement2 converts the ChannelEdgeInfo to a +// lnwire.ChannelAnnouncement2 message. +func (c *ChannelEdgeInfo) toChannelAnnouncement2() ( + *lnwire.ChannelAnnouncement2, error) { + + // If there's no auth proof, we can't create a full channel + // announcement. + if c.AuthProof == nil { + return nil, fmt.Errorf("cannot create channel announcement " + + "without auth proof") + } + + if c.AuthProof.Version != lnwire.GossipVersion2 { + return nil, fmt.Errorf("invalid channel auth proof version: %d", + c.AuthProof.Version) + } + + sigBytes := c.AuthProof.Sig() + if len(sigBytes) == 0 { + return nil, fmt.Errorf("missing signature for v2 channel " + + "announcement") + } + + sig, err := lnwire.NewSigFromSchnorrRawSignature(sigBytes) + if err != nil { + return nil, err + } + + features := lnwire.RawFeatureVector{} + if c.Features != nil && c.Features.RawFeatureVector != nil { + features = *c.Features.RawFeatureVector + } + + var chanAnn lnwire.ChannelAnnouncement2 + chanAnn.ChainHash.Val = c.ChainHash + chanAnn.Features.Val = features + chanAnn.ShortChannelID.Val = lnwire.NewShortChanIDFromInt(c.ChannelID) + chanAnn.Capacity.Val = uint64(c.Capacity) + chanAnn.NodeID1.Val = [33]byte(c.NodeKey1Bytes) + chanAnn.NodeID2.Val = [33]byte(c.NodeKey2Bytes) + chanAnn.Outpoint.Val = lnwire.OutPoint(c.ChannelPoint) + chanAnn.Signature.Val = sig + chanAnn.ExtraSignedFields = c.ExtraSignedFields + + c.BitcoinKey1Bytes.WhenSome(func(key route.Vertex) { + btcKey1 := tlv.ZeroRecordT[tlv.TlvType12, [33]byte]() + btcKey1.Val = [33]byte(key) + chanAnn.BitcoinKey1 = tlv.SomeRecordT(btcKey1) + }) + c.BitcoinKey2Bytes.WhenSome(func(key route.Vertex) { + btcKey2 := tlv.ZeroRecordT[tlv.TlvType14, [33]byte]() + btcKey2.Val = [33]byte(key) + chanAnn.BitcoinKey2 = tlv.SomeRecordT(btcKey2) + }) + c.MerkleRootHash.WhenSome(func(hash chainhash.Hash) { + merkleRoot := tlv.ZeroRecordT[tlv.TlvType16, [32]byte]() + merkleRoot.Val = [32]byte(hash) + chanAnn.MerkleRootHash = tlv.SomeRecordT(merkleRoot) + }) + + return &chanAnn, nil +} diff --git a/graph/db/models/channel_edge_policy.go b/graph/db/models/channel_edge_policy.go index 067c7861a7a..b840104251f 100644 --- a/graph/db/models/channel_edge_policy.go +++ b/graph/db/models/channel_edge_policy.go @@ -151,11 +151,11 @@ func ChanEdgePolicyFromWire(scid uint64, // IsNode1 returns true if this policy was announced by the channel's node_1. func (c *ChannelEdgePolicy) IsNode1() bool { - if c.Version == lnwire.GossipVersion1 { - return c.ChannelFlags&lnwire.ChanUpdateDirection == 0 + if c.Version == lnwire.GossipVersion2 { + return !c.SecondPeer } - return !c.SecondPeer + return c.ChannelFlags&lnwire.ChanUpdateDirection == 0 } // IsDisabled determines whether the edge has the disabled bit set. diff --git a/graph/db/models/node.go b/graph/db/models/node.go index 7f3ce777ab9..bbff70fdc15 100644 --- a/graph/db/models/node.go +++ b/graph/db/models/node.go @@ -10,6 +10,8 @@ import ( "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" ) // Node represents an individual vertex/node within the channel graph. @@ -22,6 +24,7 @@ type Node struct { // PubKeyBytes is the raw bytes of the public key of the target node. PubKeyBytes [33]byte + pubKey *btcec.PublicKey // LastUpdate is the last time the vertex information for this node has // been updated. @@ -185,20 +188,59 @@ func (n *Node) HaveAnnouncement() bool { // PubKey is the node's long-term identity public key. This key will be used to // authenticated any advertisements/updates sent by the node. +// +// NOTE: By having this method to access the attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. func (n *Node) PubKey() (*btcec.PublicKey, error) { - return btcec.ParsePubKey(n.PubKeyBytes[:]) + if n.pubKey != nil { + return n.pubKey, nil + } + + key, err := btcec.ParsePubKey(n.PubKeyBytes[:]) + if err != nil { + return nil, err + } + n.pubKey = key + + return key, nil } -// NodeAnnouncement retrieves the latest node announcement of the node. +// NodeAnnouncement retrieves the v1 node announcement for this node. func (n *Node) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement1, error) { - // Error out if we request the signed announcement, but we don't have - // a signature for this announcement. - if !n.HaveAnnouncement() && signed { + return n.toNodeAnnouncement1(signed) +} + +// WireNodeAnnouncement reconstructs the version-appropriate wire node +// announcement for this node. If signed is true, the returned announcement +// will include the node's signature. Returns an error if signed is true but +// no signature is stored. +func (n *Node) WireNodeAnnouncement(signed bool) (lnwire.NodeAnnouncement, + error) { + + if signed && !n.HaveAnnouncement() { return nil, fmt.Errorf("node does not have node announcement") } + switch n.Version { + case lnwire.GossipVersion1: + return n.toNodeAnnouncement1(signed) + + case lnwire.GossipVersion2: + return n.toNodeAnnouncement2(signed) + + default: + return nil, fmt.Errorf("unsupported node version: %d", + n.Version) + } +} + +// toNodeAnnouncement1 constructs a v1 node announcement from the node's +// stored fields. +func (n *Node) toNodeAnnouncement1(signed bool) (*lnwire.NodeAnnouncement1, + error) { + alias, err := lnwire.NewNodeAlias(n.Alias.UnwrapOr("")) if err != nil { return nil, err @@ -218,31 +260,173 @@ func (n *Node) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement1, return nodeAnn, nil } - sig, err := lnwire.NewSigFromECDSARawSignature(n.AuthSigBytes) + nodeAnn.Signature, err = lnwire.NewSigFromECDSARawSignature( + n.AuthSigBytes, + ) if err != nil { return nil, err } - nodeAnn.Signature = sig - return nodeAnn, nil } -// NodeFromWireAnnouncement creates a Node instance from an -// lnwire.NodeAnnouncement1 message. -func NodeFromWireAnnouncement(msg *lnwire.NodeAnnouncement1) *Node { - timestamp := time.Unix(int64(msg.Timestamp), 0) - - return NewV1Node( - msg.NodeID, - &NodeV1Fields{ - LastUpdate: timestamp, - Addresses: msg.Addresses, - Alias: msg.Alias.String(), - AuthSigBytes: msg.Signature.ToSignatureBytes(), - Features: msg.Features, - Color: msg.RGBColor, - ExtraOpaqueData: msg.ExtraOpaqueData, - }, +// toNodeAnnouncement2 constructs a v2 node announcement from the node's +// stored fields. +func (n *Node) toNodeAnnouncement2(signed bool) (*lnwire.NodeAnnouncement2, + error) { + + nodeAnn := &lnwire.NodeAnnouncement2{ + Features: tlv.NewRecordT[tlv.TlvType0]( + *n.Features.RawFeatureVector, + ), + BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType2]( + n.LastBlockHeight, + ), + NodeID: tlv.NewPrimitiveRecord[tlv.TlvType4, [33]byte]( + n.PubKeyBytes, + ), + ExtraSignedFields: n.ExtraSignedFields, + } + + n.Alias.WhenSome(func(s string) { + aliasRecord := tlv.ZeroRecordT[tlv.TlvType3, lnwire.NodeAlias2]() + aliasRecord.Val = lnwire.NodeAlias2(s) + nodeAnn.Alias = tlv.SomeRecordT(aliasRecord) + }) + + n.Color.WhenSome(func(rgba color.RGBA) { + colorRecord := tlv.ZeroRecordT[tlv.TlvType1, lnwire.Color]() + colorRecord.Val = lnwire.Color(rgba) + nodeAnn.Color = tlv.SomeRecordT(colorRecord) + }) + + // Categorise addresses by type for the separate TLV fields. + var ( + ipv4 lnwire.IPV4Addrs + ipv6 lnwire.IPV6Addrs + torV3 lnwire.TorV3Addrs ) + for _, addr := range n.Addresses { + switch a := addr.(type) { + case *net.TCPAddr: + if a.IP.To4() != nil { + ipv4 = append(ipv4, a) + } else { + ipv6 = append(ipv6, a) + } + + case *tor.OnionAddr: + // Only v3 onion addresses are supported in gossip v2. + if len(a.OnionService) == tor.V3Len { + torV3 = append(torV3, a) + } + + case *lnwire.DNSAddress: + nodeAnn.DNSHostName = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType11](*a), + ) + } + } + if len(ipv4) > 0 { + nodeAnn.IPV4Addrs = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType5](ipv4), + ) + } + if len(ipv6) > 0 { + nodeAnn.IPV6Addrs = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType7](ipv6), + ) + } + if len(torV3) > 0 { + nodeAnn.TorV3Addrs = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType9](torV3), + ) + } + + if !signed { + return nodeAnn, nil + } + + var err error + nodeAnn.Signature.Val, err = lnwire.NewSigFromSchnorrRawSignature( + n.AuthSigBytes, + ) + if err != nil { + return nil, err + } + + return nodeAnn, nil +} + +// NodeFromWireAnnouncement creates a Node instance from a node announcement +// wire message. +func NodeFromWireAnnouncement(msg lnwire.NodeAnnouncement) (*Node, error) { + switch msg := msg.(type) { + case *lnwire.NodeAnnouncement1: + timestamp := time.Unix(int64(msg.Timestamp), 0) + authSig := msg.Signature.ToSignatureBytes() + + return NewV1Node( + msg.NodeID, + &NodeV1Fields{ + LastUpdate: timestamp, + Addresses: msg.Addresses, + Alias: msg.Alias.String(), + AuthSigBytes: authSig, + Features: msg.Features, + Color: msg.RGBColor, + ExtraOpaqueData: msg.ExtraOpaqueData, + }, + ), nil + + case *lnwire.NodeAnnouncement2: + var addrs []net.Addr + ipv4Opt := msg.IPV4Addrs.ValOpt() + ipv4Opt.WhenSome(func(ipv4Addrs lnwire.IPV4Addrs) { + for _, addr := range ipv4Addrs { + addrs = append(addrs, addr) + } + }) + ipv6Opt := msg.IPV6Addrs.ValOpt() + ipv6Opt.WhenSome(func(ipv6Addrs lnwire.IPV6Addrs) { + for _, addr := range ipv6Addrs { + addrs = append(addrs, addr) + } + }) + torOpt := msg.TorV3Addrs.ValOpt() + torOpt.WhenSome(func(torAddrs lnwire.TorV3Addrs) { + for _, addr := range torAddrs { + addrs = append(addrs, addr) + } + }) + dnsOpt := msg.DNSHostName.ValOpt() + dnsOpt.WhenSome(func(dnsAddr lnwire.DNSAddress) { + dns := dnsAddr + addrs = append(addrs, &dns) + }) + + nodeColor := fn.MapOption(func(c lnwire.Color) color.RGBA { + return color.RGBA(c) + })(msg.Color.ValOpt()) + + nodeAlias := fn.MapOption(func(a lnwire.NodeAlias2) string { + return string(a) + })(msg.Alias.ValOpt()) + + sig := msg.Signature.Val.ToSignatureBytes() + + return NewV2Node( + msg.NodeID.Val, &NodeV2Fields{ + LastBlockHeight: msg.BlockHeight.Val, + Addresses: addrs, + Color: nodeColor, + Alias: nodeAlias, + Signature: sig, + Features: &msg.Features.Val, + ExtraSignedFields: msg.ExtraSignedFields, + }, + ), nil + } + + return nil, fmt.Errorf("unsupported node announcement: %T", msg) } diff --git a/graph/db/models/node_test.go b/graph/db/models/node_test.go new file mode 100644 index 00000000000..136a0963b18 --- /dev/null +++ b/graph/db/models/node_test.go @@ -0,0 +1,343 @@ +package models + +import ( + "image/color" + "net" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" + "github.com/stretchr/testify/require" +) + +// newTestKey generates a fresh private key for tests. +func newTestKey(t *testing.T) *btcec.PrivateKey { + t.Helper() + key, err := btcec.NewPrivateKey() + require.NoError(t, err) + return key +} + +// pubVertex returns the compressed-public-key route.Vertex for a key. +func pubVertex(key *btcec.PrivateKey) route.Vertex { + var v route.Vertex + copy(v[:], key.PubKey().SerializeCompressed()) + return v +} + +// ecdsaSigBytes signs a 32-byte hash with an ECDSA key and returns DER bytes. +func ecdsaSigBytes(t *testing.T, key *btcec.PrivateKey, hash [32]byte) []byte { + t.Helper() + sig := ecdsa.Sign(key, hash[:]) + return sig.Serialize() +} + +// schnorrSigBytes signs a 32-byte hash with a Schnorr key and returns 64 bytes. +func schnorrSigBytes(t *testing.T, key *btcec.PrivateKey, hash [32]byte) []byte { + t.Helper() + sig, err := schnorr.Sign(key, hash[:]) + require.NoError(t, err) + return sig.Serialize() +} + +// TestNodeAnnouncementV1UnsignedRoundTrip verifies that a v1 Node can +// reconstruct an unsigned NodeAnnouncement1 with all fields intact. +func TestNodeAnnouncementV1UnsignedRoundTrip(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + addr := &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 9735} + ts := time.Unix(1_000_000, 0).UTC() + + node := NewV1Node(pub, &NodeV1Fields{ + LastUpdate: ts, + Addresses: []net.Addr{addr}, + Alias: "alice", + Color: color.RGBA{R: 1, G: 2, B: 3}, + Features: lnwire.EmptyFeatureVector().RawFeatureVector, + ExtraOpaqueData: []byte{0xde, 0xad}, + }) + + ann, err := node.WireNodeAnnouncement(false) + require.NoError(t, err) + + ann1, ok := ann.(*lnwire.NodeAnnouncement1) + require.True(t, ok, "expected *NodeAnnouncement1") + require.Equal(t, pub, route.Vertex(ann1.NodeID)) + require.Equal(t, "alice", ann1.Alias.String()) + require.Equal(t, color.RGBA{R: 1, G: 2, B: 3}, ann1.RGBColor) + require.Equal(t, uint32(ts.Unix()), ann1.Timestamp) + require.EqualValues(t, []byte{0xde, 0xad}, ann1.ExtraOpaqueData) + require.Len(t, ann1.Addresses, 1) +} + +// TestNodeAnnouncementV1SignedRoundTrip verifies that NodeAnnouncement includes +// the stored ECDSA signature for v1 when signed=true. +func TestNodeAnnouncementV1SignedRoundTrip(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + sigBytes := ecdsaSigBytes(t, key, [32]byte{0x01}) + + node := NewV1Node(pub, &NodeV1Fields{ + LastUpdate: time.Now(), + AuthSigBytes: sigBytes, + Features: lnwire.EmptyFeatureVector().RawFeatureVector, + }) + + ann, err := node.WireNodeAnnouncement(true) + require.NoError(t, err) + + ann1, ok := ann.(*lnwire.NodeAnnouncement1) + require.True(t, ok) + require.Equal(t, sigBytes, ann1.Signature.ToSignatureBytes()) +} + +// TestNodeAnnouncementV2UnsignedRoundTrip verifies that a v2 Node can +// reconstruct an unsigned NodeAnnouncement2 with all core fields intact. +func TestNodeAnnouncementV2UnsignedRoundTrip(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + + node := NewV2Node(pub, &NodeV2Fields{ + LastBlockHeight: 42, + Alias: fn.Some("bob"), + Color: fn.Some(color.RGBA{R: 10, G: 20, B: 30}), + Features: lnwire.EmptyFeatureVector().RawFeatureVector, + ExtraSignedFields: map[uint64][]byte{9999: {0xab}}, + }) + + ann, err := node.WireNodeAnnouncement(false) + require.NoError(t, err) + + ann2, ok := ann.(*lnwire.NodeAnnouncement2) + require.True(t, ok, "expected *NodeAnnouncement2") + require.Equal(t, pub, route.Vertex(ann2.NodeID.Val)) + require.EqualValues(t, 42, ann2.BlockHeight.Val) + require.EqualValues(t, map[uint64][]byte{9999: {0xab}}, ann2.ExtraSignedFields) + + aliasOpt := ann2.Alias.ValOpt() + require.True(t, aliasOpt.IsSome()) + aliasOpt.WhenSome(func(a lnwire.NodeAlias2) { + require.Equal(t, "bob", string(a)) + }) + + colorOpt := ann2.Color.ValOpt() + require.True(t, colorOpt.IsSome()) + colorOpt.WhenSome(func(c lnwire.Color) { + require.Equal(t, color.RGBA{R: 10, G: 20, B: 30}, color.RGBA(c)) + }) +} + +// TestNodeAnnouncementV2SignedRoundTrip verifies that NodeAnnouncement2 +// includes the stored Schnorr signature when signed=true. +func TestNodeAnnouncementV2SignedRoundTrip(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + sigBytes := schnorrSigBytes(t, key, [32]byte{0x02}) + + node := NewV2Node(pub, &NodeV2Fields{ + LastBlockHeight: 10, + Signature: sigBytes, + Features: lnwire.EmptyFeatureVector().RawFeatureVector, + }) + + ann, err := node.WireNodeAnnouncement(true) + require.NoError(t, err) + + ann2, ok := ann.(*lnwire.NodeAnnouncement2) + require.True(t, ok) + require.Equal(t, sigBytes, ann2.Signature.Val.ToSignatureBytes()) +} + +// TestNodeAnnouncementV2AddressCategorization verifies that addresses of +// different types are placed into the correct TLV fields on NodeAnnouncement2. +func TestNodeAnnouncementV2AddressCategorization(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + + ipv4Addr := &net.TCPAddr{IP: net.ParseIP("1.2.3.4").To4(), Port: 9735} + ipv6Addr := &net.TCPAddr{IP: net.ParseIP("::1"), Port: 9736} + torAddr := &tor.OnionAddr{ + OnionService: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.onion", + Port: 9737, + } + dnsAddr := &lnwire.DNSAddress{Hostname: "example.com", Port: 9738} + + node := NewV2Node(pub, &NodeV2Fields{ + Features: lnwire.EmptyFeatureVector().RawFeatureVector, + Addresses: []net.Addr{ipv4Addr, ipv6Addr, torAddr, dnsAddr}, + }) + + ann, err := node.WireNodeAnnouncement(false) + require.NoError(t, err) + + ann2, ok := ann.(*lnwire.NodeAnnouncement2) + require.True(t, ok) + + ann2.IPV4Addrs.ValOpt().WhenSome(func(addrs lnwire.IPV4Addrs) { + require.Len(t, addrs, 1) + require.Equal(t, ipv4Addr.Port, addrs[0].Port) + }) + require.True(t, ann2.IPV4Addrs.ValOpt().IsSome(), "expected ipv4 addrs") + + ann2.IPV6Addrs.ValOpt().WhenSome(func(addrs lnwire.IPV6Addrs) { + require.Len(t, addrs, 1) + require.Equal(t, ipv6Addr.Port, addrs[0].Port) + }) + require.True(t, ann2.IPV6Addrs.ValOpt().IsSome(), "expected ipv6 addrs") + + ann2.TorV3Addrs.ValOpt().WhenSome(func(addrs lnwire.TorV3Addrs) { + require.Len(t, addrs, 1) + require.Equal(t, torAddr.Port, addrs[0].Port) + }) + require.True(t, ann2.TorV3Addrs.ValOpt().IsSome(), "expected tor addrs") + + ann2.DNSHostName.ValOpt().WhenSome(func(d lnwire.DNSAddress) { + require.Equal(t, dnsAddr.Hostname, d.Hostname) + require.Equal(t, dnsAddr.Port, d.Port) + }) + require.True(t, ann2.DNSHostName.ValOpt().IsSome(), "expected dns addr") +} + +// TestNodeFromWireAnnouncementV1 verifies a NodeAnnouncement1 is parsed into +// a v1 Node with all fields correctly mapped. +func TestNodeFromWireAnnouncementV1(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + var pubBytes [33]byte + copy(pubBytes[:], key.PubKey().SerializeCompressed()) + + rawSig := ecdsaSigBytes(t, key, [32]byte{0x03}) + wireSig, err := lnwire.NewSigFromECDSARawSignature(rawSig) + require.NoError(t, err) + + alias, err := lnwire.NewNodeAlias("carol") + require.NoError(t, err) + + ann1 := &lnwire.NodeAnnouncement1{ + Signature: wireSig, + Features: lnwire.EmptyFeatureVector().RawFeatureVector, + Timestamp: uint32(1_000_000), + NodeID: pubBytes, + RGBColor: color.RGBA{R: 5, G: 6, B: 7}, + Alias: alias, + Addresses: []net.Addr{&net.TCPAddr{IP: net.ParseIP("9.8.7.6"), Port: 9735}}, + ExtraOpaqueData: []byte{0xff}, + } + + node, err := NodeFromWireAnnouncement(ann1) + require.NoError(t, err) + require.Equal(t, lnwire.GossipVersion1, node.Version) + require.EqualValues(t, pubBytes, node.PubKeyBytes) + require.Equal(t, "carol", node.Alias.UnwrapOr("")) + require.Equal(t, color.RGBA{R: 5, G: 6, B: 7}, node.Color.UnwrapOr(color.RGBA{})) + require.Equal(t, rawSig, node.AuthSigBytes) + require.True(t, node.HaveAnnouncement()) + + // Round-trip back to wire should yield a v1 announcement. + roundTripped, err := node.WireNodeAnnouncement(true) + require.NoError(t, err) + _, ok := roundTripped.(*lnwire.NodeAnnouncement1) + require.True(t, ok, "round-trip should produce *NodeAnnouncement1") +} + +// TestNodeFromWireAnnouncementV2 verifies a NodeAnnouncement2 is parsed into +// a v2 Node with all fields correctly mapped. +func TestNodeFromWireAnnouncementV2(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + var pubBytes [33]byte + copy(pubBytes[:], key.PubKey().SerializeCompressed()) + + rawSig := schnorrSigBytes(t, key, [32]byte{0x04}) + wireSig, err := lnwire.NewSigFromSchnorrRawSignature(rawSig) + require.NoError(t, err) + + ann2 := &lnwire.NodeAnnouncement2{ + Features: tlv.NewRecordT[tlv.TlvType0]( + *lnwire.EmptyFeatureVector().RawFeatureVector, + ), + BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType2](uint32(100)), + NodeID: tlv.NewPrimitiveRecord[tlv.TlvType4, [33]byte](pubBytes), + Signature: tlv.NewRecordT[tlv.TlvType160](wireSig), + } + + node, err := NodeFromWireAnnouncement(ann2) + require.NoError(t, err) + require.Equal(t, lnwire.GossipVersion2, node.Version) + require.EqualValues(t, pubBytes, node.PubKeyBytes) + require.EqualValues(t, 100, node.LastBlockHeight) + require.Equal(t, rawSig, node.AuthSigBytes) + require.True(t, node.HaveAnnouncement()) + + // Round-trip back to wire should yield a v2 announcement. + roundTripped, err := node.WireNodeAnnouncement(true) + require.NoError(t, err) + _, ok := roundTripped.(*lnwire.NodeAnnouncement2) + require.True(t, ok, "round-trip should produce *NodeAnnouncement2") +} + +// TestNodeHaveAnnouncementAndPubKeyCache verifies HaveAnnouncement behaviour +// and that PubKey caches the parsed key across repeated calls. +func TestNodeHaveAnnouncementAndPubKeyCache(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + + shell := NewShellNode(lnwire.GossipVersion1, pub) + require.False(t, shell.HaveAnnouncement()) + + // First PubKey call parses bytes and caches the result. + first, err := shell.PubKey() + require.NoError(t, err) + require.True(t, key.PubKey().IsEqual(first)) + + // Second call returns the same cached pointer. + second, err := shell.PubKey() + require.NoError(t, err) + require.Same(t, first, second) + + // After adding a signature, HaveAnnouncement returns true. + shell.AuthSigBytes = []byte{0x01} + require.True(t, shell.HaveAnnouncement()) +} + +// TestNodeAnnouncementUnsignedNoSigRequired verifies that unsigned +// announcements succeed even when no signature is stored, and that +// signed=true correctly fails in that case. +func TestNodeAnnouncementUnsignedNoSigRequired(t *testing.T) { + t.Parallel() + + key := newTestKey(t) + pub := pubVertex(key) + + // Shell node has no signature. + node := NewShellNode(lnwire.GossipVersion1, pub) + require.False(t, node.HaveAnnouncement()) + + _, err := node.WireNodeAnnouncement(false) + require.NoError(t, err) + + // Requesting signed=true must fail when no sig is stored. + _, err = node.WireNodeAnnouncement(true) + require.ErrorContains(t, err, "does not have node announcement") +} diff --git a/graph/db/notifications.go b/graph/db/notifications.go index 1c8a8898602..1597abab6cf 100644 --- a/graph/db/notifications.go +++ b/graph/db/notifications.go @@ -424,7 +424,7 @@ func (c *ChannelGraph) addToTopologyChange(ctx context.Context, // the second node. sourceNode := edgeInfo.NodeKey1 connectingNode := edgeInfo.NodeKey2 - if m.ChannelFlags&lnwire.ChanUpdateDirection == 1 { + if !m.IsNode1() { sourceNode = edgeInfo.NodeKey2 connectingNode = edgeInfo.NodeKey1 } @@ -449,7 +449,7 @@ func (c *ChannelGraph) addToTopologyChange(ctx context.Context, FeeRate: m.FeeProportionalMillionths, AdvertisingNode: aNode, ConnectingNode: cNode, - Disabled: m.ChannelFlags&lnwire.ChanUpdateDisabled != 0, + Disabled: m.IsDisabled(), InboundFee: m.InboundFee, ExtraOpaqueData: m.ExtraOpaqueData, } diff --git a/graph/db/options.go b/graph/db/options.go index 15ea6f4ee85..58ee381f4a2 100644 --- a/graph/db/options.go +++ b/graph/db/options.go @@ -1,6 +1,13 @@ package graphdb -import "time" +import ( + "fmt" + "iter" + "time" + + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/lnwire" +) const ( // DefaultRejectCacheSize is the default number of rejectCacheEntries to @@ -39,6 +46,181 @@ type iterConfig struct { iterPublicNodes bool } +// ChanUpdateRange describes a range for channel updates. Only one of the time +// or height ranges should be set depending on the gossip version. +type ChanUpdateRange struct { + // StartTime is the inclusive lower time bound (v1 gossip only). + StartTime fn.Option[time.Time] + + // EndTime is the exclusive upper time bound (v1 gossip only). + EndTime fn.Option[time.Time] + + // StartHeight is the inclusive lower block-height bound (v2 gossip + // only). + StartHeight fn.Option[uint32] + + // EndHeight is the exclusive upper block-height bound (v2 gossip + // only). + EndHeight fn.Option[uint32] +} + +// validateForVersion checks that the range fields are consistent with the +// given gossip version: v1 requires time bounds, v2 requires block-height +// bounds, and mixing the two is rejected. +func (r ChanUpdateRange) validateForVersion(v lnwire.GossipVersion) error { + hasStartTime := r.StartTime.IsSome() + hasEndTime := r.EndTime.IsSome() + hasTimeRange := hasStartTime || hasEndTime + + hasStartHeight := r.StartHeight.IsSome() + hasEndHeight := r.EndHeight.IsSome() + hasBlockRange := hasStartHeight || hasEndHeight + + if hasTimeRange && hasBlockRange { + return fmt.Errorf("chan update range has both " + + "time and block ranges") + } + + switch v { + case lnwire.GossipVersion1: + if hasBlockRange { + return fmt.Errorf("v1 chan update range must use time") + } + if !hasTimeRange { + return fmt.Errorf("v1 chan update range missing time") + } + if !hasStartTime || !hasEndTime { + return fmt.Errorf("v1 chan update range " + + "missing time bounds") + } + + start := r.StartTime.UnwrapOr(time.Time{}) + end := r.EndTime.UnwrapOr(time.Time{}) + if start.After(end) { + return fmt.Errorf("v1 chan update range: " + + "start time after end time") + } + + case lnwire.GossipVersion2: + if hasTimeRange { + return fmt.Errorf("v2 chan update range " + + "must use blocks") + } + if !hasBlockRange { + return fmt.Errorf("v2 chan update range " + + "missing block range") + } + if !hasStartHeight || !hasEndHeight { + return fmt.Errorf("v2 chan update range " + + "missing block bounds") + } + + start := r.StartHeight.UnwrapOr(0) + end := r.EndHeight.UnwrapOr(0) + if start > end { + return fmt.Errorf("v2 chan update range: " + + "start height after end height") + } + + default: + return fmt.Errorf("unknown gossip version: %v", v) + } + + return nil +} + +// chanUpdateRangeErrIter returns an iterator that yields a single error. +func chanUpdateRangeErrIter(err error) iter.Seq2[ChannelEdge, error] { + return func(yield func(ChannelEdge, error) bool) { + _ = yield(ChannelEdge{}, err) + } +} + +// NodeUpdateRange describes a range for node updates. Only one of the time or +// height ranges should be set depending on the gossip version. +type NodeUpdateRange struct { + // StartTime is the inclusive lower time bound (v1 gossip only). + StartTime fn.Option[time.Time] + + // EndTime is the inclusive upper time bound (v1 gossip only). + EndTime fn.Option[time.Time] + + // StartHeight is the inclusive lower block-height bound (v2 gossip + // only). + StartHeight fn.Option[uint32] + + // EndHeight is the inclusive upper block-height bound (v2 gossip + // only). + EndHeight fn.Option[uint32] +} + +// validateForVersion checks that the range fields are consistent with the +// given gossip version: v1 requires time bounds, v2 requires block-height +// bounds, and mixing the two is rejected. +func (r NodeUpdateRange) validateForVersion(v lnwire.GossipVersion) error { + hasStartTime := r.StartTime.IsSome() + hasEndTime := r.EndTime.IsSome() + hasTimeRange := hasStartTime || hasEndTime + + hasStartHeight := r.StartHeight.IsSome() + hasEndHeight := r.EndHeight.IsSome() + hasBlockRange := hasStartHeight || hasEndHeight + + if hasTimeRange && hasBlockRange { + return fmt.Errorf("node update range has both " + + "time and block ranges") + } + + switch v { + case lnwire.GossipVersion1: + if hasBlockRange { + return fmt.Errorf("v1 node update range " + + "must use time") + } + if !hasTimeRange { + return fmt.Errorf("v1 node update range " + + "missing time") + } + if !hasStartTime || !hasEndTime { + return fmt.Errorf("v1 node update range " + + "missing time bounds") + } + + start := r.StartTime.UnwrapOr(time.Time{}) + end := r.EndTime.UnwrapOr(time.Time{}) + if start.After(end) { + return fmt.Errorf("v1 node update range: " + + "start time after end time") + } + + case lnwire.GossipVersion2: + if hasTimeRange { + return fmt.Errorf("v2 node update range " + + "must use height") + } + if !hasBlockRange { + return fmt.Errorf("v2 node update range " + + "missing height") + } + if !hasStartHeight || !hasEndHeight { + return fmt.Errorf("v2 node update range " + + "missing height bounds") + } + + start := r.StartHeight.UnwrapOr(0) + end := r.EndHeight.UnwrapOr(0) + if start > end { + return fmt.Errorf("v2 node update range: " + + "start height after end height") + } + + default: + return fmt.Errorf("unknown gossip version: %d", v) + } + + return nil +} + // defaultIteratorConfig returns the default configuration. func defaultIteratorConfig() *iterConfig { return &iterConfig{ diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 0e3907599e2..eebf8204dfb 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -52,6 +52,7 @@ type SQLQueries interface { GetNodesByIDs(ctx context.Context, ids []int64) ([]sqlc.GraphNode, error) GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error) GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.GraphNode, error) + GetNodesByBlockHeightRange(ctx context.Context, arg sqlc.GetNodesByBlockHeightRangeParams) ([]sqlc.GraphNode, error) ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.GraphNode, error) ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error) IsPublicV1Node(ctx context.Context, pubKey []byte) (bool, error) @@ -107,6 +108,7 @@ type SQLQueries interface { ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error) ListChannelsPaginatedV2(ctx context.Context, arg sqlc.ListChannelsPaginatedV2Params) ([]sqlc.ListChannelsPaginatedV2Row, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) + GetChannelsByPolicyBlockRange(ctx context.Context, arg sqlc.GetChannelsByPolicyBlockRangeParams) ([]sqlc.GetChannelsByPolicyBlockRangeRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error) GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.GraphChannel, error) GetPublicV2ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV2ChannelsBySCIDParams) ([]sqlc.GraphChannel, error) @@ -606,14 +608,14 @@ func (s *SQLStore) SetSourceNode(ctx context.Context, }, sqldb.NoOpReset) } -// NodeUpdatesInHorizon returns all the known lightning node which have an -// update timestamp within the passed range. This method can be used by two -// nodes to quickly determine if they have the same set of up to date node -// announcements. +// NodeUpdatesInHorizon returns all the known lightning nodes which have +// updates within the passed range for the given gossip version. This method can +// be used by two nodes to quickly determine if they have the same set of +// up-to-date node announcements. // // NOTE: This is part of the Store interface. func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r NodeUpdateRange, opts ...IteratorOption) iter.Seq2[*models.Node, error] { cfg := defaultIteratorConfig() @@ -621,27 +623,33 @@ func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, opt(cfg) } + batchSize := cfg.nodeUpdateIterBatchSize + return func(yield func(*models.Node, error) bool) { var ( lastUpdateTime sql.NullInt64 + lastBlock sql.NullInt64 lastPubKey = make([]byte, 33) hasMore = true ) - // Each iteration, we'll read a batch amount of nodes, yield - // them, then decide is we have more or not. - for hasMore { - var batch []*models.Node + if err := r.validateForVersion(v); err != nil { + yield(nil, err) + return + } - //nolint:ll - err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - //nolint:ll - params := sqlc.GetNodesByLastUpdateRangeParams{ + queryV1 := func(db SQLQueries) ([]sqlc.GraphNode, error) { + return db.GetNodesByLastUpdateRange( + ctx, sqlc.GetNodesByLastUpdateRangeParams{ StartTime: sqldb.SQLInt64( - startTime.Unix(), + r.StartTime.UnwrapOr( + time.Time{}, + ).Unix(), ), EndTime: sqldb.SQLInt64( - endTime.Unix(), + r.EndTime.UnwrapOr( + time.Time{}, + ).Unix(), ), LastUpdate: lastUpdateTime, LastPubKey: lastPubKey, @@ -652,44 +660,106 @@ func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, MaxResults: sqldb.SQLInt32( cfg.nodeUpdateIterBatchSize, ), - } - rows, err := db.GetNodesByLastUpdateRange( - ctx, params, - ) - if err != nil { - return err - } + }, + ) + } - hasMore = len(rows) == cfg.nodeUpdateIterBatchSize + queryV2 := func(db SQLQueries) ([]sqlc.GraphNode, error) { + //nolint:ll + return db.GetNodesByBlockHeightRange( + ctx, sqlc.GetNodesByBlockHeightRangeParams{ + Version: int16(v), + StartHeight: sqldb.SQLInt64( + int64(r.StartHeight.UnwrapOr(0)), //nolint:ll + ), + EndHeight: sqldb.SQLInt64( + int64(r.EndHeight.UnwrapOr(0)), //nolint:ll + ), + LastBlockHeight: lastBlock, + LastPubKey: lastPubKey, + OnlyPublic: sql.NullBool{ + Bool: cfg.iterPublicNodes, + Valid: true, + }, + MaxResults: sqldb.SQLInt32( + cfg.nodeUpdateIterBatchSize, + ), + }, + ) + } - err = forEachNodeInBatch( - ctx, s.cfg.QueryCfg, db, rows, - func(_ int64, node *models.Node) error { - batch = append(batch, node) + // queryNodes fetches the next batch of nodes in the + // horizon range, dispatching to the version-appropriate + // query. + queryNodes := func(db SQLQueries) ([]sqlc.GraphNode, error) { + switch v { + case gossipV1: + return queryV1(db) - // Update pagination cursors - // based on the last processed - // node. - lastUpdateTime = sql.NullInt64{ - Int64: node.LastUpdate. - Unix(), - Valid: true, - } - lastPubKey = node.PubKeyBytes[:] + case gossipV2: + return queryV2(db) - return nil - }, - ) - if err != nil { - return fmt.Errorf("unable to build "+ - "nodes: %w", err) + default: + return nil, fmt.Errorf("unknown gossip "+ + "version: %v", v) + } + } + + // processNode is called for each node in a batch to + // accumulate results and update pagination cursors. + processNode := func(_ int64, + node *models.Node, batch *[]*models.Node) error { + + *batch = append(*batch, node) + + switch v { + case gossipV1: + lastUpdateTime = sql.NullInt64{ + Int64: node.LastUpdate.Unix(), + Valid: true, + } + case gossipV2: + lastBlock = sql.NullInt64{ + Int64: int64(node.LastBlockHeight), + Valid: true, } + } + lastPubKey = node.PubKeyBytes[:] - return nil - }, func() { - batch = []*models.Node{} - }) + return nil + } + + // Each iteration, we'll read a batch amount of nodes, + // yield them, then decide if we have more or not. + for hasMore { + var batch []*models.Node + + err := s.db.ExecTx( + ctx, sqldb.ReadTxOpt(), + func(db SQLQueries) error { + rows, err := queryNodes(db) + if err != nil { + return err + } + hasMore = len(rows) == batchSize + + //nolint:ll + return forEachNodeInBatch( + ctx, s.cfg.QueryCfg, db, + rows, func(id int64, + node *models.Node) error { + + return processNode( + id, node, + &batch, + ) + }, + ) + }, func() { + batch = []*models.Node{} + }, + ) if err != nil { log.Errorf("NodeUpdatesInHorizon batch "+ "error: %v", err) @@ -705,7 +775,8 @@ func (s *SQLStore) NodeUpdatesInHorizon(ctx context.Context, } } - // If the batch didn't yield anything, then we're done. + // If the batch didn't yield anything, then + // we're done. if len(batch) == 0 { break } @@ -985,18 +1056,60 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context, // early. // // NOTE: part of the Store interface. -func (s *SQLStore) ForEachNode(ctx context.Context, v lnwire.GossipVersion, +func (s *SQLStore) ForEachNode(ctx context.Context, cb func(node *models.Node) error, reset func()) error { return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachNodePaginated( - ctx, s.cfg.QueryCfg, db, - v, func(_ context.Context, _ int64, - node *models.Node) error { + // Collect nodes across all versions, preferring the highest + // version's data. + type nodeEntry struct { + node *models.Node + } + nodesByPub := make(map[route.Vertex]*nodeEntry) + var order []route.Vertex + + for _, v := range []lnwire.GossipVersion{ + gossipV1, gossipV2, + } { + err := forEachNodePaginated( + ctx, s.cfg.QueryCfg, db, + v, func(_ context.Context, _ int64, + node *models.Node) error { + + pub := node.PubKeyBytes + entry, exists := nodesByPub[pub] + if !exists { + entry = &nodeEntry{} + nodesByPub[pub] = entry + order = append(order, pub) + } - return cb(node) - }, - ) + // Prefer highest version with an + // announcement, fall back to shell + // nodes. A node has been announced + // if it carries a signature. + hasAnn := len(node.AuthSigBytes) > 0 + if entry.node == nil || hasAnn { + entry.node = node + } + + return nil + }, + ) + if err != nil { + return err + } + } + + for _, pub := range order { + entry := nodesByPub[pub] + err := cb(entry.node) + if err != nil { + return err + } + } + + return nil }, reset) } @@ -1092,6 +1205,26 @@ func extractMaxUpdateTime( } } +// extractMaxBlockHeight returns the maximum of the two policy block heights. +// This is used for pagination cursor tracking in v2 gossip queries. +func extractMaxBlockHeight( + row sqlc.GetChannelsByPolicyLastUpdateRangeRow) int64 { + + switch { + case row.Policy1BlockHeight.Valid && + row.Policy2BlockHeight.Valid: + + return max(row.Policy1BlockHeight.Int64, + row.Policy2BlockHeight.Int64) + case row.Policy1BlockHeight.Valid: + return row.Policy1BlockHeight.Int64 + case row.Policy2BlockHeight.Valid: + return row.Policy2BlockHeight.Int64 + default: + return 0 + } +} + // buildChannelFromRow constructs a ChannelEdge from a database row. // This includes building the nodes, channel info, and policies. func (s *SQLStore) buildChannelFromRow(ctx context.Context, db SQLQueries, @@ -1173,127 +1306,216 @@ func (s *SQLStore) updateChanCacheBatch(v lnwire.GossipVersion, // // NOTE: This is part of the Store interface. func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context, - startTime, endTime time.Time, + v lnwire.GossipVersion, r ChanUpdateRange, opts ...IteratorOption) iter.Seq2[ChannelEdge, error] { + if err := r.validateForVersion(v); err != nil { + return chanUpdateRangeErrIter(err) + } + // Apply options. cfg := defaultIteratorConfig() for _, opt := range opts { opt(cfg) } + batchSize := cfg.chanUpdateIterBatchSize + return func(yield func(ChannelEdge, error) bool) { var ( - edgesSeen = make(map[uint64]struct{}) - edgesToCache = make(map[uint64]ChannelEdge) - hits int - total int - lastUpdateTime sql.NullInt64 - lastID sql.NullInt64 - hasMore = true + edgesSeen = make(map[uint64]struct{}) + edgesToCache = make(map[uint64]ChannelEdge) + hits int + total int + lastUpdateTime sql.NullInt64 + lastBlockHeight sql.NullInt64 + lastID sql.NullInt64 + hasMore = true ) - // Each iteration, we'll read a batch amount of channel updates - // (consulting the cache along the way), yield them, then loop - // back to decide if we have any more updates to read out. + queryV1 := func(db SQLQueries) ( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) { + + return db.GetChannelsByPolicyLastUpdateRange( + ctx, + sqlc.GetChannelsByPolicyLastUpdateRangeParams{ + StartTime: sqldb.SQLInt64( + r.StartTime.UnwrapOr( + time.Time{}, + ).Unix(), + ), + EndTime: sqldb.SQLInt64( + r.EndTime.UnwrapOr( + time.Time{}, + ).Unix(), + ), + LastUpdateTime: lastUpdateTime, + LastID: lastID, + MaxResults: sql.NullInt32{ + Int32: int32(batchSize), + Valid: true, + }, + }, + ) + } + + queryV2 := func(db SQLQueries) ( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) { + + blockRows, err := db.GetChannelsByPolicyBlockRange( + ctx, + sqlc.GetChannelsByPolicyBlockRangeParams{ + Version: int16(v), + StartHeight: sqldb.SQLInt64( + int64(r.StartHeight.UnwrapOr(0)), //nolint:ll + ), + EndHeight: sqldb.SQLInt64( + int64(r.EndHeight.UnwrapOr(0)), //nolint:ll + ), + LastBlockHeight: lastBlockHeight, + LastID: lastID, + MaxResults: sql.NullInt32{ + Int32: int32(batchSize), + Valid: true, + }, + }, + ) + if err != nil { + return nil, err + } + + rows := make( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, + 0, len(blockRows), + ) + for _, br := range blockRows { + //nolint:ll + rows = append( + rows, + sqlc.GetChannelsByPolicyLastUpdateRangeRow(br), + ) + } + + return rows, nil + } + + // queryChannels fetches the next batch of channels whose + // policies fall within the horizon range. + queryChannels := func(db SQLQueries) ( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) { + + switch v { + case gossipV1: + return queryV1(db) + + case gossipV2: + return queryV2(db) + + default: + return nil, fmt.Errorf("unknown gossip "+ + "version: %v", v) + } + } + + // processRow handles a single channel row: updates + // pagination cursors, checks the seen set and cache, and + // builds the channel edge if needed. + processRow := func(ctx context.Context, db SQLQueries, + row sqlc.GetChannelsByPolicyLastUpdateRangeRow, + batch *[]ChannelEdge) error { + + switch v { + case gossipV1: + lastUpdateTime = sql.NullInt64{ + Int64: extractMaxUpdateTime(row), + Valid: true, + } + case gossipV2: + lastBlockHeight = sql.NullInt64{ + Int64: extractMaxBlockHeight(row), + Valid: true, + } + } + lastID = sql.NullInt64{ + Int64: row.GraphChannel.ID, + Valid: true, + } + + chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid) + + if _, ok := edgesSeen[chanIDInt]; ok { + return nil + } + + // Check cache (we already hold shared read + // lock). + channel, ok := s.chanCache.get(v, chanIDInt) + if ok { + hits++ + total++ + edgesSeen[chanIDInt] = struct{}{} + *batch = append(*batch, channel) + + return nil + } + + chanEdge, err := s.buildChannelFromRow( + ctx, db, row, + ) + if err != nil { + return err + } + + edgesSeen[chanIDInt] = struct{}{} + edgesToCache[chanIDInt] = chanEdge + *batch = append(*batch, chanEdge) + total++ + + return nil + } + + // Each iteration, we'll read a batch amount of channel + // updates (consulting the cache along the way), yield + // them, then loop back to decide if we have any more + // updates to read out. for hasMore { var batch []ChannelEdge - // Acquire read lock before starting transaction to - // ensure consistent lock ordering (cacheMu -> DB) and - // prevent deadlock with write operations. + // Acquire read lock before starting transaction + // to ensure consistent lock ordering + // (cacheMu -> DB) and prevent deadlock with + // write operations. s.cacheMu.RLock() err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - //nolint:ll - params := sqlc.GetChannelsByPolicyLastUpdateRangeParams{ - Version: int16(lnwire.GossipVersion1), - StartTime: sqldb.SQLInt64( - startTime.Unix(), - ), - EndTime: sqldb.SQLInt64( - endTime.Unix(), - ), - LastUpdateTime: lastUpdateTime, - LastID: lastID, - MaxResults: sql.NullInt32{ - Int32: int32( - cfg.chanUpdateIterBatchSize, - ), - Valid: true, - }, - } - //nolint:ll - rows, err := db.GetChannelsByPolicyLastUpdateRange( - ctx, params, - ) + rows, err := queryChannels(db) if err != nil { return err } - //nolint:ll - hasMore = len(rows) == cfg.chanUpdateIterBatchSize + hasMore = len(rows) == batchSize - //nolint:ll for _, row := range rows { - lastUpdateTime = sql.NullInt64{ - Int64: extractMaxUpdateTime(row), - Valid: true, - } - lastID = sql.NullInt64{ - Int64: row.GraphChannel.ID, - Valid: true, - } - - // Skip if we've already - // processed this channel. - chanIDInt := byteOrder.Uint64( - row.GraphChannel.Scid, - ) - _, ok := edgesSeen[chanIDInt] - if ok { - continue - } - - // Check cache (we already hold - // shared read lock). - channel, ok := s.chanCache.get( - lnwire.GossipVersion1, - chanIDInt, - ) - if ok { - hits++ - total++ - edgesSeen[chanIDInt] = struct{}{} - batch = append(batch, channel) - - continue - } - - chanEdge, err := s.buildChannelFromRow( - ctx, db, row, + err := processRow( + ctx, db, row, &batch, ) if err != nil { return err } - - edgesSeen[chanIDInt] = struct{}{} - edgesToCache[chanIDInt] = chanEdge - - batch = append(batch, chanEdge) - - total++ } return nil }, func() { batch = nil - edgesSeen = make(map[uint64]struct{}) + edgesSeen = make( + map[uint64]struct{}, + ) edgesToCache = make( map[uint64]ChannelEdge, ) - }) + }, + ) // Release read lock after transaction completes. s.cacheMu.RUnlock() @@ -1313,11 +1535,10 @@ func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context, } } - // Update cache after successful batch yield, setting - // the cache lock only once for the entire batch. - s.updateChanCacheBatch( - lnwire.GossipVersion1, edgesToCache, - ) + // Update cache after successful batch yield, + // setting the cache lock only once for the + // entire batch. + s.updateChanCacheBatch(v, edgesToCache) edgesToCache = make(map[uint64]ChannelEdge) // If the batch didn't yield anything, then we're done. @@ -1327,12 +1548,12 @@ func (s *SQLStore) ChanUpdatesInHorizon(ctx context.Context, } if total > 0 { - log.Debugf("ChanUpdatesInHorizon hit percentage: "+ - "%.2f (%d/%d)", + log.Debugf("ChanUpdatesInHorizon(v%d) hit "+ + "percentage: %.2f (%d/%d)", v, float64(hits)*100/float64(total), hits, total) } else { - log.Debugf("ChanUpdatesInHorizon returned no edges "+ - "in horizon (%s, %s)", startTime, endTime) + log.Debugf("ChanUpdatesInHorizon(v%d) returned "+ + "no edges in horizon", v) } } } @@ -1401,7 +1622,7 @@ func (s *SQLStore) ForEachNodeCached(ctx context.Context, // page. allChannels, err := db.ListChannelsForNodeIDs( ctx, sqlc.ListChannelsForNodeIDsParams{ - Version: int16(lnwire.GossipVersion1), + Version: int16(v), Node1Ids: nodeIDs, Node2Ids: nodeIDs, }, @@ -1645,16 +1866,75 @@ func (s *SQLStore) ForEachChannelCacheable(ctx context.Context, // // NOTE: part of the Store interface. func (s *SQLStore) ForEachChannel(ctx context.Context, - v lnwire.GossipVersion, cb func(*models.ChannelEdgeInfo, - *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, + cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error, reset func()) error { - if !isKnownGossipVersion(v) { - return fmt.Errorf("unsupported gossip version: %d", v) - } - return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachChannelWithPolicies(ctx, db, s.cfg, v, cb) + // Collect channels across all versions, preferring the + // highest version's data. + type chanEntry struct { + info *models.ChannelEdgeInfo + p1, p2 *models.ChannelEdgePolicy + } + chansByID := make(map[uint64]*chanEntry) + var order []uint64 + + for _, v := range []lnwire.GossipVersion{ + gossipV1, gossipV2, + } { + if !isKnownGossipVersion(v) { + continue + } + + err := forEachChannelWithPolicies( + ctx, db, s.cfg, v, + func(info *models.ChannelEdgeInfo, + p1, + p2 *models.ChannelEdgePolicy) error { + + id := info.ChannelID + entry, exists := chansByID[id] + if !exists { + entry = &chanEntry{} + chansByID[id] = entry + order = append(order, id) + } + + // Prefer highest version, but only + // overwrite if the new entry has at + // least one policy or the existing + // entry has none. This prevents a + // v2 channel with no policies from + // hiding a v1 channel that had + // valid policy data. + hasPolicies := p1 != nil || p2 != nil + existingEmpty := entry.p1 == nil && + entry.p2 == nil + + if hasPolicies || existingEmpty { + entry.info = info + entry.p1 = p1 + entry.p2 = p2 + } + + return nil + }, + ) + if err != nil { + return err + } + } + + for _, id := range order { + entry := chansByID[id] + err := cb(entry.info, entry.p1, entry.p2) + if err != nil { + return err + } + } + + return nil }, reset) } @@ -1755,17 +2035,15 @@ func (s *SQLStore) FilterChannelRange(ctx context.Context, return fmt.Errorf("unable to fetch node1 "+ "policy: %w", err) } else if err == nil { - n1Update := node1Policy.LastUpdate.Int64 - n1Height := node1Policy.BlockHeight.Int64 - switch v { case gossipV1: - chanInfo.Node1Freshness = - lnwire.UnixTimestamp(n1Update) + chanInfo.Node1Freshness = lnwire.UnixTimestamp( + node1Policy.LastUpdate.Int64, + ) case gossipV2: chanInfo.Node1Freshness = lnwire.BlockHeightTimestamp( - n1Height, + uint32(node1Policy.BlockHeight.Int64), ) } } @@ -1782,17 +2060,15 @@ func (s *SQLStore) FilterChannelRange(ctx context.Context, return fmt.Errorf("unable to fetch node2 "+ "policy: %w", err) } else if err == nil { - n2Update := node2Policy.LastUpdate.Int64 - n2Height := node2Policy.BlockHeight.Int64 - switch v { case gossipV1: - chanInfo.Node2Freshness = - lnwire.UnixTimestamp(n2Update) + chanInfo.Node2Freshness = lnwire.UnixTimestamp( + node2Policy.LastUpdate.Int64, + ) case gossipV2: chanInfo.Node2Freshness = lnwire.BlockHeightTimestamp( - n2Height, + uint32(node2Policy.BlockHeight.Int64), ) } } @@ -2296,6 +2572,114 @@ func (s *SQLStore) FetchChannelEdgesByOutpoint(ctx context.Context, return edge, policy1, policy2, nil } +// gossipVersionsDescending lists gossip versions from highest to lowest for +// prefer-highest iteration. +var gossipVersionsDescending = []lnwire.GossipVersion{gossipV2, gossipV1} + +// FetchChannelEdgesByIDPreferHighest tries each known gossip version from +// highest to lowest and returns the first result found. +// +// NOTE: part of the Store interface. +func (s *SQLStore) FetchChannelEdgesByIDPreferHighest(ctx context.Context, + chanID uint64) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + for _, v := range gossipVersionsDescending { + info, p1, p2, err := s.FetchChannelEdgesByID(ctx, v, chanID) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, nil, nil, err + } + + return info, p1, p2, nil + } + + return nil, nil, nil, ErrEdgeNotFound +} + +// FetchChannelEdgesByOutpointPreferHighest tries each known gossip version +// from highest to lowest and returns the first result found. +// +// NOTE: part of the Store interface. +func (s *SQLStore) FetchChannelEdgesByOutpointPreferHighest( + ctx context.Context, op *wire.OutPoint) ( + *models.ChannelEdgeInfo, *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy, error) { + + for _, v := range gossipVersionsDescending { + info, p1, p2, err := s.FetchChannelEdgesByOutpoint( + ctx, v, op, + ) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, nil, nil, err + } + + return info, p1, p2, nil + } + + return nil, nil, nil, ErrEdgeNotFound +} + +// GetVersionsBySCID returns the gossip versions for which a channel with the +// given SCID exists in the database. +// +// NOTE: part of the Store interface. +func (s *SQLStore) GetVersionsBySCID(ctx context.Context, + chanID uint64) ([]lnwire.GossipVersion, error) { + + var versions []lnwire.GossipVersion + for _, v := range []lnwire.GossipVersion{gossipV1, gossipV2} { + _, _, _, err := s.FetchChannelEdgesByID(ctx, v, chanID) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, err + } + + versions = append(versions, v) + } + + return versions, nil +} + +// GetVersionsByOutpoint returns the gossip versions for which a channel with +// the given funding outpoint exists in the database. +// +// NOTE: part of the Store interface. +func (s *SQLStore) GetVersionsByOutpoint(ctx context.Context, + op *wire.OutPoint) ([]lnwire.GossipVersion, error) { + + var versions []lnwire.GossipVersion + for _, v := range []lnwire.GossipVersion{gossipV1, gossipV2} { + _, _, _, err := s.FetchChannelEdgesByOutpoint(ctx, v, op) + if errors.Is(err, ErrEdgeNotFound) || + errors.Is(err, ErrZombieEdge) { + + continue + } + if err != nil { + return nil, err + } + + versions = append(versions, v) + } + + return versions, nil +} + // HasV1ChannelEdge returns true if the database knows of a channel edge // with the passed channel ID, and false otherwise. If an edge with that ID // is found within the graph, then two time stamps representing the last time @@ -2740,6 +3124,7 @@ func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context, // // NOTE: part of the Store interface. func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, + v lnwire.GossipVersion, chansInfo []ChannelUpdateInfo) ([]uint64, []ChannelUpdateInfo, error) { var ( @@ -2769,7 +3154,7 @@ func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, return nil } - err := s.forEachChanInSCIDList(ctx, db, cb, chansInfo) + err := s.forEachChanInSCIDList(ctx, db, v, cb, chansInfo) if err != nil { return fmt.Errorf("unable to iterate through "+ "channels: %w", err) @@ -2788,7 +3173,7 @@ func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, isZombie, err := db.IsZombieChannel( ctx, sqlc.IsZombieChannelParams{ Scid: channelIDToBytes(channelID), - Version: int16(lnwire.GossipVersion1), + Version: int16(v), }, ) if err != nil { @@ -2827,6 +3212,7 @@ func (s *SQLStore) FilterKnownChanIDs(ctx context.Context, // ChannelUpdateInfo slice. The callback function is called for each channel // that is found. func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, + v lnwire.GossipVersion, cb func(ctx context.Context, channel sqlc.GraphChannel) error, chansInfo []ChannelUpdateInfo) error { @@ -2835,7 +3221,7 @@ func (s *SQLStore) forEachChanInSCIDList(ctx context.Context, db SQLQueries, return db.GetChannelsBySCIDs( ctx, sqlc.GetChannelsBySCIDsParams{ - Version: int16(lnwire.GossipVersion1), + Version: int16(v), Scids: scids, }, ) @@ -3485,9 +3871,27 @@ func (s *sqlNodeTraverser) ForEachNodeDirectedChannel( ctx context.Context, nodePub route.Vertex, cb func(channel *DirectedChannel) error, _ func()) error { - return forEachNodeDirectedChannel( - ctx, s.db, lnwire.GossipVersion1, nodePub, cb, - ) + // Iterate across all gossip versions (highest first) so that + // channels announced via v2 are preferred over v1. + seen := make(map[uint64]struct{}) + for _, v := range []lnwire.GossipVersion{gossipV2, gossipV1} { + err := forEachNodeDirectedChannel( + ctx, s.db, v, nodePub, + func(channel *DirectedChannel) error { + if _, ok := seen[channel.ChannelID]; ok { + return nil + } + seen[channel.ChannelID] = struct{}{} + + return cb(channel) + }, + ) + if err != nil { + return err + } + } + + return nil } // FetchNodeFeatures returns the features of the given node. If the node is @@ -3498,7 +3902,21 @@ func (s *sqlNodeTraverser) FetchNodeFeatures(ctx context.Context, nodePub route.Vertex) ( *lnwire.FeatureVector, error) { - return fetchNodeFeatures(ctx, s.db, lnwire.GossipVersion1, nodePub) + // Try v2 first, fall back to v1 if the v2 features are empty. + for _, v := range []lnwire.GossipVersion{gossipV2, gossipV1} { + features, err := fetchNodeFeatures( + ctx, s.db, v, nodePub, + ) + if err != nil { + return nil, err + } + + if !features.IsEmpty() { + return features, nil + } + } + + return lnwire.EmptyFeatureVector(), nil } // forEachNodeDirectedChannel iterates through all channels of a given diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index dc0a0641ef8..18f02a27a42 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -374,6 +374,162 @@ func (q *Queries) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAn return i, err } +const getChannelByOutpointPreferHighestVersionWithPolicies = `-- name: GetChannelByOutpointPreferHighestVersionWithPolicies :one +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Node 1 policy + cp1.id AS policy_1_id, + cp1.node_id AS policy_1_node_id, + cp1.version AS policy_1_version, + cp1.timelock AS policy_1_timelock, + cp1.fee_ppm AS policy_1_fee_ppm, + cp1.base_fee_msat AS policy_1_base_fee_msat, + cp1.min_htlc_msat AS policy_1_min_htlc_msat, + cp1.max_htlc_msat AS policy_1_max_htlc_msat, + cp1.last_update AS policy_1_last_update, + cp1.disabled AS policy_1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy_1_message_flags, + cp1.channel_flags AS policy_1_channel_flags, + cp1.signature AS policy_1_signature, + cp1.block_height AS policy_1_block_height, + cp1.disable_flags AS policy_1_disable_flags, + + -- Node 2 policy + cp2.id AS policy_2_id, + cp2.node_id AS policy_2_node_id, + cp2.version AS policy_2_version, + cp2.timelock AS policy_2_timelock, + cp2.fee_ppm AS policy_2_fee_ppm, + cp2.base_fee_msat AS policy_2_base_fee_msat, + cp2.min_htlc_msat AS policy_2_min_htlc_msat, + cp2.max_htlc_msat AS policy_2_max_htlc_msat, + cp2.last_update AS policy_2_last_update, + cp2.disabled AS policy_2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.outpoint = $1 +ORDER BY c.version DESC +LIMIT 1 +` + +type GetChannelByOutpointPreferHighestVersionWithPoliciesRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 +} + +func (q *Queries) GetChannelByOutpointPreferHighestVersionWithPolicies(ctx context.Context, outpoint string) (GetChannelByOutpointPreferHighestVersionWithPoliciesRow, error) { + row := q.db.QueryRowContext(ctx, getChannelByOutpointPreferHighestVersionWithPolicies, outpoint) + var i GetChannelByOutpointPreferHighestVersionWithPoliciesRow + err := row.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.Node1Pubkey, + &i.Node2Pubkey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, + ) + return i, err +} + const getChannelByOutpointWithPolicies = `-- name: GetChannelByOutpointWithPolicies :one SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, @@ -567,7 +723,7 @@ func (q *Queries) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDPara return i, err } -const getChannelBySCIDWithPolicies = `-- name: GetChannelBySCIDWithPolicies :one +const getChannelBySCIDPreferHighestVersionWithPolicies = `-- name: GetChannelBySCIDPreferHighestVersionWithPolicies :one SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, n1.block_height, @@ -619,15 +775,11 @@ FROM graph_channels c LEFT JOIN graph_channel_policies cp2 ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version WHERE c.scid = $1 - AND c.version = $2 +ORDER BY c.version DESC +LIMIT 1 ` -type GetChannelBySCIDWithPoliciesParams struct { - Scid []byte - Version int16 -} - -type GetChannelBySCIDWithPoliciesRow struct { +type GetChannelBySCIDPreferHighestVersionWithPoliciesRow struct { GraphChannel GraphChannel GraphNode GraphNode GraphNode_2 GraphNode @@ -667,9 +819,9 @@ type GetChannelBySCIDWithPoliciesRow struct { Policy2DisableFlags sql.NullInt16 } -func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) { - row := q.db.QueryRowContext(ctx, getChannelBySCIDWithPolicies, arg.Scid, arg.Version) - var i GetChannelBySCIDWithPoliciesRow +func (q *Queries) GetChannelBySCIDPreferHighestVersionWithPolicies(ctx context.Context, scid []byte) (GetChannelBySCIDPreferHighestVersionWithPoliciesRow, error) { + row := q.db.QueryRowContext(ctx, getChannelBySCIDPreferHighestVersionWithPolicies, scid) + var i GetChannelBySCIDPreferHighestVersionWithPoliciesRow err := row.Scan( &i.GraphChannel.ID, &i.GraphChannel.Version, @@ -741,51 +893,225 @@ func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChann return i, err } -const getChannelExtrasBatch = `-- name: GetChannelExtrasBatch :many +const getChannelBySCIDWithPolicies = `-- name: GetChannelBySCIDWithPolicies :one SELECT - channel_id, - type, - value -FROM graph_channel_extra_types -WHERE channel_id IN (/*SLICE:chan_ids*/?) -ORDER BY channel_id, type + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, n1.block_height, + n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature, n2.block_height, + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.scid = $1 + AND c.version = $2 ` -func (q *Queries) GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error) { - query := getChannelExtrasBatch - var queryParams []interface{} - if len(chanIds) > 0 { - for _, v := range chanIds { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:chan_ids*/?", makeQueryParams(len(queryParams), len(chanIds)), 1) - } else { - query = strings.Replace(query, "/*SLICE:chan_ids*/?", "NULL", 1) - } - rows, err := q.db.QueryContext(ctx, query, queryParams...) - if err != nil { - return nil, err - } - defer rows.Close() - var items []GraphChannelExtraType - for rows.Next() { - var i GraphChannelExtraType - if err := rows.Scan(&i.ChannelID, &i.Type, &i.Value); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil +type GetChannelBySCIDWithPoliciesParams struct { + Scid []byte + Version int16 } -const getChannelFeaturesBatch = `-- name: GetChannelFeaturesBatch :many -SELECT +type GetChannelBySCIDWithPoliciesRow struct { + GraphChannel GraphChannel + GraphNode GraphNode + GraphNode_2 GraphNode + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 +} + +func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) { + row := q.db.QueryRowContext(ctx, getChannelBySCIDWithPolicies, arg.Scid, arg.Version) + var i GetChannelBySCIDWithPoliciesRow + err := row.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.GraphNode.ID, + &i.GraphNode.Version, + &i.GraphNode.PubKey, + &i.GraphNode.Alias, + &i.GraphNode.LastUpdate, + &i.GraphNode.Color, + &i.GraphNode.Signature, + &i.GraphNode.BlockHeight, + &i.GraphNode_2.ID, + &i.GraphNode_2.Version, + &i.GraphNode_2.PubKey, + &i.GraphNode_2.Alias, + &i.GraphNode_2.LastUpdate, + &i.GraphNode_2.Color, + &i.GraphNode_2.Signature, + &i.GraphNode_2.BlockHeight, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, + ) + return i, err +} + +const getChannelExtrasBatch = `-- name: GetChannelExtrasBatch :many +SELECT + channel_id, + type, + value +FROM graph_channel_extra_types +WHERE channel_id IN (/*SLICE:chan_ids*/?) +ORDER BY channel_id, type +` + +func (q *Queries) GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error) { + query := getChannelExtrasBatch + var queryParams []interface{} + if len(chanIds) > 0 { + for _, v := range chanIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:chan_ids*/?", makeQueryParams(len(queryParams), len(chanIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:chan_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphChannelExtraType + for rows.Next() { + var i GraphChannelExtraType + if err := rows.Scan(&i.ChannelID, &i.Type, &i.Value); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChannelFeaturesBatch = `-- name: GetChannelFeaturesBatch :many +SELECT channel_id, feature_bit FROM graph_channel_features @@ -915,17 +1241,271 @@ func (q *Queries) GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds return items, nil } -const getChannelsByIDs = `-- name: GetChannelsByIDs :many +const getChannelsByIDs = `-- name: GetChannelsByIDs :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + + -- Minimal node data. + n1.id AS node1_id, + n1.pub_key AS node1_pub_key, + n2.id AS node2_id, + n2.pub_key AS node2_pub_key, + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.id IN (/*SLICE:ids*/?) +` + +type GetChannelsByIDsRow struct { + GraphChannel GraphChannel + Node1ID int64 + Node1PubKey []byte + Node2ID int64 + Node2PubKey []byte + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy1BlockHeight sql.NullInt64 + Policy1DisableFlags sql.NullInt16 + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte + Policy2BlockHeight sql.NullInt64 + Policy2DisableFlags sql.NullInt16 +} + +func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChannelsByIDsRow, error) { + query := getChannelsByIDs + var queryParams []interface{} + if len(ids) > 0 { + for _, v := range ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsByIDsRow + for rows.Next() { + var i GetChannelsByIDsRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.Node1ID, + &i.Node1PubKey, + &i.Node2ID, + &i.Node2PubKey, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy1BlockHeight, + &i.Policy1DisableFlags, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + &i.Policy2BlockHeight, + &i.Policy2DisableFlags, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChannelsByOutpoints = `-- name: GetChannelsByOutpoints :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id +WHERE c.outpoint IN + (/*SLICE:outpoints*/?) +` + +type GetChannelsByOutpointsRow struct { + GraphChannel GraphChannel + Node1Pubkey []byte + Node2Pubkey []byte +} + +func (q *Queries) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) { + query := getChannelsByOutpoints + var queryParams []interface{} + if len(outpoints) > 0 { + for _, v := range outpoints { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:outpoints*/?", makeQueryParams(len(queryParams), len(outpoints)), 1) + } else { + query = strings.Replace(query, "/*SLICE:outpoints*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsByOutpointsRow + for rows.Next() { + var i GetChannelsByOutpointsRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphChannel.Signature, + &i.GraphChannel.FundingPkScript, + &i.GraphChannel.MerkleRootHash, + &i.Node1Pubkey, + &i.Node2Pubkey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChannelsByPolicyBlockRange = `-- name: GetChannelsByPolicyBlockRange :many SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, + n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, n1.block_height, + n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature, n2.block_height, - -- Minimal node data. - n1.id AS node1_id, - n1.pub_key AS node1_pub_key, - n2.id AS node2_id, - n2.pub_key AS node2_pub_key, - - -- Policy 1 + -- Policy 1 (node_id_1) cp1.id AS policy1_id, cp1.node_id AS policy1_node_id, cp1.version AS policy1_version, @@ -944,7 +1524,7 @@ SELECT cp1.block_height AS policy1_block_height, cp1.disable_flags AS policy1_disable_flags, - -- Policy 2 + -- Policy 2 (node_id_2) cp2.id AS policy2_id, cp2.node_id AS policy2_node_id, cp2.version AS policy2_version, @@ -970,15 +1550,51 @@ FROM graph_channels c ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version LEFT JOIN graph_channel_policies cp2 ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version -WHERE c.id IN (/*SLICE:ids*/?) +WHERE c.version = $1 + AND ( + (cp1.block_height >= $2 AND cp1.block_height < $3) + OR + (cp2.block_height >= $2 AND cp2.block_height < $3) + ) + -- Pagination using compound cursor (max_block_height, id). + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END > COALESCE($4, -1)) + OR + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END = COALESCE($4, -1) + AND c.id > COALESCE($5, -1)) + ) +ORDER BY + CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END ASC, + c.id ASC +LIMIT COALESCE($6, 999999999) ` -type GetChannelsByIDsRow struct { +type GetChannelsByPolicyBlockRangeParams struct { + Version int16 + StartHeight sql.NullInt64 + EndHeight sql.NullInt64 + LastBlockHeight sql.NullInt64 + LastID sql.NullInt64 + MaxResults interface{} +} + +type GetChannelsByPolicyBlockRangeRow struct { GraphChannel GraphChannel - Node1ID int64 - Node1PubKey []byte - Node2ID int64 - Node2PubKey []byte + GraphNode GraphNode + GraphNode_2 GraphNode Policy1ID sql.NullInt64 Policy1NodeID sql.NullInt64 Policy1Version sql.NullInt16 @@ -1015,25 +1631,22 @@ type GetChannelsByIDsRow struct { Policy2DisableFlags sql.NullInt16 } -func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChannelsByIDsRow, error) { - query := getChannelsByIDs - var queryParams []interface{} - if len(ids) > 0 { - for _, v := range ids { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:ids*/?", makeQueryParams(len(queryParams), len(ids)), 1) - } else { - query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) - } - rows, err := q.db.QueryContext(ctx, query, queryParams...) +func (q *Queries) GetChannelsByPolicyBlockRange(ctx context.Context, arg GetChannelsByPolicyBlockRangeParams) ([]GetChannelsByPolicyBlockRangeRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelsByPolicyBlockRange, + arg.Version, + arg.StartHeight, + arg.EndHeight, + arg.LastBlockHeight, + arg.LastID, + arg.MaxResults, + ) if err != nil { return nil, err } defer rows.Close() - var items []GetChannelsByIDsRow + var items []GetChannelsByPolicyBlockRangeRow for rows.Next() { - var i GetChannelsByIDsRow + var i GetChannelsByPolicyBlockRangeRow if err := rows.Scan( &i.GraphChannel.ID, &i.GraphChannel.Version, @@ -1051,10 +1664,22 @@ func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChann &i.GraphChannel.Signature, &i.GraphChannel.FundingPkScript, &i.GraphChannel.MerkleRootHash, - &i.Node1ID, - &i.Node1PubKey, - &i.Node2ID, - &i.Node2PubKey, + &i.GraphNode.ID, + &i.GraphNode.Version, + &i.GraphNode.PubKey, + &i.GraphNode.Alias, + &i.GraphNode.LastUpdate, + &i.GraphNode.Color, + &i.GraphNode.Signature, + &i.GraphNode.BlockHeight, + &i.GraphNode_2.ID, + &i.GraphNode_2.Version, + &i.GraphNode_2.PubKey, + &i.GraphNode_2.Alias, + &i.GraphNode_2.LastUpdate, + &i.GraphNode_2.Color, + &i.GraphNode_2.Signature, + &i.GraphNode_2.BlockHeight, &i.Policy1ID, &i.Policy1NodeID, &i.Policy1Version, @@ -1103,76 +1728,6 @@ func (q *Queries) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChann return items, nil } -const getChannelsByOutpoints = `-- name: GetChannelsByOutpoints :many -SELECT - c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, - n1.pub_key AS node1_pubkey, - n2.pub_key AS node2_pubkey -FROM graph_channels c - JOIN graph_nodes n1 ON c.node_id_1 = n1.id - JOIN graph_nodes n2 ON c.node_id_2 = n2.id -WHERE c.outpoint IN - (/*SLICE:outpoints*/?) -` - -type GetChannelsByOutpointsRow struct { - GraphChannel GraphChannel - Node1Pubkey []byte - Node2Pubkey []byte -} - -func (q *Queries) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) { - query := getChannelsByOutpoints - var queryParams []interface{} - if len(outpoints) > 0 { - for _, v := range outpoints { - queryParams = append(queryParams, v) - } - query = strings.Replace(query, "/*SLICE:outpoints*/?", makeQueryParams(len(queryParams), len(outpoints)), 1) - } else { - query = strings.Replace(query, "/*SLICE:outpoints*/?", "NULL", 1) - } - rows, err := q.db.QueryContext(ctx, query, queryParams...) - if err != nil { - return nil, err - } - defer rows.Close() - var items []GetChannelsByOutpointsRow - for rows.Next() { - var i GetChannelsByOutpointsRow - if err := rows.Scan( - &i.GraphChannel.ID, - &i.GraphChannel.Version, - &i.GraphChannel.Scid, - &i.GraphChannel.NodeID1, - &i.GraphChannel.NodeID2, - &i.GraphChannel.Outpoint, - &i.GraphChannel.Capacity, - &i.GraphChannel.BitcoinKey1, - &i.GraphChannel.BitcoinKey2, - &i.GraphChannel.Node1Signature, - &i.GraphChannel.Node2Signature, - &i.GraphChannel.Bitcoin1Signature, - &i.GraphChannel.Bitcoin2Signature, - &i.GraphChannel.Signature, - &i.GraphChannel.FundingPkScript, - &i.GraphChannel.MerkleRootHash, - &i.Node1Pubkey, - &i.Node2Pubkey, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getChannelsByPolicyLastUpdateRange = `-- name: GetChannelsByPolicyLastUpdateRange :many SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, c.signature, c.funding_pk_script, c.merkle_root_hash, @@ -1224,11 +1779,11 @@ FROM graph_channels c ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version LEFT JOIN graph_channel_policies cp2 ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version -WHERE c.version = $1 +WHERE c.version = 1 AND ( - (cp1.last_update >= $2 AND cp1.last_update < $3) + (cp1.last_update >= $1 AND cp1.last_update < $2) OR - (cp2.last_update >= $2 AND cp2.last_update < $3) + (cp2.last_update >= $1 AND cp2.last_update < $2) ) -- Pagination using compound cursor (max_update_time, id). -- We use COALESCE with -1 as sentinel since timestamps are always positive. @@ -1237,14 +1792,14 @@ WHERE c.version = $1 WHEN COALESCE(cp1.last_update, 0) >= COALESCE(cp2.last_update, 0) THEN COALESCE(cp1.last_update, 0) ELSE COALESCE(cp2.last_update, 0) - END > COALESCE($4, -1)) + END > COALESCE($3, -1)) OR (CASE WHEN COALESCE(cp1.last_update, 0) >= COALESCE(cp2.last_update, 0) THEN COALESCE(cp1.last_update, 0) ELSE COALESCE(cp2.last_update, 0) - END = COALESCE($4, -1) - AND c.id > COALESCE($5, -1)) + END = COALESCE($3, -1) + AND c.id > COALESCE($4, -1)) ) ORDER BY CASE @@ -1253,11 +1808,10 @@ ORDER BY ELSE COALESCE(cp2.last_update, 0) END ASC, c.id ASC -LIMIT COALESCE($6, 999999999) +LIMIT COALESCE($5, 999999999) ` type GetChannelsByPolicyLastUpdateRangeParams struct { - Version int16 StartTime sql.NullInt64 EndTime sql.NullInt64 LastUpdateTime sql.NullInt64 @@ -1307,7 +1861,6 @@ type GetChannelsByPolicyLastUpdateRangeRow struct { func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) { rows, err := q.db.QueryContext(ctx, getChannelsByPolicyLastUpdateRange, - arg.Version, arg.StartTime, arg.EndTime, arg.LastUpdateTime, @@ -2070,6 +2623,92 @@ func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyPa return id, err } +const getNodesByBlockHeightRange = `-- name: GetNodesByBlockHeightRange :many +SELECT id, version, pub_key, alias, last_update, color, signature, block_height +FROM graph_nodes +WHERE graph_nodes.version = $1 + AND block_height >= $2 + AND block_height <= $3 + -- Pagination: We use (block_height, pub_key) as a compound cursor. + -- This ensures stable ordering and allows us to resume from where we left off. + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + -- Include rows with block_height greater than cursor (or all rows if cursor is -1). + block_height > COALESCE($4, -1) + OR + -- For rows with same block_height, use pub_key as tiebreaker. + (block_height = COALESCE($4, -1) + AND pub_key > $5) + ) + -- Optional filter for public nodes only. + AND ( + -- If only_public is false or not provided, include all nodes. + COALESCE($6, FALSE) IS FALSE + OR + -- For V2 protocol, a node is public if it has at least one public channel. + -- A public channel has signature set (channel announcement received). + EXISTS ( + SELECT 1 + FROM graph_channels c + WHERE c.version = 2 + AND COALESCE(length(c.signature), 0) > 0 + AND (c.node_id_1 = graph_nodes.id OR c.node_id_2 = graph_nodes.id) + ) + ) +ORDER BY block_height ASC, pub_key ASC +LIMIT COALESCE($7, 999999999) +` + +type GetNodesByBlockHeightRangeParams struct { + Version int16 + StartHeight sql.NullInt64 + EndHeight sql.NullInt64 + LastBlockHeight sql.NullInt64 + LastPubKey []byte + OnlyPublic interface{} + MaxResults interface{} +} + +func (q *Queries) GetNodesByBlockHeightRange(ctx context.Context, arg GetNodesByBlockHeightRangeParams) ([]GraphNode, error) { + rows, err := q.db.QueryContext(ctx, getNodesByBlockHeightRange, + arg.Version, + arg.StartHeight, + arg.EndHeight, + arg.LastBlockHeight, + arg.LastPubKey, + arg.OnlyPublic, + arg.MaxResults, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphNode + for rows.Next() { + var i GraphNode + if err := rows.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + &i.BlockHeight, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getNodesByIDs = `-- name: GetNodesByIDs :many SELECT id, version, pub_key, alias, last_update, color, signature, block_height FROM graph_nodes @@ -2121,7 +2760,8 @@ func (q *Queries) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many SELECT id, version, pub_key, alias, last_update, color, signature, block_height FROM graph_nodes -WHERE last_update >= $1 +WHERE graph_nodes.version = 1 + AND last_update >= $1 AND last_update <= $2 -- Pagination: We use (last_update, pub_key) as a compound cursor. -- This ensures stable ordering and allows us to resume from where we left off. @@ -2200,7 +2840,6 @@ func (q *Queries) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByL } return items, nil } - const getPruneEntriesForHeights = `-- name: GetPruneEntriesForHeights :many SELECT block_height, block_hash FROM graph_prune_log @@ -2502,6 +3141,66 @@ func (q *Queries) GetV2DisabledSCIDs(ctx context.Context) ([][]byte, error) { return items, nil } +const getVersionsByOutpoint = `-- name: GetVersionsByOutpoint :many +SELECT version +FROM graph_channels +WHERE outpoint = $1 +ORDER BY version +` + +func (q *Queries) GetVersionsByOutpoint(ctx context.Context, outpoint string) ([]int16, error) { + rows, err := q.db.QueryContext(ctx, getVersionsByOutpoint, outpoint) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int16 + for rows.Next() { + var version int16 + if err := rows.Scan(&version); err != nil { + return nil, err + } + items = append(items, version) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getVersionsBySCID = `-- name: GetVersionsBySCID :many +SELECT version +FROM graph_channels +WHERE scid = $1 +ORDER BY version +` + +func (q *Queries) GetVersionsBySCID(ctx context.Context, scid []byte) ([]int16, error) { + rows, err := q.db.QueryContext(ctx, getVersionsBySCID, scid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int16 + for rows.Next() { + var version int16 + if err := rows.Scan(&version); err != nil { + return nil, err + } + items = append(items, version) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getZombieChannel = `-- name: GetZombieChannel :one SELECT scid, version, node_key_1, node_key_2 FROM graph_zombie_channels diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index c148d66b4c3..6a40ca386c5 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -93,8 +93,10 @@ type Querier interface { FilterPayments(ctx context.Context, arg FilterPaymentsParams) ([]FilterPaymentsRow, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) + GetChannelByOutpointPreferHighestVersionWithPolicies(ctx context.Context, outpoint string) (GetChannelByOutpointPreferHighestVersionWithPoliciesRow, error) GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (GraphChannel, error) + GetChannelBySCIDPreferHighestVersionWithPolicies(ctx context.Context, scid []byte) (GetChannelBySCIDPreferHighestVersionWithPoliciesRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error) GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]GraphChannelFeature, error) @@ -102,6 +104,7 @@ type Querier interface { GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]GetChannelPolicyExtraTypesBatchRow, error) GetChannelsByIDs(ctx context.Context, ids []int64) ([]GetChannelsByIDsRow, error) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) + GetChannelsByPolicyBlockRange(ctx context.Context, arg GetChannelsByPolicyBlockRangeParams) ([]GetChannelsByPolicyBlockRangeRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) @@ -127,6 +130,7 @@ type Querier interface { GetNodeFeaturesBatch(ctx context.Context, ids []int64) ([]GraphNodeFeature, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) + GetNodesByBlockHeightRange(ctx context.Context, arg GetNodesByBlockHeightRangeParams) ([]GraphNode, error) GetNodesByIDs(ctx context.Context, ids []int64) ([]GraphNode, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]GraphNode, error) GetPruneEntriesForHeights(ctx context.Context, heights []int64) ([]GraphPruneLog, error) @@ -144,6 +148,8 @@ type Querier interface { // NOTE: this is V2 specific since V2 uses a disable flag // bit vector instead of a single boolean. GetV2DisabledSCIDs(ctx context.Context) ([][]byte, error) + GetVersionsByOutpoint(ctx context.Context, outpoint string) ([]int16, error) + GetVersionsBySCID(ctx context.Context, scid []byte) ([]int16, error) GetZombieChannel(ctx context.Context, arg GetZombieChannelParams) (GraphZombieChannel, error) GetZombieChannelsSCIDs(ctx context.Context, arg GetZombieChannelsSCIDsParams) ([]GraphZombieChannel, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 78c1ebe7c9f..bc3411b5fd6 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -227,7 +227,8 @@ ORDER BY node_id, type, position; -- name: GetNodesByLastUpdateRange :many SELECT * FROM graph_nodes -WHERE last_update >= @start_time +WHERE graph_nodes.version = 1 + AND last_update >= @start_time AND last_update <= @end_time -- Pagination: We use (last_update, pub_key) as a compound cursor. -- This ensures stable ordering and allows us to resume from where we left off. @@ -258,6 +259,41 @@ WHERE last_update >= @start_time ORDER BY last_update ASC, pub_key ASC LIMIT COALESCE(sqlc.narg('max_results'), 999999999); +-- name: GetNodesByBlockHeightRange :many +SELECT * +FROM graph_nodes +WHERE graph_nodes.version = @version + AND block_height >= @start_height + AND block_height <= @end_height + -- Pagination: We use (block_height, pub_key) as a compound cursor. + -- This ensures stable ordering and allows us to resume from where we left off. + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + -- Include rows with block_height greater than cursor (or all rows if cursor is -1). + block_height > COALESCE(sqlc.narg('last_block_height'), -1) + OR + -- For rows with same block_height, use pub_key as tiebreaker. + (block_height = COALESCE(sqlc.narg('last_block_height'), -1) + AND pub_key > sqlc.narg('last_pub_key')) + ) + -- Optional filter for public nodes only. + AND ( + -- If only_public is false or not provided, include all nodes. + COALESCE(sqlc.narg('only_public'), FALSE) IS FALSE + OR + -- For V2 protocol, a node is public if it has at least one public channel. + -- A public channel has signature set (channel announcement received). + EXISTS ( + SELECT 1 + FROM graph_channels c + WHERE c.version = 2 + AND COALESCE(length(c.signature), 0) > 0 + AND (c.node_id_1 = graph_nodes.id OR c.node_id_2 = graph_nodes.id) + ) + ) +ORDER BY block_height ASC, pub_key ASC +LIMIT COALESCE(sqlc.narg('max_results'), 999999999); + -- name: DeleteNodeAddresses :exec DELETE FROM graph_node_addresses WHERE node_id = $1; @@ -548,7 +584,7 @@ FROM graph_channels c ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version LEFT JOIN graph_channel_policies cp2 ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version -WHERE c.version = @version +WHERE c.version = 1 AND ( (cp1.last_update >= @start_time AND cp1.last_update < @end_time) OR @@ -579,6 +615,88 @@ ORDER BY c.id ASC LIMIT COALESCE(sqlc.narg('max_results'), 999999999); +-- name: GetChannelsByPolicyBlockRange :many +SELECT + sqlc.embed(c), + sqlc.embed(n1), + sqlc.embed(n2), + + -- Policy 1 (node_id_1) + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 (node_id_2) + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy2_message_flags, + cp2.channel_flags AS policy2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.version = @version + AND ( + (cp1.block_height >= @start_height AND cp1.block_height < @end_height) + OR + (cp2.block_height >= @start_height AND cp2.block_height < @end_height) + ) + -- Pagination using compound cursor (max_block_height, id). + -- We use COALESCE with -1 as sentinel since heights are always positive. + AND ( + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END > COALESCE(sqlc.narg('last_block_height'), -1)) + OR + (CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END = COALESCE(sqlc.narg('last_block_height'), -1) + AND c.id > COALESCE(sqlc.narg('last_id'), -1)) + ) +ORDER BY + CASE + WHEN COALESCE(cp1.block_height, 0) >= COALESCE(cp2.block_height, 0) + THEN COALESCE(cp1.block_height, 0) + ELSE COALESCE(cp2.block_height, 0) + END ASC, + c.id ASC +LIMIT COALESCE(sqlc.narg('max_results'), 999999999); + -- name: GetChannelByOutpointWithPolicies :one SELECT sqlc.embed(c), @@ -1049,6 +1167,128 @@ FROM graph_channels c WHERE c.scid = @scid AND c.version = @version; +-- name: GetChannelBySCIDPreferHighestVersionWithPolicies :one +SELECT + sqlc.embed(c), + sqlc.embed(n1), + sqlc.embed(n2), + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + cp1.block_height AS policy1_block_height, + cp1.disable_flags AS policy1_disable_flags, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy2_signature, + cp2.block_height AS policy2_block_height, + cp2.disable_flags AS policy2_disable_flags + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.scid = @scid +ORDER BY c.version DESC +LIMIT 1; + +-- name: GetChannelByOutpointPreferHighestVersionWithPolicies :one +SELECT + sqlc.embed(c), + + n1.pub_key AS node1_pubkey, + n2.pub_key AS node2_pubkey, + + -- Node 1 policy + cp1.id AS policy_1_id, + cp1.node_id AS policy_1_node_id, + cp1.version AS policy_1_version, + cp1.timelock AS policy_1_timelock, + cp1.fee_ppm AS policy_1_fee_ppm, + cp1.base_fee_msat AS policy_1_base_fee_msat, + cp1.min_htlc_msat AS policy_1_min_htlc_msat, + cp1.max_htlc_msat AS policy_1_max_htlc_msat, + cp1.last_update AS policy_1_last_update, + cp1.disabled AS policy_1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy_1_message_flags, + cp1.channel_flags AS policy_1_channel_flags, + cp1.signature AS policy_1_signature, + cp1.block_height AS policy_1_block_height, + cp1.disable_flags AS policy_1_disable_flags, + + -- Node 2 policy + cp2.id AS policy_2_id, + cp2.node_id AS policy_2_node_id, + cp2.version AS policy_2_version, + cp2.timelock AS policy_2_timelock, + cp2.fee_ppm AS policy_2_fee_ppm, + cp2.base_fee_msat AS policy_2_base_fee_msat, + cp2.min_htlc_msat AS policy_2_min_htlc_msat, + cp2.max_htlc_msat AS policy_2_max_htlc_msat, + cp2.last_update AS policy_2_last_update, + cp2.disabled AS policy_2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy_2_signature, + cp2.block_height AS policy_2_block_height, + cp2.disable_flags AS policy_2_disable_flags +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE c.outpoint = @outpoint +ORDER BY c.version DESC +LIMIT 1; + +-- name: GetVersionsBySCID :many +SELECT version +FROM graph_channels +WHERE scid = @scid +ORDER BY version; + +-- name: GetVersionsByOutpoint :many +SELECT version +FROM graph_channels +WHERE outpoint = @outpoint +ORDER BY version; + /* ───────────────────────────────────────────── graph_channel_policy_extra_types table queries ─────────────────────────────────────────────