diff --git a/src/formatter/expr.rs b/src/formatter/expr.rs index d913c91..0c9a142 100644 --- a/src/formatter/expr.rs +++ b/src/formatter/expr.rs @@ -1112,6 +1112,32 @@ impl<'a> Formatter<'a> { self.text(node).to_string() } + /// Flatten the wrapper nodes the grammar inserts for character/bit types + /// (`CharacterWithLength > character > kw_character + opt_varying`, etc.) + /// so the keyword, VARYING qualifier, and length token are all visible to + /// the single-pass type renderer below. + fn flatten_type_children(&self, node: Node<'a>) -> Vec> { + const WRAPPERS: &[&str] = &[ + "CharacterWithLength", + "CharacterWithoutLength", + "character", + "BitWithLength", + "BitWithoutLength", + "bit", + "opt_varying", + ]; + let mut out = Vec::new(); + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + if child.is_named() && WRAPPERS.contains(&child.kind()) { + out.extend(self.flatten_type_children(child)); + } else { + out.push(child); + } + } + out + } + fn format_typename_inner(&self, node: Node<'a>) -> String { // Get the base type name. let mut base = String::new(); @@ -1119,8 +1145,7 @@ impl<'a> Formatter<'a> { let mut extra_keywords = Vec::new(); let mut timezone_keywords = Vec::new(); - let mut cursor = node.walk(); - for child in node.children(&mut cursor) { + for child in self.flatten_type_children(node) { if child.is_named() { match child.kind() { "kw_integer" | "kw_int" | "kw_smallint" | "kw_bigint" | "kw_real" @@ -1169,6 +1194,19 @@ impl<'a> Formatter<'a> { "opt_type_modifiers" => { modifiers = self.format_type_modifiers(child); } + "Iconst" => { + // Bare length token inside CharacterWithLength. + modifiers = format!("({})", self.text(child).trim()); + } + "expr_list" => { + // Parenthesized length/precision args (e.g. BitWithLength + // renders the length as ( expr_list )). + let items: Vec = flatten_list(child, "expr_list") + .iter() + .map(|e| self.format_expr(*e)) + .collect(); + modifiers = format!("({})", items.join(", ")); + } "attrs" => { base.push_str(&self.format_attrs(child)); } diff --git a/tests/smoke_test.rs b/tests/smoke_test.rs index e8c74af..596db86 100644 --- a/tests/smoke_test.rs +++ b/tests/smoke_test.rs @@ -109,3 +109,31 @@ fn typed_literal_constants_river() { assert_eq!(result, expected, "\nInput: {sql}\nGot:\n{result}"); } } + +#[test] +fn cast_multiword_type_names() { + // Regression: multi-word type names in :: casts must not be truncated + // (e.g. `::character varying` previously dropped `varying`). + let cases = [ + ( + "SELECT a::character varying", + "SELECT a::CHARACTER VARYING;", + ), + ("SELECT a::varchar(50)", "SELECT a::VARCHAR(50);"), + ("SELECT a::char(10)", "SELECT a::CHAR(10);"), + ( + "SELECT a::character varying(50)", + "SELECT a::CHARACTER VARYING(50);", + ), + ("SELECT a::double precision", "SELECT a::DOUBLE PRECISION;"), + ("SELECT a::bit varying(8)", "SELECT a::BIT VARYING(8);"), + ( + "SELECT a::timestamp with time zone", + "SELECT a::TIMESTAMP WITH TIME ZONE;", + ), + ]; + for (sql, expected) in cases { + let result = format(sql, Style::River).unwrap(); + assert_eq!(result, expected, "\nInput: {sql}\nGot:\n{result}"); + } +}