diff --git a/pymongosql/__init__.py b/pymongosql/__init__.py index ac97a3f..1085e1a 100644 --- a/pymongosql/__init__.py +++ b/pymongosql/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .connection import Connection -__version__: str = "0.4.2" +__version__: str = "0.4.3" # Globals https://www.python.org/dev/peps/pep-0249/#globals apilevel: str = "2.0" diff --git a/pymongosql/sql/query_handler.py b/pymongosql/sql/query_handler.py index fe95e21..52ef408 100644 --- a/pymongosql/sql/query_handler.py +++ b/pymongosql/sql/query_handler.py @@ -168,6 +168,18 @@ def can_handle(self, ctx: Any) -> bool: """Check if this is a from context""" return hasattr(ctx, "tableReference") + @staticmethod + def _strip_collection_quotes(name: str) -> str: + """Strip surrounding double quotes from collection name if present. + + Args: + name: Collection name, potentially quoted + + Returns: + Collection name with quotes removed + """ + return re.sub(r'^"([^"]+)"$', r"\1", name) + def _parse_function_call(self, ctx: Any) -> Optional[Dict[str, Any]]: """ Detect and parse aggregate() function calls in FROM clause. @@ -196,13 +208,17 @@ def _parse_function_call(self, ctx: Any) -> Optional[Dict[str, Any]]: # Pattern: [qualifier.]functionName(arg1, arg2) # We need to match: (optional_collection.)aggregate('...', '...') - pattern = r"^(?:(\w+)\.)?aggregate\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)$" + # Support collection names with double quotes for special characters like hyphens + pattern = r"^(?:(\"[^\"]+\"|\w+)\.)?aggregate\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)$" match = re.match(pattern, text, re.IGNORECASE | re.DOTALL) if not match: return None collection = match.group(1) # Can be None for unqualified aggregate() + # Strip quotes from collection name if present + if collection: + collection = self._strip_collection_quotes(collection) pipeline = match.group(2) options = match.group(3) @@ -245,7 +261,7 @@ def handle_visitor(self, ctx: PartiQLParser.FromClauseContext, parse_result: "Qu # Regular collection reference table_text = ctx.tableReference().getText() # Strip surrounding quotes from collection name (e.g., "user.accounts" -> user.accounts) - collection_name = re.sub(r'^"([^"]+)"$', r"\1", table_text) + collection_name = self._strip_collection_quotes(table_text) parse_result.collection = collection_name _logger.debug(f"Parsed regular collection: {collection_name}") return collection_name diff --git a/tests/test_cursor_aggregate.py b/tests/test_cursor_aggregate.py index b14b555..a74e105 100644 --- a/tests/test_cursor_aggregate.py +++ b/tests/test_cursor_aggregate.py @@ -327,3 +327,30 @@ def test_aggregate_multiple_stages(self, conn): total_users_idx = col_names.index("total_users") assert row[avg_age_idx] is not None and isinstance(row[avg_age_idx], (int, float)) assert row[total_users_idx] is not None and isinstance(row[total_users_idx], (int, float)) + + def test_aggregate_collection_name_with_hyphen(self, conn): + """Test aggregate function with collection name containing hyphen (user-orders)""" + pipeline = json.dumps([{"$match": {"customer_type": "premium"}}]) + + # Test collection name with hyphen + sql = f""" + SELECT * + FROM "user-orders".aggregate('{pipeline}', '{{}}') + """ + + cursor = conn.cursor() + result = cursor.execute(sql) + + assert result == cursor + assert isinstance(cursor.result_set, ResultSet) + + rows = cursor.result_set.fetchall() + assert len(rows) > 0, "Should have results from user-orders collection" + + # Verify all returned rows are premium customers + col_names = [desc[0] for desc in cursor.result_set.description] + assert "customer_type" in col_names, "customer_type should be in result columns" + + customer_type_idx = col_names.index("customer_type") + for row in rows: + assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'"