Skip to content

Commit 7600b10

Browse files
committed
fix: dbt handle indirect cycles
1 parent cc2e1f1 commit 7600b10

File tree

4 files changed

+315
-27
lines changed

4 files changed

+315
-27
lines changed

sqlmesh/dbt/basemodel.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sqlmesh.dbt.test import TestConfig
3131
from sqlmesh.dbt.util import DBT_VERSION
3232
from sqlmesh.utils import AttributeDict
33+
from sqlmesh.utils.dag import find_path_with_dfs
3334
from sqlmesh.utils.errors import ConfigError
3435
from sqlmesh.utils.pydantic import field_validator
3536

@@ -270,9 +271,10 @@ def remove_tests_with_invalid_refs(self, context: DbtContext) -> None:
270271

271272
def fix_circular_test_refs(self, context: DbtContext) -> None:
272273
"""
273-
Checks for direct circular references between two models and moves the test to the downstream
274-
model if found. This addresses the most common circular reference - relationship tests in both
275-
directions. In the future, we may want to increase coverage by checking for indirect circular references.
274+
Checks for circular references between models and moves tests to break cycles.
275+
This handles both direct circular references (A -> B -> A) and indirect circular
276+
references (A -> B -> C -> A). Tests are moved to the model that appears latest
277+
in the dependency chain to ensure the cycle is broken.
276278
277279
Args:
278280
context: The dbt context this model resides within.
@@ -284,16 +286,91 @@ def fix_circular_test_refs(self, context: DbtContext) -> None:
284286
for ref in test.dependencies.refs:
285287
if ref == self.name or ref in self.dependencies.refs:
286288
continue
287-
model = context.refs[ref]
288-
if (
289-
self.name in model.dependencies.refs
290-
or self.name in model.tests_ref_source_dependencies.refs
291-
):
289+
290+
# Check if moving this test would create or maintain a cycle
291+
cycle_path = self._find_circular_path(ref, context, set())
292+
if cycle_path:
293+
# Find the model in the cycle that should receive the test
294+
# We want to move to the model that appears latest in the dependency chain
295+
target_model_name = self._select_target_model_for_test(cycle_path, context)
296+
target_model = context.refs[target_model_name]
297+
292298
logger.info(
293-
f"Moving test '{test.name}' from model '{self.name}' to '{model.name}' to avoid circular reference."
299+
f"Moving test '{test.name}' from model '{self.name}' to '{target_model_name}' "
300+
f"to avoid circular reference through path: {' -> '.join(cycle_path)}"
294301
)
295-
model.tests.append(test)
302+
target_model.tests.append(test)
296303
self.tests.remove(test)
304+
break
305+
306+
def _find_circular_path(
307+
self, ref: str, context: DbtContext, visited: t.Set[str]
308+
) -> t.Optional[t.List[str]]:
309+
"""
310+
Find if there's a circular dependency path from ref back to this model.
311+
312+
Args:
313+
ref: The model name to start searching from
314+
context: The dbt context
315+
visited: Set of model names already visited in this path
316+
317+
Returns:
318+
List of model names forming the circular path, or None if no cycle exists
319+
"""
320+
# Build a graph of all models and their dependencies from the context
321+
graph: t.Dict[str, t.Set[str]] = {}
322+
323+
def build_graph_from_node(node_name: str, current_visited: t.Set[str]) -> None:
324+
if node_name in current_visited or node_name in graph:
325+
return
326+
current_visited.add(node_name)
327+
328+
model = context.refs[node_name]
329+
# Include both direct model dependencies and test dependencies
330+
all_refs = model.dependencies.refs | model.tests_ref_source_dependencies.refs
331+
graph[node_name] = all_refs.copy()
332+
333+
# Recursively build graph for dependencies
334+
for dep in all_refs:
335+
build_graph_from_node(dep, current_visited)
336+
337+
# Build the graph starting from the ref, including visited nodes to avoid infinite recursion
338+
build_graph_from_node(ref, visited.copy())
339+
340+
# Add self.name to the graph if it's not already there
341+
if self.name not in graph:
342+
graph[self.name] = set()
343+
344+
# Use the shared DFS function to find path from ref to self.name
345+
return find_path_with_dfs(graph, start_node=ref, target_node=self.name)
346+
347+
def _select_target_model_for_test(self, cycle_path: t.List[str], context: DbtContext) -> str:
348+
"""
349+
Select which model in the cycle should receive the test.
350+
We select the model that has the most downstream dependencies in the cycle
351+
352+
Args:
353+
cycle_path: List of model names in the circular dependency path
354+
context: The dbt context
355+
356+
Returns:
357+
Name of the model that should receive the test
358+
"""
359+
# Count how many other models in the cycle each model depends on
360+
dependency_counts = {}
361+
362+
for model_name in cycle_path:
363+
model = context.refs[model_name]
364+
all_refs = model.dependencies.refs | model.tests_ref_source_dependencies.refs
365+
count = len([ref for ref in all_refs if ref in cycle_path])
366+
dependency_counts[model_name] = count
367+
368+
# Return the model with the fewest dependencies within the cycle
369+
# (i.e., the most downstream model in the cycle)
370+
if dependency_counts:
371+
return min(dependency_counts, key=dependency_counts.get) # type: ignore
372+
# Fallback to the last model in the path
373+
return cycle_path[-1]
297374

298375
@property
299376
def sqlmesh_config_fields(self) -> t.Set[str]:

sqlmesh/utils/dag.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,82 @@
1515
T = t.TypeVar("T", bound=t.Hashable)
1616

1717

18+
def find_path_with_dfs(
19+
graph: t.Dict[T, t.Set[T]],
20+
start_node: t.Optional[T] = None,
21+
target_node: t.Optional[T] = None,
22+
) -> t.Optional[t.List[T]]:
23+
"""
24+
Find a path in a graph using depth-first search.
25+
26+
This function can be used for two main purposes:
27+
1. Find any cycle in a cyclic subgraph (when target_node is None)
28+
2. Find a specific path from start_node to target_node
29+
30+
Args:
31+
graph: Dictionary mapping nodes to their dependencies/neighbors
32+
start_node: Optional specific node to start the search from
33+
target_node: Optional target node to search for. If None, finds any cycle
34+
35+
Returns:
36+
List of nodes forming the path, or None if no path/cycle found
37+
"""
38+
if not graph:
39+
return None
40+
41+
visited: t.Set[T] = set()
42+
rec_stack: t.Set[T] = set()
43+
path: t.List[T] = []
44+
45+
def dfs(node: T) -> t.Optional[t.List[T]]:
46+
if target_node is None:
47+
# Cycle detection mode: look for any node in recursion stack
48+
if node in rec_stack:
49+
cycle_start = path.index(node)
50+
return path[cycle_start:] + [node]
51+
else:
52+
# Target search mode: look for specific target
53+
if node == target_node:
54+
return [node]
55+
56+
if node in visited:
57+
return None
58+
59+
visited.add(node)
60+
rec_stack.add(node)
61+
path.append(node)
62+
63+
# Follow edges to neighbors
64+
for neighbor in graph.get(node, set()):
65+
if neighbor in graph: # Only follow edges to nodes in our subgraph
66+
result = dfs(neighbor)
67+
if result:
68+
if target_node is None:
69+
# Cycle detection: return the cycle as-is
70+
return result
71+
# Target search: prepend current node to path
72+
return [node] + result
73+
74+
rec_stack.remove(node)
75+
path.pop()
76+
return None
77+
78+
# Determine which nodes to try as starting points
79+
start_nodes = [start_node] if start_node is not None else list(graph.keys())
80+
81+
for node in start_nodes:
82+
if node not in visited and node in graph:
83+
result = dfs(node)
84+
if result:
85+
if target_node is None:
86+
# Cycle detection: remove duplicate node at end
87+
return result[:-1] if len(result) > 1 and result[0] == result[-1] else result
88+
# Target search: return path as-is
89+
return result
90+
91+
return None
92+
93+
1894
class DAG(t.Generic[T]):
1995
def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None):
2096
self._dag: t.Dict[T, t.Set[T]] = {}
@@ -99,6 +175,17 @@ def upstream(self, node: T) -> t.Set[T]:
99175

100176
return self._upstream[node]
101177

178+
def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]:
179+
"""Find the exact cycle path using DFS when a cycle is detected.
180+
181+
Args:
182+
nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies
183+
184+
Returns:
185+
List of nodes forming the cycle path, or None if no cycle found
186+
"""
187+
return find_path_with_dfs(nodes_in_cycle)
188+
102189
@property
103190
def roots(self) -> t.Set[T]:
104191
"""Returns all nodes in the graph without any upstream dependencies."""
@@ -125,23 +212,28 @@ def sorted(self) -> t.List[T]:
125212
next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps}
126213

127214
if not next_nodes:
128-
# Sort cycle candidates to make the order deterministic
129-
cycle_candidates_msg = (
130-
"\nPossible candidates to check for circular references: "
131-
+ ", ".join(str(node) for node in sorted(cycle_candidates))
132-
)
215+
# A cycle was detected - find the exact cycle path
216+
cycle_path = self._find_cycle_path(unprocessed_nodes)
133217

134-
if last_processed_nodes:
135-
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
136-
str(node) for node in last_processed_nodes
137-
)
218+
last_processed_msg = ""
219+
if cycle_path:
220+
cycle_msg = f"\nCycle: {' -> '.join(str(node) for node in cycle_path)} -> {cycle_path[0]}"
138221
else:
139-
last_processed_msg = ""
222+
# Fallback message in case a cycle can't be found
223+
cycle_candidates_msg = (
224+
"\nPossible candidates to check for circular references: "
225+
+ ", ".join(str(node) for node in sorted(cycle_candidates))
226+
)
227+
cycle_msg = cycle_candidates_msg
228+
if last_processed_nodes:
229+
last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join(
230+
str(node) for node in last_processed_nodes
231+
)
140232

141233
raise SQLMeshError(
142234
"Detected a cycle in the DAG. "
143235
"Please make sure there are no circular references between nodes."
144-
f"{last_processed_msg}{cycle_candidates_msg}"
236+
f"{last_processed_msg}{cycle_msg}"
145237
)
146238

147239
for node in next_nodes:

tests/dbt/test_model.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,126 @@ def test_model_test_circular_references() -> None:
116116
assert downstream_model.tests == [downstream_test, upstream_test]
117117

118118

119+
def test_model_test_indirect_circular_references() -> None:
120+
"""Test detection and resolution of indirect circular references through test dependencies"""
121+
model_a = ModelConfig(name="model_a") # No dependencies
122+
model_b = ModelConfig(
123+
name="model_b", dependencies=Dependencies(refs={"model_a"})
124+
) # B depends on A
125+
model_c = ModelConfig(
126+
name="model_c", dependencies=Dependencies(refs={"model_b"})
127+
) # C depends on B
128+
129+
context = DbtContext(_refs={"model_a": model_a, "model_b": model_b, "model_c": model_c})
130+
131+
# Test on model_a that references model_c (creates indirect cycle through test dependencies)
132+
# The cycle would be: model_a (via test) -> model_c -> model_b -> model_a
133+
test_a_refs_c = TestConfig(
134+
name="test_a_refs_c",
135+
sql="",
136+
dependencies=Dependencies(refs={"model_a", "model_c"}), # Test references both A and C
137+
)
138+
139+
# Place tests that would create indirect cycles when combined with model dependencies
140+
model_a.tests = [test_a_refs_c]
141+
assert model_b.tests == []
142+
assert model_c.tests == []
143+
144+
# Fix circular references on model_a
145+
model_a.fix_circular_test_refs(context)
146+
# The test should be moved from model_a to break the indirect cycle down to model c
147+
assert model_a.tests == []
148+
assert test_a_refs_c in model_c.tests
149+
150+
151+
def test_model_test_complex_indirect_circular_references() -> None:
152+
"""Test detection and resolution of more complex indirect circular references through test dependencies"""
153+
# Create models with a longer linear dependency chain (no cycles in models themselves)
154+
# A -> B -> C -> D (B depends on A, C depends on B, D depends on C)
155+
model_a = ModelConfig(name="model_a") # No dependencies
156+
model_b = ModelConfig(
157+
name="model_b", dependencies=Dependencies(refs={"model_a"})
158+
) # B depends on A
159+
model_c = ModelConfig(
160+
name="model_c", dependencies=Dependencies(refs={"model_b"})
161+
) # C depends on B
162+
model_d = ModelConfig(
163+
name="model_d", dependencies=Dependencies(refs={"model_c"})
164+
) # D depends on C
165+
166+
context = DbtContext(
167+
_refs={"model_a": model_a, "model_b": model_b, "model_c": model_c, "model_d": model_d}
168+
)
169+
170+
# Test on model_a that references model_d (creates long indirect cycle through test dependencies)
171+
# The cycle would be: model_a (via test) -> model_d -> model_c -> model_b -> model_a
172+
test_a_refs_d = TestConfig(
173+
name="test_a_refs_d",
174+
sql="",
175+
dependencies=Dependencies(refs={"model_a", "model_d"}), # Test references both A and D
176+
)
177+
178+
# Place tests that would create indirect cycles when combined with model dependencies
179+
model_a.tests = [test_a_refs_d]
180+
model_b.tests = []
181+
assert model_c.tests == []
182+
assert model_d.tests == []
183+
184+
# Fix circular references on model_a
185+
model_a.fix_circular_test_refs(context)
186+
# The test should be moved from model_a to break the long indirect cycle down to model_d
187+
assert model_a.tests == []
188+
assert model_d.tests == [test_a_refs_d]
189+
190+
# Test on model_b that references model_d (creates indirect cycle through test dependencies)
191+
# The cycle would be: model_b (via test) -> model_d -> model_c -> model_b
192+
test_b_refs_d = TestConfig(
193+
name="test_b_refs_d",
194+
sql="",
195+
dependencies=Dependencies(refs={"model_b", "model_d"}), # Test references both B and D
196+
)
197+
model_a.tests = []
198+
model_b.tests = [test_b_refs_d]
199+
model_c.tests = []
200+
model_d.tests = []
201+
202+
model_b.fix_circular_test_refs(context)
203+
assert model_a.tests == []
204+
assert model_b.tests == []
205+
assert model_c.tests == []
206+
assert model_d.tests == [test_b_refs_d]
207+
208+
# Do both at the same time
209+
model_a.tests = [test_a_refs_d]
210+
model_b.tests = [test_b_refs_d]
211+
model_c.tests = []
212+
model_d.tests = []
213+
214+
model_a.fix_circular_test_refs(context)
215+
model_b.fix_circular_test_refs(context)
216+
assert model_a.tests == []
217+
assert model_b.tests == []
218+
assert model_c.tests == []
219+
assert model_d.tests == [test_a_refs_d, test_b_refs_d]
220+
221+
# Test A -> B -> C cycle and make sure test ends up with C
222+
test_a_refs_c = TestConfig(
223+
name="test_a_refs_c",
224+
sql="",
225+
dependencies=Dependencies(refs={"model_a", "model_c"}), # Test references both A and C
226+
)
227+
model_a.tests = [test_a_refs_c]
228+
model_b.tests = []
229+
model_c.tests = []
230+
model_d.tests = []
231+
232+
model_a.fix_circular_test_refs(context)
233+
assert model_a.tests == []
234+
assert model_b.tests == []
235+
assert model_c.tests == [test_a_refs_c]
236+
assert model_d.tests == []
237+
238+
119239
@pytest.mark.slow
120240
def test_load_invalid_ref_audit_constraints(
121241
tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project

0 commit comments

Comments
 (0)