diff --git a/src/parser/ast/rule.rs b/src/parser/ast/rule.rs index 79ecbfb4..606c4940 100644 --- a/src/parser/ast/rule.rs +++ b/src/parser/ast/rule.rs @@ -25,10 +25,15 @@ //! let rule = first_rule("R(x) :- S(x), T(x)."); //! assert_eq!(rule.head().as_deref(), Some("R(x)")); //! assert_eq!(rule.body_literals(), vec!["S(x)".into(), "T(x)".into()]); +//! let exprs = rule.body_expressions(); +//! assert_eq!(exprs.len(), 2); +//! assert!(exprs.into_iter().all(|expr| expr.is_ok())); //! ``` -use super::AstNode; -use crate::{DdlogLanguage, SyntaxKind}; +use chumsky::error::Simple; + +use super::{AstNode, Expr}; +use crate::{DdlogLanguage, SyntaxKind, parser::expression::parse_expression}; /// Typed wrapper for a rule declaration. #[derive(Debug, Clone)] @@ -64,85 +69,28 @@ impl Rule { /// Text of each body literal in order of appearance. #[must_use] pub fn body_literals(&self) -> Vec { - let mut iter = self - .syntax - .children_with_tokens() - .skip_while(|e| e.kind() != SyntaxKind::T_IMPLIES); - - if matches!(iter.next().map(|e| e.kind()), Some(SyntaxKind::T_IMPLIES)) { - Self::extract_literals_from_body(iter) - } else { - Vec::new() - } - } - - /// Iterate over the body and collect literals. - fn extract_literals_from_body(iter: I) -> Vec - where - I: Iterator>, - { - let mut buf = String::new(); - let mut lits = Vec::new(); - - for e in iter { - if Self::process_body_element(e, &mut buf, &mut lits) { - break; - } - } - - lits - } - - /// Process a single syntax element of the rule body. - /// - /// Returns `true` if the body has ended. - fn process_body_element( - element: rowan::SyntaxElement, - buf: &mut String, - lits: &mut Vec, - ) -> bool { - use rowan::NodeOrToken; - - match element { - NodeOrToken::Token(t) => Self::process_token(&t, buf, lits), - NodeOrToken::Node(n) => { - buf.push_str(&n.text().to_string()); - false - } - } + self.body_expression_texts() } - /// Handle a token in the rule body. - /// - /// Returns `true` when the terminating `.` is encountered. - fn process_token( - token: &rowan::SyntaxToken, - buf: &mut String, - lits: &mut Vec, - ) -> bool { - match token.kind() { - SyntaxKind::T_COMMA => { - Self::add_literal_if_not_empty(buf, lits); - buf.clear(); - false - } - SyntaxKind::T_DOT => { - Self::add_literal_if_not_empty(buf, lits); - true - } - _ => { - buf.push_str(token.text()); - false - } - } + /// Parse the rule body into structured expressions. + #[must_use] + pub fn body_expressions(&self) -> Vec>>> { + self.body_expression_texts() + .into_iter() + .map(|text| parse_expression(&text)) + .collect() } - /// Add a literal to the list if it contains non-whitespace text. - fn add_literal_if_not_empty(buf: &str, lits: &mut Vec) { - let lit = buf.trim(); - if !lit.is_empty() { - lits.push(lit.to_string()); - } + fn body_expression_texts(&self) -> Vec { + self.syntax + .children() + .filter(|node| node.kind() == SyntaxKind::N_EXPR_NODE) + .map(|node| { + let text = node.text().to_string(); + text.trim().to_string() + }) + .filter(|text| !text.is_empty()) + .collect() } } @@ -151,7 +99,7 @@ impl_ast_node!(Rule); #[cfg(test)] mod tests { - use crate::parse; + use crate::{parse, parser::ast::{BinaryOp, Expr}}; #[expect(clippy::expect_used, reason = "Using expect for clearer test failures")] #[test] @@ -166,4 +114,32 @@ mod tests { assert_eq!(rule.head().as_deref(), Some("A(x)")); assert_eq!(rule.body_literals(), vec!["B(x)".to_string()]); } + + #[expect(clippy::expect_used, reason = "tests expect parsed rule bodies")] + #[test] + fn rule_body_expressions_parse_structures() { + let parsed = parse("R(x) :- 1 + 2 * 3, if (cond) Foo(cond) else Bar()."); + let rule = parsed + .root() + .rules() + .first() + .cloned() + .expect("rule missing"); + let exprs = rule.body_expressions(); + assert_eq!(exprs.len(), 2); + let first = exprs[0].as_ref().expect("first literal failed to parse"); + match first { + Expr::Binary { op: BinaryOp::Add, lhs, rhs } => match rhs.as_ref() { + Expr::Binary { op: BinaryOp::Mul, .. } => { + assert!(matches!(lhs.as_ref(), Expr::Literal(_))); + } + other => panic!("expected multiplication RHS, got {other:?}"), + }, + other => panic!("unexpected expression: {other:?}"), + } + let second = exprs[1] + .as_ref() + .expect("second literal failed to parse"); + assert!(matches!(second, Expr::IfElse { .. })); + } } diff --git a/src/parser/expression_span.rs b/src/parser/expression_span.rs index e3ee5219..0a7f1096 100644 --- a/src/parser/expression_span.rs +++ b/src/parser/expression_span.rs @@ -11,43 +11,6 @@ use thiserror::Error; use crate::parser::expression::parse_expression; use crate::{Span, SyntaxKind}; -/// Find the span of a rule body expression within a slice of tokens. -/// -/// The function scans the tokens starting at `start_idx` until `end`. It -/// returns the range of bytes between the token following `T_IMPLIES` and the -/// preceding `T_DOT`. `None` is returned if a well formed range cannot be -/// determined. -#[must_use] -pub(crate) fn rule_body_span( - tokens: &[(SyntaxKind, Span)], - start_idx: usize, - end: usize, -) -> Option { - let mut expr_start = None; - let mut expr_end = None; - let mut idx = start_idx; - while let Some(tok) = tokens.get(idx) { - if tok.1.start >= end { - break; - } - match tok.0 { - SyntaxKind::T_IMPLIES => { - expr_start = tokens.get(idx + 1).map(|t| t.1.start); - } - SyntaxKind::T_DOT => { - expr_end = Some(tok.1.start); - break; - } - _ => {} - } - idx += 1; - } - match (expr_start, expr_end) { - (Some(s), Some(e)) if s < e => Some(s..e), - _ => None, - } -} - /// Parse the text within `span` using the expression parser. /// /// Any syntax errors reported by the parser are returned for the caller to diff --git a/src/parser/span_scanner.rs b/src/parser/span_scanner.rs index f2241f30..a32b75fc 100644 --- a/src/parser/span_scanner.rs +++ b/src/parser/span_scanner.rs @@ -7,7 +7,7 @@ use chumsky::prelude::*; -use crate::parser::expression_span::{rule_body_span, validate_expression}; +use crate::parser::expression_span::validate_expression; use crate::{Span, SyntaxKind}; use super::{ @@ -483,9 +483,9 @@ fn rule_decl() -> impl Parser> { let atom_p = atom(); - let statement = rule_statement(ws.clone()); + let literal = rule_body_literal(ws.clone()); - let body = statement + let body = literal .separated_by(just(SyntaxKind::T_COMMA).padded_by(ws.clone())) .allow_trailing() .at_least(1) @@ -504,91 +504,125 @@ fn rule_decl() -> impl Parser> { .map_with_span(|_, sp: Span| sp) } -fn rule_statement( - ws: impl Parser> + Clone + 'static, +fn rule_body_literal( + ws: impl Parser> + Clone, ) -> impl Parser> + Clone { - recursive(|stmt| { - let ws = ws.clone(); - let block = balanced_block(SyntaxKind::T_LBRACE, SyntaxKind::T_RBRACE); - let skip_stmt = just(SyntaxKind::K_SKIP).ignored(); - let atom_stmt = atom(); - - let condition = balanced_block(SyntaxKind::T_LPAREN, SyntaxKind::T_RPAREN); + let grouped = choice(( + balanced_block(SyntaxKind::T_LPAREN, SyntaxKind::T_RPAREN), + balanced_block(SyntaxKind::T_LBRACE, SyntaxKind::T_RBRACE), + balanced_block(SyntaxKind::T_LBRACKET, SyntaxKind::T_RBRACKET), + )); - let if_stmt = just(SyntaxKind::K_IF) - .padded_by(ws.clone()) - .ignore_then(condition.clone()) - .then(stmt.clone().padded_by(ws.clone())) - .then( - just(SyntaxKind::K_ELSE) - .padded_by(ws.clone()) - .ignore_then(stmt.clone()) - .or_not(), + let token = grouped.or( + filter(|kind: &SyntaxKind| { + !matches!( + kind, + SyntaxKind::T_COMMA + | SyntaxKind::T_DOT + | SyntaxKind::T_WHITESPACE + | SyntaxKind::T_COMMENT ) - .ignored(); - - let header_expr = for_header(ws.clone()); - - let for_stmt = just(SyntaxKind::K_FOR) - .padded_by(ws.clone()) - .ignore_then(header_expr) - .then(stmt.clone().padded_by(ws.clone())) - .ignored(); - - choice((for_stmt, if_stmt, block, skip_stmt, atom_stmt)) - .padded_by(ws.clone()) - .ignored() - }) -} + }) + .ignored(), + ); -fn for_header( - ws: impl Parser> + Clone, -) -> impl Parser> + Clone { - just(SyntaxKind::T_LPAREN) + token .padded_by(ws.clone()) - .ignore_then(for_binding_complete(ws.clone())) - .then_ignore(just(SyntaxKind::T_RPAREN)) + .repeated() + .at_least(1) .padded_by(ws) .ignored() } -fn for_binding_complete( - ws: impl Parser> + Clone, -) -> impl Parser> + Clone { - let nested = choice(( - balanced_block(SyntaxKind::T_LPAREN, SyntaxKind::T_RPAREN), - balanced_block(SyntaxKind::T_LBRACKET, SyntaxKind::T_RBRACKET), - balanced_block(SyntaxKind::T_LBRACE, SyntaxKind::T_RBRACE), - )); +fn is_trivia_kind(kind: SyntaxKind) -> bool { + matches!(kind, SyntaxKind::T_WHITESPACE | SyntaxKind::T_COMMENT) +} - let binding_token = nested.clone().or(filter(|kind: &SyntaxKind| { - matches!( - kind, - SyntaxKind::T_IDENT - | SyntaxKind::K_UNDERSCORE - | SyntaxKind::T_COMMA - | SyntaxKind::T_COLON - | SyntaxKind::T_COLON_COLON - | SyntaxKind::T_DOT - | SyntaxKind::T_AT - | SyntaxKind::T_HASH - | SyntaxKind::T_NUMBER - | SyntaxKind::T_STRING - | SyntaxKind::T_APOSTROPHE - ) - }) - .ignored()); +fn is_top_level(paren_depth: usize, brace_depth: usize, bracket_depth: usize) -> bool { + paren_depth == 0 && brace_depth == 0 && bracket_depth == 0 +} - let header_token = - nested.or(filter(|kind: &SyntaxKind| *kind != SyntaxKind::T_RPAREN).ignored()); +fn collect_rule_literal_spans( + tokens: &[(SyntaxKind, Span)], + start_idx: usize, + end: usize, +) -> Vec { + fn push_literal_span(spans: &mut Vec, start: &mut Option, end: &mut Option) { + if let (Some(s), Some(e)) = (start.take(), end.take()) { + if s < e { + spans.push(s..e); + } + } else { + *start = None; + *end = None; + } + } - binding_token - .padded_by(ws.clone()) - .repeated() - .at_least(1) - .then_ignore(just(SyntaxKind::K_IN).padded_by(ws.clone())) - .then(header_token.padded_by(ws).repeated().at_least(1)) - .ignored() + let mut idx = start_idx; + let mut found_body = false; + while let Some((kind, span)) = tokens.get(idx) { + if span.start >= end { + break; + } + if *kind == SyntaxKind::T_IMPLIES { + found_body = true; + idx += 1; + break; + } + idx += 1; + } + + if !found_body { + return Vec::new(); + } + + let mut spans = Vec::new(); + let mut paren_depth = 0usize; + let mut brace_depth = 0usize; + let mut bracket_depth = 0usize; + let mut literal_start: Option = None; + let mut literal_end: Option = None; + + while let Some((kind, span)) = tokens.get(idx) { + if span.start >= end { + break; + } + + if *kind == SyntaxKind::T_DOT && is_top_level(paren_depth, brace_depth, bracket_depth) { + push_literal_span(&mut spans, &mut literal_start, &mut literal_end); + break; + } + + if *kind == SyntaxKind::T_COMMA && is_top_level(paren_depth, brace_depth, bracket_depth) { + push_literal_span(&mut spans, &mut literal_start, &mut literal_end); + idx += 1; + continue; + } + + if is_trivia_kind(*kind) { + idx += 1; + continue; + } + + literal_start.get_or_insert(span.start); + literal_end = Some(span.end); + + match kind { + SyntaxKind::T_LPAREN => paren_depth += 1, + SyntaxKind::T_RPAREN => paren_depth = paren_depth.saturating_sub(1), + SyntaxKind::T_LBRACE => brace_depth += 1, + SyntaxKind::T_RBRACE => brace_depth = brace_depth.saturating_sub(1), + SyntaxKind::T_LBRACKET => bracket_depth += 1, + SyntaxKind::T_RBRACKET => bracket_depth = bracket_depth.saturating_sub(1), + _ => {} + } + + idx += 1; + } + + push_literal_span(&mut spans, &mut literal_start, &mut literal_end); + + spans } /// Return `true` if `span` begins a new line in the source. @@ -631,7 +665,7 @@ fn parse_rule_at_line_start(st: &mut State<'_>, span: Span, exprs: &mut Vec { @@ -682,17 +716,8 @@ mod tests { //! Tests for the span scanner helper utilities. use super::*; use crate::test_util::tokenize; - use chumsky::Stream; use rstest::rstest; - fn parse_rule_statement_input(src: &str) -> (Option<()>, Vec>) { - let tokens = tokenize(src); - let ws = inline_ws().repeated().ignored(); - let parser = rule_statement(ws); - let stream = Stream::from_iter(0..src.len(), tokens.into_iter()); - parser.parse_recovery(stream) - } - #[rstest] #[case("import foo\n", vec![0..11], true)] #[case("import\n", vec![], false)] @@ -707,37 +732,43 @@ mod tests { assert_eq!(errs.is_empty(), errs_empty); } - #[rstest] - #[case("if (cond) if (nested) Process(nested) else Skip() else Handle()")] - #[case("for (a in A(a)) for (b in B(b)) ProcessPair(a, b)")] - #[case("for (item in if cond { Items(item) } else { Others(item) }) Process(item)")] - #[case("for (item in Items(item)) if (item > 10) Process(item)")] - #[case("for (item in Items(item) if item.active) Process(item)")] - #[case( - "if (outer) { for (item in Items(item)) if (should(item)) Process(item) } else { Skip() }" - )] - fn rule_statement_parses_control_flow(#[case] src: &str) { - let (res, errs) = parse_rule_statement_input(src); - assert!(res.is_some(), "expected successful parse for {src}"); - assert!(errs.is_empty(), "unexpected errors for {src:?}: {errs:?}"); + fn literal_texts(src: &str, spans: &[Span]) -> Vec { + spans + .iter() + .map(|sp| src.get(sp.clone()).unwrap_or_default().trim().to_string()) + .collect() } #[rstest] - #[case("if (cond { Process(cond) }")] - #[case("for (item in Items(item) Process(item)")] - #[case("for (item in Items(item) if item > 10 Process(item)")] - fn rule_statement_reports_errors(#[case] src: &str) { - let (res, errs) = parse_rule_statement_input(src); - if res.is_some() { - assert!( - !errs.is_empty(), - "expected errors for {src}, but parser recovered without diagnostics", - ); - } else { - assert!( - !errs.is_empty(), - "expected errors for {src}, but parser reported none", - ); - } + #[case( + "R(x) :- A(x).", + vec!["A(x)"], + )] + #[case( + "R(x) :- A(x), B(x).", + vec!["A(x)", "B(x)"], + )] + #[case( + "R(x) :- if (cond) Do(cond) else Skip(), for (entry in Items(entry)) Process(entry).", + vec![ + "if (cond) Do(cond) else Skip()", + "for (entry in Items(entry)) Process(entry)", + ], + )] + #[case( + "R(x) :- match (value) { 1 -> One(), _ -> Other() }.", + vec!["match (value) { 1 -> One(), _ -> Other() }"], + )] + #[case( + "R(x) :- 1 + 2 * 3, tuple(First(x), Second(x, y[0, 1])).", + vec!["1 + 2 * 3", "tuple(First(x), Second(x, y[0, 1]))"], + )] + fn collect_rule_spans_extracts_literals(#[case] src: &str, #[case] expected: Vec<&str>) { + let tokens = tokenize(src); + let (_rule_spans, expr_spans, errs) = collect_rule_spans(&tokens, src); + assert!(errs.is_empty(), "unexpected errors: {errs:?}"); + let texts = literal_texts(src, &expr_spans); + let expected_texts: Vec = expected.into_iter().map(str::to_string).collect(); + assert_eq!(texts, expected_texts); } } diff --git a/src/parser/tests/expression_integration.rs b/src/parser/tests/expression_integration.rs index e27c7017..73f3cbd7 100644 --- a/src/parser/tests/expression_integration.rs +++ b/src/parser/tests/expression_integration.rs @@ -5,10 +5,7 @@ use crate::{SyntaxKind, parse}; fn rule_body_expression_creates_node() { let src = "R(x) :- 1 + 2 * 3."; let parsed = parse(src); - // The span scanner does not currently recover from invalid rule bodies, so - // errors are expected. The expression parser should still process the body - // tokens. - assert!(!parsed.errors().is_empty()); + assert!(parsed.errors().is_empty(), "unexpected parse errors" ); let root = parsed.root().syntax(); let mut expr_nodes = root .descendants() @@ -18,3 +15,20 @@ fn rule_body_expression_creates_node() { let text = node.text().to_string(); assert_eq!(text.trim(), "1 + 2 * 3"); } + +#[expect(clippy::expect_used, reason = "Using expect for clearer test failures")] +#[test] +fn multiple_literals_create_multiple_nodes() { + let src = "R(x) :- Foo(x), if (cond) Bar() else Baz()."; + let parsed = parse(src); + assert!(parsed.errors().is_empty(), "unexpected parse errors"); + let root = parsed.root().syntax(); + let mut expr_nodes = root + .descendants() + .filter(|n| n.kind() == SyntaxKind::N_EXPR_NODE); + let first = expr_nodes.next().expect("first expr missing"); + let second = expr_nodes.next().expect("second expr missing"); + assert!(expr_nodes.next().is_none()); + assert_eq!(first.text().trim(), "Foo(x)"); + assert_eq!(second.text().trim(), "if (cond) Bar() else Baz()"); +} diff --git a/src/parser/tests/rules.rs b/src/parser/tests/rules.rs index 572c45e2..7acdda4d 100644 --- a/src/parser/tests/rules.rs +++ b/src/parser/tests/rules.rs @@ -43,9 +43,8 @@ fn nested_for_loop_rule() -> &'static str { } #[rstest] -#[case::simple_rule(simple_rule(), true)] -// TODO: rules with multiple body literals should parse without errors once supported -#[case::multi_literal_rule(multi_literal_rule(), true)] +#[case::simple_rule(simple_rule(), false)] +#[case::multi_literal_rule(multi_literal_rule(), false)] #[case::fact_rule(fact_rule(), false)] #[case::for_loop_rule(for_loop_rule(), false)] #[case::for_loop_if_iterable(for_loop_with_if_iterable(), false)]