diff --git a/Cargo.lock b/Cargo.lock index 39489ed94..a3e9336cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,12 @@ dependencies = [ "zstd", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.9" @@ -359,6 +365,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] @@ -859,9 +867,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe060b978f74ab446be722adb8a274e052e005bf6dfd171caadc3abaad10080" +checksum = "cc6cb8c2c81eada072059983657d6c9caf3fddefc43b4a65551d243253254a96" dependencies = [ "arrow", "arrow-ipc", @@ -887,7 +895,6 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-table", "datafusion-functions-window", - "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -902,7 +909,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", + "rand 0.9.1", "regex", "sqlparser", "tempfile", @@ -915,9 +922,9 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fe34f401bd03724a1f96d12108144f8cd495a3cdda2bf5e091822fb80b7e66" +checksum = "b7be8d1b627843af62e447396db08fe1372d882c0eb8d0ea655fd1fbc33120ee" dependencies = [ "arrow", "async-trait", @@ -941,9 +948,9 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4411b8e3bce5e0fc7521e44f201def2e2d5d1b5f176fb56e8cdc9942c890f00" +checksum = "38ab16c5ae43f65ee525fc493ceffbc41f40dee38b01f643dfcfc12959e92038" dependencies = [ "arrow", "async-trait", @@ -964,9 +971,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734015d81c8375eb5d4869b7f7ecccc2ee8d6cb81948ef737cd0e7b743bd69c" +checksum = "d3d56b2ac9f476b93ca82e4ef5fb00769c8a3f248d12b4965af7e27635fa7e12" dependencies = [ "ahash", "apache-avro", @@ -989,9 +996,9 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5167bb1d2ccbb87c6bc36c295274d7a0519b14afcfdaf401d53cbcaa4ef4968b" +checksum = "16015071202d6133bc84d72756176467e3e46029f3ce9ad2cb788f9b1ff139b2" dependencies = [ "futures", "log", @@ -1000,9 +1007,9 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e602dcdf2f50c2abf297cc2203c73531e6f48b29516af7695d338cf2a778b1" +checksum = "b77523c95c89d2a7eb99df14ed31390e04ab29b43ff793e562bdc1716b07e17b" dependencies = [ "arrow", "async-compression", @@ -1025,7 +1032,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -1036,9 +1043,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-avro" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4ea5111aab9d3f2a8bff570343cccb03ce4c203875ef5a566b7d6f1eb72559e" +checksum = "1371cb4ef13c2e3a15685d37a07398cf13e3b0a85e705024b769fc4c511f5fef" dependencies = [ "apache-avro", "arrow", @@ -1061,9 +1068,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bb2253952dc32296ed5b84077cb2e0257fea4be6373e1c376426e17ead4ef6" +checksum = "40d25c5e2c0ebe8434beeea997b8e88d55b3ccc0d19344293f2373f65bc524fc" dependencies = [ "arrow", "async-trait", @@ -1086,9 +1093,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8c7f47a5d2fe03bfa521ec9bafdb8a5c82de8377f60967c3663f00c8790352" +checksum = "3dc6959e1155741ab35369e1dc7673ba30fc45ed568fad34c01b7cb1daeb4d4c" dependencies = [ "arrow", "async-trait", @@ -1111,9 +1118,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27d15868ea39ed2dc266728b554f6304acd473de2142281ecfa1294bb7415923" +checksum = "b7a6afdfe358d70f4237f60eaef26ae5a1ce7cb2c469d02d5fc6c7fd5d84e58b" dependencies = [ "arrow", "async-trait", @@ -1136,21 +1143,21 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand 0.8.5", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91f8c2c5788ef32f48ff56c68e5b545527b744822a284373ac79bba1ba47292" +checksum = "9bcd8a3e3e3d02ea642541be23d44376b5d5c37c2938cce39b3873cdf7186eea" [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06f004d100f49a3658c9da6fb0c3a9b760062d96cd4ad82ccc3b7b69a9fb2f84" +checksum = "670da1d45d045eee4c2319b8c7ea57b26cf48ab77b630aaa50b779e406da476a" dependencies = [ "arrow", "dashmap", @@ -1160,16 +1167,16 @@ dependencies = [ "log", "object_store", "parking_lot", - "rand 0.8.5", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4e4ce3802609be38eeb607ee72f6fe86c3091460de9dbfae9e18db423b3964" +checksum = "b3a577f64bdb7e2cc4043cd97f8901d8c504711fde2dbcb0887645b00d7c660b" dependencies = [ "arrow", "chrono", @@ -1188,9 +1195,9 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ac9cf3b22bbbae8cdf8ceb33039107fde1b5492693168f13bd566b1bcc839" +checksum = "51b7916806ace3e9f41884f230f7f38ebf0e955dfbd88266da1826f29a0b9a6a" dependencies = [ "arrow", "datafusion-common", @@ -1201,9 +1208,9 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf3fe9ab492c56daeb7beed526690d33622d388b8870472e0b7b7f55490338c" +checksum = "980cca31de37f5dadf7ea18e4ffc2b6833611f45bed5ef9de0831d2abb50f1ef" dependencies = [ "abi_stable", "arrow", @@ -1211,7 +1218,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "futures", "log", "prost", @@ -1221,9 +1230,9 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ddf0a0a2db5d2918349c978d42d80926c6aa2459cd8a3c533a84ec4bb63479e" +checksum = "7fb31c9dc73d3e0c365063f91139dc273308f8a8e124adda9898db8085d68357" dependencies = [ "arrow", "arrow-buffer", @@ -1241,7 +1250,7 @@ dependencies = [ "itertools 0.14.0", "log", "md-5", - "rand 0.8.5", + "rand 0.9.1", "regex", "sha2", "unicode-segmentation", @@ -1250,9 +1259,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "408a05dafdc70d05a38a29005b8b15e21b0238734dab1e98483fcb58038c5aba" +checksum = "ebb72c6940697eaaba9bd1f746a697a07819de952b817e3fb841fb75331ad5d4" dependencies = [ "ahash", "arrow", @@ -1271,9 +1280,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "756d21da2dd6c9bef97af1504970ff56cbf35d03fbd4ffd62827f02f4d2279d4" +checksum = "d7fdc54656659e5ecd49bf341061f4156ab230052611f4f3609612a0da259696" dependencies = [ "ahash", "arrow", @@ -1284,9 +1293,9 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d8d50f6334b378930d992d801a10ac5b3e93b846b39e4a05085742572844537" +checksum = "fad94598e3374938ca43bca6b675febe557e7a14eb627d617db427d70d65118b" dependencies = [ "arrow", "arrow-ord", @@ -1305,9 +1314,9 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9a97220736c8fff1446e936be90d57216c06f28969f9ffd3b72ac93c958c8a" +checksum = "de2fc6c2946da5cab8364fb28b5cac3115f0f3a87960b235ed031c3f7e2e639b" dependencies = [ "arrow", "async-trait", @@ -1321,10 +1330,11 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefc2d77646e1aadd1d6a9c40088937aedec04e68c5f0465939912e1291f8193" +checksum = "3e5746548a8544870a119f556543adcd88fe0ba6b93723fe78ad0439e0fbb8b4" dependencies = [ + "arrow", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -1338,9 +1348,9 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd4aff082c42fa6da99ce0698c85addd5252928c908eb087ca3cfa64ff16b313" +checksum = "dcbe9404382cda257c434f22e13577bee7047031dfdb6216dd5e841b9465e6fe" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1348,9 +1358,9 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6f88d7ee27daf8b108ba910f9015176b36fbc72902b1ca5c2a5f1d1717e1a1" +checksum = "8dce50e3b637dab0d25d04d2fe79dfdca2b257eabd76790bffd22c7f90d700c8" dependencies = [ "datafusion-expr", "quote", @@ -1359,9 +1369,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084d9f979c4b155346d3c34b18f4256e6904ded508e9554d90fed416415c3515" +checksum = "03cfaacf06445dc3bbc1e901242d2a44f2cae99a744f49f3fefddcee46240058" dependencies = [ "arrow", "chrono", @@ -1378,9 +1388,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64c536062b0076f4e30084065d805f389f9fe38af0ca75bcbac86bc5e9fbab65" +checksum = "1908034a89d7b2630898e06863583ae4c00a0dd310c1589ca284195ee3f7f8a6" dependencies = [ "ahash", "arrow", @@ -1395,14 +1405,14 @@ dependencies = [ "itertools 0.14.0", "log", "paste", - "petgraph", + "petgraph 0.8.2", ] [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a92b53b3193fac1916a1c5b8e3f4347c526f6822e56b71faa5fb372327a863" +checksum = "47b7a12dd59ea07614b67dbb01d85254fbd93df45bcffa63495e11d3bdf847df" dependencies = [ "ahash", "arrow", @@ -1414,9 +1424,9 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa0a5ac94c7cf3da97bedabd69d6bbca12aef84b9b37e6e9e8c25286511b5e2" +checksum = "4371cc4ad33978cc2a8be93bd54a232d3f2857b50401a14631c0705f3f910aae" dependencies = [ "arrow", "datafusion-common", @@ -1433,9 +1443,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "690c615db468c2e5fe5085b232d8b1c088299a6c63d87fd960a354a71f7acb55" +checksum = "dc47bc33025757a5c11f2cd094c5b6b5ed87f46fa33c023e6fdfa25fcbfade23" dependencies = [ "ahash", "arrow", @@ -1463,9 +1473,9 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a1afb2bdb05de7ff65be6883ebfd4ec027bd9f1f21c46aa3afd01927160a83" +checksum = "d8f5d9acd7d96e3bf2a7bb04818373cab6e51de0356e3694b94905fee7b4e8b6" dependencies = [ "arrow", "chrono", @@ -1479,9 +1489,9 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35b7a5876ebd6b564fb9a1fd2c3a2a9686b787071a256b47e4708f0916f9e46f" +checksum = "09ecb5ec152c4353b60f7a5635489834391f7a291d2b39a4820cd469e318b78e" dependencies = [ "arrow", "datafusion-common", @@ -1499,6 +1509,7 @@ dependencies = [ "datafusion-proto", "datafusion-substrait", "futures", + "log", "mimalloc", "object_store", "prost", @@ -1506,6 +1517,7 @@ dependencies = [ "pyo3", "pyo3-async-runtimes", "pyo3-build-config", + "pyo3-log", "tokio", "url", "uuid", @@ -1513,9 +1525,9 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad229a134c7406c057ece00c8743c0c34b97f4e72f78b475fe17b66c5e14fa4f" +checksum = "d7485da32283985d6b45bd7d13a65169dcbe8c869e25d01b2cfbc425254b4b49" dependencies = [ "arrow", "async-trait", @@ -1537,9 +1549,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f6ab28b72b664c21a27b22a2ff815fd390ed224c26e89a93b5a8154a4e8607" +checksum = "a466b15632befddfeac68c125f0260f569ff315c6831538cbb40db754134e0df" dependencies = [ "arrow", "bigdecimal", @@ -1554,9 +1566,9 @@ dependencies = [ [[package]] name = "datafusion-substrait" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061efc0937f0ce3abb37ed0d56cfa01dd0e654b90e408656d05e846c8b7599fe" +checksum = "f2f3973b1a4f6e9ee7fd99a22d58e1c06e6723a28dc911a60df575974c8339aa" dependencies = [ "async-recursion", "async-trait", @@ -2717,6 +2729,18 @@ dependencies = [ "indexmap", ] +[[package]] +name = "petgraph" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.3", + "indexmap", + "serde", +] + [[package]] name = "phf" version = "0.11.3" @@ -2837,7 +2861,7 @@ dependencies = [ "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.7.1", "prettyplease", "prost", "prost-types", @@ -2937,6 +2961,17 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.24.2" @@ -3661,9 +3696,9 @@ dependencies = [ [[package]] name = "substrait" -version = "0.55.1" +version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "048fe52a3664881ccdfdc9bdb0f4e8805f3444ee64abf299d365c54f6a2ffabb" +checksum = "13de2e20128f2a018dab1cfa30be83ae069219a65968c6f89df66ad124de2397" dependencies = [ "heck", "pbjson", @@ -4016,9 +4051,9 @@ checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typify" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03ba3643450cfd95a1aca2e1938fef63c1c1994489337998aff4ad771f21ef8" +checksum = "6c6c647a34e851cf0260ccc14687f17cdcb8302ff1a8a687a24b97ca0f82406f" dependencies = [ "typify-impl", "typify-macro", @@ -4026,9 +4061,9 @@ dependencies = [ [[package]] name = "typify-impl" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bce48219a2f3154aaa2c56cbf027728b24a3c8fe0a47ed6399781de2b3f3eeaf" +checksum = "741b7f1e2e1338c0bee5ad5a7d3a9bbd4e24c33765c08b7691810e68d879365d" dependencies = [ "heck", "log", @@ -4046,9 +4081,9 @@ dependencies = [ [[package]] name = "typify-macro" -version = "0.3.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b5780d745920ed73c5b7447496a9b5c42ed2681a9b70859377aec423ecf02b" +checksum = "7560adf816a1e8dad7c63d8845ef6e31e673e39eab310d225636779230cbedeb" dependencies = [ "proc-macro2", "quote", @@ -4116,9 +4151,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 8107d76d3..1f7895a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,11 +37,12 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} -arrow = { version = "55.0.0", features = ["pyarrow"] } -datafusion = { version = "47.0.0", features = ["avro", "unicode_expressions"] } -datafusion-substrait = { version = "47.0.0", optional = true } -datafusion-proto = { version = "47.0.0" } -datafusion-ffi = { version = "47.0.0" } +pyo3-log = "0.12.4" +arrow = { version = "55.1.0", features = ["pyarrow"] } +datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] } +datafusion-substrait = { version = "48.0.0", optional = true } +datafusion-proto = { version = "48.0.0" } +datafusion-ffi = { version = "48.0.0" } prost = "0.13.1" # keep in line with `datafusion-substrait` uuid = { version = "1.16", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false, features = ["local_dynamic_tls"] } @@ -49,6 +50,7 @@ async-trait = "0.1.88" futures = "0.3" object_store = { version = "0.12.1", features = ["aws", "gcp", "azure", "http"] } url = "2" +log = "0.4.27" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/docs/source/api/dataframe.rst b/docs/source/api/dataframe.rst index a9e9e47c8..0efa2c6ed 100644 --- a/docs/source/api/dataframe.rst +++ b/docs/source/api/dataframe.rst @@ -174,7 +174,7 @@ HTML Rendering Customization ---------------------------- DataFusion provides extensive customization options for HTML table rendering through the -``datafusion.html_formatter`` module. +``datafusion.dataframe_formatter`` module. Configuring the HTML Formatter ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -183,7 +183,7 @@ You can customize how DataFrames are rendered by configuring the formatter: .. code-block:: python - from datafusion.html_formatter import configure_formatter + from datafusion.dataframe_formatter import configure_formatter configure_formatter( max_cell_length=30, # Maximum length of cell content before truncation @@ -206,7 +206,7 @@ For advanced styling needs, you can create a custom style provider class: .. code-block:: python - from datafusion.html_formatter import configure_formatter + from datafusion.dataframe_formatter import configure_formatter class CustomStyleProvider: def get_cell_style(self) -> str: @@ -225,7 +225,7 @@ You can register custom formatters for specific data types: .. code-block:: python - from datafusion.html_formatter import get_formatter + from datafusion.dataframe_formatter import get_formatter formatter = get_formatter() @@ -285,7 +285,7 @@ The HTML formatter maintains global state that can be managed: .. code-block:: python - from datafusion.html_formatter import reset_formatter, reset_styles_loaded_state, get_formatter + from datafusion.dataframe_formatter import reset_formatter, reset_styles_loaded_state, get_formatter # Reset the formatter to default settings reset_formatter() @@ -303,7 +303,7 @@ This example shows how to create a dashboard-like styling for your DataFrames: .. code-block:: python - from datafusion.html_formatter import configure_formatter, get_formatter + from datafusion.dataframe_formatter import configure_formatter, get_formatter # Define custom CSS custom_css = """ diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index c1f9806b3..a40af1234 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -176,7 +176,7 @@ By convention the ``datafusion-python`` library expects a Python object that has ``TableProvider`` PyCapsule to have this capsule accessible by calling a function named ``__datafusion_table_provider__``. You can see a complete working example of how to share a ``TableProvider`` from one python library to DataFusion Python in the -`repository examples folder `_. +`repository examples folder `_. This section has been written using ``TableProvider`` as an example. It is the first extension that has been written using this approach and the most thoroughly implemented. diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index 075ebd5a1..1b4ca6bee 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -323,6 +323,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] @@ -748,9 +750,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe060b978f74ab446be722adb8a274e052e005bf6dfd171caadc3abaad10080" +checksum = "cc6cb8c2c81eada072059983657d6c9caf3fddefc43b4a65551d243253254a96" dependencies = [ "arrow", "arrow-ipc", @@ -775,7 +777,6 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-table", "datafusion-functions-window", - "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -790,7 +791,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand", + "rand 0.9.1", "regex", "sqlparser", "tempfile", @@ -803,9 +804,9 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fe34f401bd03724a1f96d12108144f8cd495a3cdda2bf5e091822fb80b7e66" +checksum = "b7be8d1b627843af62e447396db08fe1372d882c0eb8d0ea655fd1fbc33120ee" dependencies = [ "arrow", "async-trait", @@ -829,9 +830,9 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4411b8e3bce5e0fc7521e44f201def2e2d5d1b5f176fb56e8cdc9942c890f00" +checksum = "38ab16c5ae43f65ee525fc493ceffbc41f40dee38b01f643dfcfc12959e92038" dependencies = [ "arrow", "async-trait", @@ -852,9 +853,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734015d81c8375eb5d4869b7f7ecccc2ee8d6cb81948ef737cd0e7b743bd69c" +checksum = "d3d56b2ac9f476b93ca82e4ef5fb00769c8a3f248d12b4965af7e27635fa7e12" dependencies = [ "ahash", "arrow", @@ -876,9 +877,9 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5167bb1d2ccbb87c6bc36c295274d7a0519b14afcfdaf401d53cbcaa4ef4968b" +checksum = "16015071202d6133bc84d72756176467e3e46029f3ce9ad2cb788f9b1ff139b2" dependencies = [ "futures", "log", @@ -887,9 +888,9 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e602dcdf2f50c2abf297cc2203c73531e6f48b29516af7695d338cf2a778b1" +checksum = "b77523c95c89d2a7eb99df14ed31390e04ab29b43ff793e562bdc1716b07e17b" dependencies = [ "arrow", "async-compression", @@ -912,7 +913,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -923,9 +924,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bb2253952dc32296ed5b84077cb2e0257fea4be6373e1c376426e17ead4ef6" +checksum = "40d25c5e2c0ebe8434beeea997b8e88d55b3ccc0d19344293f2373f65bc524fc" dependencies = [ "arrow", "async-trait", @@ -948,9 +949,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8c7f47a5d2fe03bfa521ec9bafdb8a5c82de8377f60967c3663f00c8790352" +checksum = "3dc6959e1155741ab35369e1dc7673ba30fc45ed568fad34c01b7cb1daeb4d4c" dependencies = [ "arrow", "async-trait", @@ -973,9 +974,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27d15868ea39ed2dc266728b554f6304acd473de2142281ecfa1294bb7415923" +checksum = "b7a6afdfe358d70f4237f60eaef26ae5a1ce7cb2c469d02d5fc6c7fd5d84e58b" dependencies = [ "arrow", "async-trait", @@ -998,21 +999,21 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91f8c2c5788ef32f48ff56c68e5b545527b744822a284373ac79bba1ba47292" +checksum = "9bcd8a3e3e3d02ea642541be23d44376b5d5c37c2938cce39b3873cdf7186eea" [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06f004d100f49a3658c9da6fb0c3a9b760062d96cd4ad82ccc3b7b69a9fb2f84" +checksum = "670da1d45d045eee4c2319b8c7ea57b26cf48ab77b630aaa50b779e406da476a" dependencies = [ "arrow", "dashmap", @@ -1022,16 +1023,16 @@ dependencies = [ "log", "object_store", "parking_lot", - "rand", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4e4ce3802609be38eeb607ee72f6fe86c3091460de9dbfae9e18db423b3964" +checksum = "b3a577f64bdb7e2cc4043cd97f8901d8c504711fde2dbcb0887645b00d7c660b" dependencies = [ "arrow", "chrono", @@ -1050,9 +1051,9 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ac9cf3b22bbbae8cdf8ceb33039107fde1b5492693168f13bd566b1bcc839" +checksum = "51b7916806ace3e9f41884f230f7f38ebf0e955dfbd88266da1826f29a0b9a6a" dependencies = [ "arrow", "datafusion-common", @@ -1063,9 +1064,9 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf3fe9ab492c56daeb7beed526690d33622d388b8870472e0b7b7f55490338c" +checksum = "980cca31de37f5dadf7ea18e4ffc2b6833611f45bed5ef9de0831d2abb50f1ef" dependencies = [ "abi_stable", "arrow", @@ -1073,7 +1074,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "futures", "log", "prost", @@ -1081,11 +1084,25 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-ffi-example" +version = "0.2.0" +dependencies = [ + "arrow", + "arrow-array", + "arrow-schema", + "async-trait", + "datafusion", + "datafusion-ffi", + "pyo3", + "pyo3-build-config", +] + [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ddf0a0a2db5d2918349c978d42d80926c6aa2459cd8a3c533a84ec4bb63479e" +checksum = "7fb31c9dc73d3e0c365063f91139dc273308f8a8e124adda9898db8085d68357" dependencies = [ "arrow", "arrow-buffer", @@ -1103,7 +1120,7 @@ dependencies = [ "itertools", "log", "md-5", - "rand", + "rand 0.9.1", "regex", "sha2", "unicode-segmentation", @@ -1112,9 +1129,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "408a05dafdc70d05a38a29005b8b15e21b0238734dab1e98483fcb58038c5aba" +checksum = "ebb72c6940697eaaba9bd1f746a697a07819de952b817e3fb841fb75331ad5d4" dependencies = [ "ahash", "arrow", @@ -1133,9 +1150,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "756d21da2dd6c9bef97af1504970ff56cbf35d03fbd4ffd62827f02f4d2279d4" +checksum = "d7fdc54656659e5ecd49bf341061f4156ab230052611f4f3609612a0da259696" dependencies = [ "ahash", "arrow", @@ -1146,9 +1163,9 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d8d50f6334b378930d992d801a10ac5b3e93b846b39e4a05085742572844537" +checksum = "fad94598e3374938ca43bca6b675febe557e7a14eb627d617db427d70d65118b" dependencies = [ "arrow", "arrow-ord", @@ -1167,9 +1184,9 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9a97220736c8fff1446e936be90d57216c06f28969f9ffd3b72ac93c958c8a" +checksum = "de2fc6c2946da5cab8364fb28b5cac3115f0f3a87960b235ed031c3f7e2e639b" dependencies = [ "arrow", "async-trait", @@ -1183,10 +1200,11 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefc2d77646e1aadd1d6a9c40088937aedec04e68c5f0465939912e1291f8193" +checksum = "3e5746548a8544870a119f556543adcd88fe0ba6b93723fe78ad0439e0fbb8b4" dependencies = [ + "arrow", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -1200,9 +1218,9 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd4aff082c42fa6da99ce0698c85addd5252928c908eb087ca3cfa64ff16b313" +checksum = "dcbe9404382cda257c434f22e13577bee7047031dfdb6216dd5e841b9465e6fe" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1210,9 +1228,9 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6f88d7ee27daf8b108ba910f9015176b36fbc72902b1ca5c2a5f1d1717e1a1" +checksum = "8dce50e3b637dab0d25d04d2fe79dfdca2b257eabd76790bffd22c7f90d700c8" dependencies = [ "datafusion-expr", "quote", @@ -1221,9 +1239,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084d9f979c4b155346d3c34b18f4256e6904ded508e9554d90fed416415c3515" +checksum = "03cfaacf06445dc3bbc1e901242d2a44f2cae99a744f49f3fefddcee46240058" dependencies = [ "arrow", "chrono", @@ -1240,9 +1258,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64c536062b0076f4e30084065d805f389f9fe38af0ca75bcbac86bc5e9fbab65" +checksum = "1908034a89d7b2630898e06863583ae4c00a0dd310c1589ca284195ee3f7f8a6" dependencies = [ "ahash", "arrow", @@ -1262,9 +1280,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a92b53b3193fac1916a1c5b8e3f4347c526f6822e56b71faa5fb372327a863" +checksum = "47b7a12dd59ea07614b67dbb01d85254fbd93df45bcffa63495e11d3bdf847df" dependencies = [ "ahash", "arrow", @@ -1276,9 +1294,9 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa0a5ac94c7cf3da97bedabd69d6bbca12aef84b9b37e6e9e8c25286511b5e2" +checksum = "4371cc4ad33978cc2a8be93bd54a232d3f2857b50401a14631c0705f3f910aae" dependencies = [ "arrow", "datafusion-common", @@ -1295,9 +1313,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "690c615db468c2e5fe5085b232d8b1c088299a6c63d87fd960a354a71f7acb55" +checksum = "dc47bc33025757a5c11f2cd094c5b6b5ed87f46fa33c023e6fdfa25fcbfade23" dependencies = [ "ahash", "arrow", @@ -1325,9 +1343,9 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a1afb2bdb05de7ff65be6883ebfd4ec027bd9f1f21c46aa3afd01927160a83" +checksum = "d8f5d9acd7d96e3bf2a7bb04818373cab6e51de0356e3694b94905fee7b4e8b6" dependencies = [ "arrow", "chrono", @@ -1341,9 +1359,9 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35b7a5876ebd6b564fb9a1fd2c3a2a9686b787071a256b47e4708f0916f9e46f" +checksum = "09ecb5ec152c4353b60f7a5635489834391f7a291d2b39a4820cd469e318b78e" dependencies = [ "arrow", "datafusion-common", @@ -1352,9 +1370,9 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad229a134c7406c057ece00c8743c0c34b97f4e72f78b475fe17b66c5e14fa4f" +checksum = "d7485da32283985d6b45bd7d13a65169dcbe8c869e25d01b2cfbc425254b4b49" dependencies = [ "arrow", "async-trait", @@ -1376,9 +1394,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f6ab28b72b664c21a27b22a2ff815fd390ed224c26e89a93b5a8154a4e8607" +checksum = "a466b15632befddfeac68c125f0260f569ff315c6831538cbb40db754134e0df" dependencies = [ "arrow", "bigdecimal", @@ -1441,19 +1459,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "ffi-table-provider" -version = "0.1.0" -dependencies = [ - "arrow", - "arrow-array", - "arrow-schema", - "datafusion", - "datafusion-ffi", - "pyo3", - "pyo3-build-config", -] - [[package]] name = "fixedbitset" version = "0.5.7" @@ -1487,6 +1492,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1665,6 +1676,11 @@ name = "hashbrown" version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heck" @@ -2270,12 +2286,14 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.7.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" dependencies = [ "fixedbitset", + "hashbrown 0.15.3", "indexmap", + "serde", ] [[package]] @@ -2304,7 +2322,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] @@ -2483,19 +2501,27 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.3", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", ] [[package]] @@ -2503,8 +2529,14 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.3.3", ] [[package]] @@ -3031,9 +3063,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 0e17567b9..b26ab48e3 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -16,17 +16,18 @@ # under the License. [package] -name = "ffi-table-provider" -version = "0.1.0" +name = "datafusion-ffi-example" +version = "0.2.0" edition = "2021" [dependencies] -datafusion = { version = "47.0.0" } -datafusion-ffi = { version = "47.0.0" } +datafusion = { version = "48.0.0" } +datafusion-ffi = { version = "48.0.0" } pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] } arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } +async-trait = "0.1.88" [build-dependencies] pyo3-build-config = "0.23" diff --git a/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py new file mode 100644 index 000000000..7ea6b295c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udaf +from datafusion_ffi_example import MySumUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + # Pick numbers here so we get the same value in both groups + # since we cannot be certain of the output order of batches + batch = pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3, None], type=pa.int64()), + pa.array([1, 1, 2, 2], type=pa.int64()), + ], + names=["a", "b"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_aggregate_register(): + ctx = setup_context_with_table() + my_udaf = udaf(MySumUDF()) + ctx.register_udaf(my_udaf) + + result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect() + + assert len(result) == 2 + assert result[0].num_columns == 1 + + result = [r.column(0) for r in result] + expected = [ + pa.array([3], type=pa.int64()), + pa.array([3], type=pa.int64()), + ] + + assert result == expected + + +def test_ffi_aggregate_call_directly(): + ctx = setup_context_with_table() + my_udaf = udaf(MySumUDF()) + + result = ( + ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect() + ) + + assert len(result) == 2 + assert result[0].num_columns == 2 + + result = [r.column(1) for r in result] + expected = [ + pa.array([3], type=pa.int64()), + pa.array([3], type=pa.int64()), + ] + + assert result == expected diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py new file mode 100644 index 000000000..72aadf64c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext +from datafusion_ffi_example import MyCatalogProvider + + +def test_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_catalog" + expected_schema_name = "my_schema" + expected_table_name = "my_table" + expected_table_columns = ["units", "price"] + + catalog_provider = MyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, catalog_provider) + my_catalog = ctx.catalog(my_catalog_name) + + my_catalog_schemas = my_catalog.names() + assert expected_schema_name in my_catalog_schemas + my_database = my_catalog.database(expected_schema_name) + assert expected_table_name in my_database.names() + my_table = my_database.table(expected_table_name) + assert expected_table_columns == my_table.schema.names + + result = ctx.table( + f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}" + ).collect() + assert len(result) == 2 + + col0_result = [r.column(0) for r in result] + col1_result = [r.column(1) for r in result] + expected_col0 = [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array([5, 7], type=pa.int32()), + ] + expected_col1 = [ + pa.array([1, 2, 5], type=pa.float64()), + pa.array([1.5, 2.5], type=pa.float64()), + ] + assert col0_result == expected_col0 + assert col1_result == expected_col1 diff --git a/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py new file mode 100644 index 000000000..0c949c34a --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udf +from datafusion_ffi_example import IsNullUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, None])], + names=["a"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_scalar_register(): + ctx = setup_context_with_table() + my_udf = udf(IsNullUDF()) + ctx.register_udf(my_udf) + + result = ctx.sql("select my_custom_is_null(a) from test_table").collect() + + assert len(result) == 1 + assert result[0].num_columns == 1 + print(result) + + result = [r.column(0) for r in result] + expected = [ + pa.array([False, False, False, True], type=pa.bool_()), + ] + + assert result == expected + + +def test_ffi_scalar_call_directly(): + ctx = setup_context_with_table() + my_udf = udf(IsNullUDF()) + + result = ctx.table("test_table").select(my_udf(col("a"))).collect() + + assert len(result) == 1 + assert result[0].num_columns == 1 + print(result) + + result = [r.column(0) for r in result] + expected = [ + pa.array([False, False, False, True], type=pa.bool_()), + ] + + assert result == expected diff --git a/examples/datafusion-ffi-example/python/tests/_test_window_udf.py b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py new file mode 100644 index 000000000..7d96994b9 --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udwf +from datafusion_ffi_example import MyRankUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + # Pick numbers here so we get the same value in both groups + # since we cannot be certain of the output order of batches + batch = pa.RecordBatch.from_arrays( + [ + pa.array([40, 10, 30, 20], type=pa.int64()), + ], + names=["a"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_window_register(): + ctx = setup_context_with_table() + my_udwf = udwf(MyRankUDF()) + ctx.register_udwf(my_udwf) + + result = ctx.sql( + "select a, my_custom_rank() over (order by a) from test_table" + ).collect() + assert len(result) == 1 + assert result[0].num_columns == 2 + + results = [ + (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4) + ] + results.sort() + + expected = [ + (10, 1), + (20, 2), + (30, 3), + (40, 4), + ] + assert results == expected + + +def test_ffi_window_call_directly(): + ctx = setup_context_with_table() + my_udwf = udwf(MyRankUDF()) + + result = ( + ctx.table("test_table") + .select(col("a"), my_udwf().order_by(col("a")).build()) + .collect() + ) + + assert len(result) == 1 + assert result[0].num_columns == 2 + + results = [ + (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4) + ] + results.sort() + + expected = [ + (10, 1), + (20, 2), + (30, 3), + (40, 4), + ] + assert results == expected diff --git a/examples/datafusion-ffi-example/src/aggregate_udf.rs b/examples/datafusion-ffi-example/src/aggregate_udf.rs new file mode 100644 index 000000000..9481fe9c6 --- /dev/null +++ b/examples/datafusion-ffi-example/src/aggregate_udf.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::DataType; +use datafusion::error::Result as DataFusionResult; +use datafusion::functions_aggregate::sum::Sum; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; +use datafusion_ffi::udaf::FFI_AggregateUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "MySumUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct MySumUDF { + inner: Arc, +} + +#[pymethods] +impl MySumUDF { + #[new] + fn new() -> Self { + Self { + inner: Arc::new(Sum::new()), + } + } + + fn __datafusion_aggregate_udf__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_aggregate_udf".into(); + + let func = Arc::new(AggregateUDF::from(self.clone())); + let provider = FFI_AggregateUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl AggregateUDFImpl for MySumUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_sum" + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + self.inner.return_type(arg_types) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> DataFusionResult> { + self.inner.accumulator(acc_args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.inner.coerce_types(arg_types) + } +} diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs new file mode 100644 index 000000000..54e61cf3e --- /dev/null +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::datatypes::Schema; +use async_trait::async_trait; +use datafusion::{ + catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider, + }, + common::exec_err, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use pyo3::types::PyCapsule; + +pub fn my_table() -> Arc { + use arrow::datatypes::{DataType, Field}; + use datafusion::common::record_batch; + + let schema = Arc::new(Schema::new(vec![ + Field::new("units", DataType::Int32, true), + Field::new("price", DataType::Float64, true), + ])); + + let partitions = vec![ + record_batch!( + ("units", Int32, vec![10, 20, 30]), + ("price", Float64, vec![1.0, 2.0, 5.0]) + ) + .unwrap(), + record_batch!( + ("units", Int32, vec![5, 7]), + ("price", Float64, vec![1.5, 2.5]) + ) + .unwrap(), + ]; + + Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) +} + +#[derive(Debug)] +pub struct FixedSchemaProvider { + inner: MemorySchemaProvider, +} + +impl Default for FixedSchemaProvider { + fn default() -> Self { + let inner = MemorySchemaProvider::new(); + + let table = my_table(); + + let _ = inner.register_table("my_table".to_string(), table).unwrap(); + + Self { inner } + } +} + +#[async_trait] +impl SchemaProvider for FixedSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table(&self, name: &str) -> Result>, DataFusionError> { + self.inner.table(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table(&self, name: &str) -> Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + +/// This catalog provider is intended only for unit tests. It prepopulates with one +/// schema and only allows for schemas named after four types of fruit. +#[pyclass( + name = "MyCatalogProvider", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug)] +pub(crate) struct MyCatalogProvider { + inner: MemoryCatalogProvider, +} + +impl Default for MyCatalogProvider { + fn default() -> Self { + let inner = MemoryCatalogProvider::new(); + + let schema_name: &str = "my_schema"; + let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default())); + + Self { inner } + } +} + +impl CatalogProvider for MyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + self.inner.register_schema(name, schema) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + self.inner.deregister_schema(name, cascade) + } +} + +#[pymethods] +impl MyCatalogProvider { + #[new] + pub fn new() -> Self { + Self { + inner: Default::default(), + } + } + + pub fn __datafusion_catalog_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider".into(); + let catalog_provider = + FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + + PyCapsule::new(py, catalog_provider, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index ae08c3b65..f5f96cd49 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,16 +15,28 @@ // specific language governing permissions and limitations // under the License. +use crate::aggregate_udf::MySumUDF; +use crate::catalog_provider::MyCatalogProvider; +use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; +use crate::window_udf::MyRankUDF; use pyo3::prelude::*; +pub(crate) mod aggregate_udf; +pub(crate) mod catalog_provider; +pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; +pub(crate) mod window_udf; #[pymodule] fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/examples/datafusion-ffi-example/src/scalar_udf.rs b/examples/datafusion-ffi-example/src/scalar_udf.rs new file mode 100644 index 000000000..727666638 --- /dev/null +++ b/examples/datafusion-ffi-example/src/scalar_udf.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{Array, BooleanArray}; +use arrow_schema::DataType; +use datafusion::common::ScalarValue; +use datafusion::error::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_ffi::udf::FFI_ScalarUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "IsNullUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct IsNullUDF { + signature: Signature, +} + +#[pymethods] +impl IsNullUDF { + #[new] + fn new() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } + } + + fn __datafusion_scalar_udf__<'py>(&self, py: Python<'py>) -> PyResult> { + let name = cr"datafusion_scalar_udf".into(); + + let func = Arc::new(ScalarUDF::from(self.clone())); + let provider = FFI_ScalarUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl ScalarUDFImpl for IsNullUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_is_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let input = &args.args[0]; + + Ok(match input { + ColumnarValue::Array(arr) => match arr.is_nullable() { + true => { + let nulls = arr.nulls().unwrap(); + let nulls = BooleanArray::from_iter(nulls.iter().map(|x| Some(!x))); + ColumnarValue::Array(Arc::new(nulls)) + } + false => ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))), + }, + ColumnarValue::Scalar(sv) => { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(sv == &ScalarValue::Null))) + } + }) + } +} diff --git a/examples/datafusion-ffi-example/src/window_udf.rs b/examples/datafusion-ffi-example/src/window_udf.rs new file mode 100644 index 000000000..e0d397956 --- /dev/null +++ b/examples/datafusion-ffi-example/src/window_udf.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{DataType, FieldRef}; +use datafusion::error::Result as DataFusionResult; +use datafusion::functions_window::rank::rank_udwf; +use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; +use datafusion::logical_expr::{PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl}; +use datafusion_ffi::udwf::FFI_WindowUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "MyRankUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct MyRankUDF { + inner: Arc, +} + +#[pymethods] +impl MyRankUDF { + #[new] + fn new() -> Self { + Self { inner: rank_udwf() } + } + + fn __datafusion_window_udf__<'py>(&self, py: Python<'py>) -> PyResult> { + let name = cr"datafusion_window_udf".into(); + + let func = Arc::new(WindowUDF::from(self.clone())); + let provider = FFI_WindowUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl WindowUDFImpl for MyRankUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_rank" + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> DataFusionResult> { + self.inner + .inner() + .partition_evaluator(partition_evaluator_args) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> DataFusionResult { + self.inner.inner().field(field_args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.inner.coerce_types(arg_types) + } +} diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index c3468eb4a..77fed2a94 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -21,6 +21,10 @@ See https://datafusion.apache.org/python for more information. """ +from __future__ import annotations + +from typing import Any + try: import importlib.metadata as importlib_metadata except ImportError: @@ -28,7 +32,7 @@ from datafusion.col import col, column -from . import functions, object_store, substrait, unparser +from . import catalog, functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -42,12 +46,12 @@ SessionContext, SQLOptions, ) -from .dataframe import DataFrame +from .dataframe import DataFrame, ParquetColumnOptions, ParquetWriterOptions +from .dataframe_formatter import configure_formatter from .expr import ( Expr, WindowFrame, ) -from .html_formatter import configure_formatter from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream @@ -76,6 +80,8 @@ "ExecutionPlan", "Expr", "LogicalPlan", + "ParquetColumnOptions", + "ParquetWriterOptions", "RecordBatch", "RecordBatchStream", "RuntimeEnvBuilder", @@ -87,6 +93,7 @@ "TableFunction", "WindowFrame", "WindowUDF", + "catalog", "col", "column", "common", @@ -130,3 +137,18 @@ def str_lit(value): def lit(value) -> Expr: """Create a literal expression.""" return Expr.literal(value) + + +def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: + """Creates a new expression representing a scalar value with metadata. + + Args: + value: A valid PyArrow scalar value or easily castable to one. + metadata: Metadata to attach to the expression. + """ + return Expr.literal_with_metadata(value, metadata) + + +def lit_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: + """Alias for literal_with_metadata.""" + return literal_with_metadata(value, metadata) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 6c3f188cc..5f1a317f6 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -26,46 +26,115 @@ if TYPE_CHECKING: import pyarrow as pa +try: + from warnings import deprecated # Python 3.13+ +except ImportError: + from typing_extensions import deprecated # Python 3.12 + + +__all__ = [ + "Catalog", + "Schema", + "Table", +] + class Catalog: """DataFusion data catalog.""" - def __init__(self, catalog: df_internal.Catalog) -> None: + def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None: """This constructor is not typically called by the end user.""" self.catalog = catalog - def names(self) -> list[str]: - """Returns the list of databases in this catalog.""" - return self.catalog.names() + def __repr__(self) -> str: + """Print a string representation of the catalog.""" + return self.catalog.__repr__() + + def names(self) -> set[str]: + """This is an alias for `schema_names`.""" + return self.schema_names() + + def schema_names(self) -> set[str]: + """Returns the list of schemas in this catalog.""" + return self.catalog.schema_names() - def database(self, name: str = "public") -> Database: + def schema(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" - return Database(self.catalog.database(name)) + schema = self.catalog.schema(name) + + return ( + Schema(schema) + if isinstance(schema, df_internal.catalog.RawSchema) + else schema + ) + + @deprecated("Use `schema` instead.") + def database(self, name: str = "public") -> Schema: + """Returns the database with the given ``name`` from this catalog.""" + return self.schema(name) + + def new_in_memory_schema(self, name: str) -> Schema: + """Create a new schema in this catalog using an in-memory provider.""" + self.catalog.new_in_memory_schema(name) + return self.schema(name) + def register_schema(self, name, schema) -> Schema | None: + """Register a schema with this catalog.""" + return self.catalog.register_schema(name, schema) -class Database: - """DataFusion Database.""" + def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None: + """Deregister a schema from this catalog.""" + return self.catalog.deregister_schema(name, cascade) - def __init__(self, db: df_internal.Database) -> None: + +class Schema: + """DataFusion Schema.""" + + def __init__(self, schema: df_internal.catalog.RawSchema) -> None: """This constructor is not typically called by the end user.""" - self.db = db + self._raw_schema = schema + + def __repr__(self) -> str: + """Print a string representation of the schema.""" + return self._raw_schema.__repr__() def names(self) -> set[str]: - """Returns the list of all tables in this database.""" - return self.db.names() + """This is an alias for `table_names`.""" + return self.table_names() + + def table_names(self) -> set[str]: + """Returns the list of all tables in this schema.""" + return self._raw_schema.table_names def table(self, name: str) -> Table: - """Return the table with the given ``name`` from this database.""" - return Table(self.db.table(name)) + """Return the table with the given ``name`` from this schema.""" + return Table(self._raw_schema.table(name)) + + def register_table(self, name, table) -> None: + """Register a table provider in this schema.""" + return self._raw_schema.register_table(name, table) + + def deregister_table(self, name: str) -> None: + """Deregister a table provider from this schema.""" + return self._raw_schema.deregister_table(name) + + +@deprecated("Use `Schema` instead.") +class Database(Schema): + """See `Schema`.""" class Table: """DataFusion table.""" - def __init__(self, table: df_internal.Table) -> None: + def __init__(self, table: df_internal.catalog.RawTable) -> None: """This constructor is not typically called by the end user.""" self.table = table + def __repr__(self) -> str: + """Print a string representation of the table.""" + return self.table.__repr__() + @property def schema(self) -> pa.Schema: """Returns the schema associated with this table.""" diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 26c3d2e22..f752272bb 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,8 +19,11 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Protocol +import pyarrow as pa + try: from warnings import deprecated # Python 3.13+ except ImportError: @@ -42,7 +45,6 @@ import pandas as pd import polars as pl - import pyarrow as pa from datafusion.plan import ExecutionPlan, LogicalPlan @@ -78,6 +80,15 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + + def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -496,6 +507,10 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) + def __repr__(self) -> str: + """Print a string representation of the Session Context.""" + return self.ctx.__repr__() + @classmethod def global_ctx(cls) -> SessionContext: """Retrieve the global context as a `SessionContext` wrapper. @@ -535,7 +550,7 @@ def register_listing_table( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".parquet", schema: pa.Schema | None = None, file_sort_order: list[list[Expr | SortExpr]] | None = None, @@ -556,6 +571,7 @@ def register_listing_table( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) file_sort_order_raw = ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None @@ -742,6 +758,21 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def catalog_names(self) -> set[str]: + """Returns the list of catalogs in this context.""" + return self.ctx.catalog_names() + + def new_in_memory_catalog(self, name: str) -> Catalog: + """Create a new catalog in this context using an in-memory provider.""" + self.ctx.new_in_memory_catalog(name) + return self.catalog(name) + + def register_catalog_provider( + self, name: str, provider: CatalogProviderExportable + ) -> None: + """Register a catalog provider.""" + self.ctx.register_catalog_provider(name, provider) + def register_table_provider( self, name: str, provider: TableProviderExportable ) -> None: @@ -774,7 +805,7 @@ def register_parquet( self, name: str, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -802,6 +833,7 @@ def register_parquet( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_parquet( name, str(path), @@ -865,7 +897,7 @@ def register_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> None: """Register a JSON file as a table. @@ -886,6 +918,7 @@ def register_json( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_json( name, str(path), @@ -902,7 +935,7 @@ def register_avro( path: str | pathlib.Path, schema: pa.Schema | None = None, file_extension: str = ".avro", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, ) -> None: """Register an Avro file as a table. @@ -918,6 +951,7 @@ def register_avro( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) self.ctx.register_avro( name, str(path), schema, file_extension, table_partition_cols ) @@ -977,7 +1011,7 @@ def read_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -997,6 +1031,7 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) return DataFrame( self.ctx.read_json( str(path), @@ -1016,7 +1051,7 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -1041,6 +1076,7 @@ def read_csv( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) path = [str(p) for p in path] if isinstance(path, list) else str(path) @@ -1060,7 +1096,7 @@ def read_csv( def read_parquet( self, path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -1089,6 +1125,7 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] + table_partition_cols = self._convert_table_partition_cols(table_partition_cols) file_sort_order = ( [sort_list_to_raw_sort_list(f) for f in file_sort_order] if file_sort_order is not None @@ -1110,7 +1147,7 @@ def read_avro( self, path: str | pathlib.Path, schema: pa.Schema | None = None, - file_partition_cols: list[tuple[str, str]] | None = None, + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. @@ -1126,6 +1163,7 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] + file_partition_cols = self._convert_table_partition_cols(file_partition_cols) return DataFrame( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) @@ -1142,3 +1180,41 @@ def read_table(self, table: Table) -> DataFrame: def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream: """Execute the ``plan`` and return the results.""" return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions)) + + @staticmethod + def _convert_table_partition_cols( + table_partition_cols: list[tuple[str, str | pa.DataType]], + ) -> list[tuple[str, pa.DataType]]: + warn = False + converted_table_partition_cols = [] + + for col, data_type in table_partition_cols: + if isinstance(data_type, str): + warn = True + if data_type == "string": + converted_data_type = pa.string() + elif data_type == "int": + converted_data_type = pa.int32() + else: + message = ( + f"Unsupported literal data type '{data_type}' for partition " + "column. Supported types are 'string' and 'int'" + ) + raise ValueError(message) + else: + converted_data_type = data_type + + converted_table_partition_cols.append((col, converted_data_type)) + + if warn: + message = ( + "using literals for table_partition_cols data types is deprecated," + "use pyarrow types instead" + ) + warnings.warn( + message, + category=DeprecationWarning, + stacklevel=2, + ) + + return converted_table_partition_cols diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index a1df7e080..c747c24d5 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -38,6 +38,8 @@ from typing_extensions import deprecated # Python 3.12 from datafusion._internal import DataFrame as DataFrameInternal +from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal +from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal from datafusion.expr import Expr, SortExpr, sort_or_default from datafusion.plan import ExecutionPlan, LogicalPlan from datafusion.record_batch import RecordBatchStream @@ -50,7 +52,6 @@ import polars as pl import pyarrow as pa - from datafusion._internal import DataFrame as DataFrameInternal from datafusion._internal import expr as expr_internal from enum import Enum @@ -114,6 +115,173 @@ def get_default_level(self) -> Optional[int]: return None +class ParquetWriterOptions: + """Advanced parquet writer options. + + Allows settings the writer options that apply to the entire file. Some options can + also be set on a column by column basis, with the field `column_specific_options` + (see `ParquetColumnOptions`). + + Attributes: + data_pagesize_limit: Sets best effort maximum size of data page in bytes. + write_batch_size: Sets write_batch_size in bytes. + writer_version: Sets parquet writer version. Valid values are `1.0` and + `2.0`. + skip_arrow_metadata: Skip encoding the embedded arrow metadata in the + KV_meta. + compression: Compression type to use. Default is "zstd(3)". + Available compression types are + - "uncompressed": No compression. + - "snappy": Snappy compression. + - "gzip(n)": Gzip compression with level n. + - "brotli(n)": Brotli compression with level n. + - "lz4": LZ4 compression. + - "lz4_raw": LZ4_RAW compression. + - "zstd(n)": Zstandard compression with level n. + dictionary_enabled: Sets if dictionary encoding is enabled. If None, uses + the default parquet writer setting. + dictionary_page_size_limit: Sets best effort maximum dictionary page size, + in bytes. + statistics_enabled: Sets if statistics are enabled for any column Valid + values are `none`, `chunk`, and `page`. If None, uses the default + parquet writer setting. + max_row_group_size: Target maximum number of rows in each row group + (defaults to 1M rows). Writing larger row groups requires more memory to + write, but can get better compression and be faster to read. + created_by: Sets "created by" property. + column_index_truncate_length: Sets column index truncate length. + statistics_truncate_length: Sets statistics truncate length. If None, uses + the default parquet writer setting. + data_page_row_count_limit: Sets best effort maximum number of rows in a data + page. + encoding: Sets default encoding for any column. Valid values are `plain`, + `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, + `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, and + `byte_stream_split`. If None, uses the default parquet writer setting. + bloom_filter_on_write: Write bloom filters for all columns when creating + parquet files. + bloom_filter_fpp: Sets bloom filter false positive probability. If None, + uses the default parquet writer setting + bloom_filter_ndv: Sets bloom filter number of distinct values. If None, uses + the default parquet writer setting. + allow_single_file_parallelism: Controls whether DataFusion will attempt to + speed up writing parquet files by serializing them in parallel. Each + column in each row group in each output file are serialized in parallel + leveraging a maximum possible core count of n_files * n_row_groups * + n_columns. + maximum_parallel_row_group_writers: By default parallel parquet writer is + tuned for minimum memory usage in a streaming execution plan. You may + see a performance benefit when writing large parquet files by increasing + `maximum_parallel_row_group_writers` and + `maximum_buffered_record_batches_per_stream` if your system has idle + cores and can tolerate additional memory usage. Boosting these values is + likely worthwhile when writing out already in-memory data, such as from + a cached data frame. + maximum_buffered_record_batches_per_stream: See + `maximum_parallel_row_group_writers`. + column_specific_options: Overrides options for specific columns. If a column + is not a part of this dictionary, it will use the parameters provided here. + """ + + def __init__( + self, + data_pagesize_limit: int = 1024 * 1024, + write_batch_size: int = 1024, + writer_version: str = "1.0", + skip_arrow_metadata: bool = False, + compression: Optional[str] = "zstd(3)", + dictionary_enabled: Optional[bool] = True, + dictionary_page_size_limit: int = 1024 * 1024, + statistics_enabled: Optional[str] = "page", + max_row_group_size: int = 1024 * 1024, + created_by: str = "datafusion-python", + column_index_truncate_length: Optional[int] = 64, + statistics_truncate_length: Optional[int] = None, + data_page_row_count_limit: int = 20_000, + encoding: Optional[str] = None, + bloom_filter_on_write: bool = False, + bloom_filter_fpp: Optional[float] = None, + bloom_filter_ndv: Optional[int] = None, + allow_single_file_parallelism: bool = True, + maximum_parallel_row_group_writers: int = 1, + maximum_buffered_record_batches_per_stream: int = 2, + column_specific_options: Optional[dict[str, ParquetColumnOptions]] = None, + ) -> None: + """Initialize the ParquetWriterOptions.""" + self.data_pagesize_limit = data_pagesize_limit + self.write_batch_size = write_batch_size + self.writer_version = writer_version + self.skip_arrow_metadata = skip_arrow_metadata + self.compression = compression + self.dictionary_enabled = dictionary_enabled + self.dictionary_page_size_limit = dictionary_page_size_limit + self.statistics_enabled = statistics_enabled + self.max_row_group_size = max_row_group_size + self.created_by = created_by + self.column_index_truncate_length = column_index_truncate_length + self.statistics_truncate_length = statistics_truncate_length + self.data_page_row_count_limit = data_page_row_count_limit + self.encoding = encoding + self.bloom_filter_on_write = bloom_filter_on_write + self.bloom_filter_fpp = bloom_filter_fpp + self.bloom_filter_ndv = bloom_filter_ndv + self.allow_single_file_parallelism = allow_single_file_parallelism + self.maximum_parallel_row_group_writers = maximum_parallel_row_group_writers + self.maximum_buffered_record_batches_per_stream = ( + maximum_buffered_record_batches_per_stream + ) + self.column_specific_options = column_specific_options + + +class ParquetColumnOptions: + """Parquet options for individual columns. + + Contains the available options that can be applied for an individual Parquet column, + replacing the global options in `ParquetWriterOptions`. + + Attributes: + encoding: Sets encoding for the column path. Valid values are: `plain`, + `plain_dictionary`, `rle`, `bit_packed`, `delta_binary_packed`, + `delta_length_byte_array`, `delta_byte_array`, `rle_dictionary`, and + `byte_stream_split`. These values are not case-sensitive. If `None`, uses + the default parquet options + dictionary_enabled: Sets if dictionary encoding is enabled for the column path. + If `None`, uses the default parquet options + compression: Sets default parquet compression codec for the column path. Valid + values are `uncompressed`, `snappy`, `gzip(level)`, `lzo`, `brotli(level)`, + `lz4`, `zstd(level)`, and `lz4_raw`. These values are not case-sensitive. If + `None`, uses the default parquet options. + statistics_enabled: Sets if statistics are enabled for the column Valid values + are: `none`, `chunk`, and `page` These values are not case sensitive. If + `None`, uses the default parquet options. + bloom_filter_enabled: Sets if bloom filter is enabled for the column path. If + `None`, uses the default parquet options. + bloom_filter_fpp: Sets bloom filter false positive probability for the column + path. If `None`, uses the default parquet options. + bloom_filter_ndv: Sets bloom filter number of distinct values. If `None`, uses + the default parquet options. + """ + + def __init__( + self, + encoding: Optional[str] = None, + dictionary_enabled: Optional[bool] = None, + compression: Optional[str] = None, + statistics_enabled: Optional[str] = None, + bloom_filter_enabled: Optional[bool] = None, + bloom_filter_fpp: Optional[float] = None, + bloom_filter_ndv: Optional[int] = None, + ) -> None: + """Initialize the ParquetColumnOptions.""" + self.encoding = encoding + self.dictionary_enabled = dictionary_enabled + self.compression = compression + self.statistics_enabled = statistics_enabled + self.bloom_filter_enabled = bloom_filter_enabled + self.bloom_filter_fpp = bloom_filter_fpp + self.bloom_filter_ndv = bloom_filter_ndv + + class DataFrame: """Two dimensional table representation of data. @@ -737,6 +905,58 @@ def write_parquet( self.df.write_parquet(str(path), compression.value, compression_level) + def write_parquet_with_options( + self, path: str | pathlib.Path, options: ParquetWriterOptions + ) -> None: + """Execute the :py:class:`DataFrame` and write the results to a Parquet file. + + Allows advanced writer options to be set with `ParquetWriterOptions`. + + Args: + path: Path of the Parquet file to write. + options: Sets the writer parquet options (see `ParquetWriterOptions`). + """ + options_internal = ParquetWriterOptionsInternal( + options.data_pagesize_limit, + options.write_batch_size, + options.writer_version, + options.skip_arrow_metadata, + options.compression, + options.dictionary_enabled, + options.dictionary_page_size_limit, + options.statistics_enabled, + options.max_row_group_size, + options.created_by, + options.column_index_truncate_length, + options.statistics_truncate_length, + options.data_page_row_count_limit, + options.encoding, + options.bloom_filter_on_write, + options.bloom_filter_fpp, + options.bloom_filter_ndv, + options.allow_single_file_parallelism, + options.maximum_parallel_row_group_writers, + options.maximum_buffered_record_batches_per_stream, + ) + + column_specific_options_internal = {} + for column, opts in (options.column_specific_options or {}).items(): + column_specific_options_internal[column] = ParquetColumnOptionsInternal( + bloom_filter_enabled=opts.bloom_filter_enabled, + encoding=opts.encoding, + dictionary_enabled=opts.dictionary_enabled, + compression=opts.compression, + statistics_enabled=opts.statistics_enabled, + bloom_filter_fpp=opts.bloom_filter_fpp, + bloom_filter_ndv=opts.bloom_filter_ndv, + ) + + self.df.write_parquet_with_options( + str(path), + options_internal, + column_specific_options_internal, + ) + def write_json(self, path: str | pathlib.Path) -> None: """Execute the :py:class:`DataFrame` and write the results to a JSON file. @@ -832,7 +1052,7 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram columns = list(columns) return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) - def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any: + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """Export an Arrow PyCapsule Stream. This will execute and collect the DataFrame. We will attempt to respect the @@ -891,3 +1111,17 @@ def fill_null(self, value: Any, subset: list[str] | None = None) -> DataFrame: - For columns not in subset, the original column is kept unchanged """ return DataFrame(self.df.fill_null(value, subset)) + + @staticmethod + def default_str_repr( + batches: list[pa.RecordBatch], + schema: pa.Schema, + has_more: bool, + table_uuid: str | None = None, + ) -> str: + """Return the default string representation of a DataFrame. + + This method is used by the default formatter and implemented in Rust for + performance reasons. + """ + return DataFrameInternal.default_str_repr(batches, schema, has_more, table_uuid) diff --git a/python/datafusion/dataframe_formatter.py b/python/datafusion/dataframe_formatter.py new file mode 100644 index 000000000..27f00f9c3 --- /dev/null +++ b/python/datafusion/dataframe_formatter.py @@ -0,0 +1,739 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""HTML formatting utilities for DataFusion DataFrames.""" + +from __future__ import annotations + +from typing import ( + Any, + Callable, + Optional, + Protocol, + runtime_checkable, +) + +from datafusion._internal import DataFrame as DataFrameInternal + + +def _validate_positive_int(value: Any, param_name: str) -> None: + """Validate that a parameter is a positive integer. + + Args: + value: The value to validate + param_name: Name of the parameter (used in error message) + + Raises: + ValueError: If the value is not a positive integer + """ + if not isinstance(value, int) or value <= 0: + msg = f"{param_name} must be a positive integer" + raise ValueError(msg) + + +def _validate_bool(value: Any, param_name: str) -> None: + """Validate that a parameter is a boolean. + + Args: + value: The value to validate + param_name: Name of the parameter (used in error message) + + Raises: + TypeError: If the value is not a boolean + """ + if not isinstance(value, bool): + msg = f"{param_name} must be a boolean" + raise TypeError(msg) + + +@runtime_checkable +class CellFormatter(Protocol): + """Protocol for cell value formatters.""" + + def __call__(self, value: Any) -> str: + """Format a cell value to string representation.""" + ... + + +@runtime_checkable +class StyleProvider(Protocol): + """Protocol for HTML style providers.""" + + def get_cell_style(self) -> str: + """Get the CSS style for table cells.""" + ... + + def get_header_style(self) -> str: + """Get the CSS style for header cells.""" + ... + + +class DefaultStyleProvider: + """Default implementation of StyleProvider.""" + + def get_cell_style(self) -> str: + """Get the CSS style for table cells. + + Returns: + CSS style string + """ + return ( + "border: 1px solid black; padding: 8px; text-align: left; " + "white-space: nowrap;" + ) + + def get_header_style(self) -> str: + """Get the CSS style for header cells. + + Returns: + CSS style string + """ + return ( + "border: 1px solid black; padding: 8px; text-align: left; " + "background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; " + "max-width: fit-content;" + ) + + +class DataFrameHtmlFormatter: + """Configurable HTML formatter for DataFusion DataFrames. + + This class handles the HTML rendering of DataFrames for display in + Jupyter notebooks and other rich display contexts. + + This class supports extension through composition. Key extension points: + - Provide a custom StyleProvider for styling cells and headers + - Register custom formatters for specific types + - Provide custom cell builders for specialized cell rendering + + Args: + max_cell_length: Maximum characters to display in a cell before truncation + max_width: Maximum width of the HTML table in pixels + max_height: Maximum height of the HTML table in pixels + max_memory_bytes: Maximum memory in bytes for rendered data (default: 2MB) + min_rows_display: Minimum number of rows to display + repr_rows: Default number of rows to display in repr output + enable_cell_expansion: Whether to add expand/collapse buttons for long cell + values + custom_css: Additional CSS to include in the HTML output + show_truncation_message: Whether to display a message when data is truncated + style_provider: Custom provider for cell and header styles + use_shared_styles: Whether to load styles and scripts only once per notebook + session + """ + + # Class variable to track if styles have been loaded in the notebook + _styles_loaded = False + + def __init__( + self, + max_cell_length: int = 25, + max_width: int = 1000, + max_height: int = 300, + max_memory_bytes: int = 2 * 1024 * 1024, # 2 MB + min_rows_display: int = 20, + repr_rows: int = 10, + enable_cell_expansion: bool = True, + custom_css: Optional[str] = None, + show_truncation_message: bool = True, + style_provider: Optional[StyleProvider] = None, + use_shared_styles: bool = True, + ) -> None: + """Initialize the HTML formatter. + + Parameters + ---------- + max_cell_length : int, default 25 + Maximum length of cell content before truncation. + max_width : int, default 1000 + Maximum width of the displayed table in pixels. + max_height : int, default 300 + Maximum height of the displayed table in pixels. + max_memory_bytes : int, default 2097152 (2MB) + Maximum memory in bytes for rendered data. + min_rows_display : int, default 20 + Minimum number of rows to display. + repr_rows : int, default 10 + Default number of rows to display in repr output. + enable_cell_expansion : bool, default True + Whether to allow cells to expand when clicked. + custom_css : str, optional + Custom CSS to apply to the HTML table. + show_truncation_message : bool, default True + Whether to show a message indicating that content has been truncated. + style_provider : StyleProvider, optional + Provider of CSS styles for the HTML table. If None, DefaultStyleProvider + is used. + use_shared_styles : bool, default True + Whether to use shared styles across multiple tables. + + Raises: + ------ + ValueError + If max_cell_length, max_width, max_height, max_memory_bytes, + min_rows_display, or repr_rows is not a positive integer. + TypeError + If enable_cell_expansion, show_truncation_message, or use_shared_styles is + not a boolean, + or if custom_css is provided but is not a string, + or if style_provider is provided but does not implement the StyleProvider + protocol. + """ + # Validate numeric parameters + _validate_positive_int(max_cell_length, "max_cell_length") + _validate_positive_int(max_width, "max_width") + _validate_positive_int(max_height, "max_height") + _validate_positive_int(max_memory_bytes, "max_memory_bytes") + _validate_positive_int(min_rows_display, "min_rows_display") + _validate_positive_int(repr_rows, "repr_rows") + + # Validate boolean parameters + _validate_bool(enable_cell_expansion, "enable_cell_expansion") + _validate_bool(show_truncation_message, "show_truncation_message") + _validate_bool(use_shared_styles, "use_shared_styles") + + # Validate custom_css + if custom_css is not None and not isinstance(custom_css, str): + msg = "custom_css must be None or a string" + raise TypeError(msg) + + # Validate style_provider + if style_provider is not None and not isinstance(style_provider, StyleProvider): + msg = "style_provider must implement the StyleProvider protocol" + raise TypeError(msg) + + self.max_cell_length = max_cell_length + self.max_width = max_width + self.max_height = max_height + self.max_memory_bytes = max_memory_bytes + self.min_rows_display = min_rows_display + self.repr_rows = repr_rows + self.enable_cell_expansion = enable_cell_expansion + self.custom_css = custom_css + self.show_truncation_message = show_truncation_message + self.style_provider = style_provider or DefaultStyleProvider() + self.use_shared_styles = use_shared_styles + # Registry for custom type formatters + self._type_formatters: dict[type, CellFormatter] = {} + # Custom cell builders + self._custom_cell_builder: Optional[Callable[[Any, int, int, str], str]] = None + self._custom_header_builder: Optional[Callable[[Any], str]] = None + + def register_formatter(self, type_class: type, formatter: CellFormatter) -> None: + """Register a custom formatter for a specific data type. + + Args: + type_class: The type to register a formatter for + formatter: Function that takes a value of the given type and returns + a formatted string + """ + self._type_formatters[type_class] = formatter + + def set_custom_cell_builder( + self, builder: Callable[[Any, int, int, str], str] + ) -> None: + """Set a custom cell builder function. + + Args: + builder: Function that takes (value, row, col, table_id) and returns HTML + """ + self._custom_cell_builder = builder + + def set_custom_header_builder(self, builder: Callable[[Any], str]) -> None: + """Set a custom header builder function. + + Args: + builder: Function that takes a field and returns HTML + """ + self._custom_header_builder = builder + + @classmethod + def is_styles_loaded(cls) -> bool: + """Check if HTML styles have been loaded in the current session. + + This method is primarily intended for debugging UI rendering issues + related to style loading. + + Returns: + True if styles have been loaded, False otherwise + + Example: + >>> from datafusion.dataframe_formatter import DataFrameHtmlFormatter + >>> DataFrameHtmlFormatter.is_styles_loaded() + False + """ + return cls._styles_loaded + + def format_html( + self, + batches: list, + schema: Any, + has_more: bool = False, + table_uuid: str | None = None, + ) -> str: + """Format record batches as HTML. + + This method is used by DataFrame's _repr_html_ implementation and can be + called directly when custom HTML rendering is needed. + + Args: + batches: List of Arrow RecordBatch objects + schema: Arrow Schema object + has_more: Whether there are more batches not shown + table_uuid: Unique ID for the table, used for JavaScript interactions + + Returns: + HTML string representation of the data + + Raises: + TypeError: If schema is invalid and no batches are provided + """ + if not batches: + return "No data to display" + + # Validate schema + if schema is None or not hasattr(schema, "__iter__"): + msg = "Schema must be provided" + raise TypeError(msg) + + # Generate a unique ID if none provided + table_uuid = table_uuid or f"df-{id(batches)}" + + # Build HTML components + html = [] + + # Only include styles and scripts if: + # 1. Not using shared styles, OR + # 2. Using shared styles but they haven't been loaded yet + include_styles = ( + not self.use_shared_styles or not DataFrameHtmlFormatter._styles_loaded + ) + + if include_styles: + html.extend(self._build_html_header()) + # If we're using shared styles, mark them as loaded + if self.use_shared_styles: + DataFrameHtmlFormatter._styles_loaded = True + + html.extend(self._build_table_container_start()) + + # Add table header and body + html.extend(self._build_table_header(schema)) + html.extend(self._build_table_body(batches, table_uuid)) + + html.append("") + html.append("") + + # Add footer (JavaScript and messages) + if include_styles and self.enable_cell_expansion: + html.append(self._get_javascript()) + + # Always add truncation message if needed (independent of styles) + if has_more and self.show_truncation_message: + html.append("
Data truncated due to size.
") + + return "\n".join(html) + + def format_str( + self, + batches: list, + schema: Any, + has_more: bool = False, + table_uuid: str | None = None, + ) -> str: + """Format record batches as a string. + + This method is used by DataFrame's __repr__ implementation and can be + called directly when string rendering is needed. + + Args: + batches: List of Arrow RecordBatch objects + schema: Arrow Schema object + has_more: Whether there are more batches not shown + table_uuid: Unique ID for the table, used for JavaScript interactions + + Returns: + String representation of the data + + Raises: + TypeError: If schema is invalid and no batches are provided + """ + return DataFrameInternal.default_str_repr(batches, schema, has_more, table_uuid) + + def _build_html_header(self) -> list[str]: + """Build the HTML header with CSS styles.""" + html = [] + html.append("") + return html + + def _build_table_container_start(self) -> list[str]: + """Build the opening tags for the table container.""" + html = [] + html.append( + f'
' + ) + html.append('') + return html + + def _build_table_header(self, schema: Any) -> list[str]: + """Build the HTML table header with column names.""" + html = [] + html.append("") + html.append("") + for field in schema: + if self._custom_header_builder: + html.append(self._custom_header_builder(field)) + else: + html.append( + f"" + ) + html.append("") + html.append("") + return html + + def _build_table_body(self, batches: list, table_uuid: str) -> list[str]: + """Build the HTML table body with data rows.""" + html = [] + html.append("") + + row_count = 0 + for batch in batches: + for row_idx in range(batch.num_rows): + row_count += 1 + html.append("") + + for col_idx, column in enumerate(batch.columns): + # Get the raw value from the column + raw_value = self._get_cell_value(column, row_idx) + + # Always check for type formatters first to format the value + formatted_value = self._format_cell_value(raw_value) + + # Then apply either custom cell builder or standard cell formatting + if self._custom_cell_builder: + # Pass both the raw value and formatted value to let the + # builder decide + cell_html = self._custom_cell_builder( + raw_value, row_count, col_idx, table_uuid + ) + html.append(cell_html) + else: + # Standard cell formatting with formatted value + if ( + len(str(raw_value)) > self.max_cell_length + and self.enable_cell_expansion + ): + cell_html = self._build_expandable_cell( + formatted_value, row_count, col_idx, table_uuid + ) + else: + cell_html = self._build_regular_cell(formatted_value) + html.append(cell_html) + + html.append("") + + html.append("") + return html + + def _get_cell_value(self, column: Any, row_idx: int) -> Any: + """Extract a cell value from a column. + + Args: + column: Arrow array + row_idx: Row index + + Returns: + The raw cell value + """ + try: + value = column[row_idx] + + if hasattr(value, "as_py"): + return value.as_py() + except (AttributeError, TypeError): + pass + else: + return value + + def _format_cell_value(self, value: Any) -> str: + """Format a cell value for display. + + Uses registered type formatters if available. + + Args: + value: The cell value to format + + Returns: + Formatted cell value as string + """ + # Check for custom type formatters + for type_cls, formatter in self._type_formatters.items(): + if isinstance(value, type_cls): + return formatter(value) + + # If no formatter matched, return string representation + return str(value) + + def _build_expandable_cell( + self, formatted_value: str, row_count: int, col_idx: int, table_uuid: str + ) -> str: + """Build an expandable cell for long content.""" + short_value = str(formatted_value)[: self.max_cell_length] + return ( + f"" + ) + + def _build_regular_cell(self, formatted_value: str) -> str: + """Build a regular table cell.""" + return ( + f"" + ) + + def _build_html_footer(self, has_more: bool) -> list[str]: + """Build the HTML footer with JavaScript and messages.""" + html = [] + + # Add JavaScript for interactivity only if cell expansion is enabled + # and we're not using the shared styles approach + if self.enable_cell_expansion and not self.use_shared_styles: + html.append(self._get_javascript()) + + # Add truncation message if needed + if has_more and self.show_truncation_message: + html.append("
Data truncated due to size.
") + + return html + + def _get_default_css(self) -> str: + """Get default CSS styles for the HTML table.""" + return """ + .expandable-container { + display: inline-block; + max-width: 200px; + } + .expandable { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + display: block; + } + .full-text { + display: none; + white-space: normal; + } + .expand-btn { + cursor: pointer; + color: blue; + text-decoration: underline; + border: none; + background: none; + font-size: inherit; + display: block; + margin-top: 5px; + } + """ + + def _get_javascript(self) -> str: + """Get JavaScript code for interactive elements.""" + return """ + + """ + + +class FormatterManager: + """Manager class for the global DataFrame HTML formatter instance.""" + + _default_formatter: DataFrameHtmlFormatter = DataFrameHtmlFormatter() + + @classmethod + def set_formatter(cls, formatter: DataFrameHtmlFormatter) -> None: + """Set the global DataFrame HTML formatter. + + Args: + formatter: The formatter instance to use globally + """ + cls._default_formatter = formatter + _refresh_formatter_reference() + + @classmethod + def get_formatter(cls) -> DataFrameHtmlFormatter: + """Get the current global DataFrame HTML formatter. + + Returns: + The global HTML formatter instance + """ + return cls._default_formatter + + +def get_formatter() -> DataFrameHtmlFormatter: + """Get the current global DataFrame HTML formatter. + + This function is used by the DataFrame._repr_html_ implementation to access + the shared formatter instance. It can also be used directly when custom + HTML rendering is needed. + + Returns: + The global HTML formatter instance + + Example: + >>> from datafusion.html_formatter import get_formatter + >>> formatter = get_formatter() + >>> formatter.max_cell_length = 50 # Increase cell length + """ + return FormatterManager.get_formatter() + + +def set_formatter(formatter: DataFrameHtmlFormatter) -> None: + """Set the global DataFrame HTML formatter. + + Args: + formatter: The formatter instance to use globally + + Example: + >>> from datafusion.html_formatter import get_formatter, set_formatter + >>> custom_formatter = DataFrameHtmlFormatter(max_cell_length=100) + >>> set_formatter(custom_formatter) + """ + FormatterManager.set_formatter(formatter) + + +def configure_formatter(**kwargs: Any) -> None: + """Configure the global DataFrame HTML formatter. + + This function creates a new formatter with the provided configuration + and sets it as the global formatter for all DataFrames. + + Args: + **kwargs: Formatter configuration parameters like max_cell_length, + max_width, max_height, enable_cell_expansion, etc. + + Raises: + ValueError: If any invalid parameters are provided + + Example: + >>> from datafusion.html_formatter import configure_formatter + >>> configure_formatter( + ... max_cell_length=50, + ... max_height=500, + ... enable_cell_expansion=True, + ... use_shared_styles=True + ... ) + """ + # Valid parameters accepted by DataFrameHtmlFormatter + valid_params = { + "max_cell_length", + "max_width", + "max_height", + "max_memory_bytes", + "min_rows_display", + "repr_rows", + "enable_cell_expansion", + "custom_css", + "show_truncation_message", + "style_provider", + "use_shared_styles", + } + + # Check for invalid parameters + invalid_params = set(kwargs) - valid_params + if invalid_params: + msg = ( + f"Invalid formatter parameters: {', '.join(invalid_params)}. " + f"Valid parameters are: {', '.join(valid_params)}" + ) + raise ValueError(msg) + + # Create and set formatter with validated parameters + set_formatter(DataFrameHtmlFormatter(**kwargs)) + + +def reset_formatter() -> None: + """Reset the global DataFrame HTML formatter to default settings. + + This function creates a new formatter with default configuration + and sets it as the global formatter for all DataFrames. + + Example: + >>> from datafusion.html_formatter import reset_formatter + >>> reset_formatter() # Reset formatter to default settings + """ + formatter = DataFrameHtmlFormatter() + # Reset the styles_loaded flag to ensure styles will be reloaded + DataFrameHtmlFormatter._styles_loaded = False + set_formatter(formatter) + + +def reset_styles_loaded_state() -> None: + """Reset the styles loaded state to force reloading of styles. + + This can be useful when switching between notebook sessions or + when styles need to be refreshed. + + Example: + >>> from datafusion.html_formatter import reset_styles_loaded_state + >>> reset_styles_loaded_state() # Force styles to reload in next render + """ + DataFrameHtmlFormatter._styles_loaded = False + + +def _refresh_formatter_reference() -> None: + """Refresh formatter reference in any modules using it. + + This helps ensure that changes to the formatter are reflected in existing + DataFrames that might be caching the formatter reference. + """ + # This is a no-op but signals modules to refresh their reference diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 9e58873d0..e785cab06 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -435,6 +435,20 @@ def literal(value: Any) -> Expr: value = pa.scalar(value) return Expr(expr_internal.RawExpr.literal(value)) + @staticmethod + def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: + """Creates a new expression representing a scalar value with metadata. + + Args: + value: A valid PyArrow scalar value or easily castable to one. + metadata: Metadata to attach to the expression. + """ + if isinstance(value, str): + value = pa.scalar(value, type=pa.string_view()) + value = value if isinstance(value, pa.Scalar) else pa.scalar(value) + + return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata)) + @staticmethod def string_literal(value: str) -> Expr: """Creates a new expression representing a UTF8 literal value. @@ -1172,6 +1186,10 @@ def __init__( end_bound = end_bound.cast(pa.uint64()) self.window_frame = expr_internal.WindowFrame(units, start_bound, end_bound) + def __repr__(self) -> str: + """Print a string representation of the window frame.""" + return self.window_frame.__repr__() + def get_frame_units(self) -> str: """Returns the window frame units for the bounds.""" return self.window_frame.get_frame_units() diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py index 12a7e4553..37558b913 100644 --- a/python/datafusion/html_formatter.py +++ b/python/datafusion/html_formatter.py @@ -14,698 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""HTML formatting utilities for DataFusion DataFrames.""" -from __future__ import annotations +"""Deprecated module for dataframe formatting.""" -from typing import ( - Any, - Callable, - Optional, - Protocol, - runtime_checkable, -) - - -def _validate_positive_int(value: Any, param_name: str) -> None: - """Validate that a parameter is a positive integer. - - Args: - value: The value to validate - param_name: Name of the parameter (used in error message) - - Raises: - ValueError: If the value is not a positive integer - """ - if not isinstance(value, int) or value <= 0: - msg = f"{param_name} must be a positive integer" - raise ValueError(msg) - - -def _validate_bool(value: Any, param_name: str) -> None: - """Validate that a parameter is a boolean. - - Args: - value: The value to validate - param_name: Name of the parameter (used in error message) - - Raises: - TypeError: If the value is not a boolean - """ - if not isinstance(value, bool): - msg = f"{param_name} must be a boolean" - raise TypeError(msg) - - -@runtime_checkable -class CellFormatter(Protocol): - """Protocol for cell value formatters.""" - - def __call__(self, value: Any) -> str: - """Format a cell value to string representation.""" - ... - - -@runtime_checkable -class StyleProvider(Protocol): - """Protocol for HTML style providers.""" - - def get_cell_style(self) -> str: - """Get the CSS style for table cells.""" - ... - - def get_header_style(self) -> str: - """Get the CSS style for header cells.""" - ... - - -class DefaultStyleProvider: - """Default implementation of StyleProvider.""" - - def get_cell_style(self) -> str: - """Get the CSS style for table cells. - - Returns: - CSS style string - """ - return ( - "border: 1px solid black; padding: 8px; text-align: left; " - "white-space: nowrap;" - ) - - def get_header_style(self) -> str: - """Get the CSS style for header cells. - - Returns: - CSS style string - """ - return ( - "border: 1px solid black; padding: 8px; text-align: left; " - "background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; " - "max-width: fit-content;" - ) - - -class DataFrameHtmlFormatter: - """Configurable HTML formatter for DataFusion DataFrames. - - This class handles the HTML rendering of DataFrames for display in - Jupyter notebooks and other rich display contexts. - - This class supports extension through composition. Key extension points: - - Provide a custom StyleProvider for styling cells and headers - - Register custom formatters for specific types - - Provide custom cell builders for specialized cell rendering - - Args: - max_cell_length: Maximum characters to display in a cell before truncation - max_width: Maximum width of the HTML table in pixels - max_height: Maximum height of the HTML table in pixels - max_memory_bytes: Maximum memory in bytes for rendered data (default: 2MB) - min_rows_display: Minimum number of rows to display - repr_rows: Default number of rows to display in repr output - enable_cell_expansion: Whether to add expand/collapse buttons for long cell - values - custom_css: Additional CSS to include in the HTML output - show_truncation_message: Whether to display a message when data is truncated - style_provider: Custom provider for cell and header styles - use_shared_styles: Whether to load styles and scripts only once per notebook - session - """ - - # Class variable to track if styles have been loaded in the notebook - _styles_loaded = False - - def __init__( - self, - max_cell_length: int = 25, - max_width: int = 1000, - max_height: int = 300, - max_memory_bytes: int = 2 * 1024 * 1024, # 2 MB - min_rows_display: int = 20, - repr_rows: int = 10, - enable_cell_expansion: bool = True, - custom_css: Optional[str] = None, - show_truncation_message: bool = True, - style_provider: Optional[StyleProvider] = None, - use_shared_styles: bool = True, - ) -> None: - """Initialize the HTML formatter. - - Parameters - ---------- - max_cell_length : int, default 25 - Maximum length of cell content before truncation. - max_width : int, default 1000 - Maximum width of the displayed table in pixels. - max_height : int, default 300 - Maximum height of the displayed table in pixels. - max_memory_bytes : int, default 2097152 (2MB) - Maximum memory in bytes for rendered data. - min_rows_display : int, default 20 - Minimum number of rows to display. - repr_rows : int, default 10 - Default number of rows to display in repr output. - enable_cell_expansion : bool, default True - Whether to allow cells to expand when clicked. - custom_css : str, optional - Custom CSS to apply to the HTML table. - show_truncation_message : bool, default True - Whether to show a message indicating that content has been truncated. - style_provider : StyleProvider, optional - Provider of CSS styles for the HTML table. If None, DefaultStyleProvider - is used. - use_shared_styles : bool, default True - Whether to use shared styles across multiple tables. - - Raises: - ------ - ValueError - If max_cell_length, max_width, max_height, max_memory_bytes, - min_rows_display, or repr_rows is not a positive integer. - TypeError - If enable_cell_expansion, show_truncation_message, or use_shared_styles is - not a boolean, - or if custom_css is provided but is not a string, - or if style_provider is provided but does not implement the StyleProvider - protocol. - """ - # Validate numeric parameters - _validate_positive_int(max_cell_length, "max_cell_length") - _validate_positive_int(max_width, "max_width") - _validate_positive_int(max_height, "max_height") - _validate_positive_int(max_memory_bytes, "max_memory_bytes") - _validate_positive_int(min_rows_display, "min_rows_display") - _validate_positive_int(repr_rows, "repr_rows") - - # Validate boolean parameters - _validate_bool(enable_cell_expansion, "enable_cell_expansion") - _validate_bool(show_truncation_message, "show_truncation_message") - _validate_bool(use_shared_styles, "use_shared_styles") - - # Validate custom_css - if custom_css is not None and not isinstance(custom_css, str): - msg = "custom_css must be None or a string" - raise TypeError(msg) - - # Validate style_provider - if style_provider is not None and not isinstance(style_provider, StyleProvider): - msg = "style_provider must implement the StyleProvider protocol" - raise TypeError(msg) - - self.max_cell_length = max_cell_length - self.max_width = max_width - self.max_height = max_height - self.max_memory_bytes = max_memory_bytes - self.min_rows_display = min_rows_display - self.repr_rows = repr_rows - self.enable_cell_expansion = enable_cell_expansion - self.custom_css = custom_css - self.show_truncation_message = show_truncation_message - self.style_provider = style_provider or DefaultStyleProvider() - self.use_shared_styles = use_shared_styles - # Registry for custom type formatters - self._type_formatters: dict[type, CellFormatter] = {} - # Custom cell builders - self._custom_cell_builder: Optional[Callable[[Any, int, int, str], str]] = None - self._custom_header_builder: Optional[Callable[[Any], str]] = None - - def register_formatter(self, type_class: type, formatter: CellFormatter) -> None: - """Register a custom formatter for a specific data type. - - Args: - type_class: The type to register a formatter for - formatter: Function that takes a value of the given type and returns - a formatted string - """ - self._type_formatters[type_class] = formatter - - def set_custom_cell_builder( - self, builder: Callable[[Any, int, int, str], str] - ) -> None: - """Set a custom cell builder function. - - Args: - builder: Function that takes (value, row, col, table_id) and returns HTML - """ - self._custom_cell_builder = builder - - def set_custom_header_builder(self, builder: Callable[[Any], str]) -> None: - """Set a custom header builder function. - - Args: - builder: Function that takes a field and returns HTML - """ - self._custom_header_builder = builder - - @classmethod - def is_styles_loaded(cls) -> bool: - """Check if HTML styles have been loaded in the current session. - - This method is primarily intended for debugging UI rendering issues - related to style loading. - - Returns: - True if styles have been loaded, False otherwise - - Example: - >>> from datafusion.html_formatter import DataFrameHtmlFormatter - >>> DataFrameHtmlFormatter.is_styles_loaded() - False - """ - return cls._styles_loaded - - def format_html( - self, - batches: list, - schema: Any, - has_more: bool = False, - table_uuid: str | None = None, - ) -> str: - """Format record batches as HTML. - - This method is used by DataFrame's _repr_html_ implementation and can be - called directly when custom HTML rendering is needed. - - Args: - batches: List of Arrow RecordBatch objects - schema: Arrow Schema object - has_more: Whether there are more batches not shown - table_uuid: Unique ID for the table, used for JavaScript interactions - - Returns: - HTML string representation of the data - - Raises: - TypeError: If schema is invalid and no batches are provided - """ - if not batches: - return "No data to display" - - # Validate schema - if schema is None or not hasattr(schema, "__iter__"): - msg = "Schema must be provided" - raise TypeError(msg) - - # Generate a unique ID if none provided - table_uuid = table_uuid or f"df-{id(batches)}" - - # Build HTML components - html = [] - - # Only include styles and scripts if: - # 1. Not using shared styles, OR - # 2. Using shared styles but they haven't been loaded yet - include_styles = ( - not self.use_shared_styles or not DataFrameHtmlFormatter._styles_loaded - ) - - if include_styles: - html.extend(self._build_html_header()) - # If we're using shared styles, mark them as loaded - if self.use_shared_styles: - DataFrameHtmlFormatter._styles_loaded = True - - html.extend(self._build_table_container_start()) - - # Add table header and body - html.extend(self._build_table_header(schema)) - html.extend(self._build_table_body(batches, table_uuid)) - - html.append("
" + f"{field.name}
" + f"
" + "" + "" + f"{formatted_value}" + f"" + f"
" + f"
{formatted_value}
") - html.append("
") - - # Add footer (JavaScript and messages) - if include_styles and self.enable_cell_expansion: - html.append(self._get_javascript()) - - # Always add truncation message if needed (independent of styles) - if has_more and self.show_truncation_message: - html.append("
Data truncated due to size.
") - - return "\n".join(html) - - def _build_html_header(self) -> list[str]: - """Build the HTML header with CSS styles.""" - html = [] - html.append("") - return html +import warnings - def _build_table_container_start(self) -> list[str]: - """Build the opening tags for the table container.""" - html = [] - html.append( - f'
' - ) - html.append('') - return html +from datafusion.dataframe_formatter import * # noqa: F403 - def _build_table_header(self, schema: Any) -> list[str]: - """Build the HTML table header with column names.""" - html = [] - html.append("") - html.append("") - for field in schema: - if self._custom_header_builder: - html.append(self._custom_header_builder(field)) - else: - html.append( - f"" - ) - html.append("") - html.append("") - return html - - def _build_table_body(self, batches: list, table_uuid: str) -> list[str]: - """Build the HTML table body with data rows.""" - html = [] - html.append("") - - row_count = 0 - for batch in batches: - for row_idx in range(batch.num_rows): - row_count += 1 - html.append("") - - for col_idx, column in enumerate(batch.columns): - # Get the raw value from the column - raw_value = self._get_cell_value(column, row_idx) - - # Always check for type formatters first to format the value - formatted_value = self._format_cell_value(raw_value) - - # Then apply either custom cell builder or standard cell formatting - if self._custom_cell_builder: - # Pass both the raw value and formatted value to let the - # builder decide - cell_html = self._custom_cell_builder( - raw_value, row_count, col_idx, table_uuid - ) - html.append(cell_html) - else: - # Standard cell formatting with formatted value - if ( - len(str(raw_value)) > self.max_cell_length - and self.enable_cell_expansion - ): - cell_html = self._build_expandable_cell( - formatted_value, row_count, col_idx, table_uuid - ) - else: - cell_html = self._build_regular_cell(formatted_value) - html.append(cell_html) - - html.append("") - - html.append("") - return html - - def _get_cell_value(self, column: Any, row_idx: int) -> Any: - """Extract a cell value from a column. - - Args: - column: Arrow array - row_idx: Row index - - Returns: - The raw cell value - """ - try: - value = column[row_idx] - - if hasattr(value, "as_py"): - return value.as_py() - except (AttributeError, TypeError): - pass - else: - return value - - def _format_cell_value(self, value: Any) -> str: - """Format a cell value for display. - - Uses registered type formatters if available. - - Args: - value: The cell value to format - - Returns: - Formatted cell value as string - """ - # Check for custom type formatters - for type_cls, formatter in self._type_formatters.items(): - if isinstance(value, type_cls): - return formatter(value) - - # If no formatter matched, return string representation - return str(value) - - def _build_expandable_cell( - self, formatted_value: str, row_count: int, col_idx: int, table_uuid: str - ) -> str: - """Build an expandable cell for long content.""" - short_value = str(formatted_value)[: self.max_cell_length] - return ( - f"" - ) - - def _build_regular_cell(self, formatted_value: str) -> str: - """Build a regular table cell.""" - return ( - f"" - ) - - def _build_html_footer(self, has_more: bool) -> list[str]: - """Build the HTML footer with JavaScript and messages.""" - html = [] - - # Add JavaScript for interactivity only if cell expansion is enabled - # and we're not using the shared styles approach - if self.enable_cell_expansion and not self.use_shared_styles: - html.append(self._get_javascript()) - - # Add truncation message if needed - if has_more and self.show_truncation_message: - html.append("
Data truncated due to size.
") - - return html - - def _get_default_css(self) -> str: - """Get default CSS styles for the HTML table.""" - return """ - .expandable-container { - display: inline-block; - max-width: 200px; - } - .expandable { - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; - display: block; - } - .full-text { - display: none; - white-space: normal; - } - .expand-btn { - cursor: pointer; - color: blue; - text-decoration: underline; - border: none; - background: none; - font-size: inherit; - display: block; - margin-top: 5px; - } - """ - - def _get_javascript(self) -> str: - """Get JavaScript code for interactive elements.""" - return """ - - """ - - -class FormatterManager: - """Manager class for the global DataFrame HTML formatter instance.""" - - _default_formatter: DataFrameHtmlFormatter = DataFrameHtmlFormatter() - - @classmethod - def set_formatter(cls, formatter: DataFrameHtmlFormatter) -> None: - """Set the global DataFrame HTML formatter. - - Args: - formatter: The formatter instance to use globally - """ - cls._default_formatter = formatter - _refresh_formatter_reference() - - @classmethod - def get_formatter(cls) -> DataFrameHtmlFormatter: - """Get the current global DataFrame HTML formatter. - - Returns: - The global HTML formatter instance - """ - return cls._default_formatter - - -def get_formatter() -> DataFrameHtmlFormatter: - """Get the current global DataFrame HTML formatter. - - This function is used by the DataFrame._repr_html_ implementation to access - the shared formatter instance. It can also be used directly when custom - HTML rendering is needed. - - Returns: - The global HTML formatter instance - - Example: - >>> from datafusion.html_formatter import get_formatter - >>> formatter = get_formatter() - >>> formatter.max_cell_length = 50 # Increase cell length - """ - return FormatterManager.get_formatter() - - -def set_formatter(formatter: DataFrameHtmlFormatter) -> None: - """Set the global DataFrame HTML formatter. - - Args: - formatter: The formatter instance to use globally - - Example: - >>> from datafusion.html_formatter import get_formatter, set_formatter - >>> custom_formatter = DataFrameHtmlFormatter(max_cell_length=100) - >>> set_formatter(custom_formatter) - """ - FormatterManager.set_formatter(formatter) - - -def configure_formatter(**kwargs: Any) -> None: - """Configure the global DataFrame HTML formatter. - - This function creates a new formatter with the provided configuration - and sets it as the global formatter for all DataFrames. - - Args: - **kwargs: Formatter configuration parameters like max_cell_length, - max_width, max_height, enable_cell_expansion, etc. - - Raises: - ValueError: If any invalid parameters are provided - - Example: - >>> from datafusion.html_formatter import configure_formatter - >>> configure_formatter( - ... max_cell_length=50, - ... max_height=500, - ... enable_cell_expansion=True, - ... use_shared_styles=True - ... ) - """ - # Valid parameters accepted by DataFrameHtmlFormatter - valid_params = { - "max_cell_length", - "max_width", - "max_height", - "max_memory_bytes", - "min_rows_display", - "repr_rows", - "enable_cell_expansion", - "custom_css", - "show_truncation_message", - "style_provider", - "use_shared_styles", - } - - # Check for invalid parameters - invalid_params = set(kwargs) - valid_params - if invalid_params: - msg = ( - f"Invalid formatter parameters: {', '.join(invalid_params)}. " - f"Valid parameters are: {', '.join(valid_params)}" - ) - raise ValueError(msg) - - # Create and set formatter with validated parameters - set_formatter(DataFrameHtmlFormatter(**kwargs)) - - -def reset_formatter() -> None: - """Reset the global DataFrame HTML formatter to default settings. - - This function creates a new formatter with default configuration - and sets it as the global formatter for all DataFrames. - - Example: - >>> from datafusion.html_formatter import reset_formatter - >>> reset_formatter() # Reset formatter to default settings - """ - formatter = DataFrameHtmlFormatter() - # Reset the styles_loaded flag to ensure styles will be reloaded - DataFrameHtmlFormatter._styles_loaded = False - set_formatter(formatter) - - -def reset_styles_loaded_state() -> None: - """Reset the styles loaded state to force reloading of styles. - - This can be useful when switching between notebook sessions or - when styles need to be refreshed. - - Example: - >>> from datafusion.html_formatter import reset_styles_loaded_state - >>> reset_styles_loaded_state() # Force styles to reload in next render - """ - DataFrameHtmlFormatter._styles_loaded = False - - -def _refresh_formatter_reference() -> None: - """Refresh formatter reference in any modules using it. - - This helps ensure that changes to the formatter are reflected in existing - DataFrames that might be caching the formatter reference. - """ - # This is a no-op but signals modules to refresh their reference +warnings.warn( + "The module 'html_formatter' is deprecated and will be removed in the next release." + "Please use 'dataframe_formatter' instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/python/datafusion/io.py b/python/datafusion/io.py index ef5ebf96f..551e20a6f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -34,7 +34,7 @@ def read_parquet( path: str | pathlib.Path, - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, parquet_pruning: bool = True, file_extension: str = ".parquet", skip_metadata: bool = True, @@ -83,7 +83,7 @@ def read_json( schema: pa.Schema | None = None, schema_infer_max_records: int = 1000, file_extension: str = ".json", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a line-delimited JSON data source. @@ -124,7 +124,7 @@ def read_csv( delimiter: str = ",", schema_infer_max_records: int = 1000, file_extension: str = ".csv", - table_partition_cols: list[tuple[str, str]] | None = None, + table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, ) -> DataFrame: """Read a CSV data source. @@ -171,7 +171,7 @@ def read_csv( def read_avro( path: str | pathlib.Path, schema: pa.Schema | None = None, - file_partition_cols: list[tuple[str, str]] | None = None, + file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_extension: str = ".avro", ) -> DataFrame: """Create a :py:class:`DataFrame` for reading Avro data source. diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 9ec3679a6..bd686acbb 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,7 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload import pyarrow as pa @@ -77,6 +77,12 @@ def __str__(self) -> str: return self.name.lower() +class ScalarUDFExportable(Protocol): + """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.""" + + def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 + + class ScalarUDF: """Class for performing scalar user-defined functions (UDF). @@ -96,12 +102,19 @@ def __init__( See helper method :py:func:`udf` for argument details. """ + if hasattr(func, "__datafusion_scalar_udf__"): + self._udf = df_internal.ScalarUDF.from_pycapsule(func) + return if isinstance(input_types, pa.DataType): input_types = [input_types] self._udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) ) + def __repr__(self) -> str: + """Print a string representation of the Scalar UDF.""" + return self._udf.__repr__() + def __call__(self, *args: Expr) -> Expr: """Execute the UDF. @@ -130,6 +143,10 @@ def udf( name: Optional[str] = None, ) -> ScalarUDF: ... + @overload + @staticmethod + def udf(func: ScalarUDFExportable) -> ScalarUDF: ... + @staticmethod def udf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Function (UDF). @@ -143,7 +160,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 Args: func (Callable, optional): Only needed when calling as a function. - Skip this argument when using ``udf`` as a decorator. + Skip this argument when using `udf` as a decorator. If you have a Rust + backed ScalarUDF within a PyCapsule, you can pass this parameter + and ignore the rest. They will be determined directly from the + underlying function. See the online documentation for more information. input_types (list[pa.DataType]): The data types of the arguments to ``func``. This list must be of the same length as the number of arguments. @@ -211,12 +231,31 @@ def wrapper(*args: Any, **kwargs: Any): return decorator + if hasattr(args[0], "__datafusion_scalar_udf__"): + return ScalarUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return _function(*args, **kwargs) # Case 2: Used as a decorator with parameters return _decorator(*args, **kwargs) + @staticmethod + def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF: + """Create a Scalar UDF from ScalarUDF PyCapsule object. + + This function will instantiate a Scalar UDF that uses a DataFusion + ScalarUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return ScalarUDF( + name=name, + func=func, + input_types=None, + return_type=None, + volatility=None, + ) + class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @@ -238,6 +277,12 @@ def evaluate(self) -> pa.Scalar: """Return the resultant value.""" +class AggregateUDFExportable(Protocol): + """Type hint for object that has __datafusion_aggregate_udf__ PyCapsule.""" + + def __datafusion_aggregate_udf__(self) -> object: ... # noqa: D105 + + class AggregateUDF: """Class for performing scalar user-defined functions (UDF). @@ -259,6 +304,9 @@ def __init__( See :py:func:`udaf` for a convenience function and argument descriptions. """ + if hasattr(accumulator, "__datafusion_aggregate_udf__"): + self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator) + return self._udaf = df_internal.AggregateUDF( name, accumulator, @@ -268,6 +316,10 @@ def __init__( str(volatility), ) + def __repr__(self) -> str: + """Print a string representation of the Aggregate UDF.""" + return self._udaf.__repr__() + def __call__(self, *args: Expr) -> Expr: """Execute the UDAF. @@ -299,7 +351,7 @@ def udaf( ) -> AggregateUDF: ... @staticmethod - def udaf(*args: Any, **kwargs: Any): # noqa: D417 + def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 """Create a new User-Defined Aggregate Function (UDAF). This class allows you to define an aggregate function that can be used in @@ -356,6 +408,10 @@ def udf4() -> Summarize: Args: accum: The accumulator python function. Only needed when calling as a function. Skip this argument when using ``udaf`` as a decorator. + If you have a Rust backed AggregateUDF within a PyCapsule, you can + pass this parameter and ignore the rest. They will be determined + directly from the underlying function. See the online documentation + for more information. input_types: The data types of the arguments to ``accum``. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. @@ -414,12 +470,32 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator + if hasattr(args[0], "__datafusion_aggregate_udf__"): + return AggregateUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return _function(*args, **kwargs) # Case 2: Used as a decorator with parameters return _decorator(*args, **kwargs) + @staticmethod + def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF: + """Create an Aggregate UDF from AggregateUDF PyCapsule object. + + This function will instantiate a Aggregate UDF that uses a DataFusion + AggregateUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return AggregateUDF( + name=name, + accumulator=func, + input_types=None, + return_type=None, + state_type=None, + volatility=None, + ) + class WindowEvaluator: """Evaluator class for user-defined window functions (UDWF). @@ -580,6 +656,12 @@ def include_rank(self) -> bool: return False +class WindowUDFExportable(Protocol): + """Type hint for object that has __datafusion_window_udf__ PyCapsule.""" + + def __datafusion_window_udf__(self) -> object: ... # noqa: D105 + + class WindowUDF: """Class for performing window user-defined functions (UDF). @@ -600,10 +682,17 @@ def __init__( See :py:func:`udwf` for a convenience function and argument descriptions. """ + if hasattr(func, "__datafusion_window_udf__"): + self._udwf = df_internal.WindowUDF.from_pycapsule(func) + return self._udwf = df_internal.WindowUDF( name, func, input_types, return_type, str(volatility) ) + def __repr__(self) -> str: + """Print a string representation of the Window UDF.""" + return self._udwf.__repr__() + def __call__(self, *args: Expr) -> Expr: """Execute the UDWF. @@ -671,7 +760,10 @@ def biased_numbers() -> BiasedNumbers: Args: func: Only needed when calling as a function. Skip this argument when - using ``udwf`` as a decorator. + using ``udwf`` as a decorator. If you have a Rust backed WindowUDF + within a PyCapsule, you can pass this parameter and ignore the rest. + They will be determined directly from the underlying function. See + the online documentation for more information. input_types: The data types of the arguments. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. @@ -680,6 +772,9 @@ def biased_numbers() -> BiasedNumbers: Returns: A user-defined window function that can be used in window function calls. """ + if hasattr(args[0], "__datafusion_window_udf__"): + return WindowUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return WindowUDF._create_window_udf(*args, **kwargs) @@ -747,6 +842,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator + @staticmethod + def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: + """Create a Window UDF from WindowUDF PyCapsule object. + + This function will instantiate a Window UDF that uses a DataFusion + WindowUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return WindowUDF( + name=name, + func=func, + input_types=None, + return_type=None, + volatility=None, + ) + class TableFunction: """Class for performing user-defined table functions (UDTF). diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 23b328458..21b0a3e0a 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -15,8 +15,11 @@ # specific language governing permissions and limitations # under the License. +import datafusion as dfn import pyarrow as pa +import pyarrow.dataset as ds import pytest +from datafusion import SessionContext, Table # Note we take in `database` as a variable even though we don't use @@ -27,7 +30,7 @@ def test_basic(ctx, database): ctx.catalog("non-existent") default = ctx.catalog() - assert default.names() == ["public"] + assert default.names() == {"public"} for db in [default.database("public"), default.database()]: assert db.names() == {"csv1", "csv", "csv2"} @@ -41,3 +44,100 @@ def test_basic(ctx, database): pa.field("float", pa.float64(), nullable=True), ] ) + + +class CustomTableProvider: + def __init__(self): + pass + + +def create_dataset() -> pa.dataset.Dataset: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + return ds.dataset([batch]) + + +class CustomSchemaProvider: + def __init__(self): + self.tables = {"table1": create_dataset()} + + def table_names(self) -> set[str]: + return set(self.tables.keys()) + + def register_table(self, name: str, table: Table): + self.tables[name] = table + + def deregister_table(self, name, cascade: bool = True): + del self.tables[name] + + +class CustomCatalogProvider: + def __init__(self): + self.schemas = {"my_schema": CustomSchemaProvider()} + + def schema_names(self) -> set[str]: + return set(self.schemas.keys()) + + def schema(self, name: str): + return self.schemas[name] + + def register_schema(self, name: str, schema: dfn.catalog.Schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade: bool): + del self.schemas[name] + + +def test_python_catalog_provider(ctx: SessionContext): + ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) + + # Check the default catalog provider + assert ctx.catalog("datafusion").names() == {"public"} + + my_catalog = ctx.catalog("my_catalog") + assert my_catalog.names() == {"my_schema"} + + my_catalog.register_schema("second_schema", CustomSchemaProvider()) + assert my_catalog.schema_names() == {"my_schema", "second_schema"} + + my_catalog.deregister_schema("my_schema") + assert my_catalog.schema_names() == {"second_schema"} + + +def test_python_schema_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.deregister_schema("public") + + catalog.register_schema("test_schema1", CustomSchemaProvider()) + assert catalog.names() == {"test_schema1"} + + catalog.register_schema("test_schema2", CustomSchemaProvider()) + catalog.deregister_schema("test_schema1") + assert catalog.names() == {"test_schema2"} + + +def test_python_table_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.register_schema("custom_schema", CustomSchemaProvider()) + schema = catalog.schema("custom_schema") + + assert schema.table_names() == {"table1"} + + schema.deregister_table("table1") + schema.register_table("table2", create_dataset()) + assert schema.table_names() == {"table2"} + + # Use the default schema instead of our custom schema + + schema = catalog.schema() + + schema.register_table("table3", create_dataset()) + assert schema.table_names() == {"table3"} + + schema.deregister_table("table3") + schema.register_table("table4", create_dataset()) + assert schema.table_names() == {"table4"} diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 64220ce9c..3b816bc85 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -27,6 +27,8 @@ import pytest from datafusion import ( DataFrame, + ParquetColumnOptions, + ParquetWriterOptions, SessionContext, WindowFrame, column, @@ -35,14 +37,14 @@ from datafusion import ( functions as f, ) -from datafusion.expr import Window -from datafusion.html_formatter import ( +from datafusion.dataframe_formatter import ( DataFrameHtmlFormatter, configure_formatter, get_formatter, reset_formatter, reset_styles_loaded_state, ) +from datafusion.expr import Window from pyarrow.csv import write_csv MB = 1024 * 1024 @@ -66,6 +68,21 @@ def df(): return ctx.from_arrow(batch) +@pytest.fixture +def large_df(): + ctx = SessionContext() + + rows = 100000 + data = { + "a": list(range(rows)), + "b": [f"s-{i}" for i in range(rows)], + "c": [float(i + 0.1) for i in range(rows)], + } + batch = pa.record_batch(data) + + return ctx.from_arrow(batch) + + @pytest.fixture def struct_df(): ctx = SessionContext() @@ -1632,6 +1649,395 @@ def test_write_compressed_parquet_default_compression_level(df, tmp_path, compre df.write_parquet(str(path), compression=compression) +def test_write_parquet_with_options_default_compression(df, tmp_path): + """Test that the default compression is ZSTD.""" + df.write_parquet(tmp_path) + + for file in tmp_path.rglob("*.parquet"): + metadata = pq.ParquetFile(file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["compression"].lower() == "zstd" + + +@pytest.mark.parametrize( + "compression", + ["gzip(6)", "brotli(7)", "zstd(15)", "snappy", "uncompressed"], +) +def test_write_parquet_with_options_compression(df, tmp_path, compression): + import re + + path = tmp_path + df.write_parquet_with_options( + str(path), ParquetWriterOptions(compression=compression) + ) + + # test that the actual compression scheme is the one written + for _root, _dirs, files in os.walk(path): + for file in files: + if file.endswith(".parquet"): + metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["compression"].lower() == re.sub( + r"\(\d+\)", "", compression + ) + + result = pq.read_table(str(path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize( + "compression", + ["gzip(12)", "brotli(15)", "zstd(23)"], +) +def test_write_parquet_with_options_wrong_compression_level(df, tmp_path, compression): + path = tmp_path + + with pytest.raises(Exception, match=r"valid compression range .*? exceeded."): + df.write_parquet_with_options( + str(path), ParquetWriterOptions(compression=compression) + ) + + +@pytest.mark.parametrize("compression", ["wrong", "wrong(12)"]) +def test_write_parquet_with_options_invalid_compression(df, tmp_path, compression): + path = tmp_path + + with pytest.raises(Exception, match="Unknown or unsupported parquet compression"): + df.write_parquet_with_options( + str(path), ParquetWriterOptions(compression=compression) + ) + + +@pytest.mark.parametrize( + ("writer_version", "format_version"), + [("1.0", "1.0"), ("2.0", "2.6"), (None, "1.0")], +) +def test_write_parquet_with_options_writer_version( + df, tmp_path, writer_version, format_version +): + """Test the Parquet writer version. Note that writer_version=2.0 results in + format_version=2.6""" + if writer_version is None: + df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) + else: + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(writer_version=writer_version) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + assert metadata["format_version"] == format_version + + +@pytest.mark.parametrize("writer_version", ["1.2.3", "custom-version", "0"]) +def test_write_parquet_with_options_wrong_writer_version(df, tmp_path, writer_version): + """Test that invalid writer versions in Parquet throw an exception.""" + with pytest.raises( + Exception, match="Unknown or unsupported parquet writer version" + ): + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(writer_version=writer_version) + ) + + +@pytest.mark.parametrize("dictionary_enabled", [True, False, None]) +def test_write_parquet_with_options_dictionary_enabled( + df, tmp_path, dictionary_enabled +): + """Test enabling/disabling the dictionaries in Parquet.""" + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(dictionary_enabled=dictionary_enabled) + ) + # by default, the dictionary is enabled, so None results in True + result = dictionary_enabled if dictionary_enabled is not None else True + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["has_dictionary_page"] == result + + +@pytest.mark.parametrize( + ("statistics_enabled", "has_statistics"), + [("page", True), ("chunk", True), ("none", False), (None, True)], +) +def test_write_parquet_with_options_statistics_enabled( + df, tmp_path, statistics_enabled, has_statistics +): + """Test configuring the statistics in Parquet. In pyarrow we can only check for + column-level statistics, so "page" and "chunk" are tested in the same way.""" + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(statistics_enabled=statistics_enabled) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + if has_statistics: + assert col["statistics"] is not None + else: + assert col["statistics"] is None + + +@pytest.mark.parametrize("max_row_group_size", [1000, 5000, 10000, 100000]) +def test_write_parquet_with_options_max_row_group_size( + large_df, tmp_path, max_row_group_size +): + """Test configuring the max number of rows per group in Parquet. These test cases + guarantee that the number of rows for each row group is max_row_group_size, given + the total number of rows is a multiple of max_row_group_size.""" + large_df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(max_row_group_size=max_row_group_size) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + for row_group in metadata["row_groups"]: + assert row_group["num_rows"] == max_row_group_size + + +@pytest.mark.parametrize("created_by", ["datafusion", "datafusion-python", "custom"]) +def test_write_parquet_with_options_created_by(df, tmp_path, created_by): + """Test configuring the created by metadata in Parquet.""" + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(created_by=created_by)) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + assert metadata["created_by"] == created_by + + +@pytest.mark.parametrize("statistics_truncate_length", [5, 25, 50]) +def test_write_parquet_with_options_statistics_truncate_length( + df, tmp_path, statistics_truncate_length +): + """Test configuring the truncate limit in Parquet's row-group-level statistics.""" + ctx = SessionContext() + data = { + "a": [ + "a_the_quick_brown_fox_jumps_over_the_lazy_dog", + "m_the_quick_brown_fox_jumps_over_the_lazy_dog", + "z_the_quick_brown_fox_jumps_over_the_lazy_dog", + ], + "b": ["a_smaller", "m_smaller", "z_smaller"], + } + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options( + tmp_path, + ParquetWriterOptions(statistics_truncate_length=statistics_truncate_length), + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + statistics = col["statistics"] + assert len(statistics["min"]) <= statistics_truncate_length + assert len(statistics["max"]) <= statistics_truncate_length + + +def test_write_parquet_with_options_default_encoding(tmp_path): + """Test that, by default, Parquet files are written with dictionary encoding. + Note that dictionary encoding is not used for boolean values, so it is not tested + here.""" + ctx = SessionContext() + data = { + "a": [1, 2, 3], + "b": ["1", "2", "3"], + "c": [1.01, 2.02, 3.03], + } + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options(tmp_path, ParquetWriterOptions()) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["encodings"] == ("PLAIN", "RLE", "RLE_DICTIONARY") + + +@pytest.mark.parametrize( + ("encoding", "data_types", "result"), + [ + ("plain", ["int", "float", "str", "bool"], ("PLAIN", "RLE")), + ("rle", ["bool"], ("RLE",)), + ("delta_binary_packed", ["int"], ("RLE", "DELTA_BINARY_PACKED")), + ("delta_length_byte_array", ["str"], ("RLE", "DELTA_LENGTH_BYTE_ARRAY")), + ("delta_byte_array", ["str"], ("RLE", "DELTA_BYTE_ARRAY")), + ("byte_stream_split", ["int", "float"], ("RLE", "BYTE_STREAM_SPLIT")), + ], +) +def test_write_parquet_with_options_encoding(tmp_path, encoding, data_types, result): + """Test different encodings in Parquet in their respective support column types.""" + ctx = SessionContext() + + data = {} + for data_type in data_types: + if data_type == "int": + data["int"] = [1, 2, 3] + elif data_type == "float": + data["float"] = [1.01, 2.02, 3.03] + elif data_type == "str": + data["str"] = ["a", "b", "c"] + elif data_type == "bool": + data["bool"] = [True, False, True] + + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options( + tmp_path, ParquetWriterOptions(encoding=encoding, dictionary_enabled=False) + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + assert col["encodings"] == result + + +@pytest.mark.parametrize("encoding", ["bit_packed"]) +def test_write_parquet_with_options_unsupported_encoding(df, tmp_path, encoding): + """Test that unsupported Parquet encodings do not work.""" + # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 + with pytest.raises(BaseException, match="Encoding .*? is not supported"): + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) + + +@pytest.mark.parametrize("encoding", ["non_existent", "unknown", "plain123"]) +def test_write_parquet_with_options_invalid_encoding(df, tmp_path, encoding): + """Test that invalid Parquet encodings do not work.""" + with pytest.raises(Exception, match="Unknown or unsupported parquet encoding"): + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) + + +@pytest.mark.parametrize("encoding", ["plain_dictionary", "rle_dictionary"]) +def test_write_parquet_with_options_dictionary_encoding_fallback( + df, tmp_path, encoding +): + """Test that the dictionary encoding cannot be used as fallback in Parquet.""" + # BaseException is used since this throws a Rust panic: https://github.com/PyO3/pyo3/issues/3519 + with pytest.raises( + BaseException, match="Dictionary encoding can not be used as fallback encoding" + ): + df.write_parquet_with_options(tmp_path, ParquetWriterOptions(encoding=encoding)) + + +def test_write_parquet_with_options_bloom_filter(df, tmp_path): + """Test Parquet files with and without (default) bloom filters. Since pyarrow does + not expose any information about bloom filters, the easiest way to confirm that they + are actually written is to compare the file size.""" + path_no_bloom_filter = tmp_path / "1" + path_bloom_filter = tmp_path / "2" + + df.write_parquet_with_options(path_no_bloom_filter, ParquetWriterOptions()) + df.write_parquet_with_options( + path_bloom_filter, ParquetWriterOptions(bloom_filter_on_write=True) + ) + + size_no_bloom_filter = 0 + for file in path_no_bloom_filter.rglob("*.parquet"): + size_no_bloom_filter += os.path.getsize(file) + + size_bloom_filter = 0 + for file in path_bloom_filter.rglob("*.parquet"): + size_bloom_filter += os.path.getsize(file) + + assert size_no_bloom_filter < size_bloom_filter + + +def test_write_parquet_with_options_column_options(df, tmp_path): + """Test writing Parquet files with different options for each column, which replace + the global configs (when provided).""" + data = { + "a": [1, 2, 3], + "b": ["a", "b", "c"], + "c": [False, True, False], + "d": [1.01, 2.02, 3.03], + "e": [4, 5, 6], + } + + column_specific_options = { + "a": ParquetColumnOptions(statistics_enabled="none"), + "b": ParquetColumnOptions(encoding="plain", dictionary_enabled=False), + "c": ParquetColumnOptions( + compression="snappy", encoding="rle", dictionary_enabled=False + ), + "d": ParquetColumnOptions( + compression="zstd(6)", + encoding="byte_stream_split", + dictionary_enabled=False, + statistics_enabled="none", + ), + # column "e" will use the global configs + } + + results = { + "a": { + "statistics": False, + "compression": "brotli", + "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), + }, + "b": { + "statistics": True, + "compression": "brotli", + "encodings": ("PLAIN", "RLE"), + }, + "c": { + "statistics": True, + "compression": "snappy", + "encodings": ("RLE",), + }, + "d": { + "statistics": False, + "compression": "zstd", + "encodings": ("RLE", "BYTE_STREAM_SPLIT"), + }, + "e": { + "statistics": True, + "compression": "brotli", + "encodings": ("PLAIN", "RLE", "RLE_DICTIONARY"), + }, + } + + ctx = SessionContext() + df = ctx.from_arrow(pa.record_batch(data)) + df.write_parquet_with_options( + tmp_path, + ParquetWriterOptions( + compression="brotli(8)", column_specific_options=column_specific_options + ), + ) + + for file in tmp_path.rglob("*.parquet"): + parquet = pq.ParquetFile(file) + metadata = parquet.metadata.to_dict() + + for row_group in metadata["row_groups"]: + for col in row_group["columns"]: + column_name = col["path_in_schema"] + result = results[column_name] + assert (col["statistics"] is not None) == result["statistics"] + assert col["compression"].lower() == result["compression"].lower() + assert col["encodings"] == result["encodings"] + + def test_dataframe_export(df) -> None: # Guarantees that we have the canonical implementation # reading our dataframe export diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index adca783b5..40a98dc4d 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -19,7 +19,14 @@ import pyarrow as pa import pytest -from datafusion import SessionContext, col, functions, lit +from datafusion import ( + SessionContext, + col, + functions, + lit, + lit_with_metadata, + literal_with_metadata, +) from datafusion.expr import ( Aggregate, AggregateFunction, @@ -103,7 +110,7 @@ def test_limit(test_ctx): plan = plan.to_variant() assert isinstance(plan, Limit) - assert "Skip: Some(Literal(Int64(5)))" in str(plan) + assert "Skip: Some(Literal(Int64(5), None))" in str(plan) def test_aggregate_query(test_ctx): @@ -824,3 +831,52 @@ def test_expr_functions(ctx, function, expected_result): assert len(result) == 1 assert result[0].column(0).equals(expected_result) + + +def test_literal_metadata(ctx): + result = ( + ctx.from_pydict({"a": [1]}) + .select( + lit(1).alias("no_metadata"), + lit_with_metadata(2, {"key1": "value1"}).alias("lit_with_metadata_fn"), + literal_with_metadata(3, {"key2": "value2"}).alias( + "literal_with_metadata_fn" + ), + ) + .collect() + ) + + expected_schema = pa.schema( + [ + pa.field("no_metadata", pa.int64(), nullable=False), + pa.field( + "lit_with_metadata_fn", + pa.int64(), + nullable=False, + metadata={"key1": "value1"}, + ), + pa.field( + "literal_with_metadata_fn", + pa.int64(), + nullable=False, + metadata={"key2": "value2"}, + ), + ] + ) + + expected = pa.RecordBatch.from_pydict( + { + "no_metadata": pa.array([1]), + "lit_with_metadata_fn": pa.array([2]), + "literal_with_metadata_fn": pa.array([3]), + }, + schema=expected_schema, + ) + + assert result[0] == expected + + # Testing result[0].schema == expected_schema does not check each key/value pair + # so we want to explicitly test these + for expected_field in expected_schema: + actual_field = result[0].schema.field(expected_field.name) + assert expected_field.metadata == actual_field.metadata diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index b6348e3a0..41cee4ef3 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path): assert result.to_pydict() == {"cnt": [100]} -@pytest.mark.parametrize("path_to_str", [True, False]) -def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): +@pytest.mark.parametrize( + ("path_to_str", "legacy_data_type"), [(True, False), (False, False), (False, True)] +) +def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, legacy_data_type): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) (dir_root / "grp=a").mkdir(exist_ok=False) @@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, path_to_str): dir_root = str(dir_root) if path_to_str else dir_root + partition_data_type = "string" if legacy_data_type else pa.string() + ctx.register_parquet( "datapp", dir_root, - table_partition_cols=[("grp", "string")], + table_partition_cols=[("grp", partition_data_type)], parquet_pruning=True, file_extension=".parquet", ) @@ -488,9 +492,9 @@ def test_register_listing_table( ): dir_root = tmp_path / "dataset_parquet_partitioned" dir_root.mkdir(exist_ok=False) - (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True) - (dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True) - (dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True) + (dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True) table = pa.Table.from_arrays( [ @@ -501,13 +505,13 @@ def test_register_listing_table( names=["int", "str", "float"], ) pa.parquet.write_table( - table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet" + table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet" ) pa.parquet.write_table( - table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet" + table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet" ) pa.parquet.write_table( - table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet" + table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet" ) dir_root = f"file://{dir_root}/" if path_to_str else dir_root @@ -515,7 +519,7 @@ def test_register_listing_table( ctx.register_listing_table( "my_table", dir_root, - table_partition_cols=[("grp", "string"), ("date_id", "int")], + table_partition_cols=[("grp", pa.string()), ("date", pa.date64())], file_extension=".parquet", schema=table.schema if pass_schema else None, file_sort_order=file_sort_order, @@ -531,7 +535,7 @@ def test_register_listing_table( assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2} result = ctx.sql( - "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp" # noqa: E501 + "SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' GROUP BY grp" # noqa: E501 ).collect() result = pa.Table.from_batches(result) diff --git a/python/tests/test_wrapper_coverage.py b/python/tests/test_wrapper_coverage.py index 926a65961..f484cb282 100644 --- a/python/tests/test_wrapper_coverage.py +++ b/python/tests/test_wrapper_coverage.py @@ -28,14 +28,14 @@ from enum import EnumMeta as EnumType -def missing_exports(internal_obj, wrapped_obj) -> None: +def missing_exports(internal_obj, wrapped_obj) -> None: # noqa: C901 """ Identify if any of the rust exposted structs or functions do not have wrappers. Special handling for: - Raw* classes: Internal implementation details that shouldn't be exposed - _global_ctx: Internal implementation detail - - __self__, __class__: Python special attributes + - __self__, __class__, __repr__: Python special attributes """ # Special case enums - EnumType overrides a some of the internal functions, # so check all of the values exist and move on @@ -45,6 +45,9 @@ def missing_exports(internal_obj, wrapped_obj) -> None: assert value in dir(wrapped_obj) return + if "__repr__" in internal_obj.__dict__ and "__repr__" not in wrapped_obj.__dict__: + pytest.fail(f"Missing __repr__: {internal_obj.__name__}") + for internal_attr_name in dir(internal_obj): wrapped_attr_name = internal_attr_name.removeprefix("Raw") assert wrapped_attr_name in dir(wrapped_obj) diff --git a/src/catalog.rs b/src/catalog.rs index 83f8d08cb..9a24f2d44 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -15,44 +15,51 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::sync::Arc; - -use pyo3::exceptions::PyKeyError; -use pyo3::prelude::*; - -use crate::errors::{PyDataFusionError, PyDataFusionResult}; -use crate::utils::wait_for_future; +use crate::dataset::Dataset; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::utils::{validate_pycapsule, wait_for_future}; +use async_trait::async_trait; +use datafusion::catalog::MemorySchemaProvider; +use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; +use pyo3::IntoPyObjectExt; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; -#[pyclass(name = "Catalog", module = "datafusion", subclass)] +#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "Database", module = "datafusion", subclass)] -pub struct PyDatabase { - pub database: Arc, +#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +pub struct PySchema { + pub schema: Arc, } -#[pyclass(name = "Table", module = "datafusion", subclass)] +#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] pub struct PyTable { pub table: Arc, } -impl PyCatalog { - pub fn new(catalog: Arc) -> Self { +impl From> for PyCatalog { + fn from(catalog: Arc) -> Self { Self { catalog } } } -impl PyDatabase { - pub fn new(database: Arc) -> Self { - Self { database } +impl From> for PySchema { + fn from(schema: Arc) -> Self { + Self { schema } } } @@ -68,36 +75,103 @@ impl PyTable { #[pymethods] impl PyCatalog { - fn names(&self) -> Vec { - self.catalog.schema_names() + #[new] + fn new(catalog: PyObject) -> Self { + let catalog_provider = + Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc; + catalog_provider.into() + } + + fn schema_names(&self) -> HashSet { + self.catalog.schema_names().into_iter().collect() } #[pyo3(signature = (name="public"))] - fn database(&self, name: &str) -> PyResult { - match self.catalog.schema(name) { - Some(database) => Ok(PyDatabase::new(database)), - None => Err(PyKeyError::new_err(format!( - "Database with name {name} doesn't exist." - ))), - } + fn schema(&self, name: &str) -> PyResult { + let schema = self + .catalog + .schema(name) + .ok_or(PyKeyError::new_err(format!( + "Schema with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)), + None => PySchema::from(schema).into_py_any(py), + } + }) + } + + fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> { + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let _ = self + .catalog + .register_schema(name, schema) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { + let capsule = schema_provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPySchemaProvider::new(schema_provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self + .catalog + .register_schema(name, provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> { + let _ = self + .catalog + .deregister_schema(name, cascade) + .map_err(py_datafusion_err)?; + + Ok(()) } fn __repr__(&self) -> PyResult { - Ok(format!( - "Catalog(schema_names=[{}])", - self.names().join(";") - )) + let mut names: Vec = self.schema_names().into_iter().collect(); + names.sort(); + Ok(format!("Catalog(schema_names=[{}])", names.join(", "))) } } #[pymethods] -impl PyDatabase { - fn names(&self) -> HashSet { - self.database.table_names().into_iter().collect() +impl PySchema { + #[new] + fn new(schema_provider: PyObject) -> Self { + let schema_provider = + Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc; + schema_provider.into() + } + + #[getter] + fn table_names(&self) -> HashSet { + self.schema.table_names().into_iter().collect() } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { - if let Some(table) = wait_for_future(py, self.database.table(name))?? { + if let Some(table) = wait_for_future(py, self.schema.table(name))?? { Ok(PyTable::new(table)) } else { Err(PyDataFusionError::Common(format!( @@ -107,14 +181,44 @@ impl PyDatabase { } fn __repr__(&self) -> PyResult { - Ok(format!( - "Database(table_names=[{}])", - Vec::from_iter(self.names()).join(";") - )) + let mut names: Vec = self.table_names().into_iter().collect(); + names.sort(); + Ok(format!("Schema(table_names=[{}])", names.join(";"))) } - // register_table - // deregister_table + fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if table_provider.hasattr("__datafusion_table_provider__")? { + let capsule = table_provider + .getattr("__datafusion_table_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc + }; + + let _ = self + .schema + .register_table(name.to_string(), provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_table(&self, name: &str) -> PyResult<()> { + let _ = self + .schema + .deregister_table(name) + .map_err(py_datafusion_err)?; + + Ok(()) + } } #[pymethods] @@ -145,3 +249,265 @@ impl PyTable { // fn has_exact_statistics // fn supports_filter_pushdown } + +#[derive(Debug)] +pub(crate) struct RustWrappedPySchemaProvider { + schema_provider: PyObject, + owner_name: Option, +} + +impl RustWrappedPySchemaProvider { + pub fn new(schema_provider: PyObject) -> Self { + let owner_name = Python::with_gil(|py| { + schema_provider + .bind(py) + .getattr("owner_name") + .ok() + .map(|name| name.to_string()) + }); + + Self { + schema_provider, + owner_name, + } + } + + fn table_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let py_table_method = provider.getattr("table")?; + + let py_table = py_table_method.call((name,), None)?; + if py_table.is_none() { + return Ok(None); + } + + if py_table.hasattr("__datafusion_table_provider__")? { + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + + Ok(Some(Arc::new(ds) as Arc)) + } + }) + } +} + +#[async_trait] +impl SchemaProvider for RustWrappedPySchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + + provider + .getattr("table_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get table_names: {err}"); + Vec::default() + }) + }) + } + + async fn table( + &self, + name: &str, + ) -> datafusion::common::Result>, DataFusionError> { + self.table_inner(name).map_err(to_datafusion_err) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion::common::Result>> { + let py_table = PyTable::new(table); + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let _ = provider + .call_method1("register_table", (name, py_table)) + .map_err(to_datafusion_err)?; + // Since the definition of `register_table` says that an error + // will be returned if the table already exists, there is no + // case where we want to return a table provider as output. + Ok(None) + }) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let table = provider + .call_method1("deregister_table", (name,)) + .map_err(to_datafusion_err)?; + if table.is_none() { + return Ok(None); + } + + // If we can turn this table provider into a `Dataset`, return it. + // Otherwise, return None. + let dataset = match Dataset::new(&table, py) { + Ok(dataset) => Some(Arc::new(dataset) as Arc), + Err(_) => None, + }; + + Ok(dataset) + }) + } + + fn table_exist(&self, name: &str) -> bool { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + provider + .call_method1("table_exist", (name,)) + .and_then(|pyobj| pyobj.extract()) + .unwrap_or(false) + }) + } +} + +#[derive(Debug)] +pub(crate) struct RustWrappedPyCatalogProvider { + pub(crate) catalog_provider: PyObject, +} + +impl RustWrappedPyCatalogProvider { + pub fn new(catalog_provider: PyObject) -> Self { + Self { catalog_provider } + } + + fn schema_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + + let py_schema = provider.call_method1("schema", (name,))?; + if py_schema.is_none() { + return Ok(None); + } + + if py_schema.hasattr("__datafusion_schema_provider__")? { + let capsule = provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); + + Ok(Some(Arc::new(py_schema) as Arc)) + } + }) + } +} + +#[async_trait] +impl CatalogProvider for RustWrappedPyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + provider + .getattr("schema_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get schema_names: {err}"); + Vec::default() + }) + }) + } + + fn schema(&self, name: &str) -> Option> { + self.schema_inner(name).unwrap_or_else(|err| { + log::error!("CatalogProvider schema returned error: {err}"); + None + }) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion::common::Result>> { + // JRIGHT HERE + // let py_schema: PySchema = schema.into(); + Python::with_gil(|py| { + let py_schema = match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(), + None => &PySchema::from(schema) + .into_py_any(py) + .map_err(to_datafusion_err)?, + }; + + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("register_schema", (name, py_schema)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("deregister_schema", (name, cascade)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/src/context.rs b/src/context.rs index b0af566e4..c97f2f618 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable}; +use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -61,7 +62,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::{ DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext, }; -use datafusion::execution::disk_manager::DiskManagerConfig; +use datafusion::execution::disk_manager::DiskManagerMode; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; use datafusion::execution::options::ReadOptions; use datafusion::execution::runtime_env::RuntimeEnvBuilder; @@ -69,8 +70,10 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; +use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext @@ -183,22 +186,49 @@ impl PyRuntimeEnvBuilder { } fn with_disk_manager_disabled(&self) -> Self { - let mut builder = self.builder.clone(); - builder = builder.with_disk_manager(DiskManagerConfig::Disabled); - Self { builder } + let mut runtime_builder = self.builder.clone(); + + let mut disk_mgr_builder = runtime_builder + .disk_manager_builder + .clone() + .unwrap_or_default(); + disk_mgr_builder.set_mode(DiskManagerMode::Disabled); + + runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder); + Self { + builder: runtime_builder, + } } fn with_disk_manager_os(&self) -> Self { - let builder = self.builder.clone(); - let builder = builder.with_disk_manager(DiskManagerConfig::NewOs); - Self { builder } + let mut runtime_builder = self.builder.clone(); + + let mut disk_mgr_builder = runtime_builder + .disk_manager_builder + .clone() + .unwrap_or_default(); + disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory); + + runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder); + Self { + builder: runtime_builder, + } } fn with_disk_manager_specified(&self, paths: Vec) -> Self { - let builder = self.builder.clone(); let paths = paths.iter().map(|s| s.into()).collect(); - let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths)); - Self { builder } + let mut runtime_builder = self.builder.clone(); + + let mut disk_mgr_builder = runtime_builder + .disk_manager_builder + .clone() + .unwrap_or_default(); + disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths)); + + runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder); + Self { + builder: runtime_builder, + } } fn with_unbounded_memory_pool(&self) -> Self { @@ -353,7 +383,7 @@ impl PySessionContext { &mut self, name: &str, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_extension: &str, schema: Option>, file_sort_order: Option>>, @@ -361,7 +391,12 @@ impl PySessionContext { ) -> PyDataFusionResult<()> { let options = ListingOptions::new(Arc::new(ParquetFormat::new())) .with_file_extension(file_extension) - .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .with_table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .with_file_sort_order( file_sort_order .unwrap_or_default() @@ -582,6 +617,38 @@ impl PySessionContext { Ok(()) } + pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> { + let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc; + let _ = self.ctx.register_catalog(name, catalog); + + Ok(()) + } + + pub fn register_catalog_provider( + &mut self, + name: &str, + provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + let provider = if provider.hasattr("__datafusion_catalog_provider__")? { + let capsule = provider + .getattr("__datafusion_catalog_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignCatalogProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPyCatalogProvider::new(provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self.ctx.register_catalog(name, provider); + + Ok(()) + } + /// Construct datafusion dataframe from Arrow Table pub fn register_table_provider( &mut self, @@ -629,7 +696,7 @@ impl PySessionContext { &mut self, name: &str, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, parquet_pruning: bool, file_extension: &str, skip_metadata: bool, @@ -638,7 +705,12 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult<()> { let mut options = ParquetReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .parquet_pruning(parquet_pruning) .skip_metadata(skip_metadata); options.file_extension = file_extension; @@ -718,7 +790,7 @@ impl PySessionContext { schema: Option>, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult<()> { @@ -728,7 +800,12 @@ impl PySessionContext { let mut options = NdJsonReadOptions::default() .file_compression_type(parse_file_compression_type(file_compression_type)?) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; options.schema = schema.as_ref().map(|x| &x.0); @@ -751,15 +828,19 @@ impl PySessionContext { path: PathBuf, schema: Option>, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, py: Python, ) -> PyDataFusionResult<()> { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; - let mut options = AvroReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + let mut options = AvroReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.file_extension = file_extension; options.schema = schema.as_ref().map(|x| &x.0); @@ -799,14 +880,24 @@ impl PySessionContext { } #[pyo3(signature = (name="datafusion"))] - pub fn catalog(&self, name: &str) -> PyResult { - match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::new(catalog)), - None => Err(PyKeyError::new_err(format!( - "Catalog with name {} doesn't exist.", - &name, - ))), - } + pub fn catalog(&self, name: &str) -> PyResult { + let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( + "Catalog with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)), + None => PyCatalog::from(catalog).into_py_any(py), + } + }) + } + + pub fn catalog_names(&self) -> HashSet { + self.ctx.catalog_names().into_iter().collect() } pub fn tables(&self) -> HashSet { @@ -860,7 +951,7 @@ impl PySessionContext { schema: Option>, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult { @@ -868,7 +959,12 @@ impl PySessionContext { .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; let mut options = NdJsonReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema_infer_max_records = schema_infer_max_records; options.file_extension = file_extension; @@ -901,7 +997,7 @@ impl PySessionContext { delimiter: &str, schema_infer_max_records: usize, file_extension: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_compression_type: Option, py: Python, ) -> PyDataFusionResult { @@ -917,7 +1013,12 @@ impl PySessionContext { .delimiter(delimiter[0]) .schema_infer_max_records(schema_infer_max_records) .file_extension(file_extension) - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema = schema.as_ref().map(|x| &x.0); @@ -947,7 +1048,7 @@ impl PySessionContext { pub fn read_parquet( &self, path: &str, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, parquet_pruning: bool, file_extension: &str, skip_metadata: bool, @@ -956,7 +1057,12 @@ impl PySessionContext { py: Python, ) -> PyDataFusionResult { let mut options = ParquetReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) + .table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ) .parquet_pruning(parquet_pruning) .skip_metadata(skip_metadata); options.file_extension = file_extension; @@ -978,12 +1084,16 @@ impl PySessionContext { &self, path: &str, schema: Option>, - table_partition_cols: Vec<(String, String)>, + table_partition_cols: Vec<(String, PyArrowType)>, file_extension: &str, py: Python, ) -> PyDataFusionResult { - let mut options = AvroReadOptions::default() - .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); + let mut options = AvroReadOptions::default().table_partition_cols( + table_partition_cols + .into_iter() + .map(|(name, ty)| (name, ty.0)) + .collect::>(), + ); options.file_extension = file_extension; let df = if let Some(schema) = schema { options.schema = Some(&schema.0); @@ -1082,21 +1192,6 @@ impl PySessionContext { } } -pub fn convert_table_partition_cols( - table_partition_cols: Vec<(String, String)>, -) -> PyDataFusionResult> { - table_partition_cols - .into_iter() - .map(|(name, ty)| match ty.as_str() { - "string" => Ok((name, DataType::Utf8)), - "int" => Ok((name, DataType::Int32)), - _ => Err(crate::errors::PyDataFusionError::Common(format!( - "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'" - ))), - }) - .collect::, _>>() -} - pub fn parse_file_compression_type( file_compression_type: Option, ) -> Result { diff --git a/src/dataframe.rs b/src/dataframe.rs index 7711a0782..f554f340e 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::ffi::CString; use std::sync::Arc; @@ -23,11 +24,12 @@ use arrow::compute::can_cast_types; use arrow::error::ArrowError; use arrow::ffi::FFI_ArrowSchema; use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::pyarrow::FromPyArrow; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; use datafusion::common::UnnestOptions; -use datafusion::config::{CsvOptions, TableParquetOptions}; +use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions}; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::TableProvider; use datafusion::error::DataFusionError; @@ -49,7 +51,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{ - get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, + get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future, }; use crate::{ errors::PyDataFusionResult, @@ -149,9 +151,9 @@ fn get_python_formatter_with_config(py: Python) -> PyResult { Ok(PythonFormatter { formatter, config }) } -/// Get the Python formatter from the datafusion.html_formatter module +/// Get the Python formatter from the datafusion.dataframe_formatter module fn import_python_formatter(py: Python) -> PyResult> { - let formatter_module = py.import("datafusion.html_formatter")?; + let formatter_module = py.import("datafusion.dataframe_formatter")?; let get_formatter = formatter_module.getattr("get_formatter")?; get_formatter.call0() } @@ -185,6 +187,101 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult< Ok(config) } +/// Python mapping of `ParquetOptions` (includes just the writer-related options). +#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)] +#[derive(Clone, Default)] +pub struct PyParquetWriterOptions { + options: ParquetOptions, +} + +#[pymethods] +impl PyParquetWriterOptions { + #[new] + #[allow(clippy::too_many_arguments)] + pub fn new( + data_pagesize_limit: usize, + write_batch_size: usize, + writer_version: String, + skip_arrow_metadata: bool, + compression: Option, + dictionary_enabled: Option, + dictionary_page_size_limit: usize, + statistics_enabled: Option, + max_row_group_size: usize, + created_by: String, + column_index_truncate_length: Option, + statistics_truncate_length: Option, + data_page_row_count_limit: usize, + encoding: Option, + bloom_filter_on_write: bool, + bloom_filter_fpp: Option, + bloom_filter_ndv: Option, + allow_single_file_parallelism: bool, + maximum_parallel_row_group_writers: usize, + maximum_buffered_record_batches_per_stream: usize, + ) -> Self { + Self { + options: ParquetOptions { + data_pagesize_limit, + write_batch_size, + writer_version, + skip_arrow_metadata, + compression, + dictionary_enabled, + dictionary_page_size_limit, + statistics_enabled, + max_row_group_size, + created_by, + column_index_truncate_length, + statistics_truncate_length, + data_page_row_count_limit, + encoding, + bloom_filter_on_write, + bloom_filter_fpp, + bloom_filter_ndv, + allow_single_file_parallelism, + maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream, + ..Default::default() + }, + } + } +} + +/// Python mapping of `ParquetColumnOptions`. +#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)] +#[derive(Clone, Default)] +pub struct PyParquetColumnOptions { + options: ParquetColumnOptions, +} + +#[pymethods] +impl PyParquetColumnOptions { + #[new] + pub fn new( + bloom_filter_enabled: Option, + encoding: Option, + dictionary_enabled: Option, + compression: Option, + statistics_enabled: Option, + bloom_filter_fpp: Option, + bloom_filter_ndv: Option, + ) -> Self { + Self { + options: ParquetColumnOptions { + bloom_filter_enabled, + encoding, + dictionary_enabled, + compression, + statistics_enabled, + bloom_filter_fpp, + bloom_filter_ndv, + ..Default::default() + }, + } + } +} + /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. @@ -192,12 +289,68 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult< #[derive(Clone)] pub struct PyDataFrame { df: Arc, + + // In IPython environment cache batches between __repr__ and _repr_html_ calls. + batches: Option<(Vec, bool)>, } impl PyDataFrame { /// creates a new PyDataFrame pub fn new(df: DataFrame) -> Self { - Self { df: Arc::new(df) } + Self { + df: Arc::new(df), + batches: None, + } + } + + fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult { + // Get the Python formatter and config + let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; + + let should_cache = *is_ipython_env(py) && self.batches.is_none(); + + let (batches, has_more) = match self.batches.take() { + Some(b) => b, + None => wait_for_future( + py, + collect_record_batches_to_display(self.df.as_ref().clone(), config), + )??, + }; + + if batches.is_empty() { + // This should not be reached, but do it for safety since we index into the vector below + return Ok("No data to display".to_string()); + } + + let table_uuid = uuid::Uuid::new_v4().to_string(); + + // Convert record batches to PyObject list + let py_batches = batches + .iter() + .map(|rb| rb.to_pyarrow(py)) + .collect::>>()?; + + let py_schema = self.schema().into_pyobject(py)?; + + let kwargs = pyo3::types::PyDict::new(py); + let py_batches_list = PyList::new(py, py_batches.as_slice())?; + kwargs.set_item("batches", py_batches_list)?; + kwargs.set_item("schema", py_schema)?; + kwargs.set_item("has_more", has_more)?; + kwargs.set_item("table_uuid", table_uuid)?; + + let method_name = match as_html { + true => "format_html", + false => "format_str", + }; + + let html_result = formatter.call_method(method_name, (), Some(&kwargs))?; + let html_str: String = html_result.extract()?; + if should_cache { + self.batches = Some((batches, has_more)); + } + + Ok(html_str) } } @@ -224,19 +377,32 @@ impl PyDataFrame { } } - fn __repr__(&self, py: Python) -> PyDataFusionResult { - // Get the Python formatter config - let PythonFormatter { - formatter: _, - config, - } = get_python_formatter_with_config(py)?; - let (batches, has_more) = wait_for_future( - py, - collect_record_batches_to_display(self.df.as_ref().clone(), config), - )??; + fn __repr__(&mut self, py: Python) -> PyDataFusionResult { + self.prepare_repr_string(py, false) + } + + fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { + self.prepare_repr_string(py, true) + } + + #[staticmethod] + #[expect(unused_variables)] + fn default_str_repr<'py>( + batches: Vec>, + schema: &Bound<'py, PyAny>, + has_more: bool, + table_uuid: &str, + ) -> PyResult { + let batches = batches + .into_iter() + .map(|batch| RecordBatch::from_pyarrow_bound(&batch)) + .collect::>>()? + .into_iter() + .filter(|batch| batch.num_rows() > 0) + .collect::>(); + if batches.is_empty() { - // This should not be reached, but do it for safety since we index into the vector below - return Ok("No data to display".to_string()); + return Ok("No data to display".to_owned()); } let batches_as_displ = @@ -250,41 +416,6 @@ impl PyDataFrame { Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) } - fn _repr_html_(&self, py: Python) -> PyDataFusionResult { - // Get the Python formatter and config - let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; - let (batches, has_more) = wait_for_future( - py, - collect_record_batches_to_display(self.df.as_ref().clone(), config), - )??; - if batches.is_empty() { - // This should not be reached, but do it for safety since we index into the vector below - return Ok("No data to display".to_string()); - } - - let table_uuid = uuid::Uuid::new_v4().to_string(); - - // Convert record batches to PyObject list - let py_batches = batches - .into_iter() - .map(|rb| rb.to_pyarrow(py)) - .collect::>>()?; - - let py_schema = self.schema().into_pyobject(py)?; - - let kwargs = pyo3::types::PyDict::new(py); - let py_batches_list = PyList::new(py, py_batches.as_slice())?; - kwargs.set_item("batches", py_batches_list)?; - kwargs.set_item("schema", py_schema)?; - kwargs.set_item("has_more", has_more)?; - kwargs.set_item("table_uuid", table_uuid)?; - - let html_result = formatter.call_method("format_html", (), Some(&kwargs))?; - let html_str: String = html_result.extract()?; - - Ok(html_str) - } - /// Calculate summary statistics for a DataFrame fn describe(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); @@ -689,6 +820,34 @@ impl PyDataFrame { Ok(()) } + /// Write a `DataFrame` to a Parquet file, using advanced options. + fn write_parquet_with_options( + &self, + path: &str, + options: PyParquetWriterOptions, + column_specific_options: HashMap, + py: Python, + ) -> PyDataFusionResult<()> { + let table_options = TableParquetOptions { + global: options.options, + column_specific_options: column_specific_options + .into_iter() + .map(|(k, v)| (k, v.options)) + .collect(), + ..Default::default() + }; + + wait_for_future( + py, + self.df.as_ref().clone().write_parquet( + path, + DataFrameWriteOptions::new(), + Option::from(table_options), + ), + )??; + Ok(()) + } + /// Executes a query and writes the results to a partitioned JSON file. fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> { wait_for_future( diff --git a/src/expr.rs b/src/expr.rs index bc7dbeffd..6b1d01d65 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion::logical_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion::logical_expr::expr::AggregateFunctionParams; use datafusion::logical_expr::utils::exprlist_to_fields; use datafusion::logical_expr::{ - ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition, + lit_with_metadata, ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition, }; use pyo3::IntoPyObjectExt; use pyo3::{basic::CompareOp, prelude::*}; @@ -150,7 +150,7 @@ impl PyExpr { Ok(PyScalarVariable::new(data_type, variables).into_bound_py_any(py)?) } Expr::Like(value) => Ok(PyLike::from(value.clone()).into_bound_py_any(py)?), - Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_bound_py_any(py)?), + Expr::Literal(value, metadata) => Ok(PyLiteral::new_with_metadata(value.clone(), metadata.clone()).into_bound_py_any(py)?), Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_bound_py_any(py)?), Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_bound_py_any(py)?), Expr::IsNotNull(expr) => Ok(PyIsNotNull::new(*expr.clone()).into_bound_py_any(py)?), @@ -282,6 +282,14 @@ impl PyExpr { lit(value.0).into() } + #[staticmethod] + pub fn literal_with_metadata( + value: PyScalarValue, + metadata: HashMap, + ) -> PyExpr { + lit_with_metadata(value.0, metadata).into() + } + #[staticmethod] pub fn column(value: &str) -> PyExpr { col(value).into() @@ -377,7 +385,7 @@ impl PyExpr { /// Extracts the Expr value into a PyObject that can be shared with Python pub fn python_value(&self, py: Python) -> PyResult { match &self.expr { - Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py), + Expr::Literal(scalar_value, _) => scalar_to_pyarrow(scalar_value, py), _ => Err(py_type_err(format!( "Non Expr::Literal encountered in types: {:?}", &self.expr @@ -417,11 +425,13 @@ impl PyExpr { params: AggregateFunctionParams { args, .. }, .. }) - | Expr::ScalarFunction(ScalarFunction { args, .. }) - | Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { args, .. }, - .. - }) => Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()), + | Expr::ScalarFunction(ScalarFunction { args, .. }) => { + Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()) + } + Expr::WindowFunction(boxed_window_fn) => { + let args = &boxed_window_fn.params.args; + Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()) + } // Expr(s) that require more specific processing Expr::Case(Case { @@ -600,10 +610,10 @@ impl PyExpr { ) -> PyDataFusionResult { match &self.expr { Expr::AggregateFunction(agg_fn) => { - let window_fn = Expr::WindowFunction(WindowFunction::new( + let window_fn = Expr::WindowFunction(Box::new(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()), agg_fn.params.args.clone(), - )); + ))); add_builder_fns_to_window( window_fn, @@ -743,7 +753,7 @@ impl PyExpr { | Operator::QuestionPipe => Err(py_type_err(format!("Unsupported expr: ${op}"))), }, Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type), - Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value), + Expr::Literal(scalar_value, _) => DataTypeMap::map_from_scalar_value(scalar_value), _ => Err(py_type_err(format!( "Non Expr::Literal encountered in types: {:?}", expr diff --git a/src/expr/literal.rs b/src/expr/literal.rs index a660ac914..45303a104 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -18,11 +18,22 @@ use crate::errors::PyDataFusionError; use datafusion::common::ScalarValue; use pyo3::{prelude::*, IntoPyObjectExt}; +use std::collections::BTreeMap; #[pyclass(name = "Literal", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyLiteral { pub value: ScalarValue, + pub metadata: Option>, +} + +impl PyLiteral { + pub fn new_with_metadata( + value: ScalarValue, + metadata: Option>, + ) -> PyLiteral { + Self { value, metadata } + } } impl From for ScalarValue { @@ -33,7 +44,10 @@ impl From for ScalarValue { impl From for PyLiteral { fn from(value: ScalarValue) -> PyLiteral { - PyLiteral { value } + PyLiteral { + value, + metadata: None, + } } } diff --git a/src/expr/window.rs b/src/expr/window.rs index c5467bf94..052d9eeb4 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -16,7 +16,6 @@ // under the License. use datafusion::common::{DataFusionError, ScalarValue}; -use datafusion::logical_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, WindowFrameUnits}; use pyo3::{prelude::*, IntoPyObjectExt}; use std::fmt::{self, Display, Formatter}; @@ -118,10 +117,9 @@ impl PyWindowExpr { /// Returns order by columns in a window function expression pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { order_by, .. }, - .. - }) => py_sort_expr_list(&order_by), + Expr::WindowFunction(boxed_window_fn) => { + py_sort_expr_list(&boxed_window_fn.params.order_by) + } other => Err(not_window_function_err(other)), } } @@ -129,10 +127,9 @@ impl PyWindowExpr { /// Return partition by columns in a window function expression pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { partition_by, .. }, - .. - }) => py_expr_list(&partition_by), + Expr::WindowFunction(boxed_window_fn) => { + py_expr_list(&boxed_window_fn.params.partition_by) + } other => Err(not_window_function_err(other)), } } @@ -140,10 +137,7 @@ impl PyWindowExpr { /// Return input args for window function pub fn get_args(&self, expr: PyExpr) -> PyResult> { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { args, .. }, - .. - }) => py_expr_list(&args), + Expr::WindowFunction(boxed_window_fn) => py_expr_list(&boxed_window_fn.params.args), other => Err(not_window_function_err(other)), } } @@ -151,7 +145,7 @@ impl PyWindowExpr { /// Return window function name pub fn window_func_name(&self, expr: PyExpr) -> PyResult { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { fun, .. }) => Ok(fun.to_string()), + Expr::WindowFunction(boxed_window_fn) => Ok(boxed_window_fn.fun.to_string()), other => Err(not_window_function_err(other)), } } @@ -159,10 +153,9 @@ impl PyWindowExpr { /// Returns a Pywindow frame for a given window function expression pub fn get_frame(&self, expr: PyExpr) -> Option { match expr.expr.unalias() { - Expr::WindowFunction(WindowFunction { - params: WindowFunctionParams { window_frame, .. }, - .. - }) => Some(window_frame.into()), + Expr::WindowFunction(boxed_window_fn) => { + Some(boxed_window_fn.params.window_frame.into()) + } _ => None, } } diff --git a/src/functions.rs b/src/functions.rs index caa79b8ad..eeef48385 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -103,7 +103,7 @@ fn array_cat(exprs: Vec) -> PyExpr { #[pyo3(signature = (array, element, index=None))] fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { let index = ScalarValue::Int64(index); - let index = Expr::Literal(index); + let index = Expr::Literal(index, None); datafusion::functions_nested::expr_fn::array_position(array.into(), element.into(), index) .into() } @@ -334,7 +334,7 @@ fn window( .unwrap_or(WindowFrame::new(order_by.as_ref().map(|v| !v.is_empty()))); Ok(PyExpr { - expr: datafusion::logical_expr::Expr::WindowFunction(WindowFunction { + expr: datafusion::logical_expr::Expr::WindowFunction(Box::new(WindowFunction { fun, params: WindowFunctionParams { args: args.into_iter().map(|x| x.expr).collect::>(), @@ -351,7 +351,7 @@ fn window( window_frame, null_treatment: None, }, - }), + })), }) } @@ -682,7 +682,7 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } -// We handle first_value explicitly because the signature expects an order_by +// We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] #[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))] @@ -937,7 +937,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; - m.add_wrapped(wrap_pyfunction!(log))?; + m.add_wrapped(wrap_pyfunction!(self::log))?; m.add_wrapped(wrap_pyfunction!(log10))?; m.add_wrapped(wrap_pyfunction!(log2))?; m.add_wrapped(wrap_pyfunction!(lower))?; diff --git a/src/lib.rs b/src/lib.rs index 7dced1fbd..29d3f41da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,15 +77,17 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime); /// datafusion directory. #[pymodule] fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + // Initialize logging + pyo3_log::init(); + // Register the python classes - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -96,6 +98,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + let catalog = PyModule::new(py, "catalog")?; + catalog::init_module(&catalog)?; + m.add_submodule(&catalog)?; + // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new(py, "common")?; common::init_module(&common)?; diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 4b4c86597..7fbb1dc2a 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -61,7 +61,7 @@ fn extract_scalar_list<'py>( .iter() .map(|expr| match expr { // TODO: should we also leverage `ScalarValue::to_pyarrow` here? - Expr::Literal(v) => match v { + Expr::Literal(v, _) => match v { // The unwraps here are for infallible conversions ScalarValue::Boolean(Some(b)) => Ok(b.into_bound_py_any(py)?), ScalarValue::Int8(Some(i)) => Ok(i.into_bound_py_any(py)?), @@ -106,7 +106,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { let op_module = Python::import(py, "operator")?; let pc_expr: PyDataFusionResult> = match expr { Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), - Expr::Literal(scalar) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)), + Expr::Literal(scalar, _) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let operator = operator_to_py(op, &op_module)?; let left = PyArrowFilterExpression::try_from(left.as_ref())?.0; diff --git a/src/udaf.rs b/src/udaf.rs index 34a9cd51d..78f4e2b0c 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -19,6 +19,10 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; +use crate::common::data_type::PyScalarValue; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; +use crate::expr::PyExpr; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::array::{Array, ArrayRef}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; @@ -27,11 +31,8 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF, }; - -use crate::common::data_type::PyScalarValue; -use crate::errors::to_datafusion_err; -use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF}; +use pyo3::types::PyCapsule; #[derive(Debug)] struct RustAccumulator { @@ -183,6 +184,26 @@ impl PyAggregateUDF { Ok(Self { function }) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_aggregate_udf__")? { + let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_aggregate_udf")?; + + let udaf = unsafe { capsule.reference::() }; + let udaf: ForeignAggregateUDF = udaf.try_into()?; + + Ok(Self { + function: udaf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), + )) + } + } + /// creates a new PyExpr with the call of the udf #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { diff --git a/src/udf.rs b/src/udf.rs index 574c9d7b5..de1e3f18c 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -17,6 +17,8 @@ use std::sync::Arc; +use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF}; +use pyo3::types::PyCapsule; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; @@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF; use datafusion::logical_expr::{create_udf, ColumnarValue}; use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use crate::utils::{parse_volatility, validate_pycapsule}; /// Create a Rust callable function from a python function that expects pyarrow arrays fn pyarrow_function_to_rust( @@ -105,6 +108,26 @@ impl PyScalarUDF { Ok(Self { function }) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_scalar_udf__")? { + let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_scalar_udf")?; + + let udf = unsafe { capsule.reference::() }; + let udf: ForeignScalarUDF = udf.try_into()?; + + Ok(Self { + function: udf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(), + )) + } + } + /// creates a new PyExpr with the call of the udf #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { diff --git a/src/udwf.rs b/src/udwf.rs index defd9c522..4fb98916b 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -27,16 +27,17 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use crate::common::data_type::PyScalarValue; -use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, }; -use pyo3::types::{PyList, PyTuple}; +use datafusion_ffi::udwf::{FFI_WindowUDF, ForeignWindowUDF}; +use pyo3::types::{PyCapsule, PyList, PyTuple}; #[derive(Debug)] struct RustPartitionEvaluator { @@ -245,6 +246,26 @@ impl PyWindowUDF { Ok(self.function.call(args).into()) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_window_udf__")? { + let capsule = func.getattr("__datafusion_window_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_window_udf")?; + + let udwf = unsafe { capsule.reference::() }; + let udwf: ForeignWindowUDF = udwf.try_into()?; + + Ok(Self { + function: udwf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_window_udf__ does not exist on WindowUDF object.".to_string(), + )) + } + } + fn __repr__(&self) -> PyResult { Ok(format!("WindowUDF({})", self.function.name())) } @@ -300,13 +321,9 @@ impl WindowUDFImpl for MultiColumnWindowUDF { &self.signature } - fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { // TODO: Should nullable always be `true`? - Ok(arrow::datatypes::Field::new( - field_args.name(), - self.return_type.clone(), - true, - )) + Ok(arrow::datatypes::Field::new(field_args.name(), self.return_type.clone(), true).into()) } // TODO: Enable passing partition_evaluator_args to python? diff --git a/src/utils.rs b/src/utils.rs index 90d654385..f4e121fd5 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -39,6 +39,17 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap())) } +#[inline] +pub(crate) fn is_ipython_env(py: Python) -> &'static bool { + static IS_IPYTHON_ENV: OnceLock = OnceLock::new(); + IS_IPYTHON_ENV.get_or_init(|| { + py.import("IPython") + .and_then(|ipython| ipython.call_method0("get_ipython")) + .map(|ipython| !ipython.is_none()) + .unwrap_or(false) + }) +} + /// Utility to get the Global Datafussion CTX #[inline] pub(crate) fn get_global_ctx() -> &'static SessionContext {
" - f"{field.name}
" - f"
" - "" - "" - f"{formatted_value}" - f"" - f"
" - f"
{formatted_value}