Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions minidb/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Query execution engine for MiniDB."""

import functools
import re
from collections import defaultdict
from typing import Any
Expand Down Expand Up @@ -364,29 +365,42 @@ def _project_columns(
return results

def _execute_order_by(self, rows: list[Row], order_by: list[OrderByItem]) -> list[Row]:
"""Sort rows by ORDER BY columns."""
"""Sort rows by ORDER BY columns with per-column direction."""

def sort_key(row):
keys = []
def _compare_rows(a: Row, b: Row) -> int:
for item in order_by:
col_name = item.column
if item.table_alias:
col_name = f'{item.table_alias}.{item.column}'

val = row.get(col_name, row.get(item.column))
val_a = a.get(col_name, a.get(item.column))
val_b = b.get(col_name, b.get(item.column))

# Handle None values
if val is None:
val = (1, None) # Sort NULLs last
# NULLs sort last regardless of direction
if val_a is None and val_b is None:
continue
if val_a is None:
return 1
if val_b is None:
return -1

# Compare values
if val_a < val_b:
cmp = -1
elif val_a > val_b:
cmp = 1
else:
val = (0, val)
continue

# Reverse for DESC
if item.direction == 'DESC':
cmp = -cmp

return cmp

keys.append(val)
return keys
return 0

# Sort with direction handling
reverse = bool(order_by and order_by[0].direction == 'DESC')
return sorted(rows, key=sort_key, reverse=reverse)
return sorted(rows, key=functools.cmp_to_key(_compare_rows))

def execute_insert(self, query: InsertQuery) -> int:
"""Execute an INSERT query."""
Expand Down
17 changes: 17 additions & 0 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ def test_order_by_desc(self, db):
salaries = [r['salary'] for r in results]
assert salaries == [55000.0, 52000.0, 50000.0, 48000.0, 45000.0]

def test_order_by_mixed_directions(self, db):
"""Test ORDER BY with mixed ASC/DESC on different columns."""
# active has true/false, age varies within each group
results = db.query('SELECT * FROM users ORDER BY active DESC, age ASC')

# active=true first (DESC), then within each group sorted by age ASC
active_flags = [r['active'] for r in results]
assert active_flags == [True, True, True, False, False]

# Ages within active=true group should be ascending
active_ages = [r['age'] for r in results if r['active'] is True]
assert active_ages == sorted(active_ages)

# Ages within active=false group should be ascending
inactive_ages = [r['age'] for r in results if r['active'] is False]
assert inactive_ages == sorted(inactive_ages)

def test_order_by_with_where(self, db):
"""Test ORDER BY combined with WHERE."""
results = db.query('SELECT * FROM users WHERE active = true ORDER BY age DESC')
Expand Down