Skip to content

Commit 5d9f482

Browse files
committed
use range instead of single boundary
1 parent 7a775c8 commit 5d9f482

File tree

6 files changed

+396
-195
lines changed

6 files changed

+396
-195
lines changed

sqlmesh/core/state_sync/base.py

Lines changed: 11 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from sqlmesh import migrations
1212
from sqlmesh.core.environment import (
1313
Environment,
14-
EnvironmentNamingInfo,
1514
EnvironmentStatements,
1615
EnvironmentSummary,
1716
)
@@ -21,17 +20,20 @@
2120
SnapshotIdLike,
2221
SnapshotIdAndVersionLike,
2322
SnapshotInfoLike,
24-
SnapshotTableCleanupTask,
25-
SnapshotTableInfo,
2623
SnapshotNameVersion,
2724
SnapshotIdAndVersion,
2825
)
2926
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
3027
from sqlmesh.utils import major_minor
3128
from sqlmesh.utils.date import TimeLike
3229
from sqlmesh.utils.errors import SQLMeshError
33-
from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator
34-
from sqlmesh.core.state_sync.common import StateStream
30+
from sqlmesh.utils.pydantic import PydanticModel, field_validator
31+
from sqlmesh.core.state_sync.common import (
32+
StateStream,
33+
ExpiredSnapshotBatch,
34+
PromotionResult,
35+
ExpiredBatchRange,
36+
)
3537

3638
logger = logging.getLogger(__name__)
3739

@@ -72,64 +74,6 @@ def _schema_version_validator(cls, v: t.Any) -> int:
7274
SCHEMA_VERSION: int = MIN_SCHEMA_VERSION + len(MIGRATIONS) - 1
7375

7476

75-
class BatchBoundary(PydanticModel):
76-
updated_ts: int
77-
name: str
78-
identifier: str
79-
80-
def to_upper_batch_boundary(self) -> UpperBatchBoundary:
81-
return UpperBatchBoundary(
82-
updated_ts=self.updated_ts,
83-
name=self.name,
84-
identifier=self.identifier,
85-
)
86-
87-
def to_lower_batch_boundary(self, batch_size: int) -> LowerBatchBoundary:
88-
return LowerBatchBoundary(
89-
updated_ts=self.updated_ts,
90-
name=self.name,
91-
identifier=self.identifier,
92-
batch_size=batch_size,
93-
)
94-
95-
96-
class UpperBatchBoundary(BatchBoundary):
97-
@classmethod
98-
def include_all_boundary(cls) -> UpperBatchBoundary:
99-
# 9999-12-31T23:59:59.999Z in epoch milliseconds
100-
return UpperBatchBoundary(updated_ts=253_402_300_799_999, name="", identifier="")
101-
102-
103-
class LowerBatchBoundary(BatchBoundary):
104-
batch_size: int
105-
106-
@classmethod
107-
def init_batch_boundary(cls, batch_size: int) -> LowerBatchBoundary:
108-
return LowerBatchBoundary(updated_ts=0, name="", identifier="", batch_size=batch_size)
109-
110-
111-
class ExpiredSnapshotBatch(PydanticModel):
112-
"""A batch of expired snapshots to be cleaned up."""
113-
114-
expired_snapshot_ids: t.Set[SnapshotId]
115-
cleanup_tasks: t.List[SnapshotTableCleanupTask]
116-
batch_boundary: BatchBoundary
117-
118-
119-
class PromotionResult(PydanticModel):
120-
added: t.List[SnapshotTableInfo]
121-
removed: t.List[SnapshotTableInfo]
122-
removed_environment_naming_info: t.Optional[EnvironmentNamingInfo]
123-
124-
@field_validator("removed_environment_naming_info")
125-
def _validate_removed_environment_naming_info(
126-
cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo
127-
) -> t.Optional[EnvironmentNamingInfo]:
128-
if v and not info.data.get("removed"):
129-
raise ValueError("removed_environment_naming_info must be None if removed is empty")
130-
return v
131-
132-
13377
class StateReader(abc.ABC):
13478
"""Abstract base class for read-only operations on snapshot and environment state."""
13579

@@ -361,7 +305,7 @@ def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStre
361305
def get_expired_snapshots(
362306
self,
363307
*,
364-
batch_boundary: BatchBoundary,
308+
batch_range: ExpiredBatchRange,
365309
current_ts: t.Optional[int] = None,
366310
ignore_ttl: bool = False,
367311
) -> t.Optional[ExpiredSnapshotBatch]:
@@ -370,9 +314,7 @@ def get_expired_snapshots(
370314
Args:
371315
current_ts: Timestamp used to evaluate expiration.
372316
ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced).
373-
batch_boundary: If provided, gets snapshot relative to the given boundary.
374-
If lower boundary then snapshots later than that will be returned (exclusive).
375-
If upper boundary then snapshots earlier than that will be returned (inclusive).
317+
batch_range: The range of the batch to fetch.
376318
377319
Returns:
378320
A batch describing expired snapshots or None if no snapshots are pending cleanup.
@@ -418,7 +360,7 @@ def delete_expired_snapshots(
418360
self,
419361
ignore_ttl: bool = False,
420362
current_ts: t.Optional[int] = None,
421-
upper_batch_boundary: t.Optional[UpperBatchBoundary] = None,
363+
batch_range: t.Optional[ExpiredBatchRange] = None,
422364
) -> None:
423365
"""Removes expired snapshots.
424366
@@ -429,8 +371,7 @@ def delete_expired_snapshots(
429371
ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting
430372
all snapshots that are not referenced in any environment
431373
current_ts: Timestamp used to evaluate expiration.
432-
upper_batch_boundary: The upper boundary to delete expired snapshots till (inclusive). If not provided,
433-
deletes all expired snapshots.
374+
batch_range: The range of snapshots to delete in this batch. If None, all expired snapshots are deleted.
434375
"""
435376

436377
@abc.abstractmethod

sqlmesh/core/state_sync/cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
SnapshotInfoLike,
1212
)
1313
from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals
14-
from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync, UpperBatchBoundary
14+
from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync
15+
from sqlmesh.core.state_sync.common import ExpiredBatchRange
1516
from sqlmesh.utils.date import TimeLike, now_timestamp
1617

1718

@@ -111,11 +112,11 @@ def delete_expired_snapshots(
111112
self,
112113
ignore_ttl: bool = False,
113114
current_ts: t.Optional[int] = None,
114-
upper_batch_boundary: t.Optional[UpperBatchBoundary] = None,
115+
batch_range: t.Optional[ExpiredBatchRange] = None,
115116
) -> None:
116117
self.snapshot_cache.clear()
117118
self.state_sync.delete_expired_snapshots(
118-
upper_batch_boundary=upper_batch_boundary,
119+
batch_range=batch_range,
119120
ignore_ttl=ignore_ttl,
120121
current_ts=current_ts,
121122
)

sqlmesh/core/state_sync/common.py

Lines changed: 135 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,26 @@
77
import abc
88

99
from dataclasses import dataclass
10+
11+
from pydantic_core.core_schema import ValidationInfo
1012
from sqlglot import exp
1113

1214
from sqlmesh.core.console import Console
1315
from sqlmesh.core.dialect import schema_
14-
from sqlmesh.utils.pydantic import PydanticModel
15-
from sqlmesh.core.environment import Environment, EnvironmentStatements
16+
from sqlmesh.utils.pydantic import PydanticModel, field_validator
17+
from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentNamingInfo
1618
from sqlmesh.utils.errors import SQLMeshError
17-
from sqlmesh.core.snapshot import Snapshot, SnapshotEvaluator
19+
from sqlmesh.core.snapshot import (
20+
Snapshot,
21+
SnapshotEvaluator,
22+
SnapshotId,
23+
SnapshotTableCleanupTask,
24+
SnapshotTableInfo,
25+
)
1826

1927
if t.TYPE_CHECKING:
2028
from sqlmesh.core.engine_adapter.base import EngineAdapter
21-
from sqlmesh.core.state_sync.base import Versions, ExpiredSnapshotBatch, StateReader, StateSync
29+
from sqlmesh.core.state_sync.base import Versions, StateReader, StateSync
2230

2331
logger = logging.getLogger(__name__)
2432

@@ -219,6 +227,109 @@ def __iter__(self) -> t.Iterator[StateStreamContents]:
219227
return _StateStream()
220228

221229

230+
class ExpiredBatchRange(PydanticModel):
231+
start: RowBoundary
232+
end: t.Union[RowBoundary, LimitBoundary]
233+
234+
@classmethod
235+
def init_batch_range(cls, batch_size: int) -> ExpiredBatchRange:
236+
return ExpiredBatchRange(
237+
start=RowBoundary.lowest_boundary(),
238+
end=LimitBoundary(batch_size=batch_size),
239+
)
240+
241+
@classmethod
242+
def all_batch_range(cls) -> ExpiredBatchRange:
243+
return ExpiredBatchRange(
244+
start=RowBoundary.lowest_boundary(),
245+
end=RowBoundary.highest_boundary(),
246+
)
247+
248+
249+
class RowBoundary(PydanticModel):
250+
updated_ts: int
251+
name: str
252+
identifier: str
253+
254+
@classmethod
255+
def lowest_boundary(cls) -> RowBoundary:
256+
return RowBoundary(updated_ts=0, name="", identifier="")
257+
258+
@classmethod
259+
def highest_boundary(cls) -> RowBoundary:
260+
# 9999-12-31T23:59:59.999Z in epoch milliseconds
261+
return RowBoundary(updated_ts=253_402_300_799_999, name="", identifier="")
262+
263+
# def to_upper_batch_boundary(self) -> EndRowBoundary:
264+
# return EndRowBoundary(
265+
# updated_ts=self.updated_ts,
266+
# name=self.name,
267+
# identifier=self.identifier,
268+
# )
269+
#
270+
# def to_lower_batch_boundary(self, batch_size: int) -> StartRowBoundary:
271+
# return StartRowBoundary(
272+
# updated_ts=self.updated_ts,
273+
# name=self.name,
274+
# identifier=self.identifier,
275+
# batch_size=batch_size,
276+
# )
277+
278+
279+
#
280+
#
281+
# class EndRowBoundary(RowBoundary):
282+
# @classmethod
283+
# def include_all_boundary(cls) -> EndRowBoundary:
284+
# # 9999-12-31T23:59:59.999Z in epoch milliseconds
285+
# return EndRowBoundary(updated_ts=253_402_300_799_999, name="", identifier="")
286+
#
287+
# def to_start_batch_boundary(self) -> StartRowBoundary:
288+
# return StartRowBoundary(
289+
# updated_ts=self.updated_ts,
290+
# name=self.name,
291+
# identifier=self.identifier,
292+
# )
293+
294+
295+
class LimitBoundary(PydanticModel):
296+
batch_size: int
297+
298+
@classmethod
299+
def init_batch_boundary(cls, batch_size: int) -> LimitBoundary:
300+
return LimitBoundary(batch_size=batch_size)
301+
302+
303+
#
304+
# class StartRowBoundary(RowBoundary):
305+
# @classmethod
306+
# def init_batch_boundary(cls) -> StartRowBoundary:
307+
# return StartRowBoundary(updated_ts=0, name="", identifier="")
308+
#
309+
310+
311+
class PromotionResult(PydanticModel):
312+
added: t.List[SnapshotTableInfo]
313+
removed: t.List[SnapshotTableInfo]
314+
removed_environment_naming_info: t.Optional[EnvironmentNamingInfo]
315+
316+
@field_validator("removed_environment_naming_info")
317+
def _validate_removed_environment_naming_info(
318+
cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo
319+
) -> t.Optional[EnvironmentNamingInfo]:
320+
if v and not info.data.get("removed"):
321+
raise ValueError("removed_environment_naming_info must be None if removed is empty")
322+
return v
323+
324+
325+
class ExpiredSnapshotBatch(PydanticModel):
326+
"""A batch of expired snapshots to be cleaned up."""
327+
328+
expired_snapshot_ids: t.Set[SnapshotId]
329+
cleanup_tasks: t.List[SnapshotTableCleanupTask]
330+
batch_range: ExpiredBatchRange
331+
332+
222333
def iter_expired_snapshot_batches(
223334
state_reader: StateReader,
224335
*,
@@ -234,24 +345,29 @@ def iter_expired_snapshot_batches(
234345
ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced).
235346
batch_size: Maximum number of snapshots to fetch per batch.
236347
"""
237-
from sqlmesh.core.state_sync.base import LowerBatchBoundary
238348

239349
batch_size = batch_size if batch_size is not None else EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE
240-
batch_boundary = LowerBatchBoundary.init_batch_boundary(batch_size=batch_size)
350+
batch_range = ExpiredBatchRange.init_batch_range(batch_size=batch_size)
241351

242352
while True:
243353
batch = state_reader.get_expired_snapshots(
244354
current_ts=current_ts,
245355
ignore_ttl=ignore_ttl,
246-
batch_boundary=batch_boundary,
356+
batch_range=batch_range,
247357
)
248358

249359
if batch is None:
250360
return
251361

252362
yield batch
253363

254-
batch_boundary = batch.batch_boundary.to_lower_batch_boundary(batch_size=batch_size)
364+
assert isinstance(batch.batch_range.end, RowBoundary), (
365+
"Only RowBoundary is supported for pagination currently"
366+
)
367+
batch_range = ExpiredBatchRange(
368+
start=batch.batch_range.end,
369+
end=LimitBoundary(batch_size=batch_size),
370+
)
255371

256372

257373
def delete_expired_snapshots(
@@ -286,17 +402,25 @@ def delete_expired_snapshots(
286402
ignore_ttl=ignore_ttl,
287403
batch_size=batch_size,
288404
):
405+
end_info = (
406+
f"updated_ts={batch.batch_range.end.updated_ts}"
407+
if isinstance(batch.batch_range.end, RowBoundary)
408+
else f"limit={batch.batch_range.end.batch_size}"
409+
)
289410
logger.info(
290-
"Processing batch of size %s and max_updated_ts of %s",
411+
"Processing batch of size %s with end %s",
291412
len(batch.expired_snapshot_ids),
292-
batch.batch_boundary.updated_ts,
413+
end_info,
293414
)
294415
snapshot_evaluator.cleanup(
295416
target_snapshots=batch.cleanup_tasks,
296417
on_complete=console.update_cleanup_progress if console else None,
297418
)
298419
state_sync.delete_expired_snapshots(
299-
upper_batch_boundary=batch.batch_boundary.to_upper_batch_boundary(),
420+
batch_range=ExpiredBatchRange(
421+
start=RowBoundary.lowest_boundary(),
422+
end=batch.batch_range.end,
423+
),
300424
ignore_ttl=ignore_ttl,
301425
)
302426
logger.info("Cleaned up expired snapshots batch")

0 commit comments

Comments
 (0)