-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfenwick_tree.py
More file actions
376 lines (288 loc) · 10.9 KB
/
fenwick_tree.py
File metadata and controls
376 lines (288 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
"""
Fenwick tree (Binary Indexed Tree) for efficient range sum queries and point updates.
A Fenwick tree maintains cumulative frequency information and supports two main operations:
* update(i, delta): add delta to the element at index i
* query(i): return the sum of elements from index 0 to i (inclusive)
* range_query(left, right): return the sum of elements from left to right (inclusive)
The tree uses a clever indexing scheme based on the binary representation of indices
to achieve logarithmic time complexity for both operations.
Time complexity: O(log n) for update and query operations.
Space complexity: O(n) where n is the size of the array.
"""
from __future__ import annotations
# Don't use annotations during contest
from typing import Final, Generic, Protocol, TypeVar
from typing_extensions import Self
class Summable(Protocol):
def __add__(self, other: Self, /) -> Self: ...
def __sub__(self, other: Self, /) -> Self: ...
def __le__(self, other: Self, /) -> bool: ...
ValueT = TypeVar("ValueT", bound=Summable)
class FenwickTree(Generic[ValueT]):
def __init__(self, size: int, zero: ValueT) -> None:
self.size: Final = size
self.zero: Final = zero
# 1-indexed tree for easier bit manipulation
self.tree: list[ValueT] = [zero] * (size + 1)
@classmethod
def from_array(cls, arr: list[ValueT], zero: ValueT) -> Self:
"""Create a Fenwick tree from an existing array in O(n) time."""
n = len(arr)
tree = cls(n, zero)
# Compute prefix sums
prefix = [zero] * (n + 1)
for i in range(n):
prefix[i + 1] = prefix[i] + arr[i]
# Build tree in O(n): each tree[i] contains sum of range [i - (i & -i) + 1, i]
for i in range(1, n + 1):
range_start = i - (i & (-i)) + 1
tree.tree[i] = prefix[i] - prefix[range_start - 1]
return tree
def update(self, index: int, delta: ValueT) -> None:
"""Add delta to the element at the given index."""
if not (0 <= index < self.size):
msg = f"Index {index} out of bounds for size {self.size}"
raise IndexError(msg)
# Convert to 1-indexed
index += 1
while index <= self.size:
self.tree[index] = self.tree[index] + delta
# Move to next index by adding the lowest set bit
index += index & (-index)
def query(self, index: int) -> ValueT:
"""Return the sum of elements from 0 to index (inclusive)."""
if not (0 <= index < self.size):
msg = f"Index {index} out of bounds for size {self.size}"
raise IndexError(msg)
# Convert to 1-indexed
index += 1
result = self.zero
while index > 0:
result = result + self.tree[index]
# Move to parent by removing the lowest set bit
index -= index & (-index)
return result
def range_query(self, left: int, right: int) -> ValueT:
"""Sum of elements from left to right (inclusive). Returns zero for invalid ranges."""
if left > right or left < 0 or right >= self.size:
return self.zero
if left == 0:
return self.query(right)
return self.query(right) - self.query(left - 1)
# Optional functionality (not always needed during competition)
def get_value(self, index: int) -> ValueT:
"""Get the current value at a specific index."""
if not (0 <= index < self.size):
msg = f"Index {index} out of bounds for size {self.size}"
raise IndexError(msg)
if index == 0:
return self.query(0)
return self.query(index) - self.query(index - 1)
def first_nonzero_index(self, start_index: int) -> int | None:
"""Find smallest index >= start_index with value > zero.
REQUIRES: all updates are non-negative, ValueT is totally ordered (e.g., int, float).
"""
start_index = max(start_index, 0)
if start_index >= self.size:
return None
prefix_before = self.query(start_index - 1) if start_index > 0 else self.zero
total = self.query(self.size - 1)
if total == prefix_before:
return None
# Fenwick lower_bound: first idx with prefix_sum(idx) > prefix_before
idx = 0 # 1-based cursor
cur = self.zero # running prefix at 'idx'
bit = 1 << (self.size.bit_length() - 1)
while bit:
nxt = idx + bit
if nxt <= self.size:
cand = cur + self.tree[nxt]
if cand <= prefix_before: # move right while prefix <= target
cur = cand
idx = nxt
bit >>= 1
# idx is the largest position with prefix <= prefix_before (1-based).
# The answer is idx (converted to 0-based).
return idx
def __len__(self) -> int:
return self.size
def test_main() -> None:
f = FenwickTree(5, 0)
f.update(0, 7)
f.update(2, 13)
f.update(4, 19)
assert f.query(4) == 39
assert f.range_query(1, 3) == 13
# Optional functionality (not always needed during competition)
assert f.get_value(2) == 13
g = FenwickTree.from_array([1, 2, 3, 4, 5], 0)
assert g.query(4) == 15
# Don't write tests below during competition.
def test_basic() -> None:
# Test with integers
ft = FenwickTree(5, 0)
# Initial array: [0, 0, 0, 0, 0]
assert ft.query(0) == 0
assert ft.query(4) == 0
assert ft.range_query(1, 3) == 0
# Update operations
ft.update(0, 5) # [5, 0, 0, 0, 0]
ft.update(2, 3) # [5, 0, 3, 0, 0]
ft.update(4, 7) # [5, 0, 3, 0, 7]
# Query operations
assert ft.query(0) == 5
assert ft.query(2) == 8 # 5 + 0 + 3
assert ft.query(4) == 15 # 5 + 0 + 3 + 0 + 7
# Range queries
assert ft.range_query(0, 2) == 8
assert ft.range_query(2, 4) == 10
assert ft.range_query(1, 3) == 3
# Get individual values
assert ft.get_value(0) == 5
assert ft.get_value(2) == 3
assert ft.get_value(4) == 7
def test_from_array() -> None:
arr = [1, 3, 5, 7, 9, 11]
ft = FenwickTree.from_array(arr, 0)
# Test that prefix sums match
expected_sum = 0
for i in range(len(arr)):
expected_sum += arr[i]
assert ft.query(i) == expected_sum
# Test range queries
assert ft.range_query(1, 3) == 3 + 5 + 7 # 15
assert ft.range_query(2, 4) == 5 + 7 + 9 # 21
# Test updates
ft.update(2, 10) # arr[2] becomes 15
assert ft.get_value(2) == 15
assert ft.range_query(1, 3) == 3 + 15 + 7 # 25
def test_edge_cases() -> None:
ft = FenwickTree(1, 0)
# Single element tree
ft.update(0, 42)
assert ft.query(0) == 42
assert ft.range_query(0, 0) == 42
assert ft.get_value(0) == 42
# Empty range
ft_large = FenwickTree(10, 0)
assert ft_large.range_query(5, 3) == 0 # left > right
def test_bounds_checking() -> None:
"""Test that out-of-bounds access raises appropriate errors."""
ft = FenwickTree(5, 0)
# Test update bounds
try:
ft.update(-1, 10)
assert False, "Should raise IndexError for negative index"
except IndexError:
pass
try:
ft.update(5, 10)
assert False, "Should raise IndexError for index >= size"
except IndexError:
pass
# Test query bounds
try:
ft.query(-1)
assert False, "Should raise IndexError for negative index"
except IndexError:
pass
try:
ft.query(5)
assert False, "Should raise IndexError for index >= size"
except IndexError:
pass
# Test range_query bounds - should return zero for invalid ranges
assert ft.range_query(-1, 2) == 0
assert ft.range_query(0, 5) == 0
# Test get_value bounds
try:
ft.get_value(-1)
assert False, "Should raise IndexError for negative index"
except IndexError:
pass
try:
ft.get_value(5)
assert False, "Should raise IndexError for index >= size"
except IndexError:
pass
def test_first_nonzero_bounds() -> None:
"""Test first_nonzero_index with boundary conditions."""
ft = FenwickTree(10, 0)
ft.update(5, 1)
# Negative start_index should be clamped to 0
assert ft.first_nonzero_index(-5) == 5
# Start from exactly where nonzero is
assert ft.first_nonzero_index(5) == 5
# Start past all nonzero elements
assert ft.first_nonzero_index(10) is None
assert ft.first_nonzero_index(100) is None
# Empty tree
ft_empty = FenwickTree(10, 0)
assert ft_empty.first_nonzero_index(0) is None
def test_negative_values() -> None:
ft = FenwickTree(4, 0)
# Mix of positive and negative updates
ft.update(0, 10)
ft.update(1, -5)
ft.update(2, 8)
ft.update(3, -3)
assert ft.query(3) == 10 # 10 + (-5) + 8 + (-3)
assert ft.range_query(1, 2) == 3 # (-5) + 8
# Update with negative delta
ft.update(0, -5) # Subtract 5 from position 0
assert ft.get_value(0) == 5
assert ft.query(3) == 5 # 5 + (-5) + 8 + (-3)
def test_linear_from_array() -> None:
"""Test that the optimized from_array produces identical results."""
import time
# Test arrays of different sizes
test_cases = [
[1, 3, 5, 7, 9, 11],
[10, -5, 8, -3, 15, 2, -7, 12],
list(range(100)),
]
for arr in test_cases:
ft = FenwickTree.from_array(arr, 0)
# Verify all prefix sums match expected
expected_sum = 0
for i in range(len(arr)):
expected_sum += arr[i]
assert ft.query(i) == expected_sum, f"Mismatch at index {i}"
# Verify individual values
for i in range(len(arr)):
assert ft.get_value(i) == arr[i], f"Value mismatch at index {i}"
# Test range queries
if len(arr) >= 3:
assert ft.range_query(1, 2) == sum(arr[1:3])
# Simple performance comparison for large array
large_arr = list(range(1000))
# Time the optimized version (should be faster)
start = time.perf_counter()
ft_optimized = FenwickTree.from_array(large_arr, 0)
optimized_time = time.perf_counter() - start
# Verify correctness on large array
for i in [0, 100, 500, 999]:
expected = sum(large_arr[:i + 1])
assert ft_optimized.query(i) == expected
print(f"Linear from_array time for 1000 elements: {optimized_time:.6f}s")
def test_first_nonzero_index() -> None:
ft = FenwickTree(10, 0)
ft.update(2, 1)
ft.update(8, 1)
assert ft.first_nonzero_index(5) == 8
assert ft.first_nonzero_index(8) == 8
assert ft.first_nonzero_index(0) == 2
assert ft.first_nonzero_index(9) is None
def main() -> None:
test_basic()
test_from_array()
test_edge_cases()
test_bounds_checking()
test_first_nonzero_bounds()
test_negative_values()
test_linear_from_array()
test_first_nonzero_index()
test_main()
print("All Fenwick tree tests passed!")
if __name__ == "__main__":
main()