From aecd6b01aa6493d21fa86385991d08df4d33edd7 Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:22:20 +0100 Subject: [PATCH 1/7] deps: update dependencies and remove unused packages; add .idea to .gitignore --- .gitignore | 2 + Cargo.lock | 493 +++++++++++++++++++---------------------------------- 2 files changed, 176 insertions(+), 319 deletions(-) diff --git a/.gitignore b/.gitignore index 7bc7850..194f21f 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ target/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ *.log + +.idea diff --git a/Cargo.lock b/Cargo.lock index 611a045..9296c4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,15 +19,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "addr2line" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" -dependencies = [ - "gimli", -] - [[package]] name = "adler2" version = "2.0.1" @@ -52,7 +43,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", - "getrandom 0.3.3", + "getrandom 0.3.4", "once_cell", "serde", "version_check", @@ -61,19 +52,13 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - [[package]] name = "android_system_properties" version = "0.1.5" @@ -91,15 +76,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" [[package]] name = "anyhow" -version = "1.0.99" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "autocfg" @@ -107,21 +92,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "backtrace" -version = "0.3.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-targets", -] - [[package]] name = "base64" version = "0.13.1" @@ -130,15 +100,15 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64ct" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] name = "bitflags" -version = "2.9.4" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" [[package]] name = "block-buffer" @@ -151,9 +121,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" [[package]] name = "byteorder" @@ -198,9 +168,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.35" +version = "1.2.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "590f9024a68a8c40351881787f1934dc11afd69090f5edb6831464694d836ea3" +checksum = "9f50d563227a1c37cc0a263f64eca3334388c01c5e4c4861a9def205c614383c" dependencies = [ "find-msvc-tools", "jobserver", @@ -210,17 +180,16 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "chrono" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ - "android-tzdata", "iana-time-zone", "js-sys", "num-traits", @@ -267,18 +236,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.47" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eac00902d9d136acd712710d71823fb8ac8004ca445a89e73a41d45aa712931" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.47" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ad9bbf750e73b5884fb8a211a9424a1906c1e156724260fdae972f31d70e1d6" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ "anstyle", "clap_lex", @@ -286,9 +255,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" [[package]] name = "compact_str" @@ -307,9 +276,9 @@ dependencies = [ [[package]] name = "console" -version = "0.15.11" +version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +checksum = "03e45a4a8926227e4197636ba97a9fc9b00477e9f4bd711395687c5f0734bec4" dependencies = [ "encode_unicode", "libc", @@ -414,9 +383,9 @@ checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", "typenum", @@ -424,21 +393,21 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +checksum = "52cd9d68cf7efc6ddfaaee42e7288d3a99d613d4b50f76ce9827ae0c6e14f938" dependencies = [ "csv-core", "itoa", "ryu", - "serde", + "serde_core", ] [[package]] name = "csv-core" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" dependencies = [ "memchr", ] @@ -480,18 +449,18 @@ dependencies = [ [[package]] name = "dary_heap" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" dependencies = [ "serde", ] [[package]] name = "deranged" -version = "0.5.3" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d630bccd429a5bb5a64b5e94f693bfc48c9f8566418fda4c494cc94f911f87cc" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" dependencies = [ "powerfmt", ] @@ -561,15 +530,15 @@ dependencies = [ [[package]] name = "find-msvc-tools" -version = "0.1.0" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e178e4fba8a2726903f6ba98a6d221e76f9c12c650d5dc0e6afdc50677b49650" +checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" [[package]] name = "flate2" -version = "1.1.2" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" dependencies = [ "crc32fast", "miniz_oxide", @@ -599,35 +568,30 @@ checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", - "wasi 0.11.1+wasi-snapshot-preview1", + "wasi", ] [[package]] name = "getrandom" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasip2", ] -[[package]] -name = "gimli" -version = "0.31.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" - [[package]] name = "half" -version = "2.6.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" dependencies = [ "cfg-if", "crunchy", + "zerocopy", ] [[package]] @@ -641,9 +605,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.63" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -671,14 +635,14 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "indicatif" -version = "0.17.11" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88" dependencies = [ "console", - "number_prefix", "portable-atomic", "unicode-width", + "unit-prefix", "web-time", ] @@ -691,17 +655,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "io-uring" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" -dependencies = [ - "bitflags", - "cfg-if", - "libc", -] - [[package]] name = "itertools" version = "0.13.0" @@ -722,9 +675,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "7ee5b5339afb4c41626dde77b7a611bd4f2c202b897852b4bcf5d03eddc61010" [[package]] name = "jobserver" @@ -732,15 +685,15 @@ version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", "libc", ] [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -754,15 +707,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.175" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "macro_rules_attribute" @@ -792,9 +745,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.5" +version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "minimal-lexical" @@ -809,34 +762,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", -] - -[[package]] -name = "mio" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" -dependencies = [ - "libc", - "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys", + "simd-adler32", ] [[package]] name = "monostate" -version = "0.1.14" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe1be9d0c75642e3e50fedc7ecadf1ef1cbce6eb66462153fc44245343fbee" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" dependencies = [ "monostate-impl", "serde", + "serde_core", ] [[package]] name = "monostate-impl" -version = "0.1.14" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c402a4092d5e204f32c9e155431046831fa712637043c58cb73bc6bc6c9663b5" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" dependencies = [ "proc-macro2", "quote", @@ -901,21 +845,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - -[[package]] -name = "object" -version = "0.36.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" version = "1.21.3" @@ -1021,9 +950,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.11.1" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +checksum = "f59e70c4aef1e55797c2e8fd94a4f2a973fc972cfde0e0b05f683667b0cd39dd" [[package]] name = "portable-atomic-util" @@ -1051,18 +980,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.101" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.40" +version = "1.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" dependencies = [ "proc-macro2", ] @@ -1129,7 +1058,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.3", + "getrandom 0.3.4", ] [[package]] @@ -1171,9 +1100,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.2" +version = "1.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d7fd106d8c02486a8d64e778353d1cffe08ce79ac2e82f540c86d0facf6912" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" dependencies = [ "aho-corasick", "memchr", @@ -1183,9 +1112,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.10" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b9458fa0bfeeac22b5ca447c63aaf45f28439a709ccd244698632f9aa6394d6" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" dependencies = [ "aho-corasick", "memchr", @@ -1194,15 +1123,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.6" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" - -[[package]] -name = "rustc-demangle" -version = "0.1.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "rustversion" @@ -1212,9 +1135,9 @@ checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "62049b2877bf12821e8f9ad256ee38fdc31db7387ec2d3b3f403024de2034aea" [[package]] name = "safetensors" @@ -1237,18 +1160,28 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -1257,14 +1190,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.143" +version = "1.0.146" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +checksum = "217ca874ae0207aac254aa02c957ded05585a90892cc8d87f9e5fa49669dadd8" dependencies = [ "itoa", "memchr", "ryu", "serde", + "serde_core", ] [[package]] @@ -1296,10 +1230,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] -name = "slab" -version = "0.4.11" +name = "simd-adler32" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "smallvec" @@ -1339,9 +1273,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.106" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -1376,11 +1310,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl 2.0.16", + "thiserror-impl 2.0.17", ] [[package]] @@ -1396,9 +1330,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", @@ -1407,9 +1341,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.43" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83bde6f1ec10e72d583d91623c939f623002284ef622b87de38cfd546cbf2031" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "num-conv", @@ -1436,9 +1370,9 @@ dependencies = [ [[package]] name = "tokenizers" -version = "0.22.0" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af10f51be57162b69d90a15cb226eef12c9e4faecbd5e3ea98a86bfb920b3d71" +checksum = "b238e22d44a15349529690fb07bd645cf58149a1b1e44d6cb5bd1641ff1a6223" dependencies = [ "ahash", "aho-corasick", @@ -1446,7 +1380,7 @@ dependencies = [ "dary_heap", "derive_builder", "esaxx-rs", - "getrandom 0.3.3", + "getrandom 0.3.4", "indicatif", "itertools 0.14.0", "log", @@ -1462,7 +1396,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror 2.0.16", + "thiserror 2.0.17", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -1470,24 +1404,19 @@ dependencies = [ [[package]] name = "tokio" -version = "1.47.1" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" +checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" dependencies = [ - "backtrace", - "io-uring", - "libc", - "mio", "pin-project-lite", - "slab", "tokio-macros", ] [[package]] name = "tokio-macros" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", @@ -1508,15 +1437,15 @@ dependencies = [ [[package]] name = "typenum" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" [[package]] name = "unicode-normalization-alignments" @@ -1535,9 +1464,9 @@ checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" [[package]] name = "unicode_categories" @@ -1545,6 +1474,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unit-prefix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" + [[package]] name = "version_check" version = "0.9.5" @@ -1568,45 +1503,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] -name = "wasi" -version = "0.14.2+wasi-0.2.4" +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.100" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" -dependencies = [ - "bumpalo", - "log", - "proc-macro2", - "quote", - "syn", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1614,31 +1536,31 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ + "bumpalo", "proc-macro2", "quote", "syn", - "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", @@ -1665,9 +1587,9 @@ dependencies = [ [[package]] name = "windows-core" -version = "0.61.2" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", @@ -1678,9 +1600,9 @@ dependencies = [ [[package]] name = "windows-implement" -version = "0.60.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", @@ -1689,9 +1611,9 @@ dependencies = [ [[package]] name = "windows-interface" -version = "0.59.1" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", @@ -1700,124 +1622,57 @@ dependencies = [ [[package]] name = "windows-link" -version = "0.1.3" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-result" -version = "0.3.4" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ "windows-link", ] [[package]] name = "windows-strings" -version = "0.4.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ "windows-link", ] [[package]] name = "windows-sys" -version = "0.59.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" -dependencies = [ - "windows-targets", -] - -[[package]] -name = "windows-targets" -version = "0.52.6" +version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows-link", ] [[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - -[[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags", -] +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "zerocopy" -version = "0.8.26" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.26" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", @@ -1865,9 +1720,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.15+zstd.1.5.7" +version = "2.0.16+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" dependencies = [ "cc", "pkg-config", From 961b994af7451f4414d658fc91986269e80e029b Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Mon, 22 Dec 2025 19:17:17 +0100 Subject: [PATCH 2/7] deps: update dependencies and add new packages --- Cargo.lock | 59 ++++++++++++++++++++++++++++++++++++++++++++++-------- Cargo.toml | 4 ++-- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9296c4b..8f35898 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -59,6 +59,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "android_system_properties" version = "0.1.5" @@ -319,10 +328,11 @@ dependencies = [ [[package]] name = "criterion" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" +checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" dependencies = [ + "alloca", "anes", "cast", "ciborium", @@ -331,6 +341,7 @@ dependencies = [ "itertools 0.13.0", "num-traits", "oorandom", + "page_size", "plotters", "rayon", "regex", @@ -342,9 +353,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.6.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" +checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" dependencies = [ "cast", "itertools 0.13.0", @@ -879,6 +890,16 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "password-hash" version = "0.4.2" @@ -1284,9 +1305,9 @@ dependencies = [ [[package]] name = "tch" -version = "0.20.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a760143efe7e4bb5b56e95d01f52ee6773bc315202e7c47db6a6429b0705a1f2" +checksum = "9e09b91610202dc4820c21eb474a42b386ef69f323b1c0902b5472ba7456ebb5" dependencies = [ "half", "lazy_static", @@ -1425,9 +1446,9 @@ dependencies = [ [[package]] name = "torch-sys" -version = "0.20.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" +checksum = "aef40c585e342df95b66a1fa7c923188623999c2b657227befb481dfb03a6a42" dependencies = [ "anyhow", "cc", @@ -1576,6 +1597,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -1585,6 +1622,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" diff --git a/Cargo.toml b/Cargo.toml index e0eece7..41d2fc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ authors = ["Maxime-Cllt"] [dependencies] -tch = "0.20.0" +tch = "0.22.0" tokenizers = "0.22.0" csv = "1.3.1" serde = { version = "1.0.219", features = ["derive"] } @@ -23,7 +23,7 @@ once_cell = "1.21.3" [dev-dependencies] tokio = { version = "1.47.1", features = ["macros", "rt-multi-thread", "sync"] } -criterion = "0.7.0" +criterion = "0.8.1" [lib] name = "datalib" From 017b5c91f97085642480034a59620bd47b01de45 Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Tue, 23 Dec 2025 18:17:39 +0100 Subject: [PATCH 3/7] deps: update .gitignore to include CSV and JSON files; update config.toml for environment variables and linker arguments --- .cargo/config.toml | 6 ++++++ .gitignore | 2 ++ 2 files changed, 8 insertions(+) diff --git a/.cargo/config.toml b/.cargo/config.toml index 396033c..6b46a16 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,8 @@ +[env] +LIBTORCH_USE_PYTORCH = "1" +LIBTORCH_BYPASS_VERSION_CHECK = "1" +DYLD_LIBRARY_PATH = "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/lib" + [build] rustflags = [ "-C", "target-feature=+crt-static", @@ -31,4 +36,5 @@ rustflags = [ "-C", "target-cpu=apple-m1", "-C", "link-arg=-dead_strip", "-C", "link-arg=-dead_strip_dylibs", + "-C", "link-arg=-Wl,-rpath,/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/lib", ] \ No newline at end of file diff --git a/.gitignore b/.gitignore index 194f21f..c8e1cc3 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ target/ *.log .idea +*.csv +*.json From f1fc4f1ef7f746c802b5707bcf49677fefacf382 Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Tue, 23 Dec 2025 18:30:57 +0100 Subject: [PATCH 4/7] refactor: reorganize module structure and update imports for core components --- src/benches/application_bench.rs | 10 ++--- src/{structs => core/io/file}/csv_file.rs | 13 +++--- src/{structs => core/io/file}/json_output.rs | 2 +- src/core/io/file/mod.rs | 3 ++ src/{enums => core/io/file}/separator.rs | 2 +- src/core/io/mod.rs | 3 ++ src/{enums => core/io/tracing}/color.rs | 0 src/{enums => core/io/tracing}/log_level.rs | 8 +++- src/{structs => core/io/tracing}/logger.rs | 7 +-- src/{enums => core/io/tracing}/mod.rs | 2 +- src/core/mod.rs | 2 + src/{ => core}/utils/mod.rs | 2 +- src/{ => core}/utils/regex.rs | 4 +- src/{ => core}/utils/util.rs | 12 +++--- src/{structs => detection}/anomaly.rs | 2 +- src/{structs => detection}/inferable_value.rs | 0 src/detection/mod.rs | 2 + src/lib.rs | 7 +-- src/main.rs | 43 ++++++++----------- src/model/mod.rs | 2 + src/{structs => model}/model.rs | 12 +++--- src/{structs => model}/tokenizer.rs | 8 ++-- src/structs/mod.rs | 7 --- src/tests/csv_tests.rs | 2 +- src/tests/model_tests.rs | 2 +- src/tests/test.rs | 10 ----- src/tests/utils_tests.rs | 6 +-- 27 files changed, 84 insertions(+), 89 deletions(-) rename src/{structs => core/io/file}/csv_file.rs (95%) rename src/{structs => core/io/file}/json_output.rs (96%) create mode 100644 src/core/io/file/mod.rs rename src/{enums => core/io/file}/separator.rs (98%) create mode 100644 src/core/io/mod.rs rename src/{enums => core/io/tracing}/color.rs (100%) rename src/{enums => core/io/tracing}/log_level.rs (78%) rename src/{structs => core/io/tracing}/logger.rs (92%) rename src/{enums => core/io/tracing}/mod.rs (64%) create mode 100644 src/core/mod.rs rename src/{ => core}/utils/mod.rs (51%) rename src/{ => core}/utils/regex.rs (98%) rename src/{ => core}/utils/util.rs (94%) rename src/{structs => detection}/anomaly.rs (96%) rename src/{structs => detection}/inferable_value.rs (100%) create mode 100644 src/detection/mod.rs create mode 100644 src/model/mod.rs rename src/{structs => model}/model.rs (95%) rename src/{structs => model}/tokenizer.rs (97%) delete mode 100644 src/structs/mod.rs delete mode 100644 src/tests/test.rs diff --git a/src/benches/application_bench.rs b/src/benches/application_bench.rs index 9eca6c6..b0a6e41 100644 --- a/src/benches/application_bench.rs +++ b/src/benches/application_bench.rs @@ -1,13 +1,13 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use datalib::enums::log_level::LogLevel; -use datalib::structs::csv_file::CsvFile; -use datalib::structs::logger::log_and_print_message; -use datalib::structs::model::Model; +use datalib::core::io::file::csv_file::CsvFile; +use datalib::model::model::Model; use std::time::Duration; +use datalib::core::io::tracing::log_level::LogLevel; +use datalib::core::io::tracing::logger::log_and_print_message; #[allow(dead_code)] fn test_analyse_file() { - const FILEPATH: &str = r""; // Path to the CSV file to be analyzed + const FILEPATH: &str = r"/Users/maximecolliat/RustroverProjects/DataLint/Amazon.csv"; // Path to the CSV file to be analyzed let perfage_iae: Model = match Model::from_config_file("config.json") { Ok(perfage) => perfage, diff --git a/src/structs/csv_file.rs b/src/core/io/file/csv_file.rs similarity index 95% rename from src/structs/csv_file.rs rename to src/core/io/file/csv_file.rs index bc67b16..55ab172 100644 --- a/src/structs/csv_file.rs +++ b/src/core/io/file/csv_file.rs @@ -1,9 +1,7 @@ -use crate::enums::log_level::LogLevel; -use crate::enums::separator::SeparatorType; -use crate::structs::inferable_value::InferableValue; -use crate::structs::logger::{log_and_print_message, print_message}; -use crate::utils::regex::{get_safe_regex_set, get_unsafe_value_regex_set}; -use crate::utils::util::get_file_name; + +use crate::detection::inferable_value::InferableValue; +use crate::core::utils::regex::{get_safe_regex_set, get_unsafe_value_regex_set}; +use crate::core::utils::util::get_file_name; use csv::{Reader, ReaderBuilder, StringRecord}; use rayon::iter::IntoParallelRefIterator; use rayon::prelude::ParallelIterator; @@ -14,6 +12,9 @@ use std::error::Error; use std::fs::File; use std::io; use std::io::{BufRead, BufReader, BufWriter, Read, Write}; +use crate::core::io::file::separator::SeparatorType; +use crate::core::io::tracing::log_level::LogLevel; +use crate::core::io::tracing::logger::{log_and_print_message, print_message}; /// Represents a CSV file with its path and separator. pub struct CsvFile { diff --git a/src/structs/json_output.rs b/src/core/io/file/json_output.rs similarity index 96% rename from src/structs/json_output.rs rename to src/core/io/file/json_output.rs index 5a6528f..ac89c7e 100644 --- a/src/structs/json_output.rs +++ b/src/core/io/file/json_output.rs @@ -1,4 +1,4 @@ -use crate::structs::anomaly::Anomaly; +use crate::detection::anomaly::Anomaly; use serde::{Deserialize, Serialize}; diff --git a/src/core/io/file/mod.rs b/src/core/io/file/mod.rs new file mode 100644 index 0000000..8a854e6 --- /dev/null +++ b/src/core/io/file/mod.rs @@ -0,0 +1,3 @@ +pub mod separator; +pub mod json_output; +pub mod csv_file; \ No newline at end of file diff --git a/src/enums/separator.rs b/src/core/io/file/separator.rs similarity index 98% rename from src/enums/separator.rs rename to src/core/io/file/separator.rs index 7afa98d..28e9297 100644 --- a/src/enums/separator.rs +++ b/src/core/io/file/separator.rs @@ -14,7 +14,7 @@ impl SeparatorType { /// Returns the separator as a `char`. #[inline] #[must_use] - pub(crate) const fn as_char(&self) -> char { + pub const fn as_char(&self) -> char { match self { Self::Comma => ',', Self::Semicolon => ';', diff --git a/src/core/io/mod.rs b/src/core/io/mod.rs new file mode 100644 index 0000000..9b2120a --- /dev/null +++ b/src/core/io/mod.rs @@ -0,0 +1,3 @@ + +pub mod file; +pub mod tracing; \ No newline at end of file diff --git a/src/enums/color.rs b/src/core/io/tracing/color.rs similarity index 100% rename from src/enums/color.rs rename to src/core/io/tracing/color.rs diff --git a/src/enums/log_level.rs b/src/core/io/tracing/log_level.rs similarity index 78% rename from src/enums/log_level.rs rename to src/core/io/tracing/log_level.rs index 70060d7..a8a56b1 100644 --- a/src/enums/log_level.rs +++ b/src/core/io/tracing/log_level.rs @@ -20,7 +20,7 @@ impl LogLevel { #[cfg(test)] mod tests { - use crate::enums::log_level::LogLevel; + use crate::core::io::tracing::log_level::LogLevel; #[tokio::test] async fn test_log_level_as_str() { @@ -33,4 +33,10 @@ mod tests { assert_eq!(format!("{:?}", LogLevel::Error), "Error"); assert_eq!(format!("{:?}", LogLevel::Info), "Info"); } + + #[tokio::test] + async fn test_log_level_repr() { + assert_eq!(LogLevel::Error as u8, 0); + assert_eq!(LogLevel::Info as u8, 1); + } } diff --git a/src/structs/logger.rs b/src/core/io/tracing/logger.rs similarity index 92% rename from src/structs/logger.rs rename to src/core/io/tracing/logger.rs index 447110b..bdb2f03 100644 --- a/src/structs/logger.rs +++ b/src/core/io/tracing/logger.rs @@ -1,8 +1,9 @@ -use crate::enums::color::Color; -use crate::enums::log_level::LogLevel; + use std::fs::File; use std::io::{BufWriter, Write}; use std::sync::{Mutex, MutexGuard}; +use crate::core::io::tracing::color::Color; +use crate::core::io::tracing::log_level::LogLevel; /// Logger struct to handle logging messages to a file #[non_exhaustive] @@ -18,7 +19,7 @@ impl Logger { let log_file: File = std::fs::OpenOptions::new() .create(true) .append(true) - .open("DataLint.log") + .open("../../DataLint.log") .unwrap(); Self { log_file } diff --git a/src/enums/mod.rs b/src/core/io/tracing/mod.rs similarity index 64% rename from src/enums/mod.rs rename to src/core/io/tracing/mod.rs index 67ef69b..231da0d 100644 --- a/src/enums/mod.rs +++ b/src/core/io/tracing/mod.rs @@ -1,3 +1,3 @@ pub mod color; pub mod log_level; -pub mod separator; +pub mod logger; diff --git a/src/core/mod.rs b/src/core/mod.rs new file mode 100644 index 0000000..2d3a819 --- /dev/null +++ b/src/core/mod.rs @@ -0,0 +1,2 @@ +pub mod utils; +pub mod io; \ No newline at end of file diff --git a/src/utils/mod.rs b/src/core/utils/mod.rs similarity index 51% rename from src/utils/mod.rs rename to src/core/utils/mod.rs index df75e14..5dad056 100644 --- a/src/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -1,2 +1,2 @@ pub mod regex; -pub mod util; +pub mod util; \ No newline at end of file diff --git a/src/utils/regex.rs b/src/core/utils/regex.rs similarity index 98% rename from src/utils/regex.rs rename to src/core/utils/regex.rs index b82b42f..cb3ff9a 100644 --- a/src/utils/regex.rs +++ b/src/core/utils/regex.rs @@ -1,4 +1,4 @@ -use crate::utils::regex::safe_regex::{ +use crate::core::utils::regex::safe_regex::{ get_datetime_regex, get_email_regex, get_numeric_regex, get_phone_number_regex, get_simple_word_regex, }; @@ -47,7 +47,7 @@ pub mod safe_regex { #[cfg(test)] mod test { - use crate::utils::regex::safe_regex::{ + use crate::core::utils::regex::safe_regex::{ get_datetime_regex, get_email_regex, get_numeric_regex, get_simple_word_regex, }; use regex::Regex; diff --git a/src/utils/util.rs b/src/core/utils/util.rs similarity index 94% rename from src/utils/util.rs rename to src/core/utils/util.rs index e9f146c..3d16f0e 100644 --- a/src/utils/util.rs +++ b/src/core/utils/util.rs @@ -1,11 +1,11 @@ -use crate::enums::color::Color; -use crate::enums::log_level::LogLevel; -use crate::structs::anomaly::Anomaly; -use crate::structs::json_output::JsonOutput; -use crate::structs::logger::{log_and_print_message, log_message, print_message}; +use crate::core::io::tracing::log_level::LogLevel; +use crate::detection::anomaly::Anomaly; +use crate::core::io::file::json_output::JsonOutput; use std::io::{Error, ErrorKind}; use std::path::PathBuf; use std::time::Instant; +use crate::core::io::tracing::color::Color; +use crate::core::io::tracing::logger::{log_and_print_message, log_message, print_message}; /// Create a JSON file with the analysis results. pub fn generate_json_file( @@ -34,7 +34,7 @@ pub fn generate_json_file( format!("Error while getting the current directory: {e}").as_str(), &LogLevel::Error, ); - std::path::PathBuf::from(".") + std::path::PathBuf::from("../../..") }); let current_dir: &str = binding.to_str().unwrap(); let save_path: String = format!("{JSON_DIR}/{output_file_name}.{JSON_DIR}"); diff --git a/src/structs/anomaly.rs b/src/detection/anomaly.rs similarity index 96% rename from src/structs/anomaly.rs rename to src/detection/anomaly.rs index 39913a2..980356e 100644 --- a/src/structs/anomaly.rs +++ b/src/detection/anomaly.rs @@ -1,5 +1,5 @@ -use crate::enums::color::Color; use serde::{Deserialize, Serialize}; +use crate::core::io::tracing::color::Color; /// Represents an anomaly detected in a CSV file. #[derive(Serialize, Deserialize, Clone)] diff --git a/src/structs/inferable_value.rs b/src/detection/inferable_value.rs similarity index 100% rename from src/structs/inferable_value.rs rename to src/detection/inferable_value.rs diff --git a/src/detection/mod.rs b/src/detection/mod.rs new file mode 100644 index 0000000..7881072 --- /dev/null +++ b/src/detection/mod.rs @@ -0,0 +1,2 @@ +pub mod anomaly; +pub mod inferable_value; \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 8a04079..543f431 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ -pub mod enums; -pub mod structs; -pub mod utils; +pub mod detection; +pub mod model; + +pub mod core; diff --git a/src/main.rs b/src/main.rs index 1f50ac8..f75512c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,12 +4,12 @@ mod benches; #[cfg(test)] mod tests; -use datalib::enums::log_level::LogLevel; -use datalib::structs::anomaly::Anomaly; -use datalib::structs::csv_file::CsvFile; -use datalib::structs::logger::log_and_print_message; -use datalib::structs::model::Model; -use datalib::utils::util::{ +use datalib::detection::anomaly::Anomaly; +use datalib::core::io::file::csv_file::CsvFile; +use datalib::model::model::Model; +use datalib::core::io::tracing::log_level::LogLevel; +use datalib::core::io::tracing::logger::log_and_print_message; +use datalib::core::utils::util::{ file_exists, generate_json_file, get_file_from_args, print_report, run_post_execution, }; use std::process::exit; @@ -17,20 +17,18 @@ use std::time::Instant; fn main() { let args: Vec = std::env::args().collect(); - let args: [String; 2] = - get_file_from_args(&args).expect("Error parsing command line arguments. Usage: datalib "); - + let args: [String; 2] = get_file_from_args(&args) + .expect("Error parsing command line arguments. Usage: datalib "); let start_time: Instant = Instant::now(); - let perfage_iae: Model = - Model::from_config_file("config.json").unwrap_or_else(|e| { - log_and_print_message( - &format!("Error loading model configuration: {e}"), - &LogLevel::Error, - ); - exit(1); - }); + let perfage_iae: Model = Model::from_config_file("config.json").unwrap_or_else(|e| { + log_and_print_message( + &format!("Error loading model configuration: {e}"), + &LogLevel::Error, + ); + exit(1); + }); [&perfage_iae.model_path, &perfage_iae.vocabulary_path] .iter() @@ -41,23 +39,16 @@ fn main() { }); let csv_struct: CsvFile = CsvFile::from_file(&args[0]).unwrap_or_else(|e| { - log_and_print_message( - &format!("Error reading CSV file: {e}"), - &LogLevel::Error, - ); + log_and_print_message(&format!("Error reading CSV file: {e}"), &LogLevel::Error); exit(1); }); let (dangerous_output, ai_analyze, regex_analyze): (Vec, u32, u32) = perfage_iae.analyse_file(&csv_struct).unwrap_or_else(|e| { - log_and_print_message( - &format!("Error analyzing file: {e}"), - &LogLevel::Error, - ); + log_and_print_message(&format!("Error analyzing file: {e}"), &LogLevel::Error); exit(1); }); - print_report( &start_time, &dangerous_output, diff --git a/src/model/mod.rs b/src/model/mod.rs new file mode 100644 index 0000000..0e2590d --- /dev/null +++ b/src/model/mod.rs @@ -0,0 +1,2 @@ +pub mod model; +pub mod tokenizer; diff --git a/src/structs/model.rs b/src/model/model.rs similarity index 95% rename from src/structs/model.rs rename to src/model/model.rs index 99bee85..47be6d6 100644 --- a/src/structs/model.rs +++ b/src/model/model.rs @@ -1,15 +1,15 @@ -use crate::enums::log_level::LogLevel; -use crate::structs::anomaly::Anomaly; -use crate::structs::csv_file::CsvFile; -use crate::structs::inferable_value::InferableValue; -use crate::structs::logger::print_message; -use crate::structs::tokenizer::ModelTokenizer; +use crate::core::io::tracing::log_level::LogLevel; +use crate::detection::anomaly::Anomaly; +use crate::core::io::file::csv_file::CsvFile; +use crate::detection::inferable_value::InferableValue; +use crate::model::tokenizer::ModelTokenizer; use csv::StringRecord; use serde::Deserialize; use std::error::Error; use std::fs::File; use tch::{CModule, Device, Tensor}; use tokenizers::{Encoding, Tokenizer}; +use crate::core::io::tracing::logger::print_message; /// Represents the model configuration for the anomaly detection system. /// It contains the paths to the model and vocabulary files. diff --git a/src/structs/tokenizer.rs b/src/model/tokenizer.rs similarity index 97% rename from src/structs/tokenizer.rs rename to src/model/tokenizer.rs index fa61184..b487055 100644 --- a/src/structs/tokenizer.rs +++ b/src/model/tokenizer.rs @@ -1,10 +1,10 @@ -use crate::enums::log_level::LogLevel; -use crate::structs::inferable_value::InferableValue; -use crate::structs::logger::print_message; +use crate::detection::inferable_value::InferableValue; use rayon::iter::IntoParallelRefIterator; use rayon::iter::ParallelIterator; use std::error::Error; use tokenizers::{Encoding, Tokenizer}; +use crate::core::io::tracing::log_level::LogLevel; +use crate::core::io::tracing::logger::print_message; /// Represents a tokenizer for the model, providing methods to encode and decode text data. #[non_exhaustive] @@ -90,7 +90,7 @@ mod tests { async fn test_ids_to_vector_basic() { const WORDS: [&str; 5] = ["TEST", "--", "IN", "", "RUST IS FUN BUT WINDOWS IS NOT"]; - let path: PathBuf = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("model/tokenizer.json"); + let path: PathBuf = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../model/tokenizer.json"); let tokenizer: Tokenizer = Tokenizer::from_file(path).unwrap_or_else(|e| { print_message( &format!("Error reading vocabulary file: {e}"), diff --git a/src/structs/mod.rs b/src/structs/mod.rs deleted file mode 100644 index 8310a2c..0000000 --- a/src/structs/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod anomaly; -pub mod csv_file; -pub mod inferable_value; -pub mod json_output; -pub mod logger; -pub mod model; -pub mod tokenizer; \ No newline at end of file diff --git a/src/tests/csv_tests.rs b/src/tests/csv_tests.rs index 35b549a..9b5d425 100644 --- a/src/tests/csv_tests.rs +++ b/src/tests/csv_tests.rs @@ -1,7 +1,7 @@ use crate::tests::csv_tests::csv_utils::generate_csv_file; use crate::tests::utils_tests::delete_file; use csv::StringRecord; -use datalib::structs::csv_file::CsvFile; +use datalib::core::io::file::csv_file::CsvFile; #[tokio::test] async fn test_get_headers() { diff --git a/src/tests/model_tests.rs b/src/tests/model_tests.rs index 8c7a0c0..6ed4751 100644 --- a/src/tests/model_tests.rs +++ b/src/tests/model_tests.rs @@ -1,4 +1,4 @@ -use datalib::structs::model::Model; +use datalib::model::model::Model; use std::fs::File; use std::io::Write; diff --git a/src/tests/test.rs b/src/tests/test.rs deleted file mode 100644 index d29cfe1..0000000 --- a/src/tests/test.rs +++ /dev/null @@ -1,10 +0,0 @@ -use crate::structs::color::{BLUE, GREEN, RED, RESET, YELLOW}; - -#[tokio::test] -async fn test_red_color_code() { - assert_eq!(RED, "\x1b[31m"); - assert_eq!(GREEN, "\x1b[32m"); - assert_eq!(YELLOW, "\x1b[33m"); - assert_eq!(BLUE, "\x1b[34m"); - assert_eq!(RESET, "\x1b[0m"); -} diff --git a/src/tests/utils_tests.rs b/src/tests/utils_tests.rs index 8bba38d..99338cc 100644 --- a/src/tests/utils_tests.rs +++ b/src/tests/utils_tests.rs @@ -1,7 +1,7 @@ use crate::tests::csv_tests::csv_utils::generate_csv_file; -use datalib::structs::anomaly::Anomaly; -use datalib::structs::json_output::JsonOutput; -use datalib::utils::util::{file_exists, generate_json_file, get_file_from_args, get_file_name}; +use datalib::detection::anomaly::Anomaly; +use datalib::core::io::file::json_output::JsonOutput; +use datalib::core::utils::util::{file_exists, generate_json_file, get_file_from_args, get_file_name}; #[tokio::test] async fn test_get_file_from_args() { From f3acef4d3893be13a43395ae6d456f2ef029c508 Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Tue, 23 Dec 2025 18:57:05 +0100 Subject: [PATCH 5/7] refactor: optimize CSV processing and model initialization for improved performance --- src/core/io/file/csv_file.rs | 15 +-- src/core/utils/regex.rs | 101 ++++++++++++------ src/main.rs | 19 ++-- src/model/model.rs | 191 ++++++++++++++++++++++++++++++++++- src/model/tokenizer.rs | 10 +- 5 files changed, 280 insertions(+), 56 deletions(-) diff --git a/src/core/io/file/csv_file.rs b/src/core/io/file/csv_file.rs index 55ab172..0315c40 100644 --- a/src/core/io/file/csv_file.rs +++ b/src/core/io/file/csv_file.rs @@ -3,8 +3,6 @@ use crate::detection::inferable_value::InferableValue; use crate::core::utils::regex::{get_safe_regex_set, get_unsafe_value_regex_set}; use crate::core::utils::util::get_file_name; use csv::{Reader, ReaderBuilder, StringRecord}; -use rayon::iter::IntoParallelRefIterator; -use rayon::prelude::ParallelIterator; use regex::RegexSet; use std::borrow::Cow; use std::collections::HashSet; @@ -48,9 +46,10 @@ impl CsvFile { let first_line: String = Self::read_first_line(csv_file_path).unwrap(); + // Use sequential iteration for small array (6 elements) - faster than parallel overhead POSSIBLE_SEPARATORS - .par_iter() - .find_any(|sep| first_line.contains(sep.as_char())) + .iter() + .find(|sep| first_line.contains(sep.as_char())) .cloned() .ok_or_else(|| { io::Error::new( @@ -202,10 +201,12 @@ impl CsvFile { .has_headers(true) .from_reader(csv_file); - let safe_regex_set: RegexSet = get_safe_regex_set(); // Regex for safe values - let unsafe_regex_set: RegexSet = get_unsafe_value_regex_set(); // Regex for unsafe values + let safe_regex_set: &RegexSet = get_safe_regex_set(); // Cached regex for safe values + let unsafe_regex_set: &RegexSet = get_unsafe_value_regex_set(); // Cached regex for unsafe values let mut seen_words: HashSet = HashSet::new(); // Store seen words to avoid duplicates - let mut batch_data: Vec = Vec::new(); + + // Pre-allocate with reasonable default to avoid initial reallocations + let mut batch_data: Vec = Vec::with_capacity(1000); for (row_number, record) in rdr.records().enumerate() { let record: StringRecord = match record { diff --git a/src/core/utils/regex.rs b/src/core/utils/regex.rs index cb3ff9a..26c78f6 100644 --- a/src/core/utils/regex.rs +++ b/src/core/utils/regex.rs @@ -2,47 +2,74 @@ use crate::core::utils::regex::safe_regex::{ get_datetime_regex, get_email_regex, get_numeric_regex, get_phone_number_regex, get_simple_word_regex, }; +use once_cell::sync::Lazy; use regex::RegexSet; pub mod safe_regex { + use once_cell::sync::Lazy; use regex::Regex; /// Date and time pattern, supporting various formats - #[inline] - #[must_use] - pub fn get_datetime_regex() -> Regex { + static DATETIME_REGEX: Lazy = Lazy::new(|| { Regex::new( r"(?i)\b(?:\d{4}[-/]\d{2}[-/]\d{2}|\d{2}[-/]\d{2}[-/]\d{4})\s?(?:\d{2}[:]\d{2}[:]\d{2})?\b", ) - .unwrap() - } + .unwrap() + }); /// Numeric pattern, allowing for integers and decimals with optional signs + static NUMERIC_REGEX: Lazy = Lazy::new(|| { + Regex::new(r"^[-.]?\d+([.,]\d*)?\s*$").unwrap() + }); + + /// Email pattern, case-insensitive, allowing for common email formats + static EMAIL_REGEX: Lazy = Lazy::new(|| { + Regex::new(r"(?i)\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Za-z]{2,}\b").unwrap() + }); + + /// Simple word pattern, allowing only letters (case-insensitive) + static SIMPLE_WORD_REGEX: Lazy = Lazy::new(|| { + Regex::new("^[A-Za-z]+$").unwrap() + }); + + /// Phone number pattern, allowing for international formats + static PHONE_NUMBER_REGEX: Lazy = Lazy::new(|| { + Regex::new("[+]?[0-9]{1,2}").unwrap() + }); + + /// Get the cached datetime regex #[inline] #[must_use] - pub fn get_numeric_regex() -> Regex { - Regex::new(r"^[-.]?\d+([.,]\d*)?\s*$").unwrap() + pub fn get_datetime_regex() -> &'static Regex { + &DATETIME_REGEX } - /// Email pattern, case-insensitive, allowing for common email formats + /// Get the cached numeric regex #[inline] #[must_use] - pub fn get_email_regex() -> Regex { - Regex::new(r"(?i)\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Za-z]{2,}\b").unwrap() + pub fn get_numeric_regex() -> &'static Regex { + &NUMERIC_REGEX } - /// Simple word pattern, allowing only letters (case-insensitive) + /// Get the cached email regex #[inline] #[must_use] - pub fn get_simple_word_regex() -> Regex { - Regex::new("^[A-Za-z]+$").unwrap() + pub fn get_email_regex() -> &'static Regex { + &EMAIL_REGEX } - /// Phone number pattern, allowing for international formats + /// Get the cached simple word regex #[inline] #[must_use] - pub fn get_phone_number_regex() -> Regex { - Regex::new("[+]?[0-9]{1,2}").unwrap() + pub fn get_simple_word_regex() -> &'static Regex { + &SIMPLE_WORD_REGEX + } + + /// Get the cached phone number regex + #[inline] + #[must_use] + pub fn get_phone_number_regex() -> &'static Regex { + &PHONE_NUMBER_REGEX } #[cfg(test)] @@ -64,7 +91,7 @@ pub mod safe_regex { "03-02-2024 00:00:00", "03/02/2024 07:45:30", ]; - let regex: Regex = get_datetime_regex(); + let regex: &Regex = get_datetime_regex(); for date in &VALID_DATES { assert!(regex.is_match(date), "Erreur sur: {date}"); } @@ -73,7 +100,7 @@ pub mod safe_regex { #[tokio::test] async fn test_invalid_dates() { const INVALID_DATES: [&str; 3] = ["random text", "not a date", "is it a date?"]; - let regex: Regex = get_datetime_regex(); + let regex: &Regex = get_datetime_regex(); for date in &INVALID_DATES { assert!(!regex.is_match(date), "Erreur sur: {date}"); @@ -97,7 +124,7 @@ pub mod safe_regex { "-4645464664.6515", ]; - let regex: Regex = get_numeric_regex(); + let regex: &Regex = get_numeric_regex(); for num in &VALID_NUMBERS { assert!(regex.is_match(num), "Erreur sur: {num}"); @@ -106,7 +133,7 @@ pub mod safe_regex { #[tokio::test] async fn test_invalid_numbers() { - let regex: Regex = get_numeric_regex(); + let regex: &Regex = get_numeric_regex(); const INVALID_NUMBERS: [&str; 8] = [ "abc", "123abc", "--3.14", "3..14", "3,14,15", "..5", "az4a4z6", "0.0.0", ]; @@ -126,7 +153,7 @@ pub mod safe_regex { "a@b.io", ]; - let regex: Regex = get_email_regex(); + let regex: &Regex = get_email_regex(); for email in &VALID_EMAILS { assert!(regex.is_match(email), "Erreur sur: {email}"); @@ -144,7 +171,7 @@ pub mod safe_regex { "user domain.com", ]; - let regex: Regex = get_email_regex(); + let regex: &Regex = get_email_regex(); for email in &INVALID_EMAILS { assert!(!regex.is_match(email), "Erreur sur: {email}"); @@ -162,7 +189,7 @@ pub mod safe_regex { "NUMERI:", ]; - let regex: Regex = get_simple_word_regex(); + let regex: &Regex = get_simple_word_regex(); for email in &VALID_WORD { assert!(!regex.is_match(email), "Erreur sur: {email}"); @@ -227,10 +254,8 @@ pub mod usafe_regex { } } -/// Return a `RegexSet` for unsafe values -#[inline] -#[must_use] -pub fn get_safe_regex_set() -> RegexSet { +/// Cached RegexSet for safe values +static SAFE_REGEX_SET: Lazy = Lazy::new(|| { RegexSet::new([ get_numeric_regex().as_str(), get_datetime_regex().as_str(), @@ -239,15 +264,27 @@ pub fn get_safe_regex_set() -> RegexSet { get_phone_number_regex().as_str(), ]) .unwrap() -} +}); -/// Return a `RegexSet` for unsafe values -#[inline] -#[must_use] -pub fn get_unsafe_value_regex_set() -> RegexSet { +/// Cached RegexSet for unsafe values +static UNSAFE_VALUE_REGEX_SET: Lazy = Lazy::new(|| { RegexSet::new([ usafe_regex::sql_keyword_regex(), usafe_regex::illegal_char_regex(), ]) .unwrap() +}); + +/// Return a reference to the cached `RegexSet` for safe values +#[inline] +#[must_use] +pub fn get_safe_regex_set() -> &'static RegexSet { + &SAFE_REGEX_SET +} + +/// Return a reference to the cached `RegexSet` for unsafe values +#[inline] +#[must_use] +pub fn get_unsafe_value_regex_set() -> &'static RegexSet { + &UNSAFE_VALUE_REGEX_SET } diff --git a/src/main.rs b/src/main.rs index f75512c..2b103df 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,11 +6,11 @@ mod tests; use datalib::detection::anomaly::Anomaly; use datalib::core::io::file::csv_file::CsvFile; -use datalib::model::model::Model; +use datalib::model::model::InitializedModel; use datalib::core::io::tracing::log_level::LogLevel; use datalib::core::io::tracing::logger::log_and_print_message; use datalib::core::utils::util::{ - file_exists, generate_json_file, get_file_from_args, print_report, run_post_execution, + generate_json_file, get_file_from_args, print_report, run_post_execution, }; use std::process::exit; use std::time::Instant; @@ -22,29 +22,22 @@ fn main() { let start_time: Instant = Instant::now(); - let perfage_iae: Model = Model::from_config_file("config.json").unwrap_or_else(|e| { + // Load and initialize model once (much more efficient than Model::analyse_file) + let mut model: InitializedModel = InitializedModel::new("config.json").unwrap_or_else(|e| { log_and_print_message( - &format!("Error loading model configuration: {e}"), + &format!("Error loading model: {e}"), &LogLevel::Error, ); exit(1); }); - [&perfage_iae.model_path, &perfage_iae.vocabulary_path] - .iter() - .for_each(|path| { - if !file_exists(path) { - exit(1); - } - }); - let csv_struct: CsvFile = CsvFile::from_file(&args[0]).unwrap_or_else(|e| { log_and_print_message(&format!("Error reading CSV file: {e}"), &LogLevel::Error); exit(1); }); let (dangerous_output, ai_analyze, regex_analyze): (Vec, u32, u32) = - perfage_iae.analyse_file(&csv_struct).unwrap_or_else(|e| { + model.analyse_file(&csv_struct).unwrap_or_else(|e| { log_and_print_message(&format!("Error analyzing file: {e}"), &LogLevel::Error); exit(1); }); diff --git a/src/model/model.rs b/src/model/model.rs index 47be6d6..dab1800 100644 --- a/src/model/model.rs +++ b/src/model/model.rs @@ -21,6 +21,15 @@ pub struct Model { pub vocabulary_path: String, } +/// Initialized model with loaded PyTorch model and tokenizer. +/// This struct should be reused across multiple file analyses to avoid +/// the expensive model reloading overhead. +pub struct InitializedModel { + model: CModule, + device: Device, + tokenizer: Tokenizer, +} + impl Model { /// Load the model configuration from a JSON file and return a Model instance. pub fn from_config_file(json_path: &str) -> Result> { @@ -50,6 +59,10 @@ impl Model { /// Analyse a CSV file and return a tuple containing the detected anomalies, /// the number of AI analyses performed, and the number of regex analyses performed. + /// + /// **Note**: This method reinitializes the model on every call, which is expensive. + /// Consider using `InitializedModel::new()` and `InitializedModel::analyse_file()` instead + /// for better performance when analyzing multiple files. pub fn analyse_file( &self, csv_file_struct: &CsvFile, @@ -114,7 +127,9 @@ impl Model { .sigmoid(); } - let mut all_outputs: Vec = Vec::new(); + // Pre-allocate vector with exact capacity to avoid reallocations + let num_batches = (encodings.len() + MAX_BATCH_SIZE - 1) / MAX_BATCH_SIZE; + let mut all_outputs: Vec = Vec::with_capacity(num_batches); for batch in encodings.chunks(MAX_BATCH_SIZE) { let output: Tensor = @@ -154,7 +169,179 @@ impl Model { ai_analyze: &mut u32, ) -> Vec { const THRESHOLD: f64 = 0.8; - let mut anomalies: Vec = Vec::new(); + // Pre-allocate with conservative estimate (10% anomaly rate) + let estimated_capacity = batch_data.len() / 10; + let mut anomalies: Vec = Vec::with_capacity(estimated_capacity); + + // Get prediction scores as a 1D vector + let scores = predictions.select(1, 1).iter::().unwrap(); + + for (i, score) in scores.enumerate() { + *ai_analyze += 1; + + // Check if the score exceeds the threshold and if the corresponding data exists + if score > THRESHOLD && let Some(data) = batch_data.get(i) + { + let column_name: String = + headers.get(data.column_index).unwrap_or("unknown").into(); + let row_number: u32 = u32::try_from(data.row_number + 2).unwrap_or(u32::MAX); + + anomalies.push(Anomaly::new( + data.value.clone(), + column_name, + row_number, + score as f32, + )); + } + } + + anomalies + } +} + +impl InitializedModel { + /// Load and initialize the model once from a configuration file. + /// This is much more efficient than using `Model::analyse_file()` repeatedly, + /// as the model is loaded only once and can be reused across multiple analyses. + /// + /// # Example + /// ```no_run + /// use datalib::model::model::InitializedModel; + /// let mut model = InitializedModel::new("config.json")?; + /// let (anomalies, ai_count, regex_count) = model.analyse_file(&csv_file)?; + /// ``` + pub fn new(config_path: &str) -> Result> { + let config = Model::from_config_file(config_path)?; + let device: Device = Device::cuda_if_available(); + let model: CModule = + CModule::load_on_device(&config.model_path, device).unwrap_or_else(|e| { + print_message(&format!("Error loading model: {e}"), &LogLevel::Error); + std::process::exit(1); + }); + let tokenizer: Tokenizer = ModelTokenizer::from_config_file(&config.vocabulary_path)?; + + Ok(Self { + model, + device, + tokenizer, + }) + } + + /// Analyse a CSV file using the pre-loaded model. + /// Returns a tuple containing the detected anomalies, AI analysis count, and regex analysis count. + /// This method reuses the model already loaded in memory, making it much faster + /// than calling `Model::analyse_file()` repeatedly. + pub fn analyse_file( + &mut self, + csv_file_struct: &CsvFile, + ) -> Result<(Vec, u32, u32), Box> { + let mut regex_analyze: u32 = 0; + let mut ai_analyze: u32 = 0; + + let batch_data: Vec = + csv_file_struct.collect_unsafe_value(csv_file_struct, &mut regex_analyze)?; + + if batch_data.is_empty() { + return Ok((Vec::new(), ai_analyze, regex_analyze)); + } + + let (encodings, max_seq_length) = ModelTokenizer::encode_words(&self.tokenizer, &batch_data); + + let predictions: Tensor = Self::run_sigmoid_inference_batched( + &encodings, + max_seq_length, + &mut self.model, + self.device, + ); + + let anomalies: Vec = Self::process_output( + &batch_data, + &predictions, + &csv_file_struct.get_headers()?, + &mut ai_analyze, + ); + + Ok((anomalies, ai_analyze, regex_analyze)) + } + + /// Execute the inference in batches using sigmoid activation. + fn run_sigmoid_inference_batched( + encodings: &[Encoding], + max_seq_length: i64, + model: &mut CModule, + device: Device, + ) -> Tensor { + const MAX_BATCH_SIZE: usize = 32; + model.set_eval(); + + // Fast path for small batches (optional performance boost) + if encodings.len() < 5000 { + return Self::run_single_batch_inference(encodings, max_seq_length, model, device) + .sigmoid(); + } + + // Pre-allocate vector with exact capacity to avoid reallocations + let num_batches = (encodings.len() + MAX_BATCH_SIZE - 1) / MAX_BATCH_SIZE; + let mut all_outputs: Vec = Vec::with_capacity(num_batches); + + for batch in encodings.chunks(MAX_BATCH_SIZE) { + let output: Tensor = + Self::run_single_batch_inference(batch, max_seq_length, model, device); + all_outputs.push(output); + } + + Tensor::cat(&all_outputs, 0).sigmoid() + } + + /// Run inference for a single batch of encodings. + fn run_single_batch_inference( + batch: &[Encoding], + max_seq_length: i64, + model: &CModule, + device: Device, + ) -> Tensor { + let (padded_ids, attention_masks) = ModelTokenizer::build_tokens(batch, max_seq_length); + let batch_size: i64 = i64::try_from(batch.len()).unwrap_or(0); + + let input_ids: Tensor = Tensor::from_slice(&padded_ids) + .view((batch_size, max_seq_length)) + .to_device(device); + + let attention_mask: Tensor = Tensor::from_slice(&attention_masks) + .view((batch_size, max_seq_length)) + .to_device(device); + + Self::forward(model, input_ids, attention_mask) + } + + /// Forward pass through the model with input IDs and attention mask. + fn forward(model: &CModule, input_ids: Tensor, attention_mask: Tensor) -> Tensor { + let output: Tensor = tch::no_grad(|| { + model + .forward_ts(&[input_ids, attention_mask]) + .unwrap_or_else(|e| { + print_message( + &format!("Error during model inference: {e}"), + &LogLevel::Error, + ); + std::process::exit(1); + }) + }); + + output + } + + /// Extract anomalies from the model's predictions and batch data. + fn process_output( + batch_data: &[InferableValue], + predictions: &Tensor, + headers: &StringRecord, + ai_analyze: &mut u32, + ) -> Vec { + const THRESHOLD: f64 = 0.8; + // Pre-allocate with conservative estimate (10% anomaly rate) + let estimated_capacity = batch_data.len() / 10; + let mut anomalies: Vec = Vec::with_capacity(estimated_capacity); // Get prediction scores as a 1D vector let scores = predictions.select(1, 1).iter::().unwrap(); diff --git a/src/model/tokenizer.rs b/src/model/tokenizer.rs index b487055..226689f 100644 --- a/src/model/tokenizer.rs +++ b/src/model/tokenizer.rs @@ -30,7 +30,7 @@ impl ModelTokenizer { ) -> (Vec, i64) { let encodings: Vec = batch_data .iter() - .map(|data| tokenizer.encode(data.value.clone(), true).unwrap()) + .map(|data| tokenizer.encode(data.value.as_str(), true).unwrap()) .collect(); let max_seq_length: i64 = encodings @@ -46,8 +46,14 @@ impl ModelTokenizer { #[inline] #[must_use] pub fn ids_to_vector(encoding: &Encoding) -> (Vec, i64) { + const PARALLEL_THRESHOLD: usize = 1000; let ids: &[u32] = encoding.get_ids(); - let ids: Vec = ids.par_iter().map(|&x| i64::from(x)).collect(); + // Use parallel iteration only for large sequences to avoid overhead + let ids: Vec = if ids.len() > PARALLEL_THRESHOLD { + ids.par_iter().map(|&x| i64::from(x)).collect() + } else { + ids.iter().map(|&x| i64::from(x)).collect() + }; let seq_length: i64 = i64::try_from(ids.len()).unwrap_or(i64::MAX); (ids, seq_length) } From 1a1b2e4f22469e414d37d7fe66b8a94f97b6dafa Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Tue, 23 Dec 2025 19:08:17 +0100 Subject: [PATCH 6/7] chore: update version to 1.0.1 and enhance performance with increased batch size and parallel tokenization --- .github/workflows/cd.yml | 10 +++++----- .github/workflows/ci.yml | 2 +- Cargo.lock | 2 +- Cargo.toml | 2 +- src/model/model.rs | 4 ++-- src/model/tokenizer.rs | 3 ++- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index b29816b..99b73b6 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -19,7 +19,7 @@ on: env: CARGO_TERM_COLOR: always - LIBTORCH_VERSION: 2.7.1 + LIBTORCH_VERSION: 2.9.1 LIBTORCH_CXX11_ABI: 1 LIBTORCH_BYPASS_VERSION_CHECK: 1 @@ -32,14 +32,14 @@ jobs: matrix: include: # Linux - - { os: ubuntu-latest, target: x86_64-unknown-linux-gnu, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.7.1%2Bcpu.zip" } + - { os: ubuntu-latest, target: x86_64-unknown-linux-gnu, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.9.1%2Bcpu.zip" } # Windows - - { os: windows-latest, target: x86_64-pc-windows-msvc, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.7.1%2Bcpu.zip" } + - { os: windows-latest, target: x86_64-pc-windows-msvc, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.9.1%2Bcpu.zip" } # macOS - - { os: macos-latest, target: x86_64-apple-darwin, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.7.1.zip" } - - { os: macos-latest, target: aarch64-apple-darwin, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.7.1.zip" } + - { os: macos-latest, target: x86_64-apple-darwin, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.9.1.zip" } + - { os: macos-latest, target: aarch64-apple-darwin, libtorch_url: "https://download.pytorch.org/libtorch/cpu/libtorch-macos-2.9.1.zip" } steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index af29bde..7b899c4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ on: env: CARGO_TERM_COLOR: always - LIBTORCH_VERSION: 2.7.1 + LIBTORCH_VERSION: 2.9.1 jobs: build: diff --git a/Cargo.lock b/Cargo.lock index 8f35898..b8b6318 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 4 [[package]] name = "DataLint" -version = "1.0.0" +version = "1.0.1" dependencies = [ "chrono", "criterion", diff --git a/Cargo.toml b/Cargo.toml index 41d2fc6..299a954 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "DataLint" -version = "1.0.0" +version = "1.0.1" edition = "2024" description = "A Rust executable for linting potential unsafe values in CSV files" license = "GPL-3" diff --git a/src/model/model.rs b/src/model/model.rs index dab1800..35b0695 100644 --- a/src/model/model.rs +++ b/src/model/model.rs @@ -118,7 +118,7 @@ impl Model { model: &mut CModule, device: Device, ) -> Tensor { - const MAX_BATCH_SIZE: usize = 32; + const MAX_BATCH_SIZE: usize = 512; // Increased from 32 for better GPU utilization model.set_eval(); // Fast path for small batches (optional performance boost) @@ -271,7 +271,7 @@ impl InitializedModel { model: &mut CModule, device: Device, ) -> Tensor { - const MAX_BATCH_SIZE: usize = 32; + const MAX_BATCH_SIZE: usize = 512; // Increased from 32 for better GPU utilization model.set_eval(); // Fast path for small batches (optional performance boost) diff --git a/src/model/tokenizer.rs b/src/model/tokenizer.rs index 226689f..6c48877 100644 --- a/src/model/tokenizer.rs +++ b/src/model/tokenizer.rs @@ -28,8 +28,9 @@ impl ModelTokenizer { tokenizer: &Tokenizer, batch_data: &[InferableValue], ) -> (Vec, i64) { + // Use parallel iteration for tokenization (CPU-bound operation) let encodings: Vec = batch_data - .iter() + .par_iter() .map(|data| tokenizer.encode(data.value.as_str(), true).unwrap()) .collect(); From 066cd491fecf754338f5d6857ec8b0115f4a4721 Mon Sep 17 00:00:00 2001 From: Maxime <98154358+Maxime-Cllt@users.noreply.github.com> Date: Wed, 24 Dec 2025 08:42:06 +0100 Subject: [PATCH 7/7] refactor: update model imports and enhance documentation with error handling details --- src/benches/application_bench.rs | 2 +- src/core/io/file/csv_file.rs | 34 ++- src/core/io/file/json_output.rs | 4 + src/core/io/tracing/logger.rs | 11 +- src/core/utils/regex.rs | 22 +- src/core/utils/util.rs | 12 + src/main.rs | 2 +- src/model/mod.rs | 391 ++++++++++++++++++++++++++++++- src/model/model.rs | 370 ----------------------------- src/model/tokenizer.rs | 20 +- src/tests/model_tests.rs | 2 +- 11 files changed, 476 insertions(+), 394 deletions(-) delete mode 100644 src/model/model.rs diff --git a/src/benches/application_bench.rs b/src/benches/application_bench.rs index b0a6e41..b9cb4ee 100644 --- a/src/benches/application_bench.rs +++ b/src/benches/application_bench.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, Criterion}; use datalib::core::io::file::csv_file::CsvFile; -use datalib::model::model::Model; +use datalib::model::Model; use std::time::Duration; use datalib::core::io::tracing::log_level::LogLevel; use datalib::core::io::tracing::logger::log_and_print_message; diff --git a/src/core/io/file/csv_file.rs b/src/core/io/file/csv_file.rs index 0315c40..be23abe 100644 --- a/src/core/io/file/csv_file.rs +++ b/src/core/io/file/csv_file.rs @@ -32,6 +32,10 @@ impl CsvFile { } /// Return the separator as a char. + /// + /// # Panics + /// + /// Panics if the file cannot be read or no valid separator is found. #[inline] #[must_use] pub fn find_separator_in_file(csv_file_path: &str) -> SeparatorType { @@ -61,6 +65,10 @@ impl CsvFile { } /// Return the headers of the CSV file as a `StringRecord`. + /// + /// # Errors + /// + /// Returns an error if the file cannot be read. pub fn get_headers(&self) -> Result> { let binding: String = Self::read_first_line(&self.csv_file_path)?; let first_line: &str = binding.trim(); @@ -71,6 +79,10 @@ impl CsvFile { } /// Read the first line of a file and return it as a String. + /// + /// # Errors + /// + /// Returns an error if the file cannot be opened or read, or if the file is empty. #[inline] pub fn read_first_line(file_path: &str) -> io::Result { let file: File = File::open(file_path)?; @@ -88,6 +100,10 @@ impl CsvFile { } /// Check if a file is encoded in UTF-8. + /// + /// # Errors + /// + /// Returns an error if the file cannot be opened or read. #[inline] pub fn is_file_utf8(file_path: &str) -> Result> { const CHUNK_SIZE: usize = 16 * 1024; // 16 KB @@ -110,6 +126,10 @@ impl CsvFile { } /// Convert a file to UTF-8 encoding and save it in the "utf8" directory. + /// + /// # Errors + /// + /// Returns an error if file operations fail (open, read, write). #[inline] pub fn convert_file_to_utf8(input_path: &str) -> Result> { const UTF8: &str = "utf8"; @@ -149,6 +169,10 @@ impl CsvFile { } /// Create a `CsvFile` instance from a file path, checking its encoding and separator. + /// + /// # Errors + /// + /// Returns an error if file encoding check fails, UTF-8 conversion fails, or separator detection fails. pub fn from_file(csv_file_path: &str) -> Result> { let is_utf8: bool = Self::is_file_utf8(csv_file_path).map_err(|e| { log_and_print_message( @@ -158,7 +182,9 @@ impl CsvFile { e })?; - let csv_file_path: String = if !is_utf8 { + let csv_file_path: String = if is_utf8 { + String::from(csv_file_path) + } else { Self::convert_file_to_utf8(csv_file_path).map_err(|e| { log_and_print_message( &format!("Error converting file to UTF-8: {e}"), @@ -166,8 +192,6 @@ impl CsvFile { ); e })? - } else { - String::from(csv_file_path) }; let separator: u8 = u8::from(match Self::find_separator_in_file(&csv_file_path) { @@ -189,6 +213,10 @@ impl CsvFile { } /// Collect unsafe values from the CSV file based on regex patterns. + /// + /// # Errors + /// + /// Returns an error if the CSV file cannot be opened or read. #[inline] pub fn collect_unsafe_value( &self, diff --git a/src/core/io/file/json_output.rs b/src/core/io/file/json_output.rs index ac89c7e..83cf128 100644 --- a/src/core/io/file/json_output.rs +++ b/src/core/io/file/json_output.rs @@ -35,6 +35,10 @@ impl JsonOutput { } /// Save the `JsonOutput` to a file in pretty JSON format + /// + /// # Errors + /// + /// Returns an error if JSON serialization fails or file writing fails. pub fn save_to_file(&self, file_path: &str) -> std::io::Result<()> { let json_data: String = serde_json::to_string_pretty(self)?; std::fs::write(file_path, json_data) diff --git a/src/core/io/tracing/logger.rs b/src/core/io/tracing/logger.rs index bdb2f03..d42e26a 100644 --- a/src/core/io/tracing/logger.rs +++ b/src/core/io/tracing/logger.rs @@ -42,13 +42,22 @@ impl Logger { pub static LOGGER: std::sync::LazyLock> = std::sync::LazyLock::new(|| Mutex::new(Logger::new())); -/// Static logger instance +/// Log a message and print it to the console. +/// +/// # Panics +/// +/// Panics if the logger mutex is poisoned. pub fn log_and_print_message(message: &str, log_level: &LogLevel) { print_message(message, log_level); let logger: MutexGuard = LOGGER.lock().unwrap(); logger.log(log_level, message); } +/// Log a message without printing to console. +/// +/// # Panics +/// +/// Panics if the logger mutex is poisoned. pub fn log_message(message: &str, log_level: &LogLevel) { let logger: MutexGuard = LOGGER.lock().unwrap(); logger.log(log_level, message); diff --git a/src/core/utils/regex.rs b/src/core/utils/regex.rs index 26c78f6..33c7ed8 100644 --- a/src/core/utils/regex.rs +++ b/src/core/utils/regex.rs @@ -2,15 +2,15 @@ use crate::core::utils::regex::safe_regex::{ get_datetime_regex, get_email_regex, get_numeric_regex, get_phone_number_regex, get_simple_word_regex, }; -use once_cell::sync::Lazy; +use std::sync::LazyLock; use regex::RegexSet; pub mod safe_regex { - use once_cell::sync::Lazy; + use std::sync::LazyLock; use regex::Regex; /// Date and time pattern, supporting various formats - static DATETIME_REGEX: Lazy = Lazy::new(|| { + static DATETIME_REGEX: LazyLock = LazyLock::new(|| { Regex::new( r"(?i)\b(?:\d{4}[-/]\d{2}[-/]\d{2}|\d{2}[-/]\d{2}[-/]\d{4})\s?(?:\d{2}[:]\d{2}[:]\d{2})?\b", ) @@ -18,22 +18,22 @@ pub mod safe_regex { }); /// Numeric pattern, allowing for integers and decimals with optional signs - static NUMERIC_REGEX: Lazy = Lazy::new(|| { + static NUMERIC_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"^[-.]?\d+([.,]\d*)?\s*$").unwrap() }); /// Email pattern, case-insensitive, allowing for common email formats - static EMAIL_REGEX: Lazy = Lazy::new(|| { + static EMAIL_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"(?i)\b[A-Z0-9._%+-]+@[A-Z0-9.-]+\.[A-Za-z]{2,}\b").unwrap() }); /// Simple word pattern, allowing only letters (case-insensitive) - static SIMPLE_WORD_REGEX: Lazy = Lazy::new(|| { + static SIMPLE_WORD_REGEX: LazyLock = LazyLock::new(|| { Regex::new("^[A-Za-z]+$").unwrap() }); /// Phone number pattern, allowing for international formats - static PHONE_NUMBER_REGEX: Lazy = Lazy::new(|| { + static PHONE_NUMBER_REGEX: LazyLock = LazyLock::new(|| { Regex::new("[+]?[0-9]{1,2}").unwrap() }); @@ -254,8 +254,8 @@ pub mod usafe_regex { } } -/// Cached RegexSet for safe values -static SAFE_REGEX_SET: Lazy = Lazy::new(|| { +/// Cached `RegexSet` for safe values +static SAFE_REGEX_SET: LazyLock = LazyLock::new(|| { RegexSet::new([ get_numeric_regex().as_str(), get_datetime_regex().as_str(), @@ -266,8 +266,8 @@ static SAFE_REGEX_SET: Lazy = Lazy::new(|| { .unwrap() }); -/// Cached RegexSet for unsafe values -static UNSAFE_VALUE_REGEX_SET: Lazy = Lazy::new(|| { +/// Cached `RegexSet` for unsafe values +static UNSAFE_VALUE_REGEX_SET: LazyLock = LazyLock::new(|| { RegexSet::new([ usafe_regex::sql_keyword_regex(), usafe_regex::illegal_char_regex(), diff --git a/src/core/utils/util.rs b/src/core/utils/util.rs index 3d16f0e..0d855b5 100644 --- a/src/core/utils/util.rs +++ b/src/core/utils/util.rs @@ -8,6 +8,10 @@ use crate::core::io::tracing::color::Color; use crate::core::io::tracing::logger::{log_and_print_message, log_message, print_message}; /// Create a JSON file with the analysis results. +/// +/// # Panics +/// +/// Panics if the current directory path contains invalid UTF-8. pub fn generate_json_file( dangerous_output: Vec, regex_analyze: u32, @@ -70,6 +74,10 @@ pub fn generate_json_file( } /// Get the CSV file path and output JSON file name from command line arguments. +/// +/// # Errors +/// +/// Returns an error if the number of arguments is not exactly 2. pub fn get_file_from_args(args: &[String]) -> Result<[String; 2], Error> { if args.len() != 3 { return Err(Error::new( @@ -173,6 +181,10 @@ pub fn run_post_execution(file_path: &str) { } /// Extract the file name without the extension from a given file path. +/// +/// # Panics +/// +/// Panics if the file path is empty (though this should never happen in practice). #[inline] #[must_use] pub fn get_file_name(file_path: &str) -> &str { diff --git a/src/main.rs b/src/main.rs index 2b103df..4a0c120 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ mod tests; use datalib::detection::anomaly::Anomaly; use datalib::core::io::file::csv_file::CsvFile; -use datalib::model::model::InitializedModel; +use datalib::model::InitializedModel; use datalib::core::io::tracing::log_level::LogLevel; use datalib::core::io::tracing::logger::log_and_print_message; use datalib::core::utils::util::{ diff --git a/src/model/mod.rs b/src/model/mod.rs index 0e2590d..b3e7ce5 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -1,2 +1,391 @@ -pub mod model; pub mod tokenizer; + +use crate::core::io::tracing::log_level::LogLevel; +use crate::detection::anomaly::Anomaly; +use crate::core::io::file::csv_file::CsvFile; +use crate::detection::inferable_value::InferableValue; +use crate::model::tokenizer::ModelTokenizer; +use csv::StringRecord; +use serde::Deserialize; +use std::error::Error; +use std::fs::File; +use tch::{CModule, Device, Tensor}; +use tokenizers::{Encoding, Tokenizer}; +use crate::core::io::tracing::logger::print_message; + +/// Represents the model configuration for the anomaly detection system. +/// +/// It contains the paths to the model and vocabulary files. +/// The model is used for inference, while the vocabulary is used for tokenization. +/// The model is expected to be a `PyTorch` model, and the vocabulary is expected to be a tokenizer configuration file. +#[derive(Deserialize)] +pub struct Model { + pub model_path: String, + pub vocabulary_path: String, +} + +/// Initialized model with loaded `PyTorch` model and tokenizer. +/// This struct should be reused across multiple file analyses to avoid +/// the expensive model reloading overhead. +pub struct InitializedModel { + model: CModule, + device: Device, + tokenizer: Tokenizer, +} + +impl Model { + /// Load the model configuration from a JSON file and return a Model instance. + /// + /// # Errors + /// + /// Returns an error if the file cannot be opened or the JSON is invalid. + pub fn from_config_file(json_path: &str) -> Result> { + let json_file: File = File::open(json_path)?; + let model = serde_json::from_reader(json_file).unwrap_or_else(|e| { + print_message( + &format!("Error reading model configuration from JSON: {e}"), + &LogLevel::Error, + ); + std::process::exit(1); + }); + + Ok(model) + } + + /// Init the model, device, and tokenizer based on the model path and vocabulary path. + fn init_model(&self) -> Result<(CModule, Device, Tokenizer), Box> { + let device: Device = Device::cuda_if_available(); + let model: CModule = + CModule::load_on_device(&self.model_path, device).unwrap_or_else(|e| { + print_message(&format!("Error loading model: {e}"), &LogLevel::Error); + std::process::exit(1); + }); + let tokenizer: Tokenizer = ModelTokenizer::from_config_file(&self.vocabulary_path)?; + Ok((model, device, tokenizer)) + } + + /// Analyse a CSV file and return a tuple containing the detected anomalies, + /// the number of AI analyses performed, and the number of regex analyses performed. + /// + /// **Note**: This method reinitializes the model on every call, which is expensive. + /// Consider using `InitializedModel::new()` and `InitializedModel::analyse_file()` instead + /// for better performance when analyzing multiple files. + /// + /// # Errors + /// + /// Returns an error if CSV processing fails or model initialization encounters issues. + pub fn analyse_file( + &self, + csv_file_struct: &CsvFile, + ) -> Result<(Vec, u32, u32), Box> { + let mut regex_analyze: u32 = 0; + let mut ai_analyze: u32 = 0; + + let batch_data: Vec = + csv_file_struct.collect_unsafe_value(csv_file_struct, &mut regex_analyze)?; + + if batch_data.is_empty() { + return Ok((Vec::new(), ai_analyze, regex_analyze)); + } + + let (mut model, device, tokenizer): (CModule, Device, Tokenizer) = self.init_model()?; + + let (encodings, max_seq_length) = ModelTokenizer::encode_words(&tokenizer, &batch_data); + + let predictions: Tensor = + Self::run_sigmoid_inference_batched(&encodings, max_seq_length, &mut model, device); + + let anomalies: Vec = Self::process_output( + &batch_data, + &predictions, + &csv_file_struct.get_headers()?, + &mut ai_analyze, + ); + + Ok((anomalies, ai_analyze, regex_analyze)) + } + + /// Forward pass through the model with input IDs and attention mask. + fn forward(model: &CModule, input_ids: Tensor, attention_mask: Tensor) -> Tensor { + let output: Tensor = tch::no_grad(|| { + model + .forward_ts(&[input_ids, attention_mask]) + .unwrap_or_else(|e| { + print_message( + &format!("Error during model inference: {e}"), + &LogLevel::Error, + ); + std::process::exit(1); + }) + }); + + output + } + + /// Execute the inference in batches using sigmoid activation. + fn run_sigmoid_inference_batched( + encodings: &[Encoding], + max_seq_length: i64, + model: &mut CModule, + device: Device, + ) -> Tensor { + const MAX_BATCH_SIZE: usize = 512; // Increased from 32 for better GPU utilization + model.set_eval(); + + // Fast path for small batches (optional performance boost) + if encodings.len() < 5000 { + return Self::run_single_batch_inference(encodings, max_seq_length, model, device) + .sigmoid(); + } + + // Pre-allocate vector with exact capacity to avoid reallocations + let num_batches = encodings.len().div_ceil(MAX_BATCH_SIZE); + let mut all_outputs: Vec = Vec::with_capacity(num_batches); + + for batch in encodings.chunks(MAX_BATCH_SIZE) { + let output: Tensor = + Self::run_single_batch_inference(batch, max_seq_length, model, device); + all_outputs.push(output); + } + + Tensor::cat(&all_outputs, 0).sigmoid() + } + + /// Run inference for a single batch of encodings. + fn run_single_batch_inference( + batch: &[Encoding], + max_seq_length: i64, + model: &CModule, + device: Device, + ) -> Tensor { + let (padded_ids, attention_masks) = ModelTokenizer::build_tokens(batch, max_seq_length); + let batch_size: i64 = i64::try_from(batch.len()).unwrap_or(0); + + let input_ids: Tensor = Tensor::from_slice(&padded_ids) + .view((batch_size, max_seq_length)) + .to_device(device); + + let attention_mask: Tensor = Tensor::from_slice(&attention_masks) + .view((batch_size, max_seq_length)) + .to_device(device); + + Self::forward(model, input_ids, attention_mask) + } + + /// Extract anomalies from the model's predictions and batch data. + fn process_output( + batch_data: &[InferableValue], + predictions: &Tensor, + headers: &StringRecord, + ai_analyze: &mut u32, + ) -> Vec { + const THRESHOLD: f64 = 0.8; + // Pre-allocate with conservative estimate (10% anomaly rate) + let estimated_capacity = batch_data.len() / 10; + let mut anomalies: Vec = Vec::with_capacity(estimated_capacity); + + // Get prediction scores as a 1D vector + let scores = predictions.select(1, 1).iter::().unwrap(); + + for (i, score) in scores.enumerate() { + *ai_analyze += 1; + + // Check if the score exceeds the threshold and if the corresponding data exists + if score > THRESHOLD && let Some(data) = batch_data.get(i) + { + let column_name: String = + headers.get(data.column_index).unwrap_or("unknown").into(); + let row_number: u32 = u32::try_from(data.row_number + 2).unwrap_or(u32::MAX); + + #[allow(clippy::cast_possible_truncation)] + anomalies.push(Anomaly::new( + data.value.clone(), + column_name, + row_number, + score as f32, + )); + } + } + + anomalies + } +} + +impl InitializedModel { + /// Load and initialize the model once from a configuration file. + /// This is much more efficient than using `Model::analyse_file()` repeatedly, + /// as the model is loaded only once and can be reused across multiple analyses. + /// + /// # Errors + /// + /// Returns an error if the configuration file cannot be loaded or model initialization fails. + /// + /// # Example + /// ```no_run + /// use datalib::model::InitializedModel; + /// let mut model = InitializedModel::new("config.json")?; + /// let (anomalies, ai_count, regex_count) = model.analyse_file(&csv_file)?; + /// ``` + pub fn new(config_path: &str) -> Result> { + let config = Model::from_config_file(config_path)?; + let device: Device = Device::cuda_if_available(); + let model: CModule = + CModule::load_on_device(&config.model_path, device).unwrap_or_else(|e| { + print_message(&format!("Error loading model: {e}"), &LogLevel::Error); + std::process::exit(1); + }); + let tokenizer: Tokenizer = ModelTokenizer::from_config_file(&config.vocabulary_path)?; + + Ok(Self { + model, + device, + tokenizer, + }) + } + + /// Analyse a CSV file using the pre-loaded model. + /// Returns a tuple containing the detected anomalies, AI analysis count, and regex analysis count. + /// This method reuses the model already loaded in memory, making it much faster + /// than calling `Model::analyse_file()` repeatedly. + /// + /// # Errors + /// + /// Returns an error if CSV processing or model inference fails. + pub fn analyse_file( + &mut self, + csv_file_struct: &CsvFile, + ) -> Result<(Vec, u32, u32), Box> { + let mut regex_analyze: u32 = 0; + let mut ai_analyze: u32 = 0; + + let batch_data: Vec = + csv_file_struct.collect_unsafe_value(csv_file_struct, &mut regex_analyze)?; + + if batch_data.is_empty() { + return Ok((Vec::new(), ai_analyze, regex_analyze)); + } + + let (encodings, max_seq_length) = ModelTokenizer::encode_words(&self.tokenizer, &batch_data); + + let predictions: Tensor = Self::run_sigmoid_inference_batched( + &encodings, + max_seq_length, + &mut self.model, + self.device, + ); + + let anomalies: Vec = Self::process_output( + &batch_data, + &predictions, + &csv_file_struct.get_headers()?, + &mut ai_analyze, + ); + + Ok((anomalies, ai_analyze, regex_analyze)) + } + + /// Execute the inference in batches using sigmoid activation. + fn run_sigmoid_inference_batched( + encodings: &[Encoding], + max_seq_length: i64, + model: &mut CModule, + device: Device, + ) -> Tensor { + const MAX_BATCH_SIZE: usize = 512; // Increased from 32 for better GPU utilization + model.set_eval(); + + // Fast path for small batches (optional performance boost) + if encodings.len() < 5000 { + return Self::run_single_batch_inference(encodings, max_seq_length, model, device) + .sigmoid(); + } + + // Pre-allocate vector with exact capacity to avoid reallocations + let num_batches = encodings.len().div_ceil(MAX_BATCH_SIZE); + let mut all_outputs: Vec = Vec::with_capacity(num_batches); + + for batch in encodings.chunks(MAX_BATCH_SIZE) { + let output: Tensor = + Self::run_single_batch_inference(batch, max_seq_length, model, device); + all_outputs.push(output); + } + + Tensor::cat(&all_outputs, 0).sigmoid() + } + + /// Run inference for a single batch of encodings. + fn run_single_batch_inference( + batch: &[Encoding], + max_seq_length: i64, + model: &CModule, + device: Device, + ) -> Tensor { + let (padded_ids, attention_masks) = ModelTokenizer::build_tokens(batch, max_seq_length); + let batch_size: i64 = i64::try_from(batch.len()).unwrap_or(0); + + let input_ids: Tensor = Tensor::from_slice(&padded_ids) + .view((batch_size, max_seq_length)) + .to_device(device); + + let attention_mask: Tensor = Tensor::from_slice(&attention_masks) + .view((batch_size, max_seq_length)) + .to_device(device); + + Self::forward(model, input_ids, attention_mask) + } + + /// Forward pass through the model with input IDs and attention mask. + fn forward(model: &CModule, input_ids: Tensor, attention_mask: Tensor) -> Tensor { + let output: Tensor = tch::no_grad(|| { + model + .forward_ts(&[input_ids, attention_mask]) + .unwrap_or_else(|e| { + print_message( + &format!("Error during model inference: {e}"), + &LogLevel::Error, + ); + std::process::exit(1); + }) + }); + + output + } + + /// Extract anomalies from the model's predictions and batch data. + fn process_output( + batch_data: &[InferableValue], + predictions: &Tensor, + headers: &StringRecord, + ai_analyze: &mut u32, + ) -> Vec { + const THRESHOLD: f64 = 0.8; + // Pre-allocate with conservative estimate (10% anomaly rate) + let estimated_capacity = batch_data.len() / 10; + let mut anomalies: Vec = Vec::with_capacity(estimated_capacity); + + // Get prediction scores as a 1D vector + let scores = predictions.select(1, 1).iter::().unwrap(); + + for (i, score) in scores.enumerate() { + *ai_analyze += 1; + + // Check if the score exceeds the threshold and if the corresponding data exists + if score > THRESHOLD && let Some(data) = batch_data.get(i) + { + let column_name: String = + headers.get(data.column_index).unwrap_or("unknown").into(); + let row_number: u32 = u32::try_from(data.row_number + 2).unwrap_or(u32::MAX); + + #[allow(clippy::cast_possible_truncation)] + anomalies.push(Anomaly::new( + data.value.clone(), + column_name, + row_number, + score as f32, + )); + } + } + + anomalies + } +} diff --git a/src/model/model.rs b/src/model/model.rs deleted file mode 100644 index 35b0695..0000000 --- a/src/model/model.rs +++ /dev/null @@ -1,370 +0,0 @@ -use crate::core::io::tracing::log_level::LogLevel; -use crate::detection::anomaly::Anomaly; -use crate::core::io::file::csv_file::CsvFile; -use crate::detection::inferable_value::InferableValue; -use crate::model::tokenizer::ModelTokenizer; -use csv::StringRecord; -use serde::Deserialize; -use std::error::Error; -use std::fs::File; -use tch::{CModule, Device, Tensor}; -use tokenizers::{Encoding, Tokenizer}; -use crate::core::io::tracing::logger::print_message; - -/// Represents the model configuration for the anomaly detection system. -/// It contains the paths to the model and vocabulary files. -/// The model is used for inference, while the vocabulary is used for tokenization. -/// The model is expected to be a PyTorch model, and the vocabulary is expected to be a tokenizer configuration file. -#[derive(Deserialize)] -pub struct Model { - pub model_path: String, - pub vocabulary_path: String, -} - -/// Initialized model with loaded PyTorch model and tokenizer. -/// This struct should be reused across multiple file analyses to avoid -/// the expensive model reloading overhead. -pub struct InitializedModel { - model: CModule, - device: Device, - tokenizer: Tokenizer, -} - -impl Model { - /// Load the model configuration from a JSON file and return a Model instance. - pub fn from_config_file(json_path: &str) -> Result> { - let json_file: File = File::open(json_path)?; - let model = serde_json::from_reader(json_file).unwrap_or_else(|e| { - print_message( - &format!("Error reading model configuration from JSON: {e}"), - &LogLevel::Error, - ); - std::process::exit(1); - }); - - Ok(model) - } - - /// Init the model, device, and tokenizer based on the model path and vocabulary path. - fn init_model(&self) -> Result<(CModule, Device, Tokenizer), Box> { - let device: Device = Device::cuda_if_available(); - let model: CModule = - CModule::load_on_device(&self.model_path, device).unwrap_or_else(|e| { - print_message(&format!("Error loading model: {e}"), &LogLevel::Error); - std::process::exit(1); - }); - let tokenizer: Tokenizer = ModelTokenizer::from_config_file(&self.vocabulary_path)?; - Ok((model, device, tokenizer)) - } - - /// Analyse a CSV file and return a tuple containing the detected anomalies, - /// the number of AI analyses performed, and the number of regex analyses performed. - /// - /// **Note**: This method reinitializes the model on every call, which is expensive. - /// Consider using `InitializedModel::new()` and `InitializedModel::analyse_file()` instead - /// for better performance when analyzing multiple files. - pub fn analyse_file( - &self, - csv_file_struct: &CsvFile, - ) -> Result<(Vec, u32, u32), Box> { - let mut regex_analyze: u32 = 0; - let mut ai_analyze: u32 = 0; - - let batch_data: Vec = - csv_file_struct.collect_unsafe_value(csv_file_struct, &mut regex_analyze)?; - - if batch_data.is_empty() { - return Ok((Vec::new(), ai_analyze, regex_analyze)); - } - - let (mut model, device, tokenizer): (CModule, Device, Tokenizer) = self.init_model()?; - - let (encodings, max_seq_length) = ModelTokenizer::encode_words(&tokenizer, &batch_data); - - let predictions: Tensor = - Self::run_sigmoid_inference_batched(&encodings, max_seq_length, &mut model, device); - - let anomalies: Vec = Self::process_output( - &batch_data, - &predictions, - &csv_file_struct.get_headers()?, - &mut ai_analyze, - ); - - Ok((anomalies, ai_analyze, regex_analyze)) - } - - /// Forward pass through the model with input IDs and attention mask. - fn forward(model: &CModule, input_ids: Tensor, attention_mask: Tensor) -> Tensor { - let output: Tensor = tch::no_grad(|| { - model - .forward_ts(&[input_ids, attention_mask]) - .unwrap_or_else(|e| { - print_message( - &format!("Error during model inference: {e}"), - &LogLevel::Error, - ); - std::process::exit(1); - }) - }); - - output - } - - /// Execute the inference in batches using sigmoid activation. - fn run_sigmoid_inference_batched( - encodings: &[Encoding], - max_seq_length: i64, - model: &mut CModule, - device: Device, - ) -> Tensor { - const MAX_BATCH_SIZE: usize = 512; // Increased from 32 for better GPU utilization - model.set_eval(); - - // Fast path for small batches (optional performance boost) - if encodings.len() < 5000 { - return Self::run_single_batch_inference(encodings, max_seq_length, model, device) - .sigmoid(); - } - - // Pre-allocate vector with exact capacity to avoid reallocations - let num_batches = (encodings.len() + MAX_BATCH_SIZE - 1) / MAX_BATCH_SIZE; - let mut all_outputs: Vec = Vec::with_capacity(num_batches); - - for batch in encodings.chunks(MAX_BATCH_SIZE) { - let output: Tensor = - Self::run_single_batch_inference(batch, max_seq_length, model, device); - all_outputs.push(output); - } - - Tensor::cat(&all_outputs, 0).sigmoid() - } - - /// Run inference for a single batch of encodings. - fn run_single_batch_inference( - batch: &[Encoding], - max_seq_length: i64, - model: &CModule, - device: Device, - ) -> Tensor { - let (padded_ids, attention_masks) = ModelTokenizer::build_tokens(batch, max_seq_length); - let batch_size: i64 = i64::try_from(batch.len()).unwrap_or(0); - - let input_ids: Tensor = Tensor::from_slice(&padded_ids) - .view((batch_size, max_seq_length)) - .to_device(device); - - let attention_mask: Tensor = Tensor::from_slice(&attention_masks) - .view((batch_size, max_seq_length)) - .to_device(device); - - Self::forward(model, input_ids, attention_mask) - } - - /// Extract anomalies from the model's predictions and batch data. - fn process_output( - batch_data: &[InferableValue], - predictions: &Tensor, - headers: &StringRecord, - ai_analyze: &mut u32, - ) -> Vec { - const THRESHOLD: f64 = 0.8; - // Pre-allocate with conservative estimate (10% anomaly rate) - let estimated_capacity = batch_data.len() / 10; - let mut anomalies: Vec = Vec::with_capacity(estimated_capacity); - - // Get prediction scores as a 1D vector - let scores = predictions.select(1, 1).iter::().unwrap(); - - for (i, score) in scores.enumerate() { - *ai_analyze += 1; - - // Check if the score exceeds the threshold and if the corresponding data exists - if score > THRESHOLD && let Some(data) = batch_data.get(i) - { - let column_name: String = - headers.get(data.column_index).unwrap_or("unknown").into(); - let row_number: u32 = u32::try_from(data.row_number + 2).unwrap_or(u32::MAX); - - anomalies.push(Anomaly::new( - data.value.clone(), - column_name, - row_number, - score as f32, - )); - } - } - - anomalies - } -} - -impl InitializedModel { - /// Load and initialize the model once from a configuration file. - /// This is much more efficient than using `Model::analyse_file()` repeatedly, - /// as the model is loaded only once and can be reused across multiple analyses. - /// - /// # Example - /// ```no_run - /// use datalib::model::model::InitializedModel; - /// let mut model = InitializedModel::new("config.json")?; - /// let (anomalies, ai_count, regex_count) = model.analyse_file(&csv_file)?; - /// ``` - pub fn new(config_path: &str) -> Result> { - let config = Model::from_config_file(config_path)?; - let device: Device = Device::cuda_if_available(); - let model: CModule = - CModule::load_on_device(&config.model_path, device).unwrap_or_else(|e| { - print_message(&format!("Error loading model: {e}"), &LogLevel::Error); - std::process::exit(1); - }); - let tokenizer: Tokenizer = ModelTokenizer::from_config_file(&config.vocabulary_path)?; - - Ok(Self { - model, - device, - tokenizer, - }) - } - - /// Analyse a CSV file using the pre-loaded model. - /// Returns a tuple containing the detected anomalies, AI analysis count, and regex analysis count. - /// This method reuses the model already loaded in memory, making it much faster - /// than calling `Model::analyse_file()` repeatedly. - pub fn analyse_file( - &mut self, - csv_file_struct: &CsvFile, - ) -> Result<(Vec, u32, u32), Box> { - let mut regex_analyze: u32 = 0; - let mut ai_analyze: u32 = 0; - - let batch_data: Vec = - csv_file_struct.collect_unsafe_value(csv_file_struct, &mut regex_analyze)?; - - if batch_data.is_empty() { - return Ok((Vec::new(), ai_analyze, regex_analyze)); - } - - let (encodings, max_seq_length) = ModelTokenizer::encode_words(&self.tokenizer, &batch_data); - - let predictions: Tensor = Self::run_sigmoid_inference_batched( - &encodings, - max_seq_length, - &mut self.model, - self.device, - ); - - let anomalies: Vec = Self::process_output( - &batch_data, - &predictions, - &csv_file_struct.get_headers()?, - &mut ai_analyze, - ); - - Ok((anomalies, ai_analyze, regex_analyze)) - } - - /// Execute the inference in batches using sigmoid activation. - fn run_sigmoid_inference_batched( - encodings: &[Encoding], - max_seq_length: i64, - model: &mut CModule, - device: Device, - ) -> Tensor { - const MAX_BATCH_SIZE: usize = 512; // Increased from 32 for better GPU utilization - model.set_eval(); - - // Fast path for small batches (optional performance boost) - if encodings.len() < 5000 { - return Self::run_single_batch_inference(encodings, max_seq_length, model, device) - .sigmoid(); - } - - // Pre-allocate vector with exact capacity to avoid reallocations - let num_batches = (encodings.len() + MAX_BATCH_SIZE - 1) / MAX_BATCH_SIZE; - let mut all_outputs: Vec = Vec::with_capacity(num_batches); - - for batch in encodings.chunks(MAX_BATCH_SIZE) { - let output: Tensor = - Self::run_single_batch_inference(batch, max_seq_length, model, device); - all_outputs.push(output); - } - - Tensor::cat(&all_outputs, 0).sigmoid() - } - - /// Run inference for a single batch of encodings. - fn run_single_batch_inference( - batch: &[Encoding], - max_seq_length: i64, - model: &CModule, - device: Device, - ) -> Tensor { - let (padded_ids, attention_masks) = ModelTokenizer::build_tokens(batch, max_seq_length); - let batch_size: i64 = i64::try_from(batch.len()).unwrap_or(0); - - let input_ids: Tensor = Tensor::from_slice(&padded_ids) - .view((batch_size, max_seq_length)) - .to_device(device); - - let attention_mask: Tensor = Tensor::from_slice(&attention_masks) - .view((batch_size, max_seq_length)) - .to_device(device); - - Self::forward(model, input_ids, attention_mask) - } - - /// Forward pass through the model with input IDs and attention mask. - fn forward(model: &CModule, input_ids: Tensor, attention_mask: Tensor) -> Tensor { - let output: Tensor = tch::no_grad(|| { - model - .forward_ts(&[input_ids, attention_mask]) - .unwrap_or_else(|e| { - print_message( - &format!("Error during model inference: {e}"), - &LogLevel::Error, - ); - std::process::exit(1); - }) - }); - - output - } - - /// Extract anomalies from the model's predictions and batch data. - fn process_output( - batch_data: &[InferableValue], - predictions: &Tensor, - headers: &StringRecord, - ai_analyze: &mut u32, - ) -> Vec { - const THRESHOLD: f64 = 0.8; - // Pre-allocate with conservative estimate (10% anomaly rate) - let estimated_capacity = batch_data.len() / 10; - let mut anomalies: Vec = Vec::with_capacity(estimated_capacity); - - // Get prediction scores as a 1D vector - let scores = predictions.select(1, 1).iter::().unwrap(); - - for (i, score) in scores.enumerate() { - *ai_analyze += 1; - - // Check if the score exceeds the threshold and if the corresponding data exists - if score > THRESHOLD && let Some(data) = batch_data.get(i) - { - let column_name: String = - headers.get(data.column_index).unwrap_or("unknown").into(); - let row_number: u32 = u32::try_from(data.row_number + 2).unwrap_or(u32::MAX); - - anomalies.push(Anomaly::new( - data.value.clone(), - column_name, - row_number, - score as f32, - )); - } - } - - anomalies - } -} diff --git a/src/model/tokenizer.rs b/src/model/tokenizer.rs index 6c48877..5fe3355 100644 --- a/src/model/tokenizer.rs +++ b/src/model/tokenizer.rs @@ -12,6 +12,10 @@ pub struct ModelTokenizer; impl ModelTokenizer { /// Load the tokenizer from a configuration file. + /// + /// # Errors + /// + /// Returns an error if the tokenizer file cannot be loaded or is invalid. pub fn from_config_file(file_path: &str) -> Result> { let tokenizer: Tokenizer = Tokenizer::from_file(file_path).unwrap_or_else(|e| { print_message( @@ -24,6 +28,10 @@ impl ModelTokenizer { } /// Encode the words from a batch of `InferableValue` into a vector of `Encoding` and returns the maximum sequence length. + /// + /// # Panics + /// + /// Panics if tokenization fails for any value in the batch. pub fn encode_words( tokenizer: &Tokenizer, batch_data: &[InferableValue], @@ -34,11 +42,13 @@ impl ModelTokenizer { .map(|data| tokenizer.encode(data.value.as_str(), true).unwrap()) .collect(); - let max_seq_length: i64 = encodings - .iter() - .map(|e| e.get_ids().len()) - .max() - .unwrap_or(0) as i64; + let max_seq_length: i64 = i64::try_from( + encodings + .iter() + .map(|e| e.get_ids().len()) + .max() + .unwrap_or(0) + ).unwrap_or(i64::MAX); (encodings, max_seq_length) } diff --git a/src/tests/model_tests.rs b/src/tests/model_tests.rs index 6ed4751..65413dc 100644 --- a/src/tests/model_tests.rs +++ b/src/tests/model_tests.rs @@ -1,4 +1,4 @@ -use datalib::model::model::Model; +use datalib::model::Model; use std::fs::File; use std::io::Write;