diff --git a/src/lib.rs b/src/lib.rs index 9b2f468..751ab8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -151,6 +151,13 @@ pub fn format_plpgsql(code: &str, style: Style) -> Result { .ok_or_else(|| FormatError::Parser("Failed to parse PL/pgSQL".into()))?; let root = tree.root_node(); if root.has_error() { + // The body may not be PL/pgSQL at all — e.g. a LANGUAGE sql function + // body, which is a bare SQL statement (WITH … SELECT, etc.). Fall back + // to SQL formatting; if that parses cleanly, use it. Otherwise report + // the original PL/pgSQL syntax error. + if let Ok(sql) = format(trimmed, style) { + return Ok(sql); + } return Err(FormatError::Syntax(find_error_message(&root, trimmed))); } let fmt = Formatter::new(trimmed, style); diff --git a/tests/plpgsql_test.rs b/tests/plpgsql_test.rs index bd22e4a..6aaed63 100644 --- a/tests/plpgsql_test.rs +++ b/tests/plpgsql_test.rs @@ -72,3 +72,13 @@ fn for_over_query_keeps_query() { ); assert!(result.contains("RETURN NEXT r;"), "\nGot:\n{result}"); } + +// Regression: format_plpgsql falls back to SQL formatting when the body is not +// PL/pgSQL (e.g. a LANGUAGE sql function body), rather than erroring. +#[test] +fn sql_body_fallback() { + let body = "WITH t AS (SELECT 1 AS n) SELECT n FROM t"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + assert!(result.contains("SELECT"), "\nGot:\n{result}"); + assert!(result.trim_end().ends_with(';'), "\nGot:\n{result}"); +}