diff --git a/CLAUDE.md b/CLAUDE.md index a6fa5844..7f94dc5f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -124,7 +124,7 @@ DRAW line ▼ │ ┌─────────────┐ │ │ DataFrame │ │ - │ (Polars) │ │ + │ (Arrow) │ │ └──────┬──────┘ │ │ │ └──────────┬───────────┘ @@ -534,7 +534,7 @@ pub trait Reader { - In-memory databases: `duckdb://memory` - File-based databases: `duckdb://path/to/file.db` -- SQL execution → Polars DataFrame conversion +- SQL execution → Arrow DataFrame conversion - Comprehensive type handling **Connection Parsing** (`connection.rs`): @@ -1333,9 +1333,9 @@ VIZ: "VISUALISE DRAW line MAPPING sale_date AS x, ..." ```rust // duckdb.rs connection.execute(sql) → ResultSet -ResultSet → DataFrame (Polars) +ResultSet → DataFrame (Arrow RecordBatch) -// DataFrame columns: sale_date (Date32), region (String), total (Int64) +// DataFrame columns: sale_date (Date32), region (Utf8), total (Int64) // Date32 values converted to ISO format: "2024-01-01" ``` diff --git a/Cargo.lock b/Cargo.lock index 58958f31..2761357f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,21 +43,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "allocator-api2" version = "0.2.21" @@ -163,15 +148,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "ar_archive_writer" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb93bbb63b9c227414f6eb3a0adfddca591a8ce1e9b60661bb08969b87e340b" -dependencies = [ - "object", -] - [[package]] name = "arbitrary" version = "1.4.2" @@ -181,27 +157,6 @@ dependencies = [ "derive_arbitrary", ] -[[package]] -name = "argminmax" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65" -dependencies = [ - "num-traits", -] - -[[package]] -name = "array-init-cursor" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3" - -[[package]] -name = "arrayref" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" - [[package]] name = "arrayvec" version = "0.7.6" @@ -381,40 +336,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "async-channel" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" -dependencies = [ - "concurrent-queue", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] - [[package]] name = "async-trait" version = "0.1.89" @@ -448,15 +369,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atoi_simd" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2a49e05797ca52e312a0c658938b7d00693ef037799ef7187678f212d7684cf" -dependencies = [ - "debug_unsafe", -] - [[package]] name = "atomic-waker" version = "1.1.2" @@ -497,26 +409,6 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" -[[package]] -name = "bincode" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" -dependencies = [ - "bincode_derive", - "serde", - "unty", -] - -[[package]] -name = "bincode_derive" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" -dependencies = [ - "virtue", -] - [[package]] name = "bit-set" version = "0.8.0" @@ -543,9 +435,6 @@ name = "bitflags" version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" -dependencies = [ - "serde_core", -] [[package]] name = "bitvec" @@ -559,20 +448,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "blake3" -version = "1.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2d5991425dfd0785aed03aedcf0b321d61975c9b5b3689c774a2610ae0b51e" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq", - "cpufeatures 0.3.0", -] - [[package]] name = "block-buffer" version = "0.10.4" @@ -630,33 +505,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "boxcar" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e" - -[[package]] -name = "brotli" -version = "8.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "5.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bumpalo" version = "3.20.2" @@ -702,20 +550,6 @@ name = "bytemuck" version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" -dependencies = [ - "bytemuck_derive", -] - -[[package]] -name = "bytemuck_derive" -version = "1.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", -] [[package]] name = "byteorder" @@ -728,9 +562,6 @@ name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" -dependencies = [ - "serde", -] [[package]] name = "calloop" @@ -752,15 +583,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "castaway" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" -dependencies = [ - "rustversion", -] - [[package]] name = "cc" version = "1.2.60" @@ -811,21 +633,10 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-link", ] -[[package]] -name = "chrono-tz" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" -dependencies = [ - "chrono", - "phf 0.12.1", -] - [[package]] name = "clap" version = "4.6.1" @@ -914,21 +725,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "compact_str" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "rustversion", - "ryu", - "serde", - "static_assertions", -] - [[package]] name = "concurrent-queue" version = "2.5.0" @@ -985,12 +781,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "constant_time_eq" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d52eff69cd5e647efe296129160853a42795992097e8af39800e1060caeea9b" - [[package]] name = "core-foundation" version = "0.9.4" @@ -1080,34 +870,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-queue" version = "0.3.12" @@ -1215,12 +977,6 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" -[[package]] -name = "debug_unsafe" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eed2c4702fa172d1ce21078faa7c5203e69f5394d48cc436d25928394a867a2" - [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -1344,18 +1100,6 @@ dependencies = [ "wio", ] -[[package]] -name = "dyn-clone" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" - [[package]] name = "email_address" version = "0.2.9" @@ -1381,33 +1125,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "ethnum" -version = "1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b" - -[[package]] -name = "event-listener" -version = "5.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" -dependencies = [ - "event-listener", - "pin-project-lite", -] - [[package]] name = "fallible-iterator" version = "0.2.0" @@ -1437,12 +1154,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "fast-float2" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" - [[package]] name = "fast-srgb8" version = "1.0.0" @@ -1502,15 +1213,6 @@ dependencies = [ "zlib-rs", ] -[[package]] -name = "float-cmp" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09cf3155332e944990140d967ff5eceb70df778b34f77d8075db46e4704e6d8" -dependencies = [ - "num-traits", -] - [[package]] name = "float-ord" version = "0.3.2" @@ -1609,9 +1311,9 @@ dependencies = [ [[package]] name = "fraction" -version = "0.15.3" +version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f158e3ff0a1b334408dc9fb811cd99b446986f4d8b741bb08f9df1604085ae7" +checksum = "e076045bb43dac435333ed5f04caf35c7463631d0dae2deb2638d94dd0a5b872" dependencies = [ "lazy_static", "num", @@ -1628,16 +1330,6 @@ dependencies = [ "pkg-config", ] -[[package]] -name = "fs4" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" -dependencies = [ - "rustix 1.1.4", - "windows-sys 0.59.0", -] - [[package]] name = "fs_extra" version = "1.3.0" @@ -1650,21 +1342,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" -[[package]] -name = "futures" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.32" @@ -1681,17 +1358,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" -[[package]] -name = "futures-executor" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - [[package]] name = "futures-io" version = "0.3.32" @@ -1727,7 +1393,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ - "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -1795,6 +1460,7 @@ version = "0.2.7" dependencies = [ "anyhow", "arrow", + "bytes", "chrono", "clap", "const_format", @@ -1803,9 +1469,8 @@ dependencies = [ "jsonschema", "odbc-api", "palette", + "parquet", "plotters", - "polars", - "polars-ops", "postgres", "proptest", "rand 0.8.6", @@ -1828,13 +1493,13 @@ name = "ggsql-jupyter" version = "0.2.7" dependencies = [ "anyhow", + "arrow", "bytes", "chrono", "clap", "ggsql", "hex", "hmac 0.12.1", - "polars", "serde", "serde_json", "sha2 0.10.9", @@ -1852,13 +1517,15 @@ dependencies = [ name = "ggsql-wasm" version = "0.2.7" dependencies = [ + "arrow", "csv", + "getrandom 0.2.17", "ggsql", "js-sys", - "polars", "serde_json", "sqlite-wasm-rs", "tokio", + "uuid", "wasm-bindgen", "wasm-bindgen-futures", ] @@ -1873,12 +1540,6 @@ dependencies = [ "weezl", ] -[[package]] -name = "glob" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" - [[package]] name = "h2" version = "0.4.13" @@ -1910,16 +1571,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "halfbrown" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7ed2f2edad8a14c8186b847909a41fbb9c3eafa44f88bd891114ed5019da09" -dependencies = [ - "hashbrown 0.16.1", - "serde", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -1941,8 +1592,6 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "allocator-api2", - "equivalent", "foldhash 0.1.5", ] @@ -1955,9 +1604,6 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash 0.2.0", - "rayon", - "serde", - "serde_core", ] [[package]] @@ -2020,15 +1666,6 @@ dependencies = [ "digest 0.11.2", ] -[[package]] -name = "home" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "http" version = "1.4.0" @@ -2068,12 +1705,6 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" -[[package]] -name = "humantime" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" - [[package]] name = "hybrid-array" version = "0.4.10" @@ -2114,7 +1745,6 @@ dependencies = [ "hyper", "hyper-util", "rustls", - "rustls-native-certs", "tokio", "tokio-rustls", "tower-service", @@ -2303,6 +1933,12 @@ dependencies = [ "serde_core", ] +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + [[package]] name = "ipnet" version = "2.12.0" @@ -2325,15 +1961,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" -[[package]] -name = "itertools" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.18" @@ -2656,25 +2283,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" -[[package]] -name = "lz4" -version = "1.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" -dependencies = [ - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.11.1+lz4-1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "matchers" version = "0.2.0" @@ -2700,15 +2308,6 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" -[[package]] -name = "memmap2" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" -dependencies = [ - "libc", -] - [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2760,15 +2359,6 @@ dependencies = [ "jni-sys 0.3.1", ] -[[package]] -name = "now" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" -dependencies = [ - "chrono", -] - [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -3101,50 +2691,6 @@ dependencies = [ "objc2-foundation", ] -[[package]] -name = "object" -version = "0.37.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" -dependencies = [ - "memchr", -] - -[[package]] -name = "object_store" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbfbfff40aeccab00ec8a910b57ca8ecf4319b335c542f2edcd19dd25a1e2a00" -dependencies = [ - "async-trait", - "base64", - "bytes", - "chrono", - "form_urlencoded", - "futures", - "http", - "http-body-util", - "humantime", - "hyper", - "itertools", - "parking_lot", - "percent-encoding", - "quick-xml", - "rand 0.9.4", - "reqwest 0.12.28", - "ring", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror 2.0.18", - "tokio", - "tracing", - "url", - "walkdir", - "wasm-bindgen-futures", - "web-time", -] - [[package]] name = "odbc-api" version = "13.1.0" @@ -3191,14 +2737,23 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "orbclient" -version = "0.3.51" +version = "0.3.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59aed3b33578edcfa1bc96a321d590d31832b6ad55a26f0313362ce687e9abd6" +checksum = "12c6933ddbbd16539a7672e697bb8d41ac3a4e99ac43eeb40c07236bd7fcb2dd" dependencies = [ "libc", "libredox", ] +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "outref" version = "0.5.2" @@ -3229,12 +2784,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "parking" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" - [[package]] name = "parking_lot" version = "0.12.5" @@ -3258,6 +2807,40 @@ dependencies = [ "windows-link", ] +[[package]] +name = "parquet" +version = "56.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0dbd48ad52d7dccf8ea1b90a3ddbfaea4f69878dd7683e51c507d4bc52b5b27" +dependencies = [ + "ahash 0.8.12", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema", + "arrow-select", + "base64", + "bytes", + "chrono", + "half", + "hashbrown 0.16.1", + "num", + "num-bigint", + "paste", + "seq-macro", + "snap", + "thrift", + "twox-hash", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "pathfinder_geometry" version = "0.5.1" @@ -3293,15 +2876,6 @@ dependencies = [ "phf_shared 0.11.3", ] -[[package]] -name = "phf" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" -dependencies = [ - "phf_shared 0.12.1", -] - [[package]] name = "phf" version = "0.13.1" @@ -3369,15 +2943,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "phf_shared" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" -dependencies = [ - "siphasher", -] - [[package]] name = "phf_shared" version = "0.13.1" @@ -3426,16 +2991,6 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" -[[package]] -name = "planus" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3daf8e3d4b712abe1d690838f6e29fb76b76ea19589c4afa39ec30e12f62af71" -dependencies = [ - "array-init-cursor", - "hashbrown 0.15.5", -] - [[package]] name = "plotters" version = "0.3.7" @@ -3460,581 +3015,39 @@ dependencies = [ name = "plotters-backend" version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" - -[[package]] -name = "plotters-bitmap" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ce181e3f6bf82d6c1dc569103ca7b1bd964c60ba03d7e6cdfbb3e3eb7f7405" -dependencies = [ - "gif", - "image", - "plotters-backend", -] - -[[package]] -name = "plotters-svg" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" -dependencies = [ - "plotters-backend", -] - -[[package]] -name = "png" -version = "0.17.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526" -dependencies = [ - "bitflags 1.3.2", - "crc32fast", - "fdeflate", - "flate2", - "miniz_oxide", -] - -[[package]] -name = "polars" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bc9ea901050c1bb8747ee411bc7fbb390f3b399931e7484719512965132a248" -dependencies = [ - "getrandom 0.2.17", - "getrandom 0.3.4", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-io", - "polars-lazy", - "polars-ops", - "polars-parquet", - "polars-plan", - "polars-sql", - "polars-time", - "polars-utils", - "version_check", -] - -[[package]] -name = "polars-arrow" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d3fe43f8702cf7899ff3d516c2e5f7dc84ee6f6a3007e1a831a0ff87940704" -dependencies = [ - "atoi_simd", - "bitflags 2.11.1", - "bytemuck", - "chrono", - "chrono-tz", - "dyn-clone", - "either", - "ethnum", - "getrandom 0.2.17", - "getrandom 0.3.4", - "hashbrown 0.16.1", - "itoa", - "lz4", - "num-traits", - "polars-arrow-format", - "polars-error", - "polars-schema", - "polars-utils", - "serde", - "simdutf8", - "streaming-iterator", - "strum_macros 0.27.2", - "version_check", - "zstd", -] - -[[package]] -name = "polars-arrow-format" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a556ac0ee744e61e167f34c1eb0013ce740e0ee6cd8c158b2ec0b518f10e6675" -dependencies = [ - "planus", - "serde", -] - -[[package]] -name = "polars-compute" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29cc7497378dee3a002f117e0b4e16b7cbe6c8ed3da16a0229c89294af7c3bf" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "either", - "fast-float2", - "hashbrown 0.16.1", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "rand 0.9.4", - "ryu", - "serde", - "strength_reduce", - "strum_macros 0.27.2", - "version_check", -] - -[[package]] -name = "polars-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48409b7440cb1a4aa84953fe3a4189dfbfb300a3298266a92a37363476641e40" -dependencies = [ - "bitflags 2.11.1", - "boxcar", - "bytemuck", - "chrono", - "chrono-tz", - "either", - "hashbrown 0.16.1", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-row", - "polars-schema", - "polars-utils", - "rand 0.9.4", - "rand_distr", - "rayon", - "regex", - "serde", - "serde_json", - "strum_macros 0.27.2", - "uuid", - "version_check", - "xxhash-rust", -] - -[[package]] -name = "polars-dtype" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7007e9e8b7b657cbd339b65246af7e87f5756ee9a860119b9424ddffd2aaf133" -dependencies = [ - "boxcar", - "hashbrown 0.16.1", - "polars-arrow", - "polars-error", - "polars-utils", - "serde", - "uuid", -] - -[[package]] -name = "polars-error" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a6be22566c89f6405f553bfdb7c8a6cb20ec51b35f3172de9a25fa3e252d85" -dependencies = [ - "object_store", - "parking_lot", - "polars-arrow-format", - "regex", - "signal-hook", - "simdutf8", -] - -[[package]] -name = "polars-expr" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6199a50d3e1afd0674fb009e340cbfb0010682b2387187a36328c00f3f2ca87b" -dependencies = [ - "bitflags 2.11.1", - "hashbrown 0.16.1", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-io", - "polars-ops", - "polars-plan", - "polars-row", - "polars-time", - "polars-utils", - "rand 0.9.4", - "rayon", - "recursive", - "regex", - "version_check", -] - -[[package]] -name = "polars-io" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be3714acdff87170141880a07f5d9233490d3bd5531c41898f6969d440feee11" -dependencies = [ - "async-trait", - "atoi_simd", - "blake3", - "bytes", - "chrono", - "chrono-tz", - "fast-float2", - "fs4", - "futures", - "glob", - "hashbrown 0.16.1", - "home", - "itoa", - "memchr", - "memmap2", - "num-traits", - "object_store", - "percent-encoding", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-json", - "polars-parquet", - "polars-schema", - "polars-time", - "polars-utils", - "rayon", - "regex", - "reqwest 0.12.28", - "ryu", - "serde", - "serde_json", - "simdutf8", - "tokio", -] - -[[package]] -name = "polars-json" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dd2126daebf58da564fc5840cd55eb8eb2479d24dfced0a1aea2178a9b33b12" -dependencies = [ - "chrono", - "chrono-tz", - "fallible-streaming-iterator", - "hashbrown 0.16.1", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-error", - "polars-utils", - "ryu", - "simd-json", - "streaming-iterator", -] - -[[package]] -name = "polars-lazy" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea136c360d03aafe56e0233495e30044ce43639b8b0360a4a38e840233f048a1" -dependencies = [ - "bitflags 2.11.1", - "chrono", - "either", - "memchr", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-plan", - "polars-stream", - "polars-time", - "polars-utils", - "rayon", - "version_check", -] - -[[package]] -name = "polars-mem-engine" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f6e455ceb6e5aee7ed7d5c8944104e66992173e03a9c42f9670226318672249" -dependencies = [ - "futures", - "memmap2", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "tokio", -] - -[[package]] -name = "polars-ops" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b59c80a019ef0e6f09b4416d2647076a52839305c9eb11919e8298ec667f853" -dependencies = [ - "argminmax", - "base64", - "bytemuck", - "chrono", - "chrono-tz", - "either", - "hashbrown 0.16.1", - "hex", - "indexmap", - "libm", - "memchr", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-schema", - "polars-utils", - "rayon", - "regex", - "regex-syntax", - "strum_macros 0.27.2", - "unicode-normalization", - "unicode-reverse", - "version_check", -] - -[[package]] -name = "polars-parquet" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93c2439d127c59e6bfc9d698419bdb45210068a6f501d44e6096429ad72c2eaa" -dependencies = [ - "async-stream", - "base64", - "brotli", - "bytemuck", - "ethnum", - "flate2", - "futures", - "hashbrown 0.16.1", - "lz4", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-error", - "polars-parquet-format", - "polars-utils", - "serde", - "simdutf8", - "snap", - "streaming-decompression", - "zstd", -] - -[[package]] -name = "polars-parquet-format" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" -dependencies = [ - "async-trait", - "futures", -] - -[[package]] -name = "polars-plan" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65b4619f5c7e9b91f18611c9ed82ebeee4b10052160825c1316ecf4dbd4d97e6" -dependencies = [ - "bitflags 2.11.1", - "bytemuck", - "bytes", - "chrono", - "chrono-tz", - "either", - "futures", - "hashbrown 0.16.1", - "memmap2", - "num-traits", - "percent-encoding", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-io", - "polars-ops", - "polars-parquet", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "regex", - "sha2 0.10.9", - "slotmap", - "strum_macros 0.27.2", - "version_check", -] - -[[package]] -name = "polars-row" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a18d232f25b83032e280a279a1f40beb8a6f8fc43907b13dc07b1c56f3b11eea" -dependencies = [ - "bitflags 2.11.1", - "bytemuck", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-utils", -] - -[[package]] -name = "polars-schema" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f73e21d429ae1c23f442b0220ccfe773a9734a44e997b5062a741842909d9441" -dependencies = [ - "indexmap", - "polars-error", - "polars-utils", - "serde", - "version_check", -] - -[[package]] -name = "polars-sql" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e67ac1cbb0c972a57af3be12f19aa9803898863fe95c33cdd39df05f5738a75" -dependencies = [ - "bitflags 2.11.1", - "hex", - "polars-core", - "polars-error", - "polars-lazy", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rand 0.9.4", - "regex", - "serde", - "sqlparser", -] +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" [[package]] -name = "polars-stream" -version = "0.52.0" +name = "plotters-bitmap" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ff19612074640a9d65e5928b7223db76ffee63e55b276f1e466d06719eb7362" +checksum = "72ce181e3f6bf82d6c1dc569103ca7b1bd964c60ba03d7e6cdfbb3e3eb7f7405" dependencies = [ - "async-channel", - "async-trait", - "atomic-waker", - "bitflags 2.11.1", - "chrono-tz", - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-queue", - "crossbeam-utils", - "futures", - "memmap2", - "parking_lot", - "percent-encoding", - "pin-project-lite", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-parquet", - "polars-plan", - "polars-time", - "polars-utils", - "rand 0.9.4", - "rayon", - "recursive", - "slotmap", - "tokio", - "version_check", + "gif", + "image", + "plotters-backend", ] [[package]] -name = "polars-time" -version = "0.52.0" +name = "plotters-svg" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddce7a9f81d5f47d981bcee4a8db004f9596bb51f0f4d9d93667a1a00d88166c" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "chrono-tz", - "now", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-ops", - "polars-utils", - "rayon", - "regex", - "strum_macros 0.27.2", + "plotters-backend", ] [[package]] -name = "polars-utils" -version = "0.52.0" +name = "png" +version = "0.17.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "667c1bc2d2313f934d711f6e3b58d8d9f80351d14ea60af936a26b7dfb06e309" +checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526" dependencies = [ - "bincode", - "bytemuck", - "bytes", - "compact_str", - "either", + "bitflags 1.3.2", + "crc32fast", + "fdeflate", "flate2", - "foldhash 0.2.0", - "hashbrown 0.16.1", - "indexmap", - "libc", - "memmap2", - "num-traits", - "polars-error", - "rand 0.9.4", - "raw-cpuid", - "rayon", - "regex", - "rmp-serde", - "serde", - "serde_json", - "serde_stacker", - "slotmap", - "stacker", - "uuid", - "version_check", + "miniz_oxide", ] [[package]] @@ -4159,16 +3172,6 @@ dependencies = [ "unarray", ] -[[package]] -name = "psm" -version = "0.1.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3852766467df634d74f0b2d7819bf8dc483a0eb2e3b0f50f756f9cfe8b0d18d8" -dependencies = [ - "ar_archive_writer", - "cc", -] - [[package]] name = "ptr_meta" version = "0.1.4" @@ -4195,16 +3198,6 @@ version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" -[[package]] -name = "quick-xml" -version = "0.38.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "quinn" version = "0.11.9" @@ -4363,16 +3356,6 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" -[[package]] -name = "rand_distr" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" -dependencies = [ - "num-traits", - "rand 0.9.4", -] - [[package]] name = "rand_xorshift" version = "0.4.0" @@ -4382,61 +3365,12 @@ dependencies = [ "rand_core 0.9.5", ] -[[package]] -name = "raw-cpuid" -version = "11.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" -dependencies = [ - "bitflags 2.11.1", -] - [[package]] name = "raw-window-handle" version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" -[[package]] -name = "rayon" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - -[[package]] -name = "recursive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" -dependencies = [ - "recursive-proc-macro-impl", - "stacker", -] - -[[package]] -name = "recursive-proc-macro-impl" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" -dependencies = [ - "quote", - "syn 2.0.117", -] - [[package]] name = "redox_syscall" version = "0.4.1" @@ -4559,7 +3493,6 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", "http", "http-body", "http-body-util", @@ -4572,7 +3505,6 @@ dependencies = [ "pin-project-lite", "quinn", "rustls", - "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -4580,14 +3512,12 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", - "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", "webpki-roots", ] @@ -4674,25 +3604,6 @@ dependencies = [ "syn 1.0.109", ] -[[package]] -name = "rmp" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" -dependencies = [ - "num-traits", -] - -[[package]] -name = "rmp-serde" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" -dependencies = [ - "rmp", - "serde", -] - [[package]] name = "rsqlite-vfs" version = "0.1.0" @@ -4844,9 +3755,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.12" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8279bb85272c9f10811ae6a6c547ff594d6a7f3c6c6b02ee9726d1d0dcfcdd06" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", @@ -4937,6 +3848,12 @@ version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" +[[package]] +name = "seq-macro" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" + [[package]] name = "serde" version = "1.0.228" @@ -4981,17 +3898,6 @@ dependencies = [ "zmij", ] -[[package]] -name = "serde_stacker" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4936375d50c4be7eff22293a9344f8e46f323ed2b3c243e52f89138d9bb0f4a" -dependencies = [ - "serde", - "serde_core", - "stacker", -] - [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -5041,16 +3947,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signal-hook" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" -dependencies = [ - "libc", - "signal-hook-registry", -] - [[package]] name = "signal-hook-registry" version = "1.4.8" @@ -5067,22 +3963,6 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" -[[package]] -name = "simd-json" -version = "0.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4255126f310d2ba20048db6321c81ab376f6a6735608bf11f0785c41f01f64e3" -dependencies = [ - "ahash 0.8.12", - "halfbrown", - "once_cell", - "ref-cast", - "serde", - "serde_json", - "simdutf8", - "value-trait", -] - [[package]] name = "simd_cesu8" version = "1.1.1" @@ -5111,15 +3991,6 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" -[[package]] -name = "slotmap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038" -dependencies = [ - "version_check", -] - [[package]] name = "smallvec" version = "1.15.1" @@ -5162,9 +4033,9 @@ dependencies = [ [[package]] name = "sqlite-wasm-rs" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f4206ed3a67690b9c29b77d728f6acc3ce78f16bf846d83c94f76400320181b" +checksum = "1b2c760607300407ddeaee518acf28c795661b7108c75421303dbefb237d3a36" dependencies = [ "cc", "js-sys", @@ -5172,61 +4043,18 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "sqlparser" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" -dependencies = [ - "log", -] - [[package]] name = "stable_deref_trait" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" -[[package]] -name = "stacker" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d74a23609d509411d10e2176dc2a4346e3b4aea2e7b1869f19fdedbc71c013" -dependencies = [ - "cc", - "cfg-if", - "libc", - "psm", - "windows-sys 0.59.0", -] - -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - -[[package]] -name = "streaming-decompression" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" -dependencies = [ - "fallible-streaming-iterator", -] - [[package]] name = "streaming-iterator" version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b2231b7c3057d5e4ad0156fb3dc807d900806020c5ffa3ee6ff2c8c76fb8520" -[[package]] -name = "strength_reduce" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" - [[package]] name = "stringprep" version = "0.1.5" @@ -5411,6 +4239,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -5581,7 +4420,7 @@ dependencies = [ "indexmap", "toml_datetime 1.1.1+spec-1.1.0", "toml_parser", - "winnow 1.0.1", + "winnow 1.0.2", ] [[package]] @@ -5590,7 +4429,7 @@ version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ - "winnow 1.0.1", + "winnow 1.0.2", ] [[package]] @@ -5745,11 +4584,17 @@ version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17f77d76d837a7830fe1d4f12b7b4ba4192c1888001c7164257e4bc6d21d96b4" +[[package]] +name = "twox-hash" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ea3136b675547379c4bd395ca6b938e5ad3c3d20fad76e7fe85f9e0d011419c" + [[package]] name = "typenum" -version = "1.19.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" [[package]] name = "unarray" @@ -5799,15 +4644,6 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" -[[package]] -name = "unicode-reverse" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "unicode-segmentation" version = "1.13.2" @@ -5832,12 +4668,6 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" -[[package]] -name = "unty" -version = "0.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" - [[package]] name = "ureq" version = "3.3.0" @@ -5905,7 +4735,6 @@ checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" dependencies = [ "getrandom 0.4.2", "js-sys", - "serde_core", "wasm-bindgen", ] @@ -5925,18 +4754,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" -[[package]] -name = "value-trait" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e80f0c733af0720a501b3905d22e2f97662d8eacfe082a75ed7ffb5ab08cb59" -dependencies = [ - "float-cmp", - "halfbrown", - "itoa", - "ryu", -] - [[package]] name = "vcpkg" version = "0.2.15" @@ -5949,12 +4766,6 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "virtue" -version = "0.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" - [[package]] name = "vsimd" version = "0.8.0" @@ -6006,11 +4817,11 @@ dependencies = [ [[package]] name = "wasip2" -version = "1.0.2+wasi-0.2.9" +version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", ] [[package]] @@ -6019,7 +4830,7 @@ version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.51.0", ] [[package]] @@ -6109,19 +4920,6 @@ dependencies = [ "wasmparser", ] -[[package]] -name = "wasm-streams" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "wasmparser" version = "0.244.0" @@ -6569,9 +5367,9 @@ dependencies = [ [[package]] name = "winnow" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" dependencies = [ "memchr", ] @@ -6594,6 +5392,12 @@ dependencies = [ "wit-bindgen-rust-macro", ] +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + [[package]] name = "wit-bindgen-core" version = "0.51.0" @@ -6717,12 +5521,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56" -[[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - [[package]] name = "yeslogic-fontconfig-sys" version = "6.0.0" @@ -6901,31 +5699,3 @@ dependencies = [ "log", "simd-adler32", ] - -[[package]] -name = "zstd" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "7.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" -dependencies = [ - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "2.0.16+zstd.1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" -dependencies = [ - "cc", - "pkg-config", -] diff --git a/Cargo.toml b/Cargo.toml index 4577ac13..55e9f6f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,13 +30,13 @@ ggsql = { path = "src", version = "0.2.7" } csscolorparser = "0.8.1" tree-sitter = "0.26" -# Data processing -polars = { version = "0.52", default-features = false } -polars-ops = { version = "0.52", features = ["pivot"] } +# Data container +arrow = { version = "56", default-features = false, features = ["ipc"] } # Readers duckdb = { version = "~1.4", features = ["bundled", "vtab-arrow"] } -arrow = { version = "56", default-features = false, features = ["ipc"] } +parquet = { version = "56", default-features = false, features = ["arrow", "snap"] } +bytes = "1" postgres = "0.19" rusqlite = { version = "0.38", features = ["bundled", "chrono", "functions", "window"] } diff --git a/ggsql-jupyter/Cargo.toml b/ggsql-jupyter/Cargo.toml index 1bcc08d4..9fb5f4b2 100644 --- a/ggsql-jupyter/Cargo.toml +++ b/ggsql-jupyter/Cargo.toml @@ -20,8 +20,8 @@ path = "src/lib.rs" # Core ggsql library ggsql = { workspace = true, features = ["duckdb", "vegalite"] } -# Need polars for DataFrame type -polars = { workspace = true } +# Arrow for DataFrame array types +arrow = { workspace = true } # Async runtime tokio = { workspace = true } diff --git a/ggsql-jupyter/src/connection.rs b/ggsql-jupyter/src/connection.rs index c55385ae..dca5cb6b 100644 --- a/ggsql-jupyter/src/connection.rs +++ b/ggsql-jupyter/src/connection.rs @@ -69,13 +69,13 @@ fn list_catalogs(reader: &dyn Reader) -> Result, String> { let mut catalogs = Vec::new(); for i in 0..df.height() { - if let Ok(val) = col.get(i) { - let name = val.to_string().trim_matches('"').to_string(); - catalogs.push(ObjectSchema { - name, - kind: "catalog".to_string(), - }); - } + let name = ggsql::array_util::value_to_string(col, i) + .trim_matches('"') + .to_string(); + catalogs.push(ObjectSchema { + name, + kind: "catalog".to_string(), + }); } Ok(catalogs) } @@ -91,13 +91,13 @@ fn list_schemas(reader: &dyn Reader, catalog: &str) -> Result, let mut schemas = Vec::new(); for i in 0..df.height() { - if let Ok(val) = col.get(i) { - let name = val.to_string().trim_matches('"').to_string(); - schemas.push(ObjectSchema { - name, - kind: "schema".to_string(), - }); - } + let name = ggsql::array_util::value_to_string(col, i) + .trim_matches('"') + .to_string(); + schemas.push(ObjectSchema { + name, + kind: "schema".to_string(), + }); } Ok(schemas) } @@ -119,24 +119,26 @@ fn list_tables( let mut objects = Vec::new(); for i in 0..df.height() { - if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { - let name = name_val.to_string().trim_matches('"').to_string(); - let table_type = type_val.to_string().trim_matches('"').to_uppercase(); - let kind = if table_type.contains("VIEW") { - "view" - } else if table_type == "TABLE" - || table_type == "BASE TABLE" - || table_type.contains("TABLE") - { - "table" - } else { - continue; // Skip non-table/view objects (stages, procedures, etc.) - }; - objects.push(ObjectSchema { - name, - kind: kind.to_string(), - }); - } + let name = ggsql::array_util::value_to_string(name_col, i) + .trim_matches('"') + .to_string(); + let table_type = ggsql::array_util::value_to_string(type_col, i) + .trim_matches('"') + .to_uppercase(); + let kind = if table_type.contains("VIEW") { + "view" + } else if table_type == "TABLE" + || table_type == "BASE TABLE" + || table_type.contains("TABLE") + { + "table" + } else { + continue; // Skip non-table/view objects (stages, procedures, etc.) + }; + objects.push(ObjectSchema { + name, + kind: kind.to_string(), + }); } Ok(objects) } @@ -159,11 +161,13 @@ fn list_columns( let mut fields = Vec::new(); for i in 0..df.height() { - if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { - let name = name_val.to_string().trim_matches('"').to_string(); - let dtype = type_val.to_string().trim_matches('"').to_string(); - fields.push(FieldSchema { name, dtype }); - } + let name = ggsql::array_util::value_to_string(name_col, i) + .trim_matches('"') + .to_string(); + let dtype = ggsql::array_util::value_to_string(type_col, i) + .trim_matches('"') + .to_string(); + fields.push(FieldSchema { name, dtype }); } Ok(fields) } diff --git a/ggsql-jupyter/src/data_explorer.rs b/ggsql-jupyter/src/data_explorer.rs index 306a729d..94f7e2fe 100644 --- a/ggsql-jupyter/src/data_explorer.rs +++ b/ggsql-jupyter/src/data_explorer.rs @@ -85,11 +85,13 @@ impl DataExplorerState { let num_rows = count_df .column("n") .ok() - .and_then(|col| col.get(0).ok()) - .and_then(|val| { - // Polars AnyValue — try common integer representations - let s = format!("{}", val); - s.parse::().ok() + .and_then(|col| { + if col.is_empty() { + None + } else { + let s = ggsql::array_util::value_to_string(col, 0); + s.parse::().ok() + } }) .unwrap_or(0); @@ -106,17 +108,19 @@ impl DataExplorerState { let mut columns = Vec::new(); for i in 0..columns_df.height() { - if let (Ok(name_val), Ok(type_val)) = (name_col.get(i), type_col.get(i)) { - let name = name_val.to_string().trim_matches('"').to_string(); - let raw_type = type_val.to_string().trim_matches('"').to_string(); - let type_display = sql_type_to_display(&raw_type).to_string(); - let type_name = clean_type_name(&raw_type); - columns.push(ColumnInfo { - name, - type_name, - type_display, - }); - } + let name = ggsql::array_util::value_to_string(name_col, i) + .trim_matches('"') + .to_string(); + let raw_type = ggsql::array_util::value_to_string(type_col, i) + .trim_matches('"') + .to_string(); + let type_display = sql_type_to_display(&raw_type).to_string(); + let type_name = clean_type_name(&raw_type); + columns.push(ColumnInfo { + name, + type_name, + type_display, + }); } Ok(Self { @@ -307,20 +311,16 @@ impl DataExplorerState { let columns: Vec> = (0..df.width()) .map(|col_idx| { let col = df.get_columns()[col_idx].clone(); + use arrow::array::Array; (0..df.height()) .map(|row_idx| { - match col.get(row_idx) { - Ok(val) => { - if val.is_null() { - json!(SPECIAL_VALUE_NULL) - } else { - let s = format!("{}", val); - // Strip surrounding quotes from string values - let s = s.trim_matches('"'); - Value::String(s.to_string()) - } - } - Err(_) => json!(SPECIAL_VALUE_NULL), + if col.is_null(row_idx) { + json!(SPECIAL_VALUE_NULL) + } else { + let s = ggsql::array_util::value_to_string(&col, row_idx); + // Strip surrounding quotes from string values + let s = s.trim_matches('"'); + Value::String(s.to_string()) } }) .collect() @@ -495,16 +495,18 @@ impl DataExplorerState { }; let get_str = |name: &str| -> Option { - df.column(name) - .ok() - .and_then(|c| c.get(0).ok()) - .and_then(|v| { - if v.is_null() { - None - } else { - Some(format!("{}", v).trim_matches('"').to_string()) - } - }) + use arrow::array::Array; + df.column(name).ok().and_then(|c| { + if c.is_empty() || c.is_null(0) { + None + } else { + Some( + ggsql::array_util::value_to_string(c, 0) + .trim_matches('"') + .to_string(), + ) + } + }) }; let get_i64 = @@ -551,18 +553,18 @@ impl DataExplorerState { let median_expr = dialect.sql_percentile(&col_name, 0.5, &from_query, &[]); let median_sql = format!("SELECT {} AS \"median_val\"", median_expr); if let Ok(median_df) = reader.execute_sql(&median_sql) { - if let Some(v) = median_df - .column("median_val") - .ok() - .and_then(|c| c.get(0).ok()) - .and_then(|v| { - if v.is_null() { - None - } else { - Some(format!("{}", v).trim_matches('"').to_string()) - } - }) - { + use arrow::array::Array; + if let Some(v) = median_df.column("median_val").ok().and_then(|c| { + if c.is_empty() || c.is_null(0) { + None + } else { + Some( + ggsql::array_util::value_to_string(c, 0) + .trim_matches('"') + .to_string(), + ) + } + }) { number_stats["median"] = json!(v); } } @@ -682,17 +684,17 @@ impl DataExplorerState { let bounds_df = reader.execute_sql(&bounds_sql).ok()?; let get_f64 = |name: &str| -> Option { - bounds_df - .column(name) - .ok() - .and_then(|c| c.get(0).ok()) - .and_then(|v| { - if v.is_null() { - None - } else { - format!("{}", v).trim_matches('"').parse::().ok() - } - }) + use arrow::array::Array; + bounds_df.column(name).ok().and_then(|c| { + if c.is_empty() || c.is_null(0) { + None + } else { + ggsql::array_util::value_to_string(c, 0) + .trim_matches('"') + .parse::() + .ok() + } + }) }; let min_val = get_f64("min_val")?; @@ -760,15 +762,13 @@ impl DataExplorerState { let bin_col = hist_df.column("clamped_bin").ok()?; let cnt_col = hist_df.column("cnt").ok()?; for i in 0..hist_df.height() { - if let (Ok(bin_val), Ok(cnt_val)) = (bin_col.get(i), cnt_col.get(i)) { - let bin_str = format!("{}", bin_val); - // Parse bin index — may be float (e.g., "3.0") on some backends - if let Ok(bin_idx) = bin_str.parse::() { - let idx = bin_idx as usize; - if idx < num_bins { - let count_str = format!("{}", cnt_val); - bin_counts[idx] = count_str.parse::().unwrap_or(0); - } + let bin_str = ggsql::array_util::value_to_string(bin_col, i); + // Parse bin index — may be float (e.g., "3.0") on some backends + if let Ok(bin_idx) = bin_str.parse::() { + let idx = bin_idx as usize; + if idx < num_bins { + let count_str = ggsql::array_util::value_to_string(cnt_col, i); + bin_counts[idx] = count_str.parse::().unwrap_or(0); } } } @@ -788,18 +788,18 @@ impl DataExplorerState { let expr = dialect.sql_percentile(&col_name, q_val, &from_query, &[]); let q_sql = format!("SELECT {} AS \"q_val\"", expr); if let Ok(q_df) = reader.execute_sql(&q_sql) { - if let Some(v) = q_df - .column("q_val") - .ok() - .and_then(|c| c.get(0).ok()) - .and_then(|v| { - if v.is_null() { - None - } else { - Some(format!("{}", v).trim_matches('"').to_string()) - } - }) - { + use arrow::array::Array; + if let Some(v) = q_df.column("q_val").ok().and_then(|c| { + if c.is_empty() || c.is_null(0) { + None + } else { + Some( + ggsql::array_util::value_to_string(c, 0) + .trim_matches('"') + .to_string(), + ) + } + }) { quantile_results.push(json!({"q": q_val, "value": v})); } } @@ -846,13 +846,15 @@ impl DataExplorerState { let mut top_total: i64 = 0; for i in 0..df.height() { - if let (Ok(v), Ok(c)) = (val_col.get(i), cnt_col.get(i)) { - let val_str = format!("{}", v).trim_matches('"').to_string(); - let count: i64 = format!("{}", c).parse().unwrap_or(0); - values.push(Value::String(val_str)); - counts.push(count); - top_total += count; - } + let val_str = ggsql::array_util::value_to_string(val_col, i) + .trim_matches('"') + .to_string(); + let count: i64 = ggsql::array_util::value_to_string(cnt_col, i) + .parse() + .unwrap_or(0); + values.push(Value::String(val_str)); + counts.push(count); + top_total += count; } // Compute other_count: total non-null rows minus the top-K sum @@ -865,10 +867,14 @@ impl DataExplorerState { .execute_sql(&count_sql) .ok() .and_then(|df| { - df.column("total") - .ok() - .and_then(|c| c.get(0).ok()) - .and_then(|v| format!("{}", v).parse::().ok()) + use arrow::array::Array; + df.column("total").ok().and_then(|c| { + if c.is_empty() || c.is_null(0) { + None + } else { + ggsql::array_util::value_to_string(c, 0).parse::().ok() + } + }) }) .map(|total| total - top_total) .unwrap_or(0); diff --git a/ggsql-jupyter/src/display.rs b/ggsql-jupyter/src/display.rs index 8de08818..ced44430 100644 --- a/ggsql-jupyter/src/display.rs +++ b/ggsql-jupyter/src/display.rs @@ -4,7 +4,7 @@ //! with appropriate MIME types for rich rendering. use crate::executor::ExecutionResult; -use polars::frame::DataFrame; +use ggsql::DataFrame; use serde_json::{json, Value}; /// Format execution result as Jupyter display_data content @@ -170,7 +170,7 @@ console.error('Failed to load Vega libraries:', err); /// Format DataFrame as HTML table fn format_dataframe(df: DataFrame) -> Value { let html = dataframe_to_html(&df); - let text = format!("{}", df); + let text = dataframe_to_text(&df); json!({ "data": { @@ -184,11 +184,13 @@ fn format_dataframe(df: DataFrame) -> Value { /// Convert DataFrame to HTML table fn dataframe_to_html(df: &DataFrame) -> String { + use ggsql::array_util::value_to_string; + let mut html = String::from("\n"); // Header row for col in df.get_column_names() { - html.push_str(&format!("", escape_html(col))); + html.push_str(&format!("", escape_html(&col))); } html.push_str("\n\n"); @@ -197,10 +199,8 @@ fn dataframe_to_html(df: &DataFrame) -> String { for i in 0..row_limit { html.push_str(""); for col in df.get_columns() { - let value = col - .get(i) - .unwrap_or_else(|_| polars::prelude::AnyValue::Null); - html.push_str(&format!("", escape_html(&value.to_string()))); + let value = value_to_string(col, i); + html.push_str(&format!("", escape_html(&value))); } html.push_str("\n"); } @@ -217,6 +217,27 @@ fn dataframe_to_html(df: &DataFrame) -> String { html } +/// Convert DataFrame to plain-text summary (shape + column names + first rows). +fn dataframe_to_text(df: &ggsql::DataFrame) -> String { + use ggsql::array_util::value_to_string; + + let mut s = format!("shape: ({}, {})\n", df.height(), df.width()); + let names = df.get_column_names(); + s.push_str(&names.join("\t")); + s.push('\n'); + let row_limit = df.height().min(10); + for i in 0..row_limit { + let row: Vec = df + .get_columns() + .iter() + .map(|c| value_to_string(c, i)) + .collect(); + s.push_str(&row.join("\t")); + s.push('\n'); + } + s +} + /// Escape HTML special characters fn escape_html(s: &str) -> String { s.replace('&', "&") @@ -242,10 +263,8 @@ mod tests { #[test] fn test_empty_dataframe_returns_none() { - use polars::prelude::*; - // DDL statements return DataFrames with 0 columns - let df = DataFrame::new(Vec::::new()).unwrap(); + let df = DataFrame::empty(); let result = ExecutionResult::DataFrame(df); let display = format_display_data(result); @@ -257,10 +276,12 @@ mod tests { #[test] fn test_empty_rows_dataframe_returns_some() { - use polars::prelude::*; + use arrow::array::{ArrayRef, Int32Array}; + use std::sync::Arc; // SELECT with 0 rows but columns should still display - let df = DataFrame::new(vec![Column::new("x".into(), Vec::::new())]).unwrap(); + let empty: ArrayRef = Arc::new(Int32Array::from(Vec::::new())); + let df = DataFrame::new(vec![("x", empty)]).unwrap(); let result = ExecutionResult::DataFrame(df); let display = format_display_data(result); diff --git a/ggsql-jupyter/src/executor.rs b/ggsql-jupyter/src/executor.rs index 434345c3..435c7295 100644 --- a/ggsql-jupyter/src/executor.rs +++ b/ggsql-jupyter/src/executor.rs @@ -9,8 +9,8 @@ use ggsql::{ reader::{connection::parse_connection_string, DuckDBReader, Reader}, validate::validate, writer::{VegaLiteWriter, Writer}, + DataFrame, }; -use polars::frame::DataFrame; /// Result of executing a ggsql query #[derive(Debug)] diff --git a/ggsql-jupyter/src/util.rs b/ggsql-jupyter/src/util.rs index 1e1bf152..b722677b 100644 --- a/ggsql-jupyter/src/util.rs +++ b/ggsql-jupyter/src/util.rs @@ -1,9 +1,10 @@ -use polars::prelude::{Column, DataFrame}; +use arrow::array::ArrayRef; +use ggsql::DataFrame; /// Find a DataFrame column by name, trying multiple names and falling back to /// case-insensitive matching. This handles ODBC drivers that return uppercase /// column names (e.g. `TABLE_NAME` instead of `table_name`). -pub fn find_column<'a>(df: &'a DataFrame, names: &[&str]) -> Result<&'a Column, String> { +pub fn find_column<'a>(df: &'a DataFrame, names: &[&str]) -> Result<&'a ArrayRef, String> { // Try exact match first for name in names { if let Ok(col) = df.column(name) { diff --git a/ggsql-wasm/Cargo.toml b/ggsql-wasm/Cargo.toml index e4348e23..79a6543e 100644 --- a/ggsql-wasm/Cargo.toml +++ b/ggsql-wasm/Cargo.toml @@ -16,7 +16,7 @@ wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" js-sys = "0.3" csv = "1" -polars = { version = "0.52", default-features = false, features = ["dtype-full"] } +arrow = { workspace = true } ggsql = { path = "../src", default-features = false, features = ["vegalite", "sqlite", "builtin-data"] } serde_json = "1" @@ -26,4 +26,10 @@ tokio = { version = "1.35", features = ["full"] } [target.'cfg(target_arch = "wasm32")'.dependencies] tokio = { version = "1.35", default-features = false } sqlite-wasm-rs = "0.5.2" +# Transitive dep feature overrides for wasm32-unknown-unknown. +# Cargo's feature unification activates these on the transitive deps. +# - getrandom: pulled in by arrow (via ahash/const-random), needs "js" for wasm +# - uuid: pulled in by ggsql, needs "js" to obtain randomness on wasm +getrandom = { version = "0.2", features = ["js"] } +uuid = { workspace = true, features = ["js"] } diff --git a/ggsql-wasm/src/lib.rs b/ggsql-wasm/src/lib.rs index d007a6e3..8fa1b366 100644 --- a/ggsql-wasm/src/lib.rs +++ b/ggsql-wasm/src/lib.rs @@ -1,13 +1,17 @@ +use arrow::array::{ + ArrayRef, BooleanArray, Date32Array, Float64Array, Int64Array, StringArray, + TimestampMillisecondArray, +}; +use ggsql::array_util::value_to_string; use ggsql::naming::DATA_PREFIX; use ggsql::reader::sqlite::SqliteReader; use ggsql::reader::Reader; use ggsql::validate::validate; use ggsql::writer::{VegaLiteWriter, Writer}; use ggsql::DataFrame; -use polars::prelude::IntoColumn; -use polars::prelude::*; use serde_json::json; use std::cell::RefCell; +use std::sync::Arc; use wasm_bindgen::prelude::*; @@ -46,7 +50,7 @@ fn ensure_vfs_initialized() { // Column descriptor → DataFrame conversion (for JS CSV/Parquet parsing) // ============================================================================ -/// Convert JS column descriptors to a Polars DataFrame. +/// Convert JS column descriptors to an Arrow-backed DataFrame. fn columns_js_to_dataframe(columns_js: JsValue) -> Result { let columns = js_sys::Array::from(&columns_js); let len = columns.length(); @@ -55,7 +59,10 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { return Ok(DataFrame::empty()); } - let mut series_vec: Vec = Vec::with_capacity(len as usize); + // Collect owned (name, array) pairs; DataFrame::new borrows the names so + // we build a parallel Vec to pin them for the lifetime of the call. + let mut names: Vec = Vec::with_capacity(len as usize); + let mut arrays: Vec = Vec::with_capacity(len as usize); for i in 0..len { let col = columns.get(i); @@ -74,7 +81,7 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { let nulls = js_sys::Uint8Array::new(&nulls_js).to_vec(); - let series = match col_type.as_str() { + let array: ArrayRef = match col_type.as_str() { "f64" => { let raw = js_sys::Float64Array::new(&values_js).to_vec(); let values: Vec> = raw @@ -82,7 +89,7 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { .zip(nulls.iter()) .map(|(v, &n)| if n != 0 { Some(v) } else { None }) .collect(); - Series::new(col_name.as_str().into(), values) + Arc::new(Float64Array::from(values)) } "i64" => { let raw = js_sys::Float64Array::new(&values_js).to_vec(); @@ -91,7 +98,7 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { .zip(nulls.iter()) .map(|(v, &n)| if n != 0 { Some(v as i64) } else { None }) .collect(); - Series::new(col_name.as_str().into(), values) + Arc::new(Int64Array::from(values)) } "bool" => { let raw = js_sys::Uint8Array::new(&values_js).to_vec(); @@ -100,7 +107,7 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { .zip(nulls.iter()) .map(|(v, &n)| if n != 0 { Some(v != 0) } else { None }) .collect(); - Series::new(col_name.as_str().into(), values) + Arc::new(BooleanArray::from(values)) } "string" => { let arr = js_sys::Array::from(&values_js); @@ -108,32 +115,27 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { .zip(nulls.iter()) .map(|(j, &n)| if n != 0 { arr.get(j).as_string() } else { None }) .collect(); - Series::new(col_name.as_str().into(), values) + Arc::new(StringArray::from(values)) } "date" => { + // Date32: days since Unix epoch let raw = js_sys::Float64Array::new(&values_js).to_vec(); let values: Vec> = raw .into_iter() .zip(nulls.iter()) .map(|(v, &n)| if n != 0 { Some(v as i32) } else { None }) .collect(); - let s = Series::new(col_name.as_str().into(), values); - s.cast(&DataType::Date).map_err(|e| { - JsValue::from_str(&format!("Date cast error for '{}': {}", col_name, e)) - })? + Arc::new(Date32Array::from(values)) } "datetime" => { + // Timestamp(Millisecond): milliseconds since Unix epoch let raw = js_sys::Float64Array::new(&values_js).to_vec(); let values: Vec> = raw .into_iter() .zip(nulls.iter()) .map(|(v, &n)| if n != 0 { Some(v as i64) } else { None }) .collect(); - let s = Series::new(col_name.as_str().into(), values); - s.cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) - .map_err(|e| { - JsValue::from_str(&format!("Datetime cast error for '{}': {}", col_name, e)) - })? + Arc::new(TimestampMillisecondArray::from(values)) } other => { return Err(JsValue::from_str(&format!( @@ -143,10 +145,17 @@ fn columns_js_to_dataframe(columns_js: JsValue) -> Result { } }; - series_vec.push(series.into_column()); + names.push(col_name); + arrays.push(array); } - DataFrame::new(series_vec) + let named: Vec<(&str, ArrayRef)> = names + .iter() + .zip(arrays) + .map(|(n, a)| (n.as_str(), a)) + .collect(); + + DataFrame::new(named) .map_err(|e| JsValue::from_str(&format!("DataFrame creation error: {}", e))) } @@ -217,26 +226,15 @@ impl GgsqlContext { let max_rows = 100usize; let total_rows = df.height(); let truncated = total_rows > max_rows; - let df = if truncated { - df.head(Some(max_rows)) - } else { - df - }; + let df = if truncated { df.slice(0, max_rows) } else { df }; - let columns: Vec = df - .get_column_names() - .into_iter() - .map(|s| s.to_string()) - .collect(); + let columns: Vec = df.get_column_names(); let mut rows: Vec> = Vec::with_capacity(df.height()); for i in 0..df.height() { let mut row = Vec::with_capacity(columns.len()); for col in df.get_columns() { - let val = col - .get(i) - .map_err(|e| JsValue::from_str(&format!("Error reading row {}: {}", i, e)))?; - row.push(format!("{}", val)); + row.push(value_to_string(col, i)); } rows.push(row); } diff --git a/src/CHANGELOG.md b/src/CHANGELOG.md new file mode 100644 index 00000000..57d3aa17 --- /dev/null +++ b/src/CHANGELOG.md @@ -0,0 +1,15 @@ +## [Unreleased] + +### Added + +- ODBC is now turned on for the CLI as well (#344) + +### Removed + +- Removed polars from dependency list along with all its transient dependencies. Rewrote DataFrame struct on top of arrow (#350) +- Moved ggsql-python to its own repo (posit-dev/ggsql-python) and cleaned up any additional references to it +- Moved ggsql-r to its own repo (posit-dev/ggsql-r) + +## [2.7.0] - 2026-04-20 + +- First alpha release. No changes tracked before this diff --git a/src/Cargo.toml b/src/Cargo.toml index 97ffaaa5..239e82df 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -25,22 +25,17 @@ csscolorparser.workspace = true # Color interpolation palette.workspace = true -# Data processing -polars = { workspace = true, features = [ - "lazy", - "cum_agg", - "dtype-full", - "timezones", -] } -polars-ops.workspace = true +# Arrow (core data container) +arrow = { workspace = true } # Readers duckdb = { workspace = true, optional = true } -arrow = { workspace = true, optional = true } postgres = { workspace = true, optional = true } rusqlite = { workspace = true, optional = true } odbc-api = { workspace = true, optional = true } toml_edit = { workspace = true, optional = true } +parquet = { workspace = true, optional = true } +bytes = { workspace = true } # Writers plotters = { workspace = true, optional = true } @@ -70,9 +65,9 @@ ureq = "3" [features] default = ["duckdb", "sqlite", "vegalite", "ipc", "parquet", "builtin-data", "odbc"] -ipc = ["polars/ipc"] -duckdb = ["dep:duckdb", "dep:arrow"] -parquet = ["polars/parquet"] +ipc = [] +duckdb = ["dep:duckdb"] +parquet = ["dep:parquet"] postgres = ["dep:postgres"] sqlite = ["dep:rusqlite"] odbc = ["dep:odbc-api", "dep:toml_edit"] diff --git a/src/array_util.rs b/src/array_util.rs new file mode 100644 index 00000000..5596039c --- /dev/null +++ b/src/array_util.rs @@ -0,0 +1,298 @@ +//! Typed array access helpers for Arrow arrays. +//! +//! Replaces the polars pattern of `series.f64()`, `series.str()`, etc. +//! with arrow downcasting via `as_f64(array)`, `as_str(array)`, etc. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, LargeStringArray, StringArray, Time64NanosecondArray, + TimestampMicrosecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::compute; +use arrow::datatypes::DataType; +use std::sync::Arc; + +use crate::{GgsqlError, Result}; + +// ============================================================================ +// Downcast helpers +// ============================================================================ + +macro_rules! downcast_fn { + ($fn_name:ident, $arrow_type:ty, $type_name:expr) => { + pub fn $fn_name(array: &ArrayRef) -> Result<&$arrow_type> { + array.as_any().downcast_ref::<$arrow_type>().ok_or_else(|| { + GgsqlError::InternalError(format!( + "Expected {} array, got {:?}", + $type_name, + array.data_type() + )) + }) + } + }; +} + +downcast_fn!(as_f64, Float64Array, "Float64"); +downcast_fn!(as_f32, Float32Array, "Float32"); +downcast_fn!(as_i64, Int64Array, "Int64"); +downcast_fn!(as_i32, Int32Array, "Int32"); +downcast_fn!(as_i16, Int16Array, "Int16"); +downcast_fn!(as_i8, Int8Array, "Int8"); +downcast_fn!(as_u64, UInt64Array, "UInt64"); +downcast_fn!(as_u32, UInt32Array, "UInt32"); +downcast_fn!(as_u16, UInt16Array, "UInt16"); +downcast_fn!(as_u8, UInt8Array, "UInt8"); +downcast_fn!(as_str, StringArray, "String"); +downcast_fn!(as_bool, BooleanArray, "Boolean"); +downcast_fn!(as_date32, Date32Array, "Date32"); +downcast_fn!( + as_timestamp_us, + TimestampMicrosecondArray, + "Timestamp(Microsecond)" +); +downcast_fn!(as_time64_ns, Time64NanosecondArray, "Time64(Nanosecond)"); + +// ============================================================================ +// Cast helper +// ============================================================================ + +/// Cast an array to a different data type. +/// +/// Arrow's `compute::cast` can't cast directly between temporal types (Date32, +/// Timestamp, Time64) and floating-point types — it only allows going via the +/// integer backing representation. This wrapper bridges the gap so callers can +/// treat temporal columns as numeric without special-casing every site. +pub fn cast_array(array: &ArrayRef, to: &DataType) -> Result { + let from = array.data_type(); + let do_cast = |arr: &ArrayRef, dt: &DataType| -> Result { + compute::cast(arr, dt) + .map_err(|e| GgsqlError::InternalError(format!("Failed to cast to {:?}: {}", dt, e))) + }; + + let bridge = match (from, to) { + // Temporal → floating: go via the integer backing type. + (DataType::Date32, DataType::Float32 | DataType::Float64) => Some(DataType::Int32), + ( + DataType::Timestamp(_, _) | DataType::Time64(_), + DataType::Float32 | DataType::Float64, + ) => Some(DataType::Int64), + // Floating → temporal: same bridge in reverse. + (DataType::Float32 | DataType::Float64, DataType::Date32) => Some(DataType::Int32), + ( + DataType::Float32 | DataType::Float64, + DataType::Timestamp(_, _) | DataType::Time64(_), + ) => Some(DataType::Int64), + _ => None, + }; + + match bridge { + Some(mid) => { + let intermediate = do_cast(array, &mid)?; + do_cast(&intermediate, to) + } + None => do_cast(array, to), + } +} + +// ============================================================================ +// Array construction helpers +// ============================================================================ + +/// Create a Float64 array from optional values. +pub fn new_f64_array(values: Vec>) -> ArrayRef { + Arc::new(Float64Array::from(values)) +} + +/// Create a Float64 array from non-null values. +pub fn new_f64_array_non_null(values: Vec) -> ArrayRef { + Arc::new(Float64Array::from(values)) +} + +/// Create an Int32 array from optional values. +pub fn new_i32_array(values: Vec>) -> ArrayRef { + Arc::new(Int32Array::from(values)) +} + +/// Create an Int64 array from optional values. +pub fn new_i64_array(values: Vec>) -> ArrayRef { + Arc::new(Int64Array::from(values)) +} + +/// Create a String array from optional values. +pub fn new_str_array(values: Vec>) -> ArrayRef { + Arc::new(StringArray::from(values)) +} + +/// Create a String array from owned strings. +pub fn new_string_array(values: Vec) -> ArrayRef { + let refs: Vec<&str> = values.iter().map(|s| s.as_str()).collect(); + Arc::new(StringArray::from(refs)) +} + +/// Create a Boolean array from optional values. +pub fn new_bool_array(values: Vec>) -> ArrayRef { + Arc::new(BooleanArray::from(values)) +} + +/// Create a constant Float64 array (all same value). +pub fn new_constant_f64(value: f64, len: usize) -> ArrayRef { + Arc::new(Float64Array::from(vec![value; len])) +} + +/// Create a constant String array (all same value). +pub fn new_constant_str(value: &str, len: usize) -> ArrayRef { + Arc::new(StringArray::from(vec![value; len])) +} + +/// Create a constant Boolean array (all same value). +pub fn new_constant_bool(value: bool, len: usize) -> ArrayRef { + Arc::new(BooleanArray::from(vec![value; len])) +} + +// ============================================================================ +// Null handling +// ============================================================================ + +/// Replace null values in a Float64 array with a fill value. +pub fn fill_null_f64(array: &Float64Array, fill: f64) -> Float64Array { + let mut builder = arrow::array::Float64Builder::with_capacity(array.len()); + for v in array.iter() { + builder.append_value(v.unwrap_or(fill)); + } + builder.finish() +} + +// ============================================================================ +// Value extraction helpers +// ============================================================================ + +/// Get a string representation of a value at an index, for any array type. +/// Used for building composite group keys (e.g., in dodge). +pub fn value_to_string(array: &ArrayRef, idx: usize) -> String { + if array.is_null(idx) { + return "null".to_string(); + } + match array.data_type() { + DataType::Int8 => as_i8(array).unwrap().value(idx).to_string(), + DataType::Int16 => as_i16(array).unwrap().value(idx).to_string(), + DataType::Int32 => as_i32(array).unwrap().value(idx).to_string(), + DataType::Int64 => as_i64(array).unwrap().value(idx).to_string(), + DataType::UInt8 => as_u8(array).unwrap().value(idx).to_string(), + DataType::UInt16 => as_u16(array).unwrap().value(idx).to_string(), + DataType::UInt32 => as_u32(array).unwrap().value(idx).to_string(), + DataType::UInt64 => as_u64(array).unwrap().value(idx).to_string(), + DataType::Float32 => as_f32(array).unwrap().value(idx).to_string(), + DataType::Float64 => as_f64(array).unwrap().value(idx).to_string(), + DataType::Utf8 => as_str(array).unwrap().value(idx).to_string(), + DataType::LargeUtf8 => array + .as_any() + .downcast_ref::() + .unwrap() + .value(idx) + .to_string(), + DataType::Boolean => as_bool(array).unwrap().value(idx).to_string(), + DataType::Date32 => { + let days = as_date32(array).unwrap().value(idx); + format!("{}", days) + } + DataType::Date64 => { + let ms = array + .as_any() + .downcast_ref::() + .unwrap() + .value(idx); + format!("{}", ms) + } + _ => arrow::util::display::ArrayFormatter::try_new(array.as_ref(), &Default::default()) + .and_then(|f| Ok(f.value(idx).to_string())) + .unwrap_or_else(|_| format!("{:?}", array.data_type())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_downcast_f64() { + let arr: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])); + let f64_arr = as_f64(&arr).unwrap(); + assert_eq!(f64_arr.value(0), 1.0); + assert_eq!(f64_arr.value(2), 3.0); + } + + #[test] + fn test_downcast_wrong_type() { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + assert!(as_f64(&arr).is_err()); + } + + #[test] + fn test_cast_array() { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let casted = cast_array(&arr, &DataType::Float64).unwrap(); + let f64_arr = as_f64(&casted).unwrap(); + assert_eq!(f64_arr.value(0), 1.0); + } + + #[test] + fn test_cast_date32_to_float64() { + // Arrow can't cast Date32 → Float64 directly; cast_array bridges via Int32. + let arr: ArrayRef = Arc::new(Date32Array::from(vec![19723, 19875])); + let casted = cast_array(&arr, &DataType::Float64).unwrap(); + let f64_arr = as_f64(&casted).unwrap(); + assert_eq!(f64_arr.value(0), 19723.0); + assert_eq!(f64_arr.value(1), 19875.0); + } + + #[test] + fn test_cast_float64_to_date32() { + // Reverse: Float64 → Date32 also needs to bridge via Int32. + let arr: ArrayRef = Arc::new(Float64Array::from(vec![19723.0, 19875.0])); + let casted = cast_array(&arr, &DataType::Date32).unwrap(); + assert_eq!(casted.data_type(), &DataType::Date32); + } + + #[test] + fn test_cast_timestamp_to_float64() { + use arrow::datatypes::TimeUnit; + let arr: ArrayRef = Arc::new(TimestampMicrosecondArray::from(vec![ + 1_000_000_i64, + 2_000_000, + ])); + let casted = cast_array(&arr, &DataType::Float64).unwrap(); + let f64_arr = as_f64(&casted).unwrap(); + assert_eq!(f64_arr.value(0), 1_000_000.0); + // Make sure the reverse also works with a concrete unit/tz. + let back = cast_array(&casted, &DataType::Timestamp(TimeUnit::Microsecond, None)).unwrap(); + assert!(matches!(back.data_type(), DataType::Timestamp(_, _))); + } + + #[test] + fn test_fill_null_f64() { + let arr = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); + let filled = fill_null_f64(&arr, 0.0); + assert_eq!(filled.value(0), 1.0); + assert_eq!(filled.value(1), 0.0); + assert_eq!(filled.value(2), 3.0); + assert!(!filled.is_null(1)); + } + + #[test] + fn test_new_constant_f64() { + let arr = new_constant_f64(42.0, 3); + let f64_arr = as_f64(&arr).unwrap(); + assert_eq!(f64_arr.len(), 3); + assert_eq!(f64_arr.value(0), 42.0); + assert_eq!(f64_arr.value(2), 42.0); + } + + #[test] + fn test_value_to_string() { + let arr: ArrayRef = Arc::new(StringArray::from(vec!["hello", "world"])); + assert_eq!(value_to_string(&arr, 0), "hello"); + + let arr: ArrayRef = Arc::new(Float64Array::from(vec![3.24])); + assert_eq!(value_to_string(&arr, 0), "3.24"); + } +} diff --git a/src/cli.rs b/src/cli.rs index 049d679e..38f577b4 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -387,27 +387,24 @@ fn print_table_fallback(query: &str, reader: &R, max_rows: usize) { let nrow = data.height().min(max_rows); let ncol = data.width(); - let colnames = data.get_column_names_str(); + let colnames = data.get_column_names(); // We add an extra 'row' for the column names let mut rows: Vec = vec![String::from(""); nrow + 1]; - for col_id in 0..ncol { - let col_name = colnames[col_id]; + let columns = data.get_columns(); + for (col_id, (col_name, column_data)) in colnames.iter().zip(columns.iter()).enumerate() { let mut width = col_name.chars().count(); // End last column without comma - let mut suffix = ", "; - if col_id == ncol - 1 { - suffix = ""; - } + let suffix = if col_id == ncol - 1 { "" } else { ", " }; // Prepopulate formatted column with column name let mut col_fmt: Vec = vec![format!("{}{}", col_name, suffix)]; // Format every cell in column, tracking width - let column_data = data[col_id].as_materialized_series(); - for cell in column_data.iter().take(rows.len()) { + for row_idx in 0..nrow { + let cell = ggsql::array_util::value_to_string(column_data, row_idx); let cell_fmt = format!("{}{}", cell, suffix); let nchar = cell_fmt.chars().count(); if nchar > width { @@ -422,8 +419,8 @@ fn print_table_fallback(query: &str, reader: &R, max_rows: usize) { .collect(); // Push columns to row string - for i in 0..rows.len() { - rows[i].push_str(col_fmt[i].as_str()); + for (row, fmt) in rows.iter_mut().zip(col_fmt.iter()) { + row.push_str(fmt.as_str()); } } diff --git a/src/compute.rs b/src/compute.rs new file mode 100644 index 00000000..28a36789 --- /dev/null +++ b/src/compute.rs @@ -0,0 +1,381 @@ +//! Grouped window operations for position adjustments. +//! +//! Replaces polars lazy evaluation (`cum_sum().over()`, `shift()`, etc.) +//! with direct arrow compute operations. Used primarily by stack.rs. + +use arrow::array::{Array, ArrayRef, Float64Array, UInt32Array}; +use arrow::compute; +use arrow::compute::SortOptions; + +use crate::array_util::{as_f64, fill_null_f64, value_to_string}; +use crate::dataframe::DataFrame; +use crate::{GgsqlError, Result}; + +// ============================================================================ +// Sorting +// ============================================================================ + +/// Sort a DataFrame by multiple columns (all ascending). +pub fn sort_dataframe(df: &DataFrame, columns: &[&str]) -> Result { + if columns.is_empty() || df.height() == 0 { + return Ok(df.clone()); + } + + // Build sort columns for lexsort + let sort_columns: Vec = columns + .iter() + .map(|&name| { + let col = df.column(name)?; + Ok(arrow::compute::SortColumn { + values: col.clone(), + options: Some(SortOptions::default()), + }) + }) + .collect::>>()?; + + let indices = compute::lexsort_to_indices(&sort_columns, None).map_err(|e| { + GgsqlError::InternalError(format!("Failed to sort by {:?}: {}", columns, e)) + })?; + + reorder_by_indices(df, &indices) +} + +/// Reorder all columns of a DataFrame using an index array. +fn reorder_by_indices(df: &DataFrame, indices: &UInt32Array) -> Result { + let names = df.get_column_names(); + let mut new_columns: Vec<(&str, ArrayRef)> = Vec::with_capacity(df.width()); + + for name in &names { + let col = df.column(name)?; + let reordered = compute::take(col.as_ref(), indices, None).map_err(|e| { + GgsqlError::InternalError(format!("Failed to reorder column '{}': {}", name, e)) + })?; + new_columns.push((name, reordered)); + } + + DataFrame::new(new_columns) +} + +// ============================================================================ +// Group identification +// ============================================================================ + +/// Compute a group ID (0-based) for each row based on one or more columns. +/// +/// Rows with the same combination of values in `group_cols` get the same ID. +/// The IDs are assigned in order of first appearance in the (already sorted) data. +pub fn compute_group_ids(df: &DataFrame, group_cols: &[&str]) -> Result> { + let n_rows = df.height(); + if n_rows == 0 { + return Ok(Vec::new()); + } + + // Collect the group column arrays + let arrays: Vec<&ArrayRef> = group_cols + .iter() + .map(|&name| df.column(name)) + .collect::>>()?; + + // Assign group IDs by detecting where the composite key changes. + // Since data is expected to be sorted by group columns, we just compare + // adjacent rows. + let mut group_ids = Vec::with_capacity(n_rows); + group_ids.push(0usize); + let mut current_group = 0usize; + + for i in 1..n_rows { + let changed = arrays + .iter() + .any(|arr| value_to_string(arr, i) != value_to_string(arr, i - 1)); + if changed { + current_group += 1; + } + group_ids.push(current_group); + } + + Ok(group_ids) +} + +// ============================================================================ +// Grouped cumulative operations +// ============================================================================ + +/// Compute cumulative sum within groups. +/// +/// For each row, the result is the running total of `values` within its group +/// (identified by `group_ids`). Null values are treated as 0. +pub fn grouped_cumsum(values: &Float64Array, group_ids: &[usize]) -> Float64Array { + let mut result = Vec::with_capacity(values.len()); + let mut running_sum = 0.0; + let mut current_group = group_ids.first().copied().unwrap_or(0); + + for (val_opt, &gid) in values.iter().zip(group_ids.iter()) { + if gid != current_group { + // New group — reset running sum + running_sum = 0.0; + current_group = gid; + } + running_sum += val_opt.unwrap_or(0.0); + result.push(running_sum); + } + + Float64Array::from(result) +} + +/// Compute shifted cumulative sum within groups (lag by 1, fill with 0). +/// +/// For each row, the result is the cumulative sum of all PREVIOUS rows in the +/// same group. The first row of each group gets 0. +pub fn grouped_cumsum_lag(values: &Float64Array, group_ids: &[usize]) -> Float64Array { + let mut result = Vec::with_capacity(values.len()); + let mut running_sum = 0.0; + let mut current_group = group_ids.first().copied().unwrap_or(0); + + for (val_opt, &gid) in values.iter().zip(group_ids.iter()) { + if gid != current_group { + // New group — reset running sum + running_sum = 0.0; + current_group = gid; + } + // Lag: output the running sum BEFORE adding current value + result.push(running_sum); + running_sum += val_opt.unwrap_or(0.0); + } + + Float64Array::from(result) +} + +/// Compute group sums, broadcast back to each row. +/// +/// Each row gets the total sum of its group. +pub fn grouped_sum_broadcast(values: &Float64Array, group_ids: &[usize]) -> Float64Array { + if values.is_empty() { + return Float64Array::from(Vec::::new()); + } + + let n_groups = group_ids.iter().copied().max().unwrap_or(0) + 1; + let mut group_sums = vec![0.0; n_groups]; + + for (val_opt, &gid) in values.iter().zip(group_ids.iter()) { + group_sums[gid] += val_opt.unwrap_or(0.0); + } + + let result: Vec = group_ids.iter().map(|&gid| group_sums[gid]).collect(); + Float64Array::from(result) +} + +// ============================================================================ +// Array arithmetic helpers +// ============================================================================ + +/// Compute element-wise: a / b (Float64 arrays). +pub fn divide_arrays(a: &Float64Array, b: &Float64Array) -> Result { + // Manual division to handle divide-by-zero gracefully (return 0 instead of NaN/Inf) + let result: Vec = a + .iter() + .zip(b.iter()) + .map(|(av, bv)| { + let divisor = bv.unwrap_or(0.0); + if divisor == 0.0 { + 0.0 + } else { + av.unwrap_or(0.0) / divisor + } + }) + .collect(); + Ok(Float64Array::from(result)) +} + +/// Compute element-wise: a * scalar. +pub fn multiply_scalar(a: &Float64Array, scalar: f64) -> Float64Array { + let result: Vec = a.iter().map(|v| v.unwrap_or(0.0) * scalar).collect(); + Float64Array::from(result) +} + +/// Compute element-wise: a - b. +pub fn subtract_arrays(a: &Float64Array, b: &Float64Array) -> Float64Array { + let result: Vec = a + .iter() + .zip(b.iter()) + .map(|(av, bv)| av.unwrap_or(0.0) - bv.unwrap_or(0.0)) + .collect(); + Float64Array::from(result) +} + +/// Compute element-wise: a / scalar. +pub fn divide_scalar(a: &Float64Array, scalar: f64) -> Float64Array { + if scalar == 0.0 { + return Float64Array::from(vec![0.0; a.len()]); + } + let result: Vec = a.iter().map(|v| v.unwrap_or(0.0) / scalar).collect(); + Float64Array::from(result) +} + +// ============================================================================ +// Aggregation +// ============================================================================ + +/// Get the minimum value from a Float64 array, ignoring nulls. +pub fn min_f64(array: &ArrayRef) -> Result> { + let f64_array = as_f64(array)?; + Ok(compute::min(f64_array)) +} + +/// Get the minimum value from a column, casting to Float64 first if needed. +pub fn column_min_f64(df: &DataFrame, col_name: &str) -> Result> { + let col = df.column(col_name)?; + if col.data_type() == &arrow::datatypes::DataType::Float64 { + min_f64(col) + } else { + let casted = crate::array_util::cast_array(col, &arrow::datatypes::DataType::Float64)?; + min_f64(&casted) + } +} + +// ============================================================================ +// Convenience: fill nulls on an ArrayRef +// ============================================================================ + +/// Fill nulls in an ArrayRef (expected Float64) with a value, returning a new ArrayRef. +pub fn fill_null_f64_ref(array: &ArrayRef, fill: f64) -> Result { + let f64_arr = as_f64(array)?; + Ok(std::sync::Arc::new(fill_null_f64(f64_arr, fill))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_grouped_cumsum_single_group() { + let values = Float64Array::from(vec![10.0, 20.0, 30.0]); + let group_ids = vec![0, 0, 0]; + let result = grouped_cumsum(&values, &group_ids); + assert_eq!(result.value(0), 10.0); + assert_eq!(result.value(1), 30.0); + assert_eq!(result.value(2), 60.0); + } + + #[test] + fn test_grouped_cumsum_two_groups() { + // Groups: A(10, 20), B(15, 25) + let values = Float64Array::from(vec![10.0, 20.0, 15.0, 25.0]); + let group_ids = vec![0, 0, 1, 1]; + let result = grouped_cumsum(&values, &group_ids); + assert_eq!(result.value(0), 10.0); // A: 10 + assert_eq!(result.value(1), 30.0); // A: 10+20 + assert_eq!(result.value(2), 15.0); // B: 15 + assert_eq!(result.value(3), 40.0); // B: 15+25 + } + + #[test] + fn test_grouped_cumsum_lag() { + let values = Float64Array::from(vec![10.0, 20.0, 15.0, 25.0]); + let group_ids = vec![0, 0, 1, 1]; + let result = grouped_cumsum_lag(&values, &group_ids); + assert_eq!(result.value(0), 0.0); // A: first → 0 + assert_eq!(result.value(1), 10.0); // A: lag of 10 + assert_eq!(result.value(2), 0.0); // B: first → 0 + assert_eq!(result.value(3), 15.0); // B: lag of 15 + } + + #[test] + fn test_grouped_sum_broadcast() { + let values = Float64Array::from(vec![10.0, 20.0, 15.0, 25.0]); + let group_ids = vec![0, 0, 1, 1]; + let result = grouped_sum_broadcast(&values, &group_ids); + assert_eq!(result.value(0), 30.0); // A total + assert_eq!(result.value(1), 30.0); // A total + assert_eq!(result.value(2), 40.0); // B total + assert_eq!(result.value(3), 40.0); // B total + } + + #[test] + fn test_grouped_cumsum_with_nulls() { + let values = Float64Array::from(vec![Some(10.0), None, Some(20.0)]); + let group_ids = vec![0, 0, 0]; + let result = grouped_cumsum(&values, &group_ids); + assert_eq!(result.value(0), 10.0); + assert_eq!(result.value(1), 10.0); // null treated as 0 + assert_eq!(result.value(2), 30.0); + } + + #[test] + fn test_sort_dataframe() { + let df = crate::df! { + "x" => vec!["B", "A", "C", "A"], + "y" => vec![2.0, 1.0, 3.0, 0.0], + } + .unwrap(); + + let sorted = sort_dataframe(&df, &["x"]).unwrap(); + let x_col = sorted.column("x").unwrap(); + let x_arr = crate::array_util::as_str(x_col).unwrap(); + assert_eq!(x_arr.value(0), "A"); + assert_eq!(x_arr.value(1), "A"); + assert_eq!(x_arr.value(2), "B"); + assert_eq!(x_arr.value(3), "C"); + } + + #[test] + fn test_compute_group_ids() { + let df = crate::df! { + "group" => vec!["A", "A", "B", "B", "C"], + } + .unwrap(); + + let ids = compute_group_ids(&df, &["group"]).unwrap(); + assert_eq!(ids, vec![0, 0, 1, 1, 2]); + } + + #[test] + fn test_compute_group_ids_multi_column() { + let df = crate::df! { + "g1" => vec!["A", "A", "A", "B"], + "g2" => vec!["X", "X", "Y", "X"], + } + .unwrap(); + + let ids = compute_group_ids(&df, &["g1", "g2"]).unwrap(); + assert_eq!(ids, vec![0, 0, 1, 2]); + } + + #[test] + fn test_divide_arrays_with_zero() { + let a = Float64Array::from(vec![10.0, 20.0, 30.0]); + let b = Float64Array::from(vec![2.0, 0.0, 5.0]); + let result = divide_arrays(&a, &b).unwrap(); + assert_eq!(result.value(0), 5.0); + assert_eq!(result.value(1), 0.0); // divide by zero → 0 + assert_eq!(result.value(2), 6.0); + } + + #[test] + fn test_column_min_f64() { + let df = crate::df! { + "x" => vec![3.0, 1.0, 2.0], + } + .unwrap(); + let min = column_min_f64(&df, "x").unwrap(); + assert_eq!(min, Some(1.0)); + } + + #[test] + fn test_multiply_scalar() { + let a = Float64Array::from(vec![1.0, 2.0, 3.0]); + let result = multiply_scalar(&a, 10.0); + assert_eq!(result.value(0), 10.0); + assert_eq!(result.value(1), 20.0); + assert_eq!(result.value(2), 30.0); + } + + #[test] + fn test_subtract_arrays() { + let a = Float64Array::from(vec![10.0, 20.0, 30.0]); + let b = Float64Array::from(vec![1.0, 2.0, 3.0]); + let result = subtract_arrays(&a, &b); + assert_eq!(result.value(0), 9.0); + assert_eq!(result.value(1), 18.0); + assert_eq!(result.value(2), 27.0); + } +} diff --git a/src/data/airquality.parquet b/src/data/airquality.parquet index 06df1686..3858c5b2 100644 Binary files a/src/data/airquality.parquet and b/src/data/airquality.parquet differ diff --git a/src/data/penguins.parquet b/src/data/penguins.parquet index b3be2573..b8f99628 100644 Binary files a/src/data/penguins.parquet and b/src/data/penguins.parquet differ diff --git a/src/dataframe.rs b/src/dataframe.rs new file mode 100644 index 00000000..491fe12b --- /dev/null +++ b/src/dataframe.rs @@ -0,0 +1,647 @@ +//! Thin DataFrame wrapper around Arrow RecordBatch. +//! +//! Provides ergonomic column-by-name access and mutation methods +//! (with_column, rename, drop) that RecordBatch lacks natively. +//! Each mutation returns a new DataFrame (RecordBatch is immutable). + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use std::sync::Arc; + +use crate::{GgsqlError, Result}; + +/// A thin wrapper around Arrow's `RecordBatch` providing named-column access +/// and ergonomic mutation methods. +/// +/// Clone is cheap — the underlying arrays are reference-counted. +#[derive(Debug, Clone)] +pub struct DataFrame { + inner: RecordBatch, +} + +impl DataFrame { + // ======================================================================== + // Construction + // ======================================================================== + + /// Create a DataFrame from named columns. + /// + /// All arrays must have the same length. + pub fn new(columns: Vec<(&str, ArrayRef)>) -> Result { + if columns.is_empty() { + return Ok(Self::empty()); + } + let fields: Vec = columns + .iter() + .map(|(name, arr)| Field::new(*name, arr.data_type().clone(), true)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + let arrays: Vec = columns.into_iter().map(|(_, arr)| arr).collect(); + let rb = RecordBatch::try_new(schema, arrays) + .map_err(|e| GgsqlError::InternalError(format!("Failed to create DataFrame: {}", e)))?; + Ok(Self { inner: rb }) + } + + /// Create an empty DataFrame (0 columns, 0 rows). + pub fn empty() -> Self { + Self { + inner: RecordBatch::new_empty(Arc::new(Schema::empty())), + } + } + + /// Wrap an existing RecordBatch. + pub fn from_record_batch(rb: RecordBatch) -> Self { + Self { inner: rb } + } + + // ======================================================================== + // Read access + // ======================================================================== + + /// Number of rows. + pub fn height(&self) -> usize { + self.inner.num_rows() + } + + /// Number of columns. + pub fn width(&self) -> usize { + self.inner.num_columns() + } + + /// (rows, columns) tuple. + pub fn shape(&self) -> (usize, usize) { + (self.height(), self.width()) + } + + /// Get a column by name. + pub fn column(&self, name: &str) -> Result<&ArrayRef> { + let idx = self.column_index(name)?; + Ok(self.inner.column(idx)) + } + + /// Get all columns as a slice. + pub fn get_columns(&self) -> &[ArrayRef] { + self.inner.columns() + } + + /// Get column names. + pub fn get_column_names(&self) -> Vec { + self.inner + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + } + + /// Get the Arrow schema (reference-counted). + pub fn schema(&self) -> Arc { + self.inner.schema().clone() + } + + /// Access the underlying RecordBatch directly. + pub fn inner(&self) -> &RecordBatch { + &self.inner + } + + /// Consume the wrapper and return the RecordBatch. + pub fn into_inner(self) -> RecordBatch { + self.inner + } + + /// Get the data type of a column by name. + pub fn column_dtype(&self, name: &str) -> Result { + let idx = self.column_index(name)?; + Ok(self.inner.schema().field(idx).data_type().clone()) + } + + // ======================================================================== + // Mutation (returns new DataFrame) + // ======================================================================== + + /// Add or replace a column. If a column with `name` already exists, it is replaced. + pub fn with_column(&self, name: &str, array: ArrayRef) -> Result { + if array.len() != self.height() && self.width() > 0 { + return Err(GgsqlError::InternalError(format!( + "Cannot add column '{}' with {} rows to DataFrame with {} rows", + name, + array.len(), + self.height() + ))); + } + + let mut fields: Vec = Vec::with_capacity(self.width() + 1); + let mut arrays: Vec = Vec::with_capacity(self.width() + 1); + let mut replaced = false; + + for (i, field) in self.inner.schema().fields().iter().enumerate() { + if field.name() == name { + fields.push(Field::new(name, array.data_type().clone(), true)); + arrays.push(array.clone()); + replaced = true; + } else { + fields.push(field.as_ref().clone()); + arrays.push(self.inner.column(i).clone()); + } + } + + if !replaced { + fields.push(Field::new(name, array.data_type().clone(), true)); + arrays.push(array); + } + + let schema = Arc::new(Schema::new(fields)); + let rb = RecordBatch::try_new(schema, arrays).map_err(|e| { + GgsqlError::InternalError(format!("Failed to add column '{}': {}", name, e)) + })?; + Ok(Self { inner: rb }) + } + + /// Rename a column. + pub fn rename(&self, old: &str, new: &str) -> Result { + let idx = self.column_index(old)?; + + let fields: Vec = self + .inner + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + if i == idx { + Field::new(new, f.data_type().clone(), f.is_nullable()) + } else { + f.as_ref().clone() + } + }) + .collect(); + + let schema = Arc::new(Schema::new(fields)); + let rb = RecordBatch::try_new(schema, self.inner.columns().to_vec()).map_err(|e| { + GgsqlError::InternalError(format!( + "Failed to rename column '{}' to '{}': {}", + old, new, e + )) + })?; + Ok(Self { inner: rb }) + } + + /// Drop a column by name. Returns error if column doesn't exist. + pub fn drop(&self, name: &str) -> Result { + let idx = self.column_index(name)?; + self.drop_by_index(idx) + } + + /// Drop multiple columns by name. Silently ignores names that don't exist. + /// + /// If every column is dropped, the returned DataFrame preserves the original + /// row count (0 columns × N rows), which annotation layers rely on to know + /// how many marks to draw. + pub fn drop_many>(&self, names: &[S]) -> Result { + let drop_set: std::collections::HashSet<&str> = names.iter().map(|s| s.as_ref()).collect(); + + let mut fields = Vec::new(); + let mut arrays = Vec::new(); + + for (i, field) in self.inner.schema().fields().iter().enumerate() { + if !drop_set.contains(field.name().as_str()) { + fields.push(field.as_ref().clone()); + arrays.push(self.inner.column(i).clone()); + } + } + + build_record_batch(fields, arrays, self.height()) + .map(|inner| Self { inner }) + .map_err(|e| GgsqlError::InternalError(format!("Failed to drop columns: {}", e))) + } + + /// Replace a column's array (keeping the same name). + pub fn replace(&self, name: &str, array: ArrayRef) -> Result { + let idx = self.column_index(name)?; + + if array.len() != self.height() { + return Err(GgsqlError::InternalError(format!( + "Replacement column '{}' has {} rows, expected {}", + name, + array.len(), + self.height() + ))); + } + + let fields: Vec = self + .inner + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + if i == idx { + Field::new(name, array.data_type().clone(), f.is_nullable()) + } else { + f.as_ref().clone() + } + }) + .collect(); + + let mut arrays: Vec = self.inner.columns().to_vec(); + arrays[idx] = array; + + let schema = Arc::new(Schema::new(fields)); + let rb = RecordBatch::try_new(schema, arrays).map_err(|e| { + GgsqlError::InternalError(format!("Failed to replace column '{}': {}", name, e)) + })?; + Ok(Self { inner: rb }) + } + + /// Slice the DataFrame (offset and length). + pub fn slice(&self, offset: usize, length: usize) -> Self { + Self { + inner: self.inner.slice(offset, length), + } + } + + // ======================================================================== + // Private helpers + // ======================================================================== + + fn column_index(&self, name: &str) -> Result { + self.inner + .schema() + .index_of(name) + .map_err(|_| GgsqlError::InternalError(format!("Column '{}' not found", name))) + } + + fn drop_by_index(&self, idx: usize) -> Result { + let mut fields = Vec::with_capacity(self.width() - 1); + let mut arrays = Vec::with_capacity(self.width() - 1); + + for (i, field) in self.inner.schema().fields().iter().enumerate() { + if i != idx { + fields.push(field.as_ref().clone()); + arrays.push(self.inner.column(i).clone()); + } + } + + build_record_batch(fields, arrays, self.height()) + .map(|inner| Self { inner }) + .map_err(|e| { + GgsqlError::InternalError(format!("Failed to drop column at index {}: {}", idx, e)) + }) + } +} + +/// Build a `RecordBatch`, preserving `row_count` even when the schema has no fields. +/// +/// Arrow's default constructor discards the row count for zero-column batches, +/// which would silently lose "N rows × 0 columns" — a state annotation layers +/// depend on (they draw one mark per row regardless of data columns). +fn build_record_batch( + fields: Vec, + arrays: Vec, + row_count: usize, +) -> std::result::Result { + let schema = Arc::new(Schema::new(fields)); + if arrays.is_empty() { + let options = RecordBatchOptions::new().with_row_count(Some(row_count)); + RecordBatch::try_new_with_options(schema, arrays, &options) + } else { + RecordBatch::try_new(schema, arrays) + } +} + +/// Convenience macro for creating test DataFrames, similar to polars' `df!`. +/// +/// # Examples +/// +/// ```ignore +/// let df = df! { +/// "name" => vec!["Alice", "Bob"], +/// "age" => vec![30i32, 25], +/// }.unwrap(); +/// ``` +#[macro_export] +macro_rules! df { + ($($col_name:expr => $values:expr),+ $(,)?) => {{ + { + let columns: Vec<(&str, arrow::array::ArrayRef)> = vec![ + $( + ($col_name, $crate::dataframe::into_array_ref($values)), + )+ + ]; + $crate::dataframe::DataFrame::new(columns) + } + }}; +} + +// ============================================================================ +// Conversion helpers for the df! macro +// ============================================================================ + +/// Convert typed Vecs into ArrayRef. Used by the `df!` macro. +pub fn into_array_ref(values: T) -> ArrayRef { + values.into_array_ref() +} + +/// Trait for converting typed collections into Arrow arrays. +pub trait IntoArrayRef { + fn into_array_ref(self) -> ArrayRef; +} + +// --- Vec --- +impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::Float64Array::from(self)) + } +} + +// --- Vec> --- +impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::Float64Array::from(self)) + } +} + +// --- Vec --- +impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::Int32Array::from(self)) + } +} + +// --- Vec> --- +impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::Int32Array::from(self)) + } +} + +// --- Vec --- +impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::Int64Array::from(self)) + } +} + +// --- Vec> --- +impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::Int64Array::from(self)) + } +} + +// --- Vec --- +impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::BooleanArray::from(self)) + } +} + +// --- Vec> --- +impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::BooleanArray::from(self)) + } +} + +// --- Vec<&str> --- +impl IntoArrayRef for Vec<&str> { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::StringArray::from(self)) + } +} + +// --- Vec> --- +impl IntoArrayRef for Vec> { + fn into_array_ref(self) -> ArrayRef { + Arc::new(arrow::array::StringArray::from(self)) + } +} + +// --- Vec --- +impl IntoArrayRef for Vec { + fn into_array_ref(self) -> ArrayRef { + let refs: Vec<&str> = self.iter().map(|s| s.as_str()).collect(); + Arc::new(arrow::array::StringArray::from(refs)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array, Int32Array, StringArray}; + + #[test] + fn test_new_and_accessors() { + let df = DataFrame::new(vec![ + ("x", Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef), + ( + "y", + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])) as ArrayRef, + ), + ]) + .unwrap(); + + assert_eq!(df.height(), 3); + assert_eq!(df.width(), 2); + assert_eq!(df.shape(), (3, 2)); + assert_eq!( + df.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); + assert!(df.column("x").is_ok()); + assert!(df.column("z").is_err()); + } + + #[test] + fn test_empty() { + let df = DataFrame::empty(); + assert_eq!(df.height(), 0); + assert_eq!(df.width(), 0); + } + + #[test] + fn test_with_column_add() { + let df = DataFrame::new(vec![( + "x", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )]) + .unwrap(); + + let df2 = df + .with_column( + "y", + Arc::new(Float64Array::from(vec![10.0, 20.0])) as ArrayRef, + ) + .unwrap(); + + assert_eq!(df2.width(), 2); + assert_eq!( + df2.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); + } + + #[test] + fn test_with_column_replace() { + let df = DataFrame::new(vec![( + "x", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )]) + .unwrap(); + + let df2 = df + .with_column("x", Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef) + .unwrap(); + + assert_eq!(df2.width(), 1); + let col = df2.column("x").unwrap(); + let arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 10); + } + + #[test] + fn test_rename() { + let df = DataFrame::new(vec![( + "x", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )]) + .unwrap(); + + let df2 = df.rename("x", "renamed").unwrap(); + assert!(df2.column("renamed").is_ok()); + assert!(df2.column("x").is_err()); + } + + #[test] + fn test_drop() { + let df = DataFrame::new(vec![ + ("x", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "y", + Arc::new(Float64Array::from(vec![1.0, 2.0])) as ArrayRef, + ), + ]) + .unwrap(); + + let df2 = df.drop("x").unwrap(); + assert_eq!(df2.width(), 1); + assert_eq!(df2.get_column_names(), vec!["y"]); + } + + #[test] + fn test_drop_many() { + let df = DataFrame::new(vec![ + ("a", Arc::new(Int32Array::from(vec![1])) as ArrayRef), + ("b", Arc::new(Int32Array::from(vec![2])) as ArrayRef), + ("c", Arc::new(Int32Array::from(vec![3])) as ArrayRef), + ]) + .unwrap(); + + let df2 = df.drop_many(&["a", "c"]).unwrap(); + assert_eq!(df2.get_column_names(), vec!["b".to_string()]); + } + + #[test] + fn test_drop_last_column_preserves_row_count() { + // Annotation layers rely on "N rows × 0 columns" to know how many marks to draw. + let df = DataFrame::new(vec![( + "x", + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]) + .unwrap(); + + let df2 = df.drop("x").unwrap(); + assert_eq!(df2.width(), 0); + assert_eq!(df2.height(), 3); + + let df3 = df.drop_many(&["x"]).unwrap(); + assert_eq!(df3.width(), 0); + assert_eq!(df3.height(), 3); + } + + #[test] + fn test_replace() { + let df = DataFrame::new(vec![( + "x", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )]) + .unwrap(); + + let df2 = df + .replace("x", Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef) + .unwrap(); + let arr = df2 + .column("x") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.value(0), 10); + assert_eq!(arr.value(1), 20); + } + + #[test] + fn test_slice() { + let df = DataFrame::new(vec![( + "x", + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef, + )]) + .unwrap(); + + let df2 = df.slice(1, 3); + assert_eq!(df2.height(), 3); + let arr = df2 + .column("x") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.value(0), 2); + assert_eq!(arr.value(2), 4); + } + + #[test] + fn test_df_macro() { + let df = df! { + "name" => vec!["Alice", "Bob"], + "age" => vec![30i32, 25], + "score" => vec![95.5, 87.3], + } + .unwrap(); + + assert_eq!(df.height(), 2); + assert_eq!(df.width(), 3); + assert_eq!( + df.get_column_names(), + vec!["name".to_string(), "age".to_string(), "score".to_string()] + ); + + let names = df + .column("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(names.value(0), "Alice"); + } + + #[test] + fn test_df_macro_with_optionals() { + let df = df! { + "x" => vec![Some(1.0), None, Some(3.0)], + } + .unwrap(); + + assert_eq!(df.height(), 3); + let col = df + .column("x") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(!col.is_null(0)); + assert!(col.is_null(1)); + assert!(!col.is_null(2)); + } +} diff --git a/src/execute/casting.rs b/src/execute/casting.rs index 0f731812..ce0b2dd7 100644 --- a/src/execute/casting.rs +++ b/src/execute/casting.rs @@ -7,7 +7,7 @@ use crate::naming; use crate::plot::scale::coerce_dtypes; use crate::plot::{CastTargetType, Plot}; use crate::reader::SqlDialect; -use polars::prelude::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, TimeUnit}; use std::collections::HashMap; use super::schema::TypeInfo; @@ -173,14 +173,14 @@ pub fn update_type_info_for_casting(type_info: &mut [TypeInfo], requirements: &[ entry.1 = match req.target_type { CastTargetType::Number => DataType::Float64, CastTargetType::Integer => DataType::Int64, - CastTargetType::Date => DataType::Date, - CastTargetType::DateTime => DataType::Datetime(TimeUnit::Microseconds, None), - CastTargetType::Time => DataType::Time, - CastTargetType::String => DataType::String, + CastTargetType::Date => DataType::Date32, + CastTargetType::DateTime => DataType::Timestamp(TimeUnit::Microsecond, None), + CastTargetType::Time => DataType::Time64(TimeUnit::Nanosecond), + CastTargetType::String => DataType::Utf8, CastTargetType::Boolean => DataType::Boolean, }; // Update is_discrete flag based on new type - entry.2 = matches!(entry.1, DataType::String | DataType::Boolean); + entry.2 = matches!(entry.1, DataType::Utf8 | DataType::Boolean); } } } diff --git a/src/execute/cte.rs b/src/execute/cte.rs index 041bb298..dc4637f4 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -153,13 +153,9 @@ pub fn materialize_ctes(ctes: &[CteDefinition], reader: &dyn Reader) -> Result = df - .get_column_names() - .iter() - .map(|s| s.to_string()) - .collect(); + let current_names: Vec = df.get_column_names(); for (old, new) in current_names.iter().zip(cte.column_aliases.iter()) { - df.rename(old, new.into()).map_err(|e| { + df = df.rename(old, new).map_err(|e| { GgsqlError::ReaderError(format!( "Failed to apply column alias '{}' for CTE '{}': {}", new, cte.name, e diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 7a6725be..6af5c641 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -11,7 +11,7 @@ use crate::plot::{ }; use crate::reader::SqlDialect; use crate::{naming, DataFrame, GgsqlError, Result}; -use polars::prelude::DataType; +use arrow::datatypes::DataType; use std::collections::{HashMap, HashSet}; use super::casting::TypeRequirement; @@ -153,8 +153,6 @@ pub fn build_layer_select_list( /// Note: Prefixed aesthetic names persist through the entire pipeline. /// We do NOT rename `__ggsql_aes_x__` back to `x`. pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result { - use polars::prelude::IntoColumn; - let mut df = df; let row_count = df.height(); @@ -168,7 +166,7 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result { // Check if this stat column exists in the DataFrame if df.column(name).is_ok() { - df.rename(name, target_col_name.into()).map_err(|e| { + df = df.rename(name, &target_col_name).map_err(|e| { GgsqlError::InternalError(format!( "Failed to rename stat column '{}' to '{}': {}", name, target_aesthetic, e @@ -178,16 +176,13 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result { // Add constant column for literal values - let series = literal_to_series(&target_col_name, lit, row_count); - df = df - .with_column(series.into_column()) - .map_err(|e| { - GgsqlError::InternalError(format!( - "Failed to add literal column '{}': {}", - target_col_name, e - )) - })? - .clone(); + let array = literal_to_array(lit, row_count); + df = df.with_column(&target_col_name, array).map_err(|e| { + GgsqlError::InternalError(format!( + "Failed to add literal column '{}': {}", + target_col_name, e + )) + })?; } } } @@ -197,46 +192,55 @@ pub fn apply_remappings_post_query(df: DataFrame, layer: &Layer) -> Result polars::prelude::Series { +pub fn literal_to_array(lit: &ParameterValue, len: usize) -> arrow::array::ArrayRef { + use crate::array_util::{cast_array, new_constant_bool, new_constant_f64, new_constant_str}; use crate::plot::ArrayElement; - use polars::prelude::{DataType, NamedFrom, Series, TimeUnit}; + use arrow::datatypes::{DataType, TimeUnit}; + use std::sync::Arc; match lit { - ParameterValue::Number(n) => Series::new(name.into(), vec![*n; len]), + ParameterValue::Number(n) => new_constant_f64(*n, len), ParameterValue::String(s) => { // Try to parse as temporal types (DateTime > Date > Time) match ArrayElement::String(s.clone()).try_as_temporal() { - ArrayElement::DateTime(micros) => Series::new(name.into(), vec![micros; len]) - .cast(&DataType::Datetime(TimeUnit::Microseconds, None)) - .expect("DateTime cast should not fail"), - ArrayElement::Date(days) => Series::new(name.into(), vec![days; len]) - .cast(&DataType::Date) - .expect("Date cast should not fail"), - ArrayElement::Time(nanos) => Series::new(name.into(), vec![nanos; len]) - .cast(&DataType::Time) - .expect("Time cast should not fail"), + ArrayElement::DateTime(micros) => { + let arr: arrow::array::ArrayRef = + Arc::new(arrow::array::Int64Array::from(vec![micros; len])); + cast_array(&arr, &DataType::Timestamp(TimeUnit::Microsecond, None)) + .expect("DateTime cast should not fail") + } + ArrayElement::Date(days) => { + let arr: arrow::array::ArrayRef = + Arc::new(arrow::array::Int32Array::from(vec![days; len])); + cast_array(&arr, &DataType::Date32).expect("Date cast should not fail") + } + ArrayElement::Time(nanos) => { + let arr: arrow::array::ArrayRef = + Arc::new(arrow::array::Int64Array::from(vec![nanos; len])); + cast_array(&arr, &DataType::Time64(TimeUnit::Nanosecond)) + .expect("Time cast should not fail") + } ArrayElement::String(_) => { // Parsing failed, use original string - Series::new(name.into(), vec![s.as_str(); len]) + new_constant_str(s, len) } _ => unreachable!("try_as_temporal only returns String or temporal types"), } } - ParameterValue::Boolean(b) => Series::new(name.into(), vec![*b; len]), + ParameterValue::Boolean(b) => new_constant_bool(*b, len), ParameterValue::Array(_) | ParameterValue::Null => { unreachable!("Arrays are never moved to mappings; NULL is filtered in process_annotation_layers()") } @@ -294,7 +298,7 @@ pub fn apply_pre_stat_transform( .iter() .find(|c| c.name == aes_col_name) .map(|c| c.dtype.clone()) - .unwrap_or(DataType::String); // Default to String if not found + .unwrap_or(DataType::Utf8); // Default to Utf8 if not found // Find scale for this aesthetic if let Some(scale) = scales.iter().find(|s| s.aesthetic == *aesthetic) { @@ -1065,77 +1069,68 @@ mod tests { } #[test] - fn test_literal_to_series_date_parsing() { - use polars::prelude::DataType; - - // Date literal should parse to Date type - let series = literal_to_series( - "date_col", - &ParameterValue::String("1973-06-01".to_string()), - 5, - ); + fn test_literal_to_array_date_parsing() { + use arrow::array::Array; + use arrow::datatypes::DataType; + + // Date literal should parse to Date32 type + let array = literal_to_array(&ParameterValue::String("1973-06-01".to_string()), 5); assert_eq!( - series.dtype(), - &DataType::Date, - "Date string should parse to Date type" + array.data_type(), + &DataType::Date32, + "Date string should parse to Date32 type" ); - assert_eq!(series.len(), 5); + assert_eq!(array.len(), 5); } #[test] - fn test_literal_to_series_datetime_parsing() { - use polars::prelude::{DataType, TimeUnit}; + fn test_literal_to_array_datetime_parsing() { + use arrow::array::Array; + use arrow::datatypes::{DataType, TimeUnit}; - // DateTime literal should parse to Datetime type - let series = literal_to_series( - "dt_col", + // DateTime literal should parse to Timestamp type + let array = literal_to_array( &ParameterValue::String("2024-03-17T14:30:00".to_string()), 3, ); assert!( matches!( - series.dtype(), - DataType::Datetime(TimeUnit::Microseconds, None) + array.data_type(), + DataType::Timestamp(TimeUnit::Microsecond, None) ), - "DateTime string should parse to Datetime type" + "DateTime string should parse to Timestamp type" ); - assert_eq!(series.len(), 3); + assert_eq!(array.len(), 3); } #[test] - fn test_literal_to_series_time_parsing() { - use polars::prelude::DataType; - - // Time literal should parse to Time type - let series = literal_to_series( - "time_col", - &ParameterValue::String("14:30:00".to_string()), - 4, - ); + fn test_literal_to_array_time_parsing() { + use arrow::array::Array; + use arrow::datatypes::{DataType, TimeUnit}; + + // Time literal should parse to Time64 type + let array = literal_to_array(&ParameterValue::String("14:30:00".to_string()), 4); assert_eq!( - series.dtype(), - &DataType::Time, - "Time string should parse to Time type" + array.data_type(), + &DataType::Time64(TimeUnit::Nanosecond), + "Time string should parse to Time64 type" ); - assert_eq!(series.len(), 4); + assert_eq!(array.len(), 4); } #[test] - fn test_literal_to_series_string_fallback() { - use polars::prelude::DataType; - - // Non-temporal string should remain String type - let series = literal_to_series( - "text_col", - &ParameterValue::String("not a date".to_string()), - 2, - ); + fn test_literal_to_array_string_fallback() { + use arrow::array::Array; + use arrow::datatypes::DataType; + + // Non-temporal string should remain Utf8 type + let array = literal_to_array(&ParameterValue::String("not a date".to_string()), 2); assert_eq!( - series.dtype(), - &DataType::String, - "Non-temporal string should remain String type" + array.data_type(), + &DataType::Utf8, + "Non-temporal string should remain Utf8 type" ); - assert_eq!(series.len(), 2); + assert_eq!(array.len(), 2); } #[test] diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 917ae429..d9215b41 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -388,11 +388,9 @@ fn get_unique_facet_values( facet_aesthetic: &str, layers: &[Layer], layers_missing_facet: &[bool], -) -> Option { - use polars::prelude::*; - +) -> Option { let aes_col = naming::aesthetic_column(facet_aesthetic); - let mut all_values: Vec = Vec::new(); + let mut all_arrays: Vec = Vec::new(); for (idx, layer) in layers.iter().enumerate() { // Skip layers that are missing the facet column @@ -403,23 +401,32 @@ fn get_unique_facet_values( if let Some(ref data_key) = layer.data_key { if let Some(df) = data_map.get(data_key) { if let Ok(col) = df.column(&aes_col) { - all_values.push(col.as_materialized_series().clone()); + all_arrays.push(col.clone()); } } } } - if all_values.is_empty() { + if all_arrays.is_empty() { return None; } - // Concatenate all series and get unique values - let mut combined = all_values.remove(0); - for s in all_values { - let _ = combined.extend(&s); - } + // Concatenate all arrays + let refs: Vec<&dyn arrow::array::Array> = all_arrays.iter().map(|a| a.as_ref()).collect(); + let combined = arrow::compute::concat(&refs).ok()?; - combined.unique().ok() + // Get unique values by collecting strings into a set + use crate::array_util::value_to_string; + let mut seen = std::collections::HashSet::new(); + let mut unique_indices = Vec::new(); + for i in 0..combined.len() { + let key = value_to_string(&combined, i); + if seen.insert(key) { + unique_indices.push(i as u32); + } + } + let indices = arrow::array::UInt32Array::from(unique_indices); + arrow::compute::take(&*combined, &indices, None).ok() } /// Cross-join a DataFrame with facet values (duplicate for each facet panel). @@ -428,10 +435,10 @@ fn get_unique_facet_values( /// The facet column is added with the appropriate values. fn cross_join_with_facet_values( df: &DataFrame, - unique_values: &polars::prelude::Series, + unique_values: &arrow::array::ArrayRef, facet_aesthetic: &str, ) -> Result { - use polars::prelude::*; + use arrow::array::{Array, UInt32Array}; let aes_col = naming::aesthetic_column(facet_aesthetic); let n_values = unique_values.len(); @@ -442,22 +449,21 @@ fn cross_join_with_facet_values( let n_rows = df.height(); - // Create the repeated data manually (polars cross_join requires an import we may not have) - // For each row in df, repeat n_values times - // For facet column, for each row's repetitions, cycle through unique_values - // 1. Repeat each original column n_values times - let mut new_columns: Vec = Vec::new(); - for col in df.get_columns() { - // Repeat each value n_values times: [a, b, c] with n_values=2 -> [a, a, b, b, c, c] - let indices: Vec = (0..n_rows) - .flat_map(|i| std::iter::repeat_n(i as u32, n_values)) - .collect(); - let idx = IdxCa::new(PlSmallStr::EMPTY, &indices); - let repeated = col.as_materialized_series().take(&idx).map_err(|e| { - crate::GgsqlError::InternalError(format!("Failed to repeat column: {}", e)) + // [a, b, c] with n_values=2 -> [a, a, b, b, c, c] + let repeat_indices: Vec = (0..n_rows) + .flat_map(|i| std::iter::repeat_n(i as u32, n_values)) + .collect(); + let repeat_idx = UInt32Array::from(repeat_indices); + + let col_names = df.get_column_names(); + let mut new_columns: Vec<(&str, arrow::array::ArrayRef)> = Vec::new(); + for name in &col_names { + let col = df.column(name)?; + let repeated = arrow::compute::take(col.as_ref(), &repeat_idx, None).map_err(|e| { + GgsqlError::InternalError(format!("Failed to repeat column '{}': {}", name, e)) })?; - new_columns.push(repeated.into()); + new_columns.push((name, repeated)); } // 2. Create the facet column: tile unique_values for each row @@ -465,18 +471,12 @@ fn cross_join_with_facet_values( let facet_indices: Vec = (0..n_rows) .flat_map(|_| (0..n_values).map(|j| j as u32)) .collect(); - let facet_idx = IdxCa::new(PlSmallStr::EMPTY, &facet_indices); - let facet_col = unique_values - .take(&facet_idx) - .map_err(|e| { - crate::GgsqlError::InternalError(format!("Failed to create facet column: {}", e)) - })? - .with_name(aes_col.into()); - new_columns.push(facet_col.into()); - - DataFrame::new(new_columns).map_err(|e| { - crate::GgsqlError::InternalError(format!("Failed to create expanded DataFrame: {}", e)) - }) + let facet_idx = UInt32Array::from(facet_indices); + let facet_col = arrow::compute::take(unique_values.as_ref(), &facet_idx, None) + .map_err(|e| GgsqlError::InternalError(format!("Failed to create facet column: {}", e)))?; + new_columns.push((&aes_col, facet_col)); + + DataFrame::new(new_columns) } /// Handle layers missing the facet column based on facet.missing setting. @@ -859,24 +859,22 @@ fn prune_dataframe(df: &DataFrame, required: &HashSet) -> Result 0 { // Create a 0-column DataFrame with the correct row count // We do this by creating a dummy column and then dropping it - use polars::prelude::df; - let with_rows = df! { + let with_rows = crate::df! { "__dummy__" => vec![0i32; row_count] - } - .map_err(|e| GgsqlError::InternalError(format!("Failed to create DataFrame: {}", e)))?; - - let result = with_rows.drop("__dummy__").map_err(|e| { - GgsqlError::InternalError(format!("Failed to drop dummy column: {}", e)) - })?; - return Ok(result); + }?; + return with_rows.drop("__dummy__"); } else { - // 0 rows - just return empty DataFrame - return Ok(DataFrame::default()); + return Ok(DataFrame::empty()); } } - df.select(&columns_to_keep) - .map_err(|e| GgsqlError::InternalError(format!("Failed to prune columns: {}", e))) + // Keep only the columns in columns_to_keep + let drop_cols: Vec = df + .get_column_names() + .into_iter() + .filter(|name| !columns_to_keep.contains(name)) + .collect(); + df.drop_many(&drop_cols) } /// Prune all DataFrames in the data map based on layer requirements. @@ -1540,7 +1538,7 @@ mod tests { // Should have prefixed aesthetic-named columns (using internal names) let col_names: Vec = layer_df - .get_column_names_str() + .get_column_names() .iter() .map(|s| s.to_string()) .collect(); @@ -1595,7 +1593,7 @@ mod tests { // With new approach, columns are renamed to prefixed aesthetic names (using internal names) let col_names: Vec = layer_df - .get_column_names_str() + .get_column_names() .iter() .map(|s| s.to_string()) .collect(); @@ -1682,7 +1680,7 @@ mod tests { layer_df.column(¥d_col).is_ok(), "DataFrame should have '{}' column: {:?}", yend_col, - layer_df.get_column_names_str() + layer_df.get_column_names() ); } @@ -1820,7 +1818,7 @@ mod tests { // Should have prefixed aesthetic-named columns let col_names: Vec = layer_df - .get_column_names_str() + .get_column_names() .iter() .map(|s| s.to_string()) .collect(); @@ -2028,7 +2026,7 @@ mod tests { layer_df.column(&facet_col).is_ok(), "Should have '{}' column: {:?}", facet_col, - layer_df.get_column_names_str() + layer_df.get_column_names() ); } @@ -2208,7 +2206,7 @@ mod tests { assert!( ref_df.column(&facet_col).is_ok(), "ref data should have facet column after broadcast: {:?}", - ref_df.get_column_names_str() + ref_df.get_column_names() ); } diff --git a/src/execute/position.rs b/src/execute/position.rs index 53cdd4f7..40c029b9 100644 --- a/src/execute/position.rs +++ b/src/execute/position.rs @@ -58,10 +58,12 @@ pub fn apply_position_adjustments( #[cfg(test)] mod tests { use super::*; + use crate::array_util::as_f64; + use crate::df; use crate::plot::facet::{Facet, FacetLayout}; use crate::plot::layer::{Geom, Position}; use crate::plot::{AestheticValue, Mappings, ParameterValue, Scale, ScaleType}; - use polars::prelude::*; + use arrow::array::Array; fn make_continuous_scale(aesthetic: &str) -> Scale { let mut scale = Scale::new(aesthetic); @@ -77,10 +79,10 @@ mod tests { fn make_test_df() -> DataFrame { df! { - "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + "__ggsql_aes_pos1__" => vec!["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], } .unwrap() } @@ -153,9 +155,25 @@ mod tests { let pos2_col = result_df.column("__ggsql_aes_pos2__").unwrap(); let pos2end_col = result_df.column("__ggsql_aes_pos2end__").unwrap(); - // Verify stacking was applied - assert!(pos2_col.f64().is_ok() || pos2_col.i64().is_ok()); - assert!(pos2end_col.f64().is_ok() || pos2end_col.i64().is_ok()); + // Verify stacking was applied (column should be numeric) + assert!( + matches!( + pos2_col.data_type(), + arrow::datatypes::DataType::Float64 + | arrow::datatypes::DataType::Int64 + | arrow::datatypes::DataType::Int32 + ), + "pos2 should be numeric" + ); + assert!( + matches!( + pos2end_col.data_type(), + arrow::datatypes::DataType::Float64 + | arrow::datatypes::DataType::Int64 + | arrow::datatypes::DataType::Int32 + ), + "pos2end should be numeric" + ); } #[test] @@ -184,14 +202,17 @@ mod tests { let offset_col = result_df.column("__ggsql_aes_pos1offset__"); assert!(offset_col.is_ok(), "pos1offset column should be created"); - let offset = offset_col.unwrap().f64().unwrap(); + let offset = as_f64(offset_col.unwrap()).unwrap(); // With 2 groups (X, Y) and default width 0.9: // - adjusted_width = 0.9 / 2 = 0.45 // - center_offset = 0.5 // - Group X: center = (0 - 0.5) * 0.45 = -0.225 // - Group Y: center = (1 - 0.5) * 0.45 = +0.225 - let offsets: Vec = offset.into_iter().flatten().collect(); + let offsets: Vec = (0..offset.len()) + .filter(|&i| !offset.is_null(i)) + .map(|i| offset.value(i)) + .collect(); assert!( offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), "Should have offset -0.225 for group X, got {:?}", @@ -233,15 +254,15 @@ mod tests { apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); let result_df = data_map.get("__ggsql_layer_0__").unwrap(); - let offset = result_df - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); + let offset_col = result_df.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); // With 2 groups and custom width 0.6: // - adjusted_width = 0.6 / 2 = 0.3 - let offsets: Vec = offset.into_iter().flatten().collect(); + let offsets: Vec = (0..offset.len()) + .filter(|&i| !offset.is_null(i)) + .map(|i| offset.value(i)) + .collect(); assert!(offsets.iter().any(|&v| (v - (-0.15)).abs() < 0.001)); assert!(offsets.iter().any(|&v| (v - 0.15).abs() < 0.001)); @@ -275,8 +296,11 @@ mod tests { let offset_col = result_df.column("__ggsql_aes_pos1offset__"); assert!(offset_col.is_ok()); - let offset = offset_col.unwrap().f64().unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset = as_f64(offset_col.unwrap()).unwrap(); + let offsets: Vec = (0..offset.len()) + .filter(|&i| !offset.is_null(i)) + .map(|i| offset.value(i)) + .collect(); // With default width 0.9, offsets should be in range [-0.45, 0.45] for &v in &offsets { @@ -311,12 +335,12 @@ mod tests { apply_position_adjustments(&mut spec_with_layer, &mut data_map).unwrap(); let result_df = data_map.get("__ggsql_layer_0__").unwrap(); - let offset = result_df - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result_df.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()) + .filter(|&i| !offset.is_null(i)) + .map(|i| offset.value(i)) + .collect(); // With custom width 0.6, offsets should be in range [-0.3, 0.3] for &v in &offsets { @@ -333,11 +357,11 @@ mod tests { // Two facet panels (F1, F2) each with the same x="A" and two // fill groups (X, Y). Stacking within each panel should start from 0. let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 30.0, 40.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], - "__ggsql_aes_facet1__" => ["F1", "F1", "F2", "F2"], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 30.0, 40.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], + "__ggsql_aes_facet1__" => vec!["F1", "F1", "F2", "F2"], } .unwrap(); @@ -390,29 +414,27 @@ mod tests { let result_df = data_map.get("__ggsql_layer_0__").unwrap(); // Sort by facet then fill so we can assert in predictable order - let result_df = result_df - .clone() - .lazy() - .sort_by_exprs( - [col("__ggsql_aes_facet1__"), col("__ggsql_aes_fill__")], - SortMultipleOptions::default(), - ) - .collect() - .unwrap(); - - let pos2 = result_df - .column("__ggsql_aes_pos2__") - .unwrap() - .f64() - .unwrap(); - let pos2end = result_df - .column("__ggsql_aes_pos2end__") - .unwrap() - .f64() - .unwrap(); - - let pos2_vals: Vec = pos2.into_iter().flatten().collect(); - let pos2end_vals: Vec = pos2end.into_iter().flatten().collect(); + // Build sort indices based on (facet, fill) lexicographic order + let facet_col = + crate::array_util::as_str(result_df.column("__ggsql_aes_facet1__").unwrap()).unwrap(); + let fill_col = + crate::array_util::as_str(result_df.column("__ggsql_aes_fill__").unwrap()).unwrap(); + let mut indices: Vec = (0..result_df.height()).collect(); + indices.sort_by(|&a, &b| { + let fa = facet_col.value(a); + let fb = facet_col.value(b); + let cmp1 = fa.cmp(fb); + if cmp1 != std::cmp::Ordering::Equal { + return cmp1; + } + fill_col.value(a).cmp(fill_col.value(b)) + }); + + let pos2_arr = as_f64(result_df.column("__ggsql_aes_pos2__").unwrap()).unwrap(); + let pos2end_arr = as_f64(result_df.column("__ggsql_aes_pos2end__").unwrap()).unwrap(); + + let pos2_vals: Vec = indices.iter().map(|&i| pos2_arr.value(i)).collect(); + let pos2end_vals: Vec = indices.iter().map(|&i| pos2end_arr.value(i)).collect(); // Expected (sorted by facet, fill): // F1/X: pos2end=0, pos2=10 (first in panel, starts at 0) @@ -440,11 +462,11 @@ mod tests { // adjusted_width = 0.9 / 4 = 0.225 // offsets would be different (spread across 4 positions) let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 30.0, 40.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], - "__ggsql_aes_facet1__" => ["F1", "F1", "F2", "F2"], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 30.0, 40.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], + "__ggsql_aes_facet1__" => vec!["F1", "F1", "F2", "F2"], } .unwrap(); diff --git a/src/execute/scale.rs b/src/execute/scale.rs index a4bef2fb..5a531f8d 100644 --- a/src/execute/scale.rs +++ b/src/execute/scale.rs @@ -15,7 +15,7 @@ use crate::plot::{ ScaleType, ScaleTypeKind, Schema, }; use crate::{DataFrame, GgsqlError, Result}; -use polars::prelude::Column; +use arrow::array::ArrayRef; use std::collections::{HashMap, HashSet}; use super::schema::TypeInfo; @@ -121,7 +121,7 @@ pub fn create_missing_scales_post_stat( ); if !column_refs.is_empty() { scale.scale_type = Some(ScaleType::infer_for_aesthetic( - column_refs[0].dtype(), + column_refs[0].data_type(), &scale.aesthetic, )); } @@ -219,63 +219,59 @@ pub fn apply_binning_to_dataframe( break_values: &[f64], closed_left: bool, ) -> Result { - use polars::prelude::*; + use crate::array_util::{as_f64, cast_array, new_f64_array}; + use arrow::array::Array; + use arrow::datatypes::DataType; - let column = df.column(col_name).map_err(|e| { - GgsqlError::InternalError(format!("Column '{}' not found: {}", col_name, e)) - })?; - - let series = column.as_materialized_series(); + let column = df.column(col_name)?; // Cast to f64 for binning - let float_series = series.cast(&DataType::Float64).map_err(|e| { + let float_col = cast_array(column, &DataType::Float64).map_err(|e| { GgsqlError::InternalError(format!("Cannot bin column '{}': {}", col_name, e)) })?; - let ca = float_series - .f64() - .map_err(|e| GgsqlError::InternalError(e.to_string()))?; + let f64_arr = as_f64(&float_col)?; // Apply binning: replace values with bin centers let num_bins = break_values.len() - 1; - let binned: Float64Chunked = ca.apply_values(|val| { - for i in 0..num_bins { - let lower = break_values[i]; - let upper = break_values[i + 1]; - let is_last = i == num_bins - 1; - - let in_bin = if closed_left { - // Left-closed: [lower, upper) except last bin is [lower, upper] - if is_last { - val >= lower && val <= upper - } else { - val >= lower && val < upper - } - } else { - // Right-closed: (lower, upper] except first bin is [lower, upper] - if i == 0 { - val >= lower && val <= upper + let binned: Vec> = (0..f64_arr.len()) + .map(|idx| { + if f64_arr.is_null(idx) { + return None; + } + let val = f64_arr.value(idx); + for i in 0..num_bins { + let lower = break_values[i]; + let upper = break_values[i + 1]; + let is_last = i == num_bins - 1; + + let in_bin = if closed_left { + if is_last { + val >= lower && val <= upper + } else { + val >= lower && val < upper + } } else { - val > lower && val <= upper - } - }; + if i == 0 { + val >= lower && val <= upper + } else { + val > lower && val <= upper + } + }; - if in_bin { - return (lower + upper) / 2.0; + if in_bin { + return Some((lower + upper) / 2.0); + } } - } - f64::NAN // Outside all bins - }); + Some(f64::NAN) // Outside all bins + }) + .collect(); - let binned_series = binned.into_series().with_name(col_name.into()); + let binned_array = new_f64_array(binned); // Replace column in DataFrame - let mut new_df = df.clone(); - let _ = new_df - .replace(col_name, binned_series) - .map_err(|e| GgsqlError::InternalError(format!("Failed to replace column: {}", e)))?; - - Ok(new_df) + df.with_column(col_name, binned_array) + .map_err(|e| GgsqlError::InternalError(format!("Failed to replace column: {}", e))) } // ============================================================================= @@ -451,7 +447,7 @@ pub fn collect_dtypes_for_aesthetic( aesthetic: &str, layer_type_info: &[Vec], aesthetic_ctx: &AestheticContext, -) -> Vec { +) -> Vec { let mut dtypes = Vec::new(); let aesthetics_to_check = aesthetic_ctx .internal_position_family(aesthetic) @@ -582,7 +578,7 @@ pub fn find_schema_columns_for_aesthetic( /// /// Used to include literal mappings in scale resolution. pub fn column_info_from_literal(aesthetic: &str, lit: &ParameterValue) -> Option { - use polars::prelude::DataType; + use arrow::datatypes::DataType; match lit { ParameterValue::Number(n) => Some(ColumnInfo { @@ -594,7 +590,7 @@ pub fn column_info_from_literal(aesthetic: &str, lit: &ParameterValue) -> Option }), ParameterValue::String(s) => Some(ColumnInfo { name: naming::const_column(aesthetic), - dtype: DataType::String, + dtype: DataType::Utf8, is_discrete: true, min: Some(ArrayElement::String(s.clone())), max: Some(ArrayElement::String(s.clone())), @@ -621,14 +617,12 @@ pub fn coerce_column_to_type( column_name: &str, target_type: ArrayElementType, ) -> Result { - use polars::prelude::{DataType, NamedFrom, Series, TimeUnit}; - - let column = df.column(column_name).map_err(|e| { - GgsqlError::ValidationError(format!("Column '{}' not found: {}", column_name, e)) - })?; + use crate::array_util::*; + use arrow::array::Array; + use arrow::datatypes::{DataType, TimeUnit}; - let series = column.as_materialized_series(); - let dtype = series.dtype(); + let column = df.column(column_name)?; + let dtype = column.data_type(); // Check if already the target type let already_target_type = matches!( @@ -638,10 +632,10 @@ pub fn coerce_column_to_type( DataType::Float64 | DataType::Int64 | DataType::Int32 | DataType::Float32, ArrayElementType::Number, ) - | (DataType::Date, ArrayElementType::Date) - | (DataType::Datetime(_, _), ArrayElementType::DateTime) - | (DataType::Time, ArrayElementType::Time) - | (DataType::String, ArrayElementType::String) + | (DataType::Date32, ArrayElementType::Date) + | (DataType::Timestamp(_, _), ArrayElementType::DateTime) + | (DataType::Time64(_), ArrayElementType::Time) + | (DataType::Utf8, ArrayElementType::String) ); if already_target_type { @@ -649,139 +643,105 @@ pub fn coerce_column_to_type( } // Coerce based on target type - let new_series: Series = match target_type { - ArrayElementType::Boolean => { - // Convert to boolean - match dtype { - DataType::String => { - let str_series = series.str().map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot convert column '{}' to string for boolean coercion: {}", - column_name, e - )) - })?; - - let bool_vec: Vec> = str_series - .into_iter() - .enumerate() - .map(|(idx, opt_s)| match opt_s { - None => Ok(None), - Some(s) => match s.to_lowercase().as_str() { + let new_array: arrow::array::ArrayRef = match target_type { + ArrayElementType::Boolean => match dtype { + DataType::Utf8 => { + let str_arr = as_str(column)?; + let bool_vec: Vec> = (0..str_arr.len()) + .enumerate() + .map(|(idx, i)| { + if str_arr.is_null(i) { + Ok(None) + } else { + match str_arr.value(i).to_lowercase().as_str() { "true" | "yes" | "1" => Ok(Some(true)), "false" | "no" | "0" => Ok(Some(false)), - _ => Err(GgsqlError::ValidationError(format!( + s => Err(GgsqlError::ValidationError(format!( "Column '{}' row {}: Cannot coerce string '{}' to boolean", column_name, idx, s ))), - }, - }) - .collect::>>()?; - - Series::new(column_name.into(), bool_vec) - } - DataType::Int64 | DataType::Int32 | DataType::Float64 | DataType::Float32 => { - let f64_series = series.cast(&DataType::Float64).map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot cast column '{}' to float64: {}", - column_name, e - )) - })?; - let ca = f64_series.f64().map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot get float64 chunked array: {}", - e - )) - })?; - let bool_vec: Vec> = - ca.into_iter().map(|opt| opt.map(|n| n != 0.0)).collect(); - Series::new(column_name.into(), bool_vec) - } - _ => { - return Err(GgsqlError::ValidationError(format!( - "Cannot coerce column '{}' of type {:?} to boolean", - column_name, dtype - ))); - } + } + } + }) + .collect::>>()?; + new_bool_array(bool_vec) } - } - - ArrayElementType::Number => { - // Convert to float64 - series.cast(&DataType::Float64).map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot coerce column '{}' to number: {}", - column_name, e - )) - })? - } - - ArrayElementType::Date => { - // Convert to date (from string) - match dtype { - DataType::String => { - let str_series = series.str().map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot convert column '{}' to string for date coercion: {}", - column_name, e - )) - })?; - - let date_vec: Vec> = str_series - .into_iter() + DataType::Int64 | DataType::Int32 | DataType::Float64 | DataType::Float32 => { + let f64_col = cast_array(column, &DataType::Float64)?; + let f64_arr = as_f64(&f64_col)?; + let bool_vec: Vec> = (0..f64_arr.len()) + .map(|i| { + if f64_arr.is_null(i) { + None + } else { + Some(f64_arr.value(i) != 0.0) + } + }) + .collect(); + new_bool_array(bool_vec) + } + _ => { + return Err(GgsqlError::ValidationError(format!( + "Cannot coerce column '{}' of type {:?} to boolean", + column_name, dtype + ))); + } + }, + + ArrayElementType::Number => cast_array(column, &DataType::Float64).map_err(|e| { + GgsqlError::ValidationError(format!( + "Cannot coerce column '{}' to number: {}", + column_name, e + )) + })?, + + ArrayElementType::Date => match dtype { + DataType::Utf8 => { + let str_arr = as_str(column)?; + let date_vec: Vec> = (0..str_arr.len()) .enumerate() - .map(|(idx, opt_s)| { - match opt_s { - None => Ok(None), - Some(s) => { - ArrayElement::from_date_string(s) - .and_then(|e| match e { - ArrayElement::Date(d) => Some(d), - _ => None, - }) - .ok_or_else(|| { - GgsqlError::ValidationError(format!( - "Column '{}' row {}: Cannot coerce string '{}' to date (expected YYYY-MM-DD)", - column_name, idx, s - )) - }) - .map(Some) - } + .map(|(idx, i)| { + if str_arr.is_null(i) { + Ok(None) + } else { + let s = str_arr.value(i); + ArrayElement::from_date_string(s) + .and_then(|e| match e { + ArrayElement::Date(d) => Some(d), + _ => None, + }) + .ok_or_else(|| { + GgsqlError::ValidationError(format!( + "Column '{}' row {}: Cannot coerce string '{}' to date (expected YYYY-MM-DD)", + column_name, idx, s + )) + }) + .map(Some) } }) .collect::>>()?; - - Series::new(column_name.into(), date_vec) - .cast(&DataType::Date) - .map_err(|e| { - GgsqlError::ValidationError(format!("Cannot create date series: {}", e)) - })? - } - _ => { - return Err(GgsqlError::ValidationError(format!( - "Cannot coerce column '{}' of type {:?} to date", - column_name, dtype - ))); - } + let i32_arr = new_i32_array(date_vec); + cast_array(&i32_arr, &DataType::Date32)? } - } - - ArrayElementType::DateTime => { - // Convert to datetime (from string) - match dtype { - DataType::String => { - let str_series = series.str().map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot convert column '{}' to string for datetime coercion: {}", - column_name, e - )) - })?; - - let dt_vec: Vec> = str_series - .into_iter() - .enumerate() - .map(|(idx, opt_s)| match opt_s { - None => Ok(None), - Some(s) => ArrayElement::from_datetime_string(s) + _ => { + return Err(GgsqlError::ValidationError(format!( + "Cannot coerce column '{}' of type {:?} to date", + column_name, dtype + ))); + } + }, + + ArrayElementType::DateTime => match dtype { + DataType::Utf8 => { + let str_arr = as_str(column)?; + let dt_vec: Vec> = (0..str_arr.len()) + .enumerate() + .map(|(idx, i)| { + if str_arr.is_null(i) { + Ok(None) + } else { + let s = str_arr.value(i); + ArrayElement::from_datetime_string(s) .and_then(|e| match e { ArrayElement::DateTime(dt) => Some(dt), _ => None, @@ -792,95 +752,68 @@ pub fn coerce_column_to_type( column_name, idx, s )) }) - .map(Some), - }) - .collect::>>()?; - - Series::new(column_name.into(), dt_vec) - .cast(&DataType::Datetime(TimeUnit::Microseconds, None)) - .map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot create datetime series: {}", - e - )) - })? - } - _ => { - return Err(GgsqlError::ValidationError(format!( - "Cannot coerce column '{}' of type {:?} to datetime", - column_name, dtype - ))); - } + .map(Some) + } + }) + .collect::>>()?; + let i64_arr = new_i64_array(dt_vec); + cast_array(&i64_arr, &DataType::Timestamp(TimeUnit::Microsecond, None))? } - } - - ArrayElementType::Time => { - // Convert to time (from string) - match dtype { - DataType::String => { - let str_series = series.str().map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot convert column '{}' to string for time coercion: {}", - column_name, e - )) - })?; + _ => { + return Err(GgsqlError::ValidationError(format!( + "Cannot coerce column '{}' of type {:?} to datetime", + column_name, dtype + ))); + } + }, - let time_vec: Vec> = str_series - .into_iter() + ArrayElementType::Time => match dtype { + DataType::Utf8 => { + let str_arr = as_str(column)?; + let time_vec: Vec> = (0..str_arr.len()) .enumerate() - .map(|(idx, opt_s)| { - match opt_s { - None => Ok(None), - Some(s) => { - ArrayElement::from_time_string(s) - .and_then(|e| match e { - ArrayElement::Time(t) => Some(t), - _ => None, - }) - .ok_or_else(|| { - GgsqlError::ValidationError(format!( - "Column '{}' row {}: Cannot coerce string '{}' to time (expected HH:MM:SS)", - column_name, idx, s - )) - }) - .map(Some) - } + .map(|(idx, i)| { + if str_arr.is_null(i) { + Ok(None) + } else { + let s = str_arr.value(i); + ArrayElement::from_time_string(s) + .and_then(|e| match e { + ArrayElement::Time(t) => Some(t), + _ => None, + }) + .ok_or_else(|| { + GgsqlError::ValidationError(format!( + "Column '{}' row {}: Cannot coerce string '{}' to time (expected HH:MM:SS)", + column_name, idx, s + )) + }) + .map(Some) } }) .collect::>>()?; - - Series::new(column_name.into(), time_vec) - .cast(&DataType::Time) - .map_err(|e| { - GgsqlError::ValidationError(format!("Cannot create time series: {}", e)) - })? - } - _ => { - return Err(GgsqlError::ValidationError(format!( - "Cannot coerce column '{}' of type {:?} to time", - column_name, dtype - ))); - } + let i64_arr = new_i64_array(time_vec); + cast_array(&i64_arr, &DataType::Time64(TimeUnit::Nanosecond))? } - } - - ArrayElementType::String => { - // Convert to string - series - .cast(&polars::prelude::DataType::String) - .map_err(|e| { - GgsqlError::ValidationError(format!( - "Cannot coerce column '{}' to string: {}", - column_name, e - )) - })? - } + _ => { + return Err(GgsqlError::ValidationError(format!( + "Cannot coerce column '{}' of type {:?} to time", + column_name, dtype + ))); + } + }, + + ArrayElementType::String => cast_array(column, &DataType::Utf8).map_err(|e| { + GgsqlError::ValidationError(format!( + "Cannot coerce column '{}' to string: {}", + column_name, e + )) + })?, }; // Replace the column in the DataFrame - let mut new_df = df.clone(); - let _ = new_df.replace(column_name, new_series); - Ok(new_df) + df.with_column(column_name, new_array) + .map_err(|e| GgsqlError::ValidationError(format!("Failed to replace column: {}", e))) } /// Coerce columns mapped to an aesthetic in all relevant DataFrames. @@ -1015,7 +948,7 @@ pub fn resolve_scales(spec: &mut Plot, data_map: &mut HashMap // by create_missing_scales_post_stat() which runs before position adjustments) if spec.scales[idx].scale_type.is_none() { spec.scales[idx].scale_type = Some(ScaleType::infer_for_aesthetic( - column_refs[0].dtype(), + column_refs[0].data_type(), &aesthetic, )); } @@ -1055,7 +988,7 @@ pub fn find_columns_for_aesthetic<'a>( aesthetic: &str, data_map: &'a HashMap, aesthetic_ctx: &AestheticContext, -) -> Vec<&'a Column> { +) -> Vec<&'a ArrayRef> { let mut column_refs = Vec::new(); let aesthetics_to_check = aesthetic_ctx .internal_position_family(aesthetic) @@ -1264,71 +1197,78 @@ pub fn apply_oob_to_column_numeric( range_max: f64, oob_mode: &str, ) -> Result { - use polars::prelude::*; + use crate::array_util::*; + use arrow::array::Array; + use arrow::datatypes::DataType; - let col = df.column(col_name).map_err(|e| { - GgsqlError::ValidationError(format!("Column '{}' not found: {}", col_name, e)) - })?; + let col = df.column(col_name)?; // Try to cast column to f64 for comparison - let series = col.as_materialized_series(); - let f64_col = series.cast(&DataType::Float64).map_err(|_| { + let f64_col = cast_array(col, &DataType::Float64).map_err(|_| { GgsqlError::ValidationError(format!( "Cannot apply oob to non-numeric column '{}'", col_name )) })?; - let f64_ca = f64_col.f64().map_err(|_| { - GgsqlError::ValidationError(format!( - "Cannot apply oob to non-numeric column '{}'", - col_name - )) - })?; + let f64_arr = as_f64(&f64_col)?; match oob_mode { OOB_CENSOR => { // Filter out rows where values are outside [range_min, range_max] - let mask: BooleanChunked = f64_ca - .into_iter() - .map(|opt| opt.is_none_or(|v| v >= range_min && v <= range_max)) + // Build a boolean mask + let mask_values: Vec = (0..f64_arr.len()) + .map(|i| { + if f64_arr.is_null(i) { + true // Keep nulls + } else { + let v = f64_arr.value(i); + v >= range_min && v <= range_max + } + }) .collect(); - let result = df.filter(&mask).map_err(|e| { - GgsqlError::InternalError(format!("Failed to filter DataFrame: {}", e)) - })?; - Ok(result) + // Filter all columns using the mask + let mask = arrow::array::BooleanArray::from(mask_values); + let mut new_columns = Vec::new(); + let schema = df.schema(); + for (i, field) in schema.fields().iter().enumerate() { + let col_arr = df.get_columns()[i].clone(); + let filtered = arrow::compute::filter(&col_arr, &mask) + .map_err(|e| GgsqlError::InternalError(format!("Failed to filter: {}", e)))?; + new_columns.push((field.name().as_str(), filtered)); + } + DataFrame::new(new_columns) } OOB_SQUISH => { // Clamp values to [range_min, range_max] - let clamped: Float64Chunked = f64_ca - .into_iter() - .map(|opt| opt.map(|v| v.clamp(range_min, range_max))) + let clamped: Vec> = (0..f64_arr.len()) + .map(|i| { + if f64_arr.is_null(i) { + None + } else { + Some(f64_arr.value(i).clamp(range_min, range_max)) + } + }) .collect(); - // Restore temporal type if original column was temporal - // This ensures Date/DateTime/Time values serialize to ISO strings in JSON - let original_dtype = series.dtype().clone(); - let clamped_series = clamped.into_series(); + let clamped_array = new_f64_array(clamped); - let restored_series = match &original_dtype { - DataType::Date | DataType::Datetime(_, _) | DataType::Time => { - clamped_series.cast(&original_dtype).map_err(|e| { + // Restore temporal type if original column was temporal + let original_dtype = col.data_type().clone(); + let restored_array = match &original_dtype { + DataType::Date32 | DataType::Timestamp(_, _) | DataType::Time64(_) => { + cast_array(&clamped_array, &original_dtype).map_err(|e| { GgsqlError::InternalError(format!( "Failed to restore temporal type for '{}': {}", col_name, e )) })? } - _ => clamped_series, + _ => clamped_array, }; - // Replace column with clamped values, maintaining original name - let named_series = restored_series.with_name(col_name.into()); - - df.clone() - .with_column(named_series) - .map(|df| df.clone()) + df.with_column(col_name, restored_array) .map_err(|e| GgsqlError::InternalError(format!("Failed to replace column: {}", e))) } _ => Ok(df.clone()), @@ -1339,13 +1279,23 @@ pub fn apply_oob_to_column_numeric( /// /// Used after OOB transformations to remove rows that were censored to NULL. pub fn filter_null_rows(df: &DataFrame, col_name: &str) -> Result { - let col = df.column(col_name).map_err(|e| { - GgsqlError::ValidationError(format!("Column '{}' not found: {}", col_name, e)) - })?; + use arrow::array::Array; + + let col = df.column(col_name)?; - let mask = col.is_not_null(); - df.filter(&mask) - .map_err(|e| GgsqlError::InternalError(format!("Failed to filter NULL rows: {}", e))) + // Build boolean mask: true where NOT null + let mask_values: Vec = (0..col.len()).map(|i| !col.is_null(i)).collect(); + let mask = arrow::array::BooleanArray::from(mask_values); + + let mut new_columns = Vec::new(); + let schema = df.schema(); + for (i, field) in schema.fields().iter().enumerate() { + let col_arr = df.get_columns()[i].clone(); + let filtered = arrow::compute::filter(&col_arr, &mask) + .map_err(|e| GgsqlError::InternalError(format!("Failed to filter NULL rows: {}", e)))?; + new_columns.push((field.name().as_str(), filtered)); + } + DataFrame::new(new_columns) } /// Apply oob transformation to a single discrete/categorical column in a DataFrame. @@ -1358,51 +1308,37 @@ pub fn apply_oob_to_column_discrete( allowed_values: &HashSet, oob_mode: &str, ) -> Result { - use polars::prelude::*; + use crate::array_util::*; + use arrow::array::Array; // For discrete columns, only censor makes sense (squish is validated out earlier) if oob_mode != OOB_CENSOR { return Ok(df.clone()); } - let col = df.column(col_name).map_err(|e| { - GgsqlError::ValidationError(format!("Column '{}' not found: {}", col_name, e)) - })?; - - let series = col.as_materialized_series(); + let col = df.column(col_name)?; - // Build new series: keep allowed values, set others to null - // This preserves all rows (unlike filtering) so other aesthetics can still be visualized - let new_ca: StringChunked = (0..series.len()) + // Build new string array: keep allowed values, set others to null + let new_values: Vec> = (0..col.len()) .map(|i| { - match series.get(i) { - Ok(val) => { - // Null values are kept as null - if val.is_null() { - return None; - } - // Convert value to string and check membership - let s = val.to_string(); - // Remove quotes if present (polars adds quotes around strings) - let clean = s.trim_matches('"').to_string(); - if allowed_values.contains(&clean) { - Some(clean) - } else { - None // CENSOR to null (not filter row!) - } + if col.is_null(i) { + None + } else { + let s = value_to_string(col, i); + if allowed_values.contains(&s) { + Some(s) + } else { + None // CENSOR to null (not filter row!) } - Err(_) => None, } }) .collect(); - // Replace column (keep all rows) - let new_series = new_ca.into_series().with_name(col_name.into()); - let mut result = df.clone(); - result - .with_column(new_series) - .map_err(|e| GgsqlError::InternalError(format!("Failed to replace column: {}", e)))?; - Ok(result) + let refs: Vec> = new_values.iter().map(|o| o.as_deref()).collect(); + let new_array = new_str_array(refs); + + df.with_column(col_name, new_array) + .map_err(|e| GgsqlError::InternalError(format!("Failed to replace column: {}", e))) } #[cfg(test)] @@ -1410,7 +1346,7 @@ mod tests { use super::*; use crate::plot::ArrayElement; use crate::Geom; - use polars::prelude::DataType; + use arrow::datatypes::DataType; #[test] fn test_aesthetic_context_internal_family() { @@ -1451,24 +1387,27 @@ mod tests { assert_eq!(ScaleType::infer(&DataType::UInt16), ScaleType::continuous()); // Temporal types now use Continuous scale (with temporal transforms) - assert_eq!(ScaleType::infer(&DataType::Date), ScaleType::continuous()); + assert_eq!(ScaleType::infer(&DataType::Date32), ScaleType::continuous()); assert_eq!( - ScaleType::infer(&DataType::Datetime( - polars::prelude::TimeUnit::Microseconds, + ScaleType::infer(&DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, None )), ScaleType::continuous() ); - assert_eq!(ScaleType::infer(&DataType::Time), ScaleType::continuous()); + assert_eq!( + ScaleType::infer(&DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond)), + ScaleType::continuous() + ); // Test discrete types - assert_eq!(ScaleType::infer(&DataType::String), ScaleType::discrete()); + assert_eq!(ScaleType::infer(&DataType::Utf8), ScaleType::discrete()); assert_eq!(ScaleType::infer(&DataType::Boolean), ScaleType::discrete()); } #[test] fn test_resolve_scales_infers_input_range() { - use polars::prelude::*; + use crate::df; // Create a Plot with a scale that needs range inference let mut spec = Plot::new(); @@ -1487,7 +1426,7 @@ mod tests { // Create data with numeric values let df = df! { - "value" => &[1.0f64, 5.0, 10.0] + "value" => vec![1.0f64, 5.0, 10.0] } .unwrap(); @@ -1515,7 +1454,7 @@ mod tests { #[test] fn test_resolve_scales_preserves_explicit_input_range() { - use polars::prelude::*; + use crate::df; // Create a Plot with a scale that already has a range let mut spec = Plot::new(); @@ -1535,7 +1474,7 @@ mod tests { // Create data with different values let df = df! { - "value" => &[1.0f64, 5.0, 10.0] + "value" => vec![1.0f64, 5.0, 10.0] } .unwrap(); @@ -1559,7 +1498,7 @@ mod tests { #[test] fn test_resolve_scales_from_aesthetic_family_input_range() { - use polars::prelude::*; + use crate::df; // Create a Plot where "pos2" scale should get range from pos2min and pos2max columns let mut spec = Plot::new(); @@ -1580,8 +1519,8 @@ mod tests { // Create data where pos2min/pos2max columns have different ranges let df = df! { - "low" => &[5.0f64, 10.0, 15.0], - "high" => &[20.0f64, 25.0, 30.0] + "low" => vec![5.0f64, 10.0, 15.0], + "high" => vec![20.0f64, 25.0, 30.0] } .unwrap(); @@ -1609,7 +1548,7 @@ mod tests { #[test] fn test_resolve_scales_partial_input_range_explicit_min_null_max() { - use polars::prelude::*; + use crate::df; // Create a Plot with a scale that has [0, null] (explicit min, infer max) let mut spec = Plot::new(); @@ -1629,7 +1568,7 @@ mod tests { // Create data with values 1-10 let df = df! { - "value" => &[1.0f64, 5.0, 10.0] + "value" => vec![1.0f64, 5.0, 10.0] } .unwrap(); @@ -1653,7 +1592,7 @@ mod tests { #[test] fn test_resolve_scales_partial_input_range_null_min_explicit_max() { - use polars::prelude::*; + use crate::df; // Create a Plot with a scale that has [null, 100] (infer min, explicit max) let mut spec = Plot::new(); @@ -1673,7 +1612,7 @@ mod tests { // Create data with values 1-10 let df = df! { - "value" => &[1.0f64, 5.0, 10.0] + "value" => vec![1.0f64, 5.0, 10.0] } .unwrap(); @@ -1697,8 +1636,8 @@ mod tests { #[test] fn test_resolve_scales_polar_theta_no_expansion() { + use crate::df; use crate::plot::projection::{Coord, Projection}; - use polars::prelude::*; // Create a Plot with a polar projection let mut spec = Plot::new(); @@ -1725,7 +1664,7 @@ mod tests { // Create data with numeric values let df = df! { - "value" => &[10.0f64, 20.0, 30.0] + "value" => vec![10.0f64, 20.0, 30.0] } .unwrap(); @@ -1759,4 +1698,37 @@ mod tests { _ => panic!("Expected Number elements"), } } + + #[test] + fn test_apply_oob_censor_date32() { + // Regression: Arrow can't cast Date32 directly to Float64. + // apply_oob_to_column_numeric must route through Int32 first. + use arrow::array::{ArrayRef, Date32Array}; + use std::sync::Arc; + + // Days since epoch: 2024-01-01 = 19723, 2024-06-01 = 19875, 2024-12-01 = 20058 + let dates: ArrayRef = Arc::new(Date32Array::from(vec![19723, 19875, 20058])); + let df = DataFrame::new(vec![("date", dates)]).unwrap(); + + // Censor to [2024-03-01, 2024-09-01] ≈ [19783, 19967] → keeps only row 1 + let result = apply_oob_to_column_numeric(&df, "date", 19783.0, 19967.0, OOB_CENSOR) + .expect("oob censor should handle Date32"); + assert_eq!(result.height(), 1); + } + + #[test] + fn test_apply_oob_squish_date32_restores_temporal_type() { + use arrow::array::{ArrayRef, Date32Array}; + use std::sync::Arc; + + let dates: ArrayRef = Arc::new(Date32Array::from(vec![19000, 19875, 21000])); + let df = DataFrame::new(vec![("date", dates)]).unwrap(); + + let result = apply_oob_to_column_numeric(&df, "date", 19723.0, 20089.0, OOB_SQUISH) + .expect("oob squish should handle Date32"); + assert_eq!( + result.column("date").unwrap().data_type(), + &DataType::Date32 + ); + } } diff --git a/src/execute/schema.rs b/src/execute/schema.rs index 3df3c55f..138a0176 100644 --- a/src/execute/schema.rs +++ b/src/execute/schema.rs @@ -6,9 +6,11 @@ //! 2. Apply casting to queries //! 3. complete_schema_ranges() - get min/max from cast queries +use crate::array_util::*; use crate::plot::{AestheticValue, ArrayElement, ColumnInfo, Layer, ParameterValue, Schema}; use crate::{naming, DataFrame, Result}; -use polars::prelude::{DataType, TimeUnit}; +use arrow::array::Array; +use arrow::datatypes::{DataType, TimeUnit}; /// Simple type info tuple: (name, dtype, is_discrete) pub type TypeInfo = (String, DataType, bool); @@ -45,7 +47,7 @@ pub fn build_minmax_query(source_query: &str, column_names: &[&str]) -> String { /// Extract a value from a DataFrame at a given column and row index /// -/// Converts Polars values to ArrayElement for storage in ColumnInfo. +/// Converts Arrow array values to ArrayElement for storage in ColumnInfo. pub fn extract_series_value( df: &DataFrame, column: &str, @@ -54,96 +56,67 @@ pub fn extract_series_value( use crate::plot::ArrayElement; let col = df.column(column).ok()?; - let series = col.as_materialized_series(); - if row >= series.len() { + if row >= col.len() { return None; } - match series.dtype() { - DataType::Int8 => series - .i8() - .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::Int16 => series - .i16() + if col.is_null(row) { + return None; + } + + match col.data_type() { + DataType::Int8 => as_i8(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::Int32 => series - .i32() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::Int16 => as_i16(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::Int64 => series - .i64() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::Int32 => as_i32(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::UInt8 => series - .u8() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::Int64 => as_i64(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::UInt16 => series - .u16() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::UInt8 => as_u8(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::UInt32 => series - .u32() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::UInt16 => as_u16(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::UInt64 => series - .u64() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::UInt32 => as_u32(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::Float32 => series - .f32() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::UInt64 => as_u64(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|v| ArrayElement::Number(v as f64)), - DataType::Float64 => series - .f64() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::Float32 => as_f32(col) .ok() - .and_then(|ca| ca.get(row)) - .map(ArrayElement::Number), - DataType::Boolean => series - .bool() + .map(|a| ArrayElement::Number(a.value(row) as f64)), + DataType::Float64 => as_f64(col).ok().map(|a| ArrayElement::Number(a.value(row))), + DataType::Boolean => as_bool(col) .ok() - .and_then(|ca| ca.get(row)) - .map(ArrayElement::Boolean), - DataType::String => series - .str() + .map(|a| ArrayElement::Boolean(a.value(row))), + DataType::Utf8 => as_str(col) .ok() - .and_then(|ca| ca.get(row)) - .map(|s| ArrayElement::String(s.to_string())), - DataType::Date => { + .map(|a| ArrayElement::String(a.value(row).to_string())), + DataType::Date32 => { // Return numeric days since epoch (for range computation) - series - .date() + as_date32(col) .ok() - .and_then(|ca| ca.physical().get(row)) - .map(|days| ArrayElement::Number(days as f64)) + .map(|a| ArrayElement::Number(a.value(row) as f64)) } - DataType::Datetime(_, _) => { + DataType::Timestamp(_, _) => { // Return numeric microseconds since epoch (for range computation) - series - .datetime() + as_timestamp_us(col) .ok() - .and_then(|ca| ca.physical().get(row)) - .map(|us| ArrayElement::Number(us as f64)) + .map(|a| ArrayElement::Number(a.value(row) as f64)) } - DataType::Time => { + DataType::Time64(_) => { // Return numeric nanoseconds since midnight (for range computation) - series - .time() + as_time64_ns(col) .ok() - .and_then(|ca| ca.physical().get(row)) - .map(|ns| ArrayElement::Number(ns as f64)) + .map(|a| ArrayElement::Number(a.value(row) as f64)) } _ => None, } @@ -169,16 +142,16 @@ where ); let schema_df = execute_query(&schema_query)?; - let type_info: Vec = schema_df - .get_columns() + let schema = schema_df.schema(); + let type_info: Vec = schema + .fields() .iter() - .map(|col| { - let dtype = col.dtype().clone(); - let is_discrete = - matches!(dtype, DataType::String | DataType::Boolean) || dtype.is_categorical(); - (col.name().to_string(), is_discrete, dtype) + .map(|field| { + let dtype = field.data_type().clone(); + let is_discrete = matches!(dtype, DataType::Utf8 | DataType::Boolean) + || matches!(dtype, DataType::Dictionary(_, _)); + (field.name().clone(), dtype, is_discrete) }) - .map(|(name, is_discrete, dtype)| (name, dtype, is_discrete)) .collect(); Ok(type_info) @@ -253,21 +226,23 @@ pub fn add_literal_columns_to_type_info(layers: &[Layer], layer_type_info: &mut for (aesthetic, value) in &layer.mappings.aesthetics { if let AestheticValue::Literal(lit) = value { let (dtype, is_discrete) = match lit { - ParameterValue::String(_) => (DataType::String, true), + ParameterValue::String(_) => (DataType::Utf8, true), ParameterValue::Number(_) => (DataType::Float64, false), ParameterValue::Boolean(_) => (DataType::Boolean, true), ParameterValue::Array(arr) => { // Infer dtype from first element (arrays are homogeneous) if let Some(first) = arr.first() { match first { - ArrayElement::String(_) => (DataType::String, true), + ArrayElement::String(_) => (DataType::Utf8, true), ArrayElement::Number(_) => (DataType::Float64, false), ArrayElement::Boolean(_) => (DataType::Boolean, true), - ArrayElement::Date(_) => (DataType::Date, false), + ArrayElement::Date(_) => (DataType::Date32, false), ArrayElement::DateTime(_) => { - (DataType::Datetime(TimeUnit::Microseconds, None), false) + (DataType::Timestamp(TimeUnit::Microsecond, None), false) + } + ArrayElement::Time(_) => { + (DataType::Time64(TimeUnit::Nanosecond), false) } - ArrayElement::Time(_) => (DataType::Time, false), ArrayElement::Null => { // Null element: default to Float64 (DataType::Float64, false) @@ -328,7 +303,7 @@ pub fn build_aesthetic_schema(layer: &Layer, schema: &Schema) -> Schema { // Column not in schema - add with Unknown type aesthetic_schema.push(ColumnInfo { name: aes_col_name, - dtype: DataType::Unknown(Default::default()), + dtype: DataType::Utf8, is_discrete: false, min: None, max: None, @@ -338,21 +313,23 @@ pub fn build_aesthetic_schema(layer: &Layer, schema: &Schema) -> Schema { AestheticValue::Literal(lit) => { // Literals become columns with appropriate type let (dtype, is_discrete) = match lit { - ParameterValue::String(_) => (DataType::String, true), + ParameterValue::String(_) => (DataType::Utf8, true), ParameterValue::Number(_) => (DataType::Float64, false), ParameterValue::Boolean(_) => (DataType::Boolean, true), ParameterValue::Array(arr) => { // Infer dtype from first element (arrays are homogeneous) if let Some(first) = arr.first() { match first { - ArrayElement::String(_) => (DataType::String, true), + ArrayElement::String(_) => (DataType::Utf8, true), ArrayElement::Number(_) => (DataType::Float64, false), ArrayElement::Boolean(_) => (DataType::Boolean, true), - ArrayElement::Date(_) => (DataType::Date, false), + ArrayElement::Date(_) => (DataType::Date32, false), ArrayElement::DateTime(_) => { - (DataType::Datetime(TimeUnit::Microseconds, None), false) + (DataType::Timestamp(TimeUnit::Microsecond, None), false) + } + ArrayElement::Time(_) => { + (DataType::Time64(TimeUnit::Nanosecond), false) } - ArrayElement::Time(_) => (DataType::Time, false), ArrayElement::Null => { // Null element: default to Float64 (DataType::Float64, false) diff --git a/src/format.rs b/src/format.rs index 32a1d7ed..813b6821 100644 --- a/src/format.rs +++ b/src/format.rs @@ -206,11 +206,13 @@ pub fn apply_label_template( /// let formatted_df = format_dataframe_column(&df, "_aesthetic_label", "Region: {:Title}")?; /// ``` pub fn format_dataframe_column( - df: &polars::prelude::DataFrame, + df: &crate::DataFrame, column_name: &str, template: &str, -) -> Result { - use polars::prelude::*; +) -> Result { + use crate::array_util::{as_f64, as_str, cast_array, new_str_array}; + use arrow::array::Array; + use arrow::datatypes::DataType; // Get the column let column = df @@ -218,48 +220,54 @@ pub fn format_dataframe_column( .map_err(|e| format!("Column '{}' not found: {}", column_name, e))?; // Step 1: Convert entire column to strings - let string_values: Vec> = if let Ok(str_col) = column.str() { + let string_values: Vec> = if let Ok(str_col) = as_str(column) { // String column (includes temporal data auto-converted to ISO format) - str_col - .into_iter() - .map(|opt| opt.map(|s| s.to_string())) + (0..str_col.len()) + .map(|i| { + if str_col.is_null(i) { + None + } else { + Some(str_col.value(i).to_string()) + } + }) .collect() - } else if let Ok(num_col) = column.cast(&DataType::Float64) { + } else if let Ok(cast) = cast_array(column, &DataType::Float64) { // Numeric column - use shared format_number helper for clean integer formatting use crate::plot::format_number; - let f64_col = num_col - .f64() - .map_err(|e| format!("Failed to cast column to f64: {}", e))?; + let f64_col = as_f64(&cast).map_err(|e| format!("Failed to cast column to f64: {}", e))?; - f64_col - .into_iter() - .map(|opt| opt.map(format_number)) + (0..f64_col.len()) + .map(|i| { + if f64_col.is_null(i) { + None + } else { + Some(format_number(f64_col.value(i))) + } + }) .collect() } else { return Err(format!( "Formatting doesn't support type {:?} in column '{}'. Try string or numeric types instead.", - column.dtype(), + column.data_type(), column_name )); }; // Step 2: Apply formatting template to all string values let placeholders = parse_placeholders(template); - let formatted_values: Vec> = string_values + let formatted_owned: Vec> = string_values .into_iter() .map(|opt| opt.map(|s| format_value(&s, template, &placeholders))) .collect(); - let formatted_col = Series::new(column_name.into(), formatted_values); + let formatted_refs: Vec> = + formatted_owned.iter().map(|opt| opt.as_deref()).collect(); + let formatted_col = new_str_array(formatted_refs); // Replace column in DataFrame - let mut new_df = df.clone(); - new_df - .replace(column_name, formatted_col) - .map_err(|e| format!("Failed to replace column: {}", e))?; - - Ok(new_df) + df.with_column(column_name, formatted_col) + .map_err(|e| format!("Failed to replace column: {}", e)) } /// Format a single value using template and parsed placeholders diff --git a/src/lib.rs b/src/lib.rs index 83bdd5b6..e80a2688 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,9 @@ ggsql splits queries at the `VISUALISE` boundary: // Allow complex types in test code (e.g., test case tuples with many elements) #![cfg_attr(test, allow(clippy::type_complexity))] +pub mod array_util; +pub mod compute; +pub mod dataframe; pub mod format; pub mod naming; pub mod parser; @@ -68,8 +71,8 @@ pub use util::{and_list, and_list_quoted, or_list, or_list_quoted}; // #[cfg(feature = "engine")] // pub mod engine; -// DataFrame abstraction (wraps Polars) -pub use polars::prelude::DataFrame; +// DataFrame abstraction (wraps Arrow RecordBatch) +pub use dataframe::DataFrame; /// Main library error type #[derive(thiserror::Error, Debug)] @@ -131,10 +134,11 @@ mod integration_tests { // Verify DataFrame has temporal type (DuckDB returns Datetime for DATE + INTERVAL) assert_eq!(df.get_column_names(), vec!["date", "revenue"]); let date_col = df.column("date").unwrap(); - // DATE + INTERVAL returns Datetime in DuckDB, which is still temporal + use arrow::array::Array; + // DATE + INTERVAL returns Timestamp in DuckDB (arrow), which is still temporal assert!(matches!( - date_col.dtype(), - polars::prelude::DataType::Date | polars::prelude::DataType::Datetime(_, _) + date_col.data_type(), + arrow::datatypes::DataType::Date32 | arrow::datatypes::DataType::Timestamp(_, _) )); // Create visualization spec @@ -190,11 +194,11 @@ mod integration_tests { let df = reader.execute_sql(sql).unwrap(); - // Verify DataFrame has Datetime type + // Verify DataFrame has Timestamp type let timestamp_col = df.column("timestamp").unwrap(); assert!(matches!( - timestamp_col.dtype(), - polars::prelude::DataType::Datetime(_, _) + timestamp_col.data_type(), + arrow::datatypes::DataType::Timestamp(_, _) )); // Create visualization spec @@ -243,16 +247,16 @@ mod integration_tests { // Verify types are preserved // DuckDB treats numeric literals as DECIMAL, which we convert to Float64 assert!(matches!( - df.column("int_col").unwrap().dtype(), - polars::prelude::DataType::Int32 + df.column("int_col").unwrap().data_type(), + arrow::datatypes::DataType::Int32 )); assert!(matches!( - df.column("float_col").unwrap().dtype(), - polars::prelude::DataType::Float64 + df.column("float_col").unwrap().data_type(), + arrow::datatypes::DataType::Float64 )); assert!(matches!( - df.column("bool_col").unwrap().dtype(), - polars::prelude::DataType::Boolean + df.column("bool_col").unwrap().data_type(), + arrow::datatypes::DataType::Boolean )); // Create visualization spec @@ -299,16 +303,16 @@ mod integration_tests { // Verify types assert!(matches!( - df.column("int_col").unwrap().dtype(), - polars::prelude::DataType::Int32 + df.column("int_col").unwrap().data_type(), + arrow::datatypes::DataType::Int32 )); assert!(matches!( - df.column("float_col").unwrap().dtype(), - polars::prelude::DataType::Float64 + df.column("float_col").unwrap().data_type(), + arrow::datatypes::DataType::Float64 )); assert!(matches!( - df.column("str_col").unwrap().dtype(), - polars::prelude::DataType::String + df.column("str_col").unwrap().data_type(), + arrow::datatypes::DataType::Utf8 )); // Create viz spec @@ -403,8 +407,8 @@ mod integration_tests { // DATE_TRUNC returns Date type (not Datetime) let day_col = df.column("day").unwrap(); assert!(matches!( - day_col.dtype(), - polars::prelude::DataType::Date | polars::prelude::DataType::Datetime(_, _) + day_col.data_type(), + arrow::datatypes::DataType::Date32 | arrow::datatypes::DataType::Timestamp(_, _) )); let mut spec = Plot::new(); @@ -443,16 +447,16 @@ mod integration_tests { // All should be Float64 assert!(matches!( - df.column("small").unwrap().dtype(), - polars::prelude::DataType::Float64 + df.column("small").unwrap().data_type(), + arrow::datatypes::DataType::Float64 )); assert!(matches!( - df.column("medium").unwrap().dtype(), - polars::prelude::DataType::Float64 + df.column("medium").unwrap().data_type(), + arrow::datatypes::DataType::Float64 )); assert!(matches!( - df.column("large").unwrap().dtype(), - polars::prelude::DataType::Float64 + df.column("large").unwrap().data_type(), + arrow::datatypes::DataType::Float64 )); let mut spec = Plot::new(); @@ -497,20 +501,20 @@ mod integration_tests { // Verify types assert!(matches!( - df.column("tiny").unwrap().dtype(), - polars::prelude::DataType::Int8 + df.column("tiny").unwrap().data_type(), + arrow::datatypes::DataType::Int8 )); assert!(matches!( - df.column("small").unwrap().dtype(), - polars::prelude::DataType::Int16 + df.column("small").unwrap().data_type(), + arrow::datatypes::DataType::Int16 )); assert!(matches!( - df.column("int").unwrap().dtype(), - polars::prelude::DataType::Int32 + df.column("int").unwrap().data_type(), + arrow::datatypes::DataType::Int32 )); assert!(matches!( - df.column("big").unwrap().dtype(), - polars::prelude::DataType::Int64 + df.column("big").unwrap().data_type(), + arrow::datatypes::DataType::Int64 )); let mut spec = Plot::new(); diff --git a/src/plot/facet/resolve.rs b/src/plot/facet/resolve.rs index fb6b2b80..e2b0c4cc 100644 --- a/src/plot/facet/resolve.rs +++ b/src/plot/facet/resolve.rs @@ -22,20 +22,27 @@ impl FacetDataContext { /// /// Extracts unique values from each facet variable for label resolution. pub fn from_dataframe(df: &DataFrame, variables: &[String]) -> Self { + use crate::array_util::value_to_string; + use arrow::array::Array; + use std::collections::HashSet; + let mut unique_values = HashMap::new(); let mut num_levels = 1; for (i, var) in variables.iter().enumerate() { if let Ok(col) = df.column(var) { - let unique = col.unique().ok(); - let values: Vec = unique - .as_ref() - .map(|u| { - (0..u.len()) - .filter_map(|j| u.get(j).ok().map(|v| format!("{}", v))) - .collect() - }) - .unwrap_or_default(); + // Collect unique values manually + let mut seen = HashSet::new(); + let mut values = Vec::new(); + for j in 0..col.len() { + if col.is_null(j) { + continue; + } + let s = value_to_string(col, j); + if seen.insert(s.clone()) { + values.push(s); + } + } if i == 0 { num_levels = values.len().max(1); @@ -325,8 +332,8 @@ fn apply_defaults(facet: &mut Facet, context: &FacetDataContext) { #[cfg(test)] mod tests { use super::*; + use crate::df; use crate::plot::facet::FacetLayout; - use polars::prelude::*; /// Default position names for cartesian coords const CARTESIAN: &[&str] = &["x", "y"]; @@ -547,8 +554,8 @@ mod tests { #[test] fn test_context_from_dataframe() { let df = df! { - "category" => &["A", "B", "C", "A", "B", "C"], - "value" => &[1, 2, 3, 4, 5, 6], + "category" => vec!["A", "B", "C", "A", "B", "C"], + "value" => vec![1i32, 2, 3, 4, 5, 6], } .unwrap(); @@ -559,7 +566,7 @@ mod tests { #[test] fn test_context_from_dataframe_missing_column() { let df = df! { - "other" => &[1, 2, 3], + "other" => vec![1i32, 2, 3], } .unwrap(); @@ -570,7 +577,7 @@ mod tests { #[test] fn test_context_from_dataframe_empty_variables() { let df = df! { - "x" => &[1, 2, 3], + "x" => vec![1i32, 2, 3], } .unwrap(); diff --git a/src/plot/layer/geom/area.rs b/src/plot/layer/geom/area.rs index 7427e0e2..a9df6bff 100644 --- a/src/plot/layer/geom/area.rs +++ b/src/plot/layer/geom/area.rs @@ -65,7 +65,7 @@ impl GeomTrait for Area { _aesthetics: &Mappings, _group_by: &[String], _parameters: &std::collections::HashMap, - _execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, _dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { // Area geom needs ordering by pos1 (domain axis) for proper rendering diff --git a/src/plot/layer/geom/density.rs b/src/plot/layer/geom/density.rs index a52e777f..89910be9 100644 --- a/src/plot/layer/geom/density.rs +++ b/src/plot/layer/geom/density.rs @@ -109,7 +109,7 @@ impl GeomTrait for Density { aesthetics: &Mappings, group_by: &[String], parameters: &std::collections::HashMap, - _execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn SqlDialect, ) -> crate::Result { // Density geom: no tails limit (don't set tails parameter, defaults to None) @@ -608,6 +608,7 @@ mod tests { use crate::reader::duckdb::DuckDBReader; use crate::reader::AnsiDialect; use crate::reader::Reader; + use arrow::array::Array; #[test] fn test_density_sql_no_groups() { @@ -772,27 +773,29 @@ mod tests { // Verify density integrates to ~2 (one per group) // Compute grid spacing dynamically from actual data + use crate::array_util::{as_f64, cast_array}; + use arrow::datatypes::DataType; let x_col = df.column("__ggsql_stat_x").expect("x exists"); // Cast to f64 if needed (AnsiDialect generates f32 from REAL) - let x_col = x_col - .cast(&polars::prelude::DataType::Float64) - .expect("can cast to f64"); - let x_vals = x_col.f64().expect("x is f64"); - let x_min = x_vals.into_iter().flatten().fold(f64::INFINITY, f64::min); - let x_max = x_vals - .into_iter() - .flatten() + let x_col = cast_array(x_col, &DataType::Float64).expect("can cast to f64"); + let x_vals = as_f64(&x_col).expect("x is f64"); + let x_min = (0..x_vals.len()) + .filter(|&i| !x_vals.is_null(i)) + .map(|i| x_vals.value(i)) + .fold(f64::INFINITY, f64::min); + let x_max = (0..x_vals.len()) + .filter(|&i| !x_vals.is_null(i)) + .map(|i| x_vals.value(i)) .fold(f64::NEG_INFINITY, f64::max); let dx = (x_max - x_min) / 511.0; // (n - 1) for 512 points let density_col = df .column("__ggsql_stat_density") .expect("density column exists"); - let total: f64 = density_col - .f64() - .expect("density is f64") - .into_iter() - .flatten() + let density_vals = as_f64(density_col).expect("density is f64"); + let total: f64 = (0..density_vals.len()) + .filter(|&i| !density_vals.is_null(i)) + .map(|i| density_vals.value(i)) .sum(); let integral = total * dx; @@ -895,34 +898,33 @@ mod tests { // Compute integral using trapezoidal rule // Get actual grid spacing from the data (dynamically computed range) + use crate::array_util::{as_f64, cast_array}; + use arrow::datatypes::DataType; let x_col = df.column("__ggsql_stat_x").expect("x exists"); // Cast to f64 if needed (AnsiDialect generates f32 from REAL) - let x_col = x_col - .cast(&polars::prelude::DataType::Float64) - .expect("can cast to f64"); - let x_vals = x_col.f64().expect("x is f64"); - let x_min = x_vals.into_iter().flatten().fold(f64::INFINITY, f64::min); - let x_max = x_vals - .into_iter() - .flatten() + let x_col = cast_array(x_col, &DataType::Float64).expect("can cast to f64"); + let x_vals = as_f64(&x_col).expect("x is f64"); + let x_min = (0..x_vals.len()) + .filter(|&i| !x_vals.is_null(i)) + .map(|i| x_vals.value(i)) + .fold(f64::INFINITY, f64::min); + let x_max = (0..x_vals.len()) + .filter(|&i| !x_vals.is_null(i)) + .map(|i| x_vals.value(i)) .fold(f64::NEG_INFINITY, f64::max); let dx = (x_max - x_min) / (df.height() as f64 - 1.0); let density_col = df.column("__ggsql_stat_density").expect("density exists"); - let total: f64 = density_col - .f64() - .expect("density is f64") - .into_iter() - .flatten() + let density_vals = as_f64(density_col).expect("density is f64"); + let total: f64 = (0..density_vals.len()) + .filter(|&i| !density_vals.is_null(i)) + .map(|i| density_vals.value(i)) .sum(); let integral = total * dx; // Verify all density values are non-negative - let all_non_negative = density_col - .f64() - .expect("density is f64") - .into_iter() - .all(|v| v.map(|x| x >= 0.0).unwrap_or(true)); + let all_non_negative = (0..density_vals.len()) + .all(|i| density_vals.is_null(i) || density_vals.value(i) >= 0.0); assert!( all_non_negative, "All density values should be non-negative for kernel '{}'", @@ -1029,17 +1031,15 @@ mod tests { .column("__ggsql_stat_density") .expect("density exists"); - let unweighted_values: Vec = density_unweighted - .f64() - .expect("f64") - .into_iter() - .flatten() + let unweighted_arr = crate::array_util::as_f64(density_unweighted).expect("f64"); + let unweighted_values: Vec = (0..unweighted_arr.len()) + .filter(|&i| !unweighted_arr.is_null(i)) + .map(|i| unweighted_arr.value(i)) .collect(); - let weighted_values: Vec = density_weighted - .f64() - .expect("f64") - .into_iter() - .flatten() + let weighted_arr = crate::array_util::as_f64(density_weighted).expect("f64"); + let weighted_values: Vec = (0..weighted_arr.len()) + .filter(|&i| !weighted_arr.is_null(i)) + .map(|i| weighted_arr.value(i)) .collect(); assert_eq!(unweighted_values.len(), weighted_values.len()); @@ -1096,16 +1096,16 @@ mod tests { // With REMAPPING intensity AS y, we get: __ggsql_aes_pos1__, __ggsql_aes_pos2__ // (pos2 is mapped from intensity, not the default density) - let col_names: Vec<&str> = df.get_column_names().iter().map(|s| s.as_str()).collect(); + let col_names = df.get_column_names(); // Should have pos1 and pos2 aesthetics after remapping (internal names) assert!( - col_names.contains(&"__ggsql_aes_pos1__"), + col_names.iter().any(|s| s == "__ggsql_aes_pos1__"), "Should have pos1 aesthetic, got: {:?}", col_names ); assert!( - col_names.contains(&"__ggsql_aes_pos2__"), + col_names.iter().any(|s| s == "__ggsql_aes_pos2__"), "Should have pos2 aesthetic, got: {:?}", col_names ); @@ -1117,11 +1117,8 @@ mod tests { let y_col = df .column("__ggsql_aes_pos2__") .expect("pos2 aesthetic exists"); - let all_non_negative = y_col - .f64() - .expect("y is f64") - .into_iter() - .all(|v| v.map(|x| x >= 0.0).unwrap_or(true)); + let y_arr = crate::array_util::as_f64(y_col).expect("y is f64"); + let all_non_negative = (0..y_arr.len()).all(|i| y_arr.is_null(i) || y_arr.value(i) >= 0.0); assert!( all_non_negative, "All y values (from intensity) should be non-negative" diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index 4a4165ad..66400e56 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -292,23 +292,40 @@ pub fn extract_histogram_min_max(df: &DataFrame) -> Result<(f64, f64)> { )); } - let min_val = df - .column("min_val") - .ok() - .and_then(|s| s.get(0).ok()) - .and_then(|s| s.try_extract::().ok()) - .ok_or_else(|| { - GgsqlError::ValidationError("Could not extract min value for histogram".to_string()) - })?; - - let max_val = df - .column("max_val") - .ok() - .and_then(|s| s.get(0).ok()) - .and_then(|s| s.try_extract::().ok()) - .ok_or_else(|| { - GgsqlError::ValidationError("Could not extract max value for histogram".to_string()) - })?; + let extract = |name: &str| -> Option { + use arrow::array::Array; + use arrow::datatypes::DataType; + let col = df.column(name).ok()?; + if col.is_null(0) { + return None; + } + let casted = crate::array_util::cast_array(col, &DataType::Float64).ok()?; + crate::array_util::as_f64(&casted).ok().map(|a| a.value(0)) + }; + + let min_val = extract("min_val").ok_or_else(|| { + GgsqlError::ValidationError("Could not extract min value for histogram".to_string()) + })?; + + let max_val = extract("max_val").ok_or_else(|| { + GgsqlError::ValidationError("Could not extract max value for histogram".to_string()) + })?; Ok((min_val, max_val)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::df; + + #[test] + fn test_extract_min_max_null_errors() { + let df = df! { + "min_val" => vec![None::], + "max_val" => vec![None::], + } + .unwrap(); + assert!(extract_histogram_min_max(&df).is_err()); + } +} diff --git a/src/plot/layer/geom/line.rs b/src/plot/layer/geom/line.rs index e0009961..a8ded3b1 100644 --- a/src/plot/layer/geom/line.rs +++ b/src/plot/layer/geom/line.rs @@ -50,7 +50,7 @@ impl GeomTrait for Line { _aesthetics: &Mappings, _group_by: &[String], _parameters: &std::collections::HashMap, - _execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, _dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { // Line geom needs ordering by pos1 (domain axis) for proper rendering diff --git a/src/plot/layer/geom/ribbon.rs b/src/plot/layer/geom/ribbon.rs index 46470c72..87d4636c 100644 --- a/src/plot/layer/geom/ribbon.rs +++ b/src/plot/layer/geom/ribbon.rs @@ -50,7 +50,7 @@ impl GeomTrait for Ribbon { _aesthetics: &Mappings, _group_by: &[String], _parameters: &std::collections::HashMap, - _execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, _dialect: &dyn crate::reader::SqlDialect, ) -> crate::Result { // Ribbon geom needs ordering by pos1 (domain axis) for proper rendering diff --git a/src/plot/layer/geom/smooth.rs b/src/plot/layer/geom/smooth.rs index 4509fc5c..2f053235 100644 --- a/src/plot/layer/geom/smooth.rs +++ b/src/plot/layer/geom/smooth.rs @@ -98,7 +98,7 @@ impl GeomTrait for Smooth { aesthetics: &Mappings, group_by: &[String], parameters: &std::collections::HashMap, - _execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn SqlDialect, ) -> crate::Result { // Get method from parameters (validated by ParamConstraint::string_option) diff --git a/src/plot/layer/geom/tile.rs b/src/plot/layer/geom/tile.rs index bee4becd..3633f944 100644 --- a/src/plot/layer/geom/tile.rs +++ b/src/plot/layer/geom/tile.rs @@ -344,7 +344,7 @@ fn generate_continuous_position_expressions( mod tests { use super::*; use crate::plot::types::{AestheticValue, ColumnInfo}; - use polars::prelude::DataType; + use arrow::datatypes::DataType; // ==================== Helper Functions ==================== @@ -357,7 +357,7 @@ mod tests { ColumnInfo { name: "__ggsql_aes_pos1__".to_string(), dtype: if discrete_cols.contains(&"pos1") { - DataType::String + DataType::Utf8 } else { DataType::Float64 }, @@ -389,7 +389,7 @@ mod tests { ColumnInfo { name: "__ggsql_aes_pos2__".to_string(), dtype: if discrete_cols.contains(&"pos2") { - DataType::String + DataType::Utf8 } else { DataType::Float64 }, @@ -424,7 +424,7 @@ mod tests { for col_name in extra_cols { schema.push(ColumnInfo { name: col_name.to_string(), - dtype: DataType::String, + dtype: DataType::Utf8, is_discrete: true, min: None, max: None, diff --git a/src/plot/layer/geom/violin.rs b/src/plot/layer/geom/violin.rs index 06d2239c..7fef55eb 100644 --- a/src/plot/layer/geom/violin.rs +++ b/src/plot/layer/geom/violin.rs @@ -10,7 +10,6 @@ use crate::{ }, DataFrame, GgsqlError, Mappings, Result, }; -use polars::prelude::*; use std::collections::HashMap; /// Valid kernel types for violin density estimation @@ -122,7 +121,7 @@ impl GeomTrait for Violin { aesthetics: &Mappings, group_by: &[String], parameters: &HashMap, - _execute_query: &dyn Fn(&str) -> crate::Result, + _execute_query: &dyn Fn(&str) -> crate::Result, dialect: &dyn crate::reader::SqlDialect, ) -> Result { stat_violin(query, aesthetics, group_by, parameters, dialect) @@ -171,13 +170,11 @@ fn scale_offset_column(df: DataFrame, offset_col: &str, half_width: f64) -> Resu } // Get global max of offset column - let max_val = df - .column(offset_col) - .map_err(|e| GgsqlError::InternalError(format!("Failed to get offset column: {}", e)))? - .f64() - .map_err(|e| GgsqlError::InternalError(format!("Offset column must be f64: {}", e)))? - .max() - .unwrap_or(1.0); + use arrow::array::Array; + let offset_arr = df.column(offset_col)?; + let f64_arr = crate::array_util::as_f64(offset_arr) + .map_err(|e| GgsqlError::InternalError(format!("Offset column must be f64: {}", e)))?; + let max_val = arrow::compute::max(f64_arr).unwrap_or(1.0); if max_val <= 0.0 { return Ok(df); @@ -185,11 +182,17 @@ fn scale_offset_column(df: DataFrame, offset_col: &str, half_width: f64) -> Resu // Scale: new_offset = offset * half_width / max_val let scale_factor = half_width / max_val; - let scaled = df - .lazy() - .with_column((col(offset_col) * lit(scale_factor)).alias(offset_col)) - .collect() - .map_err(|e| GgsqlError::InternalError(format!("Failed to scale offset: {}", e)))?; + let scaled_values: Vec> = (0..f64_arr.len()) + .map(|i| { + if f64_arr.is_null(i) { + None + } else { + Some(f64_arr.value(i) * scale_factor) + } + }) + .collect(); + let scaled_array = crate::array_util::new_f64_array(scaled_values); + let scaled = df.with_column(offset_col, scaled_array)?; Ok(scaled) } @@ -239,6 +242,19 @@ mod tests { use crate::reader::duckdb::DuckDBReader; use crate::reader::AnsiDialect; use crate::reader::Reader; + use arrow::array::Array; + + /// Count unique non-null string values in an ArrayRef. + fn count_unique_strings(col: &arrow::array::ArrayRef) -> usize { + let arr = crate::array_util::as_str(col).expect("expected string array"); + let mut seen = std::collections::HashSet::new(); + for i in 0..arr.len() { + if !arr.is_null(i) { + seen.insert(arr.value(i).to_string()); + } + } + seen.len() + } // ==================== Helper Functions ==================== @@ -312,18 +328,17 @@ mod tests { let df = execute(&stat_query).expect("Generated SQL should execute"); // Should have columns: pos2 (y), density, and species (the x grouping) - let col_names: Vec<&str> = - df.get_column_names().iter().map(|s| s.as_str()).collect(); - assert!(col_names.contains(&"__ggsql_stat_pos2")); - assert!(col_names.contains(&"__ggsql_stat_density")); - assert!(col_names.contains(&"species")); + let col_names = df.get_column_names(); + assert!(col_names.iter().any(|s| s == "__ggsql_stat_pos2")); + assert!(col_names.iter().any(|s| s == "__ggsql_stat_density")); + assert!(col_names.iter().any(|s| s == "species")); // Should have multiple rows per species (512 grid points per species) assert!(df.height() > 0); // Verify we have all three species let species_col = df.column("species").unwrap(); - let unique_species = species_col.n_unique().unwrap(); + let unique_species = count_unique_strings(species_col); assert_eq!(unique_species, 3, "Should have 3 unique species"); } _ => panic!("Expected Transformed result"), @@ -377,24 +392,23 @@ mod tests { let df = execute(&stat_query).expect("Generated SQL should execute"); // Should have columns: pos2 (y), density, species (x), and island (color group) - let col_names: Vec<&str> = - df.get_column_names().iter().map(|s| s.as_str()).collect(); - assert!(col_names.contains(&"__ggsql_stat_pos2")); - assert!(col_names.contains(&"__ggsql_stat_density")); - assert!(col_names.contains(&"species")); - assert!(col_names.contains(&"island")); + let col_names = df.get_column_names(); + assert!(col_names.iter().any(|s| s == "__ggsql_stat_pos2")); + assert!(col_names.iter().any(|s| s == "__ggsql_stat_density")); + assert!(col_names.iter().any(|s| s == "species")); + assert!(col_names.iter().any(|s| s == "island")); // Should have multiple rows per species-island combination assert!(df.height() > 0); // Verify we have multiple species let species_col = df.column("species").unwrap(); - let unique_species = species_col.n_unique().unwrap(); + let unique_species = count_unique_strings(species_col); assert!(unique_species >= 2, "Should have at least 2 unique species"); // Verify we have multiple islands let island_col = df.column("island").unwrap(); - let unique_islands = island_col.n_unique().unwrap(); + let unique_islands = count_unique_strings(island_col); assert!(unique_islands >= 2, "Should have at least 2 unique islands"); } _ => panic!("Expected Transformed result"), @@ -502,13 +516,14 @@ mod tests { #[test] fn test_violin_post_process_scales_offset() { + use crate::df; let violin = Violin; let offset_col = naming::aesthetic_column("offset"); // Create a DataFrame with offset values let df = df! { - offset_col.as_str() => [0.0, 0.5, 1.0, 0.25], - "__ggsql_aes_pos2__" => [1.0, 2.0, 3.0, 4.0], + offset_col.as_str() => vec![0.0, 0.5, 1.0, 0.25], + "__ggsql_aes_pos2__" => vec![1.0, 2.0, 3.0, 4.0], } .unwrap(); @@ -517,8 +532,11 @@ mod tests { let parameters = HashMap::new(); let result = violin.post_process(df, ¶meters).unwrap(); - let scaled_offset = result.column(&offset_col).unwrap().f64().unwrap(); - let values: Vec = scaled_offset.into_iter().flatten().collect(); + let scaled_arr = crate::array_util::as_f64(result.column(&offset_col).unwrap()).unwrap(); + let values: Vec = (0..scaled_arr.len()) + .filter(|&i| !scaled_arr.is_null(i)) + .map(|i| scaled_arr.value(i)) + .collect(); // Max offset (1.0) should be scaled to 0.45 (half_width) // Other values should be proportionally scaled @@ -533,13 +551,14 @@ mod tests { #[test] fn test_violin_post_process_custom_width() { + use crate::df; let violin = Violin; let offset_col = naming::aesthetic_column("offset"); // Create a DataFrame with offset values let df = df! { - offset_col.as_str() => [0.0, 0.5, 1.0], - "__ggsql_aes_pos2__" => [1.0, 2.0, 3.0], + offset_col.as_str() => vec![0.0, 0.5, 1.0], + "__ggsql_aes_pos2__" => vec![1.0, 2.0, 3.0], } .unwrap(); @@ -549,8 +568,11 @@ mod tests { let result = violin.post_process(df, ¶meters).unwrap(); - let scaled_offset = result.column(&offset_col).unwrap().f64().unwrap(); - let values: Vec = scaled_offset.into_iter().flatten().collect(); + let scaled_arr = crate::array_util::as_f64(result.column(&offset_col).unwrap()).unwrap(); + let values: Vec = (0..scaled_arr.len()) + .filter(|&i| !scaled_arr.is_null(i)) + .map(|i| scaled_arr.value(i)) + .collect(); // Max offset (1.0) should be scaled to 0.3 (half_width) assert!((values[0] - 0.0).abs() < 1e-6, "0.0 should stay 0.0"); @@ -560,11 +582,12 @@ mod tests { #[test] fn test_violin_post_process_no_offset_column() { + use crate::df; let violin = Violin; // Create a DataFrame without offset column let df = df! { - "__ggsql_aes_pos2__" => [1.0, 2.0, 3.0], + "__ggsql_aes_pos2__" => vec![1.0, 2.0, 3.0], } .unwrap(); diff --git a/src/plot/layer/orientation.rs b/src/plot/layer/orientation.rs index c2b0b096..2f7a196a 100644 --- a/src/plot/layer/orientation.rs +++ b/src/plot/layer/orientation.rs @@ -266,9 +266,7 @@ pub fn flip_dataframe_position_columns( df: DataFrame, aesthetic_ctx: &AestheticContext, ) -> DataFrame { - use polars::prelude::*; - - // Collect renames needed before consuming df + // Collect renames needed let renames: Vec<(String, String)> = df .get_column_names() .iter() @@ -289,21 +287,21 @@ pub fn flip_dataframe_position_columns( return df; } - let mut lazy = df.lazy(); + let mut result = df; // First pass: rename to temp names for (from, to) in &renames { let temp = format!("{}_temp", to); - lazy = lazy.rename([from.as_str()], [temp.as_str()], true); + result = result.rename(from, &temp).expect("rename should not fail"); } // Second pass: remove temp suffix for (_, to) in &renames { let temp = format!("{}_temp", to); - lazy = lazy.rename([temp.as_str()], [to.as_str()], true); + result = result.rename(&temp, to).expect("rename should not fail"); } - lazy.collect().expect("rename should not fail") + result } #[cfg(test)] diff --git a/src/plot/layer/position/dodge.rs b/src/plot/layer/position/dodge.rs index 09e9dfe9..58a022ae 100644 --- a/src/plot/layer/position/dodge.rs +++ b/src/plot/layer/position/dodge.rs @@ -10,9 +10,9 @@ use super::{ compute_dodge_offsets, is_continuous_scale, non_facet_partition_cols, Layer, PositionTrait, PositionType, }; +use crate::array_util::{new_f64_array_non_null, value_to_string}; use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue}; -use crate::{naming, DataFrame, GgsqlError, Plot, Result}; -use polars::prelude::*; +use crate::{compute, naming, DataFrame, Plot, Result}; use std::collections::HashMap; /// Result of computing group indices for dodge/jitter operations. @@ -51,11 +51,8 @@ pub fn compute_group_indices( for row_idx in 0..n_rows { let mut key_parts: Vec = Vec::with_capacity(group_cols.len()); for col_name in group_cols { - let col = df.column(col_name).unwrap(); - let val = col.get(row_idx).map_err(|e| { - GgsqlError::InternalError(format!("Failed to get value at row {}: {}", row_idx, e)) - })?; - key_parts.push(format!("{}", val)); + let col = df.column(col_name)?; + key_parts.push(value_to_string(col, row_idx)); } composite_keys.push(key_parts.join("\x00")); // Use null byte as separator } @@ -193,45 +190,45 @@ fn apply_dodge_with_width( // Compute dodge offsets using shared logic let offsets = compute_dodge_offsets(&indices, n_groups, bar_width, dodge_pos1, dodge_pos2); - let mut lf = df.lazy(); + let mut result = df; // Apply the computed offsets if let Some(pos1_offsets) = offsets.pos1 { - lf = lf.with_column( - lit(Series::new(pos1offset_col.clone().into(), pos1_offsets)).alias(&pos1offset_col), - ); + result = result.with_column(&pos1offset_col, new_f64_array_non_null(pos1_offsets))?; } if let Some(pos2_offsets) = offsets.pos2 { - lf = lf.with_column( - lit(Series::new(pos2offset_col.clone().into(), pos2_offsets)).alias(&pos2offset_col), - ); + result = result.with_column(&pos2offset_col, new_f64_array_non_null(pos2_offsets))?; } // If offset column exists (e.g., violin), scale it by the offset scale factor if has_offset_col { - lf = lf.with_column((col(&offset_col) / lit(offsets.offset_scale)).alias(&offset_col)); + let col = result.column(&offset_col)?; + let casted = crate::array_util::cast_array(col, &arrow::datatypes::DataType::Float64)?; + let f64_arr = crate::array_util::as_f64(&casted)?; + let scaled = compute::divide_scalar(f64_arr, offsets.offset_scale); + result = result.with_column( + &offset_col, + std::sync::Arc::new(scaled) as arrow::array::ArrayRef, + )?; } - // Collect the result - let final_df = lf.collect().map_err(|e| { - GgsqlError::InternalError(format!("Dodge position adjustment failed: {}", e)) - })?; - - Ok((final_df, Some(offsets.adjusted_width))) + Ok((result, Some(offsets.adjusted_width))) } #[cfg(test)] mod tests { use super::*; + use crate::array_util::as_f64; + use crate::df; use crate::plot::layer::Geom; use crate::plot::{AestheticValue, Mappings, Scale, ScaleType}; fn make_test_df() -> DataFrame { df! { - "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + "__ggsql_aes_pos1__" => vec!["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], } .unwrap() } @@ -301,18 +298,15 @@ mod tests { "pos2offset column should NOT be created when pos2 is continuous" ); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); // With 2 groups (X, Y) and default width 0.9: // - adjusted_width = 0.9 / 2 = 0.45 // - center_offset = 0.5 // - Group X: center = (0 - 0.5) * 0.45 = -0.225 // - Group Y: center = (1 - 0.5) * 0.45 = +0.225 - let offsets: Vec = offset.into_iter().flatten().collect(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); assert!( offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), "Should have offset -0.225 for group X, got {:?}", @@ -359,18 +353,15 @@ mod tests { "pos2offset column should be created" ); - let offset = result - .column("__ggsql_aes_pos2offset__") - .unwrap() - .f64() - .unwrap(); + let offset_col = result.column("__ggsql_aes_pos2offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); // With 2 groups (X, Y) and default width 0.9: // - adjusted_width = 0.9 / 2 = 0.45 // - center_offset = 0.5 // - Group X: center = (0 - 0.5) * 0.45 = -0.225 // - Group Y: center = (1 - 0.5) * 0.45 = +0.225 - let offsets: Vec = offset.into_iter().flatten().collect(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); assert!( offsets.iter().any(|&v| (v - (-0.225)).abs() < 0.001), "Should have offset -0.225 for group X, got {:?}", @@ -421,19 +412,17 @@ mod tests { // center_offset = (2 - 1) / 2 = 0.5 // Group 0 (X): col=0, row=0 → pos1=(-0.5)*0.45=-0.225, pos2=(-0.5)*0.45=-0.225 // Group 1 (Y): col=1, row=0 → pos1=(0.5)*0.45=0.225, pos2=(-0.5)*0.45=-0.225 - let pos1_offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let pos2_offset = result - .column("__ggsql_aes_pos2offset__") - .unwrap() - .f64() - .unwrap(); - - let pos1_offsets: Vec = pos1_offset.into_iter().flatten().collect(); - let pos2_offsets: Vec = pos2_offset.into_iter().flatten().collect(); + let pos1_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let pos1_offset = as_f64(pos1_col).unwrap(); + let pos2_col = result.column("__ggsql_aes_pos2offset__").unwrap(); + let pos2_offset = as_f64(pos2_col).unwrap(); + + let pos1_offsets: Vec = (0..pos1_offset.len()) + .map(|i| pos1_offset.value(i)) + .collect(); + let pos2_offsets: Vec = (0..pos2_offset.len()) + .map(|i| pos2_offset.value(i)) + .collect(); // Verify we have both expected pos1 offsets assert!( @@ -473,9 +462,9 @@ mod tests { let dodge = Dodge; let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], - "__ggsql_aes_fill__" => ["G1", "G2", "G3", "G4"], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_fill__" => vec!["G1", "G2", "G3", "G4"], } .unwrap(); @@ -512,19 +501,17 @@ mod tests { // G3: col=0, row=1 → (-0.5, +0.5) * adjusted_width // G4: col=1, row=1 → (+0.5, +0.5) * adjusted_width - let pos1_offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let pos2_offset = result - .column("__ggsql_aes_pos2offset__") - .unwrap() - .f64() - .unwrap(); + let pos1_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let pos1_offset = as_f64(pos1_col).unwrap(); + let pos2_col = result.column("__ggsql_aes_pos2offset__").unwrap(); + let pos2_offset = as_f64(pos2_col).unwrap(); - let pos1_offsets: Vec = pos1_offset.into_iter().flatten().collect(); - let pos2_offsets: Vec = pos2_offset.into_iter().flatten().collect(); + let pos1_offsets: Vec = (0..pos1_offset.len()) + .map(|i| pos1_offset.value(i)) + .collect(); + let pos2_offsets: Vec = (0..pos2_offset.len()) + .map(|i| pos2_offset.value(i)) + .collect(); // Verify we have both positive and negative offsets in both dimensions assert!( @@ -599,18 +586,15 @@ mod tests { let (result, width) = dodge.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); // With 2 groups and custom width 0.6: // - adjusted_width = 0.6 / 2 = 0.3 // - center_offset = 0.5 // - Group X: center = (0 - 0.5) * 0.3 = -0.15 // - Group Y: center = (1 - 0.5) * 0.3 = +0.15 - let offsets: Vec = offset.into_iter().flatten().collect(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); assert!( offsets.iter().any(|&v| (v - (-0.15)).abs() < 0.001), "Should have offset -0.15 for group X, got {:?}", diff --git a/src/plot/layer/position/identity.rs b/src/plot/layer/position/identity.rs index c62db982..571a2cc7 100644 --- a/src/plot/layer/position/identity.rs +++ b/src/plot/layer/position/identity.rs @@ -34,7 +34,7 @@ impl PositionTrait for Identity { #[cfg(test)] mod tests { use super::*; - use polars::prelude::*; + use crate::df; #[test] fn test_identity_no_change() { @@ -42,8 +42,8 @@ mod tests { assert_eq!(identity.position_type(), PositionType::Identity); let df = df! { - "x" => [1, 2, 3], - "y" => [10, 20, 30], + "x" => vec![1i32, 2, 3], + "y" => vec![10i32, 20, 30], } .unwrap(); diff --git a/src/plot/layer/position/jitter.rs b/src/plot/layer/position/jitter.rs index 852b05ae..27cf4f3f 100644 --- a/src/plot/layer/position/jitter.rs +++ b/src/plot/layer/position/jitter.rs @@ -18,9 +18,11 @@ use super::{ compute_dodge_offsets, compute_group_indices, is_continuous_scale, non_facet_partition_cols, Layer, PositionTrait, PositionType, }; +use crate::array_util::{as_f64, cast_array, new_f64_array_non_null}; use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue}; use crate::{naming, DataFrame, GgsqlError, Plot, Result}; -use polars::prelude::*; +use arrow::array::Array; +use arrow::datatypes::DataType; use rand::Rng; /// Valid distribution types for jitter position @@ -341,30 +343,32 @@ fn compute_density_scales( let discrete_col_name = naming::aesthetic_column(discrete_col); // Extract values from the continuous axis - let values: Vec = df - .column(&continuous_col_name) - .map_err(|_| { - GgsqlError::InternalError(format!( - "Missing {} column for density jitter", - continuous_col - )) - })? - .cast(&DataType::Float64) - .map_err(|_| { - GgsqlError::InternalError(format!( - "{} must be numeric for density jitter", - continuous_col - )) - })? - .f64() - .map_err(|_| { - GgsqlError::InternalError(format!( - "{} must be numeric for density jitter", - continuous_col - )) - })? - .into_iter() - .map(|v| v.unwrap_or(0.0)) + let col = df.column(&continuous_col_name).map_err(|_| { + GgsqlError::InternalError(format!( + "Missing {} column for density jitter", + continuous_col + )) + })?; + let casted = cast_array(col, &DataType::Float64).map_err(|_| { + GgsqlError::InternalError(format!( + "{} must be numeric for density jitter", + continuous_col + )) + })?; + let f64_arr = as_f64(&casted).map_err(|_| { + GgsqlError::InternalError(format!( + "{} must be numeric for density jitter", + continuous_col + )) + })?; + let values: Vec = (0..f64_arr.len()) + .map(|i| { + if f64_arr.is_null(i) { + 0.0 + } else { + f64_arr.value(i) + } + }) .collect(); // Build density grouping columns: discrete axis + relevant partition_by columns @@ -540,7 +544,7 @@ fn apply_jitter(df: DataFrame, layer: &Layer, spec: &Plot) -> Result let pos1offset_col = naming::aesthetic_column("pos1offset"); let pos2offset_col = naming::aesthetic_column("pos2offset"); - let mut result = df.lazy(); + let mut result = df; // Compute dodge centers if we have groups to dodge let dodge_offsets = if n_groups > 1 { @@ -594,9 +598,7 @@ fn apply_jitter(df: DataFrame, layer: &Layer, spec: &Plot) -> Result jitters }; - result = result.with_column( - lit(Series::new(pos1offset_col.clone().into(), offsets)).alias(&pos1offset_col), - ); + result = result.with_column(&pos1offset_col, new_f64_array_non_null(offsets))?; } // Add pos2offset if pos2 is discrete @@ -622,28 +624,26 @@ fn apply_jitter(df: DataFrame, layer: &Layer, spec: &Plot) -> Result jitters }; - result = result.with_column( - lit(Series::new(pos2offset_col.clone().into(), offsets)).alias(&pos2offset_col), - ); + result = result.with_column(&pos2offset_col, new_f64_array_non_null(offsets))?; } - result - .collect() - .map_err(|e| GgsqlError::InternalError(format!("Jitter position adjustment failed: {}", e))) + Ok(result) } #[cfg(test)] mod tests { use super::*; + use crate::array_util::{as_f64, as_str, value_to_string}; + use crate::df; use crate::plot::layer::Geom; use crate::plot::{AestheticValue, Mappings, Scale, ScaleType}; fn make_test_df() -> DataFrame { df! { - "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + "__ggsql_aes_pos1__" => vec!["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], } .unwrap() } @@ -712,12 +712,9 @@ mod tests { "pos2offset column should NOT be created when pos2 is continuous" ); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // With default width 0.9 and 2 groups (dodge=true): // effective_width = 0.9 / 2 = 0.45 @@ -753,12 +750,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // With dodge=false and width 0.9, pure jitter in range [-0.45, 0.45] for &v in &offsets { @@ -795,12 +789,9 @@ mod tests { "pos2offset column should be created" ); - let offset = result - .column("__ggsql_aes_pos2offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos2offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // With default width 0.9 and 2 groups (dodge=true), effective range is [-0.45, 0.45] for &v in &offsets { @@ -879,12 +870,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // With custom width 0.6 and 2 groups (dodge=true): // effective_width = 0.6 / 2 = 0.3 @@ -913,21 +901,18 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let fill_col = result.column("__ggsql_aes_fill__").unwrap(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let fill_arr = result.column("__ggsql_aes_fill__").unwrap(); // Collect offsets by group let mut group_x_offsets = vec![]; let mut group_y_offsets = vec![]; for i in 0..result.height() { - let fill_val = fill_col.get(i).unwrap(); - let offset_val = offset.get(i).unwrap(); - if fill_val.to_string().contains("X") { + let fill_val = value_to_string(fill_arr, i); + let offset_val = offset.value(i); + if fill_val.contains('X') { group_x_offsets.push(offset_val); } else { group_y_offsets.push(offset_val); @@ -965,12 +950,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // Without groups, pure jitter with full width range [-0.45, 0.45] for &v in &offsets { @@ -1036,12 +1018,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // Normal distribution is centered at 0 // Values can exceed the width bounds (unlike uniform), but should be centered @@ -1147,9 +1126,9 @@ mod tests { // Create data with clear density peaks // Values 1.0 appears 5 times, values 2.0 and 3.0 appear once each let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], } .unwrap(); @@ -1183,12 +1162,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // Due to randomness, we can't assert exact values // But we can verify that offsets were generated @@ -1204,10 +1180,10 @@ mod tests { // Group X: dense at 1.0 // Group Y: dense at 3.0 let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 3.0, 3.0, 3.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "X", "X", "Y", "Y", "Y"], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![1.0, 1.0, 1.0, 3.0, 3.0, 3.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "X", "X", "Y", "Y", "Y"], } .unwrap(); @@ -1246,24 +1222,22 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); // Verify offsets were created and are within expected bounds - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); assert_eq!(offsets.len(), 6); // With 2 groups, we should see separated dodge positions // Group X centered at negative, Group Y centered at positive - let fill_col = result.column("__ggsql_aes_fill__").unwrap(); + let fill_arr = result.column("__ggsql_aes_fill__").unwrap(); + let fill_str = as_str(fill_arr).unwrap(); let mut group_x_offsets = vec![]; let mut group_y_offsets = vec![]; for i in 0..result.height() { - let fill_val = fill_col.get(i).unwrap(); - let offset_val = offset.get(i).unwrap(); - if fill_val.to_string().contains("X") { + let fill_val = fill_str.value(i); + let offset_val = offset.value(i); + if fill_val.contains('X') { group_x_offsets.push(offset_val); } else { group_y_offsets.push(offset_val); @@ -1381,9 +1355,9 @@ mod tests { let jitter = Jitter; let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], } .unwrap(); @@ -1417,12 +1391,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); // Due to randomness, we can't assert exact values // But we can verify that offsets were generated @@ -1437,9 +1408,9 @@ mod tests { let jitter = Jitter; let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A", "B", "B"], - "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A", "B", "B"], + "__ggsql_aes_pos2__" => vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], } .unwrap(); @@ -1477,12 +1448,9 @@ mod tests { let (result, _) = jitter.apply_adjustment(df, &layer, &spec).unwrap(); // Verify offsets were created - let offset = result - .column("__ggsql_aes_pos1offset__") - .unwrap() - .f64() - .unwrap(); - let offsets: Vec = offset.into_iter().flatten().collect(); + let offset_col = result.column("__ggsql_aes_pos1offset__").unwrap(); + let offset = as_f64(offset_col).unwrap(); + let offsets: Vec = (0..offset.len()).map(|i| offset.value(i)).collect(); assert_eq!(offsets.len(), 7); } @@ -1492,9 +1460,9 @@ mod tests { let jitter = Jitter; let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 2.0, 3.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0], } .unwrap(); @@ -1539,9 +1507,9 @@ mod tests { let jitter = Jitter; let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "A", "A"], - "__ggsql_aes_pos2__" => [1.0, 1.0, 1.0, 2.0, 3.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "A", "A"], + "__ggsql_aes_pos2__" => vec![1.0, 1.0, 1.0, 2.0, 3.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0], } .unwrap(); diff --git a/src/plot/layer/position/stack.rs b/src/plot/layer/position/stack.rs index 2cef1e0e..97f13a79 100644 --- a/src/plot/layer/position/stack.rs +++ b/src/plot/layer/position/stack.rs @@ -7,9 +7,12 @@ //! - If pos1 is continuous and pos2 is discrete → stack horizontally (modify pos1/pos1end) use super::{is_continuous_scale, Layer, PositionTrait, PositionType}; +use crate::array_util::{as_f64, cast_array}; use crate::plot::types::{DefaultParamValue, ParamConstraint, ParamDefinition, ParameterValue}; -use crate::{naming, DataFrame, GgsqlError, Plot, Result}; -use polars::prelude::*; +use crate::{compute, naming, DataFrame, GgsqlError, Plot, Result}; +use arrow::array::{Array, Float64Array}; +use arrow::datatypes::DataType; +use std::sync::Arc; /// Stack mode for position adjustments #[derive(Clone, Copy)] @@ -143,29 +146,37 @@ fn has_zero_baseline_per_row(df: &DataFrame, col_a: &str, col_b: &str) -> bool { }; // Cast columns to f64 for comparison - handle both Int64 and Float64 sources - let Ok(a_casted) = a.cast(&polars::datatypes::DataType::Float64) else { + let Ok(a_casted) = cast_array(a, &DataType::Float64) else { return false; }; - let Ok(b_casted) = b.cast(&polars::datatypes::DataType::Float64) else { + let Ok(b_casted) = cast_array(b, &DataType::Float64) else { return false; }; - let Ok(a_vals) = a_casted.f64() else { + let Ok(a_vals) = as_f64(&a_casted) else { return false; }; - let Ok(b_vals) = b_casted.f64() else { + let Ok(b_vals) = as_f64(&b_casted) else { return false; }; - // Collect values to avoid borrow issues - let a_vec: Vec> = a_vals.into_iter().collect(); - let b_vec: Vec> = b_vals.into_iter().collect(); - // For each row, either a or b must be 0 - a_vec - .into_iter() - .zip(b_vec) - .all(|(a_val, b_val)| a_val == Some(0.0) || b_val == Some(0.0)) + for i in 0..a_vals.len() { + let a_val = if a_vals.is_null(i) { + None + } else { + Some(a_vals.value(i)) + }; + let b_val = if b_vals.is_null(i) { + None + } else { + Some(b_vals.value(i)) + }; + if a_val != Some(0.0) && b_val != Some(0.0) { + return false; + } + } + true } /// Determine stacking direction based on scale types and axis configuration. @@ -224,115 +235,94 @@ fn apply_stack(df: DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Re } // Stacking currently only supports non-negative values - let min_result = df - .clone() - .lazy() - .select([col(&stack_col).min()]) - .collect() - .map_err(|e| GgsqlError::InternalError(format!("Failed to check min value: {}", e)))?; - - if let Some(min_col) = min_result.get_columns().first() { - if let Ok(min_val) = min_col.get(0) { - if let Ok(min) = min_val.try_extract::() { - if min < 0.0 { - let axis = match direction { - StackDirection::Vertical => "y", - StackDirection::Horizontal => "x", - }; - return Err(GgsqlError::ValidationError(format!( - "position 'stack' requires non-negative {} values", - axis - ))); - } - } + if let Some(min) = compute::column_min_f64(&df, &stack_col)? { + if min < 0.0 { + let axis = match direction { + StackDirection::Vertical => "y", + StackDirection::Horizontal => "x", + }; + return Err(GgsqlError::ValidationError(format!( + "position 'stack' requires non-negative {} values", + axis + ))); } } - // Convert to lazy for transformations - let lf = df.lazy(); - // Sort by group column and partition_by columns to ensure consistent stacking order - // This ensures that within each group (e.g., x position), the stacking order is - // consistent even if data arrives in different orders or has missing values - let mut sort_cols = vec![col(&group_col)]; + let mut sort_col_names: Vec<&str> = vec![&group_col]; for partition_col in &layer.partition_by { - sort_cols.push(col(partition_col)); + sort_col_names.push(partition_col); } - let sort_options = SortMultipleOptions::default(); - let lf = lf.sort_by_exprs(&sort_cols, sort_options); - - // For stacking, compute cumulative sums within each group: - // 1. stack_col = cumulative sum (the bar top/end) - // 2. stack_end_col = lag(stack_col, 1, 0) - the bar bottom/start (previous stack top) - // The cumsum naturally stacks across the grouping column values + let df = compute::sort_dataframe(&df, &sort_col_names)?; + + // Cast stack column to f64 if needed, then fill nulls with 0 + let stack_col_array = df.column(&stack_col)?.clone(); + let stack_col_f64 = if stack_col_array.data_type() == &arrow::datatypes::DataType::Float64 { + stack_col_array + } else { + crate::array_util::cast_array(&stack_col_array, &arrow::datatypes::DataType::Float64)? + }; + let filled = compute::fill_null_f64_ref(&stack_col_f64, 0.0)?; + let df = df.with_column(&stack_col, filled)?; - // Build the partition columns for .over(): group column + facet columns. + // Build the group columns for .over(): group column + facet columns. // Facet columns must be included so stacking resets per facet panel, // matching ggplot2 where position adjustments are computed per-panel. - let mut over_cols: Vec = vec![col(&group_col)]; - if let Some(ref facet) = spec.facet { - for aes in facet.layout.internal_facet_names() { - let facet_col = naming::aesthetic_column(&aes); - over_cols.push(col(&facet_col)); - } + // Collect facet column names as owned Strings + let facet_col_names: Vec = spec + .facet + .as_ref() + .map(|f| { + f.layout + .internal_facet_names() + .into_iter() + .map(|aes| naming::aesthetic_column(&aes)) + .collect() + }) + .unwrap_or_default(); + + let mut over_col_refs: Vec<&str> = vec![&group_col]; + for name in &facet_col_names { + over_col_refs.push(name); } - // Treat NA heights as 0 for stacking + // Compute group IDs + let group_ids = compute::compute_group_ids(&df, &over_col_refs)?; + + // Get the stack column values as Float64 + let stack_arr = df.column(&stack_col)?; + let values = as_f64(stack_arr)?; + // Compute cumulative sums (shared by all modes) - let lf = lf - .with_column(col(&stack_col).fill_null(lit(0.0)).alias(&stack_col)) - .with_column( - col(&stack_col) - .cum_sum(false) - .over(&over_cols) - .alias("__cumsum__"), - ) - .with_column( - col(&stack_col) - .cum_sum(false) - .shift(lit(1)) - .fill_null(lit(0.0)) - .over(&over_cols) - .alias("__cumsum_lag__"), - ); + let cumsum = compute::grouped_cumsum(values, &group_ids); + let cumsum_lag = compute::grouped_cumsum_lag(values, &group_ids); // Apply mode-specific transformation - let (stack_expr, stack_end_expr, temp_cols): (Expr, Expr, Vec<&str>) = match mode { - StackMode::Normal => ( - col("__cumsum__").alias(&stack_col), - col("__cumsum_lag__").alias(&stack_end_col), - vec!["__cumsum__", "__cumsum_lag__"], - ), + let (new_stack, new_stack_end): (Float64Array, Float64Array) = match mode { + StackMode::Normal => (cumsum, cumsum_lag), StackMode::Fill(target) => { - let total = col(&stack_col).sum().over(&over_cols); - ( - (col("__cumsum__") / total.clone() * lit(target)).alias(&stack_col), - (col("__cumsum_lag__") / total * lit(target)).alias(&stack_end_col), - vec!["__cumsum__", "__cumsum_lag__"], - ) + let group_sum = compute::grouped_sum_broadcast(values, &group_ids); + let cumsum_div = compute::divide_arrays(&cumsum, &group_sum)?; + let cumsum_lag_div = compute::divide_arrays(&cumsum_lag, &group_sum)?; + let new_stack = compute::multiply_scalar(&cumsum_div, target); + let new_stack_end = compute::multiply_scalar(&cumsum_lag_div, target); + (new_stack, new_stack_end) } StackMode::Center => { - let half_total = col(&stack_col).sum().over(&over_cols) / lit(2.0); - ( - (col("__cumsum__") - half_total.clone()).alias(&stack_col), - (col("__cumsum_lag__") - half_total).alias(&stack_end_col), - vec!["__cumsum__", "__cumsum_lag__"], - ) + let group_sum = compute::grouped_sum_broadcast(values, &group_ids); + let half_sum = compute::divide_scalar(&group_sum, 2.0); + let new_stack = compute::subtract_arrays(&cumsum, &half_sum); + let new_stack_end = compute::subtract_arrays(&cumsum_lag, &half_sum); + (new_stack, new_stack_end) } }; - let mut result = lf - .with_columns([stack_expr, stack_end_expr]) - .collect() - .map_err(|e| { - GgsqlError::InternalError(format!("Stack position adjustment failed: {}", e)) - })?; - - for col_name in temp_cols { - result = result - .drop(col_name) - .map_err(|e| GgsqlError::InternalError(format!("Failed to drop temp column: {}", e)))?; - } + let result = df + .with_column(&stack_col, Arc::new(new_stack) as arrow::array::ArrayRef)? + .with_column( + &stack_end_col, + Arc::new(new_stack_end) as arrow::array::ArrayRef, + )?; Ok(result) } @@ -340,15 +330,17 @@ fn apply_stack(df: DataFrame, layer: &Layer, spec: &Plot, mode: StackMode) -> Re #[cfg(test)] mod tests { use super::*; + use crate::array_util::{as_f64, as_str}; + use crate::df; use crate::plot::layer::Geom; use crate::plot::{AestheticValue, Mappings}; fn make_test_df() -> DataFrame { df! { - "__ggsql_aes_pos1__" => ["A", "A", "B", "B"], - "__ggsql_aes_pos2__" => [10.0, 20.0, 15.0, 25.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + "__ggsql_aes_pos1__" => vec!["A", "A", "B", "B"], + "__ggsql_aes_pos2__" => vec![10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], } .unwrap() } @@ -391,11 +383,8 @@ mod tests { let (result, width) = stack.apply_adjustment(df, &layer, &spec).unwrap(); assert!(width.is_none()); - let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); - let pos2end_col = result.column("__ggsql_aes_pos2end__").unwrap(); - - assert!(pos2_col.f64().is_ok() || pos2_col.i64().is_ok()); - assert!(pos2end_col.f64().is_ok() || pos2end_col.i64().is_ok()); + assert!(result.column("__ggsql_aes_pos2__").is_ok()); + assert!(result.column("__ggsql_aes_pos2end__").is_ok()); } #[test] @@ -431,23 +420,19 @@ mod tests { let (result_centered, _) = stack.apply_adjustment(df, &layer_centered, &spec).unwrap(); // Normal stacking should have pos2end starting at 0 - let pos2end_normal = result_normal.column("__ggsql_aes_pos2end__").unwrap(); - let first_normal = pos2end_normal.get(0).unwrap(); + let pos2end_normal_col = result_normal.column("__ggsql_aes_pos2end__").unwrap(); + let pos2end_normal = as_f64(pos2end_normal_col).unwrap(); // First element's pos2end should be 0 for normal stack - if let polars::prelude::AnyValue::Float64(v) = first_normal { - assert_eq!(v, 0.0); - } + assert_eq!(pos2end_normal.value(0), 0.0); // Centered stacking should have negative values - let pos2end_centered = result_centered.column("__ggsql_aes_pos2end__").unwrap(); - let first_centered = pos2end_centered.get(0).unwrap(); + let pos2end_centered_col = result_centered.column("__ggsql_aes_pos2end__").unwrap(); + let pos2end_centered = as_f64(pos2end_centered_col).unwrap(); // First element's pos2end should be negative for centered stack (shifted by -total/2) - if let polars::prelude::AnyValue::Float64(v) = first_centered { - assert!( - v < 0.0, - "Centered stack should have negative pos2end for first element" - ); - } + assert!( + pos2end_centered.value(0) < 0.0, + "Centered stack should have negative pos2end for first element" + ); } fn make_continuous_scale(aesthetic: &str) -> crate::plot::Scale { @@ -487,10 +472,10 @@ mod tests { // Create data with numeric pos1 values and pos1end column with zero baselines let df = df! { - "__ggsql_aes_pos1__" => [10.0, 20.0, 15.0, 25.0], - "__ggsql_aes_pos1end__" => [0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_pos2__" => ["A", "A", "B", "B"], - "__ggsql_aes_fill__" => ["X", "Y", "X", "Y"], + "__ggsql_aes_pos1__" => vec![10.0, 20.0, 15.0, 25.0], + "__ggsql_aes_pos1end__" => vec![0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_pos2__" => vec!["A", "A", "B", "B"], + "__ggsql_aes_fill__" => vec!["X", "Y", "X", "Y"], } .unwrap(); @@ -536,9 +521,10 @@ mod tests { // Verify stacking occurred - values should be cumulative sums let pos1_col = result.column("__ggsql_aes_pos1__").unwrap(); - let pos1_vals: Vec = pos1_col.f64().unwrap().into_iter().flatten().collect(); + let pos1_arr = as_f64(pos1_col).unwrap(); + let pos1_vals: Vec = (0..pos1_arr.len()).map(|i| pos1_arr.value(i)).collect(); - // Should have cumulative sums (10, 30, 15, 40) for groups A and B + // Should have cumulative values > original max, got {:?} assert!( pos1_vals.iter().any(|&v| v > 20.0), "Should have cumulative values > original max, got {:?}", @@ -563,7 +549,8 @@ mod tests { // pos2 should sum to 100 within each group (A and B) let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); - let pos2_vals: Vec = pos2_col.f64().unwrap().into_iter().flatten().collect(); + let pos2_arr = as_f64(pos2_col).unwrap(); + let pos2_vals: Vec = (0..pos2_arr.len()).map(|i| pos2_arr.value(i)).collect(); // For group A: values 10, 20 -> normalized: 10/30, 20/30 -> cumsum: 10/30, 30/30 // Multiplied by 100: ~33.33, 100 @@ -594,7 +581,8 @@ mod tests { let (result, _) = stack.apply_adjustment(df, &layer, &spec).unwrap(); let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); - let pos2_vals: Vec = pos2_col.f64().unwrap().into_iter().flatten().collect(); + let pos2_arr = as_f64(pos2_col).unwrap(); + let pos2_vals: Vec = (0..pos2_arr.len()).map(|i| pos2_arr.value(i)).collect(); // Max values should be 1 (normalized to sum to 1) let max_val = pos2_vals.iter().cloned().fold(f64::MIN, f64::max); @@ -611,10 +599,10 @@ mod tests { // Create data with NA values in pos2 let df = df! { - "__ggsql_aes_pos1__" => ["A", "A", "A", "B", "B", "B"], - "__ggsql_aes_pos2__" => [Some(10.0), None, Some(20.0), Some(15.0), Some(25.0), None], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "Y", "Z", "X", "Y", "Z"], + "__ggsql_aes_pos1__" => vec!["A", "A", "A", "B", "B", "B"], + "__ggsql_aes_pos2__" => vec![Some(10.0), None, Some(20.0), Some(15.0), Some(25.0), None], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "Y", "Z", "X", "Y", "Z"], } .unwrap(); @@ -647,19 +635,21 @@ mod tests { // Get pos2 values - should have no nulls after stacking let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); - let pos2_vals: Vec> = pos2_col.f64().unwrap().into_iter().collect(); + let pos2_arr = as_f64(pos2_col).unwrap(); // All values should be non-null (NA treated as 0) - assert!( - pos2_vals.iter().all(|v| v.is_some()), - "Expected no null values after stacking, got {:?}", - pos2_vals - ); + for i in 0..pos2_arr.len() { + assert!( + !pos2_arr.is_null(i), + "Expected no null values after stacking, got null at index {}", + i + ); + } // For group A: 10, 0 (NA), 20 -> cumsum: 10, 10, 30 // For group B: 15, 25, 0 (NA) -> cumsum: 15, 40, 40 // Check that the cumsum for group A ends at 30 (10 + 0 + 20) - let group_a_max = pos2_vals[2].unwrap(); // Third row is last for group A + let group_a_max = pos2_arr.value(2); // Third row is last for group A assert!( (group_a_max - 30.0).abs() < 0.01, "Expected group A max ~30 (NA treated as 0), got {}", @@ -673,10 +663,10 @@ mod tests { // Create data in shuffled order - categories not in order within groups let df = df! { - "__ggsql_aes_pos1__" => ["A", "B", "A", "B", "A", "B"], - "__ggsql_aes_pos2__" => [10.0, 15.0, 30.0, 35.0, 20.0, 25.0], - "__ggsql_aes_pos2end__" => [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - "__ggsql_aes_fill__" => ["X", "X", "Z", "Z", "Y", "Y"], + "__ggsql_aes_pos1__" => vec!["A", "B", "A", "B", "A", "B"], + "__ggsql_aes_pos2__" => vec![10.0, 15.0, 30.0, 35.0, 20.0, 25.0], + "__ggsql_aes_pos2end__" => vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + "__ggsql_aes_fill__" => vec!["X", "X", "Z", "Z", "Y", "Y"], } .unwrap(); @@ -716,9 +706,13 @@ mod tests { let fill_col = result.column("__ggsql_aes_fill__").unwrap(); let pos2_col = result.column("__ggsql_aes_pos2__").unwrap(); - let pos1_vals: Vec<&str> = pos1_col.str().unwrap().into_iter().flatten().collect(); - let fill_vals: Vec<&str> = fill_col.str().unwrap().into_iter().flatten().collect(); - let pos2_vals: Vec = pos2_col.f64().unwrap().into_iter().flatten().collect(); + let pos1_arr = as_str(pos1_col).unwrap(); + let fill_arr = as_str(fill_col).unwrap(); + let pos2_arr = as_f64(pos2_col).unwrap(); + + let pos1_vals: Vec<&str> = (0..pos1_arr.len()).map(|i| pos1_arr.value(i)).collect(); + let fill_vals: Vec<&str> = (0..fill_arr.len()).map(|i| fill_arr.value(i)).collect(); + let pos2_vals: Vec = (0..pos2_arr.len()).map(|i| pos2_arr.value(i)).collect(); // Should be sorted: A-X, A-Y, A-Z, B-X, B-Y, B-Z assert_eq!(pos1_vals, vec!["A", "A", "A", "B", "B", "B"]); diff --git a/src/plot/scale/scale_type/binned.rs b/src/plot/scale/scale_type/binned.rs index dd4b028e..561465d2 100644 --- a/src/plot/scale/scale_type/binned.rs +++ b/src/plot/scale/scale_type/binned.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; -use polars::prelude::DataType; +use arrow::datatypes::DataType; use super::{ expand_numeric_range, resolve_common_steps, ScaleDataContext, ScaleTypeKind, ScaleTypeTrait, @@ -93,13 +93,13 @@ impl ScaleTypeTrait for Binned { | DataType::Float32 | DataType::Float64 => Ok(()), // Accept temporal types - DataType::Date | DataType::Datetime(_, _) | DataType::Time => Ok(()), + DataType::Date32 | DataType::Timestamp(_, _) | DataType::Time64(_) => Ok(()), // Reject discrete types - DataType::String => Err("Binned scale cannot be used with String data. \ + DataType::Utf8 => Err("Binned scale cannot be used with String data. \ Use DISCRETE scale type instead, or ensure the column contains numeric or temporal data.".to_string()), DataType::Boolean => Err("Binned scale cannot be used with Boolean data. \ Use DISCRETE scale type instead, or ensure the column contains numeric or temporal data.".to_string()), - DataType::Categorical(_, _) => Err("Binned scale cannot be used with Categorical data. \ + DataType::Dictionary(_, _) => Err("Binned scale cannot be used with Categorical data. \ Use DISCRETE scale type instead, or ensure the column contains numeric or temporal data.".to_string()), // Other types - provide generic message other => Err(format!( @@ -138,9 +138,9 @@ impl ScaleTypeTrait for Binned { // First check column data type for temporal transforms if let Some(dtype) = column_dtype { match dtype { - DataType::Date => return TransformKind::Date, - DataType::Datetime(_, _) => return TransformKind::DateTime, - DataType::Time => return TransformKind::Time, + DataType::Date32 => return TransformKind::Date, + DataType::Timestamp(_, _) => return TransformKind::DateTime, + DataType::Time64(_) => return TransformKind::Time, _ => {} } } @@ -617,7 +617,7 @@ impl ScaleTypeTrait for Binned { let transform = scale.transform.as_ref(); let is_temporal = matches!( column_dtype, - DataType::Date | DataType::Datetime(..) | DataType::Time + DataType::Date32 | DataType::Timestamp(..) | DataType::Time64(_) ); // Build CASE WHEN clauses for each bin @@ -835,6 +835,7 @@ mod tests { use super::*; use crate::plot::scale::Scale; use crate::reader::AnsiDialect; + use arrow::datatypes::TimeUnit; #[test] fn test_pre_stat_transform_sql_even_breaks() { @@ -1018,7 +1019,8 @@ mod tests { ); // Date column - no casting needed (types match) - let sql = binned.pre_stat_transform_sql("date_col", &DataType::Date, &scale, &AnsiDialect); + let sql = + binned.pre_stat_transform_sql("date_col", &DataType::Date32, &scale, &AnsiDialect); // Should successfully generate SQL (not return None due to filtered-out breaks) assert!(sql.is_some(), "SQL should be generated for Date breaks"); @@ -1067,10 +1069,10 @@ mod tests { ]), ); - use polars::prelude::TimeUnit; + use arrow::datatypes::TimeUnit; let sql = binned.pre_stat_transform_sql( "datetime_col", - &DataType::Datetime(TimeUnit::Microseconds, None), + &DataType::Timestamp(TimeUnit::Microsecond, None), &scale, &AnsiDialect, ); @@ -1101,7 +1103,12 @@ mod tests { ]), ); - let sql = binned.pre_stat_transform_sql("time_col", &DataType::Time, &scale, &AnsiDialect); + let sql = binned.pre_stat_transform_sql( + "time_col", + &DataType::Time64(TimeUnit::Nanosecond), + &scale, + &AnsiDialect, + ); // Should successfully generate SQL assert!(sql.is_some(), "SQL should be generated for Time breaks"); @@ -1142,7 +1149,7 @@ mod tests { // Date column - no column casting, but break values are formatted as ISO dates let sql = binned - .pre_stat_transform_sql("date_col", &DataType::Date, &scale, &AnsiDialect) + .pre_stat_transform_sql("date_col", &DataType::Date32, &scale, &AnsiDialect) .unwrap(); // Should NOT contain column CAST (column is already DATE) @@ -1239,7 +1246,7 @@ mod tests { fn test_datetime_column_with_datetime_transform() { // DATETIME column + datetime transform → temporal literals use crate::plot::scale::transform::Transform; - use polars::prelude::TimeUnit; + use arrow::datatypes::TimeUnit; let binned = Binned; let mut scale = Scale::new("x"); @@ -1260,7 +1267,7 @@ mod tests { let sql = binned .pre_stat_transform_sql( "datetime_col", - &DataType::Datetime(TimeUnit::Microseconds, None), + &DataType::Timestamp(TimeUnit::Microsecond, None), &scale, &AnsiDialect, ) @@ -1926,7 +1933,7 @@ mod tests { #[test] fn test_validate_dtype_accepts_numeric() { use super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let binned = Binned; assert!(binned.validate_dtype(&DataType::Int64).is_ok()); @@ -1936,22 +1943,22 @@ mod tests { #[test] fn test_validate_dtype_accepts_temporal() { use super::ScaleTypeTrait; - use polars::prelude::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, TimeUnit}; let binned = Binned; - assert!(binned.validate_dtype(&DataType::Date).is_ok()); + assert!(binned.validate_dtype(&DataType::Date32).is_ok()); assert!(binned - .validate_dtype(&DataType::Datetime(TimeUnit::Microseconds, None)) + .validate_dtype(&DataType::Timestamp(TimeUnit::Microsecond, None)) .is_ok()); } #[test] fn test_validate_dtype_rejects_string() { use super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let binned = Binned; - let result = binned.validate_dtype(&DataType::String); + let result = binned.validate_dtype(&DataType::Utf8); assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.contains("String")); @@ -1961,7 +1968,7 @@ mod tests { #[test] fn test_validate_dtype_rejects_boolean() { use super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let binned = Binned; let result = binned.validate_dtype(&DataType::Boolean); @@ -1982,7 +1989,7 @@ mod tests { // Issue: breaks like [2600, 3550, 4050, 4750, 6400] were getting terminal // breaks removed when data range was ~[2700, 6300]. use super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let binned = Binned; let mut scale = Scale::new("fill"); @@ -2036,7 +2043,7 @@ mod tests { // When BOTH explicit breaks AND explicit range are provided, // breaks should be filtered to the range. use super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let binned = Binned; let mut scale = Scale::new("fill"); diff --git a/src/plot/scale/scale_type/continuous.rs b/src/plot/scale/scale_type/continuous.rs index 099ba168..22c26c93 100644 --- a/src/plot/scale/scale_type/continuous.rs +++ b/src/plot/scale/scale_type/continuous.rs @@ -1,6 +1,6 @@ //! Continuous scale type implementation -use polars::prelude::DataType; +use arrow::datatypes::DataType; use super::{ ScaleTypeKind, ScaleTypeTrait, TransformKind, OOB_CENSOR, OOB_SQUISH, OOB_VALUES_CONTINUOUS, @@ -38,13 +38,13 @@ impl ScaleTypeTrait for Continuous { | DataType::Float32 | DataType::Float64 => Ok(()), // Accept temporal types - DataType::Date | DataType::Datetime(_, _) | DataType::Time => Ok(()), + DataType::Date32 | DataType::Timestamp(_, _) | DataType::Time64(_) => Ok(()), // Reject discrete types - DataType::String => Err("Continuous scale cannot be used with String data. \ + DataType::Utf8 => Err("Continuous scale cannot be used with String data. \ Use DISCRETE scale type instead, or ensure the column contains numeric or temporal data.".to_string()), DataType::Boolean => Err("Continuous scale cannot be used with Boolean data. \ Use DISCRETE scale type instead, or ensure the column contains numeric or temporal data.".to_string()), - DataType::Categorical(_, _) => Err("Continuous scale cannot be used with Categorical data. \ + DataType::Dictionary(_, _) => Err("Continuous scale cannot be used with Categorical data. \ Use DISCRETE scale type instead, or ensure the column contains numeric or temporal data.".to_string()), // Other types - provide generic message other => Err(format!( @@ -85,9 +85,9 @@ impl ScaleTypeTrait for Continuous { // First check column data type for temporal transforms if let Some(dtype) = column_dtype { match dtype { - DataType::Date => return TransformKind::Date, - DataType::Datetime(_, _) => return TransformKind::DateTime, - DataType::Time => return TransformKind::Time, + DataType::Date32 => return TransformKind::Date, + DataType::Timestamp(_, _) => return TransformKind::DateTime, + DataType::Time64(_) => return TransformKind::Time, _ => {} } } @@ -382,14 +382,16 @@ mod tests { #[test] fn test_validate_dtype_accepts_temporal() { use super::ScaleTypeTrait; - use polars::prelude::TimeUnit; + use arrow::datatypes::TimeUnit; let continuous = Continuous; - assert!(continuous.validate_dtype(&DataType::Date).is_ok()); + assert!(continuous.validate_dtype(&DataType::Date32).is_ok()); assert!(continuous - .validate_dtype(&DataType::Datetime(TimeUnit::Microseconds, None)) + .validate_dtype(&DataType::Timestamp(TimeUnit::Microsecond, None)) + .is_ok()); + assert!(continuous + .validate_dtype(&DataType::Time64(TimeUnit::Nanosecond)) .is_ok()); - assert!(continuous.validate_dtype(&DataType::Time).is_ok()); } #[test] @@ -397,7 +399,7 @@ mod tests { use super::ScaleTypeTrait; let continuous = Continuous; - let result = continuous.validate_dtype(&DataType::String); + let result = continuous.validate_dtype(&DataType::Utf8); assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.contains("String")); diff --git a/src/plot/scale/scale_type/discrete.rs b/src/plot/scale/scale_type/discrete.rs index c483cb3e..97c31344 100644 --- a/src/plot/scale/scale_type/discrete.rs +++ b/src/plot/scale/scale_type/discrete.rs @@ -1,6 +1,6 @@ //! Discrete scale type implementation -use polars::prelude::DataType; +use arrow::datatypes::DataType; use super::super::transform::{Transform, TransformKind}; use super::{ScaleTypeKind, ScaleTypeTrait}; @@ -24,7 +24,7 @@ impl ScaleTypeTrait for Discrete { fn validate_dtype(&self, dtype: &DataType) -> Result<(), String> { match dtype { // Accept discrete types - DataType::String | DataType::Boolean | DataType::Categorical(_, _) => Ok(()), + DataType::Utf8 | DataType::Boolean | DataType::Dictionary(_, _) => Ok(()), // Reject numeric types DataType::Int8 | DataType::Int16 @@ -38,11 +38,11 @@ impl ScaleTypeTrait for Discrete { | DataType::Float64 => Err("Discrete scale cannot be used with numeric data. \ Use CONTINUOUS or BINNED scale type instead, or ensure the column contains categorical data.".to_string()), // Reject temporal types - DataType::Date => Err("Discrete scale cannot be used with Date data. \ + DataType::Date32 => Err("Discrete scale cannot be used with Date data. \ Use CONTINUOUS scale type instead (dates are treated as continuous temporal data).".to_string()), - DataType::Datetime(_, _) => Err("Discrete scale cannot be used with DateTime data. \ + DataType::Timestamp(_, _) => Err("Discrete scale cannot be used with DateTime data. \ Use CONTINUOUS scale type instead (datetimes are treated as continuous temporal data).".to_string()), - DataType::Time => Err("Discrete scale cannot be used with Time data. \ + DataType::Time64(_) => Err("Discrete scale cannot be used with Time data. \ Use CONTINUOUS scale type instead (times are treated as continuous temporal data).".to_string()), // Other types - provide generic message other => Err(format!( @@ -84,7 +84,7 @@ impl ScaleTypeTrait for Discrete { if let Some(dtype) = column_dtype { match dtype { DataType::Boolean => return TransformKind::Bool, - DataType::String | DataType::Categorical(_, _) => return TransformKind::String, + DataType::Utf8 | DataType::Dictionary(_, _) => return TransformKind::String, _ => {} } } @@ -323,7 +323,7 @@ mod tests { // String column → String transform assert_eq!( - discrete.default_transform("color", Some(&DataType::String)), + discrete.default_transform("color", Some(&DataType::Utf8)), TransformKind::String ); @@ -397,8 +397,8 @@ mod tests { let result = discrete.resolve_transform( "color", None, - Some(&DataType::String), // String column - Some(&bool_range), // But bool input range + Some(&DataType::Utf8), // String column + Some(&bool_range), // But bool input range ); assert!(result.is_ok()); assert_eq!(result.unwrap().transform_kind(), TransformKind::Bool); @@ -427,7 +427,7 @@ mod tests { assert!(result.is_ok()); assert_eq!(result.unwrap().transform_kind(), TransformKind::Bool); - let result = discrete.resolve_transform("color", None, Some(&DataType::String), None); + let result = discrete.resolve_transform("color", None, Some(&DataType::Utf8), None); assert!(result.is_ok()); assert_eq!(result.unwrap().transform_kind(), TransformKind::String); } @@ -476,7 +476,7 @@ mod tests { scale.explicit_input_range = true; let sql = - discrete.pre_stat_transform_sql("category", &DataType::String, &scale, &AnsiDialect); + discrete.pre_stat_transform_sql("category", &DataType::Utf8, &scale, &AnsiDialect); assert!(sql.is_some()); let sql = sql.unwrap(); @@ -500,7 +500,7 @@ mod tests { scale.explicit_input_range = false; let sql = - discrete.pre_stat_transform_sql("category", &DataType::String, &scale, &AnsiDialect); + discrete.pre_stat_transform_sql("category", &DataType::Utf8, &scale, &AnsiDialect); // Should return None (no OOB handling for inferred ranges) assert!(sql.is_none()); @@ -539,7 +539,7 @@ mod tests { ]); scale.explicit_input_range = true; - let sql = discrete.pre_stat_transform_sql("text", &DataType::String, &scale, &AnsiDialect); + let sql = discrete.pre_stat_transform_sql("text", &DataType::Utf8, &scale, &AnsiDialect); assert!(sql.is_some()); let sql = sql.unwrap(); @@ -557,7 +557,7 @@ mod tests { scale.explicit_input_range = true; let sql = - discrete.pre_stat_transform_sql("category", &DataType::String, &scale, &AnsiDialect); + discrete.pre_stat_transform_sql("category", &DataType::Utf8, &scale, &AnsiDialect); // Should return None for empty range assert!(sql.is_none()); @@ -572,7 +572,7 @@ mod tests { use super::ScaleTypeTrait; let discrete = Discrete; - assert!(discrete.validate_dtype(&DataType::String).is_ok()); + assert!(discrete.validate_dtype(&DataType::Utf8).is_ok()); } #[test] @@ -601,19 +601,19 @@ mod tests { #[test] fn test_validate_dtype_rejects_temporal() { use super::ScaleTypeTrait; - use polars::prelude::TimeUnit; + use arrow::datatypes::TimeUnit; let discrete = Discrete; - let result = discrete.validate_dtype(&DataType::Date); + let result = discrete.validate_dtype(&DataType::Date32); assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.contains("Date")); assert!(err.contains("CONTINUOUS")); - let result = discrete.validate_dtype(&DataType::Datetime(TimeUnit::Microseconds, None)); + let result = discrete.validate_dtype(&DataType::Timestamp(TimeUnit::Microsecond, None)); assert!(result.is_err()); - let result = discrete.validate_dtype(&DataType::Time); + let result = discrete.validate_dtype(&DataType::Time64(TimeUnit::Nanosecond)); assert!(result.is_err()); } } diff --git a/src/plot/scale/scale_type/identity.rs b/src/plot/scale/scale_type/identity.rs index 45427844..0acb14b9 100644 --- a/src/plot/scale/scale_type/identity.rs +++ b/src/plot/scale/scale_type/identity.rs @@ -1,6 +1,6 @@ //! Identity scale type implementation -use polars::prelude::DataType; +use arrow::datatypes::DataType; use super::{CastTargetType, ScaleTypeKind, ScaleTypeTrait}; use crate::plot::ArrayElement; diff --git a/src/plot/scale/scale_type/mod.rs b/src/plot/scale/scale_type/mod.rs index 9a6a351e..db0654da 100644 --- a/src/plot/scale/scale_type/mod.rs +++ b/src/plot/scale/scale_type/mod.rs @@ -20,7 +20,8 @@ //! assert_eq!(continuous.name(), "continuous"); //! ``` -use polars::prelude::{ChunkAgg, Column, DataType}; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::DataType; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; @@ -130,15 +131,15 @@ impl ScaleDataContext { } } - /// Create from multiple Polars Columns. + /// Create from multiple Arrow ArrayRef columns. /// /// Aggregates min/max or unique values across all columns. - pub fn from_columns(columns: &[&Column], is_discrete: bool) -> Self { + pub fn from_columns(columns: &[&ArrayRef], is_discrete: bool) -> Self { if columns.is_empty() { return Self::new(); } - let dtype = Some(columns[0].dtype().clone()); + let dtype = Some(columns[0].data_type().clone()); let range = if is_discrete { // Aggregate unique values across all columns @@ -180,18 +181,19 @@ impl Default for ScaleDataContext { } /// Compute numeric min/max from multiple columns. -fn compute_column_range_multi(columns: &[&Column]) -> Option> { +fn compute_column_range_multi(columns: &[&ArrayRef]) -> Option> { + use crate::array_util::cast_array; + let mut global_min: Option = None; let mut global_max: Option = None; for column in columns { - let series = column.as_materialized_series(); - if let Ok(ca) = series.cast(&DataType::Float64) { - if let Ok(f64_series) = ca.f64() { - if let Some(min) = f64_series.min() { + if let Ok(cast) = cast_array(column, &DataType::Float64) { + if let Ok(f64_arr) = crate::array_util::as_f64(&cast) { + if let Some(min) = arrow::compute::min(f64_arr) { global_min = Some(global_min.map_or(min, |m| m.min(min))); } - if let Some(max) = f64_series.max() { + if let Some(max) = arrow::compute::max(f64_arr) { global_max = Some(global_max.map_or(max, |m| m.max(max))); } } @@ -227,7 +229,7 @@ fn merge_with_context( /// Compute unique values from multiple columns, sorted. /// NULL values are included at the end of the result. -fn compute_unique_values_multi(columns: &[&Column]) -> Vec { +fn compute_unique_values_multi(columns: &[&ArrayRef]) -> Vec { compute_unique_values_native(columns, true) } @@ -243,13 +245,16 @@ fn compute_unique_values_multi(columns: &[&Column]) -> Vec { /// /// If `include_null` is true, `ArrayElement::Null` is appended at the end if any null /// values exist in the data. -pub fn compute_unique_values_native(columns: &[&Column], include_null: bool) -> Vec { +pub fn compute_unique_values_native( + columns: &[&ArrayRef], + include_null: bool, +) -> Vec { if columns.is_empty() { return Vec::new(); } // Use first column's dtype to determine handling - let dtype = columns[0].dtype(); + let dtype = columns[0].data_type(); match dtype { DataType::Boolean => compute_unique_bool(columns, include_null), @@ -263,34 +268,34 @@ pub fn compute_unique_values_native(columns: &[&Column], include_null: bool) -> | DataType::UInt64 | DataType::Float32 | DataType::Float64 => compute_unique_numeric(columns, include_null), - DataType::Date => compute_unique_date(columns, include_null), - DataType::Datetime(_, _) => compute_unique_datetime(columns, include_null), - DataType::Time => compute_unique_time(columns, include_null), - _ => compute_unique_string(columns, include_null), // String/Categorical/fallback + DataType::Date32 => compute_unique_date(columns, include_null), + DataType::Timestamp(_, _) => compute_unique_datetime(columns, include_null), + DataType::Time64(_) => compute_unique_time(columns, include_null), + _ => compute_unique_string(columns, include_null), // Utf8/Dictionary/fallback } } /// Compute unique boolean values from columns. -fn compute_unique_bool(columns: &[&Column], include_null: bool) -> Vec { +fn compute_unique_bool(columns: &[&ArrayRef], include_null: bool) -> Vec { let mut has_false = false; let mut has_true = false; let mut has_null = false; for column in columns { - if let Ok(ca) = column.as_materialized_series().bool() { - for val in ca.into_iter() { - match val { - Some(true) => has_true = true, - Some(false) => has_false = true, - None => has_null = true, + if let Ok(ca) = crate::array_util::as_bool(column) { + for i in 0..ca.len() { + if ca.is_null(i) { + has_null = true; + } else if ca.value(i) { + has_true = true; + } else { + has_false = true; } - // Early exit if all values have been encountered if has_null && has_true && has_false { break; } } } - // Early exit if all values have been encountered if has_null && has_true && has_false { break; } @@ -311,20 +316,21 @@ fn compute_unique_bool(columns: &[&Column], include_null: bool) -> Vec Vec { +fn compute_unique_numeric(columns: &[&ArrayRef], include_null: bool) -> Vec { let mut values: Vec = Vec::new(); let mut has_null = false; for column in columns { - if let Ok(series) = column.as_materialized_series().cast(&DataType::Float64) { - if let Ok(ca) = series.f64() { - for val in ca.into_iter() { - match val { - Some(v) if v.is_finite() && !values.contains(&v) => { + if let Ok(cast) = crate::array_util::cast_array(column, &DataType::Float64) { + if let Ok(f64_arr) = crate::array_util::as_f64(&cast) { + for i in 0..f64_arr.len() { + if f64_arr.is_null(i) { + has_null = true; + } else { + let v = f64_arr.value(i); + if v.is_finite() && !values.contains(&v) { values.push(v); } - None => has_null = true, - _ => {} // Skip NaN/Inf or duplicates } } } @@ -344,21 +350,19 @@ fn compute_unique_numeric(columns: &[&Column], include_null: bool) -> Vec Vec { +fn compute_unique_date(columns: &[&ArrayRef], include_null: bool) -> Vec { use std::collections::BTreeSet; let mut values: BTreeSet = BTreeSet::new(); let mut has_null = false; for column in columns { - if let Ok(ca) = column.as_materialized_series().date() { - // Access the underlying physical Int32 chunked array - for val in ca.phys.into_iter() { - match val { - Some(days) => { - values.insert(days); - } - None => has_null = true, + if let Ok(ca) = crate::array_util::as_date32(column) { + for i in 0..ca.len() { + if ca.is_null(i) { + has_null = true; + } else { + values.insert(ca.value(i)); } } } @@ -374,21 +378,19 @@ fn compute_unique_date(columns: &[&Column], include_null: bool) -> Vec Vec { +fn compute_unique_datetime(columns: &[&ArrayRef], include_null: bool) -> Vec { use std::collections::BTreeSet; let mut values: BTreeSet = BTreeSet::new(); let mut has_null = false; for column in columns { - if let Ok(ca) = column.as_materialized_series().datetime() { - // Access the underlying physical Int64 chunked array - for val in ca.phys.into_iter() { - match val { - Some(micros) => { - values.insert(micros); - } - None => has_null = true, + if let Ok(ca) = crate::array_util::as_timestamp_us(column) { + for i in 0..ca.len() { + if ca.is_null(i) { + has_null = true; + } else { + values.insert(ca.value(i)); } } } @@ -404,21 +406,19 @@ fn compute_unique_datetime(columns: &[&Column], include_null: bool) -> Vec Vec { +fn compute_unique_time(columns: &[&ArrayRef], include_null: bool) -> Vec { use std::collections::BTreeSet; let mut values: BTreeSet = BTreeSet::new(); let mut has_null = false; for column in columns { - if let Ok(ca) = column.as_materialized_series().time() { - // Access the underlying physical Int64 chunked array - for val in ca.phys.into_iter() { - match val { - Some(nanos) => { - values.insert(nanos); - } - None => has_null = true, + if let Ok(ca) = crate::array_util::as_time64_ns(column) { + for i in 0..ca.len() { + if ca.is_null(i) { + has_null = true; + } else { + values.insert(ca.value(i)); } } } @@ -434,25 +434,28 @@ fn compute_unique_time(columns: &[&Column], include_null: bool) -> Vec Vec { +fn compute_unique_string(columns: &[&ArrayRef], include_null: bool) -> Vec { use std::collections::BTreeSet; let mut values: BTreeSet = BTreeSet::new(); let mut has_null = false; for column in columns { - let series = column.as_materialized_series(); - if let Ok(unique) = series.unique() { - for i in 0..unique.len() { - if let Ok(val) = unique.get(i) { - if val.is_null() { - has_null = true; - } else { - let s = val.to_string(); - // Remove surrounding quotes from string representation - let clean = s.trim_matches('"').to_string(); - values.insert(clean); - } + // Try to get as string array, falling back to value_to_string for other types + if let Ok(str_arr) = crate::array_util::as_str(column) { + for i in 0..str_arr.len() { + if str_arr.is_null(i) { + has_null = true; + } else { + values.insert(str_arr.value(i).to_string()); + } + } + } else { + for i in 0..column.len() { + if column.is_null(i) { + has_null = true; + } else { + values.insert(crate::array_util::value_to_string(column, i)); } } } @@ -575,9 +578,9 @@ pub trait ScaleTypeTrait: std::fmt::Debug + std::fmt::Display + Send + Sync { // First check column data type for temporal transforms if let Some(dtype) = column_dtype { match dtype { - DataType::Date => return TransformKind::Date, - DataType::Datetime(_, _) => return TransformKind::DateTime, - DataType::Time => return TransformKind::Time, + DataType::Date32 => return TransformKind::Date, + DataType::Timestamp(_, _) => return TransformKind::DateTime, + DataType::Time64(_) => return TransformKind::Time, _ => {} } } @@ -1119,8 +1122,10 @@ impl ScaleType { | DataType::Float64 => Self::continuous(), // Temporal types are fundamentally continuous (days/µs/ns since epoch) // The temporal transform is inferred from the column data type - DataType::Date | DataType::Datetime(_, _) | DataType::Time => Self::continuous(), - DataType::Boolean | DataType::String => Self::discrete(), + DataType::Date32 | DataType::Timestamp(_, _) | DataType::Time64(_) => { + Self::continuous() + } + DataType::Boolean | DataType::Utf8 => Self::discrete(), _ => Self::discrete(), } } @@ -1887,8 +1892,8 @@ fn type_family(dtype: &DataType) -> TypeFamily { | DataType::UInt64 | DataType::Float32 | DataType::Float64 => TypeFamily::Numeric, - DataType::Date | DataType::Datetime(_, _) | DataType::Time => TypeFamily::Temporal, - DataType::String => TypeFamily::String, + DataType::Date32 | DataType::Timestamp(_, _) | DataType::Time64(_) => TypeFamily::Temporal, + DataType::Utf8 => TypeFamily::String, _ => TypeFamily::String, // Unknown types treated as String } } @@ -1927,7 +1932,7 @@ fn numeric_rank(dtype: &DataType) -> u8 { /// Returns Ok(DataType) with the common type, or Err if incompatible temporal types. pub fn coerce_dtypes(dtypes: &[DataType]) -> Result { if dtypes.is_empty() { - return Ok(DataType::String); // Default to String for empty + return Ok(DataType::Utf8); // Default to String for empty } if dtypes.len() == 1 { @@ -1939,7 +1944,7 @@ pub fn coerce_dtypes(dtypes: &[DataType]) -> Result { // Check if any type is String - result is String if families.contains(&TypeFamily::String) { - return Ok(DataType::String); + return Ok(DataType::Utf8); } // Check for mixed families @@ -1948,7 +1953,7 @@ pub fn coerce_dtypes(dtypes: &[DataType]) -> Result { if has_numeric && has_temporal { // Incompatible families - coerce to String - return Ok(DataType::String); + return Ok(DataType::Utf8); } // All numeric - find highest rank @@ -1971,9 +1976,9 @@ pub fn coerce_dtypes(dtypes: &[DataType]) -> Result { let all_same = dtypes.iter().all(|d| { matches!( (first, d), - (DataType::Date, DataType::Date) - | (DataType::Datetime(_, _), DataType::Datetime(_, _)) - | (DataType::Time, DataType::Time) + (DataType::Date32, DataType::Date32) + | (DataType::Timestamp(_, _), DataType::Timestamp(_, _)) + | (DataType::Time64(_), DataType::Time64(_)) ) }); @@ -1989,7 +1994,7 @@ pub fn coerce_dtypes(dtypes: &[DataType]) -> Result { } // Fallback to String - Ok(DataType::String) + Ok(DataType::Utf8) } /// Convert a Polars DataType to the corresponding CastTargetType. @@ -2008,10 +2013,10 @@ pub fn dtype_to_cast_target(dtype: &DataType) -> CastTargetType { | DataType::UInt64 | DataType::Float32 | DataType::Float64 => CastTargetType::Number, - DataType::Date => CastTargetType::Date, - DataType::Datetime(_, _) => CastTargetType::DateTime, - DataType::Time => CastTargetType::Time, - DataType::String => CastTargetType::String, + DataType::Date32 => CastTargetType::Date, + DataType::Timestamp(_, _) => CastTargetType::DateTime, + DataType::Time64(_) => CastTargetType::Time, + DataType::Utf8 => CastTargetType::String, _ => CastTargetType::String, // Unknown types treated as String } } @@ -2027,10 +2032,10 @@ pub fn needs_cast(column_dtype: &DataType, target_dtype: &DataType) -> Option true, - (DataType::Date, DataType::Date) => true, - (DataType::Datetime(_, _), DataType::Datetime(_, _)) => true, - (DataType::Time, DataType::Time) => true, - (DataType::String, DataType::String) => true, + (DataType::Date32, DataType::Date32) => true, + (DataType::Timestamp(_, _), DataType::Timestamp(_, _)) => true, + (DataType::Time64(_), DataType::Time64(_)) => true, + (DataType::Utf8, DataType::Utf8) => true, // For numeric, check if target is Float64 and column is any numeric ( DataType::Int8 @@ -2078,6 +2083,7 @@ pub fn needs_cast(column_dtype: &DataType, target_dtype: &DataType) -> Option Result<(), String> { match dtype { // Accept discrete types - DataType::String | DataType::Boolean | DataType::Categorical(_, _) => Ok(()), + DataType::Utf8 | DataType::Boolean | DataType::Dictionary(_, _) => Ok(()), // Accept integer types (useful for ordered categories like years, rankings) DataType::Int8 | DataType::Int16 @@ -45,11 +45,11 @@ impl ScaleTypeTrait for Ordinal { .to_string(), ), // Reject temporal types - DataType::Date => Err("Ordinal scale cannot be used with Date data. \ + DataType::Date32 => Err("Ordinal scale cannot be used with Date data. \ Use CONTINUOUS scale type instead (dates are treated as continuous temporal data).".to_string()), - DataType::Datetime(_, _) => Err("Ordinal scale cannot be used with DateTime data. \ + DataType::Timestamp(_, _) => Err("Ordinal scale cannot be used with DateTime data. \ Use CONTINUOUS scale type instead (datetimes are treated as continuous temporal data).".to_string()), - DataType::Time => Err("Ordinal scale cannot be used with Time data. \ + DataType::Time64(_) => Err("Ordinal scale cannot be used with Time data. \ Use CONTINUOUS scale type instead (times are treated as continuous temporal data).".to_string()), // Other types - provide generic message other => Err(format!( @@ -82,7 +82,7 @@ impl ScaleTypeTrait for Ordinal { // Infer from column type match column_dtype { Some(DataType::Boolean) => TransformKind::Bool, - Some(DataType::String) | Some(DataType::Categorical(_, _)) => TransformKind::String, + Some(DataType::Utf8) | Some(DataType::Dictionary(_, _)) => TransformKind::String, // Numeric types use Identity to preserve numeric sorting Some( DataType::Int8 @@ -506,7 +506,7 @@ mod tests { fn test_ordinal_default_transform_numeric() { use super::super::ScaleTypeTrait; use crate::plot::scale::TransformKind; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let ordinal = Ordinal; @@ -526,7 +526,7 @@ mod tests { // String/Boolean use their respective transforms assert_eq!( - ordinal.default_transform("color", Some(&DataType::String)), + ordinal.default_transform("color", Some(&DataType::Utf8)), TransformKind::String ); assert_eq!( @@ -542,16 +542,16 @@ mod tests { #[test] fn test_validate_dtype_accepts_string() { use super::super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let ordinal = Ordinal; - assert!(ordinal.validate_dtype(&DataType::String).is_ok()); + assert!(ordinal.validate_dtype(&DataType::Utf8).is_ok()); } #[test] fn test_validate_dtype_accepts_boolean() { use super::super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let ordinal = Ordinal; assert!(ordinal.validate_dtype(&DataType::Boolean).is_ok()); @@ -560,7 +560,7 @@ mod tests { #[test] fn test_validate_dtype_accepts_integer() { use super::super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let ordinal = Ordinal; // Integers are valid for ordinal scales (years, rankings, etc.) @@ -572,7 +572,7 @@ mod tests { #[test] fn test_validate_dtype_rejects_float() { use super::super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let ordinal = Ordinal; let result = ordinal.validate_dtype(&DataType::Float64); @@ -588,10 +588,10 @@ mod tests { #[test] fn test_validate_dtype_rejects_temporal() { use super::super::ScaleTypeTrait; - use polars::prelude::DataType; + use arrow::datatypes::DataType; let ordinal = Ordinal; - let result = ordinal.validate_dtype(&DataType::Date); + let result = ordinal.validate_dtype(&DataType::Date32); assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.contains("Date")); diff --git a/src/plot/types.rs b/src/plot/types.rs index 3c70c625..d7ee7379 100644 --- a/src/plot/types.rs +++ b/src/plot/types.rs @@ -5,8 +5,8 @@ //! to capture what the user specified in their query. use crate::reader::SqlDialect; +use arrow::datatypes::DataType; use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; -use polars::prelude::DataType; use serde::{Deserialize, Serialize}; use std::collections::HashMap; diff --git a/src/reader/data.rs b/src/reader/data.rs index b143a4b2..1dd90934 100644 --- a/src/reader/data.rs +++ b/src/reader/data.rs @@ -12,6 +12,26 @@ use crate::{naming, GgsqlError}; // beneath this block. // 3. Add a match arm in `builtin_parquet_bytes()` for your dataset. // 4. Add the dataset name to `KNOWN_DATASETS`. +// +// Parquet compatibility +// --------------------- +// The file must be readable by arrow-rs without `skip_arrow_metadata`. +// The test `all_builtin_parquets_load` enforces this in CI. +// +// Known-compatible writers: +// - Python `pyarrow` (`pq.write_table(...)`) +// - Rust `arrow-rs` + `parquet` (`ArrowWriter`) +// - DuckDB (`COPY ... TO 'file.parquet'`) +// +// Known-incompatible writers: +// - R `nanoparquet` — writes ARROW:schema with a different flatbuffers +// alignment that arrow-rs's strict reader rejects. +// +// If you receive a file from an incompatible source, round-trip it with a +// compatible writer. Example with pyarrow: +// import pyarrow.parquet as pq +// pq.write_table(pq.read_table('input.parquet'), 'output.parquet', +// compression='snappy') // ============================================================================= #[cfg(feature = "builtin-data")] @@ -63,7 +83,14 @@ pub fn register_builtin_datasets_duckdb( let mut tmp_path = env::temp_dir(); tmp_path.push(format!("{}.parquet", name)); if !tmp_path.exists() { - fs::write(&tmp_path, parquet_bytes).expect("Failed to write dataset"); + fs::write(&tmp_path, parquet_bytes).map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to write builtin dataset '{}' to {}: {}", + name, + tmp_path.display(), + e + )) + })?; } let create_sql = format!( @@ -83,13 +110,12 @@ pub fn register_builtin_datasets_duckdb( } // ============================================================================= -// Polars-based builtin data loading +// Arrow-based builtin data loading // ============================================================================= -#[cfg(feature = "parquet")] +#[cfg(all(feature = "builtin-data", feature = "parquet"))] pub fn load_builtin_dataframe(name: &str) -> Result { - use polars::prelude::*; - use std::io::Cursor; + use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; let parquet_bytes = match name { "penguins" => PENGUINS, @@ -102,10 +128,35 @@ pub fn load_builtin_dataframe(name: &str) -> Result = reader + .collect::, _>>() + .map_err(|e| { + GgsqlError::ReaderError(format!("Failed to load builtin dataset '{}': {}", name, e)) + })?; + + if batches.is_empty() { + return Ok(crate::DataFrame::empty()); + } + + let rb = if batches.len() == 1 { + batches.into_iter().next().unwrap() + } else { + arrow::compute::concat_batches(&batches[0].schema(), &batches).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to concat batches for '{}': {}", name, e)) + })? + }; + + Ok(crate::DataFrame::from_record_batch(rb)) } /// Known builtin dataset names in the ggsql namespace @@ -495,23 +546,32 @@ mod duckdb_tests { } } -#[cfg(feature = "builtin-data")] +#[cfg(all(feature = "builtin-data", feature = "parquet"))] #[cfg(test)] mod builtin_data_tests { use super::*; + /// Every entry in `KNOWN_DATASETS` must load cleanly via arrow-rs without + /// the `skip_arrow_metadata` workaround. If this test fails on a newly + /// added parquet file, the file was written by an incompatible tool + /// (see the compatibility notes at the top of this module). #[test] - fn test_load_builtin_parquet_penguins() { - let df = load_builtin_dataframe("penguins").unwrap(); - assert!(df.height() > 0); - assert!(df.width() > 0); - } - - #[test] - fn test_load_builtin_parquet_airquality() { - let df = load_builtin_dataframe("airquality").unwrap(); - assert!(df.height() > 0); - assert!(df.width() > 0); + fn all_builtin_parquets_load() { + for name in KNOWN_DATASETS { + let df = load_builtin_dataframe(name).unwrap_or_else(|e| { + panic!( + "Builtin dataset '{}' failed to load — likely an incompatible \ + parquet writer. See parquet compatibility notes in \ + src/reader/data.rs. Underlying error: {}", + name, e + ) + }); + assert!( + df.height() > 0 && df.width() > 0, + "Builtin dataset '{}' loaded with zero rows or columns", + name + ); + } } #[test] diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 53dfb288..8d449e8e 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -1,17 +1,15 @@ //! DuckDB data source implementation //! -//! Provides a reader for DuckDB databases with direct Polars DataFrame integration. +//! Provides a reader for DuckDB databases with Arrow DataFrame integration. use crate::reader::{connection::ConnectionInfo, Reader}; use crate::{naming, DataFrame, GgsqlError, Result}; -use arrow::ipc::reader::FileReader; +use arrow::array::ArrayRef; use duckdb::vtab::arrow::{arrow_recordbatch_to_query_params, ArrowVTab}; use duckdb::{params, Connection}; -use polars::io::SerWriter; -use polars::prelude::*; use std::cell::RefCell; use std::collections::HashSet; -use std::io::Cursor; +use std::sync::Arc; /// DuckDB SQL dialect with native function support. /// @@ -155,40 +153,11 @@ impl DuckDBReader { use super::validate_table_name; -/// Convert a Polars DataFrame to DuckDB Arrow query parameters via IPC serialization -fn dataframe_to_arrow_params(df: DataFrame) -> Result<[usize; 2]> { - // Serialize DataFrame to IPC format - let mut buffer = Vec::new(); - { - let mut writer = IpcWriter::new(&mut buffer); - writer.finish(&mut df.clone()).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to serialize DataFrame: {}", e)) - })?; - } - - // Read IPC into arrow crate's RecordBatch - let cursor = Cursor::new(buffer); - let reader = FileReader::try_new(cursor, None) - .map_err(|e| GgsqlError::ReaderError(format!("Failed to read IPC: {}", e)))?; - - // Collect all batches and concatenate if needed - let batches: Vec<_> = reader.filter_map(|r| r.ok()).collect(); - - if batches.is_empty() { - return Err(GgsqlError::ReaderError( - "DataFrame produced no Arrow batches".into(), - )); - } - - // For single batch, use directly; for multiple, concatenate - let rb = if batches.len() == 1 { - batches.into_iter().next().unwrap() - } else { - arrow::compute::concat_batches(&batches[0].schema(), &batches) - .map_err(|e| GgsqlError::ReaderError(format!("Failed to concat batches: {}", e)))? - }; - - Ok(arrow_recordbatch_to_query_params(rb)) +/// Convert a DataFrame to DuckDB Arrow query parameters. +/// +/// Since our DataFrame is already an Arrow RecordBatch, this is a simple passthrough. +fn dataframe_to_arrow_params(df: &DataFrame) -> Result<[usize; 2]> { + Ok(arrow_recordbatch_to_query_params(df.inner().clone())) } /// Helper struct for building typed columns from rows @@ -306,18 +275,19 @@ impl ColumnBuilder { Ok(()) } - fn build(self, column_name: &str) -> Result { - use polars::prelude::*; + fn build(self, column_name: &str) -> Result<(String, ArrayRef)> { + use arrow::array::*; use ColumnBuilder::*; - Ok(match self { - TinyInt(values) => Series::new(column_name.into(), values), - SmallInt(values) => Series::new(column_name.into(), values), - Int(values) => Series::new(column_name.into(), values), - BigInt(values) => Series::new(column_name.into(), values), - UTinyInt(values) => Series::new(column_name.into(), values), - USmallInt(values) => Series::new(column_name.into(), values), - UInt(values) => Series::new(column_name.into(), values), + let name = column_name.to_string(); + let array: ArrayRef = match self { + TinyInt(values) => Arc::new(Int8Array::from(values)), + SmallInt(values) => Arc::new(Int16Array::from(values)), + Int(values) => Arc::new(Int32Array::from(values)), + BigInt(values) => Arc::new(Int64Array::from(values)), + UTinyInt(values) => Arc::new(Int16Array::from(values)), + USmallInt(values) => Arc::new(Int32Array::from(values)), + UInt(values) => Arc::new(Int64Array::from(values)), UBigInt(values) => { // Check if all values fit in i64 let all_fit = values @@ -329,7 +299,7 @@ impl ColumnBuilder { .into_iter() .map(|opt_val| opt_val.map(|val| val as i64)) .collect(); - Series::new(column_name.into(), i64_values) + Arc::new(Int64Array::from(i64_values)) } else { eprintln!( "Warning: UBigInt overflow in column '{}', converting to string", @@ -339,32 +309,33 @@ impl ColumnBuilder { .into_iter() .map(|opt_val| opt_val.map(|val| val.to_string())) .collect(); - Series::new(column_name.into(), string_values) + Arc::new(StringArray::from( + string_values + .iter() + .map(|s| s.as_deref()) + .collect::>(), + )) } } - Float(values) => Series::new(column_name.into(), values), - Double(values) => Series::new(column_name.into(), values), - Boolean(values) => Series::new(column_name.into(), values), - Text(values) => Series::new(column_name.into(), values), + Float(values) => Arc::new(Float32Array::from(values)), + Double(values) => Arc::new(Float64Array::from(values)), + Boolean(values) => Arc::new(BooleanArray::from(values)), + Text(values) => Arc::new(StringArray::from( + values.iter().map(|s| s.as_deref()).collect::>(), + )), Date32(values) => { - let series = Series::new(column_name.into(), values); - series - .cast(&DataType::Date) - .map_err(|e| GgsqlError::ReaderError(format!("Date cast failed: {}", e)))? + // Arrow Date32 stores days since epoch directly + Arc::new(Date32Array::from(values)) } Timestamp(values) => { - let series = Series::new(column_name.into(), values); - series - .cast(&DataType::Datetime(TimeUnit::Microseconds, None)) - .map_err(|e| GgsqlError::ReaderError(format!("Timestamp cast failed: {}", e)))? + // DuckDB timestamps are in microseconds + Arc::new(TimestampMicrosecondArray::from(values)) } Time64(values) => { - let series = Series::new(column_name.into(), values); - series - .cast(&DataType::Time) - .map_err(|e| GgsqlError::ReaderError(format!("Time cast failed: {}", e)))? + // DuckDB time values are in nanoseconds + Arc::new(Time64NanosecondArray::from(values)) } - Decimal(values) => Series::new(column_name.into(), values), + Decimal(values) => Arc::new(Float64Array::from(values)), HugeInt(values) => { // Check if all values fit in i64 let all_fit = values.iter().all(|opt_val| { @@ -378,7 +349,7 @@ impl ColumnBuilder { .into_iter() .map(|opt_val| opt_val.map(|val| val as i64)) .collect(); - Series::new(column_name.into(), i64_values) + Arc::new(Int64Array::from(i64_values)) } else { eprintln!( "Warning: HugeInt overflow in column '{}', converting to string", @@ -388,7 +359,12 @@ impl ColumnBuilder { .into_iter() .map(|opt_val| opt_val.map(|val| val.to_string())) .collect(); - Series::new(column_name.into(), string_values) + Arc::new(StringArray::from( + string_values + .iter() + .map(|s| s.as_deref()) + .collect::>(), + )) } } Blob(values) => { @@ -396,23 +372,26 @@ impl ColumnBuilder { "Warning: Converting Blob column '{}' to string (debug format)", column_name ); - Series::new(column_name.into(), values) + Arc::new(StringArray::from( + values.iter().map(|s| s.as_deref()).collect::>(), + )) } Fallback(values) => { eprintln!( "Warning: Using fallback string conversion for column '{}'", column_name ); - Series::new(column_name.into(), values) + Arc::new(StringArray::from( + values.iter().map(|s| s.as_deref()).collect::>(), + )) } - }) + }; + Ok((name, array)) } } impl Reader for DuckDBReader { fn execute_sql(&self, sql: &str) -> Result { - use polars::prelude::*; - // Register builtin datasets if referenced #[cfg(feature = "builtin-data")] super::data::register_builtin_datasets_duckdb(sql, &self.conn)?; @@ -436,10 +415,7 @@ impl Reader for DuckDBReader { .execute(&sql, params![]) .map_err(|e| GgsqlError::ReaderError(format!("Failed to execute DDL: {}", e)))?; - // Return empty DataFrame for DDL statements - return DataFrame::new(Vec::::new()).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to create empty DataFrame: {}", e)) - }); + return Ok(DataFrame::empty()); } // Prepare and execute statement to get schema @@ -510,19 +486,22 @@ impl Reader for DuckDBReader { return Err(err); } - // Build Series from column builders (may be empty if query returned 0 rows) + // Build named arrays from column builders let column_builders = builders_cell.into_inner(); - let mut columns = Vec::new(); - for (col_idx, builder) in column_builders.into_iter().enumerate() { - let series = builder.build(&column_names[col_idx])?; - columns.push(series.into()); + let mut named_arrays: Vec<(&str, ArrayRef)> = Vec::new(); + // We need to hold the (String, ArrayRef) pairs so we can borrow the names + let built: Vec<(String, ArrayRef)> = column_builders + .into_iter() + .enumerate() + .map(|(col_idx, builder)| builder.build(&column_names[col_idx])) + .collect::>>()?; + + for (name, array) in &built { + named_arrays.push((name.as_str(), array.clone())); } // Create DataFrame from typed columns - let df = DataFrame::new(columns) - .map_err(|e| GgsqlError::ReaderError(format!("Failed to create DataFrame: {}", e)))?; - - Ok(df) + DataFrame::new(named_arrays) } fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { @@ -559,7 +538,7 @@ impl Reader for DuckDBReader { if total_rows <= MAX_ARROW_BATCH_ROWS { // Small DataFrame: register in a single batch - let params = dataframe_to_arrow_params(df)?; + let params = dataframe_to_arrow_params(&df)?; let sql = format!( "{} TEMP TABLE {} AS SELECT * FROM arrow(?, ?)", create_or_replace, @@ -571,7 +550,7 @@ impl Reader for DuckDBReader { } else { // Large DataFrame: create table from first chunk, then insert remaining chunks let first_chunk = df.slice(0, MAX_ARROW_BATCH_ROWS); - let params = dataframe_to_arrow_params(first_chunk)?; + let params = dataframe_to_arrow_params(&first_chunk)?; let create_sql = format!( "{} TEMP TABLE {} AS SELECT * FROM arrow(?, ?)", create_or_replace, @@ -584,8 +563,8 @@ impl Reader for DuckDBReader { let mut offset = MAX_ARROW_BATCH_ROWS; while offset < total_rows { let chunk_size = std::cmp::min(MAX_ARROW_BATCH_ROWS, total_rows - offset); - let chunk = df.slice(offset as i64, chunk_size); - let params = dataframe_to_arrow_params(chunk)?; + let chunk = df.slice(offset, chunk_size); + let params = dataframe_to_arrow_params(&chunk)?; let insert_sql = format!( "INSERT INTO {} SELECT * FROM arrow(?, ?)", naming::quote_ident(name) @@ -638,6 +617,8 @@ impl Reader for DuckDBReader { #[cfg(test)] mod tests { use super::*; + use crate::array_util::{as_i32, as_i64, as_str}; + use crate::df; #[test] fn test_create_in_memory() { @@ -651,7 +632,10 @@ mod tests { let df = reader.execute_sql("SELECT 1 as x, 2 as y").unwrap(); assert_eq!(df.shape(), (1, 2)); - assert_eq!(df.get_column_names(), vec!["x", "y"]); + assert_eq!( + df.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] @@ -674,7 +658,10 @@ mod tests { let df = reader.execute_sql("SELECT * FROM test").unwrap(); assert_eq!(df.shape(), (2, 2)); - assert_eq!(df.get_column_names(), vec!["x", "y"]); + assert_eq!( + df.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] @@ -710,18 +697,21 @@ mod tests { .unwrap(); assert_eq!(df.shape(), (2, 2)); - assert_eq!(df.get_column_names(), vec!["region", "total"]); + assert_eq!( + df.get_column_names(), + vec!["region".to_string(), "total".to_string()] + ); } #[test] fn test_register_and_query() { let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - // Create a DataFrame - let df = DataFrame::new(vec![ - Column::new("x".into(), vec![1i32, 2, 3]), - Column::new("y".into(), vec![10i32, 20, 30]), - ]) + // Create a DataFrame using the df! macro + let df = df! { + "x" => vec![1i32, 2, 3], + "y" => vec![10i32, 20, 30], + } .unwrap(); // Register the DataFrame @@ -732,15 +722,18 @@ mod tests { .execute_sql("SELECT * FROM my_table ORDER BY x") .unwrap(); assert_eq!(result.shape(), (3, 2)); - assert_eq!(result.get_column_names(), vec!["x", "y"]); + assert_eq!( + result.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] fn test_register_duplicate_name_errors() { let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - let df1 = DataFrame::new(vec![Column::new("a".into(), vec![1i32])]).unwrap(); - let df2 = DataFrame::new(vec![Column::new("b".into(), vec![2i32])]).unwrap(); + let df1 = df! { "a" => vec![1i32] }.unwrap(); + let df2 = df! { "b" => vec![2i32] }.unwrap(); // First registration should succeed reader.register("dup_table", df1, false).unwrap(); @@ -755,7 +748,7 @@ mod tests { #[test] fn test_register_invalid_table_names() { let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - let df = DataFrame::new(vec![Column::new("a".into(), vec![1i32])]).unwrap(); + let df = df! { "a" => vec![1i32] }.unwrap(); // Empty name let result = reader.register("", df.clone(), false); @@ -781,10 +774,10 @@ mod tests { let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); // Create an empty DataFrame with schema - let df = DataFrame::new(vec![ - Column::new("x".into(), Vec::::new()), - Column::new("y".into(), Vec::::new()), - ]) + let df = df! { + "x" => Vec::::new(), + "y" => Vec::<&str>::new(), + } .unwrap(); reader.register("empty_table", df, false).unwrap(); @@ -792,13 +785,16 @@ mod tests { // Query should return empty result with correct schema let result = reader.execute_sql("SELECT * FROM empty_table").unwrap(); assert_eq!(result.shape(), (0, 2)); - assert_eq!(result.get_column_names(), vec!["x", "y"]); + assert_eq!( + result.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] fn test_unregister() { let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - let df = DataFrame::new(vec![Column::new("x".into(), vec![1i32, 2, 3])]).unwrap(); + let df = df! { "x" => vec![1i32, 2, 3] }.unwrap(); reader.register("test_data", df, false).unwrap(); @@ -834,7 +830,7 @@ mod tests { #[test] fn test_reregister_after_unregister() { let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - let df = DataFrame::new(vec![Column::new("x".into(), vec![1i32, 2, 3])]).unwrap(); + let df = df! { "x" => vec![1i32, 2, 3] }.unwrap(); reader.register("data", df.clone(), false).unwrap(); reader.unregister("data").unwrap(); @@ -856,11 +852,11 @@ mod tests { let values: Vec = (0..n).map(|i| i as f64 * 1.5).collect(); let names: Vec = (0..n).map(|i| format!("item_{}", i)).collect(); - let df = DataFrame::new(vec![ - Column::new("id".into(), ids), - Column::new("value".into(), values), - Column::new("name".into(), names), - ]) + let df = df! { + "id" => ids, + "value" => values, + "name" => names, + } .unwrap(); reader.register("large_table", df, false).unwrap(); @@ -869,25 +865,16 @@ mod tests { let result = reader .execute_sql("SELECT COUNT(*) as cnt FROM large_table") .unwrap(); - let count = result.column("cnt").unwrap().i64().unwrap().get(0).unwrap(); + let count = as_i64(result.column("cnt").unwrap()).unwrap().value(0); assert_eq!(count, n as i64); // Verify first and last rows survived chunking intact let result = reader .execute_sql("SELECT id, name FROM large_table ORDER BY id LIMIT 1") .unwrap(); + assert_eq!(as_i32(result.column("id").unwrap()).unwrap().value(0), 0); assert_eq!( - result.column("id").unwrap().i32().unwrap().get(0).unwrap(), - 0 - ); - assert_eq!( - result - .column("name") - .unwrap() - .str() - .unwrap() - .get(0) - .unwrap(), + as_str(result.column("name").unwrap()).unwrap().value(0), "item_0" ); @@ -895,17 +882,11 @@ mod tests { .execute_sql("SELECT id, name FROM large_table ORDER BY id DESC LIMIT 1") .unwrap(); assert_eq!( - result.column("id").unwrap().i32().unwrap().get(0).unwrap(), + as_i32(result.column("id").unwrap()).unwrap().value(0), (n - 1) ); assert_eq!( - result - .column("name") - .unwrap() - .str() - .unwrap() - .get(0) - .unwrap(), + as_str(result.column("name").unwrap()).unwrap().value(0), format!("item_{}", n - 1) ); } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 0a4b1b9f..d68cca8e 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -493,6 +493,7 @@ pub fn execute_with_reader(reader: &dyn Reader, query: &str) -> Result { #[cfg(all(feature = "duckdb", feature = "vegalite"))] mod tests { use super::*; + use crate::df; use crate::writer::{VegaLiteWriter, Writer}; #[test] @@ -832,13 +833,11 @@ mod tests { #[test] fn test_register_and_query() { - use polars::prelude::*; - let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df = df! { - "x" => [1i32, 2, 3], - "y" => [10i32, 20, 30], + "x" => vec![1i32, 2, 3], + "y" => vec![10i32, 20, 30], } .unwrap(); @@ -858,20 +857,18 @@ mod tests { #[test] fn test_register_and_join() { - use polars::prelude::*; - let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let sales = df! { - "id" => [1i32, 2, 3], - "amount" => [100i32, 200, 300], - "product_id" => [1i32, 1, 2], + "id" => vec![1i32, 2, 3], + "amount" => vec![100i32, 200, 300], + "product_id" => vec![1i32, 1, 2], } .unwrap(); let products = df! { - "id" => [1i32, 2], - "name" => ["Widget", "Gadget"], + "id" => vec![1i32, 2], + "name" => vec!["Widget", "Gadget"], } .unwrap(); diff --git a/src/reader/odbc.rs b/src/reader/odbc.rs index 6f474da7..6fbf7d39 100644 --- a/src/reader/odbc.rs +++ b/src/reader/odbc.rs @@ -5,15 +5,16 @@ use crate::reader::Reader; use crate::{naming, DataFrame, GgsqlError, Result}; +use arrow::array::*; +use arrow::datatypes::DataType; use odbc_api::sys::{Date as OdbcDate, Time as OdbcTime, Timestamp as OdbcTimestamp}; use odbc_api::{ buffers::{AnyBuffer, AnySlice, BufferDesc, ColumnarBuffer}, ConnectionOptions, Cursor, DataType as OdbcDataType, Environment, }; -use polars::prelude::*; use std::cell::RefCell; use std::collections::HashSet; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; /// Global ODBC environment (must be a singleton per process). fn odbc_env() -> &'static Environment { @@ -115,8 +116,7 @@ impl Reader for OdbcReader { let Some(cursor) = cursor else { // DDL or non-query statement — return empty DataFrame - return DataFrame::new(Vec::::new()) - .map_err(|e| GgsqlError::ReaderError(format!("Empty DataFrame error: {}", e))); + return Ok(DataFrame::empty()); }; cursor_to_dataframe(cursor) @@ -134,12 +134,13 @@ impl Reader for OdbcReader { // Build CREATE TEMP TABLE with typed columns let schema = df.schema(); let col_defs: Vec = schema + .fields() .iter() - .map(|(col_name, dtype)| { + .map(|field| { format!( "{} {}", - naming::quote_ident(col_name), - polars_dtype_to_sql(dtype) + naming::quote_ident(field.name()), + arrow_dtype_to_sql(field.data_type()) ) }) .collect(); @@ -166,17 +167,16 @@ impl Reader for OdbcReader { ); // Convert all columns to string representation for text insertion - let string_columns: Vec>> = df - .get_columns() + let columns = df.get_columns(); + let string_columns: Vec>> = columns .iter() .map(|col| { (0..num_rows) .map(|row| { - let val = col.get(row).ok()?; - if val == AnyValue::Null { + if col.is_null(row) { None } else { - Some(format!("{}", val)) + Some(crate::array_util::value_to_string(col, row)) } }) .collect() @@ -278,16 +278,16 @@ impl Reader for OdbcReader { } } -/// Map a Polars data type to a SQL column type string. -fn polars_dtype_to_sql(dtype: &DataType) -> &'static str { +/// Map an Arrow data type to a SQL column type string. +fn arrow_dtype_to_sql(dtype: &DataType) -> &'static str { match dtype { DataType::Boolean => "BOOLEAN", DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => "BIGINT", DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => "BIGINT", DataType::Float32 | DataType::Float64 => "DOUBLE PRECISION", - DataType::Date => "DATE", - DataType::Datetime(_, _) => "TIMESTAMP", - DataType::Time => "TIME", + DataType::Date32 => "DATE", + DataType::Timestamp(_, _) => "TIMESTAMP", + DataType::Time64(_) => "TIME", _ => "TEXT", } } @@ -439,29 +439,24 @@ impl ColumnBuilder { Ok(()) } - fn into_series(self, name: &str) -> Series { - match self { - Self::Int8(v) => Series::new(name.into(), v), - Self::Int16(v) => Series::new(name.into(), v), - Self::Int32(v) => Series::new(name.into(), v), - Self::Int64(v) => Series::new(name.into(), v), - Self::Float32(v) => Series::new(name.into(), v), - Self::Float64(v) => Series::new(name.into(), v), - Self::Boolean(v) => Series::new(name.into(), v), - Self::Date(v) => { - let ca = Int32Chunked::new(name.into(), &v); - ca.into_date().into_series() + fn into_named_array(self, name: &str) -> (String, ArrayRef) { + let array: ArrayRef = match self { + Self::Int8(v) => Arc::new(Int8Array::from(v)), + Self::Int16(v) => Arc::new(Int16Array::from(v)), + Self::Int32(v) => Arc::new(Int32Array::from(v)), + Self::Int64(v) => Arc::new(Int64Array::from(v)), + Self::Float32(v) => Arc::new(Float32Array::from(v)), + Self::Float64(v) => Arc::new(Float64Array::from(v)), + Self::Boolean(v) => Arc::new(BooleanArray::from(v)), + Self::Date(v) => Arc::new(Date32Array::from(v)), + Self::Time(v) => Arc::new(Time64NanosecondArray::from(v)), + Self::Timestamp(v) => Arc::new(TimestampMicrosecondArray::from(v)), + Self::Text(v) => { + let refs: Vec> = v.iter().map(|s| s.as_deref()).collect(); + Arc::new(StringArray::from(refs)) } - Self::Time(v) => { - let ca = Int64Chunked::new(name.into(), &v); - ca.into_time().into_series() - } - Self::Timestamp(v) => { - let ca = Int64Chunked::new(name.into(), &v); - ca.into_datetime(TimeUnit::Microseconds, None).into_series() - } - Self::Text(v) => Series::new(name.into(), v), - } + }; + (name.to_string(), array) } } @@ -492,7 +487,7 @@ fn odbc_timestamp_to_micros(ts: &OdbcTimestamp) -> Option { .map(|dt| dt.and_utc().timestamp_micros()) } -/// Convert an ODBC cursor to a Polars DataFrame using typed buffers. +/// Convert an ODBC cursor to a DataFrame using typed buffers. fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { let col_count = cursor .num_result_cols() @@ -500,8 +495,7 @@ fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { as usize; if col_count == 0 { - return DataFrame::new(Vec::::new()) - .map_err(|e| GgsqlError::ReaderError(e.to_string())); + return Ok(DataFrame::empty()); } // Collect column names and types, build buffer descriptors @@ -551,14 +545,19 @@ fn cursor_to_dataframe(mut cursor: impl Cursor) -> Result { } } - // Convert builders to Polars Series - let series: Vec = col_names + // Convert builders to named arrays + let built: Vec<(String, ArrayRef)> = col_names .iter() .zip(builders) - .map(|(name, builder)| Column::from(builder.into_series(name))) + .map(|(name, builder)| builder.into_named_array(name)) + .collect(); + + let named_arrays: Vec<(&str, ArrayRef)> = built + .iter() + .map(|(name, arr)| (name.as_str(), arr.clone())) .collect(); - DataFrame::new(series).map_err(|e| GgsqlError::ReaderError(e.to_string())) + DataFrame::new(named_arrays) } // ============================================================================ @@ -844,11 +843,11 @@ account = "otheraccount" } #[test] - fn test_polars_dtype_to_sql() { - assert_eq!(polars_dtype_to_sql(&DataType::Int64), "BIGINT"); - assert_eq!(polars_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); - assert_eq!(polars_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); - assert_eq!(polars_dtype_to_sql(&DataType::Date), "DATE"); - assert_eq!(polars_dtype_to_sql(&DataType::String), "TEXT"); + fn test_arrow_dtype_to_sql() { + assert_eq!(arrow_dtype_to_sql(&DataType::Int64), "BIGINT"); + assert_eq!(arrow_dtype_to_sql(&DataType::Float64), "DOUBLE PRECISION"); + assert_eq!(arrow_dtype_to_sql(&DataType::Boolean), "BOOLEAN"); + assert_eq!(arrow_dtype_to_sql(&DataType::Date32), "DATE"); + assert_eq!(arrow_dtype_to_sql(&DataType::Utf8), "TEXT"); } } diff --git a/src/reader/sqlite.rs b/src/reader/sqlite.rs index c6fd92d5..33c077cd 100644 --- a/src/reader/sqlite.rs +++ b/src/reader/sqlite.rs @@ -1,15 +1,17 @@ //! SQLite data source implementation //! -//! Provides a reader for SQLite databases with Polars DataFrame integration. +//! Provides a reader for SQLite databases with Arrow DataFrame integration. //! Works on both native targets and wasm32-unknown-unknown (via sqlite-wasm-rs). use crate::reader::Reader; use crate::{naming, DataFrame, GgsqlError, Result}; +use arrow::array::*; +use arrow::datatypes::{DataType, TimeUnit}; use chrono::Datelike; -use polars::prelude::*; use rusqlite::Connection; use std::cell::RefCell; use std::collections::HashSet; +use std::sync::Arc; /// SQLite SQL dialect. /// @@ -95,7 +97,7 @@ impl super::SqlDialect for SqliteDialect { /// SQLite database reader /// /// Executes SQL queries against SQLite databases (in-memory or file-based) -/// and returns results as Polars DataFrames. +/// and returns results as DataFrames. pub struct SqliteReader { conn: Connection, registered_tables: RefCell>, @@ -190,8 +192,8 @@ fn validate_table_name(name: &str) -> Result<()> { Ok(()) } -/// Map a Polars DataType to a SQLite column type string -fn polars_type_to_sqlite(dtype: &DataType) -> &'static str { +/// Map an Arrow DataType to a SQLite column type string +fn arrow_type_to_sqlite(dtype: &DataType) -> &'static str { match dtype { DataType::Float32 | DataType::Float64 => "REAL", DataType::Int8 @@ -203,47 +205,116 @@ fn polars_type_to_sqlite(dtype: &DataType) -> &'static str { | DataType::UInt32 | DataType::UInt64 => "INTEGER", DataType::Boolean => "INTEGER", - DataType::Date => "TEXT", - DataType::Datetime(_, _) => "TEXT", - DataType::Time => "TEXT", + DataType::Date32 => "TEXT", + DataType::Timestamp(_, _) => "TEXT", + DataType::Time64(_) => "TEXT", _ => "TEXT", } } -/// Convert a Polars AnyValue to a rusqlite Value for parameter binding -fn anyvalue_to_sqlite(value: AnyValue, _dtype: &DataType) -> rusqlite::types::Value { +/// Convert an Arrow array value at a given row index to a rusqlite Value for parameter binding. +fn array_value_to_sqlite(array: &ArrayRef, row_idx: usize) -> rusqlite::types::Value { + use crate::array_util; use rusqlite::types::Value; - match value { - AnyValue::Null => Value::Null, - AnyValue::Boolean(b) => Value::Integer(b as i64), - AnyValue::Int8(v) => Value::Integer(v as i64), - AnyValue::Int16(v) => Value::Integer(v as i64), - AnyValue::Int32(v) => Value::Integer(v as i64), - AnyValue::Int64(v) => Value::Integer(v), - AnyValue::UInt8(v) => Value::Integer(v as i64), - AnyValue::UInt16(v) => Value::Integer(v as i64), - AnyValue::UInt32(v) => Value::Integer(v as i64), - AnyValue::UInt64(v) => Value::Integer(v as i64), - AnyValue::Float32(v) => Value::Real(v as f64), - AnyValue::Float64(v) => Value::Real(v), - AnyValue::String(s) => Value::Text(s.to_string()), - AnyValue::StringOwned(s) => Value::Text(s.to_string()), - AnyValue::Date(days) => chrono::NaiveDate::from_num_days_from_ce_opt(days + 719_163) - .and_then(|d| to_sql_value(&d)) - .unwrap_or(Value::Null), - AnyValue::Datetime(us, _, _) => chrono::DateTime::from_timestamp_micros(us) - .map(|d| d.naive_utc()) - .and_then(|d| to_sql_value(&d)) - .unwrap_or(Value::Null), - AnyValue::Time(ns) => { + if array.is_null(row_idx) { + return Value::Null; + } + + match array.data_type() { + DataType::Boolean => { + let arr = array.as_any().downcast_ref::().unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::Int8 => { + let arr = array_util::as_i8(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::Int16 => { + let arr = array_util::as_i16(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::Int32 => { + let arr = array_util::as_i32(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::Int64 => { + let arr = array_util::as_i64(array).unwrap(); + Value::Integer(arr.value(row_idx)) + } + DataType::UInt8 => { + let arr = array_util::as_u8(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::UInt16 => { + let arr = array_util::as_u16(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::UInt32 => { + let arr = array_util::as_u32(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::UInt64 => { + let arr = array_util::as_u64(array).unwrap(); + Value::Integer(arr.value(row_idx) as i64) + } + DataType::Float32 => { + let arr = array_util::as_f32(array).unwrap(); + Value::Real(arr.value(row_idx) as f64) + } + DataType::Float64 => { + let arr = array_util::as_f64(array).unwrap(); + Value::Real(arr.value(row_idx)) + } + DataType::Utf8 => { + let arr = array_util::as_str(array).unwrap(); + Value::Text(arr.value(row_idx).to_string()) + } + DataType::Date32 => { + let arr = array.as_any().downcast_ref::().unwrap(); + let days = arr.value(row_idx); + chrono::NaiveDate::from_num_days_from_ce_opt(days + 719_163) + .and_then(|d| to_sql_value(&d)) + .unwrap_or(Value::Null) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .unwrap(); + let us = arr.value(row_idx); + chrono::DateTime::from_timestamp_micros(us) + .map(|d| d.naive_utc()) + .and_then(|d| to_sql_value(&d)) + .unwrap_or(Value::Null) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .unwrap(); + let ms = arr.value(row_idx); + chrono::DateTime::from_timestamp_millis(ms) + .map(|d| d.naive_utc()) + .and_then(|d| to_sql_value(&d)) + .unwrap_or(Value::Null) + } + DataType::Time64(TimeUnit::Nanosecond) => { + let arr = array + .as_any() + .downcast_ref::() + .unwrap(); + let ns = arr.value(row_idx); let secs = (ns / 1_000_000_000) as u32; let nanos = (ns % 1_000_000_000) as u32; chrono::NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos) .and_then(|t| to_sql_value(&t)) .unwrap_or(Value::Null) } - _ => Value::Text(format!("{}", value)), + _ => { + // Fallback: use array_util::value_to_string + Value::Text(crate::array_util::value_to_string(array, row_idx)) + } } } @@ -260,7 +331,7 @@ fn to_sql_value(v: &dyn rusqlite::types::ToSql) -> Option Result { // Handle ggsql:name namespaced identifiers (builtin datasets) - #[cfg(feature = "parquet")] + #[cfg(all(feature = "builtin-data", feature = "parquet"))] { let dataset_names = super::data::extract_builtin_dataset_names(sql)?; for name in &dataset_names { @@ -288,9 +359,7 @@ impl Reader for SqliteReader { self.conn .execute_batch(&sql) .map_err(|e| GgsqlError::ReaderError(format!("Failed to execute DDL: {}", e)))?; - return DataFrame::new(Vec::::new()).map_err(|e| { - GgsqlError::ReaderError(format!("Failed to create empty DataFrame: {}", e)) - }); + return Ok(DataFrame::empty()); } let mut stmt = self @@ -330,16 +399,24 @@ impl Reader for SqliteReader { } } - // Build Series from collected values - let mut columns = Vec::with_capacity(column_count); - for (col_idx, values) in col_values.into_iter().enumerate() { - let name = &column_names[col_idx]; - let series = sqlite_values_to_series(name, values)?; - columns.push(series.into()); + // Build named arrays from collected values + let mut named_arrays: Vec<(&str, ArrayRef)> = Vec::with_capacity(column_count); + // Hold the built arrays so we can borrow names + let built: Vec<(String, ArrayRef)> = col_values + .into_iter() + .enumerate() + .map(|(col_idx, values)| { + let name = column_names[col_idx].clone(); + let array = sqlite_values_to_array(&name, values)?; + Ok((name, array)) + }) + .collect::>>()?; + + for (name, array) in &built { + named_arrays.push((name.as_str(), array.clone())); } - DataFrame::new(columns) - .map_err(|e| GgsqlError::ReaderError(format!("Failed to create DataFrame: {}", e))) + DataFrame::new(named_arrays) } fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { @@ -361,13 +438,14 @@ impl Reader for SqliteReader { } // Build CREATE TABLE statement - let col_defs: Vec = df - .get_columns() + let col_names = df.get_column_names(); + let schema = df.schema(); + let col_defs: Vec = schema + .fields() .iter() - .map(|col| { - let col_name = col.name().to_string(); - let col_type = polars_type_to_sqlite(col.dtype()); - format!("{} {}", naming::quote_ident(&col_name), col_type) + .map(|field| { + let col_type = arrow_type_to_sqlite(field.data_type()); + format!("{} {}", naming::quote_ident(field.name()), col_type) }) .collect(); @@ -380,7 +458,7 @@ impl Reader for SqliteReader { GgsqlError::ReaderError(format!("Failed to create table '{}': {}", name, e)) })?; - // Insert data using params_from_iter, wrapped in a transaction + // Insert data row by row, wrapped in a transaction if df.height() > 0 { let placeholders: Vec<&str> = vec!["?"; df.width()]; let insert_sql = format!( @@ -389,8 +467,8 @@ impl Reader for SqliteReader { placeholders.join(", ") ); - let dtypes: Vec = - df.get_columns().iter().map(|c| c.dtype().clone()).collect(); + let columns = df.get_columns(); + let _ = &col_names; // keep col_names alive self.conn.execute_batch("BEGIN").map_err(|e| { GgsqlError::ReaderError(format!("Failed to begin transaction: {}", e)) @@ -402,20 +480,10 @@ impl Reader for SqliteReader { })?; for row_idx in 0..df.height() { - let values: Vec = df - .get_columns() + let values: Vec = columns .iter() - .enumerate() - .map(|(col_idx, col)| { - let value = col.get(row_idx).map_err(|e| { - GgsqlError::ReaderError(format!( - "Failed to get value at row {}, col {}: {}", - row_idx, col_idx, e - )) - })?; - Ok(anyvalue_to_sqlite(value, &dtypes[col_idx])) - }) - .collect::>>()?; + .map(|col| array_value_to_sqlite(col, row_idx)) + .collect(); stmt.execute(rusqlite::params_from_iter(values)) .map_err(|e| { @@ -472,8 +540,8 @@ impl Reader for SqliteReader { } /// Try to parse all non-null TEXT values as ISO-8601 dates (YYYY-MM-DD). -/// Returns a Date series if all non-null values parse, None otherwise. -fn try_parse_as_date(name: &str, values: &[rusqlite::types::Value]) -> Option { +/// Returns a Date32 array if all non-null values parse, None otherwise. +fn try_parse_as_date(values: &[rusqlite::types::Value]) -> Option { use rusqlite::types::{FromSql, Value, ValueRef}; // Days between 0001-01-01 (CE day 1) and 1970-01-01 (Unix epoch) @@ -493,14 +561,13 @@ fn try_parse_as_date(name: &str, values: &[rusqlite::types::Value]) -> Option Option { +/// Returns a TimestampMillisecond array if all non-null values parse, None otherwise. +fn try_parse_as_datetime(values: &[rusqlite::types::Value]) -> Option { use rusqlite::types::{FromSql, Value, ValueRef}; let mut parsed: Vec> = Vec::with_capacity(values.len()); @@ -521,25 +588,24 @@ fn try_parse_as_datetime(name: &str, values: &[rusqlite::types::Value]) -> Optio } } - let series = Series::new(name.into(), parsed); - series - .cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) - .ok() + Some(Arc::new(TimestampMillisecondArray::from(parsed)) as ArrayRef) } -/// Infer the best Polars type from a column of SQLite values and build a Series. +/// Infer the best Arrow type from a column of SQLite values and build an ArrayRef. /// /// SQLite uses dynamic typing, so we infer the column type from all values: -/// - All Integer → Int64 -/// - All Integer/Real → Float64 -/// - All Text → String -/// - Mixed → String fallback -fn sqlite_values_to_series(name: &str, values: Vec) -> Result { +/// - All Integer -> Int64 +/// - All Integer/Real -> Float64 +/// - All Text -> String (with temporal detection) +/// - Mixed -> String fallback +fn sqlite_values_to_array(name: &str, values: Vec) -> Result { use rusqlite::types::Value; + let _ = name; // name is unused now but kept for consistency + if values.is_empty() { // Default to String for empty columns - return Ok(Series::new(name.into(), Vec::>::new())); + return Ok(Arc::new(StringArray::from(Vec::>::new())) as ArrayRef); } // Determine the dominant type @@ -560,11 +626,11 @@ fn sqlite_values_to_series(name: &str, values: Vec) -> R // If we have text, try temporal detection before falling back to String if has_text && !has_blob { - if let Some(series) = try_parse_as_date(name, &values) { - return Ok(series); + if let Some(array) = try_parse_as_date(&values) { + return Ok(array); } - if let Some(series) = try_parse_as_datetime(name, &values) { - return Ok(series); + if let Some(array) = try_parse_as_datetime(&values) { + return Ok(array); } } @@ -579,7 +645,8 @@ fn sqlite_values_to_series(name: &str, values: Vec) -> R Value::Blob(b) => Some(format!("{:?}", b)), }) .collect(); - return Ok(Series::new(name.into(), vals)); + let refs: Vec> = vals.iter().map(|s| s.as_deref()).collect(); + return Ok(Arc::new(StringArray::from(refs)) as ArrayRef); } // If we have any reals, use f64 @@ -593,7 +660,7 @@ fn sqlite_values_to_series(name: &str, values: Vec) -> R _ => None, }) .collect(); - return Ok(Series::new(name.into(), vals)); + return Ok(Arc::new(Float64Array::from(vals)) as ArrayRef); } // Pure integers @@ -606,17 +673,19 @@ fn sqlite_values_to_series(name: &str, values: Vec) -> R _ => None, }) .collect(); - return Ok(Series::new(name.into(), vals)); + return Ok(Arc::new(Int64Array::from(vals)) as ArrayRef); } // All nulls — default to String - let vals: Vec> = values.into_iter().map(|_| None).collect(); - Ok(Series::new(name.into(), vals)) + let vals: Vec> = values.iter().map(|_| None).collect(); + Ok(Arc::new(StringArray::from(vals)) as ArrayRef) } #[cfg(test)] mod tests { use super::*; + use crate::array_util::as_i64; + use crate::df; #[test] fn test_create_in_memory() { @@ -630,7 +699,10 @@ mod tests { let df = reader.execute_sql("SELECT 1 as x, 2 as y").unwrap(); assert_eq!(df.shape(), (1, 2)); - assert_eq!(df.get_column_names(), vec!["x", "y"]); + assert_eq!( + df.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] @@ -641,8 +713,8 @@ mod tests { .unwrap(); assert_eq!(df.shape(), (1, 2)); - assert_eq!(df.column("x").unwrap().dtype(), &DataType::Int64); - assert_eq!(df.column("y").unwrap().dtype(), &DataType::Int64); + assert_eq!(df.column_dtype("x").unwrap(), DataType::Int64); + assert_eq!(df.column_dtype("y").unwrap(), DataType::Int64); } #[test] @@ -687,7 +759,10 @@ mod tests { let df = reader.execute_sql("SELECT * FROM test").unwrap(); assert_eq!(df.shape(), (2, 2)); - assert_eq!(df.get_column_names(), vec!["x", "y"]); + assert_eq!( + df.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] @@ -701,10 +776,10 @@ mod tests { fn test_register_and_query() { let reader = SqliteReader::new().unwrap(); - let df = DataFrame::new(vec![ - Column::new("x".into(), vec![1i32, 2, 3]), - Column::new("y".into(), vec![10i32, 20, 30]), - ]) + let df = df! { + "x" => vec![1i32, 2, 3], + "y" => vec![10i32, 20, 30], + } .unwrap(); reader.register("my_table", df, false).unwrap(); @@ -713,15 +788,18 @@ mod tests { .execute_sql("SELECT * FROM my_table ORDER BY x") .unwrap(); assert_eq!(result.shape(), (3, 2)); - assert_eq!(result.get_column_names(), vec!["x", "y"]); + assert_eq!( + result.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] fn test_register_duplicate_name_errors() { let reader = SqliteReader::new().unwrap(); - let df1 = DataFrame::new(vec![Column::new("a".into(), vec![1i32])]).unwrap(); - let df2 = DataFrame::new(vec![Column::new("b".into(), vec![2i32])]).unwrap(); + let df1 = df! { "a" => vec![1i32] }.unwrap(); + let df2 = df! { "b" => vec![2i32] }.unwrap(); reader.register("dup_table", df1, false).unwrap(); @@ -734,7 +812,7 @@ mod tests { #[test] fn test_register_invalid_table_names() { let reader = SqliteReader::new().unwrap(); - let df = DataFrame::new(vec![Column::new("a".into(), vec![1i32])]).unwrap(); + let df = df! { "a" => vec![1i32] }.unwrap(); let result = reader.register("", df.clone(), false); assert!(result.is_err()); @@ -757,23 +835,28 @@ mod tests { fn test_register_empty_dataframe() { let reader = SqliteReader::new().unwrap(); - let df = DataFrame::new(vec![ - Column::new("x".into(), Vec::::new()), - Column::new("y".into(), Vec::::new()), - ]) - .unwrap(); + // Create an empty DataFrame with schema by slicing a 1-row df to 0 rows + let df = df! { + "x" => vec![0i32], + "y" => vec!["placeholder"], + } + .unwrap() + .slice(0, 0); reader.register("empty_table", df, false).unwrap(); let result = reader.execute_sql("SELECT * FROM empty_table").unwrap(); assert_eq!(result.shape(), (0, 2)); - assert_eq!(result.get_column_names(), vec!["x", "y"]); + assert_eq!( + result.get_column_names(), + vec!["x".to_string(), "y".to_string()] + ); } #[test] fn test_unregister() { let reader = SqliteReader::new().unwrap(); - let df = DataFrame::new(vec![Column::new("x".into(), vec![1i32, 2, 3])]).unwrap(); + let df = df! { "x" => vec![1i32, 2, 3] }.unwrap(); reader.register("test_data", df, false).unwrap(); @@ -804,7 +887,7 @@ mod tests { #[test] fn test_reregister_after_unregister() { let reader = SqliteReader::new().unwrap(); - let df = DataFrame::new(vec![Column::new("x".into(), vec![1i32, 2, 3])]).unwrap(); + let df = df! { "x" => vec![1i32, 2, 3] }.unwrap(); reader.register("data", df.clone(), false).unwrap(); reader.unregister("data").unwrap(); @@ -823,11 +906,11 @@ mod tests { let values: Vec = (0..n).map(|i| i as f64 * 1.5).collect(); let names: Vec = (0..n).map(|i| format!("item_{}", i)).collect(); - let df = DataFrame::new(vec![ - Column::new("id".into(), ids), - Column::new("value".into(), values), - Column::new("name".into(), names), - ]) + let df = df! { + "id" => ids, + "value" => values, + "name" => names, + } .unwrap(); reader.register("large_table", df, false).unwrap(); @@ -835,7 +918,7 @@ mod tests { let result = reader .execute_sql("SELECT COUNT(*) as cnt FROM large_table") .unwrap(); - let count = result.column("cnt").unwrap().i64().unwrap().get(0).unwrap(); + let count = as_i64(result.column("cnt").unwrap()).unwrap().value(0); assert_eq!(count, n as i64); } @@ -861,15 +944,18 @@ mod tests { .unwrap(); assert_eq!(df.shape(), (2, 2)); - assert_eq!(df.get_column_names(), vec!["region", "total"]); + assert_eq!( + df.get_column_names(), + vec!["region".to_string(), "total".to_string()] + ); } #[test] fn test_register_with_replace() { let reader = SqliteReader::new().unwrap(); - let df1 = DataFrame::new(vec![Column::new("x".into(), vec![1i32])]).unwrap(); - let df2 = DataFrame::new(vec![Column::new("x".into(), vec![2i32, 3])]).unwrap(); + let df1 = df! { "x" => vec![1i32] }.unwrap(); + let df2 = df! { "x" => vec![2i32, 3] }.unwrap(); reader.register("data", df1, false).unwrap(); reader.register("data", df2, true).unwrap(); @@ -902,7 +988,7 @@ mod tests { fn test_boolean_roundtrip() { let reader = SqliteReader::new().unwrap(); - let df = DataFrame::new(vec![Column::new("flag".into(), vec![true, false, true])]).unwrap(); + let df = df! { "flag" => vec![true, false, true] }.unwrap(); reader.register("bool_data", df, false).unwrap(); @@ -934,21 +1020,17 @@ mod tests { fn test_date_column_roundtrip() { let reader = SqliteReader::new().unwrap(); - // Register a DataFrame with a Date column - let dates = Series::new("d".into(), vec![19000i32, 19001, 19002]); - let dates = dates.cast(&DataType::Date).unwrap(); - let df = DataFrame::new(vec![ - dates.into_column(), - Column::new("v".into(), vec![1, 2, 3]), - ]) - .unwrap(); + // Register a DataFrame with a Date column (Date32 in Arrow) + let dates: ArrayRef = Arc::new(Date32Array::from(vec![19000i32, 19001, 19002])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let df = DataFrame::new(vec![("d", dates), ("v", values)]).unwrap(); reader.register("date_data", df, false).unwrap(); let result = reader.execute_sql("SELECT * FROM date_data").unwrap(); assert_eq!(result.height(), 3); - assert_eq!(result.column("d").unwrap().dtype(), &DataType::Date); - assert_eq!(result.column("v").unwrap().dtype(), &DataType::Int64); + assert_eq!(result.column_dtype("d").unwrap(), DataType::Date32); + assert_eq!(result.column_dtype("v").unwrap(), DataType::Int64); } #[test] @@ -969,11 +1051,11 @@ mod tests { assert_eq!(result.height(), 2); assert!( matches!( - result.column("ts").unwrap().dtype(), - DataType::Datetime(_, _) + result.column_dtype("ts").unwrap(), + DataType::Timestamp(_, _) ), - "Expected Datetime, got {:?}", - result.column("ts").unwrap().dtype() + "Expected Timestamp, got {:?}", + result.column_dtype("ts").unwrap() ); } @@ -989,7 +1071,7 @@ mod tests { .unwrap(); let result = reader.execute_sql("SELECT * FROM str_data").unwrap(); - assert_eq!(result.column("name").unwrap().dtype(), &DataType::String); + assert_eq!(result.column_dtype("name").unwrap(), DataType::Utf8); } #[test] @@ -998,14 +1080,10 @@ mod tests { let reader = SqliteReader::new().unwrap(); - // Register a table with a date column - let dates = Series::new("date".into(), vec![19000i32, 19001, 19002]); - let dates = dates.cast(&DataType::Date).unwrap(); - let df = DataFrame::new(vec![ - dates.into_column(), - Column::new("value".into(), vec![10, 20, 30]), - ]) - .unwrap(); + // Register a table with a date column (Date32 in Arrow) + let dates: ArrayRef = Arc::new(Date32Array::from(vec![19000i32, 19001, 19002])); + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + let df = DataFrame::new(vec![("date", dates), ("value", values)]).unwrap(); reader.register("ts_data", df, false).unwrap(); let spec = reader @@ -1161,7 +1239,7 @@ mod tests { } } -#[cfg(feature = "parquet")] +#[cfg(all(feature = "builtin-data", feature = "parquet"))] #[cfg(test)] mod builtin_data_tests { use super::*; @@ -1196,9 +1274,9 @@ mod builtin_data_tests { .execute_sql("SELECT Date FROM ggsql:airquality LIMIT 5") .unwrap(); assert_eq!( - result.column("Date").unwrap().dtype(), - &DataType::Date, - "airquality Date column should be detected as Date, not String" + result.column_dtype("Date").unwrap(), + DataType::Date32, + "airquality Date column should be detected as Date32, not String" ); } } diff --git a/src/writer/vegalite/data.rs b/src/writer/vegalite/data.rs index 65da1a1b..3953b68c 100644 --- a/src/writer/vegalite/data.rs +++ b/src/writer/vegalite/data.rs @@ -1,8 +1,9 @@ //! DataFrame to JSON conversion utilities for Vega-Lite writer //! -//! This module handles converting Polars DataFrames to Vega-Lite JSON data values, +//! This module handles converting Arrow DataFrames to Vega-Lite JSON data values, //! including temporal type handling and binned data transformations. +use crate::array_util::*; use crate::plot::scale::ScaleTypeKind; /// Column name for row index (used to preserve data order in Vega-Lite) @@ -12,7 +13,8 @@ pub(super) const ROW_INDEX_COLUMN: &str = "__ggsql_row_index__"; use crate::plot::ArrayElement; use crate::plot::ParameterValue; use crate::{naming, AestheticValue, DataFrame, GgsqlError, Plot, Result}; -use polars::prelude::*; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::{DataType, TimeUnit}; use serde_json::{json, Map, Value}; use std::collections::HashMap; @@ -24,22 +26,23 @@ pub(super) enum TemporalType { Time, } -/// Convert Polars DataFrame to Vega-Lite data values (array of objects) +/// Convert DataFrame to Vega-Lite data values (array of objects) pub(super) fn dataframe_to_values(df: &DataFrame) -> Result> { let mut values = Vec::new(); let height = df.height(); let column_names = df.get_column_names(); + let columns = df.get_columns(); for row_idx in 0..height { let mut row_obj = Map::new(); for (col_idx, col_name) in column_names.iter().enumerate() { - let column = df.get_columns().get(col_idx).ok_or_else(|| { + let column = columns.get(col_idx).ok_or_else(|| { GgsqlError::WriterError(format!("Failed to get column {}", col_name)) })?; - // Get value from series and convert to JSON Value - let value = series_value_at(column.as_materialized_series(), row_idx)?; + // Get value from array and convert to JSON Value + let value = series_value_at(column, row_idx)?; row_obj.insert(col_name.to_string(), value); } @@ -49,120 +52,70 @@ pub(super) fn dataframe_to_values(df: &DataFrame) -> Result> { Ok(values) } -/// Get a single value from a series at a given index as JSON Value -pub(super) fn series_value_at(series: &Series, idx: usize) -> Result { - use DataType::*; +/// Get a single value from an arrow array at a given index as JSON Value +pub(super) fn series_value_at(array: &ArrayRef, idx: usize) -> Result { + if array.is_null(idx) { + return Ok(Value::Null); + } - match series.dtype() { - Int8 => { - let ca = series - .i8() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to i8: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - Int16 => { - let ca = series - .i16() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to i16: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - Int32 => { - let ca = series - .i32() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to i32: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - Int64 => { - let ca = series - .i64() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to i64: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - Float32 => { - let ca = series - .f32() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to f32: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - Float64 => { - let ca = series - .f64() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to f64: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - Boolean => { - let ca = series - .bool() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to bool: {}", e)))?; - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) - } - String => { - let ca = series - .str() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to string: {}", e)))?; + match array.data_type() { + DataType::Int8 => Ok(json!(as_i8(array)?.value(idx))), + DataType::Int16 => Ok(json!(as_i16(array)?.value(idx))), + DataType::Int32 => Ok(json!(as_i32(array)?.value(idx))), + DataType::Int64 => Ok(json!(as_i64(array)?.value(idx))), + DataType::UInt8 => Ok(json!(as_u8(array)?.value(idx))), + DataType::UInt16 => Ok(json!(as_u16(array)?.value(idx))), + DataType::UInt32 => Ok(json!(as_u32(array)?.value(idx))), + DataType::UInt64 => Ok(json!(as_u64(array)?.value(idx))), + DataType::Float32 => Ok(json!(as_f32(array)?.value(idx))), + DataType::Float64 => Ok(json!(as_f64(array)?.value(idx))), + DataType::Boolean => Ok(json!(as_bool(array)?.value(idx))), + DataType::Utf8 => { // Keep strings as strings (don't parse to numbers) - Ok(ca.get(idx).map(|v| json!(v)).unwrap_or(Value::Null)) + Ok(json!(as_str(array)?.value(idx))) } - Date => { + DataType::Date32 => { // Convert days since epoch to ISO date string: "YYYY-MM-DD" - let ca = series - .date() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to date: {}", e)))?; - if let Some(days) = ca.phys.get(idx) { - let unix_epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let date = unix_epoch + chrono::Duration::days(days as i64); - Ok(json!(date.format("%Y-%m-%d").to_string())) - } else { - Ok(Value::Null) - } + let days = as_date32(array)?.value(idx); + let unix_epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let date = unix_epoch + chrono::Duration::days(days as i64); + Ok(json!(date.format("%Y-%m-%d").to_string())) } - Datetime(time_unit, _) => { + DataType::Timestamp(time_unit, _) => { // Convert timestamp to ISO datetime: "YYYY-MM-DDTHH:MM:SS.sssZ" - let ca = series.datetime().map_err(|e| { - GgsqlError::WriterError(format!("Failed to cast to datetime: {}", e)) + let timestamp = as_timestamp_us(array).map(|a| a.value(idx)).or_else(|_| { + // Try casting to microsecond timestamp first + let cast = cast_array(array, &DataType::Timestamp(TimeUnit::Microsecond, None))?; + Ok(as_timestamp_us(&cast)?.value(idx)) })?; - if let Some(timestamp) = ca.phys.get(idx) { - // Convert to microseconds based on time unit - let micros = match time_unit { - TimeUnit::Microseconds => timestamp, - TimeUnit::Milliseconds => timestamp * 1_000, - TimeUnit::Nanoseconds => timestamp / 1_000, - }; - let secs = micros / 1_000_000; - let nsecs = ((micros % 1_000_000) * 1000) as u32; - let dt = chrono::DateTime::::from_timestamp(secs, nsecs) - .unwrap_or_else(|| { - chrono::DateTime::::from_timestamp(0, 0).unwrap() - }); - Ok(json!(dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string())) - } else { - Ok(Value::Null) - } + // timestamp is in microseconds for TimestampMicrosecondArray + let micros = match time_unit { + TimeUnit::Microsecond => timestamp, + TimeUnit::Millisecond => timestamp * 1_000, + TimeUnit::Nanosecond => timestamp / 1_000, + TimeUnit::Second => timestamp * 1_000_000, + }; + let secs = micros / 1_000_000; + let nsecs = ((micros % 1_000_000) * 1000) as u32; + let dt = chrono::DateTime::::from_timestamp(secs, nsecs) + .unwrap_or_else(|| chrono::DateTime::::from_timestamp(0, 0).unwrap()); + Ok(json!(dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string())) } - Time => { + DataType::Time64(_) => { // Convert nanoseconds since midnight to ISO time: "HH:MM:SS.sss" - let ca = series - .time() - .map_err(|e| GgsqlError::WriterError(format!("Failed to cast to time: {}", e)))?; - if let Some(nanos) = ca.phys.get(idx) { - let hours = nanos / 3_600_000_000_000; - let minutes = (nanos % 3_600_000_000_000) / 60_000_000_000; - let seconds = (nanos % 60_000_000_000) / 1_000_000_000; - let millis = (nanos % 1_000_000_000) / 1_000_000; - Ok(json!(format!( - "{:02}:{:02}:{:02}.{:03}", - hours, minutes, seconds, millis - ))) - } else { - Ok(Value::Null) - } + let nanos = as_time64_ns(array)?.value(idx); + let hours = nanos / 3_600_000_000_000; + let minutes = (nanos % 3_600_000_000_000) / 60_000_000_000; + let seconds = (nanos % 60_000_000_000) / 1_000_000_000; + let millis = (nanos % 1_000_000_000) / 1_000_000; + Ok(json!(format!( + "{:02}:{:02}:{:02}.{:03}", + hours, minutes, seconds, millis + ))) } _ => { // Fallback: convert to string - Ok(json!(series - .get(idx) - .map(|v| v.to_string()) - .unwrap_or_default())) + Ok(json!(value_to_string(array, idx))) } } } @@ -198,7 +151,7 @@ pub(super) fn find_bin_for_value(value: f64, breaks: &[f64]) -> Option<(f64, f64 None } -/// Convert Polars DataFrame to Vega-Lite data values with bin columns. +/// Convert DataFrame to Vega-Lite data values with bin columns. /// /// For columns with binned scales, this replaces the center value with bin_start /// and adds a corresponding bin_end column. @@ -209,17 +162,18 @@ pub(super) fn dataframe_to_values_with_bins( let mut values = Vec::new(); let height = df.height(); let column_names = df.get_column_names(); + let columns = df.get_columns(); for row_idx in 0..height { let mut row_obj = Map::new(); for (col_idx, col_name) in column_names.iter().enumerate() { - let column = df.get_columns().get(col_idx).ok_or_else(|| { + let column = columns.get(col_idx).ok_or_else(|| { GgsqlError::WriterError(format!("Failed to get column {}", col_name)) })?; - // Get value from series and convert to JSON Value - let value = series_value_at(column.as_materialized_series(), row_idx)?; + // Get value from array and convert to JSON Value + let value = series_value_at(column, row_idx)?; // Check if this column has binned data let col_name_str = col_name.to_string(); diff --git a/src/writer/vegalite/encoding.rs b/src/writer/vegalite/encoding.rs index 17c97c06..112831d4 100644 --- a/src/writer/vegalite/encoding.rs +++ b/src/writer/vegalite/encoding.rs @@ -3,11 +3,13 @@ //! This module handles building Vega-Lite encoding channels from ggsql aesthetic mappings, //! including type inference, scale properties, and title handling. +use crate::array_util::as_str; use crate::plot::aesthetic::{is_position_aesthetic, AestheticContext}; use crate::plot::scale::{linetype_to_stroke_dash, shape_to_svg_path, ScaleTypeKind}; use crate::plot::{CoordKind, ParameterValue}; use crate::{AestheticValue, DataFrame, GgsqlError, Plot, Result}; -use polars::prelude::*; +use arrow::array::Array; +use arrow::datatypes::DataType; use serde_json::{json, Value}; use std::collections::{HashMap, HashSet}; @@ -213,12 +215,15 @@ pub(super) fn count_binned_legend_scales(spec: &Plot) -> usize { .count() } -/// Check if a string column contains numeric values -pub(super) fn is_numeric_string_column(series: &Series) -> bool { - if let Ok(ca) = series.str() { +/// Check if a string (Utf8) column contains numeric values +pub(super) fn is_numeric_string_column(array: &arrow::array::ArrayRef) -> bool { + if let Ok(ca) = as_str(array) { // Check first few non-null values to see if they're numeric - for val in ca.into_iter().flatten().take(5) { - if val.parse::().is_err() { + for i in 0..ca.len().min(5) { + if ca.is_null(i) { + continue; + } + if ca.value(i).parse::().is_err() { return false; } } @@ -231,18 +236,25 @@ pub(super) fn is_numeric_string_column(series: &Series) -> bool { /// Infer Vega-Lite field type from DataFrame column pub(super) fn infer_field_type(df: &DataFrame, field: &str) -> String { if let Ok(column) = df.column(field) { - use DataType::*; - match column.dtype() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 => { + match column.data_type() { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => "quantitative", + DataType::Boolean => "nominal", + DataType::Utf8 + // Check if string column contains numeric values + if is_numeric_string_column(column) => + { "quantitative" } - Boolean => "nominal", - String - // Check if string column contains numeric values - if is_numeric_string_column(column.as_materialized_series()) => { - "quantitative" - } - Date | Datetime(_, _) | Time => "temporal", + DataType::Date32 | DataType::Timestamp(_, _) | DataType::Time64(_) => "temporal", _ => "nominal", } .to_string() diff --git a/src/writer/vegalite/layer.rs b/src/writer/vegalite/layer.rs index e1bd85e5..34ecc9c7 100644 --- a/src/writer/vegalite/layer.rs +++ b/src/writer/vegalite/layer.rs @@ -12,7 +12,7 @@ use crate::plot::layer::is_transposed; use crate::plot::{ArrayElement, ParameterValue}; use crate::writer::vegalite::POINTS_TO_PIXELS; use crate::{naming, AestheticValue, DataFrame, Geom, GgsqlError, Layer, Result}; -use polars::prelude::ChunkCompareEq; +use arrow::array::Array; use serde_json::{json, Map, Value}; use std::any::Any; use std::collections::HashMap; @@ -294,7 +294,7 @@ pub struct PathRenderer; /// /// Used by both line segmentation and text font run-length encoding. fn find_change_starts(df: &DataFrame, columns: &[String]) -> Result> { - use polars::prelude::*; + use crate::array_util::value_to_string; let n_rows = df.height(); @@ -302,34 +302,35 @@ fn find_change_starts(df: &DataFrame, columns: &[String]) -> Result> return Ok(vec![0]); } - // Initialize change mask as all false (no changes) - let mut change_mask = BooleanChunked::full("change_mask".into(), false, n_rows - 1); - - // For each column, OR its change mask into the accumulator - for col_name in columns { - let series = df.column(col_name).map_err(|e| { - GgsqlError::InternalError(format!("Column '{}' not found: {}", col_name, e)) - })?; - - // Compare each row with the previous row - // curr = series[1..n], prev = series[0..n-1] - let curr = series.slice(1, n_rows - 1); - let prev = series.slice(0, n_rows - 1); + // Build a change mask manually: for each row i (1..n), check if any column differs from row i-1 + let mut change_starts = vec![0]; - // Get boolean mask where values differ - let not_equal = curr.not_equal(&prev).map_err(|e| { - GgsqlError::InternalError(format!("Failed to compare column '{}': {}", col_name, e)) - })?; + for i in 1..n_rows { + let mut changed = false; + for col_name in columns { + let col = df.column(col_name).map_err(|e| { + GgsqlError::InternalError(format!("Column '{}' not found: {}", col_name, e)) + })?; - // OR with accumulator (change if this column OR any previous column changed) - change_mask = &change_mask | ¬_equal; - } + // Compare values using string representation + let curr_null = col.is_null(i); + let prev_null = col.is_null(i - 1); - // Extract indices where mask is true (offset by 1 since we compared with previous) - let mut change_starts = vec![0]; - for (idx, changed) in change_mask.into_iter().enumerate() { - if changed == Some(true) { - change_starts.push(idx + 1); + if curr_null != prev_null { + changed = true; + break; + } + if !curr_null { + let curr_val = value_to_string(col, i); + let prev_val = value_to_string(col, i - 1); + if curr_val != prev_val { + changed = true; + break; + } + } + } + if changed { + change_starts.push(i); } } @@ -345,7 +346,10 @@ fn aesthetic_varies_within_groups( aesthetic_col: &str, group_boundaries: &[usize], ) -> Result { - let series = df.column(aesthetic_col).map_err(|e| { + use crate::array_util::value_to_string; + use std::collections::HashSet; + + let col = df.column(aesthetic_col).map_err(|e| { GgsqlError::InternalError(format!("Column '{}' not found: {}", aesthetic_col, e)) })?; @@ -358,14 +362,17 @@ fn aesthetic_varies_within_groups( continue; // Single-row groups can't vary } - // Slice the series for this group and check uniqueness - let segment = series.slice(start as i64, end - start); - let n_unique = segment.n_unique().map_err(|e| { - GgsqlError::InternalError(format!("Failed to count unique values: {}", e)) - })?; - - if n_unique > 1 { - return Ok(true); + // Count unique values in this segment + let mut unique = HashSet::new(); + for i in start..end { + if col.is_null(i) { + unique.insert("__null__".to_string()); + } else { + unique.insert(value_to_string(col, i)); + } + if unique.len() > 1 { + return Ok(true); + } } } @@ -798,13 +805,14 @@ impl TextRenderer { /// - DataFrame where each row represents a run's font properties (family, fontweight, italic, hjust, vjust, angle) /// - Vec of run lengths corresponding to each row fn build_font_rle(df: &DataFrame) -> Result<(DataFrame, Vec)> { - use polars::prelude::*; + use arrow::array::ArrayRef; + use arrow::compute; let nrows = df.height(); if nrows == 0 { // Return empty DataFrame and empty run lengths - return Ok((DataFrame::default(), Vec::new())); + return Ok((DataFrame::empty(), Vec::new())); } // Collect font property column names that exist in the DataFrame @@ -818,7 +826,7 @@ impl TextRenderer { ]; let mut font_column_names = Vec::new(); - let mut font_columns: HashMap<&str, &polars::prelude::Column> = HashMap::new(); + let mut font_columns: HashMap<&str, &ArrayRef> = HashMap::new(); for aesthetic in font_aesthetics { let col_name = naming::aesthetic_column(aesthetic); @@ -841,29 +849,38 @@ impl TextRenderer { }) .collect(); - // Extract rows at change indices (only font columns) - let indices_ca = UInt32Chunked::from_vec( - "indices".into(), - change_indices.iter().map(|&i| i as u32).collect(), - ); + // Extract rows at change indices (only font columns) using arrow take + let indices_array: ArrayRef = std::sync::Arc::new(arrow::array::UInt32Array::from( + change_indices + .iter() + .map(|&i| i as u32) + .collect::>(), + )); - let mut result_cols = Vec::new(); + let mut result_cols: Vec<(&str, ArrayRef)> = Vec::new(); for aesthetic in font_aesthetics { if let Some(col) = font_columns.get(aesthetic) { - let taken = col.take(&indices_ca).map_err(|e| { + let taken = compute::take( + col.as_ref(), + indices_array + .as_any() + .downcast_ref::() + .unwrap(), + None, + ) + .map_err(|e| { GgsqlError::InternalError(format!( "Failed to take indices from {}: {}", aesthetic, e )) })?; - result_cols.push(taken); + let col_name = naming::aesthetic_column(aesthetic); + result_cols.push((Box::leak(col_name.into_boxed_str()), taken)); } } // Create result DataFrame (only font properties, no run_length column) - let result_df = DataFrame::new(result_cols).map_err(|e| { - GgsqlError::InternalError(format!("Failed to create run DataFrame: {}", e)) - })?; + let result_df = DataFrame::new(result_cols)?; Ok((result_df, run_lengths)) } @@ -1093,29 +1110,35 @@ impl TextRenderer { ) -> Result<()> { // Helper to extract string column values using aesthetic column naming let get_str = |aesthetic: &str| -> Option { + use crate::array_util::as_str; let col_name = naming::aesthetic_column(aesthetic); - df.column(&col_name) - .ok() - .and_then(|col| col.str().ok()) - .and_then(|ca| ca.get(row_idx)) - .map(|s| s.to_string()) + let col = df.column(&col_name).ok()?; + if col.is_null(row_idx) { + return None; + } + as_str(col).ok().map(|ca| ca.value(row_idx).to_string()) }; // Helper to extract numeric column values (for angle) let get_f64 = |aesthetic: &str| -> Option { - use polars::prelude::*; + use crate::array_util::{as_f64, as_str, cast_array}; + use arrow::datatypes::DataType; let col_name = naming::aesthetic_column(aesthetic); let col = df.column(&col_name).ok()?; + if col.is_null(row_idx) { + return None; + } + // Try as string first (for string-encoded numbers) - if let Ok(ca) = col.str() { - return ca.get(row_idx).and_then(|s| s.parse::().ok()); + if let Ok(ca) = as_str(col) { + return ca.value(row_idx).parse::().ok(); } // Try as numeric types directly - if let Ok(casted) = col.cast(&DataType::Float64) { - if let Ok(ca) = casted.f64() { - return ca.get(row_idx); + if let Ok(casted) = cast_array(col, &DataType::Float64) { + if let Ok(ca) = as_f64(&casted) { + return Some(ca.value(row_idx)); } } @@ -1273,7 +1296,7 @@ impl GeomRenderer for TextRenderer { let suffix = format!("_font_{}", run_idx); // Slice the contiguous run from the DataFrame (more efficient than boolean masking) - let sliced = df.slice(position as i64, length); + let sliced = df.slice(position, length); let mut values = if binned_columns.is_empty() { dataframe_to_values(&sliced)? @@ -1776,32 +1799,59 @@ impl BoxplotRenderer { let type_col = type_col.as_str(); // Get the type column for filtering - let type_series = data + let type_array = data .column(type_col) - .and_then(|s| s.str()) + .map_err(|e| GgsqlError::WriterError(e.to_string()))?; + let type_str_array = crate::array_util::as_str(type_array) .map_err(|e| GgsqlError::WriterError(e.to_string()))?; // Check for outliers - let has_outliers = type_series.equal("outlier").any(); + let has_outliers = (0..type_str_array.len()) + .any(|i| !type_str_array.is_null(i) && type_str_array.value(i) == "outlier"); // Split data by type into separate datasets let mut type_datasets: HashMap> = HashMap::new(); for type_name in &["lower_whisker", "upper_whisker", "box", "median", "outlier"] { - let mask = type_series.equal(*type_name); - let filtered = data - .filter(&mask) - .map_err(|e| GgsqlError::WriterError(e.to_string()))?; + // Collect row indices matching this type + let matching_indices: Vec = (0..type_str_array.len()) + .filter(|&i| !type_str_array.is_null(i) && type_str_array.value(i) == *type_name) + .collect(); // Skip empty datasets (e.g., no outliers) - if filtered.height() == 0 { + if matching_indices.is_empty() { continue; } - // Drop the type column since type is now encoded in the source key - let filtered = filtered - .drop(type_col) - .map_err(|e| GgsqlError::WriterError(e.to_string()))?; + // Build filtered DataFrame by taking matching rows + let indices_arr: arrow::array::ArrayRef = + std::sync::Arc::new(arrow::array::UInt32Array::from( + matching_indices + .iter() + .map(|&i| i as u32) + .collect::>(), + )); + let indices_u32 = indices_arr + .as_any() + .downcast_ref::() + .unwrap(); + + let column_names = data.get_column_names(); + let columns = data.get_columns(); + let mut new_cols: Vec<(&str, arrow::array::ArrayRef)> = Vec::new(); + for (col_idx, col_name) in column_names.iter().enumerate() { + if col_name == type_col { + continue; // Drop the type column + } + let taken = arrow::compute::take(columns[col_idx].as_ref(), indices_u32, None) + .map_err(|e| { + GgsqlError::WriterError(format!("Failed to filter column: {}", e)) + })?; + new_cols.push((Box::leak(col_name.clone().into_boxed_str()), taken)); + } + let filtered = DataFrame::new(new_cols).map_err(|e| { + GgsqlError::WriterError(format!("Failed to create filtered DataFrame: {}", e)) + })?; let values = if binned_columns.is_empty() { dataframe_to_values(&filtered)? @@ -2372,18 +2422,22 @@ mod tests { } #[test] fn test_text_constant_font() { + use crate::df; use crate::naming; - use polars::prelude::*; let renderer = TextRenderer; let layer = Layer::new(crate::plot::Geom::text()); // Create DataFrame where all rows have the same font + let x_col = naming::aesthetic_column("x"); + let y_col = naming::aesthetic_column("y"); + let label_col = naming::aesthetic_column("label"); + let typeface_col = naming::aesthetic_column("typeface"); let df = df! { - naming::aesthetic_column("x").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("y").as_str() => &[10.0, 20.0, 30.0], - naming::aesthetic_column("label").as_str() => &["A", "B", "C"], - naming::aesthetic_column("typeface").as_str() => &["Arial", "Arial", "Arial"], + x_col.as_str() => vec![1.0, 2.0, 3.0], + y_col.as_str() => vec![10.0, 20.0, 30.0], + label_col.as_str() => vec!["A", "B", "C"], + typeface_col.as_str() => vec!["Arial", "Arial", "Arial"], } .unwrap(); @@ -2404,18 +2458,22 @@ mod tests { #[test] fn test_text_varying_font() { + use crate::df; use crate::naming; - use polars::prelude::*; let renderer = TextRenderer; let layer = Layer::new(crate::plot::Geom::text()); // Create DataFrame with different fonts per row + let x_col = naming::aesthetic_column("x"); + let y_col = naming::aesthetic_column("y"); + let label_col = naming::aesthetic_column("label"); + let typeface_col = naming::aesthetic_column("typeface"); let df = df! { - naming::aesthetic_column("x").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("y").as_str() => &[10.0, 20.0, 30.0], - naming::aesthetic_column("label").as_str() => &["A", "B", "C"], - naming::aesthetic_column("typeface").as_str() => &["Arial", "Courier", "Times"], + x_col.as_str() => vec![1.0, 2.0, 3.0], + y_col.as_str() => vec![10.0, 20.0, 30.0], + label_col.as_str() => vec!["A", "B", "C"], + typeface_col.as_str() => vec!["Arial", "Courier", "Times"], } .unwrap(); @@ -2438,20 +2496,26 @@ mod tests { #[test] fn test_text_nested_layers_structure() { + use crate::df; use crate::naming; - use polars::prelude::*; let renderer = TextRenderer; let layer = Layer::new(crate::plot::Geom::text()); // Create DataFrame with different fonts + let x_col = naming::aesthetic_column("x"); + let y_col = naming::aesthetic_column("y"); + let label_col = naming::aesthetic_column("label"); + let typeface_col = naming::aesthetic_column("typeface"); + let fontweight_col = naming::aesthetic_column("fontweight"); + let italic_col = naming::aesthetic_column("italic"); let df = df! { - naming::aesthetic_column("x").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("y").as_str() => &[10.0, 20.0, 30.0], - naming::aesthetic_column("label").as_str() => &["A", "B", "C"], - naming::aesthetic_column("typeface").as_str() => &["Arial", "Courier", "Arial"], - naming::aesthetic_column("fontweight").as_str() => &["bold", "normal", "bold"], - naming::aesthetic_column("italic").as_str() => &["false", "true", "false"], + x_col.as_str() => vec![1.0, 2.0, 3.0], + y_col.as_str() => vec![10.0, 20.0, 30.0], + label_col.as_str() => vec!["A", "B", "C"], + typeface_col.as_str() => vec!["Arial", "Courier", "Arial"], + fontweight_col.as_str() => vec!["bold", "normal", "bold"], + italic_col.as_str() => vec!["false", "true", "false"], } .unwrap(); @@ -2519,18 +2583,22 @@ mod tests { #[test] fn test_text_varying_angle() { + use crate::df; use crate::naming; - use polars::prelude::*; let renderer = TextRenderer; let layer = Layer::new(crate::plot::Geom::text()); // Create DataFrame with different angles + let x_col = naming::aesthetic_column("x"); + let y_col = naming::aesthetic_column("y"); + let label_col = naming::aesthetic_column("label"); + let rotation_col = naming::aesthetic_column("rotation"); let df = df! { - naming::aesthetic_column("x").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("y").as_str() => &[10.0, 20.0, 30.0], - naming::aesthetic_column("label").as_str() => &["A", "B", "C"], - naming::aesthetic_column("rotation").as_str() => &["0", "45", "90"], + x_col.as_str() => vec![1.0, 2.0, 3.0], + y_col.as_str() => vec![10.0, 20.0, 30.0], + label_col.as_str() => vec!["A", "B", "C"], + rotation_col.as_str() => vec!["0", "45", "90"], } .unwrap(); @@ -2587,18 +2655,22 @@ mod tests { #[test] fn test_text_varying_angle_numeric() { + use crate::df; use crate::naming; - use polars::prelude::*; let renderer = TextRenderer; let layer = Layer::new(crate::plot::Geom::text()); // Create DataFrame with numeric angle column (matching actual query) + let x_col = naming::aesthetic_column("x"); + let y_col = naming::aesthetic_column("y"); + let label_col = naming::aesthetic_column("label"); + let rotation_col = naming::aesthetic_column("rotation"); let df = df! { - naming::aesthetic_column("x").as_str() => &[1, 2, 3], - naming::aesthetic_column("y").as_str() => &[1, 2, 3], - naming::aesthetic_column("label").as_str() => &["A", "B", "C"], - naming::aesthetic_column("rotation").as_str() => &[0i32, 180i32, 0i32], // integer column + x_col.as_str() => vec![1i32, 2, 3], + y_col.as_str() => vec![1i32, 2, 3], + label_col.as_str() => vec!["A", "B", "C"], + rotation_col.as_str() => vec![0i32, 180i32, 0i32], // integer column } .unwrap(); @@ -3923,17 +3995,19 @@ mod tests { #[test] fn test_path_renderer_varying_aesthetics_metadata() { + use crate::df; use crate::plot::{AestheticValue, Geom, Layer}; - use polars::prelude::*; let renderer = PathRenderer; let mut layer = Layer::new(Geom::line()); // Create DataFrame with varying stroke + let pos1_col = naming::aesthetic_column("pos1"); + let pos2_col = naming::aesthetic_column("pos2"); let df = df! { - naming::aesthetic_column("pos1").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("pos2").as_str() => &[10.0, 20.0, 30.0], - "color".to_string().as_str() => &[1.0, 2.0, 3.0], + pos1_col.as_str() => vec![1.0, 2.0, 3.0], + pos2_col.as_str() => vec![10.0, 20.0, 30.0], + "color" => vec![1.0, 2.0, 3.0], } .unwrap(); @@ -3962,17 +4036,20 @@ mod tests { #[test] fn test_path_renderer_trail_mark_for_varying_linewidth() { + use crate::df; use crate::plot::{AestheticValue, Geom, Layer}; - use polars::prelude::*; let renderer = PathRenderer; let mut layer = Layer::new(Geom::line()); // Create DataFrame with varying linewidth + let pos1_col = naming::aesthetic_column("pos1"); + let pos2_col = naming::aesthetic_column("pos2"); + let linewidth_col = naming::aesthetic_column("linewidth"); let df = df! { - naming::aesthetic_column("pos1").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("pos2").as_str() => &[10.0, 20.0, 30.0], - naming::aesthetic_column("linewidth").as_str() => &[1.0, 3.0, 5.0], + pos1_col.as_str() => vec![1.0, 2.0, 3.0], + pos2_col.as_str() => vec![10.0, 20.0, 30.0], + linewidth_col.as_str() => vec![1.0, 3.0, 5.0], } .unwrap(); @@ -4023,19 +4100,23 @@ mod tests { #[test] fn test_path_renderer_trail_mark_with_stroke_legend() { + use crate::df; use crate::plot::{AestheticValue, Geom, Layer}; - use polars::prelude::*; let context = RenderContext::default_for_test(); let renderer = PathRenderer; let mut layer = Layer::new(Geom::line()); // Create DataFrame with varying linewidth and stroke + let pos1_col = naming::aesthetic_column("pos1"); + let pos2_col = naming::aesthetic_column("pos2"); + let linewidth_col = naming::aesthetic_column("linewidth"); + let stroke_col = naming::aesthetic_column("stroke"); let df = df! { - naming::aesthetic_column("pos1").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("pos2").as_str() => &[10.0, 20.0, 30.0], - naming::aesthetic_column("linewidth").as_str() => &[1.0, 3.0, 5.0], - naming::aesthetic_column("stroke").as_str() => &["A", "A", "B"], + pos1_col.as_str() => vec![1.0, 2.0, 3.0], + pos2_col.as_str() => vec![10.0, 20.0, 30.0], + linewidth_col.as_str() => vec![1.0, 3.0, 5.0], + stroke_col.as_str() => vec!["A", "A", "B"], } .unwrap(); @@ -4105,18 +4186,20 @@ mod tests { #[test] fn test_path_renderer_segmentation_for_varying_stroke() { + use crate::df; use crate::plot::{AestheticValue, Geom, Layer}; - use polars::prelude::*; let renderer = PathRenderer; let mut layer = Layer::new(Geom::line()); // Create DataFrame with varying stroke + let pos1_col = naming::aesthetic_column("pos1"); + let pos2_col = naming::aesthetic_column("pos2"); let df = df! { - naming::aesthetic_column("pos1").as_str() => &[1.0, 2.0, 3.0], - naming::aesthetic_column("pos2").as_str() => &[10.0, 20.0, 30.0], - "color".to_string().as_str() => &[1.0, 2.0, 3.0], - ROW_INDEX_COLUMN => &[0, 1, 2], + pos1_col.as_str() => vec![1.0, 2.0, 3.0], + pos2_col.as_str() => vec![10.0, 20.0, 30.0], + "color" => vec![1.0, 2.0, 3.0], + ROW_INDEX_COLUMN => vec![0i32, 1, 2], } .unwrap(); diff --git a/src/writer/vegalite/mod.rs b/src/writer/vegalite/mod.rs index 8de2bd87..69cf307d 100644 --- a/src/writer/vegalite/mod.rs +++ b/src/writer/vegalite/mod.rs @@ -1204,9 +1204,9 @@ impl Writer for VegaLiteWriter { #[cfg(test)] mod tests { use super::*; + use crate::df; use crate::plot::{Labels, Layer, ParameterValue}; use crate::Geom; - use polars::prelude::*; use serde_json::Value; use std::collections::HashMap; use std::sync::LazyLock; @@ -1365,8 +1365,8 @@ mod tests { /// Helper to create a simple DataFrame with x and y columns for testing fn simple_df() -> DataFrame { df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], } .unwrap() } @@ -1514,8 +1514,8 @@ mod tests { // Create simple DataFrame let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], } .unwrap(); @@ -1561,8 +1561,8 @@ mod tests { spec.labels = Some(labels); let df = df! { - "date" => &["2024-01-01", "2024-01-02"], - "value" => &[10, 20], + "date" => vec!["2024-01-01", "2024-01-02"], + "value" => vec![10, 20], } .unwrap(); @@ -1702,10 +1702,10 @@ mod tests { // Create DataFrame let df = df! { - "x" => &[1, 2, 3], - "y" => &[1, 2, 3], - "label" => &["A", "B", "C"], - "value" => &[1.0, 2.0, 3.0], + "x" => vec![1, 2, 3], + "y" => vec![1, 2, 3], + "label" => vec!["A", "B", "C"], + "value" => vec![1.0, 2.0, 3.0], } .unwrap(); @@ -1752,8 +1752,8 @@ mod tests { spec.layers.push(layer); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], } .unwrap(); @@ -1783,8 +1783,8 @@ mod tests { transform_spec(&mut spec); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], } .unwrap(); @@ -1798,7 +1798,7 @@ mod tests { #[test] fn test_numeric_type_inference_integers() { let df = df! { - "x" => &[1i64, 2, 3], + "x" => vec![1i64, 2, 3], } .unwrap(); @@ -1808,7 +1808,7 @@ mod tests { #[test] fn test_nominal_type_inference_strings() { let df = df! { - "category" => &["A", "B", "C"], + "category" => vec!["A", "B", "C"], } .unwrap(); @@ -1818,7 +1818,7 @@ mod tests { #[test] fn test_numeric_string_type_inference() { let df = df! { - "numbers_as_strings" => &["1.5", "2.5", "3.5"], + "numbers_as_strings" => vec!["1.5", "2.5", "3.5"], } .unwrap(); @@ -1883,8 +1883,8 @@ mod tests { spec.layers.push(point_layer); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], } .unwrap(); @@ -1992,9 +1992,9 @@ mod tests { spec.scales.push(scale); let df = df! { - "x" => &[1, 2, 3], - "y" => &[10, 45, 80], - "value" => &[10.0, 45.0, 80.0], + "x" => vec![1, 2, 3], + "y" => vec![10, 45, 80], + "value" => vec![10.0, 45.0, 80.0], } .unwrap(); @@ -2160,9 +2160,9 @@ mod tests { spec.layers.push(layer); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], - "stroke" => &["red", "blue", "green"], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], + "stroke" => vec!["red", "blue", "green"], } .unwrap(); @@ -2201,8 +2201,8 @@ mod tests { spec.layers.push(layer); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], } .unwrap(); @@ -2555,10 +2555,10 @@ mod tests { spec.scales.push(x_scale); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], - "category" => &["A", "A", "B"], - "__ggsql_aes_facet1__" => &["A", "A", "B"], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], + "category" => vec!["A", "A", "B"], + "__ggsql_aes_facet1__" => vec!["A", "A", "B"], } .unwrap(); @@ -2639,10 +2639,10 @@ mod tests { spec.scales.push(y_scale); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], - "category" => &["A", "A", "B"], - "__ggsql_aes_facet1__" => &["A", "A", "B"], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], + "category" => vec!["A", "A", "B"], + "__ggsql_aes_facet1__" => vec!["A", "A", "B"], } .unwrap(); @@ -2720,10 +2720,10 @@ mod tests { spec.scales.push(x_scale); let df = df! { - "x" => &[1, 2, 3], - "y" => &[4, 5, 6], - "category" => &["A", "A", "B"], - "__ggsql_aes_facet1__" => &["A", "A", "B"], + "x" => vec![1, 2, 3], + "y" => vec![4, 5, 6], + "category" => vec!["A", "A", "B"], + "__ggsql_aes_facet1__" => vec!["A", "A", "B"], } .unwrap(); @@ -2814,10 +2814,10 @@ mod tests { spec.layers.push(layer); let df = df! { - "x1" => &[0, 1], - "y1" => &[0, 1], - "x2" => &[1, 2], - "y2" => &[1, 2], + "x1" => vec![0, 1], + "y1" => vec![0, 1], + "x2" => vec![1, 2], + "y2" => vec![1, 2], } .unwrap();
{}{}
{}{}