Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/formatter/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,15 +1112,40 @@ impl<'a> Formatter<'a> {
self.text(node).to_string()
}

/// Flatten the wrapper nodes the grammar inserts for character/bit types
/// (`CharacterWithLength > character > kw_character + opt_varying`, etc.)
/// so the keyword, VARYING qualifier, and length token are all visible to
/// the single-pass type renderer below.
fn flatten_type_children(&self, node: Node<'a>) -> Vec<Node<'a>> {
const WRAPPERS: &[&str] = &[
"CharacterWithLength",
"CharacterWithoutLength",
"character",
"BitWithLength",
"BitWithoutLength",
"bit",
"opt_varying",
];
let mut out = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.is_named() && WRAPPERS.contains(&child.kind()) {
out.extend(self.flatten_type_children(child));
} else {
out.push(child);
}
}
out
}

fn format_typename_inner(&self, node: Node<'a>) -> String {
// Get the base type name.
let mut base = String::new();
let mut modifiers = String::new();
let mut extra_keywords = Vec::new();
let mut timezone_keywords = Vec::new();

let mut cursor = node.walk();
for child in node.children(&mut cursor) {
for child in self.flatten_type_children(node) {
if child.is_named() {
match child.kind() {
"kw_integer" | "kw_int" | "kw_smallint" | "kw_bigint" | "kw_real"
Expand Down Expand Up @@ -1169,6 +1194,19 @@ impl<'a> Formatter<'a> {
"opt_type_modifiers" => {
modifiers = self.format_type_modifiers(child);
}
"Iconst" => {
// Bare length token inside CharacterWithLength.
modifiers = format!("({})", self.text(child).trim());
}
"expr_list" => {
// Parenthesized length/precision args (e.g. BitWithLength
// renders the length as ( expr_list )).
let items: Vec<String> = flatten_list(child, "expr_list")
.iter()
.map(|e| self.format_expr(*e))
.collect();
modifiers = format!("({})", items.join(", "));
}
"attrs" => {
base.push_str(&self.format_attrs(child));
}
Expand Down
28 changes: 28 additions & 0 deletions tests/smoke_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,31 @@ fn typed_literal_constants_river() {
assert_eq!(result, expected, "\nInput: {sql}\nGot:\n{result}");
}
}

#[test]
fn cast_multiword_type_names() {
// Regression: multi-word type names in :: casts must not be truncated
// (e.g. `::character varying` previously dropped `varying`).
let cases = [
(
"SELECT a::character varying",
"SELECT a::CHARACTER VARYING;",
),
("SELECT a::varchar(50)", "SELECT a::VARCHAR(50);"),
("SELECT a::char(10)", "SELECT a::CHAR(10);"),
(
"SELECT a::character varying(50)",
"SELECT a::CHARACTER VARYING(50);",
),
("SELECT a::double precision", "SELECT a::DOUBLE PRECISION;"),
("SELECT a::bit varying(8)", "SELECT a::BIT VARYING(8);"),
(
"SELECT a::timestamp with time zone",
"SELECT a::TIMESTAMP WITH TIME ZONE;",
),
];
for (sql, expected) in cases {
let result = format(sql, Style::River).unwrap();
assert_eq!(result, expected, "\nInput: {sql}\nGot:\n{result}");
}
}
Loading