diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 46bf20ec..8f584362 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -109,7 +109,7 @@ impl Parsed { #[must_use] pub fn parse(src: &str) -> Parsed { let tokens = tokenize(src); - let (import_spans, typedef_spans, relation_spans, index_spans, errors) = + let (import_spans, typedef_spans, relation_spans, index_spans, function_spans, errors) = parse_tokens(&tokens, src); let green = build_green_tree( @@ -119,6 +119,7 @@ pub fn parse(src: &str) -> Parsed { &typedef_spans, &relation_spans, &index_spans, + &function_spans, ); let root = ast::Root::from_green(green.clone()); @@ -129,17 +130,17 @@ pub fn parse(src: &str) -> Parsed { } } -/// Identifies and collects the spans of `import`, `typedef`, `relation`, and -/// `index` statements in a token stream. +/// Identifies and collects the spans of `import`, `typedef`, `relation`, +/// `index`, and `function` statements in a token stream. /// /// Returns tuples containing the spans of `import` statements, `typedef`/`extern type` declarations, -/// relation declarations, index declarations, and any parse errors encountered +/// relation declarations, index declarations, function declarations or definitions, and any parse errors encountered /// during span collection. /// /// # Examples /// /// ```no_run -/// let (imports, typedefs, relations, indexes, errors) = parse_tokens(&tokens, src); +/// let (imports, typedefs, relations, indexes, functions, errors) = parse_tokens(&tokens, src); /// assert!(imports.iter().all(|span| span.start < span.end)); /// ``` #[expect(clippy::type_complexity, reason = "returning multiple span lists")] @@ -151,21 +152,25 @@ fn parse_tokens( Vec, Vec, Vec, + Vec, Vec>, ) { let (import_spans, errors) = collect_import_spans(tokens, src); let typedef_spans = collect_typedef_spans(tokens, src); let relation_spans = collect_relation_spans(tokens, src); let (index_spans, index_errors) = collect_index_spans(tokens, src); + let (function_spans, func_errors) = collect_function_spans(tokens, src); let mut all_errors = errors; all_errors.extend(index_errors); + all_errors.extend(func_errors); ( import_spans, typedef_spans, relation_spans, index_spans, + function_spans, all_errors, ) } @@ -253,6 +258,112 @@ fn collect_import_spans( st.into_parts() } +fn function_params() -> impl Parser> { + use std::cell::Cell; + + let depth = Cell::new(0usize); + + just(SyntaxKind::T_LPAREN) + .padded_by(inline_ws().repeated()) + .ignore_then( + filter_map(move |span, kind| match kind { + SyntaxKind::T_LPAREN => { + depth.set(depth.get() + 1); + Ok(()) + } + SyntaxKind::T_RPAREN => { + if depth.get() == 0 { + Err(Simple::custom(span, "unexpected ')'")) + } else { + depth.set(depth.get() - 1); + Ok(()) + } + } + _ => Ok(()), + }) + .padded_by(inline_ws().repeated()) + .repeated(), + ) + .then_ignore(just(SyntaxKind::T_RPAREN)) + .ignored() +} + +fn function_body() -> impl Parser> { + use std::cell::Cell; + + let depth = Cell::new(0usize); + + just(SyntaxKind::T_LBRACE) + .padded_by(inline_ws().repeated()) + .ignore_then( + filter_map(move |span, kind| match kind { + SyntaxKind::T_LBRACE => { + depth.set(depth.get() + 1); + Ok(()) + } + SyntaxKind::T_RBRACE => { + if depth.get() == 0 { + Err(Simple::custom(span, "unexpected '}'")) + } else { + depth.set(depth.get() - 1); + Ok(()) + } + } + _ => Ok(()), + }) + .padded_by(inline_ws().repeated()) + .repeated(), + ) + .then_ignore(just(SyntaxKind::T_RBRACE)) + .ignored() +} + +fn function_return_ty() -> impl Parser> { + just(SyntaxKind::T_COLON) + .padded_by(inline_ws().repeated()) + .ignore_then( + filter(|kind: &SyntaxKind| { + !matches!(kind, SyntaxKind::T_LBRACE | SyntaxKind::T_RBRACE) + }) + .padded_by(inline_ws().repeated()) + .repeated() + .at_least(1), + ) + .ignored() +} + +fn extern_function_decl() -> impl Parser> { + let ident = just(SyntaxKind::T_IDENT) + .ignored() + .padded_by(inline_ws().repeated()); + + just(SyntaxKind::K_EXTERN) + .padded_by(inline_ws().repeated()) + .ignore_then(just(SyntaxKind::K_FUNCTION)) + .padded_by(inline_ws().repeated()) + .ignore_then(ident) + .then(function_params()) + .then(function_return_ty().or_not()) + .padded_by(inline_ws().repeated()) + .map_with_span(|_, sp: Span| sp) +} + +fn normal_function_def() -> impl Parser> { + let ident = just(SyntaxKind::T_IDENT) + .ignored() + .padded_by(inline_ws().repeated()); + + just(SyntaxKind::K_FUNCTION) + .padded_by(inline_ws().repeated()) + .ignore_then(ident) + .then(function_params()) + .then(function_return_ty().or_not()) + .padded_by(inline_ws().repeated()) + .then(function_body()) + .padded_by(inline_ws().repeated()) + .map_with_span(|_, sp: Span| sp) +} + /// Collects the spans of `typedef` and `extern type` declarations in the token stream. /// /// Each span covers the entire line of the declaration, enabling grouping of tokens into @@ -547,6 +658,84 @@ fn collect_index_spans( st.into_parts() } +/// Collects the spans of function declarations and definitions. +/// +/// The parser recognises `extern function` declarations without a body and +/// regular `function` definitions with a body enclosed in braces. Any syntax +/// errors are collected for later reporting and the cursor skips past the +/// offending span. +fn collect_function_spans( + tokens: &[(SyntaxKind, Span)], + src: &str, +) -> (Vec, Vec>) { + type State<'a> = SpanCollector<'a, Vec>>; + + fn parse_into_span( + st: &mut State<'_>, + parser: impl Parser>, + start: usize, + ) { + let iter = st.stream.tokens().iter().skip(st.stream.cursor()).cloned(); + let sub = Stream::from_iter(start..st.stream.src().len(), iter); + let (res, err) = parser.parse_recovery(sub); + if let Some(sp) = res { + let end = sp.end; + st.spans.push(sp); + st.stream.skip_until(end); + } else { + st.extra.extend(err); + let end = st.stream.line_end(st.stream.cursor()); + st.stream.skip_until(end); + } + } + + fn handle_extern(st: &mut State<'_>, span: Span) { + let mut idx = st.stream.cursor() + 1; + while let Some((kind, sp)) = st.stream.tokens().get(idx) { + if matches!(kind, SyntaxKind::T_WHITESPACE | SyntaxKind::T_COMMENT) + && !st + .stream + .src() + .get(sp.clone()) + .is_some_and(|t| t.contains('\n')) + { + idx += 1; + continue; + } + break; + } + + let is_func = st + .stream + .tokens() + .get(idx) + .is_some_and(|(kind, _)| *kind == SyntaxKind::K_FUNCTION); + + if is_func { + let parser = extern_function_decl(); + parse_into_span(st, parser, span.start); + } else { + st.stream.advance(); + let end = st.stream.line_end(st.stream.cursor()); + st.stream.skip_until(end); + } + } + + fn handle_function(st: &mut State<'_>, span: Span) { + let parser = normal_function_def(); + parse_into_span(st, parser, span.start); + } + + let mut st = State::new(tokens, src, Vec::new()); + + token_dispatch!(st, { + SyntaxKind::K_EXTERN => handle_extern, + SyntaxKind::K_FUNCTION => handle_function, + }); + + st.into_parts() +} + /// Construct the CST from the token stream and recorded statement spans. /// /// `imports` and `typedefs` must be sorted and non-overlapping so that tokens @@ -559,11 +748,13 @@ fn build_green_tree( typedefs: &[Span], relations: &[Span], indexes: &[Span], + functions: &[Span], ) -> GreenNode { assert_spans_sorted(imports); assert_spans_sorted(typedefs); assert_spans_sorted(relations); assert_spans_sorted(indexes); + assert_spans_sorted(functions); let mut builder = GreenNodeBuilder::new(); builder.start_node(DdlogLanguage::kind_to_raw(SyntaxKind::N_DATALOG_PROGRAM)); @@ -571,12 +762,14 @@ fn build_green_tree( let mut typedef_iter = typedefs.iter().peekable(); let mut relation_iter = relations.iter().peekable(); let mut index_iter = indexes.iter().peekable(); + let mut function_iter = functions.iter().peekable(); for (kind, span) in tokens { advance_span_iter(&mut import_iter, span.start); advance_span_iter(&mut typedef_iter, span.start); advance_span_iter(&mut relation_iter, span.start); advance_span_iter(&mut index_iter, span.start); + advance_span_iter(&mut function_iter, span.start); maybe_start( &mut builder, @@ -602,6 +795,12 @@ fn build_green_tree( span.start, SyntaxKind::N_INDEX, ); + maybe_start( + &mut builder, + &mut function_iter, + span.start, + SyntaxKind::N_FUNCTION, + ); push_token(&mut builder, kind, &span, src); @@ -609,6 +808,7 @@ fn build_green_tree( maybe_finish(&mut builder, &mut typedef_iter, span.end); maybe_finish(&mut builder, &mut relation_iter, span.end); maybe_finish(&mut builder, &mut index_iter, span.end); + maybe_finish(&mut builder, &mut function_iter, span.end); } builder.finish_node(); @@ -777,6 +977,16 @@ pub mod ast { .map(|syntax| Index { syntax }) .collect() } + + /// Collect all function declarations and definitions under this root. + #[must_use] + pub fn functions(&self) -> Vec { + self.syntax + .children() + .filter(|n| n.kind() == SyntaxKind::N_FUNCTION) + .map(|syntax| Function { syntax }) + .collect() + } } /// Typed wrapper for an `import` statement. @@ -1193,6 +1403,43 @@ pub mod ast { cols } } + + /// Typed wrapper for a function declaration or definition. + #[derive(Debug, Clone)] + pub struct Function { + pub(crate) syntax: SyntaxNode, + } + + impl Function { + /// Access the underlying syntax node. + #[must_use] + pub fn syntax(&self) -> &SyntaxNode { + &self.syntax + } + + /// Name of the function if present. + #[must_use] + pub fn name(&self) -> Option { + self.syntax + .children_with_tokens() + .skip_while(|e| !matches!(e.kind(), SyntaxKind::K_FUNCTION)) + .skip(1) + .find_map(|e| match e { + rowan::NodeOrToken::Token(t) if t.kind() == SyntaxKind::T_IDENT => { + Some(t.text().to_string()) + } + _ => None, + }) + } + + /// Whether this is an `extern function` declaration. + #[must_use] + pub fn is_extern(&self) -> bool { + self.syntax + .children_with_tokens() + .any(|e| e.kind() == SyntaxKind::K_EXTERN) + } + } } #[cfg(test)] diff --git a/tests/parser.rs b/tests/parser.rs index 82faa811..3c199159 100644 --- a/tests/parser.rs +++ b/tests/parser.rs @@ -111,6 +111,21 @@ fn index_whitespace_variations() -> &'static str { " index Idx_User_ws \t on\n User (\n username ) " } +#[fixture] +fn extern_function() -> &'static str { + "extern function hash(data: string): u64" +} + +#[fixture] +fn function_with_body() -> &'static str { + "function to_uppercase(s: string): string { }" +} + +#[fixture] +fn function_no_return() -> &'static str { + "function log_message(msg: string) { }" +} + /// Verifies that parsing and pretty-printing preserves the original input text /// and produces the expected root node kind. #[rstest] @@ -515,3 +530,42 @@ fn index_declaration_whitespace_variations(#[case] src: &str) { assert_eq!(idx.relation(), Some("User".into())); assert_eq!(idx.columns(), vec![String::from("username")]); } + +#[rstest] +fn extern_function_parsed(extern_function: &str) { + let parsed = parse(extern_function); + assert!(parsed.errors().is_empty()); + let funcs = parsed.root().functions(); + assert_eq!(funcs.len(), 1); + let Some(func) = funcs.first() else { + panic!("function should exist"); + }; + assert_eq!(func.name(), Some("hash".into())); + assert!(func.is_extern()); +} + +#[rstest] +fn function_with_body_parsed(function_with_body: &str) { + let parsed = parse(function_with_body); + assert!(parsed.errors().is_empty()); + let funcs = parsed.root().functions(); + assert_eq!(funcs.len(), 1); + let Some(func) = funcs.first() else { + panic!("function should exist"); + }; + assert_eq!(func.name(), Some("to_uppercase".into())); + assert!(!func.is_extern()); +} + +#[rstest] +fn function_no_return_parsed(function_no_return: &str) { + let parsed = parse(function_no_return); + assert!(parsed.errors().is_empty()); + let funcs = parsed.root().functions(); + assert_eq!(funcs.len(), 1); + let Some(func) = funcs.first() else { + panic!("function should exist"); + }; + assert_eq!(func.name(), Some("log_message".into())); + assert!(!func.is_extern()); +}