From 3936bdb4f5e33d24f3f404b416a7bbc5dd21016a Mon Sep 17 00:00:00 2001 From: Sebastien Tardif Date: Thu, 2 Jul 2026 12:29:14 -0700 Subject: [PATCH] fix: support mixed ASC/DESC directions in ORDER BY Replace single reverse flag with per-column comparator using functools.cmp_to_key. Each ORDER BY column now respects its own direction independently. Closes #3 --- minidb/query.py | 40 +++++++++++++++++++++++++++------------- tests/test_queries.py | 17 +++++++++++++++++ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/minidb/query.py b/minidb/query.py index abf8b76..27e7d5d 100644 --- a/minidb/query.py +++ b/minidb/query.py @@ -1,5 +1,6 @@ """Query execution engine for MiniDB.""" +import functools import re from collections import defaultdict from typing import Any @@ -364,29 +365,42 @@ def _project_columns( return results def _execute_order_by(self, rows: list[Row], order_by: list[OrderByItem]) -> list[Row]: - """Sort rows by ORDER BY columns.""" + """Sort rows by ORDER BY columns with per-column direction.""" - def sort_key(row): - keys = [] + def _compare_rows(a: Row, b: Row) -> int: for item in order_by: col_name = item.column if item.table_alias: col_name = f'{item.table_alias}.{item.column}' - val = row.get(col_name, row.get(item.column)) + val_a = a.get(col_name, a.get(item.column)) + val_b = b.get(col_name, b.get(item.column)) - # Handle None values - if val is None: - val = (1, None) # Sort NULLs last + # NULLs sort last regardless of direction + if val_a is None and val_b is None: + continue + if val_a is None: + return 1 + if val_b is None: + return -1 + + # Compare values + if val_a < val_b: + cmp = -1 + elif val_a > val_b: + cmp = 1 else: - val = (0, val) + continue + + # Reverse for DESC + if item.direction == 'DESC': + cmp = -cmp + + return cmp - keys.append(val) - return keys + return 0 - # Sort with direction handling - reverse = bool(order_by and order_by[0].direction == 'DESC') - return sorted(rows, key=sort_key, reverse=reverse) + return sorted(rows, key=functools.cmp_to_key(_compare_rows)) def execute_insert(self, query: InsertQuery) -> int: """Execute an INSERT query.""" diff --git a/tests/test_queries.py b/tests/test_queries.py index 0b726d2..bf0f4dc 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -222,6 +222,23 @@ def test_order_by_desc(self, db): salaries = [r['salary'] for r in results] assert salaries == [55000.0, 52000.0, 50000.0, 48000.0, 45000.0] + def test_order_by_mixed_directions(self, db): + """Test ORDER BY with mixed ASC/DESC on different columns.""" + # active has true/false, age varies within each group + results = db.query('SELECT * FROM users ORDER BY active DESC, age ASC') + + # active=true first (DESC), then within each group sorted by age ASC + active_flags = [r['active'] for r in results] + assert active_flags == [True, True, True, False, False] + + # Ages within active=true group should be ascending + active_ages = [r['age'] for r in results if r['active'] is True] + assert active_ages == sorted(active_ages) + + # Ages within active=false group should be ascending + inactive_ages = [r['age'] for r in results if r['active'] is False] + assert inactive_ages == sorted(inactive_ages) + def test_order_by_with_where(self, db): """Test ORDER BY combined with WHERE.""" results = db.query('SELECT * FROM users WHERE active = true ORDER BY age DESC')