diff --git a/Directory.Packages.props b/Directory.Packages.props index 8cb98ab..6523ea8 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -16,7 +16,9 @@ + + diff --git a/SearchLite.sln b/SearchLite.sln index 8cf2a62..ba9fb65 100644 --- a/SearchLite.sln +++ b/SearchLite.sln @@ -32,6 +32,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SearchLite.Sqlite.Tests", " EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SearchLite.Postgres.Tests", "Tests\SearchLite.Postgres.Tests\SearchLite.Postgres.Tests.csproj", "{2C44B23D-0A5E-4E95-9741-E5A73CD7787E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SearchLite.MariaDb", "Source\SearchLite.MariaDb\SearchLite.MariaDb.csproj", "{242653EF-82A2-436A-B1E2-5164AE7B18E3}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SearchLite.MariaDb.Tests", "Tests\SearchLite.MariaDb.Tests\SearchLite.MariaDb.Tests.csproj", "{90590189-940C-47BC-ABB0-925178D12224}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -126,6 +130,30 @@ Global {2C44B23D-0A5E-4E95-9741-E5A73CD7787E}.Release|x64.Build.0 = Release|Any CPU {2C44B23D-0A5E-4E95-9741-E5A73CD7787E}.Release|x86.ActiveCfg = Release|Any CPU {2C44B23D-0A5E-4E95-9741-E5A73CD7787E}.Release|x86.Build.0 = Release|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Debug|x64.ActiveCfg = Debug|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Debug|x64.Build.0 = Debug|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Debug|x86.ActiveCfg = Debug|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Debug|x86.Build.0 = Debug|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Release|Any CPU.Build.0 = Release|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Release|x64.ActiveCfg = Release|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Release|x64.Build.0 = Release|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Release|x86.ActiveCfg = Release|Any CPU + {242653EF-82A2-436A-B1E2-5164AE7B18E3}.Release|x86.Build.0 = Release|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Debug|Any CPU.Build.0 = Debug|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Debug|x64.ActiveCfg = Debug|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Debug|x64.Build.0 = Debug|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Debug|x86.ActiveCfg = Debug|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Debug|x86.Build.0 = Debug|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Release|Any CPU.ActiveCfg = Release|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Release|Any CPU.Build.0 = Release|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Release|x64.ActiveCfg = Release|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Release|x64.Build.0 = Release|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Release|x86.ActiveCfg = Release|Any CPU + {90590189-940C-47BC-ABB0-925178D12224}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -135,5 +163,7 @@ Global {9DFACF5F-6514-459E-B942-0265965F240B} = {0AB3BF05-4346-4AA6-1389-037BE0695223} {FAACA2A7-8925-4303-B369-E70F0DF4952B} = {0AB3BF05-4346-4AA6-1389-037BE0695223} {2C44B23D-0A5E-4E95-9741-E5A73CD7787E} = {0AB3BF05-4346-4AA6-1389-037BE0695223} + {242653EF-82A2-436A-B1E2-5164AE7B18E3} = {B8EFCA5F-814F-285C-A8CB-F00F14650265} + {90590189-940C-47BC-ABB0-925178D12224} = {0AB3BF05-4346-4AA6-1389-037BE0695223} EndGlobalSection EndGlobal diff --git a/Source/SearchLite.MariaDb/Extensions.cs b/Source/SearchLite.MariaDb/Extensions.cs new file mode 100644 index 0000000..8c7c36e --- /dev/null +++ b/Source/SearchLite.MariaDb/Extensions.cs @@ -0,0 +1,27 @@ +using MySqlConnector; + +namespace SearchLite.MariaDb; + +internal static class Extensions +{ + public static string ToWhereClause(this IReadOnlyList clauses) + { + if (clauses.Count == 0) + { + return string.Empty; + } + + return "WHERE " + string.Join(" AND ", clauses.Select(c => c.Sql)); + } + + public static void AddParameters(this MySqlCommand command, IReadOnlyCollection clauses) + { + foreach (var clause in clauses) + { + foreach (var parameter in clause.Parameters) + { + command.Parameters.Add(parameter); + } + } + } +} diff --git a/Source/SearchLite.MariaDb/SearchIndex.cs b/Source/SearchLite.MariaDb/SearchIndex.cs new file mode 100644 index 0000000..3556223 --- /dev/null +++ b/Source/SearchLite.MariaDb/SearchIndex.cs @@ -0,0 +1,462 @@ +using System.Globalization; +using System.Diagnostics; +using System.Text; +using System.Text.Json; +using System.Text.RegularExpressions; +using MySqlConnector; + +namespace SearchLite.MariaDb; + +public partial class SearchIndex : ISearchIndex where T : ISearchableDocument +{ + private readonly string _connectionString; + private readonly SearchManager _manager; + public string TableName { get; } + public bool Initialized { get; private set; } + + public SearchIndex(string connectionString, string tableName, SearchManager manager) + { + _connectionString = connectionString; + _manager = manager; + TableName = tableName; + } + + public async Task Init(CancellationToken cancellationToken) + { + if (Initialized) + { + return; + } + + await EnsureTableExistsAsync(cancellationToken); + Initialized = true; + } + + public Task GetAsync(string docId, CancellationToken ct = default) + { + return GetDocumentAsync(docId, ct); + } + + private async Task GetDocumentAsync(string id, CancellationToken ct) + { + await using var conn = await CreateConnectionAsync(ct); + var sql = $""" + SELECT document + FROM {TableName} + WHERE id = @id; + """; + await using var cmd = new MySqlCommand(sql, conn); + cmd.Parameters.AddWithValue("id", id); + await using var reader = await cmd.ExecuteReaderAsync(ct); + if (!await reader.ReadAsync(ct)) + { + return default; + } + + var json = reader.GetString(0); + return JsonSerializer.Deserialize(json); + } + + public async Task IndexAsync(T document, CancellationToken ct = default) + { + await using var conn = await CreateConnectionAsync(ct); + var sql = $""" + INSERT INTO {TableName} (id, document, search_text, last_updated) + VALUES (@id, @doc, @text, CURRENT_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE + document = VALUES(document), + search_text = VALUES(search_text), + last_updated = CURRENT_TIMESTAMP(6); + """; + await using var cmd = new MySqlCommand(sql, conn); + cmd.Parameters.AddWithValue("id", document.Id); + cmd.Parameters.AddWithValue("doc", JsonSerializer.Serialize(document)); + cmd.Parameters.AddWithValue("text", document.GetSearchText()); + await cmd.ExecuteNonQueryAsync(ct); + } + + public async Task IndexManyAsync(IEnumerable documents, CancellationToken ct = default) + { + var docs = documents.ToList(); + if (docs.Count == 0) + { + return; + } + + await using var conn = await CreateConnectionAsync(ct); + await using var transaction = await conn.BeginTransactionAsync(ct); + try + { + const int batchSize = 500; + for (var offset = 0; offset < docs.Count; offset += batchSize) + { + var batch = docs.Skip(offset).Take(batchSize).ToList(); + var valueRows = new List(batch.Count); + await using var cmd = new MySqlCommand { Connection = conn, Transaction = transaction }; + + for (var i = 0; i < batch.Count; i++) + { + var doc = batch[i]; + valueRows.Add($"(@id{i}, @doc{i}, @text{i}, CURRENT_TIMESTAMP(6))"); + cmd.Parameters.AddWithValue($"id{i}", doc.Id); + cmd.Parameters.AddWithValue($"doc{i}", JsonSerializer.Serialize(doc)); + cmd.Parameters.AddWithValue($"text{i}", doc.GetSearchText()); + } + + cmd.CommandText = $""" + INSERT INTO {TableName} (id, document, search_text, last_updated) + VALUES {string.Join(", ", valueRows)} + ON DUPLICATE KEY UPDATE + document = VALUES(document), + search_text = VALUES(search_text), + last_updated = CURRENT_TIMESTAMP(6); + """; + await cmd.ExecuteNonQueryAsync(ct); + } + + await transaction.CommitAsync(ct); + } + catch (Exception) + { + await transaction.RollbackAsync(ct); + throw; + } + } + + public async Task> SearchAsync(SearchRequest request, CancellationToken ct = default) + { + await using var conn = await CreateConnectionAsync(ct); + var sw = Stopwatch.StartNew(); + + var hasQuery = !string.IsNullOrWhiteSpace(request.Query); + var booleanQuery = hasQuery ? BuildBooleanQuery(request.Query!, request.Options.IncludePartialMatches) : null; + // A query that tokenizes to nothing (e.g. only punctuation) cannot match anything via FTS. + hasQuery = hasQuery && !string.IsNullOrEmpty(booleanQuery); + + var scoreExpression = hasQuery + ? "MATCH(search_text) AGAINST(@Query IN BOOLEAN MODE)" + : "CAST(0 AS DOUBLE)"; + + var clauses = BuildWhereClauses(request, hasQuery); + var orderClause = BuildOrderByClause(request) ?? "ORDER BY score DESC"; + var offsetClause = request.Options.Skip < 1 ? "" : $"OFFSET {request.Options.Skip}"; + var limitClause = $"LIMIT {request.Options.Take}"; + + var sql = $""" + SELECT id, document, score, last_updated, COUNT(*) OVER() AS total + FROM ( + SELECT id, document, last_updated, + {scoreExpression} AS score + FROM {TableName} + {clauses.ToWhereClause()} + ) AS ranked + WHERE score >= @minScore + {orderClause} + {limitClause} + {offsetClause} + """; + + var results = new List>(); + long totalCount = 0; + float maxScore = 0; + + await using var cmd = new MySqlCommand(sql, conn); + if (hasQuery) + { + cmd.Parameters.AddWithValue("Query", booleanQuery); + } + cmd.Parameters.AddWithValue("minScore", request.Options.MinScore); + cmd.AddParameters(clauses); + + await using (var reader = await cmd.ExecuteReaderAsync(ct)) + { + while (await reader.ReadAsync(ct)) + { + // The score column is DOUBLE for FTS queries and CAST(0 AS DOUBLE) otherwise, but + // read it tolerantly so an unexpected provider/server numeric type can't throw. + var score = Convert.ToSingle(reader.GetValue(2), CultureInfo.InvariantCulture); + maxScore = Math.Max(maxScore, score); + var json = reader.GetString(1); + totalCount = reader.GetInt64(4); + results.Add(new SearchResult + { + Id = reader.GetString(0), + LastUpdated = reader.GetDateTime(3), + Score = score, + Document = request.Options.IncludeRawDocument && !string.IsNullOrEmpty(json) + ? JsonSerializer.Deserialize(json) + : default + }); + } + } + + // If no rows were returned, the window function gave us no total; compute it separately. + if (results.Count == 0) + { + var countSql = $""" + SELECT COUNT(*) + FROM ( + SELECT {scoreExpression} AS score + FROM {TableName} + {clauses.ToWhereClause()} + ) AS ranked + WHERE score >= @minScore + """; + await using var countCmd = new MySqlCommand(countSql, conn); + if (hasQuery) + { + countCmd.Parameters.AddWithValue("Query", booleanQuery); + } + countCmd.Parameters.AddWithValue("minScore", request.Options.MinScore); + countCmd.AddParameters(clauses); + var countResult = await countCmd.ExecuteScalarAsync(ct); + totalCount = Convert.ToInt64(countResult); + } + + return new SearchResponse + { + Results = results, + TotalCount = totalCount, + MaxScore = maxScore, + SearchTime = sw.Elapsed + }; + } + + private static string? BuildOrderByClause(SearchRequest request) + { + if (request.OrderBys.Count == 0) + { + return null; + } + + var orderClauses = request.OrderBys.Select(order => + { + var direction = order.Direction == SortDirection.Ascending ? "ASC" : "DESC"; + return $"{WhereClauseBuilder.BuildOrderAccessor(order.PropertyName)} {direction}"; + }); + return $"ORDER BY {string.Join(", ", orderClauses)}"; + } + + public async Task DeleteAsync(string id, CancellationToken ct = default) + { + await using var conn = await CreateConnectionAsync(ct); + var sql = $""" + DELETE FROM {TableName} + WHERE id = @id; + """; + await using var cmd = new MySqlCommand(sql, conn); + cmd.Parameters.AddWithValue("id", id); + await cmd.ExecuteNonQueryAsync(ct); + } + + public async Task DeleteManyAsync(IEnumerable ids, CancellationToken ct = default) + { + var idsList = ids.ToList(); + if (idsList.Count == 0) return 0; + + await using var conn = await CreateConnectionAsync(ct); + var paramNames = new List(idsList.Count); + await using var cmd = new MySqlCommand { Connection = conn }; + for (var i = 0; i < idsList.Count; i++) + { + var name = $"@id{i}"; + paramNames.Add(name); + cmd.Parameters.AddWithValue($"id{i}", idsList[i]); + } + + cmd.CommandText = $""" + DELETE FROM {TableName} + WHERE id IN ({string.Join(", ", paramNames)}); + """; + return await cmd.ExecuteNonQueryAsync(ct); + } + + public async Task DeleteWhereAsync(SearchRequest request, CancellationToken ct = default) + { + if (request.Filters.Count == 0) + { + throw new InvalidOperationException("DeleteWhereAsync requires at least one filter. Use ClearAsync() to delete all documents."); + } + + await using var conn = await CreateConnectionAsync(ct); + var hasQuery = !string.IsNullOrWhiteSpace(request.Query); + var booleanQuery = hasQuery ? BuildBooleanQuery(request.Query!, request.Options.IncludePartialMatches) : null; + hasQuery = hasQuery && !string.IsNullOrEmpty(booleanQuery); + + var clauses = BuildWhereClauses(request, hasQuery); + var sql = $"DELETE FROM {TableName} {clauses.ToWhereClause()}"; + + await using var cmd = new MySqlCommand(sql, conn); + if (hasQuery) + { + cmd.Parameters.AddWithValue("Query", booleanQuery); + } + cmd.AddParameters(clauses); + return await cmd.ExecuteNonQueryAsync(ct); + } + + public async Task ClearAsync(CancellationToken ct = default) + { + await using var conn = await CreateConnectionAsync(ct); + var sql = $"DELETE FROM {TableName};"; + await using var cmd = new MySqlCommand(sql, conn); + return await cmd.ExecuteNonQueryAsync(ct); + } + + private async Task CreateConnectionAsync(CancellationToken ct) + { + var conn = new MySqlConnection(_connectionString); + await conn.OpenAsync(ct); + return conn; + } + + public async Task DropIndexAsync(CancellationToken ct = default) + { + await using var conn = await CreateConnectionAsync(ct); + var sql = $"DROP TABLE IF EXISTS {TableName}"; + await using var cmd = new MySqlCommand(sql, conn); + await cmd.ExecuteNonQueryAsync(ct); + _manager.Remove(this); + } + + private async Task EnsureTableExistsAsync(CancellationToken cancellationToken) + { + await using var conn = await CreateConnectionAsync(cancellationToken); + var sql = $""" + CREATE TABLE IF NOT EXISTS {TableName} ( + id VARCHAR(255) NOT NULL, + document JSON, + search_text LONGTEXT, + last_updated TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (id), + FULLTEXT KEY {TableName}_ft (search_text) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """; + await using var cmd = new MySqlCommand(sql, conn); + await cmd.ExecuteNonQueryAsync(cancellationToken); + } + + public async Task CountAsync(CancellationToken cancellationToken = default) + { + await using var conn = await CreateConnectionAsync(cancellationToken); + var sql = $"SELECT COUNT(*) FROM {TableName}"; + await using var cmd = new MySqlCommand(sql, conn); + var result = await cmd.ExecuteScalarAsync(cancellationToken); + return Convert.ToInt64(result); + } + + public async Task CountAsync(SearchRequest request, CancellationToken cancellationToken = default) + { + await using var conn = await CreateConnectionAsync(cancellationToken); + var hasQuery = !string.IsNullOrWhiteSpace(request.Query); + var booleanQuery = hasQuery ? BuildBooleanQuery(request.Query!, request.Options.IncludePartialMatches) : null; + hasQuery = hasQuery && !string.IsNullOrEmpty(booleanQuery); + + var clauses = BuildWhereClauses(request, hasQuery); + var sql = $""" + SELECT COUNT(*) + FROM {TableName} + {clauses.ToWhereClause()} + """; + await using var cmd = new MySqlCommand(sql, conn); + if (hasQuery) + { + cmd.Parameters.AddWithValue("Query", booleanQuery); + } + cmd.AddParameters(clauses); + var result = await cmd.ExecuteScalarAsync(cancellationToken); + return Convert.ToInt64(result); + } + + private static List BuildWhereClauses(SearchRequest request, bool hasQuery) + { + List clauses = []; + if (hasQuery) + { + clauses.Add(new Clause { Sql = "MATCH(search_text) AGAINST(@Query IN BOOLEAN MODE)" }); + } + + foreach (var clause in WhereClauseBuilder.BuildClauses(request.Filters)) + { + clauses.Add(clause); + } + + return clauses; + } + + /// + /// Builds a safe MariaDB boolean-mode full-text query from arbitrary user input. + /// The raw query is tokenized into alphanumeric terms (so operator characters such as + /// + - * " ( ) ~ < > cannot reach the parser), then rebuilt: + /// * IncludePartialMatches == true -> each term gets a trailing '*' wildcard and terms + /// are OR-ed (space separated), so any prefix match contributes to the result. + /// * IncludePartialMatches == false -> each term is required ('+term'), so a document must + /// contain every term to match. + /// Returns an empty string when the input tokenizes to nothing. + /// + internal static string BuildBooleanQuery(string query, bool includePartialMatches) + { + var terms = TokenizeRegex().Matches(query) + .Select(m => m.Value) + .Where(t => t.Length > 0) + .ToList(); + + if (terms.Count == 0) + { + return string.Empty; + } + + var builder = new StringBuilder(); + foreach (var term in terms) + { + if (builder.Length > 0) + { + builder.Append(' '); + } + + if (includePartialMatches) + { + // Partial = match ANY term (boolean-mode OR of bare, exact tokens). This mirrors the + // SQLite/Postgres providers: it is term-level OR, NOT prefix matching — a query that + // tokenizes to "c" must match the token "c", not every word starting with c. + builder.Append(term); + } + else + { + // Non-partial = every term required (AND). + builder.Append('+').Append(term); + } + } + + return builder.ToString(); + } + + public static string GetTableName(string collectionName) + { + var sanitizedTypeName = IdentifierRegex().Replace(typeof(T).Name, "").ToLowerInvariant(); + collectionName = IdentifierRegex().Replace(collectionName, "").TrimEnd('_').ToLowerInvariant(); + + var budget = 64 - collectionName.Length - 11; + + if (budget > 0 && sanitizedTypeName.Length > budget) + { + sanitizedTypeName = sanitizedTypeName[..budget]; + } + + var sanitized = $"searchlite_{sanitizedTypeName}_{collectionName}"; + + // MariaDB has a 64-character limit for identifiers + if (sanitized.Length > 64) + { + sanitized = sanitized[..64]; + } + + return sanitized; + } + + [GeneratedRegex(@"[^a-zA-Z0-9_]")] + private static partial Regex IdentifierRegex(); + + [GeneratedRegex(@"[\p{L}\p{N}_]+")] + private static partial Regex TokenizeRegex(); +} diff --git a/Source/SearchLite.MariaDb/SearchLite.MariaDb.csproj b/Source/SearchLite.MariaDb/SearchLite.MariaDb.csproj new file mode 100644 index 0000000..c3805c1 --- /dev/null +++ b/Source/SearchLite.MariaDb/SearchLite.MariaDb.csproj @@ -0,0 +1,15 @@ + + + + SearchLite.MariaDb + + + + + + + + + + + diff --git a/Source/SearchLite.MariaDb/SearchManager.cs b/Source/SearchLite.MariaDb/SearchManager.cs new file mode 100644 index 0000000..33bab4e --- /dev/null +++ b/Source/SearchLite.MariaDb/SearchManager.cs @@ -0,0 +1,55 @@ +using System.Collections.Concurrent; + +namespace SearchLite.MariaDb; + +public class SearchManager : ISearchEngineManager +{ + private readonly string _connectionString; + private readonly ConcurrentDictionary _cache = new(); + private readonly SemaphoreSlim _lock = new(1, 1); + + public SearchManager(string connectionString) + { + _connectionString = connectionString; + } + + public async Task> Get(string collectionName, CancellationToken cancellationToken = default) + where T : ISearchableDocument + { + var tableName = SearchIndex.GetTableName(collectionName); + var cached = _cache.GetOrAdd(tableName, Create); + + if (cached is not SearchIndex searchEngine) + { + throw new InvalidOperationException( + $"Unexpected type {cached.GetType().Name} in cache for {collectionName}"); + } + + if (searchEngine.Initialized) return searchEngine; + + await _lock.WaitAsync(cancellationToken); + try + { + if (!searchEngine.Initialized) + { + await searchEngine.Init(cancellationToken); + } + } + finally + { + _lock.Release(); + } + + return searchEngine; + } + + private SearchIndex Create(string tableName) where T : ISearchableDocument + { + return new SearchIndex(_connectionString, tableName, this); + } + + public void Remove(SearchIndex index) where T : ISearchableDocument + { + _cache.TryRemove(index.TableName, out _); + } +} diff --git a/Source/SearchLite.MariaDb/ServiceCollectionExtensions.cs b/Source/SearchLite.MariaDb/ServiceCollectionExtensions.cs new file mode 100644 index 0000000..0b5f44a --- /dev/null +++ b/Source/SearchLite.MariaDb/ServiceCollectionExtensions.cs @@ -0,0 +1,13 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace SearchLite.MariaDb; + +public static class ServiceCollectionExtensions +{ + public static IServiceCollection AddSearch(this IServiceCollection services, + string connectionString) where T : ISearchableDocument + { + services.AddSingleton(_ => new SearchManager(connectionString)); + return services; + } +} diff --git a/Source/SearchLite.MariaDb/WhereClauseBuilder.cs b/Source/SearchLite.MariaDb/WhereClauseBuilder.cs new file mode 100644 index 0000000..e480a38 --- /dev/null +++ b/Source/SearchLite.MariaDb/WhereClauseBuilder.cs @@ -0,0 +1,582 @@ +using System.Text.Json; +using System.Text.Json.Nodes; +using MySqlConnector; + +namespace SearchLite.MariaDb; + +public record Clause +{ + public required string Sql { get; init; } + public List Parameters { get; init; } = []; +} + +public static class WhereClauseBuilder +{ + public static IEnumerable BuildClauses(List> filters) + { + var globalParamCounter = 0; + return filters.Select(filter => BuildClause(filter, ref globalParamCounter)); + } + + private static Clause BuildClause(FilterNode filter, ref int globalParamCounter) + { + var parameters = new List(); + var sql = BuildSql(filter, ref globalParamCounter, parameters); + return new Clause + { + Sql = sql, + Parameters = parameters + }; + } + + private static string BuildSql(FilterNode node, ref int paramCounter, List parameters) + { + return node switch + { + FilterNode.Condition condition => BuildConditionSql(condition, ref paramCounter, parameters), + FilterNode.Group group => BuildGroupSql(group, ref paramCounter, parameters), + _ => throw new ArgumentException($"Unsupported node type: {node.GetType()}") + }; + } + + private static string BuildGroupSql(FilterNode.Group group, ref int paramCounter, + List parameters) + { + var op = group.Operator switch + { + LogicalOperator.And => " AND ", + LogicalOperator.Or => " OR ", + _ => throw new ArgumentException($"Unsupported logical operator: {group.Operator}") + }; + + var conditions = new List(); + + foreach (var condition in group.Conditions) + { + conditions.Add(BuildSql(condition, ref paramCounter, parameters)); + } + + return conditions.Count > 1 + ? $"({string.Join(op, conditions)})" + : conditions.FirstOrDefault() ?? "TRUE"; + } + + /// + /// Builds a JSON path string for the given dotted property path. + /// 1 segment -> $.seg + /// N segments -> $.seg1.seg2... + /// + private static string BuildJsonPath(string propertyName) + { + var segments = FieldPath.Split(propertyName); + return "$." + string.Join(".", segments); + } + + /// + /// Builds a scalar text accessor for the given dotted property path that yields SQL NULL for a + /// JSON null OR a missing key, and the unquoted scalar text otherwise (preserving the empty + /// string). This is deliberately NOT JSON_VALUE — MariaDB's JSON_VALUE collapses an empty + /// string to NULL and nulls out objects — and NOT a bare JSON_UNQUOTE(JSON_EXTRACT(...)), which + /// returns the literal text 'null' for a present-but-null field. Both break IS NULL / + /// IsNullOrEmpty / ordering / nested null-guard semantics. The CASE on JSON_TYPE handles every + /// case: missing -> JSON_EXTRACT is SQL NULL; JSON null -> JSON_TYPE 'NULL'; anything else + /// (incl. "") -> the unquoted value, and for an object the (non-null) object text so an + /// IS NOT NULL guard on a nested object still holds. + /// + private static string BuildTextAccessor(string propertyName) + { + var extract = $"JSON_EXTRACT(document, '{BuildJsonPath(propertyName)}')"; + return $"(CASE WHEN JSON_TYPE({extract}) = 'NULL' THEN NULL ELSE JSON_UNQUOTE({extract}) END)"; + } + + /// + /// Builds an ORDER BY accessor for a dotted property path. Uses the scalar accessor (so a JSON + /// null / missing key sorts as SQL NULL rather than as a JSON-null value), and casts numeric + /// fields so they sort numerically instead of lexically ("100" before "20"). Other types + /// (string, DateTime as ISO text, Guid, bool, enum) order on their text form. + /// + public static string BuildOrderAccessor(string propertyName) + { + var accessor = BuildTextAccessor(propertyName); + var leaf = ResolveLeafType(propertyName); + if (leaf == null) + { + return accessor; + } + + var underlying = Nullable.GetUnderlyingType(leaf) ?? leaf; + string? cast = null; + if (underlying == typeof(int) || underlying == typeof(long) || underlying == typeof(short) + || underlying == typeof(byte) || underlying == typeof(char)) + { + cast = "SIGNED"; + } + else if (underlying == typeof(double) || underlying == typeof(float) || underlying == typeof(decimal)) + { + cast = "DECIMAL(65,30)"; + } + + return cast == null ? accessor : $"CAST({accessor} AS {cast})"; + } + + private static Type? ResolveLeafType(string propertyName) + { + Type? current = typeof(T); + foreach (var segment in FieldPath.Split(propertyName)) + { + var prop = current?.GetProperty(segment, + System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.IgnoreCase); + if (prop == null) + { + return null; + } + + current = Nullable.GetUnderlyingType(prop.PropertyType) ?? prop.PropertyType; + } + + return current; + } + + private static string BuildConditionSql(FilterNode.Condition condition, ref int paramCounter, + List parameters) + { + if (IsCollectionOperator(condition.Operator)) + { + return BuildCollectionCondition(condition, ref paramCounter, parameters); + } + + if (IsNullOperator(condition.Operator)) + { + return BuildNullCondition(condition.PropertyName, condition.Operator); + } + + if (IsStringNullOrEmptyOperator(condition.Operator)) + { + return BuildStringNullOrEmptyCondition(condition.PropertyName, condition.Operator); + } + + if (IsSetOperator(condition.Operator)) + { + return BuildSetCondition(condition, ref paramCounter, parameters); + } + + if (IsStringOperator(condition.Operator)) + { + return BuildStringCondition(condition, ref paramCounter, parameters); + } + + var underlyingType = Nullable.GetUnderlyingType(condition.PropertyType) ?? condition.PropertyType; + + // Equal/NotEqual on containment-eligible types use JSON_CONTAINS so the value is matched + // against the stored JSON representation (mirrors Postgres @> containment). + if ((condition.Operator is Operator.Equal or Operator.NotEqual) + && IsContainmentEligible(underlyingType)) + { + return BuildContainmentEqualityCondition(condition, underlyingType, ref paramCounter, parameters); + } + + var fieldExpression = BuildCastAccessor(condition.PropertyName, condition.PropertyType); + var operatorString = GetOperatorString(condition.Operator); + var paramName = $"@p{paramCounter++}"; + + object? paramValue = condition.Value; + + if (underlyingType.IsEnum) + { + var format = ResolveEnumFormat(condition.PropertyName, underlyingType); + + paramValue = format == EnumSerializationFormat.String + ? condition.Value?.ToString() + : condition.Value != null + ? Convert.ChangeType(condition.Value, underlyingType.GetEnumUnderlyingType()) + : null; + } + else if (underlyingType == typeof(Guid)) + { + paramValue = condition.Value?.ToString(); + } + + parameters.Add(new MySqlParameter(paramName, paramValue ?? DBNull.Value)); + + return $"{fieldExpression} {operatorString} {paramName}"; + } + + /// + /// Builds a typed accessor that casts the extracted JSON scalar to the appropriate SQL type + /// for range/scalar comparisons. + /// + private static string BuildCastAccessor(string propertyName, Type propertyType) + { + var castType = GetMariaDbCastType(propertyType, propertyName); + var accessor = BuildTextAccessor(propertyName); + return castType == null + ? accessor + : $"CAST({accessor} AS {castType})"; + } + + /// + /// Underlying types for which JSON containment reliably matches the stored representation. + /// + private static bool IsContainmentEligible(Type underlyingType) + { + if (underlyingType.IsEnum) return true; + + return underlyingType == typeof(string) + || underlyingType == typeof(bool) + || underlyingType == typeof(Guid) + || underlyingType == typeof(int) + || underlyingType == typeof(long) + || underlyingType == typeof(short) + || underlyingType == typeof(byte); + } + + private static EnumSerializationFormat ResolveEnumFormat(string propertyName, Type underlyingType) + { + var prop = typeof(T).GetProperty(propertyName, System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.IgnoreCase); + return prop != null + ? EnumSerializationAnalyzer.GetPropertyFormat(prop) + : EnumSerializationAnalyzer.GetDefaultFormat(underlyingType); + } + + /// + /// Produces the JSON node representation of a scalar leaf value, honoring enum format. + /// + private static JsonNode? BuildLeafJson(object? value, Type underlyingType, string propertyName) + { + if (value == null) return null; + + if (underlyingType.IsEnum) + { + var format = ResolveEnumFormat(propertyName, underlyingType); + if (format == EnumSerializationFormat.String) + { + return JsonValue.Create(value.ToString()); + } + + var numeric = Convert.ChangeType(value, underlyingType.GetEnumUnderlyingType()); + return JsonSerializer.SerializeToNode(numeric, numeric!.GetType()); + } + + if (underlyingType == typeof(Guid)) + { + return JsonValue.Create(((Guid)value).ToString("D")); + } + + if (underlyingType == typeof(string)) + { + return JsonValue.Create((string)value); + } + + if (underlyingType == typeof(bool)) + { + return JsonValue.Create((bool)value); + } + + // Integer types: serialize as JSON numbers. + return JsonSerializer.SerializeToNode(value, underlyingType); + } + + private static string BuildContainmentEqualityCondition(FilterNode.Condition condition, Type underlyingType, + ref int paramCounter, List parameters) + { + var leaf = BuildLeafJson(condition.Value, underlyingType, condition.PropertyName); + var json = leaf?.ToJsonString() ?? "null"; + var path = BuildJsonPath(condition.PropertyName); + + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, json)); + + // JSON_CONTAINS(document, candidate, path) returns 1 when the value at `path` contains + // the candidate value (or, for scalars, equals it). + return condition.Operator == Operator.Equal + ? $"JSON_CONTAINS(document, {paramName}, '{path}')" + : $"NOT JSON_CONTAINS(document, {paramName}, '{path}')"; + } + + private static bool IsCollectionOperator(Operator op) + { + return op is Operator.CollectionContains or Operator.CollectionNotContains; + } + + private static string BuildCollectionCondition(FilterNode.Condition condition, ref int paramCounter, + List parameters) + { + var elementType = GetElementType(condition.PropertyType); + var underlyingElementType = Nullable.GetUnderlyingType(elementType) ?? elementType; + + var leaf = BuildLeafJson(condition.Value, underlyingElementType, condition.PropertyName); + var json = leaf?.ToJsonString() ?? "null"; + var path = BuildJsonPath(condition.PropertyName); + + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, json)); + + // JSON_CONTAINS against an array path matches when the array contains the candidate element. + return condition.Operator == Operator.CollectionContains + ? $"JSON_CONTAINS(document, {paramName}, '{path}')" + : $"NOT JSON_CONTAINS(document, {paramName}, '{path}')"; + } + + private static Type GetElementType(Type collectionType) + { + if (collectionType.IsArray) + { + return collectionType.GetElementType()!; + } + + if (collectionType.IsGenericType) + { + return collectionType.GetGenericArguments()[0]; + } + + var enumerableInterface = collectionType.GetInterfaces() + .FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>)); + if (enumerableInterface != null) + { + return enumerableInterface.GetGenericArguments()[0]; + } + + throw new NotSupportedException($"Cannot determine element type for collection type {collectionType}"); + } + + /// + /// Returns the MariaDB CAST target type for the given .NET type, or null when no cast is needed + /// (the value is compared as the unquoted text scalar). + /// + private static string? GetMariaDbCastType(Type type, string propertyName) + { + var underlyingType = Nullable.GetUnderlyingType(type) ?? type; + + if (underlyingType.IsEnum) + { + var format = ResolveEnumFormat(propertyName, underlyingType); + return format == EnumSerializationFormat.String ? "CHAR" : "SIGNED"; + } + + return underlyingType switch + { + { } t when t == typeof(int) => "SIGNED", + { } t when t == typeof(string) => "CHAR", + { } t when t == typeof(bool) => "SIGNED", + { } t when t == typeof(double) => "DECIMAL(65,30)", + { } t when t == typeof(decimal) => "DECIMAL(65,30)", + { } t when t == typeof(DateTime) => "DATETIME(6)", + { } t when t == typeof(DateTimeOffset) => "DATETIME(6)", + { } t when t == typeof(Guid) => "CHAR", + { } t when t == typeof(byte) => "SIGNED", + { } t when t == typeof(short) => "SIGNED", + { } t when t == typeof(long) => "SIGNED", + { } t when t == typeof(float) => "DECIMAL(65,30)", + { } t when t == typeof(char) => "SIGNED", + _ => throw new NotSupportedException($"Type {underlyingType} is not supported") + }; + } + + private static string GetOperatorString(Operator op) + { + return op switch + { + Operator.Equal => "=", + Operator.NotEqual => "!=", + Operator.GreaterThan => ">", + Operator.GreaterThanOrEqual => ">=", + Operator.LessThan => "<", + Operator.LessThanOrEqual => "<=", + _ => throw new NotSupportedException($"Operator {op} is not supported") + }; + } + + private static bool IsStringNullOrEmptyOperator(Operator op) + { + return op is Operator.IsNullOrEmpty or Operator.IsNotNullOrEmpty or + Operator.IsNullOrWhiteSpace or Operator.IsNotNullOrWhiteSpace; + } + + private static bool IsNullOperator(Operator op) + { + return op is Operator.IsNull or Operator.IsNotNull; + } + + private static string BuildNullCondition(string propertyName, Operator op) + { + var fieldExpression = BuildTextAccessor(propertyName); + + return op switch + { + Operator.IsNull => $"{fieldExpression} IS NULL", + Operator.IsNotNull => $"{fieldExpression} IS NOT NULL", + _ => throw new NotSupportedException($"Null operator {op} is not supported") + }; + } + + private static string BuildStringNullOrEmptyCondition(string propertyName, Operator op) + { + var fieldExpression = BuildTextAccessor(propertyName); + + // Emptiness is tested with CHAR_LENGTH, not `= ''`: MySQL's PAD SPACE collation treats a + // string of only spaces (" ") as equal to '', which would make IsNullOrEmpty wrongly match + // whitespace. CHAR_LENGTH counts the actual characters. + var trimmed = $"TRIM(REPLACE(REPLACE(REPLACE({fieldExpression}, CHAR(9), ' '), CHAR(10), ' '), CHAR(13), ' '))"; + return op switch + { + Operator.IsNullOrEmpty => $"({fieldExpression} IS NULL OR CHAR_LENGTH({fieldExpression}) = 0)", + Operator.IsNotNullOrEmpty => $"({fieldExpression} IS NOT NULL AND CHAR_LENGTH({fieldExpression}) > 0)", + Operator.IsNullOrWhiteSpace => $"({fieldExpression} IS NULL OR CHAR_LENGTH({trimmed}) = 0)", + Operator.IsNotNullOrWhiteSpace => $"({fieldExpression} IS NOT NULL AND CHAR_LENGTH({trimmed}) > 0)", + _ => throw new NotSupportedException($"String operator {op} is not supported") + }; + } + + private static bool IsSetOperator(Operator op) + { + return op is Operator.In or Operator.NotIn; + } + + private static bool IsStringOperator(Operator op) + { + return op is Operator.Contains or Operator.NotContains or + Operator.ContainsIgnoreCase or Operator.NotContainsIgnoreCase or + Operator.StartsWith or Operator.NotStartsWith or + Operator.StartsWithIgnoreCase or Operator.NotStartsWithIgnoreCase or + Operator.EndsWith or Operator.NotEndsWith or + Operator.EndsWithIgnoreCase or Operator.NotEndsWithIgnoreCase; + } + + private static string BuildSetCondition(FilterNode.Condition condition, ref int paramCounter, List parameters) + { + var fieldExpression = BuildCastAccessor(condition.PropertyName, condition.PropertyType); + + return condition.Operator switch + { + Operator.In => BuildInCondition(fieldExpression, condition.Value, ref paramCounter, parameters, condition.PropertyType, condition.PropertyName), + Operator.NotIn => $"NOT ({BuildInCondition(fieldExpression, condition.Value, ref paramCounter, parameters, condition.PropertyType, condition.PropertyName)})", + _ => throw new NotSupportedException($"Set operator {condition.Operator} is not supported") + }; + } + + private static string BuildStringCondition(FilterNode.Condition condition, ref int paramCounter, List parameters) + { + var fieldExpression = BuildTextAccessor(condition.PropertyName); + + return condition.Operator switch + { + Operator.Contains => BuildContainsCondition(fieldExpression, condition.Value, ref paramCounter, parameters), + Operator.NotContains => $"NOT ({BuildContainsCondition(fieldExpression, condition.Value, ref paramCounter, parameters)})", + Operator.ContainsIgnoreCase => BuildContainsIgnoreCaseCondition(fieldExpression, condition.Value, ref paramCounter, parameters), + Operator.NotContainsIgnoreCase => $"NOT ({BuildContainsIgnoreCaseCondition(fieldExpression, condition.Value, ref paramCounter, parameters)})", + Operator.StartsWith => BuildStartsWithCondition(fieldExpression, condition.Value, ref paramCounter, parameters), + Operator.NotStartsWith => $"NOT ({BuildStartsWithCondition(fieldExpression, condition.Value, ref paramCounter, parameters)})", + Operator.StartsWithIgnoreCase => BuildStartsWithIgnoreCaseCondition(fieldExpression, condition.Value, ref paramCounter, parameters), + Operator.NotStartsWithIgnoreCase => $"NOT ({BuildStartsWithIgnoreCaseCondition(fieldExpression, condition.Value, ref paramCounter, parameters)})", + Operator.EndsWith => BuildEndsWithCondition(fieldExpression, condition.Value, ref paramCounter, parameters), + Operator.NotEndsWith => $"NOT ({BuildEndsWithCondition(fieldExpression, condition.Value, ref paramCounter, parameters)})", + Operator.EndsWithIgnoreCase => BuildEndsWithIgnoreCaseCondition(fieldExpression, condition.Value, ref paramCounter, parameters), + Operator.NotEndsWithIgnoreCase => $"NOT ({BuildEndsWithIgnoreCaseCondition(fieldExpression, condition.Value, ref paramCounter, parameters)})", + _ => throw new NotSupportedException($"String operator {condition.Operator} is not supported") + }; + } + + private static string BuildContainsCondition(string fieldExpression, object value, ref int paramCounter, List parameters) + { + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, $"%{value}%")); + return $"{fieldExpression} LIKE {paramName}"; + } + + private static string BuildInCondition(string fieldExpression, object? value, ref int paramCounter, List parameters, Type? propertyType = null, string? propertyName = null) + { + if (value is System.Collections.IEnumerable enumerable and not string) + { + var values = new List(); + foreach (var item in enumerable) + { + values.Add(item); + } + + if (values.Count == 0) + { + return "FALSE"; + } + + var paramNames = new List(); + foreach (var item in values) + { + var paramName = $"@p{paramCounter++}"; + object? paramValue = item; + + var underlyingType = propertyType != null ? Nullable.GetUnderlyingType(propertyType) ?? propertyType : null; + + if (underlyingType != null && underlyingType.IsEnum) + { + var prop = typeof(T).GetProperty(propertyName ?? string.Empty, System.Reflection.BindingFlags.Instance | System.Reflection.BindingFlags.Public | System.Reflection.BindingFlags.IgnoreCase); + var format = prop != null + ? EnumSerializationAnalyzer.GetPropertyFormat(prop) + : EnumSerializationAnalyzer.GetDefaultFormat(underlyingType); + + if (format == EnumSerializationFormat.String) + { + paramValue = item?.ToString(); + } + else if (item != null) + { + paramValue = Convert.ChangeType(item, underlyingType.GetEnumUnderlyingType()); + } + } + else if (underlyingType == typeof(Guid)) + { + paramValue = item?.ToString(); + } + + parameters.Add(new MySqlParameter(paramName, paramValue ?? DBNull.Value)); + paramNames.Add(paramName); + } + + return $"{fieldExpression} IN ({string.Join(", ", paramNames)})"; + } + + var singleParamName = $"@p{paramCounter++}"; + object? singleParamValue = value; + if (propertyType == typeof(Guid) || propertyType == typeof(Guid?)) + { + singleParamValue = value?.ToString() ?? ""; + } + + parameters.Add(new MySqlParameter(singleParamName, singleParamValue ?? DBNull.Value)); + return $"{fieldExpression} = {singleParamName}"; + } + + private static string BuildContainsIgnoreCaseCondition(string fieldExpression, object value, ref int paramCounter, List parameters) + { + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, $"%{value}%")); + return $"LOWER({fieldExpression}) LIKE LOWER({paramName})"; + } + + private static string BuildStartsWithCondition(string fieldExpression, object value, ref int paramCounter, List parameters) + { + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, $"{value}%")); + return $"{fieldExpression} LIKE {paramName}"; + } + + private static string BuildStartsWithIgnoreCaseCondition(string fieldExpression, object value, ref int paramCounter, List parameters) + { + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, $"{value}%")); + return $"LOWER({fieldExpression}) LIKE LOWER({paramName})"; + } + + private static string BuildEndsWithCondition(string fieldExpression, object value, ref int paramCounter, List parameters) + { + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, $"%{value}")); + return $"{fieldExpression} LIKE {paramName}"; + } + + private static string BuildEndsWithIgnoreCaseCondition(string fieldExpression, object value, ref int paramCounter, List parameters) + { + var paramName = $"@p{paramCounter++}"; + parameters.Add(new MySqlParameter(paramName, $"%{value}")); + return $"LOWER({fieldExpression}) LIKE LOWER({paramName})"; + } +} diff --git a/Tests/SearchLite.MariaDb.Tests/Fixtures/MariaDbFixture.cs b/Tests/SearchLite.MariaDb.Tests/Fixtures/MariaDbFixture.cs new file mode 100644 index 0000000..e383f78 --- /dev/null +++ b/Tests/SearchLite.MariaDb.Tests/Fixtures/MariaDbFixture.cs @@ -0,0 +1,62 @@ +using Testcontainers.MariaDb; + +namespace SearchLite.Tests.MariaDb.Fixtures; + +[Collection("mariadb")] +public sealed class MariaDbFixture : IAsyncLifetime +{ + private readonly MariaDbContainer? _container; + + public string ConnectionString { get; private set; } + + public MariaDbFixture() + { + // Allow pointing the suite at an already-running MariaDB (e.g. a local server started with + // innodb_ft_min_token_size=1) via an env var, so it can be run without Docker. When unset, + // a throwaway Testcontainers MariaDB is used (the CI path). + var external = Environment.GetEnvironmentVariable("SEARCHLITE_MARIADB_CONNSTR"); + if (!string.IsNullOrEmpty(external)) + { + ConnectionString = external; + return; + } + + ConnectionString = null!; + var userName = Guid.NewGuid().ToString("N"); + var password = Guid.NewGuid().ToString("N"); + var dbName = Guid.NewGuid().ToString("N"); + + _container = new MariaDbBuilder() + .WithImage("mariadb:11") + .WithUsername(userName) + .WithPassword(password) + .WithDatabase(dbName) + // InnoDB's default minimum FULLTEXT token size is 3, which would drop the short tokens + // ("c", "1", "doc"...) the conformance suite searches for. These settings must be in + // place before any FULLTEXT index is created, so they are passed as server startup + // arguments; the index is created afterwards by SearchIndex, so it picks up the change. + .WithCommand( + "--innodb-ft-min-token-size=1", + "--ft-min-word-len=1") + .Build(); + } + + public async Task InitializeAsync() + { + if (_container is null) + { + return; + } + + await _container.StartAsync(); + ConnectionString = _container.GetConnectionString(); + } + + public async Task DisposeAsync() + { + if (_container is not null) + { + await _container.StopAsync(); + } + } +} diff --git a/Tests/SearchLite.MariaDb.Tests/IndexTests.cs b/Tests/SearchLite.MariaDb.Tests/IndexTests.cs new file mode 100644 index 0000000..f79893e --- /dev/null +++ b/Tests/SearchLite.MariaDb.Tests/IndexTests.cs @@ -0,0 +1,41 @@ +using FluentAssertions; +using SearchLite.MariaDb; +using SearchLite.Tests.MariaDb.Fixtures; + +namespace SearchLite.Tests.MariaDb; + +public class IndexTests(MariaDbFixture fixture) + : Tests.IndexTests(new SearchManager(fixture.ConnectionString)), IClassFixture +{ + [Theory] + // MariaDB boolean-mode relevance scores are not comparable to other providers, so we only + // assert the deterministic threshold-of-zero case: every full-text match is returned. + // Query "Exact match testing" (partial matches on) matches docs 1 ("exact"/"match") and + // 2 ("match"); doc 3 ("Unrelated document") shares no term. + [InlineData(2, 0.0)] + public async Task SearchAsync_WithMinScore_ShouldFilterLowScores(int expectedCount, float minScore) + { + var docs = new[] + { + new TestDocument { Id = "1", Title = "Exact match test" }, + new TestDocument { Id = "2", Title = "Somewhat related match" }, + new TestDocument { Id = "3", Title = "Unrelated document" } + }; + + await Index.IndexManyAsync(docs); + + var request = new SearchRequest + { + Query = "Exact match testing", + Options = new SearchOptions + { + MinScore = minScore, + IncludePartialMatches = true + } + }; + + var result = await Index.SearchAsync(request); + + result.Results.Should().HaveCount(expectedCount); + } +} diff --git a/Tests/SearchLite.MariaDb.Tests/SearchLite.MariaDb.Tests.csproj b/Tests/SearchLite.MariaDb.Tests/SearchLite.MariaDb.Tests.csproj new file mode 100644 index 0000000..9436879 --- /dev/null +++ b/Tests/SearchLite.MariaDb.Tests/SearchLite.MariaDb.Tests.csproj @@ -0,0 +1,31 @@ + + + net8.0;net9.0;net10.0 + enable + enable + false + preview + SearchLite.Tests.MariaDb + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + diff --git a/Tests/SearchLite.MariaDb.Tests/TableNameTests.cs b/Tests/SearchLite.MariaDb.Tests/TableNameTests.cs new file mode 100644 index 0000000..856d689 --- /dev/null +++ b/Tests/SearchLite.MariaDb.Tests/TableNameTests.cs @@ -0,0 +1,45 @@ +using FluentAssertions; +using SearchLite.MariaDb; + +namespace SearchLite.Tests.MariaDb; + +public class TableNameTests +{ + [Fact] + public void GetTableName_ShouldReturnCorrectTableName() + { + var result = SearchIndex.GetTableName("test"); + + result.Should().Be("searchlite_testdocument_test"); + } + + [Fact] + public void GetTableName_ShouldReturnCorrectTableNameGeneric() + { + var result = SearchIndex>.GetTableName("test"); + + result.Should().Be("searchlite_genericandlongtestdocument1_test"); + } + + [Fact] + public void GetTableName_ShouldReturnCorrectTableNameWhenTooLong() + { + var result = SearchIndex>.GetTableName("a".PadRight(32, 'a')); + + result.Should().Be("searchlite_genericandlongtestdoc_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + result.Length.Should().Be(64); + } + + public class Foo{} + + public class GenericAndLongTestDocument: ISearchableDocument + { + public required string Id { get; init; } + public required string Title { get; init; } + public string Content { get; init; } = ""; + public int Views { get; init; } + public DateTime CreatedAt { get; set; } + + public string GetSearchText() => $"{Title} {Content}"; + } +} diff --git a/Tests/SearchLite.MariaDb.Tests/WhereClauseTests.cs b/Tests/SearchLite.MariaDb.Tests/WhereClauseTests.cs new file mode 100644 index 0000000..397b9e5 --- /dev/null +++ b/Tests/SearchLite.MariaDb.Tests/WhereClauseTests.cs @@ -0,0 +1,352 @@ +using System.Linq.Expressions; +using System.Text.Json.Serialization; +using FluentAssertions; +using SearchLite.MariaDb; + +namespace SearchLite.Tests.MariaDb; + +public class WhereClauseTests +{ + public enum TestEnum + { + Value1, + Value2 + } + + [JsonConverter(typeof(JsonStringEnumConverter))] + public enum TestStringEnum + { + String1, + String2 + } + + public class Address + { + public required string City { get; set; } + } + + public class Author + { + public required string Name { get; set; } + public int Rank { get; set; } + public required Address Address { get; set; } + public required List Roles { get; set; } + } + + public class TestModel + { + public int Age { get; set; } + public required string Name { get; set; } + public bool IsActive { get; set; } + public double Score { get; set; } + public decimal Price { get; set; } + public DateTime CreatedAt { get; set; } + public TestEnum EnumValue { get; set; } + public TestStringEnum StringEnumValue { get; set; } + public required Author Author { get; set; } + public required List Labels { get; set; } + } + + [Fact] + public void Should_Handle_Simple_Integer_Comparison() + { + var clause = BuildClause(x => x.Age > 18); + + clause.Sql.Should().Be($"{Cast("$.Age", "SIGNED")} > @p0"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(18); + } + + [Fact] + public void Should_Handle_Enum_Comparison() + { + var clause = BuildClause(x => x.EnumValue == TestEnum.Value2); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.EnumValue')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be($"{(int)TestEnum.Value2}"); + } + + [Fact] + public void Should_Handle_String_Enum_Comparison() + { + var clause = BuildClause(x => x.StringEnumValue == TestStringEnum.String2); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.StringEnumValue')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be($"\"{nameof(TestStringEnum.String2)}\""); + } + + [Fact] + public void Should_Handle_String_Equality() + { + var clause = BuildClause(x => x.Name == "John"); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.Name')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("\"John\""); + } + + [Fact] + public void Should_Handle_Boolean_Comparison() + { + var clause = BuildClause(x => x.IsActive == true); + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.IsActive')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("true"); + } + + [Fact] + public void Should_Handle_Double_Comparison() + { + var clause = BuildClause(x => x.Score >= 95.5); + clause.Sql.Should().Be($"{Cast("$.Score", "DECIMAL(65,30)")} >= @p0"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(95.5); + } + + [Fact] + public void Should_Handle_Decimal_Comparison() + { + var clause = BuildClause(x => x.Price < 199.99m); + + clause.Sql.Should().Be($"{Cast("$.Price", "DECIMAL(65,30)")} < @p0"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(199.99m); + } + + [Fact] + public void Should_Handle_DateTime_Comparison() + { + var date = new DateTime(2024, 1, 1); + var clause = BuildClause(x => x.CreatedAt > date); + + clause.Sql.Should().Be($"{Cast("$.CreatedAt", "DATETIME(6)")} > @p0"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(date); + } + + [Fact] + public void Should_Handle_Multiple_Conditions_With_And() + { + var clause = BuildClause(x => x.Age > 18 && x.IsActive == true); + + clause.Sql.Should().Be( + $"({Cast("$.Age", "SIGNED")} > @p0 AND " + + "JSON_CONTAINS(document, @p1, '$.IsActive'))"); + clause.Parameters.Should().HaveCount(2); + clause.Parameters[0].Value.Should().Be(18); + clause.Parameters[1].Value.Should().Be("true"); + } + + [Fact] + public void Should_Handle_Multiple_Conditions_With_Or() + { + var clause = BuildClause(x => x.Name == "John" || x.Name == "Jane"); + + clause.Sql.Should().Be( + "(JSON_CONTAINS(document, @p0, '$.Name') OR JSON_CONTAINS(document, @p1, '$.Name'))"); + clause.Parameters.Should().HaveCount(2); + clause.Parameters[0].Value.Should().Be("\"John\""); + clause.Parameters[1].Value.Should().Be("\"Jane\""); + } + + [Fact] + public void Should_Handle_Complex_Nested_Conditions() + { + var clause = BuildClause(x => + (x.Age > 18 && x.IsActive == true) || (x.Score >= 95.5 && x.Name == "John")); + + clause.Sql.Should().Be( + $"(({Cast("$.Age", "SIGNED")} > @p0 AND JSON_CONTAINS(document, @p1, '$.IsActive')) OR " + + $"({Cast("$.Score", "DECIMAL(65,30)")} >= @p2 AND JSON_CONTAINS(document, @p3, '$.Name')))"); + clause.Parameters.Should().HaveCount(4); + clause.Parameters[0].Value.Should().Be(18); + clause.Parameters[1].Value.Should().Be("true"); + clause.Parameters[2].Value.Should().Be(95.5); + clause.Parameters[3].Value.Should().Be("\"John\""); + } + + [Fact] + public void Should_Handle_Multiple_Predicates() + { + var predicates = new List>> + { + x => x.Age > 18, + x => x.IsActive == true + }; + + var clauses = BuildClauses(predicates).ToList(); + + clauses.Should().HaveCount(2); + clauses[0].Sql.Should().Be($"{Cast("$.Age", "SIGNED")} > @p0"); + clauses[1].Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.IsActive')"); + clauses[0].Parameters.Should().HaveCount(1); + clauses[1].Parameters.Should().HaveCount(1); + clauses[1].Parameters[0].Value.Should().Be("true"); + } + + [Fact] + public void Should_Handle_All_Comparison_Operators() + { + // Equal/NotEqual on integer (containment-eligible) use JSON_CONTAINS with a JSON-number leaf. + (Expression> expression, string expectedSql, object expectedParam)[] containmentCases = + [ + (x => x.Age == 18, "JSON_CONTAINS(document, @p0, '$.Age')", "18"), + (x => x.Age != 18, "NOT JSON_CONTAINS(document, @p0, '$.Age')", "18") + ]; + + foreach (var testCase in containmentCases) + { + var clause = BuildClause(testCase.expression); + + clause.Sql.Should().Be(testCase.expectedSql); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(testCase.expectedParam); + } + + // Range comparisons keep the cast-and-compare form. + (Expression> expression, string expected)[] comparisonCases = + [ + (x => x.Age > 18, $"{Cast("$.Age", "SIGNED")} > @p0"), + (x => x.Age >= 18, $"{Cast("$.Age", "SIGNED")} >= @p0"), + (x => x.Age < 18, $"{Cast("$.Age", "SIGNED")} < @p0"), + (x => x.Age <= 18, $"{Cast("$.Age", "SIGNED")} <= @p0") + ]; + + foreach (var testCase in comparisonCases) + { + var clause = BuildClause(testCase.expression); + + clause.Sql.Should().Be(testCase.expected); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(18); + } + } + + [Fact] + public void Should_Handle_Multiple_Where_Calls() + { + var request = new SearchRequest() + .Where(x => x.Name.Contains("Test")) + .Where(x => x.Age >= 18); + + var clauses = WhereClauseBuilder.BuildClauses(request.Filters).ToList(); + + clauses.Should().HaveCount(2); + + var allParams = clauses.SelectMany(c => c.Parameters).ToList(); + var paramNames = allParams.Select(p => p.ParameterName).ToList(); + + paramNames.Should().HaveCount(2); + paramNames.Should().OnlyHaveUniqueItems("Parameter names should be unique to avoid conflicts"); + + paramNames.Should().Contain("@p0"); + paramNames.Should().Contain("@p1"); + } + + [Fact] + public void Should_Handle_Nested_String_Equality_With_Containment() + { + var clause = BuildClause(x => x.Author.Name == "Alice"); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.Author.Name')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("\"Alice\""); + } + + [Fact] + public void Should_Handle_Deeply_Nested_String_Equality_With_Containment() + { + var clause = BuildClause(x => x.Author.Address.City == "Oslo"); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.Author.Address.City')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("\"Oslo\""); + } + + [Fact] + public void Should_Handle_Nested_Integer_Equality_With_Containment() + { + var clause = BuildClause(x => x.Author.Rank == 5); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.Author.Rank')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("5"); + } + + [Fact] + public void Should_Handle_Nested_Comparison_With_Path_Accessor() + { + var clause = BuildClause(x => x.Author.Rank > 3); + + clause.Sql.Should().Be($"{Cast("$.Author.Rank", "SIGNED")} > @p0"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be(3); + } + + [Fact] + public void Should_Handle_Nested_Null_Check_With_Path_Accessor() + { + var clause = BuildClause(x => x.Author.Name == null); + + clause.Sql.Should().Be($"{Acc("$.Author.Name")} IS NULL"); + clause.Parameters.Should().HaveCount(0); + } + + [Fact] + public void Should_Handle_Nested_String_Operation_With_Path_Accessor() + { + var clause = BuildClause(x => x.Author.Name.Contains("li")); + + clause.Sql.Should().Be($"{Acc("$.Author.Name")} LIKE @p0"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("%li%"); + } + + [Fact] + public void Should_Handle_CollectionContains_TopLevel() + { + var clause = BuildClause(x => x.Labels.Contains("urgent")); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.Labels')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("\"urgent\""); + } + + [Fact] + public void Should_Handle_CollectionNotContains_TopLevel() + { + var clause = BuildClause(x => !x.Labels.Contains("urgent")); + + clause.Sql.Should().Be("NOT JSON_CONTAINS(document, @p0, '$.Labels')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("\"urgent\""); + } + + [Fact] + public void Should_Handle_CollectionContains_Nested() + { + var clause = BuildClause(x => x.Author.Roles.Contains("admin")); + + clause.Sql.Should().Be("JSON_CONTAINS(document, @p0, '$.Author.Roles')"); + clause.Parameters.Should().HaveCount(1); + clause.Parameters[0].Value.Should().Be("\"admin\""); + } + + // The scalar accessor: SQL NULL for JSON null / missing key, the unquoted value otherwise. + private static string Acc(string path) => + $"(CASE WHEN JSON_TYPE(JSON_EXTRACT(document, '{path}')) = 'NULL' THEN NULL ELSE JSON_UNQUOTE(JSON_EXTRACT(document, '{path}')) END)"; + + private static string Cast(string path, string type) => $"CAST({Acc(path)} AS {type})"; + + private static Clause BuildClause(Expression> predicate) => BuildClauses(predicate).Single(); + + private static IReadOnlyList BuildClauses(params IEnumerable>> predicates) + { + return predicates.SelectMany(predicate => + WhereClauseBuilder.BuildClauses( + [FilterMapper.Map(predicate)])) + .ToList(); + } +}