diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index fddae907b2c0..bb35dc2350c7 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -1498,7 +1498,7 @@ def extract_preimage_from_htlc_txin(self, txin: TxInput, *, is_deeply_mined: boo error_bytes=None, failure_message=failure) - def balance(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None) -> int: + def balance(self, whose: HTLCOwner, *, ctx_owner: HTLCOwner = None, ctn: int = None) -> int: assert type(whose) is HTLCOwner initial = self.config[whose].initial_msat return self.hm.get_balance_msat(whose=whose, diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 9a5950758e97..e55114725538 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -507,19 +507,35 @@ def all_htlcs_ever(self) -> Sequence[Tuple[Direction, UpdateAddHtlc]]: return sent + received @with_lock - def get_balance_msat(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: int = None, + def get_balance_msat(self, whose: HTLCOwner, *, ctx_owner: HTLCOwner = None, ctn: int = None, initial_balance_msat: int) -> int: """Returns the balance of 'whose' in 'ctx' at 'ctn'. Only HTLCs that have been settled by that ctn are counted. """ - if ctn is None: - ctn = self.ctn_oldest_unrevoked(ctx_owner) + if ctx_owner is None: + # if ctx_owner is None, we want result to be consistent with get_lightning_history + # thus, we consider that htlcs are settled as soon as their preimage is released, + # because get_lightning_history calls self.was_htlc_preimage_released + ctx_owner = whose + ctx_owner_sent = whose + ctx_owner_recv = -whose + else: + ctx_owner_sent = ctx_owner + ctx_owner_recv = ctx_owner + if ctn is None: + ctn = self.ctn_oldest_unrevoked(ctx_owner) + balance = initial_balance_msat - if ctn >= self.ctn_oldest_unrevoked(ctx_owner): + if ctn is None or ctn >= self.ctn_oldest_unrevoked(ctx_owner): balance += self._balance_delta * whose considered_sent_htlc_ids = self._maybe_active_htlc_ids[whose] considered_recv_htlc_ids = self._maybe_active_htlc_ids[-whose] - else: # ctn is too old; need to consider full log (slow...) + elif ctn == 0: + considered_sent_htlc_ids = [] + considered_recv_htlc_ids = [] + else: + # ctn is too old; need to consider full log (slow...) + # used in sync_with_remote_watchtower considered_sent_htlc_ids = self.log[whose]['settles'] considered_recv_htlc_ids = self.log[-whose]['settles'] # sent htlcs @@ -527,7 +543,7 @@ def get_balance_msat(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: ctns = self.log[whose]['settles'].get(htlc_id, None) if ctns is None: continue - if ctns[ctx_owner] is not None and ctns[ctx_owner] <= ctn: + if ctns[ctx_owner_sent] is not None and (ctn is None or ctns[ctx_owner_sent] <= ctn): htlc = self.log[whose]['adds'][htlc_id] balance -= htlc.amount_msat # recv htlcs @@ -535,7 +551,7 @@ def get_balance_msat(self, whose: HTLCOwner, *, ctx_owner=HTLCOwner.LOCAL, ctn: ctns = self.log[-whose]['settles'].get(htlc_id, None) if ctns is None: continue - if ctns[ctx_owner] is not None and ctns[ctx_owner] <= ctn: + if ctns[ctx_owner_recv] is not None and (ctn is None or ctns[ctx_owner_recv] <= ctn): htlc = self.log[-whose]['adds'][htlc_id] balance += htlc.amount_msat return balance diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 4270dad722c4..92fb0357eb81 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1349,7 +1349,6 @@ def get_lightning_history(self) -> Dict[str, LightningHistoryItem]: lb = sum(chan.balance(LOCAL) if not chan.is_closed_or_closing() else 0 for chan in self.channels.values()) if balance_msat != lb: - # this typically happens when a channel is recently force closed self.logger.info(f'get_lightning_history: balance mismatch {balance_msat - lb}') return out diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 2081dd7585db..ff62bd62e883 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -573,7 +573,7 @@ async def test_SimpleAddSettleWorkflow(self): tx4 = str(alice_channel.force_close_tx()) self.assertNotEqual(tx3, tx4) - self.assertEqual(alice_channel.balance(LOCAL), 500000000000) + self.assertEqual(alice_channel.balance(LOCAL, ctx_owner=LOCAL), 500000000000) self.assertEqual(1, alice_channel.get_oldest_unrevoked_ctn(LOCAL)) self.assertEqual(len(alice_channel.included_htlcs(LOCAL, RECEIVED, ctn=2)), 0) aliceRevocation2 = alice_channel.revoke_current_commitment()