From 57c38b4815d8b7d81f3f4ba7fb89b1f968e78718 Mon Sep 17 00:00:00 2001 From: "Gavin M. Roy" Date: Mon, 15 Jun 2026 14:04:28 -0400 Subject: [PATCH] Fall back to SQL formatting for non-PL/pgSQL bodies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit format_plpgsql errored on LANGUAGE sql function bodies (a bare SQL statement such as WITH … SELECT). When the PL/pgSQL parse has errors, try SQL formatting and use it if it parses cleanly; otherwise report the original PL/pgSQL syntax error. This lets callers that route all function bodies through format_plpgsql handle SQL-language functions. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/lib.rs | 7 +++++++ tests/plpgsql_test.rs | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 9b2f468..751ab8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -151,6 +151,13 @@ pub fn format_plpgsql(code: &str, style: Style) -> Result { .ok_or_else(|| FormatError::Parser("Failed to parse PL/pgSQL".into()))?; let root = tree.root_node(); if root.has_error() { + // The body may not be PL/pgSQL at all — e.g. a LANGUAGE sql function + // body, which is a bare SQL statement (WITH … SELECT, etc.). Fall back + // to SQL formatting; if that parses cleanly, use it. Otherwise report + // the original PL/pgSQL syntax error. + if let Ok(sql) = format(trimmed, style) { + return Ok(sql); + } return Err(FormatError::Syntax(find_error_message(&root, trimmed))); } let fmt = Formatter::new(trimmed, style); diff --git a/tests/plpgsql_test.rs b/tests/plpgsql_test.rs index bd22e4a..6aaed63 100644 --- a/tests/plpgsql_test.rs +++ b/tests/plpgsql_test.rs @@ -72,3 +72,13 @@ fn for_over_query_keeps_query() { ); assert!(result.contains("RETURN NEXT r;"), "\nGot:\n{result}"); } + +// Regression: format_plpgsql falls back to SQL formatting when the body is not +// PL/pgSQL (e.g. a LANGUAGE sql function body), rather than erroring. +#[test] +fn sql_body_fallback() { + let body = "WITH t AS (SELECT 1 AS n) SELECT n FROM t"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + assert!(result.contains("SELECT"), "\nGot:\n{result}"); + assert!(result.trim_end().ends_with(';'), "\nGot:\n{result}"); +}