From ce918ecf15e3874a5b6edebbc0494ad240f5ba6d Mon Sep 17 00:00:00 2001 From: bodymindarts Date: Thu, 12 Mar 2026 08:38:33 +0100 Subject: [PATCH 1/4] feat: add SQLite backend support with feature-gated database selection Add a new `sqlite` feature flag alongside the existing `postgres` feature, allowing compile-time selection of the database backend. This includes: - New `es-entity-macros-sqlite` proc macro crate generating SQLite-compatible SQL (positional `?N` params, `IS` instead of `IS NOT DISTINCT FROM`, etc.) - `src/db.rs` module with cfg-gated type aliases as the single chokepoint for database-specific types (Pool, Row, Db, etc.) - SQLite migration for test infrastructure with in-memory databases - All existing integration tests adapted to work with both backends - 82 macro unit tests + full integration test suite passing on SQLite Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 15 + Cargo.toml | 22 +- es-entity-macros-sqlite/Cargo.toml | 28 + es-entity-macros-sqlite/src/entity.rs | 242 +++ .../src/es_event_context.rs | 206 +++ es-entity-macros-sqlite/src/event.rs | 111 ++ es-entity-macros-sqlite/src/lib.rs | 146 ++ es-entity-macros-sqlite/src/query/input.rs | 194 +++ es-entity-macros-sqlite/src/query/mod.rs | 259 ++++ es-entity-macros-sqlite/src/repo/begin.rs | 60 + .../src/repo/combo_cursor.rs | 337 ++++ .../src/repo/create_all_fn.rs | 286 ++++ es-entity-macros-sqlite/src/repo/create_fn.rs | 393 +++++ es-entity-macros-sqlite/src/repo/delete_fn.rs | 330 ++++ .../src/repo/error_types.rs | 1349 +++++++++++++++++ .../src/repo/find_all_fn.rs | 237 +++ .../src/repo/find_by_fn.rs | 628 ++++++++ .../src/repo/list_by_fn.rs | 852 +++++++++++ .../src/repo/list_for_filters_fn.rs | 959 ++++++++++++ .../src/repo/list_for_fn.rs | 491 ++++++ es-entity-macros-sqlite/src/repo/mod.rs | 271 ++++ es-entity-macros-sqlite/src/repo/nested.rs | 147 ++ .../src/repo/options/columns.rs | 695 +++++++++ .../src/repo/options/delete.rs | 42 + .../src/repo/options/mod.rs | 447 ++++++ .../src/repo/persist_events_batch_fn.rs | 284 ++++ .../src/repo/persist_events_fn.rs | 236 +++ .../src/repo/populate_nested.rs | 101 ++ .../src/repo/post_hydrate_hook.rs | 113 ++ .../src/repo/post_persist_hook.rs | 136 ++ .../src/repo/update_all_fn.rs | 408 +++++ es-entity-macros-sqlite/src/repo/update_fn.rs | 347 +++++ .../src/retry_on_concurrent_modification.rs | 210 +++ flake.nix | 9 +- .../20250718092455_test_setup.sql | 157 ++ src/context/sqlx.rs | 98 +- src/db.rs | 55 +- src/lib.rs | 27 +- src/one_time_executor.rs | 16 +- src/operation/hooks.rs | 4 +- src/operation/mod.rs | 6 +- src/operation/with_time.rs | 2 +- tests/context.rs | 2 + tests/es_query.rs | 13 +- tests/from_async_trait.rs | 11 +- tests/helpers.rs | 13 + tests/hooks.rs | 6 +- tests/nested_entities.rs | 9 +- tests/repo_bulk.rs | 5 +- tests/repo_clock.rs | 17 +- tests/repo_crud.rs | 16 +- tests/repo_errors.rs | 2 + tests/repo_hooks.rs | 12 +- 53 files changed, 10965 insertions(+), 97 deletions(-) create mode 100644 es-entity-macros-sqlite/Cargo.toml create mode 100644 es-entity-macros-sqlite/src/entity.rs create mode 100644 es-entity-macros-sqlite/src/es_event_context.rs create mode 100644 es-entity-macros-sqlite/src/event.rs create mode 100644 es-entity-macros-sqlite/src/lib.rs create mode 100644 es-entity-macros-sqlite/src/query/input.rs create mode 100644 es-entity-macros-sqlite/src/query/mod.rs create mode 100644 es-entity-macros-sqlite/src/repo/begin.rs create mode 100644 es-entity-macros-sqlite/src/repo/combo_cursor.rs create mode 100644 es-entity-macros-sqlite/src/repo/create_all_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/create_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/delete_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/error_types.rs create mode 100644 es-entity-macros-sqlite/src/repo/find_all_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/find_by_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/list_by_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/list_for_filters_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/list_for_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/mod.rs create mode 100644 es-entity-macros-sqlite/src/repo/nested.rs create mode 100644 es-entity-macros-sqlite/src/repo/options/columns.rs create mode 100644 es-entity-macros-sqlite/src/repo/options/delete.rs create mode 100644 es-entity-macros-sqlite/src/repo/options/mod.rs create mode 100644 es-entity-macros-sqlite/src/repo/persist_events_batch_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/persist_events_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/populate_nested.rs create mode 100644 es-entity-macros-sqlite/src/repo/post_hydrate_hook.rs create mode 100644 es-entity-macros-sqlite/src/repo/post_persist_hook.rs create mode 100644 es-entity-macros-sqlite/src/repo/update_all_fn.rs create mode 100644 es-entity-macros-sqlite/src/repo/update_fn.rs create mode 100644 es-entity-macros-sqlite/src/retry_on_concurrent_modification.rs create mode 100644 migrations-sqlite/20250718092455_test_setup.sql diff --git a/Cargo.lock b/Cargo.lock index 9ef34440..d01c59a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -486,6 +486,7 @@ dependencies = [ "chrono", "derive_builder", "es-entity-macros", + "es-entity-macros-sqlite", "futures", "im", "opentelemetry", @@ -516,6 +517,19 @@ dependencies = [ "syn", ] +[[package]] +name = "es-entity-macros-sqlite" +version = "0.10.28-dev" +dependencies = [ + "convert_case", + "darling 0.23.0", + "pluralizer", + "proc-macro2", + "quote", + "regex", + "syn", +] + [[package]] name = "etcetera" version = "0.8.0" @@ -1027,6 +1041,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ + "cc", "pkg-config", "vcpkg", ] diff --git a/Cargo.toml b/Cargo.toml index cbe6d92e..ec5f23f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,17 +13,22 @@ inherits = "dev" [features] -fail-on-warnings = ["es-entity-macros/fail-on-warnings"] +default = ["postgres"] +postgres = ["sqlx/postgres", "dep:es-entity-macros"] +sqlite = ["sqlx/sqlite", "dep:es-entity-macros-sqlite"] + +fail-on-warnings = ["es-entity-macros?/fail-on-warnings", "es-entity-macros-sqlite?/fail-on-warnings"] tracing-context = ["dep:tracing", "dep:tracing-opentelemetry", "dep:opentelemetry", "dep:opentelemetry_sdk"] -graphql = ["es-entity-macros/graphql", "dep:async-graphql", "dep:base64"] -event-context = ["es-entity-macros/event-context", "event-context-enabled"] -event-context-enabled = ["es-entity-macros/event-context-enabled"] +graphql = ["es-entity-macros?/graphql", "es-entity-macros-sqlite?/graphql", "dep:async-graphql", "dep:base64"] +event-context = ["es-entity-macros?/event-context", "es-entity-macros-sqlite?/event-context", "event-context-enabled"] +event-context-enabled = ["es-entity-macros?/event-context-enabled", "es-entity-macros-sqlite?/event-context-enabled"] json-schema = ["dep:schemars"] mdbook-test = ["dep:anyhow"] -instrument = ["es-entity-macros/instrument", "dep:tracing"] +instrument = ["es-entity-macros?/instrument", "es-entity-macros-sqlite?/instrument", "dep:tracing"] [dependencies] -es-entity-macros = { workspace = true } +es-entity-macros = { workspace = true, optional = true } +es-entity-macros-sqlite = { workspace = true, optional = true } base64 = { workspace = true, optional = true } sqlx = { workspace = true } @@ -53,16 +58,19 @@ tokio = { workspace = true } anyhow = { workspace = true } async-trait = { workspace = true } futures = { workspace = true } +sqlx = { workspace = true, features = ["migrate"] } [workspace] resolver = "2" members = [ "es-entity-macros", + "es-entity-macros-sqlite", ] [workspace.dependencies] es-entity-macros = { path = "es-entity-macros", version = "0.10.28-dev" } +es-entity-macros-sqlite = { path = "es-entity-macros-sqlite", version = "0.10.28-dev" } anyhow = "1.0" async-graphql = { version = "8.0.0-rc.3", default-features = false } @@ -74,7 +82,7 @@ schemars = { version = "1.0", features = ["uuid1"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_with = "3.15" -sqlx = { version = "0.8", default-features = false, features = ["macros", "runtime-tokio-rustls", "postgres", "uuid", "chrono", "json" ] } +sqlx = { version = "0.8", default-features = false, features = ["macros", "runtime-tokio-rustls", "uuid", "chrono", "json" ] } tokio = { version = "1.50", features = ["rt-multi-thread", "macros", "time"] } thiserror = "2.0" uuid = { version = "1.22", features = ["serde", "v7"] } diff --git a/es-entity-macros-sqlite/Cargo.toml b/es-entity-macros-sqlite/Cargo.toml new file mode 100644 index 00000000..06970b44 --- /dev/null +++ b/es-entity-macros-sqlite/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "es-entity-macros-sqlite" +description = "Proc macros for es-entity (SQLite backend)" +repository = "https://github.com/GaloyMoney/cala" +version = "0.10.28-dev" +edition = "2024" +license = "Apache-2.0" +categories = ["data-structures", "database"] + +[features] + +fail-on-warnings = [] +graphql = [] +event-context = ["event-context-enabled"] +event-context-enabled = [] +instrument = [] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +syn = "2.0" +quote = "1.0" +darling = "0.23" +pluralizer = "0.5" +convert_case = "0.11" +regex = "1.12" diff --git a/es-entity-macros-sqlite/src/entity.rs b/es-entity-macros-sqlite/src/entity.rs new file mode 100644 index 00000000..6a808dbe --- /dev/null +++ b/es-entity-macros-sqlite/src/entity.rs @@ -0,0 +1,242 @@ +use darling::{FromDeriveInput, FromField, ToTokens}; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; +use syn::Type; + +#[derive(Debug, FromField)] +#[darling(attributes(es_entity))] +struct Field { + ident: Option, + ty: Type, + #[darling(default)] + events: bool, + #[darling(default)] + nested: bool, +} + +impl Field { + fn is_events_field(&self) -> bool { + self.events || self.ident.as_ref().is_some_and(|i| i == "events") + } + + fn extract_nested_entity_type(&self) -> &Type { + if let Type::Path(type_path) = &self.ty + && let Some(segment) = type_path.path.segments.last() + && segment.ident == "Nested" + && let syn::PathArguments::AngleBracketed(generic_args) = &segment.arguments + && let Some(syn::GenericArgument::Type(inner_type)) = generic_args.args.first() + { + return inner_type; + } + panic!("Field must be of type Nested"); + } +} + +#[derive(Debug, FromDeriveInput)] +#[darling(supports(struct_named), attributes(es_entity))] +pub struct EsEntity { + ident: syn::Ident, + #[darling(default, rename = "new")] + new_entity_ident: Option, + #[darling(default, rename = "event")] + event_ident: Option, + data: darling::ast::Data<(), Field>, +} + +impl EsEntity { + fn find_events_field(&self) -> Option<&Field> { + match &self.data { + darling::ast::Data::Struct(fields) => { + fields.iter().find(|field| field.is_events_field()) + } + _ => None, + } + } + + fn nested_fields(&self) -> Vec<&Field> { + match &self.data { + darling::ast::Data::Struct(fields) => { + fields.iter().filter(|field| field.nested).collect() + } + _ => Vec::new(), + } + } +} + +pub fn derive(ast: syn::DeriveInput) -> darling::Result { + let entity = EsEntity::from_derive_input(&ast)?; + Ok(quote!(#entity)) +} + +impl ToTokens for EsEntity { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + let events_field = self + .find_events_field() + .expect("Struct must have a field marked with #[es_entity(events)]") + .ident + .as_ref() + .expect("Not ident on #[events]"); + + let event = self.event_ident.clone().unwrap_or_else(|| { + syn::Ident::new( + &format!("{}Event", self.ident), + proc_macro2::Span::call_site(), + ) + }); + let new = self.new_entity_ident.clone().unwrap_or_else(|| { + syn::Ident::new( + &format!("New{}", self.ident), + proc_macro2::Span::call_site(), + ) + }); + + let nested = self.nested_fields().into_iter().map(|f| { + let field = &f.ident; + let ty = f.extract_nested_entity_type(); + quote! { + impl es_entity::Parent<#ty> for #ident { + fn new_children_mut(&mut self) -> &mut Vec<<#ty as es_entity::EsEntity>::New> { + self.#field.new_entities_mut() + } + + fn inject_children(&mut self, children: impl IntoIterator) { + self.#field.load(children) + } + + fn iter_persisted_children_mut( + &mut self + ) -> std::collections::hash_map::ValuesMut<'_, <<#ty as EsEntity>::Event as EsEvent>::EntityId, #ty> + { + self.#field.iter_persisted_mut() + } + } + } + }); + + tokens.append_all(quote! { + impl es_entity::EsEntity for #ident { + type Event = #event; + type New = #new; + + fn events_mut(&mut self) -> &mut es_entity::EntityEvents<#event> { + &mut self.#events_field + } + fn events(&self) -> &es_entity::EntityEvents<#event> { + &self.#events_field + } + } + + #(#nested)* + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quote::quote; + use syn::parse_quote; + + #[test] + fn test_derive_es_entity() { + let input: syn::DeriveInput = parse_quote! { + #[derive(EsEntity)] + pub struct User { + pub id: UserId, + pub email: String, + #[es_entity(events)] + the_events: EntityEvents + } + }; + + let output = derive(input).unwrap(); + let expected = quote! { + impl es_entity::EsEntity for User { + type Event = UserEvent; + type New = NewUser; + fn events_mut(&mut self) -> &mut es_entity::EntityEvents { + &mut self.the_events + } + fn events(&self) -> &es_entity::EntityEvents { + &self.the_events + } + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } + + #[test] + fn test_derive_without_events_attr() { + let input: syn::DeriveInput = parse_quote! { + #[derive(EsEntity)] + pub struct User { + pub id: UserId, + events: EntityEvents + } + }; + + let output = derive(input).unwrap(); + let expected = quote! { + impl es_entity::EsEntity for User { + type Event = UserEvent; + type New = NewUser; + fn events_mut(&mut self) -> &mut es_entity::EntityEvents { + &mut self.events + } + fn events(&self) -> &es_entity::EntityEvents { + &self.events + } + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } + + #[test] + fn test_derive_with_nested() { + let input: syn::DeriveInput = parse_quote! { + #[derive(EsEntity)] + pub struct User { + pub id: UserId, + #[es_entity(nested)] + children: Nested, + events: EntityEvents + } + }; + + let output = derive(input).unwrap(); + let expected = quote! { + impl es_entity::EsEntity for User { + type Event = UserEvent; + type New = NewUser; + fn events_mut(&mut self) -> &mut es_entity::EntityEvents { + &mut self.events + } + fn events(&self) -> &es_entity::EntityEvents { + &self.events + } + } + + impl es_entity::Parent for User { + fn new_children_mut(&mut self) -> &mut Vec<::New> { + self.children.new_entities_mut() + } + + fn inject_children(&mut self, children: impl IntoIterator) { + self.children.load(children) + } + + fn iter_persisted_children_mut( + &mut self + ) -> std::collections::hash_map::ValuesMut<'_, <::Event as EsEvent>::EntityId, ChildEntity> + { + self.children.iter_persisted_mut() + } + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/es_event_context.rs b/es-entity-macros-sqlite/src/es_event_context.rs new file mode 100644 index 00000000..2f0c6110 --- /dev/null +++ b/es-entity-macros-sqlite/src/es_event_context.rs @@ -0,0 +1,206 @@ +use proc_macro2::TokenStream as TokenStream2; +use syn::{Ident, ItemFn, Token, parse::Parse, parse::ParseStream, punctuated::Punctuated}; + +struct MacroArgs { + args: Vec, +} + +impl Parse for MacroArgs { + fn parse(input: ParseStream) -> syn::Result { + let args = Punctuated::::parse_terminated(input)?; + Ok(MacroArgs { + args: args.into_iter().collect(), + }) + } +} + +// Wrapper for the proc macro that converts between TokenStream types +pub fn make( + args: proc_macro::TokenStream, + input: ItemFn, +) -> darling::Result { + make_internal(args.into(), input) +} + +pub fn make_internal(args: TokenStream2, input: ItemFn) -> darling::Result { + let macro_args: MacroArgs = + syn::parse2(args).map_err(|e| darling::Error::custom(e.to_string()))?; + + let ItemFn { + attrs, + vis, + sig, + block, + } = input; + + let is_async = sig.asyncness.is_some(); + + let insert_stmts: Vec<_> = macro_args + .args + .iter() + .map(|arg| { + let arg_name = arg.to_string(); + quote::quote! { + let _ = ctx.insert(#arg_name, &#arg); + } + }) + .collect(); + + let inserts = if !insert_stmts.is_empty() { + quote::quote! { + { + let mut ctx = es_entity::context::EventContext::current(); + #(#insert_stmts)* + } + } + } else { + quote::quote! {} + }; + + let wrapped_body = if is_async { + quote::quote! { + use es_entity::context::WithEventContext; + let data = es_entity::context::EventContext::current().data(); + async { + #inserts + #block + }.with_event_context(data).await + } + } else { + quote::quote! { + let __es_event_context_guard = es_entity::context::EventContext::fork(); + #inserts + #block + } + }; + + Ok(quote::quote! { + #(#attrs)* + #vis #sig { + #wrapped_body + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use quote::quote; + use syn::parse_quote; + + #[test] + fn no_async_no_args() { + let input: ItemFn = parse_quote! { + pub fn no_async_no_args(&self, a: u32) { + unimplemented!() + } + }; + + // Create empty args + let args = TokenStream2::new(); + + let output = make_internal(args, input).unwrap(); + + let expected = quote! { + pub fn no_async_no_args(&self, a: u32) { + let __es_event_context_guard = es_entity::context::EventContext::fork(); + { + unimplemented!() + } + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } + + #[test] + fn no_async_with_args() { + let input: ItemFn = parse_quote! { + pub fn no_async_with_args(&self, arg_one: u32, arg_two: u64) { + unimplemented!() + } + }; + + // Create args with some parameters + let args = quote! { arg_one, arg_two }; + + let output = make_internal(args, input).unwrap(); + + let expected = quote! { + pub fn no_async_with_args(&self, arg_one: u32, arg_two: u64) { + let __es_event_context_guard = es_entity::context::EventContext::fork(); + { + let mut ctx = es_entity::context::EventContext::current(); + let _ = ctx.insert("arg_one", &arg_one); + let _ = ctx.insert("arg_two", &arg_two); + } + { + unimplemented!() + } + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } + + #[test] + fn async_no_args() { + let input: ItemFn = parse_quote! { + pub async fn async_no_args(&self, a: u32) { + unimplemented!() + } + }; + + // Create empty args + let args = TokenStream2::new(); + + let output = make_internal(args, input).unwrap(); + + let expected = quote! { + pub async fn async_no_args(&self, a: u32) { + use es_entity::context::WithEventContext; + let data = es_entity::context::EventContext::current().data(); + async { + { + unimplemented!() + } + }.with_event_context(data).await + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } + + #[test] + fn async_with_args() { + let input: ItemFn = parse_quote! { + pub async fn async_with_args(&self, arg_one: u32, arg_two: u64) { + unimplemented!() + } + }; + + // Create args with some parameters + let args = quote! { arg_one, arg_two }; + + let output = make_internal(args, input).unwrap(); + + let expected = quote! { + pub async fn async_with_args(&self, arg_one: u32, arg_two: u64) { + use es_entity::context::WithEventContext; + let data = es_entity::context::EventContext::current().data(); + async { + { + let mut ctx = es_entity::context::EventContext::current(); + let _ = ctx.insert("arg_one", &arg_one); + let _ = ctx.insert("arg_two", &arg_two); + } + { + unimplemented!() + } + }.with_event_context(data).await + } + }; + + assert_eq!(output.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/event.rs b/es-entity-macros-sqlite/src/event.rs new file mode 100644 index 00000000..16562780 --- /dev/null +++ b/es-entity-macros-sqlite/src/event.rs @@ -0,0 +1,111 @@ +use convert_case::{Case, Casing}; +use darling::{FromDeriveInput, ToTokens}; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +#[derive(Debug, Clone, FromDeriveInput)] +#[darling(attributes(es_event))] +pub struct EsEvent { + ident: syn::Ident, + data: darling::ast::Data, + id: syn::Type, + #[darling(default, rename = "event_context")] + event_ctx: Option, +} + +pub fn derive(ast: syn::DeriveInput) -> darling::Result { + let event = EsEvent::from_derive_input(&ast)?; + Ok(quote!(#event)) +} + +impl ToTokens for EsEvent { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.ident; + let id = &self.id; + let event_context = { + #[cfg(feature = "event-context")] + { + self.event_ctx.unwrap_or(true) + } + #[cfg(not(feature = "event-context"))] + { + self.event_ctx.unwrap_or(false) + } + }; + + let match_arms = match &self.data { + darling::ast::Data::Enum(variants) => { + let arms: Vec<_> = variants + .iter() + .map(|v| { + let variant_ident = &v.ident; + let snake_name = variant_ident.to_string().to_case(Case::Snake); + quote! { + Self::#variant_ident { .. } => #snake_name, + } + }) + .collect(); + quote! { #(#arms)* } + } + _ => panic!("EsEvent can only be derived for enums"), + }; + + tokens.append_all(quote! { + impl es_entity::EsEvent for #ident { + type EntityId = #id; + + fn event_context() -> bool { + #event_context + } + + fn event_type(&self) -> &'static str { + match self { + #match_arms + } + } + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generates_event_type_match() { + let input: syn::DeriveInput = syn::parse_quote! { + #[es_event(id = "UserId")] + enum UserEvent { + Initialized { id: UserId, name: String }, + NameUpdated { name: String }, + Deactivated { reason: String }, + AccountClosed {}, + } + }; + let event = EsEvent::from_derive_input(&input).unwrap(); + let mut tokens = TokenStream::new(); + event.to_tokens(&mut tokens); + + let expected = quote! { + impl es_entity::EsEvent for UserEvent { + type EntityId = UserId; + + fn event_context() -> bool { + false + } + + fn event_type(&self) -> &'static str { + match self { + Self::Initialized { .. } => "initialized", + Self::NameUpdated { .. } => "name_updated", + Self::Deactivated { .. } => "deactivated", + Self::AccountClosed { .. } => "account_closed", + } + } + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/lib.rs b/es-entity-macros-sqlite/src/lib.rs new file mode 100644 index 00000000..5beace39 --- /dev/null +++ b/es-entity-macros-sqlite/src/lib.rs @@ -0,0 +1,146 @@ +#![cfg_attr(feature = "fail-on-warnings", deny(warnings))] +#![cfg_attr(feature = "fail-on-warnings", deny(clippy::all))] +#![forbid(unsafe_code)] + +mod entity; +mod es_event_context; +mod event; +mod query; +mod repo; +mod retry_on_concurrent_modification; + +use proc_macro::TokenStream; +use syn::parse_macro_input; + +#[proc_macro_derive(EsEvent, attributes(es_event))] +pub fn es_event_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::DeriveInput); + match event::derive(ast) { + Ok(tokens) => tokens.into(), + Err(e) => e.write_errors().into(), + } +} + +#[proc_macro_attribute] +pub fn retry_on_concurrent_modification(args: TokenStream, input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::ItemFn); + match retry_on_concurrent_modification::make(args, ast) { + Ok(tokens) => tokens.into(), + Err(e) => e.write_errors().into(), + } +} + +/// Automatically captures function arguments into the event context. +/// +/// This attribute macro wraps functions to automatically insert specified arguments +/// into the current [`EventContext`](es_entity::context::EventContext), making them +/// available for audit trails when events are persisted. +/// +/// # Behavior +/// +/// - **For async functions**: Uses the [`WithEventContext`](es_entity::context::WithEventContext) +/// trait to propagate context across async boundaries +/// - **For sync functions**: Uses [`EventContext::fork()`](es_entity::context::EventContext::fork) +/// to create an isolated child context +/// +/// # Syntax +/// +/// ```rust,ignore +/// #[es_event_context] // No arguments captured +/// #[es_event_context(arg1)] // Capture single argument +/// #[es_event_context(arg1, arg2)] // Capture multiple arguments +/// ``` +/// +/// # Examples +/// +/// ## Async function with argument capture +/// ```rust,ignore +/// use es_entity_macros::es_event_context; +/// +/// impl UserService { +/// #[es_event_context(user_id, operation)] +/// async fn update_user(&self, user_id: UserId, operation: &str, data: UserData) -> Result<()> { +/// // user_id and operation are automatically added to context +/// // They will be included when events are persisted +/// self.repo.update(data).await +/// } +/// } +/// ``` +/// +/// ## Sync function with context isolation +/// ```rust,ignore +/// use es_entity_macros::es_event_context; +/// +/// impl Calculator { +/// #[es_event_context(transaction_id)] +/// fn process(&mut self, transaction_id: u64, amount: i64) { +/// // transaction_id is captured in an isolated context +/// // Parent context is restored when function exits +/// self.apply_transaction(amount); +/// } +/// } +/// ``` +/// +/// ## Manual context additions +/// ```rust,ignore +/// use es_entity_macros::es_event_context; +/// use es_entity::context::EventContext; +/// +/// #[es_event_context(request_id)] +/// async fn handle_request(request_id: String, data: RequestData) { +/// // request_id is automatically captured +/// +/// // You can still manually add more context +/// let mut ctx = EventContext::current(); +/// ctx.insert("timestamp", &chrono::Utc::now()).unwrap(); +/// +/// process_data(data).await; +/// } +/// ``` +/// +/// # Context Keys +/// +/// Arguments are captured using their parameter names as keys. For example, +/// `user_id: UserId` will be stored with key `"user_id"` in the context. +/// +/// # See Also +/// +/// - [`EventContext`](es_entity::context::EventContext) - The context management system +/// - [`WithEventContext`](es_entity::context::WithEventContext) - Async context propagation +/// - Event Context chapter in the book for complete usage patterns +#[proc_macro_attribute] +pub fn es_event_context(args: TokenStream, input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::ItemFn); + match es_event_context::make(args, ast) { + Ok(tokens) => tokens.into(), + Err(e) => e.write_errors().into(), + } +} + +#[proc_macro_derive(EsEntity, attributes(es_entity))] +pub fn es_entity_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::DeriveInput); + match entity::derive(ast) { + Ok(tokens) => tokens.into(), + Err(e) => e.write_errors().into(), + } +} + +#[proc_macro_derive(EsRepo, attributes(es_repo))] +pub fn es_repo_derive(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::DeriveInput); + match repo::derive(ast) { + Ok(tokens) => tokens.into(), + Err(e) => e.write_errors().into(), + } +} + +#[proc_macro] +#[doc(hidden)] +pub fn expand_es_query(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as query::QueryInput); + match query::expand(input) { + Ok(tokens) => tokens.into(), + Err(e) => e.write_errors().into(), + } +} diff --git a/es-entity-macros-sqlite/src/query/input.rs b/es-entity-macros-sqlite/src/query/input.rs new file mode 100644 index 00000000..5292fe95 --- /dev/null +++ b/es-entity-macros-sqlite/src/query/input.rs @@ -0,0 +1,194 @@ +use proc_macro2::Span; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, +}; + +pub struct QueryInput { + pub(super) tbl_prefix: Option, + pub(super) sql: String, + pub(super) sql_span: Span, + pub(super) arg_exprs: Vec, + pub(super) entity: Option, +} + +impl QueryInput { + pub(super) fn table_name(&self) -> darling::Result { + let query = self.sql.to_lowercase(); + let words: Vec<&str> = query.split_whitespace().collect(); + let from_pos = words.iter().position(|&word| word == "from").ok_or( + darling::Error::custom("Could not identify table name - no 'FROM' clause") + .with_span(&self.sql_span), + )?; + let table_name = words.get(from_pos + 1).ok_or( + darling::Error::custom("No word after 'FROM' clause").with_span(&self.sql_span), + )?; + let table_name = table_name.trim_end_matches(|c: char| !c.is_alphanumeric()); + Ok(table_name.to_string()) + } + + pub(super) fn table_name_without_prefix(&self) -> darling::Result { + let table_name = self.table_name()?; + if let Some(ignore_prefix) = &self.tbl_prefix + && table_name.starts_with(ignore_prefix) + { + return Ok(table_name[ignore_prefix.len() + 1..].to_string()); + } + Ok(table_name) + } + + pub(super) fn order_by(&self) -> String { + let columns = self.order_by_columns(); + if columns.is_empty() { + "i.id,".to_string() + } else { + columns.join(", ") + ", i.id," + } + } + + fn order_by_columns(&self) -> Vec { + use regex::Regex; + let re = Regex::new(r"(?i)ORDER\s+BY\s+(.+?)(?:\s+(?:LIMIT|OFFSET)|\s*;?\s*$)").unwrap(); + + if let Some(captures) = re.captures(&self.sql.to_lowercase()) + && let Some(order_by_clause) = captures.get(1) + { + return order_by_clause + .as_str() + .split(',') + .map(|s| { + let trimmed = s.trim(); + // Strip any existing alias prefix (e.g., "a.id" -> "id") + let column = if let Some(dot_pos) = trimmed.rfind('.') { + &trimmed[dot_pos + 1..] + } else { + trimmed + }; + format!("i.{}", column) + }) + .filter(|s| !s.is_empty()) + .collect(); + } + + Vec::new() + } +} + +impl Parse for QueryInput { + fn parse(input: ParseStream) -> syn::Result { + let mut sql: Option<(String, Span)> = None; + let mut args: Option> = None; + let mut expect_comma = false; + let mut tbl_prefix = None; + let mut entity = None; + + while !input.is_empty() { + if expect_comma { + let _ = input.parse::()?; + } + let key: syn::Ident = input.parse()?; + + let _ = input.parse::()?; + + if key == "tbl_prefix" { + tbl_prefix = Some(input.parse::()?.value()); + } else if key == "sql" { + sql = Some(( + Punctuated::::parse_separated_nonempty(input)? + .iter() + .map(syn::LitStr::value) + .collect(), + input.span(), + )); + } else if key == "args" { + let exprs = input.parse::()?; + args = Some(exprs.elems.into_iter().collect()) + } else if key == "entity" { + entity = Some(input.parse::()?); + } else { + let message = format!("unexpected input key: {key}"); + return Err(syn::Error::new_spanned(key, message)); + } + + expect_comma = true; + } + + let (sql, sql_span) = sql.ok_or_else(|| input.error("expected `sql` key"))?; + + Ok(QueryInput { + tbl_prefix, + sql, + sql_span, + arg_exprs: args.unwrap_or_default(), + entity, + }) + } +} + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::*; + + #[test] + fn parse_input() { + let input: QueryInput = parse_quote!( + tbl_prefix = "ignore_prefix", + sql = "SELECT * FROM ignore_prefix_users WHERE name = $1", + args = [id] + ); + assert_eq!(input.tbl_prefix, Some("ignore_prefix".to_string())); + assert_eq!( + input.sql, + "SELECT * FROM ignore_prefix_users WHERE name = $1" + ); + assert_eq!(input.arg_exprs[0], parse_quote!(id)); + assert_eq!(input.table_name_without_prefix().unwrap(), "users"); + } + + #[test] + fn test_order_by_columns() { + let test_cases = vec![ + ( + "SELECT id FROM entities WHERE (id > $2) OR $2 IS NULL ORDER BY id LIMIT $1", + vec!["i.id"], + ), + ( + "select id from entities order by name asc, date desc", + vec!["i.name asc", "i.date desc"], + ), + ("SELECT TOP 10 id FROM entities Order By id", vec!["i.id"]), + ( + "select id from entities ORDER BY id offset 10", + vec!["i.id"], + ), + ("select a.id from entities a ORDER BY a.id", vec!["i.id"]), + ("SELECT id FROM entities orDer bY id;", vec!["i.id"]), + ( + "SELECT * FROM users WHERE age > 18 ORDER BY last_name, first_name DESC LIMIT 10", + vec!["i.last_name", "i.first_name desc"], + ), + ( + "SELECT * FROM products ORDER BY price ASC, stock DESC, name", + vec!["i.price asc", "i.stock desc", "i.name"], + ), + ("SELECT * FROM orders", vec![]), + ( + "SELECT * FROM orders ORDER BY orders NULLS FIRST, id", + vec!["i.orders nulls first", "i.id"], + ), + ]; + + for (sql, expected) in test_cases { + let input = QueryInput { + tbl_prefix: None, + sql: sql.to_string(), + sql_span: Span::call_site(), + arg_exprs: vec![], + entity: None, + }; + assert_eq!(input.order_by_columns(), expected, "Failed for SQL: {sql}",); + } + } +} diff --git a/es-entity-macros-sqlite/src/query/mod.rs b/es-entity-macros-sqlite/src/query/mod.rs new file mode 100644 index 00000000..3c0f8169 --- /dev/null +++ b/es-entity-macros-sqlite/src/query/mod.rs @@ -0,0 +1,259 @@ +mod input; + +use convert_case::{Case, Casing}; +use darling::ToTokens; +use proc_macro2::{Span, TokenStream}; +use quote::{TokenStreamExt, quote}; + +pub use input::QueryInput; + +pub fn expand(input: QueryInput) -> darling::Result { + let query = EsQuery::from(input); + Ok(quote!(#query)) +} + +pub struct EsQuery { + input: QueryInput, +} + +impl From for EsQuery { + fn from(input: QueryInput) -> Self { + Self { input } + } +} + +/// Convert `$N` bind parameters in SQL to SQLite-style `?N`. +fn pg_to_sqlite_params(sql: &str) -> String { + use regex::Regex; + let re = Regex::new(r"\$(\d+)").unwrap(); + re.replace_all(sql, "?$1").to_string() +} + +/// Strip sqlx-style type annotations (`as CustomType`) from a cast +/// expression while preserving actual Rust primitive casts (`as i64`). +/// +/// In `query_as!`, `id as UserId` is a type annotation for sqlx, not a +/// Rust cast. We strip those. But `(first + 1) as i64` is a genuine +/// type conversion we must keep. +/// +/// Heuristic: if the target type is a single-segment path whose ident +/// starts with a lowercase letter, it's a primitive cast – keep it. +/// Otherwise strip the cast. +fn strip_cast(expr: &syn::Expr) -> &syn::Expr { + match expr { + syn::Expr::Cast(cast) => { + if is_primitive_cast(&cast.ty) { + expr // keep `(first + 1) as i64` + } else { + &cast.expr // strip `id as UserId` + } + } + other => other, + } +} + +fn is_primitive_cast(ty: &syn::Type) -> bool { + if let syn::Type::Path(path) = ty { + if let Some(ident) = path.path.get_ident() { + let s = ident.to_string(); + s.starts_with(|c: char| c.is_ascii_lowercase()) + } else { + false + } + } else { + false + } +} + +impl ToTokens for EsQuery { + fn to_tokens(&self, tokens: &mut TokenStream) { + let singular = pluralizer::pluralize( + &self + .input + .table_name() + .expect("Could not identify table name"), + 1, + false, + ); + let entity = if let Some(entity_ty) = &self.input.entity { + entity_ty.clone() + } else { + let singular_without_prefix = pluralizer::pluralize( + &self + .input + .table_name_without_prefix() + .expect("Could not identify table name"), + 1, + false, + ); + syn::Ident::new( + &singular_without_prefix.to_case(Case::UpperCamel), + Span::call_site(), + ) + }; + + let entity_snake = entity.to_string().to_case(Case::Snake); + let repo_types_mod = + syn::Ident::new(&format!("{entity_snake}_repo_types"), Span::call_site()); + let order_by = self.input.order_by(); + + let events_table = format!("{singular}_events"); + let args = &self.input.arg_exprs; + let context_arg_num = args.len() + 1; + + // Convert $N to ?N in the user-provided SQL + let user_sql = pg_to_sqlite_params(&self.input.sql); + + let query = format!( + "WITH entities AS ({}) SELECT i.id AS entity_id, e.sequence, e.event, CASE WHEN ?{} THEN e.context ELSE NULL END AS context, e.recorded_at FROM {} e JOIN entities i ON i.id = e.id ORDER BY {} e.sequence", + user_sql, context_arg_num, events_table, order_by + ); + + // Generate .bind() calls for each arg, stripping `as Type` casts + let bind_exprs: Vec<&syn::Expr> = args.iter().map(strip_cast).collect(); + + tokens.append_all(quote! { + { + use #repo_types_mod::*; + use es_entity::prelude::sqlx::Row as _; + + es_entity::EsQuery::::EsQueryFlavor, _, _>::new( + sqlx::query(#query) + #(.bind(#bind_exprs))* + .bind(<<::Entity as EsEntity>::Event>::event_context()) + .try_map(|row: es_entity::db::Row| -> Result { + Ok(Repo__DbEvent { + entity_id: row.try_get("entity_id")?, + sequence: row.try_get("sequence")?, + event: row.try_get("event")?, + context: row.try_get("context")?, + recorded_at: row.try_get("recorded_at")?, + }) + }) + ) + } + }); + } +} + +#[cfg(test)] +mod tests { + use syn::parse_quote; + + use super::*; + + #[test] + fn query() { + let input: QueryInput = parse_quote!( + sql = "SELECT * FROM users WHERE id = $1", + args = [id as UserId] + ); + + let query = EsQuery::from(input); + let mut tokens = TokenStream::new(); + query.to_tokens(&mut tokens); + + let expected = quote! { + { + use user_repo_types::*; + use es_entity::prelude::sqlx::Row as _; + + es_entity::EsQuery::::EsQueryFlavor, _, _>::new( + sqlx::query("WITH entities AS (SELECT * FROM users WHERE id = ?1) SELECT i.id AS entity_id, e.sequence, e.event, CASE WHEN ?2 THEN e.context ELSE NULL END AS context, e.recorded_at FROM user_events e JOIN entities i ON i.id = e.id ORDER BY i.id, e.sequence") + .bind(id) + .bind(<<::Entity as EsEntity>::Event>::event_context()) + .try_map(|row: es_entity::db::Row| -> Result { + Ok(Repo__DbEvent { + entity_id: row.try_get("entity_id")?, + sequence: row.try_get("sequence")?, + event: row.try_get("event")?, + context: row.try_get("context")?, + recorded_at: row.try_get("recorded_at")?, + }) + }) + ) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn query_with_entity_ty() { + let input: QueryInput = parse_quote!( + entity = MyCustomEntity, + sql = "SELECT * FROM my_custom_table WHERE id = $1", + args = [id as MyCustomEntityId] + ); + + let query = EsQuery::from(input); + let mut tokens = TokenStream::new(); + query.to_tokens(&mut tokens); + + let expected = quote! { + { + use my_custom_entity_repo_types::*; + use es_entity::prelude::sqlx::Row as _; + + es_entity::EsQuery::::EsQueryFlavor, _, _>::new( + sqlx::query("WITH entities AS (SELECT * FROM my_custom_table WHERE id = ?1) SELECT i.id AS entity_id, e.sequence, e.event, CASE WHEN ?2 THEN e.context ELSE NULL END AS context, e.recorded_at FROM my_custom_table_events e JOIN entities i ON i.id = e.id ORDER BY i.id, e.sequence") + .bind(id) + .bind(<<::Entity as EsEntity>::Event>::event_context()) + .try_map(|row: es_entity::db::Row| -> Result { + Ok(Repo__DbEvent { + entity_id: row.try_get("entity_id")?, + sequence: row.try_get("sequence")?, + event: row.try_get("event")?, + context: row.try_get("context")?, + recorded_at: row.try_get("recorded_at")?, + }) + }) + ) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn query_with_order() { + let input: QueryInput = parse_quote!( + sql = "SELECT name, id FROM entities WHERE ((name, id) > ($3, $2)) OR $2 IS NULL ORDER BY name, id LIMIT $1", + args = [ + (first + 1) as i64, + id as Option, + name as Option + ] + ); + + let query = EsQuery::from(input); + let mut tokens = TokenStream::new(); + query.to_tokens(&mut tokens); + + let expected = quote! { + { + use entity_repo_types::*; + use es_entity::prelude::sqlx::Row as _; + + es_entity::EsQuery::::EsQueryFlavor, _, _>::new( + sqlx::query("WITH entities AS (SELECT name, id FROM entities WHERE ((name, id) > (?3, ?2)) OR ?2 IS NULL ORDER BY name, id LIMIT ?1) SELECT i.id AS entity_id, e.sequence, e.event, CASE WHEN ?4 THEN e.context ELSE NULL END AS context, e.recorded_at FROM entity_events e JOIN entities i ON i.id = e.id ORDER BY i.name, i.id, i.id, e.sequence") + .bind((first + 1) as i64) + .bind(id) + .bind(name) + .bind(<<::Entity as EsEntity>::Event>::event_context()) + .try_map(|row: es_entity::db::Row| -> Result { + Ok(Repo__DbEvent { + entity_id: row.try_get("entity_id")?, + sequence: row.try_get("sequence")?, + event: row.try_get("event")?, + context: row.try_get("context")?, + recorded_at: row.try_get("recorded_at")?, + }) + }) + ) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/begin.rs b/es-entity-macros-sqlite/src/repo/begin.rs new file mode 100644 index 00000000..c9339bd2 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/begin.rs @@ -0,0 +1,60 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::{ClockFieldInfo, RepositoryOptions}; + +pub struct Begin<'a> { + clock_field: ClockFieldInfo<'a>, +} + +impl<'a> From<&'a RepositoryOptions> for Begin<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + clock_field: opts.clock_field(), + } + } +} + +impl ToTokens for Begin<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let begin_op_body = match &self.clock_field { + ClockFieldInfo::None => { + // No clock field - always use global clock + quote! { + self.begin_op_with_clock(es_entity::clock::Clock::handle()).await + } + } + ClockFieldInfo::Optional(clock_field) => { + // Optional clock field - use if Some, fallback to global + quote! { + match &self.#clock_field { + Some(clock) => self.begin_op_with_clock(clock).await, + None => self.begin_op_with_clock(es_entity::clock::Clock::handle()).await, + } + } + } + ClockFieldInfo::Required(clock_field) => { + // Required clock field - always use it + quote! { + self.begin_op_with_clock(&self.#clock_field).await + } + } + }; + + tokens.append_all(quote! { + #[inline(always)] + pub async fn begin_op(&self) -> Result, sqlx::Error> { + #begin_op_body + } + + #[inline(always)] + pub async fn begin_op_with_clock( + &self, + clock: &es_entity::clock::ClockHandle, + ) -> Result, sqlx::Error> { + es_entity::DbOp::init_with_clock(self.pool(), clock).await + } + }); + } +} diff --git a/es-entity-macros-sqlite/src/repo/combo_cursor.rs b/es-entity-macros-sqlite/src/repo/combo_cursor.rs new file mode 100644 index 00000000..e80a6dc7 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/combo_cursor.rs @@ -0,0 +1,337 @@ +use convert_case::{Case, Casing}; +use darling::ToTokens; +use proc_macro2::{Span, TokenStream}; +use quote::{TokenStreamExt, quote}; + +use super::{list_by_fn::CursorStruct, options::*}; + +pub struct ComboCursor<'a> { + entity: &'a syn::Ident, + cursors: Vec>, +} + +impl<'a> ComboCursor<'a> { + pub fn new(opts: &'a RepositoryOptions, cursors: Vec>) -> Self { + Self { + entity: opts.entity(), + cursors, + } + } + + #[cfg(test)] + pub fn new_test(entity: &'a syn::Ident, cursors: Vec>) -> Self { + Self { entity, cursors } + } + + pub fn ident(&self) -> syn::Ident { + let entity_name = pluralizer::pluralize(&format!("{}", self.entity), 2, false); + syn::Ident::new( + &format!("{entity_name}_cursor").to_case(Case::UpperCamel), + Span::call_site(), + ) + } + + pub fn tag(column: &Column) -> syn::Ident { + let tag_name = format!("By{}", column.name()); + syn::Ident::new(&tag_name, Span::call_site()) + } + + pub fn variants(&self) -> TokenStream { + let variants = self + .cursors + .iter() + .map(|cursor| { + let tag = Self::tag(cursor.column); + let ident = cursor.ident(); + quote! { + #tag(#ident), + } + }) + .collect::(); + + quote! { + #variants + } + } + + pub fn trait_impls(&self) -> TokenStream { + let self_ident = self.ident(); + let trait_impls = self + .cursors + .iter() + .map(|cursor| { + let tag = + syn::Ident::new(&format!("By{}", cursor.column.name()), Span::call_site()); + let ident = cursor.ident(); + quote! { + impl From<#ident> for #self_ident { + fn from(cursor: #ident) -> Self { + Self::#tag(cursor) + } + } + + impl TryFrom<#self_ident> for #ident { + type Error = es_entity::CursorDestructureError; + + fn try_from(cursor: #self_ident) -> Result { + match cursor { + #self_ident::#tag(cursor) => Ok(cursor), + _ => Err(es_entity::CursorDestructureError::from((stringify!(#self_ident), stringify!(#ident)))), + } + } + } + } + }) + .collect::(); + + quote! { + #trait_impls + } + } + + pub fn sort_by_name(&self) -> syn::Ident { + let entity_name = pluralizer::pluralize(&format!("{}", self.entity), 2, false); + syn::Ident::new( + &format!("{entity_name}_sort_by").to_case(Case::UpperCamel), + Span::call_site(), + ) + } + + pub fn sort_by(&self) -> TokenStream { + let mut default = true; + let variants = self.cursors.iter().map(|cursor| { + let name = syn::Ident::new( + &format!("{}", cursor.column.name()).to_case(Case::UpperCamel), + Span::call_site(), + ); + if default { + default = false; + quote! { + #[default] + #name + } + } else { + quote! { + #name + } + } + }); + let name = self.sort_by_name(); + #[cfg(feature = "graphql")] + let mod_name = syn::Ident::new(&format!("{name}").to_case(Case::Snake), Span::call_site()); + #[cfg(feature = "graphql")] + let sort_by_enum = quote! { + mod #mod_name { + #[derive(es_entity::graphql::async_graphql::Enum, Default, Debug, Clone, Copy, PartialEq, Eq)] + #[graphql(crate = "es_entity::graphql::async_graphql")] + pub enum #name { + #(#variants),* + } + } + pub use #mod_name::#name; + }; + #[cfg(not(feature = "graphql"))] + let sort_by_enum = quote! { + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] + pub enum #name { + #(#variants),* + } + }; + quote! { + #sort_by_enum + } + } + + #[cfg(feature = "graphql")] + pub fn gql_cursor(&self) -> TokenStream { + let ident = self.ident(); + quote! { + impl es_entity::graphql::async_graphql::connection::CursorType for #ident { + type Error = String; + + fn encode_cursor(&self) -> String { + use es_entity::graphql::base64::{engine::general_purpose, Engine as _}; + let json = es_entity::prelude::serde_json::to_string(&self).expect("could not serialize token"); + general_purpose::STANDARD_NO_PAD.encode(json.as_bytes()) + } + + fn decode_cursor(s: &str) -> Result { + use es_entity::graphql::base64::{engine::general_purpose, Engine as _}; + let bytes = general_purpose::STANDARD_NO_PAD + .decode(s.as_bytes()) + .map_err(|e| e.to_string())?; + let json = String::from_utf8(bytes).map_err(|e| e.to_string())?; + es_entity::prelude::serde_json::from_str(&json).map_err(|e| e.to_string()) + } + } + } + } +} + +impl ToTokens for ComboCursor<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = self.ident(); + let variants = self.variants(); + let trait_impls = self.trait_impls(); + + tokens.append_all(quote! { + #[derive(Debug, serde::Serialize, serde::Deserialize)] + #[allow(clippy::enum_variant_names)] + #[serde(tag = "type")] + pub enum #ident { + #variants + } + + #trait_impls + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::repo::list_by_fn::CursorStruct; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn combo_cursor_generation() { + let entity = Ident::new("User", Span::call_site()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + let id = syn::Ident::new("UserId", Span::call_site()); + + let id_column = Column::for_id(syn::parse_str("UserId").unwrap()); + let name_column = Column::new( + syn::Ident::new("name", proc_macro2::Span::call_site()), + syn::parse_str("String").unwrap(), + ); + + let id_cursor = CursorStruct { + column: &id_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let name_cursor = CursorStruct { + column: &name_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let cursors = vec![id_cursor, name_cursor]; + + let combo_cursor = ComboCursor { + entity: &entity, + cursors, + }; + + let mut tokens = TokenStream::new(); + combo_cursor.to_tokens(&mut tokens); + + let expected = quote! { + #[derive(Debug, serde::Serialize, serde::Deserialize)] + #[allow(clippy::enum_variant_names)] + #[serde(tag = "type")] + pub enum UsersCursor { + Byid(UsersByIdCursor), + Byname(UsersByNameCursor), + } + + impl From for UsersCursor { + fn from(cursor: UsersByIdCursor) -> Self { + Self::Byid(cursor) + } + } + + impl TryFrom for UsersByIdCursor { + type Error = es_entity::CursorDestructureError; + + fn try_from(cursor: UsersCursor) -> Result { + match cursor { + UsersCursor::Byid(cursor) => Ok(cursor), + _ => Err(es_entity::CursorDestructureError::from((stringify!(UsersCursor), stringify!(UsersByIdCursor)))), + } + } + } + impl From for UsersCursor { + fn from(cursor: UsersByNameCursor) -> Self { + Self::Byname(cursor) + } + } + + impl TryFrom for UsersByNameCursor { + type Error = es_entity::CursorDestructureError; + + fn try_from(cursor: UsersCursor) -> Result { + match cursor { + UsersCursor::Byname(cursor) => Ok(cursor), + _ => Err(es_entity::CursorDestructureError::from((stringify!(UsersCursor), stringify!(UsersByNameCursor)))), + } + } + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn combo_cursor_sort_by_generation() { + let entity = Ident::new("Order", Span::call_site()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + let id = syn::Ident::new("OrderId", Span::call_site()); + + let id_column = Column::for_id(syn::parse_str("OrderId").unwrap()); + let status_column = Column::new( + syn::Ident::new("status", proc_macro2::Span::call_site()), + syn::parse_str("String").unwrap(), + ); + let created_at_column = Column::new( + syn::Ident::new("created_at", proc_macro2::Span::call_site()), + syn::parse_str("chrono::DateTime").unwrap(), + ); + + let id_cursor = CursorStruct { + column: &id_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let status_cursor = CursorStruct { + column: &status_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let created_at_cursor = CursorStruct { + column: &created_at_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let cursors = vec![id_cursor, status_cursor, created_at_cursor]; + + let combo_cursor = ComboCursor { + entity: &entity, + cursors, + }; + + let sort_by_tokens = combo_cursor.sort_by(); + + let expected = quote! { + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq)] + pub enum OrdersSortBy { + #[default] + Id, + Status, + CreatedAt + } + }; + + assert_eq!(sort_by_tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/create_all_fn.rs b/es-entity-macros-sqlite/src/repo/create_all_fn.rs new file mode 100644 index 00000000..a61a6bb9 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/create_all_fn.rs @@ -0,0 +1,286 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct CreateAllFn<'a> { + entity: &'a syn::Ident, + table_name: &'a str, + columns: &'a Columns, + create_error: syn::Ident, + nested_fn_names: Vec, + post_hydrate_error: Option<&'a syn::Type>, + post_persist_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> From<&'a RepositoryOptions> for CreateAllFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + table_name: opts.table_name(), + entity: opts.entity(), + create_error: opts.create_error(), + nested_fn_names: opts + .all_nested() + .map(|f| f.create_nested_fn_name()) + .collect(), + columns: &opts.columns, + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for CreateAllFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let create_error = &self.create_error; + + let nested = self.nested_fn_names.iter().map(|f| { + quote! { + self.#f(op, &mut entity).await?; + } + }); + let maybe_mut_entity = if self.nested_fn_names.is_empty() { + quote! { entity } + } else { + quote! { mut entity } + }; + + let table_name = self.table_name; + + let column_names = self.columns.insert_column_names(); + let placeholders = self.columns.insert_placeholders(0); + let args = self.columns.create_query_args(); + + let query = format!( + "INSERT INTO {} ({}, created_at) VALUES ({}, COALESCE(?{}, datetime('now')))", + table_name, + column_names.join(", "), + placeholders, + column_names.len() + 1, + ); + + let assignments = self + .columns + .variable_assignments_for_create(syn::parse_quote! { new_entity }); + + #[cfg(feature = "instrument")] + let (instrument_attr, error_recording) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.create_all", repo_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, count = new_entities.len(), error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr, error_recording) = (quote! {}, quote! {}); + + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + self.execute_post_hydrate_hook(&entity).map_err(#create_error::PostHydrateError)?; + } + } else { + quote! {} + }; + + let post_persist_check = if self.post_persist_error.is_some() { + quote! { + self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#create_error::PostPersistHookError)?; + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + pub async fn create_all( + &self, + new_entities: Vec<<#entity as es_entity::EsEntity>::New> + ) -> Result, #create_error> { + let mut op = self.begin_op().await?; + let res = self.create_all_in_op(&mut op, new_entities).await?; + op.commit().await?; + Ok(res) + } + + #instrument_attr + pub async fn create_all_in_op( + &self, + op: &mut OP, + new_entities: Vec<<#entity as es_entity::EsEntity>::New> + ) -> Result, #create_error> + where + OP: es_entity::AtomicOperation + { + let __result: Result, #create_error> = async { + let mut res = Vec::new(); + if new_entities.is_empty() { + return Ok(res); + } + + let now = op.maybe_now(); + for new_entity in new_entities.iter() { + #assignments + + sqlx::query(#query) + #(#args)* + .bind(now) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + #create_error::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => #create_error::Sqlx(e), + })?; + } + + let mut all_events: Vec::Event>> = new_entities.into_iter().map(Self::convert_new).collect(); + let mut n_persisted = Self::extract_concurrent_modification( + self.persist_events_batch(op, &mut all_events).await, + #create_error::ConcurrentModification, + )?; + + for events in all_events.into_iter() { + let n_events = n_persisted.remove(events.id()).expect("n_events exists"); + let #maybe_mut_entity = Self::hydrate_entity(events)?; + + #(#nested)* + + #post_hydrate_check + #post_persist_check + res.push(entity); + } + + Ok(res) + }.await; + + #error_recording + __result + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn create_all_fn() { + let entity = Ident::new("Entity", Span::call_site()); + let create_error = syn::Ident::new("EntityCreateError", Span::call_site()); + + use darling::FromMeta; + let input: syn::Meta = syn::parse_quote!(columns(id = "EntityId", name = "String",)); + let columns = Columns::from_meta(&input).expect("Failed to parse Fields"); + + let create_fn = CreateAllFn { + table_name: "entities", + entity: &entity, + create_error, + columns: &columns, + nested_fn_names: Vec::new(), + post_hydrate_error: None, + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + create_fn.to_tokens(&mut tokens); + + let mut tokens = TokenStream::new(); + create_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn create_all( + &self, + new_entities: Vec<::New> + ) -> Result, EntityCreateError> { + let mut op = self.begin_op().await?; + let res = self.create_all_in_op(&mut op, new_entities).await?; + op.commit().await?; + Ok(res) + } + + pub async fn create_all_in_op( + &self, + op: &mut OP, + new_entities: Vec<::New> + ) -> Result, EntityCreateError> + where + OP: es_entity::AtomicOperation + { + let __result: Result, EntityCreateError> = async { + let mut res = Vec::new(); + if new_entities.is_empty() { + return Ok(res); + } + + let now = op.maybe_now(); + for new_entity in new_entities.iter() { + let id = &new_entity.id; + let name = &new_entity.name; + + sqlx::query("INSERT INTO entities (id, name, created_at) VALUES (?1, ?2, COALESCE(?3, datetime('now')))") + .bind(id) + .bind(name) + .bind(now) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityCreateError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityCreateError::Sqlx(e), + })?; + } + + let mut all_events: Vec::Event>> = new_entities.into_iter().map(Self::convert_new).collect(); + let mut n_persisted = Self::extract_concurrent_modification( + self.persist_events_batch(op, &mut all_events).await, + EntityCreateError::ConcurrentModification, + )?; + + for events in all_events.into_iter() { + let n_events = n_persisted.remove(events.id()).expect("n_events exists"); + let entity = Self::hydrate_entity(events)?; + + res.push(entity); + } + + Ok(res) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/create_fn.rs b/es-entity-macros-sqlite/src/repo/create_fn.rs new file mode 100644 index 00000000..45a384cf --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/create_fn.rs @@ -0,0 +1,393 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct CreateFn<'a> { + entity: &'a syn::Ident, + table_name: &'a str, + columns: &'a Columns, + create_error: syn::Ident, + nested_fn_names: Vec, + post_hydrate_error: Option<&'a syn::Type>, + post_persist_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> From<&'a RepositoryOptions> for CreateFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + table_name: opts.table_name(), + entity: opts.entity(), + create_error: opts.create_error(), + nested_fn_names: opts + .all_nested() + .map(|f| f.create_nested_fn_name()) + .collect(), + columns: &opts.columns, + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for CreateFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let create_error = &self.create_error; + + let nested = self.nested_fn_names.iter().map(|f| { + quote! { + self.#f(op, &mut entity).await?; + } + }); + let maybe_mut_entity = if self.nested_fn_names.is_empty() { + quote! { entity } + } else { + quote! { mut entity } + }; + let assignments = self + .columns + .variable_assignments_for_create(syn::parse_quote! { new_entity }); + + let table_name = self.table_name; + + let column_names = self.columns.insert_column_names(); + let placeholders = self.columns.insert_placeholders(0); + let args = self.columns.create_query_args(); + + let query = format!( + "INSERT INTO {} ({}, created_at) VALUES ({}, COALESCE(?{}, datetime('now')))", + table_name, + column_names.join(", "), + placeholders, + column_names.len() + 1, + ); + + #[cfg(feature = "instrument")] + let (instrument_attr, record_id, error_recording) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.create", repo_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, id = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + tracing::Span::current().record("id", tracing::field::debug(&id)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr, record_id, error_recording) = (quote! {}, quote! {}, quote! {}); + + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + self.execute_post_hydrate_hook(&entity).map_err(#create_error::PostHydrateError)?; + } + } else { + quote! {} + }; + + let post_persist_check = if self.post_persist_error.is_some() { + quote! { + self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#create_error::PostPersistHookError)?; + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + #[inline(always)] + fn convert_new(item: Entity) -> es_entity::EntityEvents + where + Entity: es_entity::IntoEvents, + Event: es_entity::EsEvent, + { + item.into_events() + } + + #[inline(always)] + fn hydrate_entity(events: es_entity::EntityEvents) -> Result + where + Entity: es_entity::TryFromEvents, + Event: es_entity::EsEvent, + { + Entity::try_from_events(events) + } + + pub async fn create( + &self, + new_entity: <#entity as es_entity::EsEntity>::New + ) -> Result<#entity, #create_error> { + let mut op = self.begin_op().await?; + let res = self.create_in_op(&mut op, new_entity).await?; + op.commit().await?; + Ok(res) + } + + #instrument_attr + pub async fn create_in_op( + &self, + op: &mut OP, + new_entity: <#entity as es_entity::EsEntity>::New + ) -> Result<#entity, #create_error> + where + OP: es_entity::AtomicOperation + { + let __result: Result<#entity, #create_error> = async { + #assignments + #record_id + + sqlx::query(#query) + #(#args)* + .bind(op.maybe_now()) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + #create_error::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => #create_error::Sqlx(e), + })?; + + let mut events = Self::convert_new(new_entity); + let n_events = Self::extract_concurrent_modification( + self.persist_events(op, &mut events).await, + #create_error::ConcurrentModification, + )?; + let #maybe_mut_entity = Self::hydrate_entity(events)?; + + #(#nested)* + + #post_hydrate_check + #post_persist_check + Ok(entity) + }.await; + + #error_recording + __result + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn create_fn() { + let entity = Ident::new("Entity", Span::call_site()); + let create_error = syn::Ident::new("EntityCreateError", Span::call_site()); + let id = Ident::new("EntityId", Span::call_site()); + let mut columns = Columns::default(); + columns.set_id_column(&id); + + let create_fn = CreateFn { + table_name: "entities", + entity: &entity, + create_error, + columns: &columns, + nested_fn_names: Vec::new(), + post_hydrate_error: None, + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + create_fn.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + fn convert_new(item: Entity) -> es_entity::EntityEvents + where + Entity: es_entity::IntoEvents, + Event: es_entity::EsEvent, + { + item.into_events() + } + + #[inline(always)] + fn hydrate_entity(events: es_entity::EntityEvents) -> Result + where + Entity: es_entity::TryFromEvents, + Event: es_entity::EsEvent, + { + Entity::try_from_events(events) + } + + pub async fn create( + &self, + new_entity: ::New + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.create_in_op(&mut op, new_entity).await?; + op.commit().await?; + Ok(res) + } + + pub async fn create_in_op( + &self, + op: &mut OP, + new_entity: ::New + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + let id = &new_entity.id; + + sqlx::query("INSERT INTO entities (id, created_at) VALUES (?1, COALESCE(?2, datetime('now')))") + .bind(id) + .bind(op.maybe_now()) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityCreateError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityCreateError::Sqlx(e), + })?; + + let mut events = Self::convert_new(new_entity); + let n_events = Self::extract_concurrent_modification( + self.persist_events(op, &mut events).await, + EntityCreateError::ConcurrentModification, + )?; + let entity = Self::hydrate_entity(events)?; + + Ok(entity) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn create_fn_with_columns() { + let entity = Ident::new("Entity", Span::call_site()); + let create_error = syn::Ident::new("EntityCreateError", Span::call_site()); + + use darling::FromMeta; + let input: syn::Meta = syn::parse_quote!(columns( + id = "EntityId", + name(ty = "String", create(accessor = "name()")) + )); + let columns = Columns::from_meta(&input).expect("Failed to parse Fields"); + + let create_fn = CreateFn { + table_name: "entities", + entity: &entity, + create_error, + columns: &columns, + nested_fn_names: Vec::new(), + post_hydrate_error: None, + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + create_fn.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + fn convert_new(item: Entity) -> es_entity::EntityEvents + where + Entity: es_entity::IntoEvents, + Event: es_entity::EsEvent, + { + item.into_events() + } + + #[inline(always)] + fn hydrate_entity(events: es_entity::EntityEvents) -> Result + where + Entity: es_entity::TryFromEvents, + Event: es_entity::EsEvent, + { + Entity::try_from_events(events) + } + + pub async fn create( + &self, + new_entity: ::New + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.create_in_op(&mut op, new_entity).await?; + op.commit().await?; + Ok(res) + } + + pub async fn create_in_op( + &self, + op: &mut OP, + new_entity: ::New + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + let id = &new_entity.id; + let name = &new_entity.name(); + + sqlx::query("INSERT INTO entities (id, name, created_at) VALUES (?1, ?2, COALESCE(?3, datetime('now')))") + .bind(id) + .bind(name) + .bind(op.maybe_now()) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityCreateError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityCreateError::Sqlx(e), + })?; + + let mut events = Self::convert_new(new_entity); + let n_events = Self::extract_concurrent_modification( + self.persist_events(op, &mut events).await, + EntityCreateError::ConcurrentModification, + )?; + let entity = Self::hydrate_entity(events)?; + + Ok(entity) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/delete_fn.rs b/es-entity-macros-sqlite/src/repo/delete_fn.rs new file mode 100644 index 00000000..9ecd98df --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/delete_fn.rs @@ -0,0 +1,330 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct DeleteFn<'a> { + modify_error: syn::Ident, + entity: &'a syn::Ident, + table_name: &'a str, + columns: &'a Columns, + delete_option: &'a DeleteOption, + post_persist_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> DeleteFn<'a> { + pub fn from(opts: &'a RepositoryOptions) -> Self { + Self { + entity: opts.entity(), + modify_error: opts.modify_error(), + columns: &opts.columns, + table_name: opts.table_name(), + delete_option: &opts.delete, + post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for DeleteFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + if !self.delete_option.is_soft() { + return; + } + + let entity = self.entity; + let modify_error = &self.modify_error; + + let assignments = self + .columns + .variable_assignments_for_update(syn::parse_quote! { entity }); + let column_updates = self.columns.sql_updates(); + let query = format!( + "UPDATE {} SET {}{}deleted = TRUE WHERE id = ?1", + self.table_name, + column_updates, + if column_updates.is_empty() { "" } else { ", " } + ); + let args = self.columns.update_query_args(); + + #[cfg(feature = "instrument")] + let (instrument_attr, record_id, error_recording) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.delete", repo_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, id = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + tracing::Span::current().record("id", tracing::field::debug(&entity.id)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr, record_id, error_recording) = (quote! {}, quote! {}, quote! {}); + + let post_persist_check = if self.post_persist_error.is_some() { + quote! { + self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#modify_error::PostPersistHookError)?; + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + pub async fn delete( + &self, + entity: #entity + ) -> Result<(), #modify_error> { + let mut op = self.begin_op().await?; + let res = self.delete_in_op(&mut op, entity).await?; + op.commit().await?; + Ok(res) + } + + #instrument_attr + pub async fn delete_in_op(&self, + op: &mut OP, + mut entity: #entity + ) -> Result<(), #modify_error> + where + OP: es_entity::AtomicOperation + { + let __result: Result<(), #modify_error> = async { + #assignments + #record_id + + sqlx::query(#query) + #(#args)* + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + #modify_error::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => #modify_error::Sqlx(e), + })?; + + let new_events = { + let events = Self::extract_events(&mut entity); + events.any_new() + }; + + if new_events { + let n_events = { + let events = Self::extract_events(&mut entity); + Self::extract_concurrent_modification( + self.persist_events(op, events).await, + #modify_error::ConcurrentModification, + )? + }; + + #post_persist_check + } + + Ok(()) + }.await; + + #error_recording + __result + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn delete_fn() { + let id = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let mut columns = Columns::default(); + columns.set_id_column(&id); + + let delete_fn = DeleteFn { + entity: &entity, + modify_error: syn::Ident::new("EntityModifyError", Span::call_site()), + table_name: "entities", + columns: &columns, + delete_option: &DeleteOption::Soft, + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + delete_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn delete( + &self, + entity: Entity + ) -> Result<(), EntityModifyError> { + let mut op = self.begin_op().await?; + let res = self.delete_in_op(&mut op, entity).await?; + op.commit().await?; + Ok(res) + } + + pub async fn delete_in_op( + &self, + op: &mut OP, + mut entity: Entity + ) -> Result<(), EntityModifyError> + where + OP: es_entity::AtomicOperation + { + let __result: Result<(), EntityModifyError> = async { + let id = &entity.id; + + sqlx::query("UPDATE entities SET deleted = TRUE WHERE id = ?1") + .bind(id) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityModifyError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityModifyError::Sqlx(e), + })?; + + let new_events = { + let events = Self::extract_events(&mut entity); + events.any_new() + }; + + if new_events { + let n_events = { + let events = Self::extract_events(&mut entity); + Self::extract_concurrent_modification( + self.persist_events(op, events).await, + EntityModifyError::ConcurrentModification, + )? + }; + } + + Ok(()) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn delete_fn_with_update_columns() { + let id = syn::parse_str("EntityId").unwrap(); + let entity = Ident::new("Entity", Span::call_site()); + + let columns = Columns::new( + &id, + [Column::new( + Ident::new("name", Span::call_site()), + syn::parse_str("String").unwrap(), + )], + ); + + let delete_fn = DeleteFn { + entity: &entity, + modify_error: syn::Ident::new("EntityModifyError", Span::call_site()), + table_name: "entities", + columns: &columns, + delete_option: &DeleteOption::Soft, + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + delete_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn delete( + &self, + entity: Entity + ) -> Result<(), EntityModifyError> { + let mut op = self.begin_op().await?; + let res = self.delete_in_op(&mut op, entity).await?; + op.commit().await?; + Ok(res) + } + + pub async fn delete_in_op( + &self, + op: &mut OP, + mut entity: Entity + ) -> Result<(), EntityModifyError> + where + OP: es_entity::AtomicOperation + { + let __result: Result<(), EntityModifyError> = async { + let id = &entity.id; + let name = &entity.name; + + sqlx::query("UPDATE entities SET name = ?2, deleted = TRUE WHERE id = ?1") + .bind(id) + .bind(name) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityModifyError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityModifyError::Sqlx(e), + })?; + + let new_events = { + let events = Self::extract_events(&mut entity); + events.any_new() + }; + + if new_events { + let n_events = { + let events = Self::extract_events(&mut entity); + Self::extract_concurrent_modification( + self.persist_events(op, events).await, + EntityModifyError::ConcurrentModification, + )? + }; + } + + Ok(()) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/error_types.rs b/es-entity-macros-sqlite/src/repo/error_types.rs new file mode 100644 index 00000000..0d6c4bc2 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/error_types.rs @@ -0,0 +1,1349 @@ +use convert_case::{Case, Casing}; +use proc_macro2::{Span, TokenStream}; +use quote::{ToTokens, quote}; + +use super::options::{PostHydrateHookConfig, PostPersistHookConfig, RepositoryOptions}; + +pub struct ErrorTypes<'a> { + entity: &'a syn::Ident, + column_enum: syn::Ident, + create_error: syn::Ident, + modify_error: syn::Ident, + find_error: syn::Ident, + query_error: syn::Ident, + column_variants: Vec, + nested: Vec, + post_hydrate_hook: &'a Option, + post_persist_hook: &'a Option, +} + +struct ColumnVariant { + variant_name: syn::Ident, + column_name: String, + constraint_names: Vec, +} + +struct NestedErrorInfo { + child_repo_ty: syn::Type, + variant_name: syn::Ident, + /// When set, error types are referenced by convention-based concrete names + /// (e.g., `FooCreateError`) instead of associated type projections + /// (`::CreateError`). This avoids generic params leaking + /// into module-level error enums. + nested_entity: Option, +} + +impl NestedErrorInfo { + fn create_error_ty(&self) -> TokenStream { + if let Some(entity) = &self.nested_entity { + let error_ident = syn::Ident::new(&format!("{entity}CreateError"), Span::call_site()); + quote! { #error_ident } + } else { + let child_repo_ty = &self.child_repo_ty; + quote! { <#child_repo_ty as es_entity::EsRepo>::CreateError } + } + } + + fn modify_error_ty(&self) -> TokenStream { + if let Some(entity) = &self.nested_entity { + let error_ident = syn::Ident::new(&format!("{entity}ModifyError"), Span::call_site()); + quote! { #error_ident } + } else { + let child_repo_ty = &self.child_repo_ty; + quote! { <#child_repo_ty as es_entity::EsRepo>::ModifyError } + } + } +} + +impl<'a> ErrorTypes<'a> { + pub fn new(opts: &'a RepositoryOptions) -> Self { + let table_name = opts.table_name(); + let column_variants: Vec = opts + .columns + .column_enum_columns() + .map(|col| { + let col_name = col.name().to_string(); + let variant_name = + syn::Ident::new(&col_name.to_case(Case::UpperCamel), Span::call_site()); + let mut constraint_names = vec![format!("{table_name}_{col_name}_key")]; + if col.is_id() { + constraint_names.push(format!("{table_name}_pkey")); + } + if let Some(custom) = col.custom_constraint() { + constraint_names.push(custom.to_string()); + } + ColumnVariant { + variant_name, + column_name: col_name, + constraint_names, + } + }) + .collect(); + + let type_param_idents: Vec<&syn::Ident> = + opts.generics.type_params().map(|p| &p.ident).collect(); + + let nested: Vec = opts + .all_nested() + .map(|f| { + let nested_entity = f.entity.clone().or_else(|| { + // Auto-derive entity name when the nested repo type uses parent generics. + // Conventions tried in order: + // 1. Strip "Repo" suffix: `ObligationRepo` → "Obligation" + // 2. Singularize: `OrderItems` → "OrderItem" + // Override with `#[es_repo(nested, entity = "...")]` if neither matches. + if !type_param_idents.is_empty() + && type_uses_any_generic(&f.ty, &type_param_idents) + { + derive_entity_from_repo_type(&f.ty) + } else { + None + } + }); + NestedErrorInfo { + child_repo_ty: f.ty.clone(), + variant_name: f.nested_variant_name(), + nested_entity, + } + }) + .collect(); + + Self { + entity: opts.entity(), + column_enum: opts.column_enum(), + create_error: opts.create_error(), + modify_error: opts.modify_error(), + find_error: opts.find_error(), + query_error: opts.query_error(), + column_variants, + nested, + post_hydrate_hook: &opts.post_hydrate_hook, + post_persist_hook: &opts.post_persist_hook, + } + } + + pub fn generate(&self) -> TokenStream { + let column_enum = self.generate_column_enum(); + let create_error = self.generate_create_error(); + let modify_error = self.generate_modify_error(); + let find_error = self.generate_find_error(); + let query_error = self.generate_query_error(); + + quote! { + #column_enum + #create_error + #modify_error + #find_error + #query_error + } + } + + pub fn generate_map_constraint_fn(&self) -> TokenStream { + self.generate_map_constraint_column() + } + + fn generate_column_enum(&self) -> TokenStream { + let column_enum = &self.column_enum; + let variants: Vec<_> = self + .column_variants + .iter() + .map(|v| &v.variant_name) + .collect(); + let display_arms: Vec<_> = self + .column_variants + .iter() + .map(|v| { + let variant = &v.variant_name; + let name = &v.column_name; + quote! { Self::#variant => write!(f, #name), } + }) + .collect(); + + quote! { + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum #column_enum { + #(#variants,)* + } + + impl std::fmt::Display for #column_enum { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + #(#display_arms)* + } + } + } + } + } + + fn generate_map_constraint_column(&self) -> TokenStream { + let column_enum = &self.column_enum; + let match_arms: Vec<_> = self + .column_variants + .iter() + .flat_map(|v| { + let variant = &v.variant_name; + v.constraint_names.iter().map(move |name| { + quote! { Some(#name) => Some(#column_enum::#variant), } + }) + }) + .collect(); + + quote! { + #[inline(always)] + fn map_constraint_column(constraint: Option<&str>) -> Option<#column_enum> { + match constraint { + #(#match_arms)* + _ => None, + } + } + } + } + + fn generate_create_error(&self) -> TokenStream { + let create_error = &self.create_error; + let column_enum = &self.column_enum; + let entity = self.entity; + + // Nested child variants + let nested_variants: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + let child_error_ty = n.create_error_ty(); + quote! { #variant(#child_error_ty), } + }) + .collect(); + let nested_display_arms: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + quote! { Self::#variant(e) => write!(f, "{}: {}", stringify!(#variant), e), } + }) + .collect(); + let nested_source_arms: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + quote! { Self::#variant(e) => Some(e), } + }) + .collect(); + let nested_from_impls: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + let child_error_ty = n.create_error_ty(); + quote! { + impl From<#child_error_ty> for #create_error { + fn from(e: #child_error_ty) -> Self { + Self::#variant(e) + } + } + } + }) + .collect(); + let nested_cm_checks: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + quote! { Self::#variant(e) => e.was_concurrent_modification(), } + }) + .collect(); + let nested_wd_checks: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + quote! { Self::#variant(e) => e.was_duplicate(), } + }) + .collect(); + let nested_dv_checks: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + quote! { Self::#variant(e) => e.duplicate_value(), } + }) + .collect(); + let nested_ph_checks: Vec<_> = self + .nested + .iter() + .map(|n| { + let variant = &n.variant_name; + quote! { Self::#variant(e) => e.was_post_hydrate_error(), } + }) + .collect(); + let create_ph_self_check = if self.post_hydrate_hook.is_some() { + quote! { Self::PostHydrateError(..) => true, } + } else { + quote! {} + }; + + let entity_name = entity.to_string(); + + let (ph_variant, ph_display_arm, ph_source_arm) = if let Some(config) = + &self.post_hydrate_hook + { + let error_ty = &config.error; + ( + quote! { PostHydrateError(#error_ty), }, + quote! { Self::PostHydrateError(e) => write!(f, "{}CreateError - PostHydrateError: {}", #entity_name, e), }, + quote! { Self::PostHydrateError(e) => Some(e), }, + ) + } else { + (quote! {}, quote! {}, quote! {}) + }; + + let (pp_variant, pp_display_arm, pp_source_arm) = if let Some(config) = + &self.post_persist_hook + { + let error_ty = &config.error; + ( + quote! { PostPersistHookError(#error_ty), }, + quote! { Self::PostPersistHookError(e) => write!(f, "{}CreateError - PostPersistHookError: {}", #entity_name, e), }, + quote! { Self::PostPersistHookError(e) => Some(e), }, + ) + } else { + (quote! {}, quote! {}, quote! {}) + }; + + quote! { + #[derive(Debug)] + pub enum #create_error { + Sqlx(sqlx::Error), + ConstraintViolation { column: Option<#column_enum>, value: Option, inner: sqlx::Error }, + ConcurrentModification, + HydrationError(es_entity::EntityHydrationError), + #pp_variant + #ph_variant + #(#nested_variants)* + } + + impl std::fmt::Display for #create_error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Sqlx(e) => write!(f, "{}CreateError - Sqlx: {}", #entity_name, e), + Self::ConstraintViolation { column, value, inner } => write!(f, "{}CreateError - ConstraintViolation({:?}, {:?}): {}", #entity_name, column, value, inner), + Self::ConcurrentModification => write!(f, "{}CreateError - ConcurrentModification", #entity_name), + Self::HydrationError(e) => write!(f, "{}CreateError - HydrationError: {}", #entity_name, e), + #pp_display_arm + #ph_display_arm + #(#nested_display_arms)* + } + } + } + + impl std::error::Error for #create_error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Sqlx(e) => Some(e), + Self::ConstraintViolation { inner, .. } => Some(inner), + Self::ConcurrentModification => None, + Self::HydrationError(e) => Some(e), + #pp_source_arm + #ph_source_arm + #(#nested_source_arms)* + } + } + } + + impl From for #create_error { + fn from(e: sqlx::Error) -> Self { + Self::Sqlx(e) + } + } + + impl From for #create_error { + fn from(e: es_entity::EntityHydrationError) -> Self { + Self::HydrationError(e) + } + } + + #(#nested_from_impls)* + + impl #create_error { + pub fn was_concurrent_modification(&self) -> bool { + match self { + Self::ConcurrentModification => true, + #(#nested_cm_checks)* + _ => false, + } + } + + pub fn was_duplicate(&self) -> bool { + match self { + Self::ConstraintViolation { .. } => true, + #(#nested_wd_checks)* + _ => false, + } + } + + pub fn was_duplicate_by(&self, column: #column_enum) -> bool { + matches!(self, Self::ConstraintViolation { column: Some(c), .. } if *c == column) + } + + pub fn duplicate_value(&self) -> Option<&str> { + match self { + Self::ConstraintViolation { value: Some(v), .. } => Some(v.as_str()), + #(#nested_dv_checks)* + _ => None, + } + } + + pub fn was_post_hydrate_error(&self) -> bool { + match self { + #create_ph_self_check + #(#nested_ph_checks)* + _ => false, + } + } + } + } + } + + fn generate_modify_error(&self) -> TokenStream { + let modify_error = &self.modify_error; + let column_enum = &self.column_enum; + let entity = self.entity; + + // Nested variants: both Modify and Create for each child + let nested_variants: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + let child_modify_ty = n.modify_error_ty(); + let child_create_ty = n.create_error_ty(); + vec![ + quote! { #modify_variant(#child_modify_ty), }, + quote! { #create_variant(#child_create_ty), }, + ] + }) + .collect(); + let nested_display_arms: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = syn::Ident::new( + &format!("{}Modify", n.variant_name), + Span::call_site(), + ); + let create_variant = syn::Ident::new( + &format!("{}Create", n.variant_name), + Span::call_site(), + ); + vec![ + quote! { Self::#modify_variant(e) => write!(f, "{}: {}", stringify!(#modify_variant), e), }, + quote! { Self::#create_variant(e) => write!(f, "{}: {}", stringify!(#create_variant), e), }, + ] + }) + .collect(); + let nested_source_arms: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + vec![ + quote! { Self::#modify_variant(e) => Some(e), }, + quote! { Self::#create_variant(e) => Some(e), }, + ] + }) + .collect(); + let nested_cm_checks: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + vec![ + quote! { Self::#modify_variant(e) => e.was_concurrent_modification(), }, + quote! { Self::#create_variant(e) => e.was_concurrent_modification(), }, + ] + }) + .collect(); + + let modify_nested_wd_checks: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + vec![ + quote! { Self::#modify_variant(e) => e.was_duplicate(), }, + quote! { Self::#create_variant(e) => e.was_duplicate(), }, + ] + }) + .collect(); + let nested_dv_checks: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + vec![ + quote! { Self::#modify_variant(e) => e.duplicate_value(), }, + quote! { Self::#create_variant(e) => e.duplicate_value(), }, + ] + }) + .collect(); + + let modify_nested_ph_checks: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + vec![ + quote! { Self::#modify_variant(e) => e.was_post_hydrate_error(), }, + quote! { Self::#create_variant(e) => e.was_post_hydrate_error(), }, + ] + }) + .collect(); + + let nested_from_impls: Vec<_> = self + .nested + .iter() + .flat_map(|n| { + let modify_variant = + syn::Ident::new(&format!("{}Modify", n.variant_name), Span::call_site()); + let create_variant = + syn::Ident::new(&format!("{}Create", n.variant_name), Span::call_site()); + let child_modify_ty = n.modify_error_ty(); + let child_create_ty = n.create_error_ty(); + vec![ + quote! { + impl From<#child_modify_ty> for #modify_error { + fn from(e: #child_modify_ty) -> Self { + Self::#modify_variant(e) + } + } + }, + quote! { + impl From<#child_create_ty> for #modify_error { + fn from(e: #child_create_ty) -> Self { + Self::#create_variant(e) + } + } + }, + ] + }) + .collect(); + + let entity_name = entity.to_string(); + + let (pp_variant, pp_display_arm, pp_source_arm) = if let Some(config) = + &self.post_persist_hook + { + let error_ty = &config.error; + ( + quote! { PostPersistHookError(#error_ty), }, + quote! { Self::PostPersistHookError(e) => write!(f, "{}ModifyError - PostPersistHookError: {}", #entity_name, e), }, + quote! { Self::PostPersistHookError(e) => Some(e), }, + ) + } else { + (quote! {}, quote! {}, quote! {}) + }; + + quote! { + #[derive(Debug)] + pub enum #modify_error { + Sqlx(sqlx::Error), + ConstraintViolation { column: Option<#column_enum>, value: Option, inner: sqlx::Error }, + ConcurrentModification, + #pp_variant + #(#nested_variants)* + } + + impl std::fmt::Display for #modify_error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Sqlx(e) => write!(f, "{}ModifyError - Sqlx: {}", #entity_name, e), + Self::ConstraintViolation { column, value, inner } => write!(f, "{}ModifyError - ConstraintViolation({:?}, {:?}): {}", #entity_name, column, value, inner), + Self::ConcurrentModification => write!(f, "{}ModifyError - ConcurrentModification", #entity_name), + #pp_display_arm + #(#nested_display_arms)* + } + } + } + + impl std::error::Error for #modify_error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Sqlx(e) => Some(e), + Self::ConstraintViolation { inner, .. } => Some(inner), + Self::ConcurrentModification => None, + #pp_source_arm + #(#nested_source_arms)* + } + } + } + + impl From for #modify_error { + fn from(e: sqlx::Error) -> Self { + Self::Sqlx(e) + } + } + + #(#nested_from_impls)* + + impl #modify_error { + pub fn was_concurrent_modification(&self) -> bool { + match self { + Self::ConcurrentModification => true, + #(#nested_cm_checks)* + _ => false, + } + } + + pub fn was_duplicate(&self) -> bool { + match self { + Self::ConstraintViolation { .. } => true, + #(#modify_nested_wd_checks)* + _ => false, + } + } + + pub fn was_duplicate_by(&self, column: #column_enum) -> bool { + matches!(self, Self::ConstraintViolation { column: Some(c), .. } if *c == column) + } + + pub fn duplicate_value(&self) -> Option<&str> { + match self { + Self::ConstraintViolation { value: Some(v), .. } => Some(v.as_str()), + #(#nested_dv_checks)* + _ => None, + } + } + + pub fn was_post_hydrate_error(&self) -> bool { + match self { + #(#modify_nested_ph_checks)* + _ => false, + } + } + } + } + } + + fn generate_find_error(&self) -> TokenStream { + let find_error = &self.find_error; + let query_error = &self.query_error; + let column_enum = &self.column_enum; + let entity = self.entity; + let entity_name = entity.to_string(); + + let (ph_variant, ph_display_arm, ph_source_arm, ph_from_arm) = if let Some(config) = + &self.post_hydrate_hook + { + let error_ty = &config.error; + ( + quote! { PostHydrateError(#error_ty), }, + quote! { Self::PostHydrateError(e) => write!(f, "{}FindError - PostHydrateError: {}", #entity_name, e), }, + quote! { Self::PostHydrateError(e) => Some(e), }, + quote! { #query_error::PostHydrateError(e) => Self::PostHydrateError(e), }, + ) + } else { + (quote! {}, quote! {}, quote! {}, quote! {}) + }; + let find_ph_self_check = if self.post_hydrate_hook.is_some() { + quote! { Self::PostHydrateError(..) => true, } + } else { + quote! {} + }; + + quote! { + #[derive(Debug)] + pub enum #find_error { + Sqlx(sqlx::Error), + NotFound { entity: &'static str, column: Option<#column_enum>, value: String }, + HydrationError(es_entity::EntityHydrationError), + #ph_variant + } + + impl std::fmt::Display for #find_error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Sqlx(e) => write!(f, "{}FindError - Sqlx: {}", #entity_name, e), + Self::NotFound { entity, column: Some(column), value } => write!(f, "{}FindError - NotFound({column}={value})", entity), + Self::NotFound { entity, column: None, value } => write!(f, "{}FindError - NotFound({})", entity, value), + Self::HydrationError(e) => write!(f, "{}FindError - HydrationError: {}", #entity_name, e), + #ph_display_arm + } + } + } + + impl std::error::Error for #find_error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Sqlx(e) => Some(e), + Self::NotFound { .. } => None, + Self::HydrationError(e) => Some(e), + #ph_source_arm + } + } + } + + impl From for #find_error { + fn from(e: sqlx::Error) -> Self { + Self::Sqlx(e) + } + } + + impl From for #find_error { + fn from(e: es_entity::EntityHydrationError) -> Self { + Self::HydrationError(e) + } + } + + impl From<#query_error> for #find_error { + fn from(e: #query_error) -> Self { + match e { + #query_error::Sqlx(e) => Self::Sqlx(e), + #query_error::HydrationError(e) => Self::HydrationError(e), + #query_error::CursorDestructureError(_) => unreachable!("CursorDestructureError cannot occur in find operations"), + #ph_from_arm + } + } + } + + impl #find_error { + pub fn was_not_found(&self) -> bool { + matches!(self, Self::NotFound { .. }) + } + + pub fn was_not_found_by(&self, column: #column_enum) -> bool { + matches!(self, Self::NotFound { column: Some(c), .. } if *c == column) + } + + pub fn not_found_value(&self) -> Option<&str> { + match self { + Self::NotFound { value, .. } => Some(value.as_str()), + _ => None, + } + } + + pub fn was_post_hydrate_error(&self) -> bool { + match self { + #find_ph_self_check + _ => false, + } + } + } + } + } + + fn generate_query_error(&self) -> TokenStream { + let query_error = &self.query_error; + let entity = self.entity; + let entity_name = entity.to_string(); + + let (ph_variant, ph_display_arm, ph_source_arm) = if let Some(config) = + &self.post_hydrate_hook + { + let error_ty = &config.error; + ( + quote! { PostHydrateError(#error_ty), }, + quote! { Self::PostHydrateError(e) => write!(f, "{}QueryError - PostHydrateError: {}", #entity_name, e), }, + quote! { Self::PostHydrateError(e) => Some(e), }, + ) + } else { + (quote! {}, quote! {}, quote! {}) + }; + let query_ph_self_check = if self.post_hydrate_hook.is_some() { + quote! { Self::PostHydrateError(..) => true, } + } else { + quote! {} + }; + + quote! { + #[derive(Debug)] + pub enum #query_error { + Sqlx(sqlx::Error), + HydrationError(es_entity::EntityHydrationError), + CursorDestructureError(es_entity::CursorDestructureError), + #ph_variant + } + + impl std::fmt::Display for #query_error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Sqlx(e) => write!(f, "{}QueryError - Sqlx: {}", #entity_name, e), + Self::HydrationError(e) => write!(f, "{}QueryError - HydrationError: {}", #entity_name, e), + Self::CursorDestructureError(e) => write!(f, "{}QueryError - CursorDestructureError: {}", #entity_name, e), + #ph_display_arm + } + } + } + + impl std::error::Error for #query_error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Sqlx(e) => Some(e), + Self::HydrationError(e) => Some(e), + Self::CursorDestructureError(e) => Some(e), + #ph_source_arm + } + } + } + + impl From for #query_error { + fn from(e: sqlx::Error) -> Self { + Self::Sqlx(e) + } + } + + impl From for #query_error { + fn from(e: es_entity::EntityHydrationError) -> Self { + Self::HydrationError(e) + } + } + + impl From for #query_error { + fn from(e: es_entity::CursorDestructureError) -> Self { + Self::CursorDestructureError(e) + } + } + + impl #query_error { + pub fn was_post_hydrate_error(&self) -> bool { + match self { + #query_ph_self_check + _ => false, + } + } + } + } + } +} + +/// Check if a type references any of the given idents (generic type params). +fn type_uses_any_generic(ty: &syn::Type, idents: &[&syn::Ident]) -> bool { + let ts = ty.to_token_stream(); + token_stream_contains_any(ts, idents) +} + +fn token_stream_contains_any(ts: proc_macro2::TokenStream, idents: &[&syn::Ident]) -> bool { + for tt in ts { + match tt { + proc_macro2::TokenTree::Ident(ref i) => { + if idents.iter().any(|id| *i == **id) { + return true; + } + } + proc_macro2::TokenTree::Group(g) => { + if token_stream_contains_any(g.stream(), idents) { + return true; + } + } + _ => {} + } + } + false +} + +/// Derive the entity name from a repo type using conventions: +/// 1. Strip `Repo` suffix: `ObligationRepo` → `Obligation` +/// 2. Singularize: `OrderItems` → `OrderItem` +/// +/// Returns `None` if neither convention matches. +fn derive_entity_from_repo_type(ty: &syn::Type) -> Option { + if let syn::Type::Path(type_path) = ty + && let Some(segment) = type_path.path.segments.last() + { + let name = segment.ident.to_string(); + + // Convention 1: strip "Repo" suffix + if let Some(entity_name) = name.strip_suffix("Repo") + && !entity_name.is_empty() + { + return Some(syn::Ident::new(entity_name, segment.ident.span())); + } + + // Convention 2: singularize plural name (e.g., OrderItems → OrderItem) + let singular = pluralizer::pluralize(&name, 1, false); + if singular != name { + return Some(syn::Ident::new(&singular, segment.ident.span())); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::{Ident, parse_quote}; + + fn make_error_types(nested: Vec) -> ErrorTypes<'static> { + // Leak entity ident to get a 'static reference for tests + let entity: &'static syn::Ident = + Box::leak(Box::new(Ident::new("Order", Span::call_site()))); + let post_hydrate_hook: &'static Option = Box::leak(Box::new(None)); + let post_persist_hook: &'static Option = Box::leak(Box::new(None)); + ErrorTypes { + entity, + column_enum: Ident::new("OrderColumn", Span::call_site()), + create_error: Ident::new("OrderCreateError", Span::call_site()), + modify_error: Ident::new("OrderModifyError", Span::call_site()), + find_error: Ident::new("OrderFindError", Span::call_site()), + query_error: Ident::new("OrderQueryError", Span::call_site()), + column_variants: vec![], + nested, + post_hydrate_hook, + post_persist_hook, + } + } + + #[test] + fn non_generic_nested_uses_associated_type() { + let error_types = make_error_types(vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: None, + }]); + + let tokens = error_types.generate_create_error(); + let output = tokens.to_string(); + + // Should use associated type projection (existing behavior) + assert!( + output.contains("< ItemRepo as es_entity :: EsRepo > :: CreateError"), + "Expected associated type projection, got: {}", + output + ); + } + + #[test] + fn generic_nested_with_entity_uses_concrete_name() { + let error_types = make_error_types(vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: Some(Ident::new("InterestAccrualCycle", Span::call_site())), + }]); + + let tokens = error_types.generate_create_error(); + let output = tokens.to_string(); + + // Should use concrete error type name, NOT associated type projection + assert!( + output.contains("InterestAccrualCycleCreateError"), + "Expected concrete error type name, got: {}", + output + ); + assert!( + !output.contains("ItemRepo"), + "Should not reference the generic repo type, got: {}", + output + ); + } + + #[test] + fn generic_nested_modify_error_uses_concrete_names() { + let error_types = make_error_types(vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: Some(Ident::new("InterestAccrualCycle", Span::call_site())), + }]); + + let tokens = error_types.generate_modify_error(); + let output = tokens.to_string(); + + // Should use concrete error type names for both Modify and Create variants + assert!( + output.contains("InterestAccrualCycleModifyError"), + "Expected concrete modify error type, got: {}", + output + ); + assert!( + output.contains("InterestAccrualCycleCreateError"), + "Expected concrete create error type, got: {}", + output + ); + assert!( + !output.contains("ItemRepo"), + "Should not reference the generic repo type, got: {}", + output + ); + } + + #[test] + fn mixed_nested_repos() { + let error_types = make_error_types(vec![ + NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: None, + }, + NestedErrorInfo { + child_repo_ty: parse_quote! { AccrualRepo }, + variant_name: Ident::new("Accruals", Span::call_site()), + nested_entity: Some(Ident::new("Accrual", Span::call_site())), + }, + ]); + + let tokens = error_types.generate_create_error(); + let output = tokens.to_string(); + + // Non-generic nested should use associated type projection + assert!( + output.contains("< ItemRepo as es_entity :: EsRepo > :: CreateError"), + "Expected associated type projection for non-generic repo, got: {}", + output + ); + // Generic nested with entity should use concrete name + assert!( + output.contains("AccrualCreateError"), + "Expected concrete error type for generic repo, got: {}", + output + ); + } + + #[test] + fn auto_derive_entity_from_repo_type_name() { + // When nested_entity is derived automatically (via derive_entity_from_repo_type), + // it strips the "Repo" suffix: ObligationRepo → Obligation + let error_types = make_error_types(vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ObligationRepo }, + variant_name: Ident::new("Obligations", Span::call_site()), + nested_entity: derive_entity_from_repo_type(&parse_quote! { ObligationRepo }), + }]); + + let tokens = error_types.generate_create_error(); + let output = tokens.to_string(); + + assert!( + output.contains("ObligationCreateError"), + "Expected auto-derived concrete error type, got: {}", + output + ); + assert!( + !output.contains("ObligationRepo"), + "Should not reference the generic repo type, got: {}", + output + ); + } + + #[test] + fn type_uses_any_generic_detects_params() { + let evt = Ident::new("Evt", Span::call_site()); + let idents = vec![&evt]; + + // Type with generic param + let ty: syn::Type = parse_quote! { SomeRepo }; + assert!(type_uses_any_generic(&ty, &idents)); + + // Type without generic param + let ty: syn::Type = parse_quote! { SomeRepo }; + assert!(!type_uses_any_generic(&ty, &idents)); + + // Type with different generic param + let ty: syn::Type = parse_quote! { SomeRepo }; + assert!(!type_uses_any_generic(&ty, &idents)); + } + + #[test] + fn derive_entity_strips_repo_suffix() { + let ty: syn::Type = parse_quote! { ObligationRepo }; + let entity = derive_entity_from_repo_type(&ty); + assert_eq!(entity.unwrap().to_string(), "Obligation"); + + let ty: syn::Type = parse_quote! { InterestAccrualRepo }; + let entity = derive_entity_from_repo_type(&ty); + assert_eq!(entity.unwrap().to_string(), "InterestAccrual"); + + // Non-generic also works + let ty: syn::Type = parse_quote! { ItemRepo }; + let entity = derive_entity_from_repo_type(&ty); + assert_eq!(entity.unwrap().to_string(), "Item"); + } + + #[test] + fn derive_entity_singularizes_plural_name() { + // Plural → singular convention + let ty: syn::Type = parse_quote! { OrderItems }; + let entity = derive_entity_from_repo_type(&ty); + assert_eq!(entity.unwrap().to_string(), "OrderItem"); + + let ty: syn::Type = parse_quote! { BillingPeriods }; + let entity = derive_entity_from_repo_type(&ty); + assert_eq!(entity.unwrap().to_string(), "BillingPeriod"); + + // Non-generic plural also works + let ty: syn::Type = parse_quote! { Users }; + let entity = derive_entity_from_repo_type(&ty); + assert_eq!(entity.unwrap().to_string(), "User"); + } + + #[test] + fn derive_entity_returns_none_for_unrecognized() { + // Neither Repo suffix nor plural → None + let ty: syn::Type = parse_quote! { SomeType }; + assert!(derive_entity_from_repo_type(&ty).is_none()); + + // Singular name without Repo suffix → None + let ty: syn::Type = parse_quote! { Obligation }; + assert!(derive_entity_from_repo_type(&ty).is_none()); + } + + // ----------------------------------------------------------------------- + // Hook variant and helper method generation tests + // ----------------------------------------------------------------------- + + fn make_error_types_with_hooks( + nested: Vec, + post_hydrate_hook: Option, + post_persist_hook: Option, + ) -> ErrorTypes<'static> { + let entity: &'static syn::Ident = + Box::leak(Box::new(Ident::new("Order", Span::call_site()))); + let ph: &'static Option = Box::leak(Box::new(post_hydrate_hook)); + let pp: &'static Option = Box::leak(Box::new(post_persist_hook)); + ErrorTypes { + entity, + column_enum: Ident::new("OrderColumn", Span::call_site()), + create_error: Ident::new("OrderCreateError", Span::call_site()), + modify_error: Ident::new("OrderModifyError", Span::call_site()), + find_error: Ident::new("OrderFindError", Span::call_site()), + query_error: Ident::new("OrderQueryError", Span::call_site()), + column_variants: vec![], + nested, + post_hydrate_hook: ph, + post_persist_hook: pp, + } + } + + fn ph_hook() -> PostHydrateHookConfig { + PostHydrateHookConfig { + method: Ident::new("validate", Span::call_site()), + error: syn::parse_str("MyHydrateError").unwrap(), + } + } + + fn pp_hook() -> PostPersistHookConfig { + PostPersistHookConfig { + method: Ident::new("on_persist", Span::call_site()), + error: syn::parse_str("MyPersistError").unwrap(), + } + } + + #[test] + fn create_error_without_hooks_omits_hook_variants() { + let et = make_error_types_with_hooks(vec![], None, None); + let output = et.generate_create_error().to_string(); + + assert!( + !output.contains("PostHydrateError"), + "should not contain PostHydrateError variant without hook: {output}" + ); + assert!( + !output.contains("PostPersistHookError"), + "should not contain PostPersistHookError variant without hook: {output}" + ); + } + + #[test] + fn create_error_without_hooks_still_generates_was_post_hydrate_error() { + let et = make_error_types_with_hooks(vec![], None, None); + let output = et.generate_create_error().to_string(); + + assert!( + output.contains("was_post_hydrate_error"), + "was_post_hydrate_error should always be generated: {output}" + ); + } + + #[test] + fn create_error_with_post_hydrate_hook_has_variant_and_self_check() { + let et = make_error_types_with_hooks(vec![], Some(ph_hook()), None); + let output = et.generate_create_error().to_string(); + + assert!( + output.contains("PostHydrateError (MyHydrateError)"), + "should contain PostHydrateError variant with custom type: {output}" + ); + assert!( + output.contains("was_post_hydrate_error"), + "should contain was_post_hydrate_error helper: {output}" + ); + } + + #[test] + fn create_error_with_post_persist_hook_has_variant() { + let et = make_error_types_with_hooks(vec![], None, Some(pp_hook())); + let output = et.generate_create_error().to_string(); + + assert!( + output.contains("PostPersistHookError (MyPersistError)"), + "should contain PostPersistHookError variant with custom type: {output}" + ); + } + + #[test] + fn create_error_nested_cascades_was_duplicate() { + let et = make_error_types(vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: None, + }]); + let output = et.generate_create_error().to_string(); + + // was_duplicate should cascade into nested Items variant + assert!( + output.contains("Self :: Items (e) => e . was_duplicate ()"), + "was_duplicate should cascade into nested variant: {output}" + ); + } + + #[test] + fn create_error_nested_cascades_was_post_hydrate_error() { + let et = make_error_types_with_hooks( + vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: None, + }], + Some(ph_hook()), + None, + ); + let output = et.generate_create_error().to_string(); + + // was_post_hydrate_error should cascade into nested Items variant + assert!( + output.contains("Self :: Items (e) => e . was_post_hydrate_error ()"), + "was_post_hydrate_error should cascade into nested variant: {output}" + ); + } + + #[test] + fn modify_error_with_post_persist_hook_has_variant() { + let et = make_error_types_with_hooks(vec![], None, Some(pp_hook())); + let output = et.generate_modify_error().to_string(); + + assert!( + output.contains("PostPersistHookError (MyPersistError)"), + "should contain PostPersistHookError variant with custom type: {output}" + ); + assert!( + !output.contains("PostHydrateError"), + "modify error should never have PostHydrateError: {output}" + ); + } + + #[test] + fn modify_error_without_hooks_still_generates_was_post_hydrate_error() { + let et = make_error_types_with_hooks(vec![], None, None); + let output = et.generate_modify_error().to_string(); + + assert!( + output.contains("was_post_hydrate_error"), + "was_post_hydrate_error should always be generated on ModifyError: {output}" + ); + } + + #[test] + fn modify_error_nested_cascades_was_duplicate_and_was_post_hydrate_error() { + let et = make_error_types(vec![NestedErrorInfo { + child_repo_ty: parse_quote! { ItemRepo }, + variant_name: Ident::new("Items", Span::call_site()), + nested_entity: None, + }]); + let output = et.generate_modify_error().to_string(); + + // was_duplicate cascades into both Modify and Create nested variants + assert!( + output.contains("Self :: ItemsModify (e) => e . was_duplicate ()"), + "was_duplicate should cascade into nested Modify variant: {output}" + ); + assert!( + output.contains("Self :: ItemsCreate (e) => e . was_duplicate ()"), + "was_duplicate should cascade into nested Create variant: {output}" + ); + // was_post_hydrate_error cascades into both + assert!( + output.contains("Self :: ItemsModify (e) => e . was_post_hydrate_error ()"), + "was_post_hydrate_error should cascade into nested Modify variant: {output}" + ); + assert!( + output.contains("Self :: ItemsCreate (e) => e . was_post_hydrate_error ()"), + "was_post_hydrate_error should cascade into nested Create variant: {output}" + ); + } + + #[test] + fn find_error_with_post_hydrate_hook_has_variant() { + let et = make_error_types_with_hooks(vec![], Some(ph_hook()), None); + let output = et.generate_find_error().to_string(); + + assert!( + output.contains("PostHydrateError (MyHydrateError)"), + "should contain PostHydrateError variant with custom type: {output}" + ); + assert!( + output.contains("was_post_hydrate_error"), + "should contain was_post_hydrate_error helper: {output}" + ); + } + + #[test] + fn find_error_without_hooks_still_generates_was_post_hydrate_error() { + let et = make_error_types_with_hooks(vec![], None, None); + let output = et.generate_find_error().to_string(); + + assert!( + output.contains("was_post_hydrate_error"), + "was_post_hydrate_error should always be generated on FindError: {output}" + ); + assert!( + !output.contains("PostHydrateError"), + "should not contain PostHydrateError variant without hook: {output}" + ); + } + + #[test] + fn query_error_with_post_hydrate_hook_has_variant() { + let et = make_error_types_with_hooks(vec![], Some(ph_hook()), None); + let output = et.generate_query_error().to_string(); + + assert!( + output.contains("PostHydrateError (MyHydrateError)"), + "should contain PostHydrateError variant with custom type: {output}" + ); + assert!( + output.contains("was_post_hydrate_error"), + "should contain was_post_hydrate_error helper: {output}" + ); + } + + #[test] + fn query_error_without_hooks_still_generates_was_post_hydrate_error() { + let et = make_error_types_with_hooks(vec![], None, None); + let output = et.generate_query_error().to_string(); + + assert!( + output.contains("was_post_hydrate_error"), + "was_post_hydrate_error should always be generated on QueryError: {output}" + ); + assert!( + !output.contains("PostHydrateError"), + "should not contain PostHydrateError variant without hook: {output}" + ); + } +} diff --git a/es-entity-macros-sqlite/src/repo/find_all_fn.rs b/es-entity-macros-sqlite/src/repo/find_all_fn.rs new file mode 100644 index 00000000..832f91ec --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/find_all_fn.rs @@ -0,0 +1,237 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct FindAllFn<'a> { + id: &'a syn::Ident, + entity: &'a syn::Ident, + table_name: &'a str, + events_table_name: &'a str, + query_error: syn::Ident, + any_nested: bool, + post_hydrate_error: Option<&'a syn::Type>, + repo_types_mod: syn::Ident, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> From<&'a RepositoryOptions> for FindAllFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + id: opts.id(), + entity: opts.entity(), + table_name: opts.table_name(), + events_table_name: opts.events_table_name(), + query_error: opts.query_error(), + any_nested: opts.any_nested(), + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + repo_types_mod: opts.repo_types_mod(), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for FindAllFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let id = self.id; + let entity = self.entity; + let query_error = &self.query_error; + let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested); + let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested); + let repo_types_mod = &self.repo_types_mod; + let table_name = self.table_name; + let events_table_name = self.events_table_name; + + let generics = if self.any_nested { + quote! { > } + } else { + quote! { <'a, Out: From<#entity>> } + }; + + let op_param = if self.any_nested { + quote! { op: &mut impl #query_fn_op_traits } + } else { + quote! { op: impl #query_fn_op_traits } + }; + + #[cfg(feature = "instrument")] + let instrument_attr = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.find_all", repo_name); + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, count = ids.len(), ids = tracing::field::debug(ids)), err)] + } + }; + #[cfg(not(feature = "instrument"))] + let instrument_attr = quote! {}; + + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + for __entity in &entities { + self.execute_post_hydrate_hook(__entity).map_err(#query_error::PostHydrateError)?; + } + } + } else { + quote! {} + }; + + let fetch_and_load = if self.any_nested { + quote! { + let db_events = (&mut *op).into_executor().fetch_all(query).await?; + let n = db_events.len(); + let (mut entities, _) = es_entity::EntityEvents::load_n::<#entity>(db_events.into_iter(), n)?; + Self::load_all_nested_in_op::<_, #query_error>(op, &mut entities).await?; + } + } else { + quote! { + let db_events = op.into_executor().fetch_all(query).await?; + let n = db_events.len(); + let (mut entities, _) = es_entity::EntityEvents::load_n::<#entity>(db_events.into_iter(), n)?; + } + }; + + tokens.append_all(quote! { + pub async fn find_all>( + &self, + ids: &[#id] + ) -> Result, #query_error> { + self.find_all_in_op(#query_fn_get_op, ids).await + } + + #instrument_attr + pub async fn find_all_in_op #generics( + &self, + #op_param, + ids: &[#id] + ) -> Result, #query_error> { + if ids.is_empty() { + return Ok(std::collections::HashMap::new()); + } + let placeholders: String = (1..=ids.len()) + .map(|i| format!("?{i}")) + .collect::>() + .join(", "); + let ctx_param = ids.len() + 1; + let query_str = format!( + "WITH entities AS (SELECT * FROM {} WHERE id IN ({})) \ + SELECT i.id AS entity_id, e.sequence, e.event, \ + CASE WHEN ?{} THEN e.context ELSE NULL END AS context, \ + e.recorded_at \ + FROM entities i JOIN {} e ON i.id = e.id ORDER BY e.id, e.sequence", + #table_name, + placeholders, + ctx_param, + #events_table_name, + ); + let mut query = es_entity::prelude::sqlx::query(&query_str); + for id in ids { + query = query.bind(id); + } + query = query.bind(<#repo_types_mod::Repo__Event as EsEvent>::event_context()); + let query = query.try_map(|row: es_entity::db::Row| -> Result<#repo_types_mod::Repo__DbEvent, sqlx::Error> { + use es_entity::prelude::sqlx::Row as _; + Ok(#repo_types_mod::Repo__DbEvent { + entity_id: row.try_get("entity_id")?, + sequence: row.try_get("sequence")?, + event: row.try_get("event")?, + context: row.try_get("context")?, + recorded_at: row.try_get("recorded_at")?, + }) + }); + #fetch_and_load + #post_hydrate_check + Ok(entities.into_iter().map(|u| (u.id.clone(), Out::from(u))).collect()) + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn find_all_fn() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + + let persist_fn = FindAllFn { + id: &id_type, + entity: &entity, + table_name: "entities", + events_table_name: "entity_events", + query_error, + any_nested: false, + post_hydrate_error: None, + repo_types_mod: syn::Ident::new("entity_repo_types", Span::call_site()), + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn find_all>( + &self, + ids: &[EntityId] + ) -> Result, EntityQueryError> { + self.find_all_in_op(self.pool(), ids).await + } + + pub async fn find_all_in_op<'a, Out: From>( + &self, + op: impl es_entity::IntoOneTimeExecutor<'a>, + ids: &[EntityId] + ) -> Result, EntityQueryError> { + if ids.is_empty() { + return Ok(std::collections::HashMap::new()); + } + let placeholders: String = (1..=ids.len()) + .map(|i| format!("?{i}")) + .collect::>() + .join(", "); + let ctx_param = ids.len() + 1; + let query_str = format!( + "WITH entities AS (SELECT * FROM {} WHERE id IN ({})) \ + SELECT i.id AS entity_id, e.sequence, e.event, \ + CASE WHEN ?{} THEN e.context ELSE NULL END AS context, \ + e.recorded_at \ + FROM entities i JOIN {} e ON i.id = e.id ORDER BY e.id, e.sequence", + "entities", + placeholders, + ctx_param, + "entity_events", + ); + let mut query = es_entity::prelude::sqlx::query(&query_str); + for id in ids { + query = query.bind(id); + } + query = query.bind(::event_context()); + let query = query.try_map(|row: es_entity::db::Row| -> Result { + use es_entity::prelude::sqlx::Row as _; + Ok(entity_repo_types::Repo__DbEvent { + entity_id: row.try_get("entity_id")?, + sequence: row.try_get("sequence")?, + event: row.try_get("event")?, + context: row.try_get("context")?, + recorded_at: row.try_get("recorded_at")?, + }) + }); + let db_events = op.into_executor().fetch_all(query).await?; + let n = db_events.len(); + let (mut entities, _) = es_entity::EntityEvents::load_n::(db_events.into_iter(), n)?; + Ok(entities.into_iter().map(|u| (u.id.clone(), Out::from(u))).collect()) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/find_by_fn.rs b/es-entity-macros-sqlite/src/repo/find_by_fn.rs new file mode 100644 index 00000000..fa022c98 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/find_by_fn.rs @@ -0,0 +1,628 @@ +use convert_case::{Case, Casing}; +use darling::ToTokens; +use proc_macro2::{Span, TokenStream}; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct FindByFn<'a> { + prefix: Option<&'a syn::LitStr>, + entity: &'a syn::Ident, + column: &'a Column, + table_name: &'a str, + column_enum: syn::Ident, + find_error: syn::Ident, + query_error: syn::Ident, + delete: DeleteOption, + any_nested: bool, + post_hydrate_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> FindByFn<'a> { + pub fn new(column: &'a Column, opts: &'a RepositoryOptions) -> Self { + Self { + prefix: opts.table_prefix(), + column, + entity: opts.entity(), + table_name: opts.table_name(), + column_enum: opts.column_enum(), + find_error: opts.find_error(), + query_error: opts.query_error(), + delete: opts.delete, + any_nested: opts.any_nested(), + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for FindByFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let column_name = &self.column.name(); + let (column_type, impl_expr, access_expr) = &self.column.ty_for_find_by(); + let query_fn_generics = RepositoryOptions::query_fn_generics(self.any_nested); + let query_fn_op_arg = RepositoryOptions::query_fn_op_arg(self.any_nested); + let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested); + let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested); + + for maybe in ["", "maybe_"] { + let error = if maybe.is_empty() { + &self.find_error + } else { + &self.query_error + }; + + let result_type = if maybe.is_empty() { + quote! { #entity } + } else { + quote! { Option<#entity> } + }; + + for delete in [DeleteOption::No, DeleteOption::Soft] { + let fn_name = syn::Ident::new( + &format!( + "{}find_by_{}{}", + maybe, + column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + let fn_in_op = syn::Ident::new( + &format!( + "{}find_by_{}{}_in_op", + maybe, + column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + + let query = format!( + r#"SELECT id FROM {} WHERE {} = $1{}"#, + self.table_name, + column_name, + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + } + ); + + let es_query_call = if let Some(prefix) = self.prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #query, + #column_name as &#column_type, + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #query, + #column_name as &#column_type, + ) + } + }; + + let fetch_and_validate = if maybe.is_empty() { + let entity_name_str = entity.to_string(); + let column_enum = &self.column_enum; + let column_variant = syn::Ident::new( + &column_name.to_string().to_case(Case::UpperCamel), + Span::call_site(), + ); + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + self.execute_post_hydrate_hook(&__entity).map_err(#error::PostHydrateError)?; + } + } else { + quote! {} + }; + quote! { + let __entity = #es_query_call.fetch_optional(op).await?.ok_or_else(|| #error::NotFound { + entity: #entity_name_str, + column: Some(#column_enum::#column_variant), + value: { + use es_entity::ToNotFoundValueFallback; + es_entity::NotFoundValue(#column_name).to_not_found_value() + }, + })?; + #post_hydrate_check + Ok(__entity) + } + } else { + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + if let Some(ref __entity) = __result { + self.execute_post_hydrate_hook(__entity).map_err(#error::PostHydrateError)?; + } + } + } else { + quote! {} + }; + quote! { + let __result = #es_query_call.fetch_optional(op).await?; + #post_hydrate_check + Ok(__result) + } + }; + + #[cfg(feature = "instrument")] + let (instrument_attr_in_op, record_field, error_recording) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.{}find_by_{}", repo_name, maybe, column_name); + let field_name = format!("query_{}", column_name); + let field_ident = syn::Ident::new(&field_name, proc_macro2::Span::call_site()); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, #field_ident = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + tracing::Span::current().record(#field_name, tracing::field::debug(&#column_name)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr_in_op, record_field, error_recording) = + (quote! {}, quote! {}, quote! {}); + + tokens.append_all(quote! { + pub async fn #fn_name( + &self, + #column_name: #impl_expr + ) -> Result<#result_type, #error> { + self.#fn_in_op(#query_fn_get_op, #column_name).await + } + + #instrument_attr_in_op + pub async fn #fn_in_op #query_fn_generics( + &self, + #query_fn_op_arg, + #column_name: #impl_expr + ) -> Result<#result_type, #error> + where + OP: #query_fn_op_traits + { + let __result: Result<#result_type, #error> = async { + let #column_name = #column_name.#access_expr; + #record_field + #fetch_and_validate + }.await; + + #error_recording + __result + } + }); + + if delete == self.delete || self.delete == DeleteOption::SoftWithoutQueries { + break; + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn find_by_fn() { + let column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let entity = Ident::new("Entity", Span::call_site()); + + let persist_fn = FindByFn { + prefix: None, + column: &column, + entity: &entity, + table_name: "entities", + column_enum: syn::Ident::new("EntityColumn", Span::call_site()), + find_error: syn::Ident::new("EntityFindError", Span::call_site()), + query_error: syn::Ident::new("EntityQueryError", Span::call_site()), + delete: DeleteOption::No, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn find_by_id( + &self, + id: impl std::borrow::Borrow + ) -> Result { + self.find_by_id_in_op(self.pool(), id).await + } + + pub async fn find_by_id_in_op<'a, OP>( + &self, + op: OP, + id: impl std::borrow::Borrow + ) -> Result + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result = async { + let id = id.borrow(); + let __entity = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE id = $1", + id as &EntityId, + ) + .fetch_optional(op).await?.ok_or_else(|| EntityFindError::NotFound { + entity: "Entity", + column: Some(EntityColumn::Id), + value: { + use es_entity::ToNotFoundValueFallback; + es_entity::NotFoundValue(id).to_not_found_value() + }, + })?; + Ok(__entity) + }.await; + + __result + } + + pub async fn maybe_find_by_id( + &self, + id: impl std::borrow::Borrow + ) -> Result, EntityQueryError> { + self.maybe_find_by_id_in_op(self.pool(), id).await + } + + pub async fn maybe_find_by_id_in_op<'a, OP>( + &self, + op: OP, + id: impl std::borrow::Borrow + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let id = id.borrow(); + let __result = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE id = $1", + id as &EntityId, + ) + .fetch_optional(op).await?; + Ok(__result) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn find_by_fn_string_arg() { + let column = Column::new( + syn::Ident::new("email", proc_macro2::Span::call_site()), + syn::parse_str("String").unwrap(), + ); + let entity = Ident::new("Entity", Span::call_site()); + + let persist_fn = FindByFn { + prefix: None, + column: &column, + entity: &entity, + table_name: "entities", + column_enum: syn::Ident::new("EntityColumn", Span::call_site()), + find_error: syn::Ident::new("EntityFindError", Span::call_site()), + query_error: syn::Ident::new("EntityQueryError", Span::call_site()), + delete: DeleteOption::No, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn find_by_email( + &self, + email: impl std::convert::AsRef + ) -> Result { + self.find_by_email_in_op(self.pool(), email).await + } + + pub async fn find_by_email_in_op<'a, OP>( + &self, + op: OP, + email: impl std::convert::AsRef + ) -> Result + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result = async { + let email = email.as_ref(); + let __entity = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE email = $1", + email as &str, + ) + .fetch_optional(op).await?.ok_or_else(|| EntityFindError::NotFound { + entity: "Entity", + column: Some(EntityColumn::Email), + value: { + use es_entity::ToNotFoundValueFallback; + es_entity::NotFoundValue(email).to_not_found_value() + }, + })?; + Ok(__entity) + }.await; + + __result + } + + pub async fn maybe_find_by_email( + &self, + email: impl std::convert::AsRef + ) -> Result, EntityQueryError> { + self.maybe_find_by_email_in_op(self.pool(), email).await + } + + pub async fn maybe_find_by_email_in_op<'a, OP>( + &self, + op: OP, + email: impl std::convert::AsRef + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let email = email.as_ref(); + let __result = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE email = $1", + email as &str, + ) + .fetch_optional(op).await?; + Ok(__result) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn find_by_fn_with_soft_delete() { + let column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let entity = Ident::new("Entity", Span::call_site()); + + let persist_fn = FindByFn { + prefix: None, + column: &column, + entity: &entity, + table_name: "entities", + column_enum: syn::Ident::new("EntityColumn", Span::call_site()), + find_error: syn::Ident::new("EntityFindError", Span::call_site()), + query_error: syn::Ident::new("EntityQueryError", Span::call_site()), + delete: DeleteOption::SoftWithoutQueries, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn find_by_id( + &self, + id: impl std::borrow::Borrow + ) -> Result { + self.find_by_id_in_op(self.pool(), id).await + } + + pub async fn find_by_id_in_op<'a, OP>( + &self, + op: OP, + id: impl std::borrow::Borrow + ) -> Result + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result = async { + let id = id.borrow(); + let __entity = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE id = $1 AND deleted = FALSE", + id as &EntityId, + ) + .fetch_optional(op).await?.ok_or_else(|| EntityFindError::NotFound { + entity: "Entity", + column: Some(EntityColumn::Id), + value: { + use es_entity::ToNotFoundValueFallback; + es_entity::NotFoundValue(id).to_not_found_value() + }, + })?; + Ok(__entity) + }.await; + + __result + } + + pub async fn maybe_find_by_id( + &self, + id: impl std::borrow::Borrow + ) -> Result, EntityQueryError> { + self.maybe_find_by_id_in_op(self.pool(), id).await + } + + pub async fn maybe_find_by_id_in_op<'a, OP>( + &self, + op: OP, + id: impl std::borrow::Borrow + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let id = id.borrow(); + let __result = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE id = $1 AND deleted = FALSE", + id as &EntityId, + ) + .fetch_optional(op).await?; + Ok(__result) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn find_by_fn_with_soft_delete_include_deleted() { + let column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let entity = Ident::new("Entity", Span::call_site()); + + let persist_fn = FindByFn { + prefix: None, + column: &column, + entity: &entity, + table_name: "entities", + column_enum: syn::Ident::new("EntityColumn", Span::call_site()), + find_error: syn::Ident::new("EntityFindError", Span::call_site()), + query_error: syn::Ident::new("EntityQueryError", Span::call_site()), + delete: DeleteOption::Soft, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let token_str = tokens.to_string(); + assert!(token_str.contains("find_by_id_include_deleted")); + assert!(token_str.contains("maybe_find_by_id_include_deleted")); + } + + #[test] + fn find_by_fn_nested() { + let column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let entity = Ident::new("Entity", Span::call_site()); + + let persist_fn = FindByFn { + prefix: None, + column: &column, + entity: &entity, + table_name: "entities", + column_enum: syn::Ident::new("EntityColumn", Span::call_site()), + find_error: syn::Ident::new("EntityFindError", Span::call_site()), + query_error: syn::Ident::new("EntityQueryError", Span::call_site()), + delete: DeleteOption::No, + any_nested: true, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn find_by_id( + &self, + id: impl std::borrow::Borrow + ) -> Result { + self.find_by_id_in_op(&mut self.pool().begin().await?, id).await + } + + pub async fn find_by_id_in_op( + &self, + op: &mut OP, + id: impl std::borrow::Borrow + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + let id = id.borrow(); + let __entity = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE id = $1", + id as &EntityId, + ) + .fetch_optional(op).await?.ok_or_else(|| EntityFindError::NotFound { + entity: "Entity", + column: Some(EntityColumn::Id), + value: { + use es_entity::ToNotFoundValueFallback; + es_entity::NotFoundValue(id).to_not_found_value() + }, + })?; + Ok(__entity) + }.await; + + __result + } + + pub async fn maybe_find_by_id( + &self, + id: impl std::borrow::Borrow + ) -> Result, EntityQueryError> { + self.maybe_find_by_id_in_op(&mut self.pool().begin().await?, id).await + } + + pub async fn maybe_find_by_id_in_op( + &self, + op: &mut OP, + id: impl std::borrow::Borrow + ) -> Result, EntityQueryError> + where + OP: es_entity::AtomicOperation + { + let __result: Result, EntityQueryError> = async { + let id = id.borrow(); + let __result = es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE id = $1", + id as &EntityId, + ) + .fetch_optional(op).await?; + Ok(__result) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/list_by_fn.rs b/es-entity-macros-sqlite/src/repo/list_by_fn.rs new file mode 100644 index 00000000..c4462151 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/list_by_fn.rs @@ -0,0 +1,852 @@ +use convert_case::{Case, Casing}; +use darling::ToTokens; +use proc_macro2::{Span, TokenStream}; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct CursorStruct<'a> { + pub id: &'a syn::Ident, + pub entity: &'a syn::Ident, + pub column: &'a Column, + pub cursor_mod: &'a syn::Ident, +} + +impl CursorStruct<'_> { + fn name(&self) -> String { + let entity_name = pluralizer::pluralize(&format!("{}", self.entity), 2, false); + format!("{}_by_{}_cursor", entity_name, self.column.name()).to_case(Case::UpperCamel) + } + + pub fn ident(&self) -> syn::Ident { + syn::Ident::new(&self.name(), Span::call_site()) + } + + pub fn cursor_mod(&self) -> &syn::Ident { + self.cursor_mod + } + + pub fn select_columns(&self, for_column: Option<&syn::Ident>) -> String { + let mut for_column_str = String::new(); + if let Some(for_column) = for_column + && self.column.name() != for_column + { + for_column_str = format!("{for_column}, "); + } + if self.column.is_id() { + format!("{for_column_str}id") + } else { + format!("{}{}, id", for_column_str, self.column.name()) + } + } + + pub fn order_by(&self, ascending: bool) -> String { + let dir = if ascending { "ASC" } else { "DESC" }; + let nulls = if ascending { "FIRST" } else { "LAST" }; + if self.column.is_id() { + format!("id {dir}") + } else if self.column.is_optional() { + format!("{0} {dir} NULLS {nulls}, id {dir}", self.column.name()) + } else { + format!("{} {dir}, id {dir}", self.column.name()) + } + } + + pub fn condition(&self, offset: u32, ascending: bool) -> String { + let comp = if ascending { ">" } else { "<" }; + let id_offset = offset + 2; + let column_offset = offset + 3; + + if self.column.is_id() { + format!("COALESCE(id {comp} ${id_offset}, true)") + } else if self.column.is_optional() { + format!( + "({0} IS ${column_offset}) AND COALESCE(id {comp} ${id_offset}, true) OR COALESCE({0} {comp} ${column_offset}, {0} IS NOT NULL)", + self.column.name(), + ) + } else { + format!( + "COALESCE(({0}, id) {comp} (${column_offset}, ${id_offset}), ${id_offset} IS NULL)", + self.column.name(), + ) + } + } + + pub fn query_arg_tokens(&self) -> TokenStream { + let id = self.id; + + if self.column.is_id() { + quote! { + (first + 1) as i64, + id as Option<#id>, + } + } else if self.column.is_optional() { + let column_name = self.column.name(); + let column_type = self.column.ty(); + quote! { + (first + 1) as i64, + id as Option<#id>, + #column_name as #column_type, + } + } else { + let column_name = self.column.name(); + let column_type = self.column.ty(); + quote! { + (first + 1) as i64, + id as Option<#id>, + #column_name as Option<#column_type>, + } + } + } + + pub fn destructure_tokens(&self) -> TokenStream { + let column_name = self.column.name(); + + let mut after_args = quote! { + (id, #column_name) + }; + let mut after_destruction = quote! { + (Some(after.id), Some(after.#column_name)) + }; + let mut after_default = quote! { + (None, None) + }; + + if self.column.is_id() { + after_args = quote! { + id + }; + after_destruction = quote! { + Some(after.id) + }; + after_default = quote! { + None + }; + } else if self.column.is_optional() { + after_destruction = quote! { + (Some(after.id), after.#column_name) + }; + } + + quote! { + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let #after_args = if let Some(after) = after { + #after_destruction + } else { + #after_default + }; + } + } + + #[cfg(feature = "graphql")] + pub fn gql_cursor(&self) -> TokenStream { + let ident = self.ident(); + quote! { + impl es_entity::graphql::async_graphql::connection::CursorType for #ident { + type Error = String; + + fn encode_cursor(&self) -> String { + use es_entity::graphql::base64::{engine::general_purpose, Engine as _}; + let json = es_entity::prelude::serde_json::to_string(&self).expect("could not serialize token"); + general_purpose::STANDARD_NO_PAD.encode(json.as_bytes()) + } + + fn decode_cursor(s: &str) -> Result { + use es_entity::graphql::base64::{engine::general_purpose, Engine as _}; + let bytes = general_purpose::STANDARD_NO_PAD + .decode(s.as_bytes()) + .map_err(|e| e.to_string())?; + let json = String::from_utf8(bytes).map_err(|e| e.to_string())?; + es_entity::prelude::serde_json::from_str(&json).map_err(|e| e.to_string()) + } + } + } + } +} + +impl ToTokens for CursorStruct<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let accessor = &self.column.accessor(); + let ident = self.ident(); + let id = &self.id; + + let (field, from_impl) = if self.column.is_id() { + (quote! {}, quote! {}) + } else { + let column_name = self.column.name(); + let column_type = self.column.ty(); + ( + quote! { + pub #column_name: #column_type, + }, + quote! { + #column_name: entity.#accessor.clone(), + }, + ) + }; + + tokens.append_all(quote! { + #[derive(Debug, serde::Serialize, serde::Deserialize)] + pub struct #ident { + pub id: #id, + #field + } + + impl From<&#entity> for #ident { + fn from(entity: &#entity) -> Self { + Self { + id: entity.id.clone(), + #from_impl + } + } + } + }); + } +} + +pub struct ListByFn<'a> { + ignore_prefix: Option<&'a syn::LitStr>, + id: &'a syn::Ident, + entity: &'a syn::Ident, + column: &'a Column, + table_name: &'a str, + query_error: syn::Ident, + delete: DeleteOption, + cursor_mod: syn::Ident, + any_nested: bool, + post_hydrate_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> ListByFn<'a> { + pub fn new(column: &'a Column, opts: &'a RepositoryOptions) -> Self { + Self { + ignore_prefix: opts.table_prefix(), + column, + id: opts.id(), + entity: opts.entity(), + table_name: opts.table_name(), + query_error: opts.query_error(), + delete: opts.delete, + cursor_mod: opts.cursor_mod(), + any_nested: opts.any_nested(), + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } + + pub fn cursor(&'a self) -> CursorStruct<'a> { + CursorStruct { + column: self.column, + id: self.id, + entity: self.entity, + cursor_mod: &self.cursor_mod, + } + } +} + +impl ToTokens for ListByFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let column_name = self.column.name(); + let cursor = self.cursor(); + let cursor_ident = cursor.ident(); + let cursor_mod = cursor.cursor_mod(); + let query_error = &self.query_error; + let query_fn_generics = RepositoryOptions::query_fn_generics(self.any_nested); + let query_fn_op_arg = RepositoryOptions::query_fn_op_arg(self.any_nested); + let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested); + let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested); + + let destructure_tokens = self.cursor().destructure_tokens(); + let select_columns = cursor.select_columns(None); + let arg_tokens = cursor.query_arg_tokens(); + + for delete in [DeleteOption::No, DeleteOption::Soft] { + let fn_name = syn::Ident::new( + &format!( + "list_by_{}{}", + column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + let fn_in_op = syn::Ident::new( + &format!( + "list_by_{}{}_in_op", + column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + + let asc_query = format!( + r#"SELECT {} FROM {} WHERE ({}){} ORDER BY {} LIMIT $1"#, + select_columns, + self.table_name, + cursor.condition(0, true), + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + }, + cursor.order_by(true), + ); + let desc_query = format!( + r#"SELECT {} FROM {} WHERE ({}){} ORDER BY {} LIMIT $1"#, + select_columns, + self.table_name, + cursor.condition(0, false), + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + }, + cursor.order_by(false), + ); + + let es_query_asc_call = if let Some(prefix) = self.ignore_prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #asc_query, + #arg_tokens + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #asc_query, + #arg_tokens + ) + } + }; + + let es_query_desc_call = if let Some(prefix) = self.ignore_prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #desc_query, + #arg_tokens + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #desc_query, + #arg_tokens + ) + } + }; + + #[cfg(feature = "instrument")] + let ( + instrument_attr, + extract_has_cursor, + record_fields, + record_results, + error_recording, + ) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.list_by_{}", repo_name, column_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, first, has_cursor, direction = tracing::field::debug(&direction), count = tracing::field::Empty, has_next_page = tracing::field::Empty, ids = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + let has_cursor = cursor.after.is_some(); + }, + quote! { + tracing::Span::current().record("first", first); + tracing::Span::current().record("has_cursor", has_cursor); + }, + quote! { + let result_ids: Vec<_> = entities.iter().map(|e| &e.id).collect(); + tracing::Span::current().record("count", result_ids.len()); + tracing::Span::current().record("has_next_page", has_next_page); + tracing::Span::current().record("ids", tracing::field::debug(&result_ids)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let ( + instrument_attr, + extract_has_cursor, + record_fields, + record_results, + error_recording, + ) = (quote! {}, quote! {}, quote! {}, quote! {}, quote! {}); + + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + for __entity in &entities { + self.execute_post_hydrate_hook(__entity).map_err(#query_error::PostHydrateError)?; + } + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + pub async fn #fn_name( + &self, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + direction: es_entity::ListDirection, + ) -> Result, #query_error> { + self.#fn_in_op(#query_fn_get_op, cursor, direction).await + } + + #instrument_attr + pub async fn #fn_in_op #query_fn_generics( + &self, + #query_fn_op_arg, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + direction: es_entity::ListDirection, + ) -> Result, #query_error> + where + OP: #query_fn_op_traits + { + let __result: Result, #query_error> = async { + #extract_has_cursor + #destructure_tokens + #record_fields + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + #es_query_asc_call.fetch_n(op, first).await? + }, + es_entity::ListDirection::Descending => { + #es_query_desc_call.fetch_n(op, first).await? + }, + }; + + #post_hydrate_check + #record_results + + let end_cursor = entities.last().map(#cursor_mod::#cursor_ident::from); + + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + #error_recording + __result + } + }); + + if delete == self.delete || self.delete == DeleteOption::SoftWithoutQueries { + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn cursor_struct_by_id() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let by_column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let cursor = CursorStruct { + column: &by_column, + id: &id_type, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let mut tokens = TokenStream::new(); + cursor.to_tokens(&mut tokens); + + let expected = quote! { + #[derive(Debug, serde::Serialize, serde::Deserialize)] + pub struct EntitiesByIdCursor { + pub id: EntityId, + } + + impl From<&Entity> for EntitiesByIdCursor { + fn from(entity: &Entity) -> Self { + Self { + id: entity.id.clone(), + } + } + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn cursor_struct_by_created_at() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let by_column = Column::for_created_at(); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let cursor = CursorStruct { + column: &by_column, + id: &id_type, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let mut tokens = TokenStream::new(); + cursor.to_tokens(&mut tokens); + + let expected = quote! { + #[derive(Debug, serde::Serialize, serde::Deserialize)] + pub struct EntitiesByCreatedAtCursor { + pub id: EntityId, + pub created_at: es_entity::prelude::chrono::DateTime, + } + + impl From<&Entity> for EntitiesByCreatedAtCursor { + fn from(entity: &Entity) -> Self { + Self { + id: entity.id.clone(), + created_at: entity.events() + .entity_first_persisted_at() + .expect("entity not persisted") + .clone(), + } + } + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn list_by_fn() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListByFn { + ignore_prefix: None, + column: &column, + id: &id_type, + entity: &entity, + table_name: "entities", + query_error, + delete: DeleteOption::SoftWithoutQueries, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn list_by_id( + &self, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> { + self.list_by_id_in_op(self.pool(), cursor, direction).await + } + + pub async fn list_by_id_in_op<'a, OP>( + &self, + op: OP, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let id = if let Some(after) = after { + Some(after.id) + } else { + None + }; + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE (COALESCE(id > $2, true)) AND deleted = FALSE ORDER BY id ASC LIMIT $1", + (first + 1) as i64, + id as Option, + ) + .fetch_n(op, first) + .await? + }, + es_entity::ListDirection::Descending => { + es_entity::es_query!( + entity = Entity, + "SELECT id FROM entities WHERE (COALESCE(id < $2, true)) AND deleted = FALSE ORDER BY id DESC LIMIT $1", + (first + 1) as i64, + id as Option, + ) + .fetch_n(op, first) + .await? + }, + }; + + let end_cursor = entities.last().map(cursor_mod::EntitiesByIdCursor::from); + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn list_by_fn_with_soft_delete_include_deleted() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListByFn { + ignore_prefix: None, + column: &column, + id: &id_type, + entity: &entity, + table_name: "entities", + query_error, + delete: DeleteOption::Soft, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let token_str = tokens.to_string(); + assert!(token_str.contains("list_by_id_include_deleted")); + } + + #[test] + fn list_by_fn_name() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let column = Column::new( + syn::Ident::new("name", proc_macro2::Span::call_site()), + syn::parse_str("String").unwrap(), + ); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListByFn { + ignore_prefix: None, + column: &column, + id: &id_type, + entity: &entity, + table_name: "entities", + query_error, + delete: DeleteOption::No, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn list_by_name( + &self, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> { + self.list_by_name_in_op(self.pool(), cursor, direction).await + } + + pub async fn list_by_name_in_op<'a, OP>( + &self, + op: OP, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let (id, name) = if let Some(after) = after { + (Some(after.id), Some(after.name)) + } else { + (None, None) + }; + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + es_entity::es_query!( + entity = Entity, + "SELECT name, id FROM entities WHERE (COALESCE((name, id) > ($3, $2), $2 IS NULL)) ORDER BY name ASC, id ASC LIMIT $1", + (first + 1) as i64, + id as Option, + name as Option, + ) + .fetch_n(op, first) + .await? + }, + es_entity::ListDirection::Descending => { + es_entity::es_query!( + entity = Entity, + "SELECT name, id FROM entities WHERE (COALESCE((name, id) < ($3, $2), $2 IS NULL)) ORDER BY name DESC, id DESC LIMIT $1", + (first + 1) as i64, + id as Option, + name as Option, + ) + .fetch_n(op, first) + .await? + }, + }; + + let end_cursor = entities.last().map(cursor_mod::EntitiesByNameCursor::from); + + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn list_by_fn_optional_column() { + let id_type = Ident::new("EntityId", Span::call_site()); + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let column = Column::new( + syn::Ident::new("value", proc_macro2::Span::call_site()), + syn::parse_str("Option").unwrap(), + ); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListByFn { + ignore_prefix: None, + column: &column, + id: &id_type, + entity: &entity, + table_name: "entities", + query_error, + delete: DeleteOption::No, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn list_by_value( + &self, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> { + self.list_by_value_in_op(self.pool(), cursor, direction).await + } + + pub async fn list_by_value_in_op<'a, OP>( + &self, + op: OP, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let (id, value) = if let Some(after) = after { + (Some(after.id), after.value) + } else { + (None, None) + }; + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + es_entity::es_query!( + entity = Entity, + "SELECT value, id FROM entities WHERE ((value IS $3) AND COALESCE(id > $2, true) OR COALESCE(value > $3, value IS NOT NULL)) ORDER BY value ASC NULLS FIRST, id ASC LIMIT $1", + (first + 1) as i64, + id as Option, + value as Option, + ) + .fetch_n(op, first) + .await? + }, + es_entity::ListDirection::Descending => { + es_entity::es_query!( + entity = Entity, + "SELECT value, id FROM entities WHERE ((value IS $3) AND COALESCE(id < $2, true) OR COALESCE(value < $3, value IS NOT NULL)) ORDER BY value DESC NULLS LAST, id DESC LIMIT $1", + (first + 1) as i64, + id as Option, + value as Option, + ) + .fetch_n(op, first) + .await? + }, + }; + + let end_cursor = entities.last().map(cursor_mod::EntitiesByValueCursor::from); + + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/list_for_filters_fn.rs b/es-entity-macros-sqlite/src/repo/list_for_filters_fn.rs new file mode 100644 index 00000000..379dec43 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/list_for_filters_fn.rs @@ -0,0 +1,959 @@ +use convert_case::{Case, Casing}; +use darling::ToTokens; +use proc_macro2::{Span, TokenStream}; +use quote::{TokenStreamExt, quote}; + +use super::{combo_cursor::ComboCursor, list_by_fn::CursorStruct, options::*}; + +pub struct FiltersStruct<'a> { + columns: Vec<&'a Column>, + entity: &'a syn::Ident, +} + +impl<'a> FiltersStruct<'a> { + pub fn new(opts: &'a RepositoryOptions, columns: Vec<&'a Column>) -> Self { + Self { + entity: opts.entity(), + columns, + } + } + + #[cfg(test)] + fn new_test(entity: &'a syn::Ident, columns: Vec<&'a Column>) -> Self { + Self { entity, columns } + } + + pub fn ident(&self) -> syn::Ident { + let entity_name = pluralizer::pluralize(&format!("{}", self.entity), 2, false); + syn::Ident::new( + &format!("{entity_name}_filters").to_case(Case::UpperCamel), + Span::call_site(), + ) + } + + fn fields(&self) -> TokenStream { + self.columns + .iter() + .map(|column| { + let name = column.name(); + let ty = column.ty(); + quote! { + pub #name: Option<#ty>, + } + }) + .collect() + } + + fn where_clause_fragment(column: &Column, idx: u32) -> String { + let col_name = column.name(); + let param = format!("${idx}"); + format!("COALESCE({col_name} = {param}, {param} IS NULL)") + } + + fn filter_arg_tokens(column: &Column) -> TokenStream { + let name = syn::Ident::new(&format!("filter_{}", column.name()), Span::call_site()); + let ty = column.ty(); + if let syn::Type::Path(type_path) = ty + && type_path.path.is_ident("String") + { + quote! { + #name as Option, + } + } else { + quote! { + #name as Option<#ty>, + } + } + } +} + +impl ToTokens for FiltersStruct<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = self.ident(); + let fields = self.fields(); + + tokens.append_all(quote! { + #[derive(Debug, Default)] + pub struct #ident { + #fields + } + }); + } +} + +pub struct ListForFiltersFn<'a> { + pub filters_struct: FiltersStruct<'a>, + entity: &'a syn::Ident, + query_error: syn::Ident, + for_columns: Vec<&'a Column>, + by_columns: Vec<&'a Column>, + cursor: &'a ComboCursor<'a>, + delete: DeleteOption, + cursor_mod: syn::Ident, + table_name: &'a str, + ignore_prefix: Option<&'a syn::LitStr>, + id: &'a syn::Ident, + any_nested: bool, + post_hydrate_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> ListForFiltersFn<'a> { + pub fn new( + opts: &'a RepositoryOptions, + for_columns: Vec<&'a Column>, + by_columns: Vec<&'a Column>, + cursor: &'a ComboCursor<'a>, + ) -> Self { + Self { + filters_struct: FiltersStruct::new(opts, for_columns.clone()), + entity: opts.entity(), + query_error: opts.query_error(), + for_columns, + by_columns, + cursor, + delete: opts.delete, + cursor_mod: opts.cursor_mod(), + table_name: opts.table_name(), + ignore_prefix: opts.table_prefix(), + id: opts.id(), + any_nested: opts.any_nested(), + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } + + fn generate_proxy_body(&self, by_col: &Column, delete: DeleteOption) -> TokenStream { + let by_col_name = by_col.name(); + let delete_postfix = delete.include_deletion_fn_postfix(); + + let list_by_fn = syn::Ident::new( + &format!("list_by_{}{}", by_col_name, delete_postfix), + Span::call_site(), + ); + + if self.for_columns.is_empty() { + return quote! { self.#list_by_fn(query, direction).await? }; + } + + let all_none_checks: Vec<_> = self + .for_columns + .iter() + .map(|c| { + let name = c.name(); + quote! { filters.#name.is_none() } + }) + .collect(); + + // Determine which for_columns have individual methods for this by_col. + let paired_for_columns: Vec<_> = self + .for_columns + .iter() + .filter(|fc| fc.list_for_by_columns().iter().any(|n| n == by_col_name)) + .collect(); + + let single_filter_branches: TokenStream = paired_for_columns + .iter() + .map(|for_col| { + let others_none: Vec<_> = self + .for_columns + .iter() + .filter(|c| c.name() != for_col.name()) + .map(|c| { + let name = c.name(); + quote! { filters.#name.is_none() } + }) + .collect(); + + let for_col_name = for_col.name(); + let fn_name = syn::Ident::new( + &format!( + "list_for_{}_by_{}{}", + for_col_name, by_col_name, delete_postfix + ), + Span::call_site(), + ); + + if others_none.is_empty() { + quote! { + else { + self.#fn_name(filters.#for_col_name.unwrap(), query, direction).await? + } + } + } else { + quote! { + else if #(#others_none)&&* { + self.#fn_name(filters.#for_col_name.unwrap(), query, direction).await? + } + } + } + }) + .collect(); + + // Need a fallback when: + // - there are unpaired for_columns (they need COALESCE) + // - there are 2+ paired columns (multi-filter case) + // - there are 2+ for_columns total (multi-filter case) + let has_unpaired = paired_for_columns.len() < self.for_columns.len(); + let needs_fallback = has_unpaired || self.for_columns.len() >= 2; + let multi_filter_fallback = if needs_fallback { + let list_for_filters_fn = syn::Ident::new( + &format!("list_for_filters_by_{}{}", by_col_name, delete_postfix), + Span::call_site(), + ); + quote! { + else { + self.#list_for_filters_fn(filters, query, direction).await? + } + } + } else { + quote! {} + }; + + quote! { + if #(#all_none_checks)&&* { + self.#list_by_fn(query, direction).await? + } + #single_filter_branches + #multi_filter_fallback + } + } + + fn generate_by_fn(&self, by_column: &'a Column, delete: DeleteOption) -> TokenStream { + let entity = self.entity; + let error = &self.query_error; + let cursor_mod = &self.cursor_mod; + let query_fn_generics = RepositoryOptions::query_fn_generics(self.any_nested); + let query_fn_op_arg = RepositoryOptions::query_fn_op_arg(self.any_nested); + let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested); + let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested); + + let by_column_name = by_column.name(); + let cursor_struct = CursorStruct { + column: by_column, + id: self.id, + entity: self.entity, + cursor_mod: &self.cursor_mod, + }; + let cursor_ident = cursor_struct.ident(); + + let n_filters = self.for_columns.len() as u32; + + let destructure_tokens = cursor_struct.destructure_tokens(); + let select_columns = cursor_struct.select_columns(None); + let cursor_arg_tokens = cursor_struct.query_arg_tokens(); + + let fn_name = syn::Ident::new( + &format!( + "list_for_filters_by_{}{}", + by_column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + let fn_in_op = syn::Ident::new( + &format!( + "list_for_filters_by_{}{}_in_op", + by_column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + + let filters_ident = self.filters_struct.ident(); + + // Generate filter destructuring + let filter_field_names: Vec<_> = self + .for_columns + .iter() + .map(|c| { + let col_name = c.name(); + let filter_name = + syn::Ident::new(&format!("filter_{}", col_name), Span::call_site()); + (col_name.clone(), filter_name) + }) + .collect(); + + let destructure_filters: TokenStream = filter_field_names + .iter() + .map(|(col_name, filter_name)| { + quote! { + let #filter_name = filters.#col_name; + } + }) + .collect(); + + // Generate WHERE clause fragments + let where_fragments: Vec = self + .for_columns + .iter() + .enumerate() + .map(|(i, col)| FiltersStruct::where_clause_fragment(col, (i + 1) as u32)) + .collect(); + + let filter_where = if where_fragments.is_empty() { + String::new() + } else { + format!("{} AND ", where_fragments.join(" AND ")) + }; + + // Generate filter arg bindings for es_query! + let filter_arg_bindings: TokenStream = self + .for_columns + .iter() + .map(|col| FiltersStruct::filter_arg_tokens(col)) + .collect(); + + let asc_query = format!( + r#"SELECT {} FROM {} WHERE {}({}){} ORDER BY {} LIMIT ${}"#, + select_columns, + self.table_name, + filter_where, + cursor_struct.condition(n_filters, true), + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + }, + cursor_struct.order_by(true), + n_filters + 1, + ); + let desc_query = format!( + r#"SELECT {} FROM {} WHERE {}({}){} ORDER BY {} LIMIT ${}"#, + select_columns, + self.table_name, + filter_where, + cursor_struct.condition(n_filters, false), + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + }, + cursor_struct.order_by(false), + n_filters + 1, + ); + + let es_query_asc_call = if let Some(prefix) = self.ignore_prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #asc_query, + #filter_arg_bindings + #cursor_arg_tokens + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #asc_query, + #filter_arg_bindings + #cursor_arg_tokens + ) + } + }; + + let es_query_desc_call = if let Some(prefix) = self.ignore_prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #desc_query, + #filter_arg_bindings + #cursor_arg_tokens + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #desc_query, + #filter_arg_bindings + #cursor_arg_tokens + ) + } + }; + + #[cfg(feature = "instrument")] + let (instrument_attr, extract_has_cursor, record_fields, record_results, error_recording) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.list_for_filters_by_{}", repo_name, by_column_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, filters = tracing::field::debug(&filters), first, has_cursor, direction = tracing::field::debug(&direction), count = tracing::field::Empty, has_next_page = tracing::field::Empty, ids = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + let has_cursor = cursor.after.is_some(); + }, + quote! { + tracing::Span::current().record("first", first); + tracing::Span::current().record("has_cursor", has_cursor); + }, + quote! { + let result_ids: Vec<_> = entities.iter().map(|e| &e.id).collect(); + tracing::Span::current().record("count", result_ids.len()); + tracing::Span::current().record("has_next_page", has_next_page); + tracing::Span::current().record("ids", tracing::field::debug(&result_ids)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr, extract_has_cursor, record_fields, record_results, error_recording) = + (quote! {}, quote! {}, quote! {}, quote! {}, quote! {}); + + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + for __entity in &entities { + self.execute_post_hydrate_hook(__entity).map_err(#error::PostHydrateError)?; + } + } + } else { + quote! {} + }; + + quote! { + pub async fn #fn_name( + &self, + filters: #filters_ident, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + direction: es_entity::ListDirection, + ) -> Result, #error> { + self.#fn_in_op(#query_fn_get_op, filters, cursor, direction).await + } + + #instrument_attr + pub async fn #fn_in_op #query_fn_generics( + &self, + #query_fn_op_arg, + filters: #filters_ident, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + direction: es_entity::ListDirection, + ) -> Result, #error> + where + OP: #query_fn_op_traits + { + let __result: Result, #error> = async { + #extract_has_cursor + #destructure_filters + #destructure_tokens + #record_fields + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + #es_query_asc_call.fetch_n(op, first).await? + }, + es_entity::ListDirection::Descending => { + #es_query_desc_call.fetch_n(op, first).await? + } + }; + + #post_hydrate_check + #record_results + + let end_cursor = entities.last().map(#cursor_mod::#cursor_ident::from); + + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + #error_recording + __result + } + } + } +} + +impl ToTokens for ListForFiltersFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let filters_name = self.filters_struct.ident(); + let sort_by_name = self.cursor.sort_by_name(); + let cursor_ident = self.cursor.ident(); + + let entity = self.entity; + let error = &self.query_error; + let cursor_mod = &self.cursor_mod; + + for delete in [DeleteOption::No, DeleteOption::Soft] { + // Generate per-sort-column functions + let by_fns: TokenStream = self + .by_columns + .iter() + .map(|by_col| self.generate_by_fn(by_col, delete)) + .collect(); + + tokens.append_all(by_fns); + + // Generate dispatch function + let dispatch_arms: TokenStream = self + .by_columns + .iter() + .map(|by_col| { + let by_variant = syn::Ident::new( + &format!("{}", by_col.name()).to_case(Case::UpperCamel), + Span::call_site(), + ); + let inner_cursor_ident = { + let entity_name = + pluralizer::pluralize(&format!("{}", self.entity), 2, false); + syn::Ident::new( + &format!("{}_by_{}_cursor", entity_name, by_col.name()) + .to_case(Case::UpperCamel), + Span::call_site(), + ) + }; + let proxy_body = self.generate_proxy_body(by_col, delete); + quote! { + #sort_by_name::#by_variant => { + let after = after.map(#cursor_mod::#inner_cursor_ident::try_from).transpose()?; + let query = es_entity::PaginatedQueryArgs { first, after }; + + let es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + } = #proxy_body; + es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor: end_cursor.map(#cursor_mod::#cursor_ident::from) + } + } + } + }) + .collect(); + + let fn_name = syn::Ident::new( + &format!("list_for_filters{}", delete.include_deletion_fn_postfix()), + Span::call_site(), + ); + + #[cfg(feature = "instrument")] + let ( + instrument_attr, + extract_has_cursor, + record_fields, + record_results, + error_recording, + ) = { + let entity_name = self.entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.list_for_filters", repo_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, filters = tracing::field::debug(&filters), sort_by = tracing::field::debug(&sort.by), direction = tracing::field::debug(&sort.direction), first, has_cursor, count = tracing::field::Empty, has_next_page = tracing::field::Empty, ids = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + let has_cursor = cursor.after.is_some(); + }, + quote! { + tracing::Span::current().record("first", first); + tracing::Span::current().record("has_cursor", has_cursor); + }, + quote! { + let result_ids: Vec<_> = res.entities.iter().map(|e| &e.id).collect(); + tracing::Span::current().record("count", result_ids.len()); + tracing::Span::current().record("has_next_page", res.has_next_page); + tracing::Span::current().record("ids", tracing::field::debug(&result_ids)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let ( + instrument_attr, + extract_has_cursor, + record_fields, + record_results, + error_recording, + ) = (quote! {}, quote! {}, quote! {}, quote! {}, quote! {}); + + tokens.append_all(quote! { + #instrument_attr + pub async fn #fn_name( + &self, + filters: #filters_name, + sort: es_entity::Sort<#sort_by_name>, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + ) -> Result, #error> + { + let __result: Result, #error> = async { + #extract_has_cursor + let es_entity::Sort { by, direction } = sort; + let es_entity::PaginatedQueryArgs { first, after } = cursor; + #record_fields + + use #cursor_mod::#cursor_ident; + let res = match by { + #dispatch_arms + }; + + #record_results + + Ok(res) + }.await; + + #error_recording + __result + } + }); + + if delete == self.delete || self.delete == DeleteOption::SoftWithoutQueries { + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn filters_struct() { + let entity = Ident::new("Order", Span::call_site()); + let customer_id_column = Column::new( + syn::Ident::new("customer_id", proc_macro2::Span::call_site()), + syn::parse_str("CustomerId").unwrap(), + ); + let status_column = Column::new( + syn::Ident::new("status", proc_macro2::Span::call_site()), + syn::parse_str("OrderStatus").unwrap(), + ); + + let filters = FiltersStruct::new_test(&entity, vec![&customer_id_column, &status_column]); + + let mut tokens = TokenStream::new(); + filters.to_tokens(&mut tokens); + + let expected = quote! { + #[derive(Debug, Default)] + pub struct OrdersFilters { + pub customer_id: Option, + pub status: Option, + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn list_for_filters_function_generation() { + let entity = Ident::new("Order", Span::call_site()); + let query_error = syn::Ident::new("OrderQueryError", Span::call_site()); + let id = syn::Ident::new("OrderId", proc_macro2::Span::call_site()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let id_column = Column::for_id(syn::parse_str("OrderId").unwrap()); + let id_ident = syn::Ident::new("id", proc_macro2::Span::call_site()); + let customer_id_column = Column::new_list_for( + syn::Ident::new("customer_id", proc_macro2::Span::call_site()), + syn::parse_str("CustomerId").unwrap(), + vec![id_ident.clone()], + ); + let status_column = Column::new_list_for( + syn::Ident::new("status", proc_macro2::Span::call_site()), + syn::parse_str("OrderStatus").unwrap(), + vec![id_ident], + ); + + let for_columns = vec![&customer_id_column, &status_column]; + let by_columns = vec![&id_column]; + + let id_cursor = CursorStruct { + column: &id_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let combo_cursor = ComboCursor::new_test(&entity, vec![id_cursor]); + + let list_for_filters_fn = ListForFiltersFn { + filters_struct: FiltersStruct::new_test(&entity, for_columns.clone()), + entity: &entity, + query_error, + for_columns, + by_columns, + cursor: &combo_cursor, + delete: DeleteOption::No, + cursor_mod: cursor_mod.clone(), + table_name: "orders", + ignore_prefix: None, + id: &id, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + list_for_filters_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn list_for_filters_by_id( + &self, + filters: OrdersFilters, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, OrderQueryError> { + self.list_for_filters_by_id_in_op(self.pool(), filters, cursor, direction).await + } + + pub async fn list_for_filters_by_id_in_op<'a, OP>( + &self, + op: OP, + filters: OrdersFilters, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, OrderQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, OrderQueryError> = async { + let filter_customer_id = filters.customer_id; + let filter_status = filters.status; + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let id = if let Some(after) = after { + Some(after.id) + } else { + None + }; + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + es_entity::es_query!( + entity = Order, + "SELECT id FROM orders WHERE COALESCE(customer_id = $1, $1 IS NULL) AND COALESCE(status = $2, $2 IS NULL) AND (COALESCE(id > $4, true)) ORDER BY id ASC LIMIT $3", + filter_customer_id as Option, + filter_status as Option, + (first + 1) as i64, + id as Option, + ) + .fetch_n(op, first) + .await? + }, + es_entity::ListDirection::Descending => { + es_entity::es_query!( + entity = Order, + "SELECT id FROM orders WHERE COALESCE(customer_id = $1, $1 IS NULL) AND COALESCE(status = $2, $2 IS NULL) AND (COALESCE(id < $4, true)) ORDER BY id DESC LIMIT $3", + filter_customer_id as Option, + filter_status as Option, + (first + 1) as i64, + id as Option, + ) + .fetch_n(op, first) + .await? + } + }; + + let end_cursor = entities.last().map(cursor_mod::OrdersByIdCursor::from); + + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + __result + } + + pub async fn list_for_filters( + &self, + filters: OrdersFilters, + sort: es_entity::Sort, + cursor: es_entity::PaginatedQueryArgs, + ) -> Result, OrderQueryError> + { + let __result: Result, OrderQueryError> = async { + let es_entity::Sort { by, direction } = sort; + let es_entity::PaginatedQueryArgs { first, after } = cursor; + + use cursor_mod::OrdersCursor; + let res = match by { + OrdersSortBy::Id => { + let after = after.map(cursor_mod::OrdersByIdCursor::try_from).transpose()?; + let query = es_entity::PaginatedQueryArgs { first, after }; + + let es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + } = if filters.customer_id.is_none() && filters.status.is_none() { + self.list_by_id(query, direction).await? + } else if filters.status.is_none() { + self.list_for_customer_id_by_id(filters.customer_id.unwrap(), query, direction).await? + } else if filters.customer_id.is_none() { + self.list_for_status_by_id(filters.status.unwrap(), query, direction).await? + } else { + self.list_for_filters_by_id(filters, query, direction).await? + }; + es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor: end_cursor.map(cursor_mod::OrdersCursor::from) + } + } + }; + + Ok(res) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn list_for_filters_bare_list_for_defaults_to_by_id() { + // Bare list_for defaults to by(id) only + let entity = Ident::new("Order", Span::call_site()); + let query_error = syn::Ident::new("OrderQueryError", Span::call_site()); + let id = syn::Ident::new("OrderId", proc_macro2::Span::call_site()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let id_column = Column::for_id(syn::parse_str("OrderId").unwrap()); + let id_ident = syn::Ident::new("id", proc_macro2::Span::call_site()); + let customer_id_column = Column::new_list_for( + syn::Ident::new("customer_id", proc_macro2::Span::call_site()), + syn::parse_str("CustomerId").unwrap(), + vec![id_ident.clone()], + ); + let status_column = Column::new_list_for( + syn::Ident::new("status", proc_macro2::Span::call_site()), + syn::parse_str("OrderStatus").unwrap(), + vec![id_ident], + ); + + let for_columns = vec![&customer_id_column, &status_column]; + let by_columns = vec![&id_column]; + + let id_cursor = CursorStruct { + column: &id_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let combo_cursor = ComboCursor::new_test(&entity, vec![id_cursor]); + + let list_for_filters_fn = ListForFiltersFn { + filters_struct: FiltersStruct::new_test(&entity, for_columns.clone()), + entity: &entity, + query_error, + for_columns, + by_columns, + cursor: &combo_cursor, + delete: DeleteOption::No, + cursor_mod: cursor_mod.clone(), + table_name: "orders", + ignore_prefix: None, + id: &id, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + list_for_filters_fn.to_tokens(&mut tokens); + + let token_str = tokens.to_string(); + + // Bare list_for defaults to by(id), so should dispatch to individual methods for id + assert!(token_str.contains("list_for_customer_id_by_id")); + assert!(token_str.contains("list_for_status_by_id")); + assert!(token_str.contains("list_for_filters_by_id")); + assert!(token_str.contains("list_by_id")); + } + + #[test] + fn list_for_filters_mixed_by_columns() { + // Test: customer_id has list_for(by(id)), status has list_for(by(created_at)) + // Only customer_id should dispatch to individual method for by_id sort + let entity = Ident::new("Order", Span::call_site()); + let query_error = syn::Ident::new("OrderQueryError", Span::call_site()); + let id = syn::Ident::new("OrderId", proc_macro2::Span::call_site()); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let id_column = Column::for_id(syn::parse_str("OrderId").unwrap()); + let id_ident = syn::Ident::new("id", proc_macro2::Span::call_site()); + let created_at_ident = syn::Ident::new("created_at", proc_macro2::Span::call_site()); + // customer_id has by(id) - gets individual method for id sort + let customer_id_column = Column::new_list_for( + syn::Ident::new("customer_id", proc_macro2::Span::call_site()), + syn::parse_str("CustomerId").unwrap(), + vec![id_ident], + ); + // status has by(created_at) - NOT paired with id sort + let status_column = Column::new_list_for( + syn::Ident::new("status", proc_macro2::Span::call_site()), + syn::parse_str("OrderStatus").unwrap(), + vec![created_at_ident], + ); + + let for_columns = vec![&customer_id_column, &status_column]; + let by_columns = vec![&id_column]; + + let id_cursor = CursorStruct { + column: &id_column, + id: &id, + entity: &entity, + cursor_mod: &cursor_mod, + }; + + let combo_cursor = ComboCursor::new_test(&entity, vec![id_cursor]); + + let list_for_filters_fn = ListForFiltersFn { + filters_struct: FiltersStruct::new_test(&entity, for_columns.clone()), + entity: &entity, + query_error, + for_columns, + by_columns, + cursor: &combo_cursor, + delete: DeleteOption::No, + cursor_mod: cursor_mod.clone(), + table_name: "orders", + ignore_prefix: None, + id: &id, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + list_for_filters_fn.to_tokens(&mut tokens); + + let token_str = tokens.to_string(); + + // customer_id has by(id), so dispatch should use list_for_customer_id_by_id + assert!(token_str.contains("list_for_customer_id_by_id")); + // status has by(created_at) not by(id), so no individual dispatch for id sort + assert!(!token_str.contains("list_for_status_by_id")); + // Should still have unified fallback + assert!(token_str.contains("list_for_filters_by_id")); + } +} diff --git a/es-entity-macros-sqlite/src/repo/list_for_fn.rs b/es-entity-macros-sqlite/src/repo/list_for_fn.rs new file mode 100644 index 00000000..14febfc7 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/list_for_fn.rs @@ -0,0 +1,491 @@ +use darling::ToTokens; +use proc_macro2::{Span, TokenStream}; +use quote::{TokenStreamExt, quote}; + +use super::{list_by_fn::CursorStruct, options::*}; + +pub struct ListForFn<'a> { + ignore_prefix: Option<&'a syn::LitStr>, + pub for_column: &'a Column, + pub by_column: &'a Column, + entity: &'a syn::Ident, + id: &'a syn::Ident, + table_name: &'a str, + query_error: syn::Ident, + delete: DeleteOption, + cursor_mod: syn::Ident, + any_nested: bool, + post_hydrate_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> ListForFn<'a> { + pub fn new(for_column: &'a Column, by_column: &'a Column, opts: &'a RepositoryOptions) -> Self { + Self { + ignore_prefix: opts.table_prefix(), + for_column, + by_column, + id: opts.id(), + entity: opts.entity(), + table_name: opts.table_name(), + query_error: opts.query_error(), + delete: opts.delete, + cursor_mod: opts.cursor_mod(), + any_nested: opts.any_nested(), + post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } + + pub fn cursor(&'a self) -> CursorStruct<'a> { + CursorStruct { + column: self.by_column, + id: self.id, + entity: self.entity, + cursor_mod: &self.cursor_mod, + } + } +} + +impl ToTokens for ListForFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let cursor = self.cursor(); + let cursor_ident = cursor.ident(); + let cursor_mod = cursor.cursor_mod(); + let error = &self.query_error; + let query_fn_generics = RepositoryOptions::query_fn_generics(self.any_nested); + let query_fn_op_arg = RepositoryOptions::query_fn_op_arg(self.any_nested); + let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested); + let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested); + + let by_column_name = self.by_column.name(); + + let for_column_name = self.for_column.name(); + let filter_arg_name = syn::Ident::new( + &format!("filter_{}", self.for_column.name()), + Span::call_site(), + ); + let (for_column_type, for_impl_expr, for_access_expr) = self.for_column.ty_for_find_by(); + + let destructure_tokens = self.cursor().destructure_tokens(); + let select_columns = cursor.select_columns(Some(for_column_name)); + let arg_tokens = cursor.query_arg_tokens(); + + for delete in [DeleteOption::No, DeleteOption::Soft] { + let fn_name = syn::Ident::new( + &format!( + "list_for_{}_by_{}{}", + for_column_name, + by_column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + let fn_in_op = syn::Ident::new( + &format!( + "list_for_{}_by_{}{}_in_op", + for_column_name, + by_column_name, + delete.include_deletion_fn_postfix() + ), + Span::call_site(), + ); + + let asc_query = format!( + r#"SELECT {} FROM {} WHERE (({} = $1) AND ({})){} ORDER BY {} LIMIT $2"#, + select_columns, + self.table_name, + for_column_name, + cursor.condition(1, true), + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + }, + cursor.order_by(true) + ); + let desc_query = format!( + r#"SELECT {} FROM {} WHERE (({} = $1) AND ({})){} ORDER BY {} LIMIT $2"#, + select_columns, + self.table_name, + for_column_name, + cursor.condition(1, false), + if delete == DeleteOption::No { + self.delete.not_deleted_condition() + } else { + "" + }, + cursor.order_by(false) + ); + + let es_query_asc_call = if let Some(prefix) = self.ignore_prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #asc_query, + #filter_arg_name as &#for_column_type, + #arg_tokens + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #asc_query, + #filter_arg_name as &#for_column_type, + #arg_tokens + ) + } + }; + + let es_query_desc_call = if let Some(prefix) = self.ignore_prefix { + quote! { + es_entity::es_query!( + tbl_prefix = #prefix, + #desc_query, + #filter_arg_name as &#for_column_type, + #arg_tokens + ) + } + } else { + quote! { + es_entity::es_query!( + entity = #entity, + #desc_query, + #filter_arg_name as &#for_column_type, + #arg_tokens + ) + } + }; + + #[cfg(feature = "instrument")] + let ( + instrument_attr, + extract_has_cursor, + record_fields, + record_results, + error_recording, + ) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!( + "{}.list_for_{}_by_{}", + repo_name, for_column_name, by_column_name + ); + let filter_field_name = format!("query_{}", filter_arg_name); + let filter_field_ident = + syn::Ident::new(&filter_field_name, proc_macro2::Span::call_site()); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, #filter_field_ident = tracing::field::Empty, first, has_cursor, direction = tracing::field::debug(&direction), count = tracing::field::Empty, has_next_page = tracing::field::Empty, ids = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + let has_cursor = cursor.after.is_some(); + }, + quote! { + tracing::Span::current().record(#filter_field_name, tracing::field::debug(&#filter_arg_name)); + tracing::Span::current().record("first", first); + tracing::Span::current().record("has_cursor", has_cursor); + }, + quote! { + let result_ids: Vec<_> = entities.iter().map(|e| &e.id).collect(); + tracing::Span::current().record("count", result_ids.len()); + tracing::Span::current().record("has_next_page", has_next_page); + tracing::Span::current().record("ids", tracing::field::debug(&result_ids)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let ( + instrument_attr, + extract_has_cursor, + record_fields, + record_results, + error_recording, + ) = (quote! {}, quote! {}, quote! {}, quote! {}, quote! {}); + + let post_hydrate_check = if self.post_hydrate_error.is_some() { + quote! { + for __entity in &entities { + self.execute_post_hydrate_hook(__entity).map_err(#error::PostHydrateError)?; + } + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + pub async fn #fn_name( + &self, + #filter_arg_name: #for_impl_expr, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + direction: es_entity::ListDirection, + ) -> Result, #error> { + self.#fn_in_op(#query_fn_get_op, #filter_arg_name, cursor, direction).await + } + + #instrument_attr + pub async fn #fn_in_op #query_fn_generics( + &self, + #query_fn_op_arg, + #filter_arg_name: #for_impl_expr, + cursor: es_entity::PaginatedQueryArgs<#cursor_mod::#cursor_ident>, + direction: es_entity::ListDirection, + ) -> Result, #error> + where + OP: #query_fn_op_traits + { + let __result: Result, #error> = async { + #extract_has_cursor + let #filter_arg_name = #filter_arg_name.#for_access_expr; + #destructure_tokens + #record_fields + + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + #es_query_asc_call.fetch_n(op, first).await? + }, + es_entity::ListDirection::Descending => { + #es_query_desc_call.fetch_n(op, first).await? + } + }; + + #post_hydrate_check + #record_results + + let end_cursor = entities.last().map(#cursor_mod::#cursor_ident::from); + + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + #error_recording + __result + } + }); + + if delete == self.delete || self.delete == DeleteOption::SoftWithoutQueries { + break; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn list_for_fn() { + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let id = syn::Ident::new("EntityId", proc_macro2::Span::call_site()); + let by_column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let for_column = Column::new( + syn::Ident::new("customer_id", proc_macro2::Span::call_site()), + syn::parse_str("Uuid").unwrap(), + ); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListForFn { + ignore_prefix: None, + entity: &entity, + id: &id, + for_column: &for_column, + by_column: &by_column, + table_name: "entities", + query_error, + delete: DeleteOption::No, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn list_for_customer_id_by_id( + &self, + filter_customer_id: impl std::borrow::Borrow, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> { + self.list_for_customer_id_by_id_in_op(self.pool(), filter_customer_id, cursor, direction).await + } + + pub async fn list_for_customer_id_by_id_in_op<'a, OP>( + &self, + op: OP, + filter_customer_id: impl std::borrow::Borrow, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let filter_customer_id = filter_customer_id.borrow(); + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let id = if let Some(after) = after { + Some(after.id) + } else { + None + }; + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + es_entity::es_query!( + entity = Entity, + "SELECT customer_id, id FROM entities WHERE ((customer_id = $1) AND (COALESCE(id > $3, true))) ORDER BY id ASC LIMIT $2", + filter_customer_id as &Uuid, + (first + 1) as i64, + id as Option, + ) + .fetch_n(op, first) + .await? + }, + es_entity::ListDirection::Descending => { + es_entity::es_query!( + entity = Entity, + "SELECT customer_id, id FROM entities WHERE ((customer_id = $1) AND (COALESCE(id < $3, true))) ORDER BY id DESC LIMIT $2", + filter_customer_id as &Uuid, + (first + 1) as i64, + id as Option, + ) + .fetch_n(op, first) + .await? + } + }; + + let end_cursor = entities.last().map(cursor_mod::EntitiesByIdCursor::from); + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn list_same_column() { + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let id = syn::Ident::new("EntityId", proc_macro2::Span::call_site()); + let column = Column::new( + syn::Ident::new("email", proc_macro2::Span::call_site()), + syn::parse_str("String").unwrap(), + ); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListForFn { + ignore_prefix: None, + entity: &entity, + id: &id, + for_column: &column, + by_column: &column, + table_name: "entities", + query_error, + delete: DeleteOption::No, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn list_for_email_by_email( + &self, + filter_email: impl std::convert::AsRef, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> { + self.list_for_email_by_email_in_op(self.pool(), filter_email, cursor, direction).await + } + + pub async fn list_for_email_by_email_in_op<'a, OP>( + &self, + op: OP, + filter_email: impl std::convert::AsRef, + cursor: es_entity::PaginatedQueryArgs, + direction: es_entity::ListDirection, + ) -> Result, EntityQueryError> + where + OP: es_entity::IntoOneTimeExecutor<'a> + { + let __result: Result, EntityQueryError> = async { + let filter_email = filter_email.as_ref(); + let es_entity::PaginatedQueryArgs { first, after } = cursor; + let (id, email) = if let Some(after) = after { + (Some(after.id), Some(after.email)) + } else { + (None, None) + }; + let (entities, has_next_page) = match direction { + es_entity::ListDirection::Ascending => { + es_entity::es_query!( + entity = Entity, + "SELECT email, id FROM entities WHERE ((email = $1) AND (COALESCE((email, id) > ($4, $3), $3 IS NULL))) ORDER BY email ASC, id ASC LIMIT $2", + filter_email as &str, + (first + 1) as i64, + id as Option, + email as Option, + ) + .fetch_n(op, first) + .await? + }, + es_entity::ListDirection::Descending => { + es_entity::es_query!( + entity = Entity, + "SELECT email, id FROM entities WHERE ((email = $1) AND (COALESCE((email, id) < ($4, $3), $3 IS NULL))) ORDER BY email DESC, id DESC LIMIT $2", + filter_email as &str, + (first + 1) as i64, + id as Option, + email as Option, + ) + .fetch_n(op, first) + .await? + } + }; + + let end_cursor = entities.last().map(cursor_mod::EntitiesByEmailCursor::from); + Ok(es_entity::PaginatedQueryRet { + entities, + has_next_page, + end_cursor, + }) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/mod.rs b/es-entity-macros-sqlite/src/repo/mod.rs new file mode 100644 index 00000000..fe051d2b --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/mod.rs @@ -0,0 +1,271 @@ +mod begin; +mod combo_cursor; +mod create_all_fn; +mod create_fn; +mod delete_fn; +mod error_types; +mod find_all_fn; +mod find_by_fn; +mod list_by_fn; +mod list_for_filters_fn; +mod list_for_fn; +mod nested; +mod options; +mod persist_events_batch_fn; +mod persist_events_fn; +mod populate_nested; +mod post_hydrate_hook; +mod post_persist_hook; +mod update_all_fn; +mod update_fn; + +use darling::{FromDeriveInput, ToTokens}; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use options::RepositoryOptions; + +pub fn derive(ast: syn::DeriveInput) -> darling::Result { + let opts = RepositoryOptions::from_derive_input(&ast)?; + opts.columns.validate_list_for_by_columns()?; + let repo = EsRepo::from(&opts); + Ok(quote!(#repo)) +} +pub struct EsRepo<'a> { + repo: &'a syn::Ident, + generics: &'a syn::Generics, + persist_events_fn: persist_events_fn::PersistEventsFn<'a>, + persist_events_batch_fn: persist_events_batch_fn::PersistEventsBatchFn<'a>, + update_fn: update_fn::UpdateFn<'a>, + update_all_fn: update_all_fn::UpdateAllFn<'a>, + create_fn: create_fn::CreateFn<'a>, + create_all_fn: create_all_fn::CreateAllFn<'a>, + delete_fn: delete_fn::DeleteFn<'a>, + find_by_fns: Vec>, + find_all_fn: find_all_fn::FindAllFn<'a>, + post_hydrate_hook: post_hydrate_hook::PostHydrateHook<'a>, + post_persist_hook: post_persist_hook::PostPersistHook<'a>, + begin: begin::Begin<'a>, + list_by_fns: Vec>, + list_for_fns: Vec>, + nested_fns: Vec, + nested: Vec>, + populate_nested: Option>, + error_types: error_types::ErrorTypes<'a>, + opts: &'a RepositoryOptions, +} + +impl<'a> From<&'a RepositoryOptions> for EsRepo<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + let find_by_fns = opts + .columns + .all_find_by() + .map(|c| find_by_fn::FindByFn::new(c, opts)) + .collect(); + let list_by_fns = opts + .columns + .all_list_by() + .map(|c| list_by_fn::ListByFn::new(c, opts)) + .collect(); + let list_for_fns = opts + .columns + .all_list_for() + .flat_map(|for_col| { + for_col + .list_for_by_columns() + .iter() + .filter_map(|by_name| { + opts.columns + .find_list_by(by_name) + .map(|by_col| list_for_fn::ListForFn::new(for_col, by_col, opts)) + }) + .collect::>() + }) + .collect(); + let populate_nested = opts + .columns + .parent() + .map(|c| populate_nested::PopulateNested::new(c, opts)); + let (nested_fns, nested): (Vec<_>, Vec<_>) = opts + .all_nested() + .map(|n| (n.find_nested_fn_name(), nested::Nested::new(n, opts))) + .unzip(); + + Self { + repo: &opts.ident, + generics: &opts.generics, + persist_events_fn: persist_events_fn::PersistEventsFn::from(opts), + persist_events_batch_fn: persist_events_batch_fn::PersistEventsBatchFn::from(opts), + update_fn: update_fn::UpdateFn::from(opts), + update_all_fn: update_all_fn::UpdateAllFn::from(opts), + create_fn: create_fn::CreateFn::from(opts), + create_all_fn: create_all_fn::CreateAllFn::from(opts), + delete_fn: delete_fn::DeleteFn::from(opts), + find_by_fns, + find_all_fn: find_all_fn::FindAllFn::from(opts), + post_hydrate_hook: post_hydrate_hook::PostHydrateHook::from(opts), + post_persist_hook: post_persist_hook::PostPersistHook::from(opts), + begin: begin::Begin::from(opts), + list_by_fns, + list_for_fns, + nested_fns, + nested, + populate_nested, + error_types: error_types::ErrorTypes::new(opts), + opts, + } + } +} + +impl ToTokens for EsRepo<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let repo = &self.repo; + let persist_events_fn = &self.persist_events_fn; + let persist_events_batch_fn = &self.persist_events_batch_fn; + let update_fn = &self.update_fn; + let update_all_fn = &self.update_all_fn; + let create_fn = &self.create_fn; + let create_all_fn = &self.create_all_fn; + let delete_fn = &self.delete_fn; + let find_by_fns = &self.find_by_fns; + let find_all_fn = &self.find_all_fn; + let post_hydrate_hook = &self.post_hydrate_hook; + let post_persist_hook = &self.post_persist_hook; + let begin = &self.begin; + let cursors = self.list_by_fns.iter().map(|l| l.cursor()); + let combo_cursor = combo_cursor::ComboCursor::new( + self.opts, + self.list_by_fns.iter().map(|l| l.cursor()).collect(), + ); + let sort_by = combo_cursor.sort_by(); + let list_for_filters = list_for_filters_fn::ListForFiltersFn::new( + self.opts, + self.opts.columns.all_list_for().collect(), + self.opts.columns.all_list_by().collect(), + &combo_cursor, + ); + let list_for_filters_struct = &list_for_filters.filters_struct; + #[cfg(feature = "graphql")] + let gql_combo_cursor = combo_cursor.gql_cursor(); + #[cfg(not(feature = "graphql"))] + let gql_combo_cursor = TokenStream::new(); + #[cfg(feature = "graphql")] + let gql_cursors: Vec<_> = self + .list_by_fns + .iter() + .map(|l| l.cursor().gql_cursor()) + .collect(); + #[cfg(not(feature = "graphql"))] + let gql_cursors: Vec = Vec::new(); + let list_by_fns = &self.list_by_fns; + let list_for_fns = &self.list_for_fns; + + let entity = self.opts.entity(); + let event = self.opts.event(); + let id = self.opts.id(); + + let cursor_mod = self.opts.cursor_mod(); + let types_mod = self.opts.repo_types_mod(); + + let nested_fns = &self.nested_fns; + let nested = &self.nested; + let populate_nested = &self.populate_nested; + + let pool_field = self.opts.pool_field(); + let es_query_flavor = if nested_fns.is_empty() { + quote! { + es_entity::EsQueryFlavorFlat + } + } else { + quote! { es_entity::EsQueryFlavorNested } + }; + + let create_error = self.opts.create_error(); + let modify_error = self.opts.modify_error(); + let find_error = self.opts.find_error(); + let query_error = self.opts.query_error(); + let error_types = self.error_types.generate(); + let map_constraint_fn = self.error_types.generate_map_constraint_fn(); + + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + + tokens.append_all(quote! { + pub mod #cursor_mod { + use super::*; + + #(#cursors)* + #(#gql_cursors)* + + #combo_cursor + #gql_combo_cursor + } + + mod #types_mod { + + use super::*; + + #[allow(non_camel_case_types)] + pub(super) type Repo__Id = #id; + #[allow(non_camel_case_types)] + pub(super) type Repo__Event = #event; + #[allow(non_camel_case_types)] + pub(super) type Repo__Entity = #entity; + #[allow(non_camel_case_types)] + pub(super) type Repo__DbEvent = es_entity::GenericEvent<#id>; + } + + #error_types + + #list_for_filters_struct + #sort_by + + impl #impl_generics #repo #ty_generics #where_clause { + #[inline(always)] + pub fn pool(&self) -> &es_entity::db::Pool { + &self.#pool_field + } + + #map_constraint_fn + #begin + #post_hydrate_hook + #post_persist_hook + #persist_events_fn + #persist_events_batch_fn + #create_fn + #create_all_fn + #update_fn + #update_all_fn + #delete_fn + #(#find_by_fns)* + #find_all_fn + #list_for_filters + #(#list_by_fns)* + #(#list_for_fns)* + #(#nested)* + } + + #populate_nested + + impl #impl_generics es_entity::EsRepo for #repo #ty_generics #where_clause { + type Entity = #entity; + type CreateError = #create_error; + type ModifyError = #modify_error; + type FindError = #find_error; + type QueryError = #query_error; + type EsQueryFlavor = #es_query_flavor; + + #[inline(always)] + async fn load_all_nested_in_op( + op: &mut OP, entities: &mut [#entity] + ) -> Result<(), __EsErr> + where + OP: es_entity::AtomicOperation, + __EsErr: From + From + Send, + { + #(Self::#nested_fns::<_, _, __EsErr>(op, entities).await?;)* + Ok(()) + } + } + }); + } +} diff --git a/es-entity-macros-sqlite/src/repo/nested.rs b/es-entity-macros-sqlite/src/repo/nested.rs new file mode 100644 index 00000000..8dea346d --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/nested.rs @@ -0,0 +1,147 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::{RepoField, RepositoryOptions}; + +pub struct Nested<'a> { + field: &'a RepoField, + parent_modify_error: syn::Ident, +} + +impl<'a> Nested<'a> { + pub fn new(field: &'a RepoField, opts: &'a RepositoryOptions) -> Nested<'a> { + Nested { + field, + parent_modify_error: opts.modify_error(), + } + } +} + +impl ToTokens for Nested<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let parent_modify_error = &self.parent_modify_error; + let repo_field = self.field.ident(); + + let nested_repo_ty = &self.field.ty; + let create_fn_name = self.field.create_nested_fn_name(); + let update_fn_name = self.field.update_nested_fn_name(); + let find_fn_name = self.field.find_nested_fn_name(); + + tokens.append_all(quote! { + async fn #create_fn_name(&self, op: &mut OP, entity: &mut P) -> Result<(), <#nested_repo_ty as es_entity::EsRepo>::CreateError> + where + P: es_entity::Parent<<#nested_repo_ty as EsRepo>::Entity>, + OP: es_entity::AtomicOperation + { + let new_children = entity.new_children_mut(); + if new_children.is_empty() { + return Ok(()); + } + + let new_children = new_children.drain(..).collect(); + let children = self.#repo_field.create_all_in_op(op, new_children).await?; + entity.inject_children(children); + Ok(()) + } + + async fn #update_fn_name(&self, op: &mut OP, entity: &mut P) -> Result<(), #parent_modify_error> + where + P: es_entity::Parent<<#nested_repo_ty as EsRepo>::Entity>, + OP: es_entity::AtomicOperation + { + for entity in entity.iter_persisted_children_mut() { + self.#repo_field.update_in_op(op, entity).await?; + } + self.#create_fn_name(op, entity).await?; + Ok(()) + } + + async fn #find_fn_name(op: &mut OP, entities: &mut [P]) -> Result<(), __EsErr> + where + OP: es_entity::AtomicOperation, + P: es_entity::Parent<<#nested_repo_ty as es_entity::EsRepo>::Entity> + es_entity::EsEntity, + #nested_repo_ty: es_entity::PopulateNested<<

::Event as es_entity::EsEvent>::EntityId>, + __EsErr: From + From + Send, + { + let lookup = entities.iter_mut().map(|e| (e.events().entity_id.clone(), e)).collect(); + <#nested_repo_ty>::populate_in_op::<_, _, __EsErr>(op, lookup).await?; + Ok(()) + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::{Ident, parse_quote}; + + #[test] + fn nested() { + let field = RepoField { + ident: Some(Ident::new("users", Span::call_site())), + ty: parse_quote! { UserRepo }, + nested: true, + pool: false, + clock: false, + entity: None, + }; + + let cursor = Nested { + field: &field, + parent_modify_error: syn::Ident::new( + "ParentModifyError", + proc_macro2::Span::call_site(), + ), + }; + + let mut tokens = TokenStream::new(); + cursor.to_tokens(&mut tokens); + + let expected = quote! { + async fn create_nested_users_in_op(&self, op: &mut OP, entity: &mut P) -> Result<(), ::CreateError> + where + P: es_entity::Parent<::Entity>, + OP: es_entity::AtomicOperation + { + let new_children = entity.new_children_mut(); + if new_children.is_empty() { + return Ok(()); + } + + let new_children = new_children.drain(..).collect(); + let children = self.users.create_all_in_op(op, new_children).await?; + entity.inject_children(children); + Ok(()) + } + + async fn update_nested_users_in_op(&self, op: &mut OP, entity: &mut P) -> Result<(), ParentModifyError> + where + P: es_entity::Parent<::Entity>, + OP: es_entity::AtomicOperation + { + for entity in entity.iter_persisted_children_mut() { + self.users.update_in_op(op, entity).await?; + } + self.create_nested_users_in_op(op, entity).await?; + Ok(()) + } + + async fn find_nested_users_in_op(op: &mut OP, entities: &mut [P]) -> Result<(), __EsErr> + where + OP: es_entity::AtomicOperation, + P: es_entity::Parent<::Entity> + es_entity::EsEntity, + UserRepo: es_entity::PopulateNested<<

::Event as es_entity::EsEvent>::EntityId>, + __EsErr: From + From + Send, + { + let lookup = entities.iter_mut().map(|e| (e.events().entity_id.clone(), e)).collect(); + ::populate_in_op::<_, _, __EsErr>(op, lookup).await?; + Ok(()) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/options/columns.rs b/es-entity-macros-sqlite/src/repo/options/columns.rs new file mode 100644 index 00000000..ad381f56 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/options/columns.rs @@ -0,0 +1,695 @@ +use darling::FromMeta; +use quote::quote; + +#[derive(Default)] +pub struct Columns { + all: Vec, +} + +impl Columns { + #[cfg(test)] + pub fn new(id: &syn::Ident, columns: impl IntoIterator) -> Self { + let all = columns.into_iter().collect(); + let mut res = Columns { all }; + res.set_id_column(id); + res + } + + pub fn set_id_column(&mut self, ty: &syn::Ident) { + let mut all = vec![ + Column::for_created_at(), + Column::for_id(syn::parse_str(&ty.to_string()).unwrap()), + ]; + all.append(&mut self.all); + self.all = all; + } + + pub fn all_find_by(&self) -> impl Iterator { + self.all.iter().filter(|c| c.opts.find_by()) + } + + pub fn all_list_by(&self) -> impl Iterator { + self.all.iter().filter(|c| c.opts.list_by()) + } + + pub fn all_list_for(&self) -> impl Iterator { + self.all.iter().filter(|c| c.opts.list_for()) + } + + pub fn find_list_by(&self, name: &syn::Ident) -> Option<&Column> { + self.all + .iter() + .find(|c| c.name() == name && c.opts.list_by()) + } + + pub fn validate_list_for_by_columns(&self) -> darling::Result<()> { + let mut errors = darling::Error::accumulator(); + for col in self.all.iter().filter(|c| c.opts.list_for()) { + for by_name in col.list_for_by_columns() { + if self.find_list_by(by_name).is_none() { + let available: Vec<_> = + self.all_list_by().map(|c| c.name().to_string()).collect(); + errors.push(darling::Error::custom(format!( + "column '{}' in list_for(by(...)) on '{}' is not a list_by column. Available list_by columns: {}", + by_name, + col.name(), + available.join(", "), + ))); + } + } + } + errors.finish() + } + + /// Returns columns for the Column enum (id + user columns, not created_at) + pub fn column_enum_columns(&self) -> impl Iterator { + self.all.iter().filter(|c| *c.name() != "created_at") + } + + pub fn parent(&self) -> Option<&Column> { + self.all.iter().find(|c| c.opts.parent_opts.is_some()) + } + + pub fn updates_needed(&self) -> bool { + self.all.iter().any(|c| c.opts.persist_on_update()) + } + + pub fn variable_assignments_for_update(&self, ident: syn::Ident) -> proc_macro2::TokenStream { + let assignments = self.all.iter().filter_map(|c| { + if c.opts.persist_on_update() || c.opts.is_id { + Some(c.variable_assignment_for_update(&ident)) + } else { + None + } + }); + quote! { + #(#assignments)* + } + } + + pub fn variable_assignments_for_create(&self, ident: syn::Ident) -> proc_macro2::TokenStream { + let assignments = self.all.iter().filter_map(|c| { + if c.opts.persist_on_create() { + Some(c.variable_assignment_for_create(&ident)) + } else { + None + } + }); + quote! { + #(#assignments)* + } + } + + pub fn create_query_args(&self) -> Vec { + self.all + .iter() + .filter(|c| c.opts.persist_on_create()) + .map(|column| { + let ident = &column.name; + quote! { + .bind(#ident) + } + }) + .collect() + } + + pub fn insert_column_names(&self) -> Vec { + self.all + .iter() + .filter_map(|c| { + if c.opts.persist_on_create() { + Some(c.name.to_string()) + } else { + None + } + }) + .collect() + } + + pub fn insert_placeholders(&self, offset: usize) -> String { + let count = self + .all + .iter() + .filter(|c| c.opts.persist_on_create()) + .count(); + ((1 + offset)..=(count + offset)) + .map(|i| format!("?{i}")) + .collect::>() + .join(", ") + } + + pub fn sql_updates(&self) -> String { + self.all + .iter() + .skip(1) + .filter(|c| c.opts.persist_on_update()) + .enumerate() + .map(|(idx, column)| format!("{} = ?{}", column.name, idx + 2)) + .collect::>() + .join(", ") + } + + pub fn update_query_args(&self) -> Vec { + self.all + .iter() + .filter(|c| c.opts.persist_on_update() || c.opts.is_id) + .map(|column| { + let ident = &column.name; + quote! { + .bind(#ident) + } + }) + .collect() + } +} + +impl FromMeta for Columns { + fn from_list(items: &[darling::ast::NestedMeta]) -> darling::Result { + let all = items + .iter() + .map(Column::from_nested_meta) + .collect::, _>>()?; + Ok(Columns { all }) + } +} + +#[derive(PartialEq)] +pub struct Column { + name: syn::Ident, + opts: ColumnOpts, +} + +impl FromMeta for Column { + fn from_nested_meta(item: &darling::ast::NestedMeta) -> darling::Result { + match item { + darling::ast::NestedMeta::Meta( + meta @ syn::Meta::NameValue(syn::MetaNameValue { + value: + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }), + .. + }), + ) => { + let name = meta.path().get_ident().cloned().ok_or_else(|| { + darling::Error::custom("Expected identifier").with_span(meta.path()) + })?; + Ok(Column::new(name, syn::parse_str(&lit_str.value())?)) + } + darling::ast::NestedMeta::Meta(meta @ syn::Meta::List(_)) => { + let name = meta.path().get_ident().cloned().ok_or_else(|| { + darling::Error::custom("Expected identifier").with_span(meta.path()) + })?; + let column = Column { + name, + opts: ColumnOpts::from_meta(meta)?, + }; + Ok(column) + } + _ => Err( + darling::Error::custom("Expected name-value pair or attribute list") + .with_span(item), + ), + } + } +} + +impl Column { + pub fn new(name: syn::Ident, ty: syn::Type) -> Self { + Column { + name, + opts: ColumnOpts::new(ty), + } + } + + #[cfg(test)] + pub fn new_list_for(name: syn::Ident, ty: syn::Type, by_columns: Vec) -> Self { + Column { + name, + opts: ColumnOpts { + list_for_opts: Some(ListForOpts { by_columns }), + ..ColumnOpts::new(ty) + }, + } + } + + pub fn for_id(ty: syn::Type) -> Self { + Column { + name: syn::Ident::new("id", proc_macro2::Span::call_site()), + opts: ColumnOpts { + ty, + is_id: true, + list_by: Some(true), + find_by: Some(true), + list_for_opts: None, + parent_opts: None, + create_opts: Some(CreateOpts { + persist: Some(true), + accessor: None, + }), + update_opts: Some(UpdateOpts { + persist: Some(false), + accessor: None, + }), + constraint: None, + }, + } + } + + pub fn for_created_at() -> Self { + Column { + name: syn::Ident::new("created_at", proc_macro2::Span::call_site()), + opts: ColumnOpts { + ty: syn::parse_quote!( + es_entity::prelude::chrono::DateTime + ), + is_id: false, + list_by: Some(true), + find_by: Some(false), + list_for_opts: None, + parent_opts: None, + create_opts: Some(CreateOpts { + persist: Some(false), + accessor: None, + }), + update_opts: Some(UpdateOpts { + persist: Some(false), + accessor: Some(syn::parse_quote!( + events() + .entity_first_persisted_at() + .expect("entity not persisted") + )), + }), + constraint: None, + }, + } + } + + pub fn list_for_by_columns(&self) -> &[syn::Ident] { + self.opts.list_for_by_columns() + } + + pub fn custom_constraint(&self) -> Option<&str> { + self.opts.constraint.as_deref() + } + + pub fn is_id(&self) -> bool { + self.opts.is_id + } + + pub fn is_optional(&self) -> bool { + if let syn::Type::Path(type_path) = self.ty() + && type_path.path.segments.len() == 1 + { + let segment = &type_path.path.segments[0]; + if segment.ident == "Option" { + return true; + } + } + false + } + + pub fn name(&self) -> &syn::Ident { + &self.name + } + + pub fn ty(&self) -> &syn::Type { + &self.opts.ty + } + + pub fn ty_for_find_by( + &self, + ) -> ( + syn::Type, + proc_macro2::TokenStream, + proc_macro2::TokenStream, + ) { + if let syn::Type::Path(type_path) = self.ty() + && type_path.path.is_ident("String") + { + ( + syn::parse_quote! { str }, + quote! { impl std::convert::AsRef }, + quote! { as_ref() }, + ) + } else { + let ty = &self.ty(); + ( + self.ty().clone(), + quote! { impl std::borrow::Borrow<#ty> }, + quote! { borrow() }, + ) + } + } + + pub fn accessor(&self) -> proc_macro2::TokenStream { + self.opts.update_accessor(&self.name) + } + + pub fn parent_accessor(&self) -> proc_macro2::TokenStream { + self.opts.parent_accessor(&self.name) + } + + fn variable_assignment_for_create(&self, ident: &syn::Ident) -> proc_macro2::TokenStream { + let name = &self.name; + let accessor = self.opts.create_accessor(name); + quote! { + let #name = &#ident.#accessor; + } + } + + fn variable_assignment_for_update(&self, ident: &syn::Ident) -> proc_macro2::TokenStream { + let name = &self.name; + let accessor = self.opts.update_accessor(name); + quote! { + let #name = &#ident.#accessor; + } + } +} + +#[derive(PartialEq, FromMeta)] +struct ColumnOpts { + ty: syn::Type, + #[darling(default, skip)] + is_id: bool, + #[darling(default)] + find_by: Option, + #[darling(default)] + list_by: Option, + #[darling(default, rename = "list_for")] + list_for_opts: Option, + #[darling(default, rename = "parent")] + parent_opts: Option, + #[darling(default, rename = "create")] + create_opts: Option, + #[darling(default, rename = "update")] + update_opts: Option, + #[darling(default)] + constraint: Option, +} + +impl ColumnOpts { + fn new(ty: syn::Type) -> Self { + ColumnOpts { + ty, + is_id: false, + find_by: None, + list_by: None, + list_for_opts: None, + parent_opts: None, + create_opts: None, + update_opts: None, + constraint: None, + } + } + + fn find_by(&self) -> bool { + self.find_by.unwrap_or(true) + } + + fn list_by(&self) -> bool { + self.list_by.unwrap_or(false) + } + + fn list_for(&self) -> bool { + self.list_for_opts.is_some() + } + + fn list_for_by_columns(&self) -> &[syn::Ident] { + self.list_for_opts + .as_ref() + .map(|o| o.by_columns.as_slice()) + .unwrap_or(&[]) + } + + fn persist_on_create(&self) -> bool { + self.create_opts + .as_ref() + .is_none_or(|o| o.persist.unwrap_or(true)) + } + + fn create_accessor(&self, name: &syn::Ident) -> proc_macro2::TokenStream { + if let Some(accessor) = &self.create_opts.as_ref().and_then(|o| o.accessor.as_ref()) { + quote! { + #accessor + } + } else { + quote! { + #name + } + } + } + + fn persist_on_update(&self) -> bool { + self.update_opts + .as_ref() + .is_none_or(|o| o.persist.unwrap_or(true)) + } + + fn update_accessor(&self, name: &syn::Ident) -> proc_macro2::TokenStream { + if let Some(accessor) = &self.update_opts.as_ref().and_then(|o| o.accessor.as_ref()) { + quote! { + #accessor + } + } else { + quote! { + #name + } + } + } + + fn parent_accessor(&self, name: &syn::Ident) -> proc_macro2::TokenStream { + if let Some(accessor) = &self.parent_opts.as_ref().and_then(|o| o.accessor.as_ref()) { + quote! { + #accessor + } + } else { + self.update_accessor(name) + } + } +} + +#[derive(Default, PartialEq, FromMeta)] +struct CreateOpts { + persist: Option, + accessor: Option, +} + +#[derive(Default, PartialEq, FromMeta)] +struct UpdateOpts { + persist: Option, + accessor: Option, +} + +#[derive(PartialEq, Debug, Default)] +struct ListForOpts { + by_columns: Vec, +} + +impl FromMeta for ListForOpts { + fn from_word() -> darling::Result { + Ok(ListForOpts { + by_columns: vec![syn::Ident::new("id", proc_macro2::Span::call_site())], + }) + } + + fn from_bool(value: bool) -> darling::Result { + if value { + Self::from_word() + } else { + Err(darling::Error::custom( + "list_for = false is not supported; remove list_for entirely to disable", + )) + } + } + + fn from_list(items: &[darling::ast::NestedMeta]) -> darling::Result { + let mut by_columns = Vec::new(); + for item in items { + match item { + darling::ast::NestedMeta::Meta(syn::Meta::List(list)) + if list.path.is_ident("by") => + { + let inner: syn::punctuated::Punctuated = + list.parse_args_with(syn::punctuated::Punctuated::parse_terminated)?; + by_columns.extend(inner); + } + _ => { + return Err( + darling::Error::custom("Expected `by(col1, col2, ...)`").with_span(item) + ); + } + } + } + Ok(ListForOpts { by_columns }) + } +} + +#[derive(PartialEq, Debug, Default)] +struct ParentOpts { + accessor: Option, +} + +impl FromMeta for ParentOpts { + fn from_word() -> darling::Result { + Ok(ParentOpts::default()) + } + + fn from_list(items: &[darling::ast::NestedMeta]) -> darling::Result { + #[derive(FromMeta)] + struct Inner { + #[darling(default)] + accessor: Option, + } + + let inner = Inner::from_list(items)?; + Ok(ParentOpts { + accessor: inner.accessor, + }) + } +} + +#[cfg(test)] +mod tests { + use darling::FromMeta; + use syn::parse_quote; + + use super::*; + + #[test] + fn column_opts_from_list() { + let input: syn::Meta = parse_quote!(thing( + ty = "crate::module::Thing", + list_by = false, + create(persist = true, accessor = accessor_fn()), + )); + let values = ColumnOpts::from_meta(&input).expect("Failed to parse Field"); + assert_eq!(values.ty, parse_quote!(crate::module::Thing)); + assert!(!values.list_by()); + assert!(values.find_by()); + // assert!(values.update()); + assert_eq!( + values.create_opts.unwrap().accessor.unwrap(), + parse_quote!(accessor_fn()) + ); + } + + #[test] + fn columns_from_list() { + let input: syn::Meta = parse_quote!(columns( + name = "String", + email( + ty = "String", + list_by = false, + create(accessor = "email()"), + update(persist = false) + ) + )); + let columns = Columns::from_meta(&input).expect("Failed to parse Fields"); + assert_eq!(columns.all.len(), 2); + + assert_eq!(columns.all[0].name.to_string(), "name"); + + assert_eq!(columns.all[1].name.to_string(), "email"); + assert!(!columns.all[1].opts.list_by()); + assert_eq!( + columns.all[1] + .opts + .create_accessor(&parse_quote!(email)) + .to_string(), + quote!(email()).to_string() + ); + assert!(!columns.all[1].opts.persist_on_update()); + } + + #[test] + fn parent_opts_from_list() { + let input: syn::Meta = parse_quote!(thing(ty = "String", parent)); + let values = ColumnOpts::from_meta(&input).expect("Failed to parse Field"); + assert_eq!(values.ty, parse_quote!(String)); + assert!(values.parent_opts.is_some()); + + let input: syn::Meta = parse_quote!(thing(ty = "String", parent(accessor = "parent_id()"))); + let values = ColumnOpts::from_meta(&input).expect("Failed to parse Field"); + assert_eq!(values.ty, parse_quote!(String)); + assert!(values.parent_opts.is_some()); + assert_eq!( + values.parent_accessor(&parse_quote!(thing)).to_string(), + quote!(parent_id()).to_string() + ); + } + + #[test] + fn list_for_bare_word() { + let input: syn::Meta = parse_quote!(thing(ty = "String", list_for)); + let values = ColumnOpts::from_meta(&input).expect("Failed to parse Field"); + assert!(values.list_for()); + assert_eq!(values.list_for_by_columns().len(), 1); + assert_eq!(values.list_for_by_columns()[0].to_string(), "id"); + } + + #[test] + fn list_for_with_by_columns() { + let input: syn::Meta = parse_quote!(thing(ty = "String", list_for(by(created_at)))); + let values = ColumnOpts::from_meta(&input).expect("Failed to parse Field"); + assert!(values.list_for()); + assert_eq!(values.list_for_by_columns().len(), 1); + assert_eq!(values.list_for_by_columns()[0].to_string(), "created_at"); + } + + #[test] + fn list_for_by_column_must_be_list_by() { + let id_ident: syn::Ident = parse_quote!(TestId); + let col = Column::new_list_for( + parse_quote!(status), + syn::parse_str("String").unwrap(), + vec![parse_quote!(nonexistent)], + ); + let columns = Columns::new(&id_ident, vec![col]); + let result = columns.validate_list_for_by_columns(); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("nonexistent"), + "error should mention the invalid column name: {err}" + ); + assert!( + err.contains("list_by"), + "error should mention list_by: {err}" + ); + } + + #[test] + fn list_for_by_valid_column_passes_validation() { + let id_ident: syn::Ident = parse_quote!(TestId); + let col = Column::new_list_for( + parse_quote!(status), + syn::parse_str("String").unwrap(), + vec![parse_quote!(id)], + ); + let columns = Columns::new(&id_ident, vec![col]); + let result = columns.validate_list_for_by_columns(); + assert!(result.is_ok()); + } + + #[test] + fn list_for_with_multiple_by_columns() { + let input: syn::Meta = parse_quote!(thing(ty = "String", list_for(by(created_at, id)))); + let values = ColumnOpts::from_meta(&input).expect("Failed to parse Field"); + assert!(values.list_for()); + assert_eq!(values.list_for_by_columns().len(), 2); + assert_eq!(values.list_for_by_columns()[0].to_string(), "created_at"); + assert_eq!(values.list_for_by_columns()[1].to_string(), "id"); + } + + #[test] + fn custom_constraint() { + let input: syn::Meta = + parse_quote!(job_type(ty = "String", constraint = "idx_unique_job_type")); + let column = Column::from_nested_meta(&darling::ast::NestedMeta::Meta(input)) + .expect("Failed to parse Column"); + assert_eq!(column.name().to_string(), "job_type"); + assert_eq!(column.custom_constraint(), Some("idx_unique_job_type")); + } +} diff --git a/es-entity-macros-sqlite/src/repo/options/delete.rs b/es-entity-macros-sqlite/src/repo/options/delete.rs new file mode 100644 index 00000000..7ef88d9f --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/options/delete.rs @@ -0,0 +1,42 @@ +use darling::FromMeta; + +#[derive(Debug, Default, Clone, Copy, FromMeta, PartialEq)] +pub enum DeleteOption { + #[default] + No, + Soft, + SoftWithoutQueries, +} + +impl DeleteOption { + pub fn include_deletion_fn_postfix(&self) -> &'static str { + match self { + DeleteOption::Soft | DeleteOption::SoftWithoutQueries => "_include_deleted", + DeleteOption::No => "", + } + } + + pub fn not_deleted_condition(&self) -> &'static str { + match self { + DeleteOption::Soft | DeleteOption::SoftWithoutQueries => " AND deleted = FALSE", + DeleteOption::No => "", + } + } + + pub fn is_soft(&self) -> bool { + matches!(self, DeleteOption::Soft | DeleteOption::SoftWithoutQueries) + } +} + +impl std::str::FromStr for DeleteOption { + type Err = darling::Error; + + fn from_str(s: &str) -> Result { + match s { + "no" => Ok(DeleteOption::No), + "soft" => Ok(DeleteOption::Soft), + "soft_without_queries" => Ok(DeleteOption::SoftWithoutQueries), + _ => Err(darling::Error::unknown_value(s)), + } + } +} diff --git a/es-entity-macros-sqlite/src/repo/options/mod.rs b/es-entity-macros-sqlite/src/repo/options/mod.rs new file mode 100644 index 00000000..8f0670ec --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/options/mod.rs @@ -0,0 +1,447 @@ +mod columns; +mod delete; + +use convert_case::{Case, Casing}; +use darling::{FromDeriveInput, FromField, FromMeta}; +use proc_macro2::Span; +use quote::quote; + +pub use columns::*; +pub use delete::*; + +#[derive(Debug, Clone)] +pub struct PostPersistHookConfig { + pub method: syn::Ident, + pub error: syn::Type, +} + +impl FromMeta for PostPersistHookConfig { + /// Old syntax: `post_persist_hook = "method_name"` → defaults error to `sqlx::Error` + fn from_string(value: &str) -> darling::Result { + Ok(PostPersistHookConfig { + method: syn::Ident::new(value, Span::call_site()), + error: syn::parse_str("sqlx::Error") + .map_err(|e| darling::Error::custom(format!("invalid error type: {e}")))?, + }) + } + + /// New syntax: `post_persist_hook(method = "...", error = "...")` + /// `error` defaults to `sqlx::Error` if omitted + fn from_list(items: &[darling::ast::NestedMeta]) -> darling::Result { + let mut method: Option = None; + let mut error: Option = None; + + for item in items { + if let darling::ast::NestedMeta::Meta(syn::Meta::NameValue(nv)) = item { + if nv.path.is_ident("method") + && let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(s), + .. + }) = &nv.value + { + method = Some(syn::Ident::new(&s.value(), s.span())); + } else if nv.path.is_ident("error") + && let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(s), + .. + }) = &nv.value + { + error = + Some(syn::parse_str(&s.value()).map_err(|e| { + darling::Error::custom(format!("invalid error type: {e}")) + })?); + } + } + } + + let error = error + .unwrap_or_else(|| syn::parse_str("sqlx::Error").expect("sqlx::Error is a valid type")); + + Ok(PostPersistHookConfig { + method: method + .ok_or_else(|| darling::Error::custom("missing `method` in post_persist_hook"))?, + error, + }) + } +} + +#[derive(Debug, Clone)] +pub struct PostHydrateHookConfig { + pub method: syn::Ident, + pub error: syn::Type, +} + +impl FromMeta for PostHydrateHookConfig { + fn from_list(items: &[darling::ast::NestedMeta]) -> darling::Result { + let mut method: Option = None; + let mut error: Option = None; + + for item in items { + if let darling::ast::NestedMeta::Meta(syn::Meta::NameValue(nv)) = item { + if nv.path.is_ident("method") { + if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(s), + .. + }) = &nv.value + { + method = Some(syn::Ident::new(&s.value(), s.span())); + } + } else if nv.path.is_ident("error") + && let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(s), + .. + }) = &nv.value + { + error = + Some(syn::parse_str(&s.value()).map_err(|e| { + darling::Error::custom(format!("invalid error type: {e}")) + })?); + } + } + } + + Ok(PostHydrateHookConfig { + method: method + .ok_or_else(|| darling::Error::custom("missing `method` in post_hydrate_hook"))?, + error: error + .ok_or_else(|| darling::Error::custom("missing `error` in post_hydrate_hook"))?, + }) + } +} + +/// Information about the clock field in a repository +#[derive(Debug, Clone)] +pub enum ClockFieldInfo<'a> { + /// No clock field present + None, + /// Clock field is `Option` - use if Some, fallback to global + Optional(&'a syn::Ident), + /// Clock field is `ClockHandle` - always use it + Required(&'a syn::Ident), +} + +#[derive(FromField)] +#[darling(attributes(es_repo))] +pub struct RepoField { + pub ident: Option, + pub ty: syn::Type, + #[darling(default)] + pub pool: bool, + #[darling(default)] + pub clock: bool, + #[darling(default)] + pub nested: bool, + /// For nested fields whose repo type is generic, specify the child entity name + /// so error types can be referenced concretely (e.g., `entity = "InterestAccrualCycle"` + /// generates `InterestAccrualCycleCreateError` instead of + /// ` as EsRepo>::CreateError`). + #[darling(default)] + pub entity: Option, +} + +impl RepoField { + pub fn ident(&self) -> &syn::Ident { + self.ident.as_ref().expect("Field must have an identifier") + } + + fn is_pool_field(&self) -> bool { + self.pool || self.ident.as_ref().is_some_and(|i| i == "pool") + } + + fn is_clock_field(&self) -> bool { + self.clock || self.ident.as_ref().is_some_and(|i| i == "clock") + } + + /// Check if the field type is `Option<...>` + fn is_option_type(&self) -> bool { + if let syn::Type::Path(type_path) = &self.ty + && let Some(segment) = type_path.path.segments.last() + { + return segment.ident == "Option"; + } + false + } + + pub fn create_nested_fn_name(&self) -> syn::Ident { + syn::Ident::new( + &format!("create_nested_{}_in_op", self.ident()), + proc_macro2::Span::call_site(), + ) + } + + pub fn update_nested_fn_name(&self) -> syn::Ident { + syn::Ident::new( + &format!("update_nested_{}_in_op", self.ident()), + proc_macro2::Span::call_site(), + ) + } + + pub fn find_nested_fn_name(&self) -> syn::Ident { + syn::Ident::new( + &format!("find_nested_{}_in_op", self.ident()), + proc_macro2::Span::call_site(), + ) + } + + /// PascalCase variant name derived from field name (e.g. `line_items` -> `LineItems`) + pub fn nested_variant_name(&self) -> syn::Ident { + syn::Ident::new( + &self.ident().to_string().to_case(Case::UpperCamel), + Span::call_site(), + ) + } +} + +#[derive(FromDeriveInput)] +#[darling(attributes(es_repo), map = "Self::update_defaults")] +pub struct RepositoryOptions { + pub ident: syn::Ident, + pub generics: syn::Generics, + #[darling(default)] + pub columns: Columns, + #[darling(default)] + pub post_persist_hook: Option, + #[darling(default)] + pub post_hydrate_hook: Option, + #[darling(default)] + pub delete: DeleteOption, + + data: darling::ast::Data<(), RepoField>, + + #[darling(rename = "entity")] + entity_ident: syn::Ident, + #[darling(default, rename = "event")] + event_ident: Option, + #[darling(default, rename = "id")] + id_ty: Option, + #[darling(default, rename = "tbl_prefix")] + prefix: Option, + #[darling(default, rename = "tbl")] + table_name: Option, + #[darling(default, rename = "events_tbl")] + events_table_name: Option, + + #[darling(default)] + persist_event_context: Option, +} + +impl RepositoryOptions { + fn update_defaults(mut self) -> Self { + let entity_name = self.entity_ident.to_string(); + if self.event_ident.is_none() { + self.event_ident = Some(syn::Ident::new( + &format!("{entity_name}Event"), + proc_macro2::Span::call_site(), + )); + } + if self.id_ty.is_none() { + self.id_ty = Some(syn::Ident::new( + &format!("{entity_name}Id"), + proc_macro2::Span::call_site(), + )); + } + let prefix = if let Some(prefix) = &self.prefix { + format!("{}_", prefix.value()) + } else { + String::new() + }; + if self.table_name.is_none() { + self.table_name = Some(format!( + "{prefix}{}", + pluralizer::pluralize(&entity_name, 2, false).to_case(Case::Snake) + )); + } + if self.events_table_name.is_none() { + self.events_table_name = + Some(format!("{prefix}{entity_name}Events").to_case(Case::Snake)); + } + + self.columns + .set_id_column(self.id_ty.as_ref().expect("Id not set")); + + self + } + + pub fn entity(&self) -> &syn::Ident { + &self.entity_ident + } + + pub fn table_name(&self) -> &str { + self.table_name.as_ref().expect("Table name is not set") + } + + pub fn table_prefix(&self) -> Option<&syn::LitStr> { + self.prefix.as_ref() + } + + pub fn id(&self) -> &syn::Ident { + self.id_ty.as_ref().expect("ID identifier is not set") + } + + pub fn event(&self) -> &syn::Ident { + self.event_ident + .as_ref() + .expect("Event identifier is not set") + } + + pub fn event_context_enabled(&self) -> bool { + #[cfg(feature = "event-context-enabled")] + { + self.persist_event_context.unwrap_or(true) + } + #[cfg(not(feature = "event-context-enabled"))] + { + self.persist_event_context.unwrap_or(false) + } + } + + pub fn events_table_name(&self) -> &str { + self.events_table_name + .as_ref() + .expect("Events table name is not set") + } + + pub fn cursor_mod(&self) -> syn::Ident { + let name = format!("{}Cursor", self.entity_ident).to_case(Case::Snake); + syn::Ident::new(&name, proc_macro2::Span::call_site()) + } + + pub fn repo_types_mod(&self) -> syn::Ident { + let name = format!("{}RepoTypes", self.entity_ident).to_case(Case::Snake); + syn::Ident::new(&name, proc_macro2::Span::call_site()) + } + + #[cfg(feature = "instrument")] + pub fn repo_name_snake_case(&self) -> String { + self.ident.to_string().to_case(Case::Snake) + } + + pub fn pool_field(&self) -> &syn::Ident { + let field = match &self.data { + darling::ast::Data::Struct(fields) => fields.iter().find_map(|field| { + if field.is_pool_field() { + Some(field.ident.as_ref().unwrap()) + } else { + None + } + }), + _ => None, + }; + field.expect("Repo must have a field named 'pool' or marked with #[es_repo(pool)]") + } + + pub fn clock_field(&self) -> ClockFieldInfo<'_> { + match &self.data { + darling::ast::Data::Struct(fields) => { + for field in fields.iter() { + if field.is_clock_field() { + let ident = field.ident.as_ref().unwrap(); + return if field.is_option_type() { + ClockFieldInfo::Optional(ident) + } else { + ClockFieldInfo::Required(ident) + }; + } + } + ClockFieldInfo::None + } + _ => ClockFieldInfo::None, + } + } + + pub fn any_nested(&self) -> bool { + if let darling::ast::Data::Struct(fields) = &self.data { + fields.iter().any(|f| f.nested) + } else { + panic!("Repository must be a struct") + } + } + + pub fn all_nested(&self) -> impl Iterator { + if let darling::ast::Data::Struct(fields) = &self.data { + fields.iter().filter(|f| f.nested) + } else { + panic!("Repository must be a struct") + } + } + + pub fn query_fn_generics(nested: bool) -> proc_macro2::TokenStream { + if nested { + quote! { + + } + } else { + quote! { + <'a, OP> + } + } + } + + pub fn query_fn_op_arg(nested: bool) -> proc_macro2::TokenStream { + if nested { + quote! { + op: &mut OP + } + } else { + quote! { + op: OP + } + } + } + + pub fn query_fn_op_traits(nested: bool) -> proc_macro2::TokenStream { + if nested { + quote! { + es_entity::AtomicOperation + } + } else { + quote! { + es_entity::IntoOneTimeExecutor<'a> + } + } + } + + pub fn create_error(&self) -> syn::Ident { + syn::Ident::new( + &format!("{}CreateError", self.entity_ident), + Span::call_site(), + ) + } + + pub fn modify_error(&self) -> syn::Ident { + syn::Ident::new( + &format!("{}ModifyError", self.entity_ident), + Span::call_site(), + ) + } + + pub fn find_error(&self) -> syn::Ident { + syn::Ident::new( + &format!("{}FindError", self.entity_ident), + Span::call_site(), + ) + } + + pub fn query_error(&self) -> syn::Ident { + syn::Ident::new( + &format!("{}QueryError", self.entity_ident), + Span::call_site(), + ) + } + + pub fn column_enum(&self) -> syn::Ident { + syn::Ident::new(&format!("{}Column", self.entity_ident), Span::call_site()) + } + + pub fn query_fn_get_op(nested: bool) -> proc_macro2::TokenStream { + if nested { + quote! { + &mut self.pool().begin().await? + } + } else { + quote! { + self.pool() + } + } + } +} diff --git a/es-entity-macros-sqlite/src/repo/persist_events_batch_fn.rs b/es-entity-macros-sqlite/src/repo/persist_events_batch_fn.rs new file mode 100644 index 00000000..ea407df1 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/persist_events_batch_fn.rs @@ -0,0 +1,284 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct PersistEventsBatchFn<'a> { + id: &'a syn::Ident, + event: &'a syn::Ident, + events_table_name: &'a str, + event_ctx: bool, +} + +impl<'a> From<&'a RepositoryOptions> for PersistEventsBatchFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + id: opts.id(), + event: opts.event(), + events_table_name: opts.events_table_name(), + event_ctx: opts.event_context_enabled(), + } + } +} + +impl ToTokens for PersistEventsBatchFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let id_type = &self.id; + let event_type = &self.event; + let events_table_name = self.events_table_name; + + let (insert_query, ctx_var, ctx_extend, ctx_bind) = if self.event_ctx { + ( + format!( + "INSERT INTO {} (id, recorded_at, sequence, event_type, event, context) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + events_table_name + ), + quote! { + let mut all_contexts: Vec> = Vec::new(); + }, + quote! { + let contexts = events.serialize_new_event_contexts(); + if let Some(contexts) = contexts { + all_contexts.extend(contexts.into_iter().map(Some)); + } else { + all_contexts.extend(std::iter::repeat(None).take(n_events)); + } + }, + quote! { + let context = all_contexts.get(i).and_then(|c| c.as_ref()); + query = query.bind(context); + }, + ) + } else { + ( + format!( + "INSERT INTO {} (id, recorded_at, sequence, event_type, event) VALUES (?1, ?2, ?3, ?4, ?5)", + events_table_name + ), + quote! {}, + quote! {}, + quote! {}, + ) + }; + + tokens.append_all(quote! { + async fn persist_events_batch( + &self, + op: &mut OP, + all_events: &mut [B] + ) -> Result, sqlx::Error> + where + OP: es_entity::AtomicOperation, + B: std::borrow::BorrowMut>, + { + let mut all_serialized = Vec::new(); + #ctx_var + let mut all_types = Vec::new(); + let mut all_ids: Vec<&#id_type> = Vec::new(); + let mut all_sequences = Vec::new(); + let now = op.maybe_now(); + let recorded_at = now.unwrap_or_else(|| es_entity::prelude::chrono::Utc::now()); + + let mut n_events_map = std::collections::HashMap::new(); + for item in all_events.iter() { + let events: &es_entity::EntityEvents<#event_type> = item.borrow(); + let id = events.id(); + let offset = events.len_persisted() + 1; + let types = events.new_event_types(); + let serialized = events.serialize_new_events(); + + let n_events = serialized.len(); + #ctx_extend + all_serialized.extend(serialized); + all_types.extend(types); + all_ids.extend(std::iter::repeat(id).take(n_events)); + all_sequences.extend((offset..).take(n_events).map(|i| i as i64)); + n_events_map.insert(id.clone(), n_events); + } + + for (i, ((id, sequence), (event_type, event_json))) in all_ids.iter() + .zip(all_sequences.iter()) + .zip(all_types.iter().zip(all_serialized.iter())) + .enumerate() + { + let mut query = sqlx::query(#insert_query) + .bind(*id as &#id_type) + .bind(recorded_at) + .bind(*sequence) + .bind(event_type) + .bind(event_json); + #ctx_bind + query.execute(op.as_executor()).await?; + } + + for item in all_events.iter_mut() { + let events: &mut es_entity::EntityEvents<#event_type> = item.borrow_mut(); + events.mark_new_events_persisted_at(recorded_at); + } + + Ok(n_events_map) + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn persist_events_fn() { + let id = syn::parse_str("EntityId").unwrap(); + let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site()); + let persist_fn = PersistEventsBatchFn { + id: &id, + event: &event, + events_table_name: "entity_events", + event_ctx: true, + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + async fn persist_events_batch( + &self, + op: &mut OP, + all_events: &mut [B] + ) -> Result, sqlx::Error> + where + OP: es_entity::AtomicOperation, + B: std::borrow::BorrowMut>, + { + let mut all_serialized = Vec::new(); + let mut all_contexts: Vec> = Vec::new(); + let mut all_types = Vec::new(); + let mut all_ids: Vec<&EntityId> = Vec::new(); + let mut all_sequences = Vec::new(); + let now = op.maybe_now(); + let recorded_at = now.unwrap_or_else(|| es_entity::prelude::chrono::Utc::now()); + + let mut n_events_map = std::collections::HashMap::new(); + for item in all_events.iter() { + let events: &es_entity::EntityEvents = item.borrow(); + let id = events.id(); + let offset = events.len_persisted() + 1; + let types = events.new_event_types(); + let serialized = events.serialize_new_events(); + + let n_events = serialized.len(); + let contexts = events.serialize_new_event_contexts(); + if let Some(contexts) = contexts { + all_contexts.extend(contexts.into_iter().map(Some)); + } else { + all_contexts.extend(std::iter::repeat(None).take(n_events)); + } + all_serialized.extend(serialized); + all_types.extend(types); + all_ids.extend(std::iter::repeat(id).take(n_events)); + all_sequences.extend((offset..).take(n_events).map(|i| i as i64)); + n_events_map.insert(id.clone(), n_events); + } + + for (i, ((id, sequence), (event_type, event_json))) in all_ids.iter() + .zip(all_sequences.iter()) + .zip(all_types.iter().zip(all_serialized.iter())) + .enumerate() + { + let mut query = sqlx::query("INSERT INTO entity_events (id, recorded_at, sequence, event_type, event, context) VALUES (?1, ?2, ?3, ?4, ?5, ?6)") + .bind(*id as &EntityId) + .bind(recorded_at) + .bind(*sequence) + .bind(event_type) + .bind(event_json); + let context = all_contexts.get(i).and_then(|c| c.as_ref()); + query = query.bind(context); + query.execute(op.as_executor()).await?; + } + + for item in all_events.iter_mut() { + let events: &mut es_entity::EntityEvents = item.borrow_mut(); + events.mark_new_events_persisted_at(recorded_at); + } + + Ok(n_events_map) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn persist_events_fn_without_event_context() { + let id = syn::parse_str("EntityId").unwrap(); + let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site()); + let persist_fn = PersistEventsBatchFn { + id: &id, + event: &event, + events_table_name: "entity_events", + event_ctx: false, + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + async fn persist_events_batch( + &self, + op: &mut OP, + all_events: &mut [B] + ) -> Result, sqlx::Error> + where + OP: es_entity::AtomicOperation, + B: std::borrow::BorrowMut>, + { + let mut all_serialized = Vec::new(); + let mut all_types = Vec::new(); + let mut all_ids: Vec<&EntityId> = Vec::new(); + let mut all_sequences = Vec::new(); + let now = op.maybe_now(); + let recorded_at = now.unwrap_or_else(|| es_entity::prelude::chrono::Utc::now()); + + let mut n_events_map = std::collections::HashMap::new(); + for item in all_events.iter() { + let events: &es_entity::EntityEvents = item.borrow(); + let id = events.id(); + let offset = events.len_persisted() + 1; + let types = events.new_event_types(); + let serialized = events.serialize_new_events(); + + let n_events = serialized.len(); + all_serialized.extend(serialized); + all_types.extend(types); + all_ids.extend(std::iter::repeat(id).take(n_events)); + all_sequences.extend((offset..).take(n_events).map(|i| i as i64)); + n_events_map.insert(id.clone(), n_events); + } + + for (i, ((id, sequence), (event_type, event_json))) in all_ids.iter() + .zip(all_sequences.iter()) + .zip(all_types.iter().zip(all_serialized.iter())) + .enumerate() + { + let mut query = sqlx::query("INSERT INTO entity_events (id, recorded_at, sequence, event_type, event) VALUES (?1, ?2, ?3, ?4, ?5)") + .bind(*id as &EntityId) + .bind(recorded_at) + .bind(*sequence) + .bind(event_type) + .bind(event_json); + query.execute(op.as_executor()).await?; + } + + for item in all_events.iter_mut() { + let events: &mut es_entity::EntityEvents = item.borrow_mut(); + events.mark_new_events_persisted_at(recorded_at); + } + + Ok(n_events_map) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/persist_events_fn.rs b/es-entity-macros-sqlite/src/repo/persist_events_fn.rs new file mode 100644 index 00000000..00656f71 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/persist_events_fn.rs @@ -0,0 +1,236 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct PersistEventsFn<'a> { + id: &'a syn::Ident, + event: &'a syn::Ident, + events_table_name: &'a str, + event_ctx: bool, +} + +impl<'a> From<&'a RepositoryOptions> for PersistEventsFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + id: opts.id(), + event: opts.event(), + events_table_name: opts.events_table_name(), + event_ctx: opts.event_context_enabled(), + } + } +} + +impl ToTokens for PersistEventsFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let events_table_name = self.events_table_name; + + let (insert_query, ctx_var, ctx_bind) = if self.event_ctx { + ( + format!( + "INSERT INTO {} (id, recorded_at, sequence, event_type, event, context) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + events_table_name + ), + quote! { let contexts = events.serialize_new_event_contexts(); }, + quote! { + let context = contexts.as_ref().and_then(|c| c.get(i)); + query = query.bind(context); + }, + ) + } else { + ( + format!( + "INSERT INTO {} (id, recorded_at, sequence, event_type, event) VALUES (?1, ?2, ?3, ?4, ?5)", + events_table_name + ), + quote! {}, + quote! {}, + ) + }; + let id_type = &self.id; + let event_type = &self.event; + + tokens.append_all(quote! { + fn extract_concurrent_modification>( + res: Result, + concurrent_modification: __EsErr, + ) -> Result { + match res { + Ok(v) => Ok(v), + Err(sqlx::Error::Database(ref db_err)) if db_err.is_unique_violation() => { + Err(concurrent_modification) + } + Err(e) => Err(__EsErr::from(e)), + } + } + + async fn persist_events( + &self, + op: &mut OP, + events: &mut es_entity::EntityEvents<#event_type> + ) -> Result + where + OP: es_entity::AtomicOperation, + { + let id = events.id(); + let offset = events.len_persisted(); + let events_types = events.new_event_types(); + let serialized_events = events.serialize_new_events(); + #ctx_var + let now = op.maybe_now(); + let recorded_at = now.unwrap_or_else(|| es_entity::prelude::chrono::Utc::now()); + + for (i, (event_type, event_json)) in events_types.iter().zip(serialized_events.iter()).enumerate() { + let sequence = offset as i64 + i as i64 + 1; + let mut query = sqlx::query(#insert_query) + .bind(id as &#id_type) + .bind(recorded_at) + .bind(sequence) + .bind(event_type) + .bind(event_json); + #ctx_bind + query.execute(op.as_executor()).await?; + } + + let n_events = events.mark_new_events_persisted_at(recorded_at); + + Ok(n_events) + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn persist_events_fn() { + let id = syn::parse_str("EntityId").unwrap(); + let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site()); + let persist_fn = PersistEventsFn { + id: &id, + event: &event, + events_table_name: "entity_events", + event_ctx: true, + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + fn extract_concurrent_modification>( + res: Result, + concurrent_modification: __EsErr, + ) -> Result { + match res { + Ok(v) => Ok(v), + Err(sqlx::Error::Database(ref db_err)) if db_err.is_unique_violation() => { + Err(concurrent_modification) + } + Err(e) => Err(__EsErr::from(e)), + } + } + + async fn persist_events( + &self, + op: &mut OP, + events: &mut es_entity::EntityEvents + ) -> Result + where + OP: es_entity::AtomicOperation, + { + let id = events.id(); + let offset = events.len_persisted(); + let events_types = events.new_event_types(); + let serialized_events = events.serialize_new_events(); + let contexts = events.serialize_new_event_contexts(); + let now = op.maybe_now(); + let recorded_at = now.unwrap_or_else(|| es_entity::prelude::chrono::Utc::now()); + + for (i, (event_type, event_json)) in events_types.iter().zip(serialized_events.iter()).enumerate() { + let sequence = offset as i64 + i as i64 + 1; + let mut query = sqlx::query("INSERT INTO entity_events (id, recorded_at, sequence, event_type, event, context) VALUES (?1, ?2, ?3, ?4, ?5, ?6)") + .bind(id as &EntityId) + .bind(recorded_at) + .bind(sequence) + .bind(event_type) + .bind(event_json); + let context = contexts.as_ref().and_then(|c| c.get(i)); + query = query.bind(context); + query.execute(op.as_executor()).await?; + } + + let n_events = events.mark_new_events_persisted_at(recorded_at); + + Ok(n_events) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn persist_events_fn_without_event_context() { + let id = syn::parse_str("EntityId").unwrap(); + let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site()); + let persist_fn = PersistEventsFn { + id: &id, + event: &event, + events_table_name: "entity_events", + event_ctx: false, + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let expected = quote! { + fn extract_concurrent_modification>( + res: Result, + concurrent_modification: __EsErr, + ) -> Result { + match res { + Ok(v) => Ok(v), + Err(sqlx::Error::Database(ref db_err)) if db_err.is_unique_violation() => { + Err(concurrent_modification) + } + Err(e) => Err(__EsErr::from(e)), + } + } + + async fn persist_events( + &self, + op: &mut OP, + events: &mut es_entity::EntityEvents + ) -> Result + where + OP: es_entity::AtomicOperation, + { + let id = events.id(); + let offset = events.len_persisted(); + let events_types = events.new_event_types(); + let serialized_events = events.serialize_new_events(); + let now = op.maybe_now(); + let recorded_at = now.unwrap_or_else(|| es_entity::prelude::chrono::Utc::now()); + + for (i, (event_type, event_json)) in events_types.iter().zip(serialized_events.iter()).enumerate() { + let sequence = offset as i64 + i as i64 + 1; + let mut query = sqlx::query("INSERT INTO entity_events (id, recorded_at, sequence, event_type, event) VALUES (?1, ?2, ?3, ?4, ?5)") + .bind(id as &EntityId) + .bind(recorded_at) + .bind(sequence) + .bind(event_type) + .bind(event_json); + query.execute(op.as_executor()).await?; + } + + let n_events = events.mark_new_events_persisted_at(recorded_at); + + Ok(n_events) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/populate_nested.rs b/es-entity-macros-sqlite/src/repo/populate_nested.rs new file mode 100644 index 00000000..6f1ea864 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/populate_nested.rs @@ -0,0 +1,101 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct PopulateNested<'a> { + column: &'a Column, + ident: &'a syn::Ident, + generics: &'a syn::Generics, + table_name: &'a str, + events_table_name: &'a str, + repo_types_mod: syn::Ident, +} + +impl<'a> PopulateNested<'a> { + pub fn new(column: &'a Column, opts: &'a RepositoryOptions) -> Self { + Self { + column, + ident: &opts.ident, + generics: &opts.generics, + table_name: opts.table_name(), + events_table_name: opts.events_table_name(), + repo_types_mod: opts.repo_types_mod(), + } + } +} + +impl ToTokens for PopulateNested<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ty = self.column.ty(); + let ident = self.ident; + let repo_types_mod = &self.repo_types_mod; + let accessor = self.column.parent_accessor(); + let table_name = self.table_name; + let column_name = self.column.name().to_string(); + let events_table_name = self.events_table_name; + + let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl(); + + tokens.append_all(quote! { + impl #impl_generics es_entity::PopulateNested<#ty> for #ident #ty_generics #where_clause { + async fn populate_in_op( + op: &mut OP, + mut lookup: std::collections::HashMap<#ty, &mut P>, + ) -> Result<(), __EsErr> + where + OP: es_entity::AtomicOperation, + P: Parent<::Entity>, + __EsErr: From + From + Send, + { + let parent_ids: Vec<_> = lookup.keys().collect(); + if parent_ids.is_empty() { + return Ok(()); + } + let placeholders: String = (1..=parent_ids.len()) + .map(|i| format!("?{i}")) + .collect::>() + .join(", "); + let ctx_param = parent_ids.len() + 1; + let query_str = format!( + "WITH entities AS (SELECT * FROM {} WHERE ({} IN ({}))) \ + SELECT i.id AS entity_id, e.sequence, e.event, \ + CASE WHEN ?{} THEN e.context ELSE NULL END AS context, \ + e.recorded_at \ + FROM entities i JOIN {} e ON i.id = e.id ORDER BY e.id, e.sequence", + #table_name, + #column_name, + placeholders, + ctx_param, + #events_table_name, + ); + let mut query = es_entity::prelude::sqlx::query(&query_str); + for id in &parent_ids { + query = query.bind(id); + } + query = query.bind(<#repo_types_mod::Repo__Event as EsEvent>::event_context()); + let rows = query.fetch_all(op.as_executor()).await?; + use es_entity::prelude::sqlx::Row as _; + let db_events: Vec<#repo_types_mod::Repo__DbEvent> = rows.iter().map(|row| { + #repo_types_mod::Repo__DbEvent { + entity_id: row.try_get("entity_id").expect("entity_id"), + sequence: row.try_get("sequence").expect("sequence"), + event: row.try_get("event").expect("event"), + context: row.try_get("context").expect("context"), + recorded_at: row.try_get("recorded_at").expect("recorded_at"), + } + }).collect(); + let n = db_events.len(); + let (mut res, _) = es_entity::EntityEvents::load_n::<::Entity>(db_events.into_iter(), n)?; + Self::load_all_nested_in_op::<_, __EsErr>(op, &mut res).await?; + for entity in res.into_iter() { + let parent = lookup.get_mut(&entity.#accessor).expect("parent not present"); + parent.inject_children(std::iter::once(entity)); + } + Ok(()) + } + } + }); + } +} diff --git a/es-entity-macros-sqlite/src/repo/post_hydrate_hook.rs b/es-entity-macros-sqlite/src/repo/post_hydrate_hook.rs new file mode 100644 index 00000000..7bebc2b4 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/post_hydrate_hook.rs @@ -0,0 +1,113 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::{PostHydrateHookConfig, RepositoryOptions}; + +pub struct PostHydrateHook<'a> { + entity: &'a syn::Ident, + hook: &'a Option, +} + +impl<'a> From<&'a RepositoryOptions> for PostHydrateHook<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + entity: opts.entity(), + hook: &opts.post_hydrate_hook, + } + } +} + +impl ToTokens for PostHydrateHook<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = &self.entity; + + let (return_type, hook) = if let Some(config) = self.hook { + let method = &config.method; + let error_ty = &config.error; + ( + quote! { #error_ty }, + quote! { + self.#method(entity) + }, + ) + } else { + ( + quote! { std::convert::Infallible }, + quote! { + Ok(()) + }, + ) + }; + + tokens.append_all(quote! { + #[inline(always)] + fn execute_post_hydrate_hook( + &self, + entity: &#entity, + ) -> Result<(), #return_type> { + #hook + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn post_hydrate_hook_none() { + let entity = syn::Ident::new("Entity", proc_macro2::Span::call_site()); + let hook = None; + + let hook = PostHydrateHook { + entity: &entity, + hook: &hook, + }; + + let mut tokens = TokenStream::new(); + hook.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + fn execute_post_hydrate_hook( + &self, + entity: &Entity, + ) -> Result<(), std::convert::Infallible> { + Ok(()) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn post_hydrate_hook_some() { + let entity = syn::Ident::new("Entity", proc_macro2::Span::call_site()); + let hook = Some(PostHydrateHookConfig { + method: syn::Ident::new("validate_entity", proc_macro2::Span::call_site()), + error: syn::parse_str("EntityPostHydrateError").unwrap(), + }); + + let hook = PostHydrateHook { + entity: &entity, + hook: &hook, + }; + + let mut tokens = TokenStream::new(); + hook.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + fn execute_post_hydrate_hook( + &self, + entity: &Entity, + ) -> Result<(), EntityPostHydrateError> { + self.validate_entity(entity) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/post_persist_hook.rs b/es-entity-macros-sqlite/src/repo/post_persist_hook.rs new file mode 100644 index 00000000..e1be9560 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/post_persist_hook.rs @@ -0,0 +1,136 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::RepositoryOptions; +use super::options::PostPersistHookConfig; + +pub struct PostPersistHook<'a> { + event: &'a syn::Ident, + entity: &'a syn::Ident, + hook: &'a Option, +} + +impl<'a> From<&'a RepositoryOptions> for PostPersistHook<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + event: opts.event(), + entity: opts.entity(), + hook: &opts.post_persist_hook, + } + } +} + +impl ToTokens for PostPersistHook<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let event = &self.event; + let entity = &self.entity; + + let (error_ty, hook) = if let Some(config) = self.hook { + let method = &config.method; + let error = &config.error; + ( + quote! { #error }, + quote! { + self.#method(op, entity, new_events).await?; + Ok(()) + }, + ) + } else { + ( + quote! { sqlx::Error }, + quote! { + Ok(()) + }, + ) + }; + + tokens.append_all(quote! { + #[inline(always)] + async fn execute_post_persist_hook( + &self, + op: &mut OP, + entity: &#entity, + new_events: es_entity::LastPersisted<'_, #event> + ) -> Result<(), #error_ty> + where + OP: es_entity::AtomicOperation + { + #hook + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn post_persist_hook_none() { + let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site()); + let entity = syn::Ident::new("Entity", proc_macro2::Span::call_site()); + let hook = None; + + let hook = PostPersistHook { + event: &event, + entity: &entity, + hook: &hook, + }; + + let mut tokens = TokenStream::new(); + hook.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + async fn execute_post_persist_hook(&self, + op: &mut OP, + entity: &Entity, + new_events: es_entity::LastPersisted<'_, EntityEvent> + ) -> Result<(), sqlx::Error> + where + OP: es_entity::AtomicOperation + { + Ok(()) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn post_persist_hook_some() { + let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site()); + let entity = syn::Ident::new("Entity", proc_macro2::Span::call_site()); + let config = Some(PostPersistHookConfig { + method: syn::Ident::new("on_persist", proc_macro2::Span::call_site()), + error: syn::parse_str("MyPersistError").unwrap(), + }); + + let hook = PostPersistHook { + event: &event, + entity: &entity, + hook: &config, + }; + + let mut tokens = TokenStream::new(); + hook.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + async fn execute_post_persist_hook(&self, + op: &mut OP, + entity: &Entity, + new_events: es_entity::LastPersisted<'_, EntityEvent> + ) -> Result<(), MyPersistError> + where + OP: es_entity::AtomicOperation + { + self.on_persist(op, entity, new_events).await?; + Ok(()) + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/update_all_fn.rs b/es-entity-macros-sqlite/src/repo/update_all_fn.rs new file mode 100644 index 00000000..c37f7710 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/update_all_fn.rs @@ -0,0 +1,408 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct UpdateAllFn<'a> { + entity: &'a syn::Ident, + table_name: &'a str, + columns: &'a Columns, + modify_error: syn::Ident, + nested_fn_names: Vec, + post_persist_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> From<&'a RepositoryOptions> for UpdateAllFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + entity: opts.entity(), + modify_error: opts.modify_error(), + columns: &opts.columns, + table_name: opts.table_name(), + nested_fn_names: opts + .all_nested() + .map(|f| f.update_nested_fn_name()) + .collect(), + post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for UpdateAllFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let modify_error = &self.modify_error; + + let nested = self.nested_fn_names.iter().map(|f| { + quote! { + self.#f(op, entity).await?; + } + }); + + let nested_phase = if self.nested_fn_names.is_empty() { + None + } else { + let nested = nested.collect::>(); + Some(quote! { + for entity in entities.iter_mut() { + #(#nested)* + } + }) + }; + + let update_tokens = if self.columns.updates_needed() { + let assignments = self + .columns + .variable_assignments_for_update(syn::parse_quote! { entity }); + let column_updates = self.columns.sql_updates(); + let table_name = self.table_name; + let query = format!("UPDATE {} SET {} WHERE id = ?1", table_name, column_updates,); + let args = self.columns.update_query_args(); + + Some(quote! { + for entity in entities.iter() { + if !entity.events().any_new() { + continue; + } + + #assignments + sqlx::query(#query) + #(#args)* + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + #modify_error::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => #modify_error::Sqlx(e), + })?; + } + }) + } else { + None + }; + + #[cfg(feature = "instrument")] + let (instrument_attr, error_recording) = { + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + let span_name = format!("{}.update_all", repo_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, count = entities.len(), error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr, error_recording) = (quote! {}, quote! {}); + + let post_persist_check = if self.post_persist_error.is_some() { + quote! { + self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#modify_error::PostPersistHookError)?; + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + pub async fn update_all( + &self, + entities: &mut [#entity] + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.update_all_in_op(&mut op, entities).await?; + op.commit().await?; + Ok(res) + } + + #instrument_attr + pub async fn update_all_in_op( + &self, + op: &mut OP, + entities: &mut [#entity] + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + if entities.is_empty() { + return Ok(0); + } + + #nested_phase + + let mut has_new_events = false; + for entity in entities.iter() { + if !entity.events().any_new() { + continue; + } + has_new_events = true; + } + + if !has_new_events { + return Ok(0); + } + + #update_tokens + + let mut all_event_refs: Vec<_> = entities.iter_mut() + .filter_map(|entity| { + let events = Self::extract_events(entity); + if events.any_new() { Some(events) } else { None } + }) + .collect(); + let n_persisted = Self::extract_concurrent_modification( + self.persist_events_batch(op, &mut all_event_refs).await, + #modify_error::ConcurrentModification, + )?; + drop(all_event_refs); + + let mut total_events = 0usize; + for entity in entities.iter_mut() { + if let Some(&n_events) = n_persisted.get(&entity.id) { + if n_events > 0 { + #post_persist_check + total_events += n_events; + } + } + } + + Ok(total_events) + }.await; + + #error_recording + __result + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn update_all_fn() { + let id = syn::parse_str("EntityId").unwrap(); + let entity = Ident::new("Entity", Span::call_site()); + + let columns = Columns::new( + &id, + [Column::new( + Ident::new("name", Span::call_site()), + syn::parse_str("String").unwrap(), + )], + ); + + let update_all_fn = UpdateAllFn { + entity: &entity, + table_name: "entities", + modify_error: syn::Ident::new("EntityModifyError", Span::call_site()), + columns: &columns, + nested_fn_names: Vec::new(), + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + update_all_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn update_all( + &self, + entities: &mut [Entity] + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.update_all_in_op(&mut op, entities).await?; + op.commit().await?; + Ok(res) + } + + pub async fn update_all_in_op( + &self, + op: &mut OP, + entities: &mut [Entity] + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + if entities.is_empty() { + return Ok(0); + } + + let mut has_new_events = false; + for entity in entities.iter() { + if !entity.events().any_new() { + continue; + } + has_new_events = true; + } + + if !has_new_events { + return Ok(0); + } + + for entity in entities.iter() { + if !entity.events().any_new() { + continue; + } + + let id = &entity.id; + let name = &entity.name; + sqlx::query("UPDATE entities SET name = ?2 WHERE id = ?1") + .bind(id) + .bind(name) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityModifyError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityModifyError::Sqlx(e), + })?; + } + + let mut all_event_refs: Vec<_> = entities.iter_mut() + .filter_map(|entity| { + let events = Self::extract_events(entity); + if events.any_new() { Some(events) } else { None } + }) + .collect(); + let n_persisted = Self::extract_concurrent_modification( + self.persist_events_batch(op, &mut all_event_refs).await, + EntityModifyError::ConcurrentModification, + )?; + drop(all_event_refs); + + let mut total_events = 0usize; + for entity in entities.iter_mut() { + if let Some(&n_events) = n_persisted.get(&entity.id) { + if n_events > 0 { + total_events += n_events; + } + } + } + + Ok(total_events) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn update_all_fn_no_columns() { + let id = syn::parse_str("EntityId").unwrap(); + let entity = Ident::new("Entity", Span::call_site()); + + let mut columns = Columns::default(); + columns.set_id_column(&id); + + let update_all_fn = UpdateAllFn { + entity: &entity, + table_name: "entities", + modify_error: syn::Ident::new("EntityModifyError", Span::call_site()), + columns: &columns, + nested_fn_names: Vec::new(), + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + update_all_fn.to_tokens(&mut tokens); + + let expected = quote! { + pub async fn update_all( + &self, + entities: &mut [Entity] + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.update_all_in_op(&mut op, entities).await?; + op.commit().await?; + Ok(res) + } + + pub async fn update_all_in_op( + &self, + op: &mut OP, + entities: &mut [Entity] + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + if entities.is_empty() { + return Ok(0); + } + + let mut has_new_events = false; + for entity in entities.iter() { + if !entity.events().any_new() { + continue; + } + has_new_events = true; + } + + if !has_new_events { + return Ok(0); + } + + let mut all_event_refs: Vec<_> = entities.iter_mut() + .filter_map(|entity| { + let events = Self::extract_events(entity); + if events.any_new() { Some(events) } else { None } + }) + .collect(); + let n_persisted = Self::extract_concurrent_modification( + self.persist_events_batch(op, &mut all_event_refs).await, + EntityModifyError::ConcurrentModification, + )?; + drop(all_event_refs); + + let mut total_events = 0usize; + for entity in entities.iter_mut() { + if let Some(&n_events) = n_persisted.get(&entity.id) { + if n_events > 0 { + total_events += n_events; + } + } + } + + Ok(total_events) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/repo/update_fn.rs b/es-entity-macros-sqlite/src/repo/update_fn.rs new file mode 100644 index 00000000..03d336a0 --- /dev/null +++ b/es-entity-macros-sqlite/src/repo/update_fn.rs @@ -0,0 +1,347 @@ +use darling::ToTokens; +use proc_macro2::TokenStream; +use quote::{TokenStreamExt, quote}; + +use super::options::*; + +pub struct UpdateFn<'a> { + entity: &'a syn::Ident, + table_name: &'a str, + columns: &'a Columns, + modify_error: syn::Ident, + nested_fn_names: Vec, + post_persist_error: Option<&'a syn::Type>, + #[cfg(feature = "instrument")] + repo_name_snake: String, +} + +impl<'a> From<&'a RepositoryOptions> for UpdateFn<'a> { + fn from(opts: &'a RepositoryOptions) -> Self { + Self { + entity: opts.entity(), + modify_error: opts.modify_error(), + columns: &opts.columns, + table_name: opts.table_name(), + nested_fn_names: opts + .all_nested() + .map(|f| f.update_nested_fn_name()) + .collect(), + post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error), + #[cfg(feature = "instrument")] + repo_name_snake: opts.repo_name_snake_case(), + } + } +} + +impl ToTokens for UpdateFn<'_> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let entity = self.entity; + let modify_error = &self.modify_error; + + let nested = self.nested_fn_names.iter().map(|f| { + quote! { + self.#f(op, entity).await?; + } + }); + + let update_tokens = if self.columns.updates_needed() { + let assignments = self + .columns + .variable_assignments_for_update(syn::parse_quote! { entity }); + let column_updates = self.columns.sql_updates(); + let query = format!( + "UPDATE {} SET {} WHERE id = ?1", + self.table_name, column_updates, + ); + let args = self.columns.update_query_args(); + Some(quote! { + #assignments + sqlx::query(#query) + #(#args)* + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + #modify_error::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => #modify_error::Sqlx(e), + })?; + }) + } else { + None + }; + + #[cfg(feature = "instrument")] + let (instrument_attr, record_id, error_recording) = { + use convert_case::{Case, Casing}; + + let entity_name = entity.to_string(); + let repo_name = &self.repo_name_snake; + + let id_ident = quote::format_ident!("{}_id", entity.to_string().to_case(Case::Snake)); + + let span_name = format!("{}.update", repo_name); + ( + quote! { + #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, #id_ident = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))] + }, + quote! { + tracing::Span::current().record(stringify!(#id_ident), tracing::field::display(&entity.id)); + }, + quote! { + if let Err(ref e) = __result { + tracing::Span::current().record("error", true); + tracing::Span::current().record("exception.message", tracing::field::display(e)); + tracing::Span::current().record("exception.type", std::any::type_name_of_val(e)); + } + }, + ) + }; + #[cfg(not(feature = "instrument"))] + let (instrument_attr, record_id, error_recording) = (quote! {}, quote! {}, quote! {}); + + let post_persist_check = if self.post_persist_error.is_some() { + quote! { + self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#modify_error::PostPersistHookError)?; + } + } else { + quote! {} + }; + + tokens.append_all(quote! { + #[inline(always)] + fn extract_events(entity: &mut Entity) -> &mut es_entity::EntityEvents + where + Entity: es_entity::EsEntity, + Event: es_entity::EsEvent, + { + entity.events_mut() + } + + pub async fn update( + &self, + entity: &mut #entity + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.update_in_op(&mut op, entity).await?; + op.commit().await?; + Ok(res) + } + + #instrument_attr + pub async fn update_in_op( + &self, + op: &mut OP, + entity: &mut #entity + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + #record_id + #(#nested)* + + if !Self::extract_events(entity).any_new() { + return Ok(0); + } + + #update_tokens + let n_events = { + let events = Self::extract_events(entity); + Self::extract_concurrent_modification( + self.persist_events(op, events).await, + #modify_error::ConcurrentModification, + )? + }; + + #post_persist_check + + Ok(n_events) + }.await; + + #error_recording + __result + } + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use syn::Ident; + + #[test] + fn update_fn() { + let id = syn::parse_str("EntityId").unwrap(); + let entity = Ident::new("Entity", Span::call_site()); + + let columns = Columns::new( + &id, + [Column::new( + Ident::new("name", Span::call_site()), + syn::parse_str("String").unwrap(), + )], + ); + + let update_fn = UpdateFn { + entity: &entity, + table_name: "entities", + modify_error: syn::Ident::new("EntityModifyError", Span::call_site()), + columns: &columns, + nested_fn_names: Vec::new(), + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + update_fn.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + fn extract_events(entity: &mut Entity) -> &mut es_entity::EntityEvents + where + Entity: es_entity::EsEntity, + Event: es_entity::EsEvent, + { + entity.events_mut() + } + + pub async fn update( + &self, + entity: &mut Entity + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.update_in_op(&mut op, entity).await?; + op.commit().await?; + Ok(res) + } + + pub async fn update_in_op( + &self, + op: &mut OP, + entity: &mut Entity + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + if !Self::extract_events(entity).any_new() { + return Ok(0); + } + + let id = &entity.id; + let name = &entity.name; + sqlx::query("UPDATE entities SET name = ?2 WHERE id = ?1") + .bind(id) + .bind(name) + .execute(op.as_executor()) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.is_unique_violation() => { + EntityModifyError::ConstraintViolation { + column: Self::map_constraint_column(db_err.constraint()), + value: es_entity::db::extract_constraint_value(db_err.as_ref()), + inner: e, + } + } + _ => EntityModifyError::Sqlx(e), + })?; + + let n_events = { + let events = Self::extract_events(entity); + Self::extract_concurrent_modification( + self.persist_events(op, events).await, + EntityModifyError::ConcurrentModification, + )? + }; + + Ok(n_events) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } + + #[test] + fn update_fn_no_columns() { + let id = syn::parse_str("EntityId").unwrap(); + let entity = Ident::new("Entity", Span::call_site()); + + let mut columns = Columns::default(); + columns.set_id_column(&id); + + let update_fn = UpdateFn { + entity: &entity, + table_name: "entities", + modify_error: syn::Ident::new("EntityModifyError", Span::call_site()), + columns: &columns, + nested_fn_names: Vec::new(), + post_persist_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + update_fn.to_tokens(&mut tokens); + + let expected = quote! { + #[inline(always)] + fn extract_events(entity: &mut Entity) -> &mut es_entity::EntityEvents + where + Entity: es_entity::EsEntity, + Event: es_entity::EsEvent, + { + entity.events_mut() + } + + pub async fn update( + &self, + entity: &mut Entity + ) -> Result { + let mut op = self.begin_op().await?; + let res = self.update_in_op(&mut op, entity).await?; + op.commit().await?; + Ok(res) + } + + pub async fn update_in_op( + &self, + op: &mut OP, + entity: &mut Entity + ) -> Result + where + OP: es_entity::AtomicOperation + { + let __result: Result = async { + if !Self::extract_events(entity).any_new() { + return Ok(0); + } + + let n_events = { + let events = Self::extract_events(entity); + Self::extract_concurrent_modification( + self.persist_events(op, events).await, + EntityModifyError::ConcurrentModification, + )? + }; + + Ok(n_events) + }.await; + + __result + } + }; + + assert_eq!(tokens.to_string(), expected.to_string()); + } +} diff --git a/es-entity-macros-sqlite/src/retry_on_concurrent_modification.rs b/es-entity-macros-sqlite/src/retry_on_concurrent_modification.rs new file mode 100644 index 00000000..816d180e --- /dev/null +++ b/es-entity-macros-sqlite/src/retry_on_concurrent_modification.rs @@ -0,0 +1,210 @@ +use darling::{FromMeta, ast::NestedMeta}; +use syn::ItemFn; + +#[derive(FromMeta)] +struct MacroArgs { + any_error: Option, + max_retries: Option, +} + +pub fn make( + args: proc_macro::TokenStream, + input: ItemFn, +) -> darling::Result { + let attr_args = NestedMeta::parse_meta_list(args.into())?; + let args = MacroArgs::from_list(&attr_args)?; + + let mut inner_fn = input.clone(); + let inner_ident = syn::Ident::new( + &format!("{}_exec_one", &input.sig.ident), + input.sig.ident.span(), + ); + inner_fn.sig.ident = inner_ident.clone(); + inner_fn.vis = syn::Visibility::Inherited; + // Keep user-provided attributes (like #[instrument]) on the inner function + // inner_fn.attrs is preserved + + // Filter out tracing-related attributes for the outer function + // (they should only be on the inner function) + let outer_attrs: Vec<_> = input + .attrs + .iter() + .filter(|attr| { + // Keep non-tracing attributes on outer function + !(attr.path().is_ident("instrument") + || (attr.path().segments.len() == 2 + && attr.path().segments[0].ident == "tracing" + && attr.path().segments[1].ident == "instrument")) + }) + .collect(); + + let vis = &input.vis; + let sig = &input.sig; + + let any_error = args.any_error.unwrap_or(false); + + #[cfg(feature = "instrument")] + let err_match = if any_error { + quote::quote! { + if result.is_err() { + tracing::warn!( + attempt = n, + max_retries = max_retries, + "Error detected, retrying" + ); + continue; + } + } + } else { + quote::quote! { + if let Err(e) = result.as_ref() { + if e.was_concurrent_modification() { + tracing::warn!( + attempt = n, + max_retries = max_retries, + "Concurrent modification detected, retrying" + ); + continue; + } + } + } + }; + + #[cfg(not(feature = "instrument"))] + let err_match = if any_error { + quote::quote! { + if result.is_err() { + continue; + } + } + } else { + quote::quote! { + if let Err(e) = result.as_ref() { + if e.was_concurrent_modification() { + continue; + } + } + } + }; + + let inputs: Vec<_> = input + .sig + .inputs + .iter() + .filter_map(|input| match input { + syn::FnArg::Receiver(_) => None, + syn::FnArg::Typed(pat_type) => Some(&pat_type.pat), + }) + .collect(); + + let max_retries = args.max_retries.unwrap_or(3); + + #[cfg(feature = "instrument")] + let outer_fn = { + let fn_name = input.sig.ident.to_string(); + let retry_span_name = format!("{}.retry_wrapper", fn_name); + + quote::quote! { + #( #outer_attrs )* + #[tracing::instrument( + name = #retry_span_name, + skip_all, + fields( + max_retries = #max_retries, + attempt = tracing::field::Empty, + retried = false + ) + )] + #vis #sig { + let max_retries = #max_retries; + for n in 1..=max_retries { + tracing::Span::current().record("attempt", n); + if n > 1 { + tracing::Span::current().record("retried", true); + } + + let result = self.#inner_ident(#(#inputs),*).await; + if n == max_retries { + return result; + } + #err_match + return result; + } + unreachable!(); + } + } + }; + + #[cfg(not(feature = "instrument"))] + let outer_fn = { + quote::quote! { + #( #outer_attrs )* + #vis #sig { + let max_retries = #max_retries; + for n in 1..=max_retries { + let result = self.#inner_ident(#(#inputs),*).await; + if n == max_retries { + return result; + } + #err_match + return result; + } + unreachable!(); + } + } + }; + + let output = quote::quote! { + #inner_fn + #outer_fn + }; + Ok(output) +} + +// Its working - just need to figure out how to parse the attribute args for testing + +// #[cfg(test)] +// mod tests { +// use super::*; +// use syn::parse_quote; + +// #[test] +// fn retry_on_concurrent_modification() { +// let input = parse_quote! { +// #[retry_on_concurrent_modification] +// #[instrument(name = "test")] +// pub async fn test(&self, a: u32) -> Result<(), es_entity::EsRepoError> { +// self.repo.update().await?; +// Ok(()) +// } +// }; + +// let output = make(input).unwrap(); +// let expected = quote::quote! { +// async fn test_exec_one(&self, a: u32) -> Result<(), es_entity::EsRepoError> { +// self.repo.update().await?; +// Ok(()) +// } + +// #[retry_on_concurrent_modification] +// #[instrument(name = "test")] +// pub async fn test(&self, a: u32) -> Result<(), es_entity::EsRepoError> { +// let max_retries = 3; +// for n in 1..=max_retries { +// let result = self.test_exec_one(a).await; +// if n == max_retries { +// return result; +// } +// if let Err(e) = result.as_ref() { +// if e.was_concurrent_modification() { +// continue; +// } +// } +// return result; +// } +// unreachable!(); +// } +// }; +// assert_eq!(output.to_string(), expected.to_string()); +// } +// } diff --git a/flake.nix b/flake.nix index 61fce191..8b71ea8c 100644 --- a/flake.nix +++ b/flake.nix @@ -145,10 +145,15 @@ checks = { workspace-fmt = craneLib.cargoFmt commonArgs; - workspace-clippy = craneLib.cargoClippy (commonArgs + workspace-clippy-pg = craneLib.cargoClippy (commonArgs // { inherit cargoArtifacts; - cargoClippyExtraArgs = "--all-features -- --deny warnings"; + cargoClippyExtraArgs = "--workspace --features postgres,graphql,event-context,instrument,tracing-context,json-schema -- --deny warnings"; + }); + workspace-clippy-sqlite = craneLib.cargoClippy (commonArgs + // { + inherit cargoArtifacts; + cargoClippyExtraArgs = "--workspace --no-default-features --features sqlite,graphql,event-context,instrument,json-schema -- --deny warnings"; }); workspace-audit = craneLib.cargoAudit { inherit advisory-db; diff --git a/migrations-sqlite/20250718092455_test_setup.sql b/migrations-sqlite/20250718092455_test_setup.sql new file mode 100644 index 00000000..c9a85ad9 --- /dev/null +++ b/migrations-sqlite/20250718092455_test_setup.sql @@ -0,0 +1,157 @@ +CREATE TABLE users ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + deleted INTEGER DEFAULT 0, + created_at TEXT NOT NULL +); +CREATE INDEX idx_users_name ON users (name); + +CREATE TABLE user_events ( + id TEXT NOT NULL REFERENCES users(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +CREATE TABLE user_documents ( + id TEXT PRIMARY KEY NOT NULL, + user_id TEXT, + created_at TEXT NOT NULL +); + +CREATE TABLE user_document_events ( + id TEXT NOT NULL REFERENCES user_documents(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +CREATE TABLE ignore_prefix_users ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + created_at TEXT NOT NULL +); +CREATE INDEX idx_ignore_prefix_users_name ON ignore_prefix_users (name); + +CREATE TABLE ignore_prefix_user_events ( + id TEXT NOT NULL REFERENCES ignore_prefix_users(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +CREATE TABLE custom_name_for_users ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + created_at TEXT NOT NULL +); +CREATE INDEX idx_custom_name_for_users_name ON custom_name_for_users (name); + +CREATE TABLE custom_name_for_user_events ( + id TEXT NOT NULL REFERENCES custom_name_for_users(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +-- Tables for nested entities test +CREATE TABLE orders ( + id TEXT PRIMARY KEY NOT NULL, + created_at TEXT NOT NULL +); + +CREATE TABLE order_events ( + id TEXT NOT NULL REFERENCES orders(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +CREATE TABLE order_items ( + id TEXT PRIMARY KEY NOT NULL, + order_id TEXT NOT NULL REFERENCES orders(id), + created_at TEXT NOT NULL +); + +CREATE TABLE order_item_events ( + id TEXT NOT NULL REFERENCES order_items(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +-- Tables for subscription/billing period example +CREATE TABLE subscriptions ( + id TEXT PRIMARY KEY NOT NULL, + created_at TEXT NOT NULL +); + +CREATE TABLE subscription_events ( + id TEXT NOT NULL REFERENCES subscriptions(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +CREATE TABLE billing_periods ( + id TEXT PRIMARY KEY NOT NULL, + subscription_id TEXT NOT NULL REFERENCES subscriptions(id), + created_at TEXT NOT NULL +); + +CREATE TABLE billing_period_events ( + id TEXT NOT NULL REFERENCES billing_periods(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); + +CREATE TABLE hook_events ( + entity_id TEXT NOT NULL, + event_type TEXT NOT NULL, + created_at TEXT NOT NULL +); + +-- Tables for custom accessor tests +CREATE TABLE profiles ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + display_name TEXT NOT NULL, + email TEXT NOT NULL, + created_at TEXT NOT NULL +); +CREATE UNIQUE INDEX profiles_email_key ON profiles (email); + +CREATE TABLE profile_events ( + id TEXT NOT NULL REFERENCES profiles(id), + sequence INTEGER NOT NULL, + event_type TEXT NOT NULL, + event TEXT NOT NULL, + context TEXT DEFAULT NULL, + recorded_at TEXT NOT NULL, + UNIQUE(id, sequence) +); diff --git a/src/context/sqlx.rs b/src/context/sqlx.rs index 271d584a..6d76b254 100644 --- a/src/context/sqlx.rs +++ b/src/context/sqlx.rs @@ -1,37 +1,85 @@ -use sqlx::postgres::{PgHasArrayType, PgValueRef}; +use super::ContextData; -use crate::db; +// ── Postgres implementation ────────────────────────────────────────────── -use super::ContextData; +#[cfg(feature = "postgres")] +mod pg { + use sqlx::{ + Postgres, + postgres::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueRef}, + }; -impl sqlx::Type for ContextData { - fn type_info() -> db::TypeInfo { - >::type_info() + use super::ContextData; + + impl sqlx::Type for ContextData { + fn type_info() -> PgTypeInfo { + >::type_info() + } } -} -impl<'q> sqlx::Encode<'q, db::Db> for ContextData { - fn encode_by_ref( - &self, - buf: &mut db::ArgumentBuffer, - ) -> Result> { - let json_value = serde_json::to_value(&self.0)?; - >::encode_by_ref(&json_value, buf) + impl<'q> sqlx::Encode<'q, Postgres> for ContextData { + fn encode_by_ref( + &self, + buf: &mut PgArgumentBuffer, + ) -> Result> + { + let json_value = serde_json::to_value(&self.0)?; + >::encode_by_ref(&json_value, buf) + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for ContextData { + fn decode( + value: PgValueRef<'r>, + ) -> Result> { + let json_value = >::decode(value)?; + let res: ContextData = serde_json::from_value(json_value)?; + Ok(res) + } } -} -impl<'r> sqlx::Decode<'r, db::Db> for ContextData { - fn decode( - value: PgValueRef<'r>, - ) -> Result> { - let json_value = >::decode(value)?; - let res: ContextData = serde_json::from_value(json_value)?; - Ok(res) + impl PgHasArrayType for ContextData { + fn array_type_info() -> PgTypeInfo { + ::array_type_info() + } } } -impl PgHasArrayType for ContextData { - fn array_type_info() -> db::TypeInfo { - ::array_type_info() +// ── SQLite implementation ──────────────────────────────────────────────── + +#[cfg(feature = "sqlite")] +mod sqlite { + use sqlx::{ + Sqlite, + sqlite::{SqliteTypeInfo, SqliteValueRef}, + }; + + use super::ContextData; + + impl sqlx::Type for ContextData { + fn type_info() -> SqliteTypeInfo { + >::type_info() + } + } + + impl<'q> sqlx::Encode<'q, Sqlite> for ContextData { + fn encode_by_ref( + &self, + buf: &mut Vec>, + ) -> Result> + { + let json_str = serde_json::to_string(self)?; + >::encode(json_str, buf) + } + } + + impl<'r> sqlx::Decode<'r, Sqlite> for ContextData { + fn decode( + value: SqliteValueRef<'r>, + ) -> Result> { + let json_str = >::decode(value)?; + let res: ContextData = serde_json::from_str(&json_str)?; + Ok(res) + } } } diff --git a/src/db.rs b/src/db.rs index 9ecde873..efe44429 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,16 +1,31 @@ //! Centralized database type aliases. //! -//! Re-exports PostgreSQL-specific types from [`sqlx`] under shorter names, +//! Re-exports database-specific types from [`sqlx`] under shorter names, //! giving the rest of the crate a single place to reference them. +//! +//! Exactly one of `postgres` or `sqlite` must be enabled. + +#[cfg(all(feature = "postgres", feature = "sqlite"))] +compile_error!("features `postgres` and `sqlite` are mutually exclusive — enable only one"); + +#[cfg(not(any(feature = "postgres", feature = "sqlite")))] +compile_error!("one of features `postgres` or `sqlite` must be enabled"); + +// ── Postgres ────────────────────────────────────────────────────────────── +#[cfg(feature = "postgres")] pub use sqlx::PgConnection as Connection; +#[cfg(feature = "postgres")] pub use sqlx::PgPool as Pool; +#[cfg(feature = "postgres")] pub use sqlx::Postgres as Db; +#[cfg(feature = "postgres")] pub use sqlx::postgres::{ PgArgumentBuffer as ArgumentBuffer, PgRow as Row, PgTypeInfo as TypeInfo, }; /// Fetches the current timestamp from the database via `SELECT NOW()`. +#[cfg(feature = "postgres")] pub async fn database_now( executor: impl sqlx::Executor<'_, Database = Db>, ) -> Result, sqlx::Error> { @@ -18,3 +33,41 @@ pub async fn database_now( .fetch_one(executor) .await } + +/// Extract the conflicting value from a database constraint violation, if possible. +#[cfg(feature = "postgres")] +pub fn extract_constraint_value(db_err: &dyn sqlx::error::DatabaseError) -> Option { + db_err + .try_downcast_ref::() + .and_then(|pg_err| crate::error::parse_constraint_detail_value(pg_err.detail())) +} + +// ── SQLite ──────────────────────────────────────────────────────────────── + +#[cfg(feature = "sqlite")] +pub use sqlx::Sqlite as Db; +#[cfg(feature = "sqlite")] +pub use sqlx::SqliteConnection as Connection; +#[cfg(feature = "sqlite")] +pub use sqlx::SqlitePool as Pool; +#[cfg(feature = "sqlite")] +pub use sqlx::sqlite::SqliteRow as Row; + +/// Obtain the current database time. +/// +/// SQLite does not have a native `NOW()` that returns a proper timestamp type, +/// so we fall back to `chrono::Utc::now()` on the application side. +#[cfg(feature = "sqlite")] +pub async fn database_now( + _executor: impl sqlx::Executor<'_, Database = Db>, +) -> Result, sqlx::Error> { + Ok(chrono::Utc::now()) +} + +/// Extract the conflicting value from a database constraint violation, if possible. +/// +/// SQLite does not provide detail messages like PostgreSQL, so this always returns `None`. +#[cfg(feature = "sqlite")] +pub fn extract_constraint_value(_db_err: &dyn sqlx::error::DatabaseError) -> Option { + None +} diff --git a/src/lib.rs b/src/lib.rs index fe351766..59eb1b19 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,7 @@ -//! A Rust library for persisting Event Sourced entities to PostgreSQL +//! A Rust library for persisting Event Sourced entities to PostgreSQL or SQLite //! //! This crate simplifies Event Sourcing persistence by automatically generating type-safe -//! queries and operations for PostgreSQL. It decouples domain logic from persistence +//! queries and operations. It decouples domain logic from persistence //! concerns while ensuring compile-time query verification via [sqlx](https://crates.io/crates/sqlx). //! //! # Documentation @@ -54,12 +54,35 @@ pub mod prelude { pub use context::*; #[doc(inline)] pub use error::*; + +// ── Derive macro re-exports (backend-gated) ────────────────────────────── + +#[cfg(feature = "postgres")] pub use es_entity_macros::EsEntity; +#[cfg(feature = "postgres")] pub use es_entity_macros::EsEvent; +#[cfg(feature = "postgres")] pub use es_entity_macros::EsRepo; +#[cfg(feature = "postgres")] pub use es_entity_macros::es_event_context; +#[cfg(feature = "postgres")] pub use es_entity_macros::expand_es_query; +#[cfg(feature = "postgres")] pub use es_entity_macros::retry_on_concurrent_modification; + +#[cfg(feature = "sqlite")] +pub use es_entity_macros_sqlite::EsEntity; +#[cfg(feature = "sqlite")] +pub use es_entity_macros_sqlite::EsEvent; +#[cfg(feature = "sqlite")] +pub use es_entity_macros_sqlite::EsRepo; +#[cfg(feature = "sqlite")] +pub use es_entity_macros_sqlite::es_event_context; +#[cfg(feature = "sqlite")] +pub use es_entity_macros_sqlite::expand_es_query; +#[cfg(feature = "sqlite")] +pub use es_entity_macros_sqlite::retry_on_concurrent_modification; + #[doc(inline)] pub use events::*; #[doc(inline)] diff --git a/src/one_time_executor.rs b/src/one_time_executor.rs index d1e9975b..1998c2a4 100644 --- a/src/one_time_executor.rs +++ b/src/one_time_executor.rs @@ -11,7 +11,7 @@ use crate::{db, operation::AtomicOperation}; /// /// In order to make the consumption of the executor work we have to pass the query to the /// executor: -/// ```rust +/// ```rust,ignore /// async fn query(ex: impl es_entity::IntoOneTimeExecutor<'_>) -> Result<(), sqlx::Error> { /// ex.into_executor().fetch_optional( /// sqlx::query!( @@ -62,12 +62,12 @@ where /// Proxy call to `query.fetch_all` but guarantees the inner executor will only be used once. pub async fn fetch_all<'q, F, O, A>( self, - query: sqlx::query::Map<'q, sqlx::Postgres, F, A>, + query: sqlx::query::Map<'q, db::Db, F, A>, ) -> Result, sqlx::Error> where - F: FnMut(sqlx::postgres::PgRow) -> Result + Send, + F: FnMut(db::Row) -> Result + Send, O: Send + Unpin, - A: 'q + Send + sqlx::IntoArguments<'q, sqlx::Postgres>, + A: 'q + Send + sqlx::IntoArguments<'q, db::Db>, { query.fetch_all(self.executor).await } @@ -75,18 +75,18 @@ where /// Proxy call to `query.fetch_optional` but guarantees the inner executor will only be used once. pub async fn fetch_optional<'q, F, O, A>( self, - query: sqlx::query::Map<'q, sqlx::Postgres, F, A>, + query: sqlx::query::Map<'q, db::Db, F, A>, ) -> Result, sqlx::Error> where - F: FnMut(sqlx::postgres::PgRow) -> Result + Send, + F: FnMut(db::Row) -> Result + Send, O: Send + Unpin, - A: 'q + Send + sqlx::IntoArguments<'q, sqlx::Postgres>, + A: 'q + Send + sqlx::IntoArguments<'q, db::Db>, { query.fetch_optional(self.executor).await } } -/// Marker trait for [`IntoOneTimeExecutorAt<'a> + 'a`](`IntoOneTimeExecutorAt`). Do not implement directly. +/// Marker trait for [`IntoOnetOneExecutorAt<'a> + 'a`](`IntoOneTimeExecutorAt`). Do not implement directly. /// /// Used as sugar to avoid writing: /// ```rust,ignore diff --git a/src/operation/hooks.rs b/src/operation/hooks.rs index a070a507..dced6ea3 100644 --- a/src/operation/hooks.rs +++ b/src/operation/hooks.rs @@ -29,7 +29,7 @@ //! fire-and-forget operations like sending to channels. A background task can then //! handle the async work of publishing to external systems. //! -//! ``` +//! ```ignore //! use es_entity::{AtomicOperation, operation::hooks::{CommitHook, HookOperation, PreCommitRet}}; //! //! #[derive(Debug, Clone)] @@ -98,7 +98,7 @@ //! //! ## Usage //! -//! ```no_run +//! ```ignore //! # use es_entity::{AtomicOperation, DbOp, operation::hooks::{CommitHook, HookOperation, PreCommitRet}}; //! # use es_entity::db; //! # #[derive(Debug, Clone)] diff --git a/src/operation/mod.rs b/src/operation/mod.rs index 9ed60c26..ac700aef 100644 --- a/src/operation/mod.rs +++ b/src/operation/mod.rs @@ -79,7 +79,7 @@ impl<'c> DbOp<'c> { /// Priority order: /// 1. Cached time if present /// 2. Artificial clock time if the clock is artificial (and hasn't transitioned) - /// 3. Database time via `SELECT NOW()` + /// 3. Database time via `SELECT NOW()` (Postgres) or application time (SQLite) pub async fn with_db_time(mut self) -> Result, sqlx::Error> { let time = if let Some(time) = self.now { time @@ -223,7 +223,7 @@ pub trait AtomicOperation: Send { /// The desired way to represent this would actually be as a GAT: /// ```rust /// trait AtomicOperation { - /// type Executor<'c>: sqlx::PgExecutor<'c> + /// type Executor<'c>: sqlx::Executor<'c> /// where Self: 'c; /// /// fn as_executor<'c>(&'c mut self) -> Self::Executor<'c>; @@ -235,7 +235,7 @@ pub trait AtomicOperation: Send { /// /// Since this trait is generally applied to types that wrap a [`sqlx::Transaction`] /// there is no variance in the return type - so its fine. - fn as_executor(&mut self) -> &mut sqlx::PgConnection; + fn as_executor(&mut self) -> &mut db::Connection; /// Registers a commit hook that will run pre_commit before and post_commit after the transaction commits. /// Returns Ok(()) if the hook was registered, Err(hook) if hooks are not supported. diff --git a/src/operation/with_time.rs b/src/operation/with_time.rs index dc680a38..a5a4d3ee 100644 --- a/src/operation/with_time.rs +++ b/src/operation/with_time.rs @@ -24,7 +24,7 @@ impl<'a, Op: AtomicOperation + ?Sized> OpWithTime<'a, Op> { /// Priority order: /// 1. Cached time from operation /// 2. Artificial clock time if the operation's clock is artificial (and hasn't transitioned) - /// 3. Database time via `SELECT NOW()` + /// 3. Database time (Postgres: `SELECT NOW()`, SQLite: application time) pub async fn cached_or_db_time(op: &'a mut Op) -> Result { let now = if let Some(time) = op.maybe_now() { time diff --git a/tests/context.rs b/tests/context.rs index c5f3c132..1d8b1b63 100644 --- a/tests/context.rs +++ b/tests/context.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "postgres")] + use es_entity::context::EventContext; use es_entity_macros::es_event_context; use serde_json::json; diff --git a/tests/es_query.rs b/tests/es_query.rs index c0afe1f9..bf6d1cad 100644 --- a/tests/es_query.rs +++ b/tests/es_query.rs @@ -4,7 +4,6 @@ mod helpers; use entities::user::*; use es_entity::*; use helpers::init_pool; -use sqlx::PgPool; mod tbl_prefix_param { use super::*; @@ -16,11 +15,11 @@ mod tbl_prefix_param { columns(name(ty = "String")) )] struct UsersTblPrefix { - pool: PgPool, + pool: es_entity::db::Pool, } impl UsersTblPrefix { - fn new(pool: PgPool) -> Self { + fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } @@ -100,11 +99,11 @@ mod entity_param { columns(name(ty = "String")) )] struct UsersEntity { - pool: PgPool, + pool: es_entity::db::Pool, } impl UsersEntity { - fn new(pool: PgPool) -> Self { + fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } @@ -178,11 +177,11 @@ mod no_params { #[derive(EsRepo, Debug)] #[es_repo(entity = "User", columns(name(ty = "String")))] struct UsersNoParams { - pool: PgPool, + pool: es_entity::db::Pool, } impl UsersNoParams { - fn new(pool: PgPool) -> Self { + fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } diff --git a/tests/from_async_trait.rs b/tests/from_async_trait.rs index e92d53b6..fde46d83 100644 --- a/tests/from_async_trait.rs +++ b/tests/from_async_trait.rs @@ -3,7 +3,6 @@ mod helpers; use es_entity::*; use helpers::init_pool; -use sqlx::PgPool; use entities::order::*; @@ -18,7 +17,7 @@ trait RunJob { } struct TestJob { - pool: PgPool, + pool: es_entity::db::Pool, } #[async_trait::async_trait] @@ -43,13 +42,13 @@ impl RunJob for TestJob { #[derive(EsRepo, Debug)] #[es_repo(entity = "Order")] pub struct Orders { - pool: PgPool, + pool: es_entity::db::Pool, #[es_repo(nested)] items: OrderItems, } impl Orders { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool: pool.clone(), items: OrderItems::new(pool), @@ -63,11 +62,11 @@ impl Orders { columns(order_id(ty = "OrderId", update(persist = false), parent)) )] pub struct OrderItems { - pool: PgPool, + pool: es_entity::db::Pool, } impl OrderItems { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } } diff --git a/tests/helpers.rs b/tests/helpers.rs index e0ecd2c0..f16442de 100644 --- a/tests/helpers.rs +++ b/tests/helpers.rs @@ -1,6 +1,19 @@ +#[cfg(feature = "postgres")] pub async fn init_pool() -> anyhow::Result { let pg_host = std::env::var("PG_HOST").unwrap_or("localhost".to_string()); let pg_con = format!("postgres://user:password@{pg_host}:5432/pg"); let pool = sqlx::PgPool::connect(&pg_con).await?; Ok(pool) } + +#[cfg(feature = "sqlite")] +pub async fn init_pool() -> anyhow::Result { + use std::sync::atomic::{AtomicU64, Ordering}; + + static COUNTER: AtomicU64 = AtomicU64::new(0); + let db_id = COUNTER.fetch_add(1, Ordering::Relaxed); + let url = format!("sqlite:file:memdb_{db_id}?mode=memory&cache=shared"); + let pool = sqlx::SqlitePool::connect(&url).await?; + sqlx::migrate!("./migrations-sqlite").run(&pool).await?; + Ok(pool) +} diff --git a/tests/hooks.rs b/tests/hooks.rs index fba37195..edc831e1 100644 --- a/tests/hooks.rs +++ b/tests/hooks.rs @@ -18,10 +18,8 @@ impl CommitHook for FullCommitHook { self, mut op: HookOperation<'_>, ) -> Result, sqlx::Error> { - let result = sqlx::query!("SELECT NOW() as now") - .fetch_one(op.as_executor()) - .await?; - *self.pre_result.lock().unwrap() = result.now; + let now = es_entity::db::database_now(op.as_executor()).await?; + *self.pre_result.lock().unwrap() = Some(now); PreCommitRet::ok(self, op) } diff --git a/tests/nested_entities.rs b/tests/nested_entities.rs index 7a09116f..836c35c4 100644 --- a/tests/nested_entities.rs +++ b/tests/nested_entities.rs @@ -3,19 +3,18 @@ mod helpers; use entities::order::*; use es_entity::*; -use sqlx::PgPool; #[derive(EsRepo, Debug)] #[es_repo(entity = "Order")] pub struct Orders { - pool: PgPool, + pool: es_entity::db::Pool, #[es_repo(nested)] items: OrderItems, } impl Orders { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool: pool.clone(), items: OrderItems::new(pool), @@ -29,11 +28,11 @@ impl Orders { columns(order_id(ty = "OrderId", update(persist = false), parent)) )] pub struct OrderItems { - pool: PgPool, + pool: es_entity::db::Pool, } impl OrderItems { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } } diff --git a/tests/repo_bulk.rs b/tests/repo_bulk.rs index 7c3d56bf..c9c11331 100644 --- a/tests/repo_bulk.rs +++ b/tests/repo_bulk.rs @@ -3,7 +3,6 @@ mod helpers; use entities::profile::*; use es_entity::*; -use sqlx::PgPool; /// Profiles repo with custom accessors: /// - `name`: field-path accessor (`data.name`) — accesses nested struct field @@ -23,11 +22,11 @@ use sqlx::PgPool; ) )] pub struct Profiles { - pool: PgPool, + pool: es_entity::db::Pool, } impl Profiles { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } } diff --git a/tests/repo_clock.rs b/tests/repo_clock.rs index f7173be5..6576310d 100644 --- a/tests/repo_clock.rs +++ b/tests/repo_clock.rs @@ -3,16 +3,15 @@ mod helpers; use entities::user::*; use es_entity::{clock::*, *}; -use sqlx::PgPool; #[derive(EsRepo, Debug)] #[es_repo(entity = "User", columns(name(ty = "String", list_for)))] pub struct Users { - pool: PgPool, + pool: es_entity::db::Pool, } impl Users { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } } @@ -20,7 +19,6 @@ impl Users { /// A separate module for the clock field repo to avoid type conflicts mod users_with_clock { use es_entity::{EsEntity, EsEvent, EsRepo, clock::ClockHandle}; - use sqlx::PgPool; use crate::entities::user::*; @@ -28,16 +26,16 @@ mod users_with_clock { #[derive(EsRepo, Debug)] #[es_repo(entity = "User", columns(name(ty = "String", list_for)))] pub struct UsersWithClock { - pool: PgPool, + pool: es_entity::db::Pool, clock: Option, } impl UsersWithClock { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool, clock: None } } - pub fn with_clock(pool: PgPool, clock: ClockHandle) -> Self { + pub fn with_clock(pool: es_entity::db::Pool, clock: ClockHandle) -> Self { Self { pool, clock: Some(clock), @@ -51,7 +49,6 @@ use users_with_clock::UsersWithClock; /// A separate module for the required clock field repo mod users_with_required_clock { use es_entity::{EsEntity, EsEvent, EsRepo, clock::ClockHandle}; - use sqlx::PgPool; use crate::entities::user::*; @@ -59,12 +56,12 @@ mod users_with_required_clock { #[derive(EsRepo, Debug)] #[es_repo(entity = "User", columns(name(ty = "String", list_for)))] pub struct UsersWithRequiredClock { - pool: PgPool, + pool: es_entity::db::Pool, clock: ClockHandle, } impl UsersWithRequiredClock { - pub fn new(pool: PgPool, clock: ClockHandle) -> Self { + pub fn new(pool: es_entity::db::Pool, clock: ClockHandle) -> Self { Self { pool, clock } } } diff --git a/tests/repo_crud.rs b/tests/repo_crud.rs index efc097e5..401b6f89 100644 --- a/tests/repo_crud.rs +++ b/tests/repo_crud.rs @@ -3,7 +3,6 @@ mod helpers; use entities::user::*; use es_entity::*; -use sqlx::PgPool; /// Regression test: repo structs with a generic parameter named 'E' must compile /// without conflicting with the macro's internal error generic. @@ -12,7 +11,6 @@ mod generic_e_repo { #![allow(dead_code)] use es_entity::*; - use sqlx::PgPool; use crate::entities::user::*; @@ -21,7 +19,7 @@ mod generic_e_repo { #[derive(EsRepo, Debug)] #[es_repo(entity = "User", columns(name(ty = "String", list_for)))] pub struct UsersWithGenericE { - pool: PgPool, + pool: es_entity::db::Pool, _marker: std::marker::PhantomData, } } @@ -29,11 +27,11 @@ mod generic_e_repo { #[derive(EsRepo, Debug)] #[es_repo(entity = "User", columns(name(ty = "String", list_for)))] pub struct Users { - pool: PgPool, + pool: es_entity::db::Pool, } impl Users { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: es_entity::db::Pool) -> Self { Self { pool } } } @@ -104,6 +102,14 @@ async fn list_for_filters() -> anyhow::Result<()> { let users = Users::new(pool); + // Seed a user so the table isn't empty + let seed_user = NewUser::builder() + .id(UserId::new()) + .name("SeedUser") + .build() + .unwrap(); + users.create(seed_user).await?; + // Test with default filters (no filter) - should return all entities let PaginatedQueryRet { entities, diff --git a/tests/repo_errors.rs b/tests/repo_errors.rs index 45c72f72..dc6bb474 100644 --- a/tests/repo_errors.rs +++ b/tests/repo_errors.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "postgres")] + mod entities; mod helpers; diff --git a/tests/repo_hooks.rs b/tests/repo_hooks.rs index 5b10573a..e60f5ee4 100644 --- a/tests/repo_hooks.rs +++ b/tests/repo_hooks.rs @@ -34,8 +34,8 @@ impl std::error::Error for UserPersistAuditError {} // --------------------------------------------------------------------------- mod users_with_hydrate_hook { + use es_entity::db::Pool; use es_entity::*; - use sqlx::PgPool; use crate::UserHydrateValidationError; use crate::entities::user::*; @@ -47,11 +47,11 @@ mod users_with_hydrate_hook { post_hydrate_hook(method = "validate_hydrated", error = "UserHydrateValidationError") )] pub struct UsersWithHydrateHook { - pool: PgPool, + pool: Pool, } impl UsersWithHydrateHook { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } @@ -73,8 +73,8 @@ mod users_with_hydrate_hook { // --------------------------------------------------------------------------- mod users_with_persist_hook { + use es_entity::db::Pool; use es_entity::*; - use sqlx::PgPool; use crate::UserPersistAuditError; use crate::entities::user::*; @@ -86,11 +86,11 @@ mod users_with_persist_hook { post_persist_hook(method = "audit_persist", error = "UserPersistAuditError") )] pub struct UsersWithPersistHook { - pool: PgPool, + pool: Pool, } impl UsersWithPersistHook { - pub fn new(pool: PgPool) -> Self { + pub fn new(pool: Pool) -> Self { Self { pool } } From 7b2b5aed8d1e89f04ce3de8d1adac068a9659721 Mon Sep 17 00:00:00 2001 From: bodymindarts Date: Thu, 12 Mar 2026 11:27:04 +0100 Subject: [PATCH 2/4] fix: feature-gate extract_constraint_value for SQLite compatibility The function in error.rs referenced sqlx::postgres::PgDatabaseError without a feature gate. Add #[cfg(feature = "postgres")] and a SQLite no-op variant. Co-Authored-By: Claude Opus 4.6 --- src/error.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/error.rs b/src/error.rs index 002fb2b3..32d74f36 100644 --- a/src/error.rs +++ b/src/error.rs @@ -44,11 +44,19 @@ pub fn parse_constraint_detail_value(detail: Option<&str>) -> Option { /// /// Downcasts to [`sqlx::postgres::PgDatabaseError`], reads its `detail()`, /// and parses the conflicting value. +#[cfg(feature = "postgres")] pub fn extract_constraint_value(db_err: &dyn sqlx::error::DatabaseError) -> Option { let pg_err = db_err.try_downcast_ref::()?; parse_constraint_detail_value(pg_err.detail()) } +#[doc(hidden)] +/// SQLite does not provide detail messages like PostgreSQL, so this always returns `None`. +#[cfg(feature = "sqlite")] +pub fn extract_constraint_value(_db_err: &dyn sqlx::error::DatabaseError) -> Option { + None +} + #[doc(hidden)] /// Wrapper used by generated code to format not-found values. /// Prefers `Display` over `Debug` via inherent-vs-trait method resolution. From d9fa03e245618d86044016bfe84307ceae21123f Mon Sep 17 00:00:00 2001 From: bodymindarts Date: Thu, 12 Mar 2026 11:57:01 +0100 Subject: [PATCH 3/4] ci: add SQLite test runner to nix flake and GitHub Actions Add nextest-sqlite-runner script that runs macro and integration tests with SQLite features (in-memory databases, no infrastructure needed). Add sqlite-tests job to test.yml workflow for CI visibility. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/test.yml | 14 ++++++++++++++ flake.nix | 25 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4fe44cbd..1bdbc5f7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,3 +18,17 @@ jobs: - uses: actions/checkout@v3 - name: Run integration tests run: nix run .#nextest + + sqlite-tests: + name: sqlite tests + runs-on: ubuntu-latest + steps: + - name: Install Nix + uses: DeterminateSystems/nix-installer-action@v16 + - uses: cachix/cachix-action@v15 + with: + name: lana-ci + authToken: ${{ env.CACHIX_AUTH_TOKEN }} + - uses: actions/checkout@v3 + - name: Run SQLite tests + run: nix run .#nextest-sqlite diff --git a/flake.nix b/flake.nix index 8b71ea8c..dfb36190 100644 --- a/flake.nix +++ b/flake.nix @@ -137,10 +137,35 @@ echo "Tests completed successfully!" ''; + + nextest-sqlite-runner = pkgs.writeShellScriptBin "nextest-sqlite-runner" '' + set -e + + export PATH="${pkgs.lib.makeBinPath [ + pkgs.cargo-nextest + pkgs.coreutils + rustToolchain + pkgs.stdenv.cc + ]}:$PATH" + + export SQLX_OFFLINE=true + + echo "Running SQLite macro tests..." + cargo nextest run -p es-entity-macros-sqlite --verbose + + echo "Running SQLite integration tests..." + cargo nextest run -p es-entity --no-default-features --features sqlite,graphql,event-context,instrument,json-schema --verbose + + echo "Running SQLite doc tests..." + cargo test --doc -p es-entity --no-default-features --features sqlite,graphql,event-context,instrument,json-schema + + echo "SQLite tests completed successfully!" + ''; in with pkgs; { packages = { nextest = nextest-runner; + nextest-sqlite = nextest-sqlite-runner; }; checks = { From ff1808d3bf6b6a1df1eb690737b90b26ed8e46b4 Mon Sep 17 00:00:00 2001 From: bodymindarts Date: Fri, 13 Mar 2026 09:26:48 +0100 Subject: [PATCH 4/4] fix(sqlite): use NULL-safe equality for Option columns in find_by and list_for The generated SQL for find_by and list_for queries used bare `= ?` equality, which fails when the bound parameter is NULL because `column = NULL` evaluates to NULL (falsy) in SQL. This caused zero rows to be returned when querying for NULL values on Option columns. Switch to SQLite's `IS` operator for Option columns, which handles NULL-safe equality (`NULL IS NULL` returns true). The list_for_filters variant already handled this correctly via its COALESCE pattern. Co-Authored-By: Claude Opus 4.6 --- .../src/repo/find_by_fn.rs | 41 ++++++- .../src/repo/list_for_fn.rs | 52 ++++++++- tests/entities/mod.rs | 1 + tests/entities/user_document.rs | 70 +++++++++++ tests/optional_column.rs | 110 ++++++++++++++++++ 5 files changed, 271 insertions(+), 3 deletions(-) create mode 100644 tests/entities/user_document.rs create mode 100644 tests/optional_column.rs diff --git a/es-entity-macros-sqlite/src/repo/find_by_fn.rs b/es-entity-macros-sqlite/src/repo/find_by_fn.rs index fa022c98..2a1d8211 100644 --- a/es-entity-macros-sqlite/src/repo/find_by_fn.rs +++ b/es-entity-macros-sqlite/src/repo/find_by_fn.rs @@ -82,8 +82,9 @@ impl ToTokens for FindByFn<'_> { Span::call_site(), ); + let eq_op = if self.column.is_optional() { "IS" } else { "=" }; let query = format!( - r#"SELECT id FROM {} WHERE {} = $1{}"#, + r#"SELECT id FROM {} WHERE {} {eq_op} $1{}"#, self.table_name, column_name, if delete == DeleteOption::No { @@ -412,6 +413,44 @@ mod tests { assert_eq!(tokens.to_string(), expected.to_string()); } + #[test] + fn find_by_fn_optional_column_uses_is() { + let column = Column::new( + syn::Ident::new("project_id", proc_macro2::Span::call_site()), + syn::parse_str("Option").unwrap(), + ); + let entity = Ident::new("Entity", Span::call_site()); + + let persist_fn = FindByFn { + prefix: None, + column: &column, + entity: &entity, + table_name: "entities", + column_enum: syn::Ident::new("EntityColumn", Span::call_site()), + find_error: syn::Ident::new("EntityFindError", Span::call_site()), + query_error: syn::Ident::new("EntityQueryError", Span::call_site()), + delete: DeleteOption::No, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let token_str = tokens.to_string(); + assert!( + token_str.contains("project_id IS $1"), + "Expected 'IS' for Option column, got: {}", + token_str + ); + assert!( + !token_str.contains("project_id = $1"), + "Should not use '=' for Option column" + ); + } + #[test] fn find_by_fn_with_soft_delete() { let column = Column::for_id(syn::parse_str("EntityId").unwrap()); diff --git a/es-entity-macros-sqlite/src/repo/list_for_fn.rs b/es-entity-macros-sqlite/src/repo/list_for_fn.rs index 14febfc7..f88e05dc 100644 --- a/es-entity-macros-sqlite/src/repo/list_for_fn.rs +++ b/es-entity-macros-sqlite/src/repo/list_for_fn.rs @@ -94,8 +94,13 @@ impl ToTokens for ListForFn<'_> { Span::call_site(), ); + let eq_op = if self.for_column.is_optional() { + "IS" + } else { + "=" + }; let asc_query = format!( - r#"SELECT {} FROM {} WHERE (({} = $1) AND ({})){} ORDER BY {} LIMIT $2"#, + r#"SELECT {} FROM {} WHERE (({} {eq_op} $1) AND ({})){} ORDER BY {} LIMIT $2"#, select_columns, self.table_name, for_column_name, @@ -108,7 +113,7 @@ impl ToTokens for ListForFn<'_> { cursor.order_by(true) ); let desc_query = format!( - r#"SELECT {} FROM {} WHERE (({} = $1) AND ({})){} ORDER BY {} LIMIT $2"#, + r#"SELECT {} FROM {} WHERE (({} {eq_op} $1) AND ({})){} ORDER BY {} LIMIT $2"#, select_columns, self.table_name, for_column_name, @@ -389,6 +394,49 @@ mod tests { assert_eq!(tokens.to_string(), expected.to_string()); } + #[test] + fn list_for_optional_column_uses_is() { + let entity = Ident::new("Entity", Span::call_site()); + let query_error = syn::Ident::new("EntityQueryError", Span::call_site()); + let id = syn::Ident::new("EntityId", proc_macro2::Span::call_site()); + let by_column = Column::for_id(syn::parse_str("EntityId").unwrap()); + let for_column = Column::new( + syn::Ident::new("project_id", proc_macro2::Span::call_site()), + syn::parse_str("Option").unwrap(), + ); + let cursor_mod = Ident::new("cursor_mod", Span::call_site()); + + let persist_fn = ListForFn { + ignore_prefix: None, + entity: &entity, + id: &id, + for_column: &for_column, + by_column: &by_column, + table_name: "entities", + query_error, + delete: DeleteOption::No, + cursor_mod, + any_nested: false, + post_hydrate_error: None, + #[cfg(feature = "instrument")] + repo_name_snake: "test_repo".to_string(), + }; + + let mut tokens = TokenStream::new(); + persist_fn.to_tokens(&mut tokens); + + let token_str = tokens.to_string(); + assert!( + token_str.contains("project_id IS $1"), + "Expected 'IS' for Option column, got: {}", + token_str + ); + assert!( + !token_str.contains("project_id = $1"), + "Should not use '=' for Option column" + ); + } + #[test] fn list_same_column() { let entity = Ident::new("Entity", Span::call_site()); diff --git a/tests/entities/mod.rs b/tests/entities/mod.rs index ee709c51..8bec606d 100644 --- a/tests/entities/mod.rs +++ b/tests/entities/mod.rs @@ -1,3 +1,4 @@ pub mod order; pub mod profile; pub mod user; +pub mod user_document; diff --git a/tests/entities/user_document.rs b/tests/entities/user_document.rs new file mode 100644 index 00000000..6022ef51 --- /dev/null +++ b/tests/entities/user_document.rs @@ -0,0 +1,70 @@ +#![allow(dead_code)] + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use es_entity::*; + +es_entity::entity_id! { UserDocumentId } + +use super::user::UserId; + +#[derive(EsEvent, Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +#[es_event(id = "UserDocumentId")] +pub enum UserDocumentEvent { + Initialized { + id: UserDocumentId, + user_id: Option, + }, +} + +#[derive(EsEntity, Builder)] +#[builder(pattern = "owned", build_fn(error = "EntityHydrationError"))] +pub struct UserDocument { + pub id: UserDocumentId, + pub user_id: Option, + + events: EntityEvents, +} + +impl TryFromEvents for UserDocument { + fn try_from_events( + events: EntityEvents, + ) -> Result { + let mut builder = UserDocumentBuilder::default(); + for event in events.iter_all() { + match event { + UserDocumentEvent::Initialized { id, user_id } => { + builder = builder.id(*id).user_id(*user_id); + } + } + } + builder.events(events).build() + } +} + +#[derive(Debug, Builder)] +pub struct NewUserDocument { + #[builder(setter(into))] + pub id: UserDocumentId, + pub user_id: Option, +} + +impl NewUserDocument { + pub fn builder() -> NewUserDocumentBuilder { + NewUserDocumentBuilder::default() + } +} + +impl IntoEvents for NewUserDocument { + fn into_events(self) -> EntityEvents { + EntityEvents::init( + self.id, + [UserDocumentEvent::Initialized { + id: self.id, + user_id: self.user_id, + }], + ) + } +} diff --git a/tests/optional_column.rs b/tests/optional_column.rs new file mode 100644 index 00000000..46ab7c63 --- /dev/null +++ b/tests/optional_column.rs @@ -0,0 +1,110 @@ +mod entities; +mod helpers; + +use entities::{user::*, user_document::*}; +use es_entity::*; + +#[derive(EsRepo, Debug)] +#[es_repo( + entity = "UserDocument", + columns(user_id(ty = "Option", list_for, find_by)) +)] +pub struct UserDocuments { + pool: es_entity::db::Pool, +} + +impl UserDocuments { + pub fn new(pool: es_entity::db::Pool) -> Self { + Self { pool } + } +} + +/// Regression test: list_for on an Option column with None must return rows +/// where the column IS NULL. Previously the generated SQL used `= ?` which +/// evaluates to NULL (falsy) when the bound value is NULL. +#[tokio::test] +async fn list_for_optional_column_with_none() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let docs = UserDocuments::new(pool); + + // Insert a document with user_id = NULL + let null_doc = NewUserDocument::builder() + .id(UserDocumentId::new()) + .user_id(None) + .build() + .unwrap(); + docs.create(null_doc).await?; + + // Insert a document with user_id = Some(...) + let some_user_id = UserId::new(); + let non_null_doc = NewUserDocument::builder() + .id(UserDocumentId::new()) + .user_id(Some(some_user_id)) + .build() + .unwrap(); + docs.create(non_null_doc).await?; + + // Query for documents where user_id IS NULL + let result = docs + .list_for_user_id_by_id( + None, + PaginatedQueryArgs { + first: 10, + after: None, + }, + ListDirection::Ascending, + ) + .await?; + + assert_eq!( + result.entities.len(), + 1, + "Expected 1 row with NULL user_id, got {}", + result.entities.len() + ); + assert_eq!(result.entities[0].user_id, None); + + // Query for documents where user_id = some_user_id (non-NULL still works) + let result = docs + .list_for_user_id_by_id( + Some(some_user_id), + PaginatedQueryArgs { + first: 10, + after: None, + }, + ListDirection::Ascending, + ) + .await?; + + assert_eq!(result.entities.len(), 1); + assert_eq!(result.entities[0].user_id, Some(some_user_id)); + + Ok(()) +} + +/// Regression test: find_by on an Option column with None must return the +/// row where the column IS NULL. +#[tokio::test] +async fn find_by_optional_column_with_none() -> anyhow::Result<()> { + let pool = helpers::init_pool().await?; + let docs = UserDocuments::new(pool); + + // Insert a document with user_id = NULL + let null_doc_id = UserDocumentId::new(); + let null_doc = NewUserDocument::builder() + .id(null_doc_id) + .user_id(None) + .build() + .unwrap(); + docs.create(null_doc).await?; + + // find_by_user_id(None) should find the row + let found = docs.maybe_find_by_user_id(None).await?; + assert!( + found.is_some(), + "Expected to find document with NULL user_id" + ); + assert_eq!(found.unwrap().id, null_doc_id); + + Ok(()) +}