diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 17f25b639c..d2ad985903 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -9,22 +9,20 @@ on: jobs: format: name: Format - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: cargo fmt --all -- --check - check: name: Check - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: runtime: [async-std, tokio] tls: [native-tls, rustls] steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -46,35 +44,22 @@ jobs: cargo clippy \ --no-default-features \ --all-targets \ - --features offline,all-databases,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }} \ + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }} \ -- -D warnings test: name: Unit Test - runs-on: ubuntu-22.04 - strategy: - matrix: - runtime: [ - # Disabled because of https://github.com/rust-lang/cargo/issues/12964 - # async-std, - # actix, - tokio - ] - tls: [ - # native-tls, - rustls - ] + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx save-if: ${{ false }} - - run: - cargo test - --manifest-path sqlx-core/Cargo.toml - --features offline,all-databases,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + - run: sudo apt-get update && sudo apt-get install -y libodbc2 unixodbc-dev + - run: cargo test + --manifest-path sqlx-core/Cargo.toml + --features offline,all-databases,all-types,runtime-tokio-rustls cli: name: CLI Binaries @@ -92,32 +77,29 @@ jobs: target: x86_64-pc-windows-msvc bin: target/debug/cargo-sqlx.exe # FIXME: macOS build fails because of missing pin-project-internal -# - os: macOS-latest -# target: x86_64-apple-darwin -# bin: target/debug/cargo-sqlx + # - os: macOS-latest + # target: x86_64-apple-darwin + # bin: target/debug/cargo-sqlx steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx save-if: ${{ github.ref == 'refs/heads/main' }} - - run: - cargo build - --manifest-path sqlx-cli/Cargo.toml - --bin cargo-sqlx - ${{ matrix.args }} + - run: cargo build + --manifest-path sqlx-cli/Cargo.toml + --bin cargo-sqlx + ${{ matrix.args }} - uses: actions/upload-artifact@v4 with: name: cargo-sqlx-${{ matrix.target }} path: ${{ matrix.bin }} - sqlite: name: SQLite - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: runtime: [async-std, tokio, actix] @@ -126,7 +108,6 @@ jobs: steps: - uses: actions/checkout@v4 - run: mkdir /tmp/sqlite3-lib && wget -O /tmp/sqlite3-lib/ipaddr.so https://github.com/nalgeon/sqlean/releases/download/0.15.2/ipaddr.so - - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -138,12 +119,11 @@ jobs: --no-default-features \ --features sqlite,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }},macros,migrate \ -- -D warnings - - run: - cargo test - --no-default-features - --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - -- - --test-threads=1 + - run: cargo test + --no-default-features + --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} + -- + --test-threads=1 env: DATABASE_URL: sqlite://tests/sqlite/sqlite.db RUSTFLAGS: --cfg sqlite_ipaddr @@ -151,7 +131,7 @@ jobs: postgres: name: Postgres - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: postgres: [14, 10] @@ -161,8 +141,6 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -207,11 +185,10 @@ jobs: postgres_ssl_client_cert: name: Postgres with SSL client cert - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 needs: check steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -225,7 +202,7 @@ jobs: mysql: name: MySQL - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: mysql: [8, 5_7] @@ -235,8 +212,6 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -270,7 +245,7 @@ jobs: mariadb: name: MariaDB - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: mariadb: [10_6, 10_3] @@ -280,8 +255,6 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -308,7 +281,7 @@ jobs: mssql: name: MSSQL - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 strategy: matrix: mssql: [2019, 2022] @@ -318,8 +291,6 @@ jobs: steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 with: prefix-key: v1-sqlx @@ -343,3 +314,38 @@ jobs: cargo test --no-default-features --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mssql://sa:Password123!@localhost/sqlx + + odbc: + name: ODBC (PostgreSQL and SQLite) + runs-on: ubuntu-24.04 + needs: check + timeout-minutes: 15 + steps: + - uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + with: + prefix-key: v1-sqlx + shared-key: odbc + save-if: ${{ github.ref == 'refs/heads/main' }} + - name: Start Postgres (no SSL) + run: | + docker compose -f tests/docker-compose.yml run -d -p 5432:5432 --name postgres_16_no_ssl postgres_16_no_ssl + docker exec postgres_16_no_ssl bash -c "until pg_isready; do sleep 1; done" + - name: Install unixODBC and ODBC drivers (PostgreSQL, SQLite) + run: | + sudo apt-get update + sudo apt-get install -y unixodbc odbcinst unixodbc-common libodbcinst2 odbc-postgresql libsqliteodbc libodbc2 unixodbc-dev + odbcinst -j + - name: Configure system/user DSN for PostgreSQL + run: | + cp tests/odbc.ini ~/.odbc.ini + odbcinst -q -s || true + echo "select 1;" | isql -v SQLX_PG_5432 || true + - name: Run ODBC tests (PostgreSQL DSN) + run: cargo test --no-default-features --features any,odbc,macros,all-types,runtime-tokio-rustls + env: + DATABASE_URL: DSN=SQLX_PG_5432;UID=postgres;PWD=password + - name: Run ODBC tests (SQLite driver) + run: cargo test --no-default-features --features any,odbc,macros,all-types,runtime-tokio-rustls + env: + DATABASE_URL: Driver={SQLite3};Database=./tests/odbc/sqlite.db diff --git a/Cargo.lock b/Cargo.lock index 937b56ec68..dba1eb8e16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -50,6 +50,33 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-activity" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef6978589202a00cd7e118380c448a08b6ed394c3a8df3a430d0898e3a42d046" +dependencies = [ + "android-properties", + "bitflags 2.9.4", + "cc", + "cesu8", + "jni", + "jni-sys", + "libc", + "log", + "ndk", + "ndk-context", + "ndk-sys", + "num_enum", + "thiserror 1.0.69", +] + +[[package]] +name = "android-properties" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7eb209b1518d6bb87b283c20095f5228ecda460da70b44f0802523dea6da04" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -580,6 +607,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block2" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c132eebf10f5cad5289222520a4a058514204aed6d791f1cf4fe8088b82d15f" +dependencies = [ + "objc2", +] + [[package]] name = "blocking" version = "1.6.2" @@ -666,6 +702,20 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "calloop" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b99da2f8558ca23c71f4fd15dc57c906239752dd27ff3c00a1d56b685b7cbfec" +dependencies = [ + "bitflags 2.9.4", + "log", + "polling", + "rustix 0.38.44", + "slab", + "thiserror 1.0.69", +] + [[package]] name = "camino" version = "1.1.12" @@ -715,6 +765,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cexpr" version = "0.6.0" @@ -842,6 +898,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -886,6 +952,30 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-graphics" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-graphics-types", + "foreign-types 0.5.0", + "libc", +] + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -1011,6 +1101,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "cursor-icon" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27ae1dd37df86211c42e150270f82743308803d90a6f6e6651cd730d5e1732f" + [[package]] name = "darling" version = "0.20.11" @@ -1127,6 +1223,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "dispatch" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" + [[package]] name = "displaydoc" version = "0.2.5" @@ -1138,6 +1240,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "dlib" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" +dependencies = [ + "libloading", +] + [[package]] name = "dotenvy" version = "0.15.7" @@ -1150,6 +1261,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" +[[package]] +name = "dpi" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b14ccef22fc6f5a8f4d7d768562a182c04ce9a3b3157b91390b52ddfdf1a76" + [[package]] name = "dunce" version = "1.0.5" @@ -1345,7 +1462,28 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared", + "foreign-types-shared 0.1.1", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared 0.3.1", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -1354,6 +1492,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -2000,6 +2144,28 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "jobserver" version = "0.1.34" @@ -2100,7 +2266,7 @@ checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3" dependencies = [ "bitflags 2.9.4", "libc", - "redox_syscall", + "redox_syscall 0.5.17", ] [[package]] @@ -2290,6 +2456,36 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3f42e7bbe13d351b6bead8286a43aac9534b82bd3cc43e47037f012ebfd62d4" +dependencies = [ + "bitflags 2.9.4", + "jni-sys", + "log", + "ndk-sys", + "num_enum", + "raw-window-handle", + "thiserror 1.0.69", +] + +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + +[[package]] +name = "ndk-sys" +version = "0.6.0+11769913" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6cda3051665f1fb8d9e08fc35c96d5a244fb1be711a03b71118828afc9a873" +dependencies = [ + "jni-sys", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -2414,6 +2610,231 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a973b4e44ce6cad84ce69d797acf9a044532e4184c4f267913d1b546a0727b7a" +dependencies = [ + "num_enum_derive", + "rustversion", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77e878c846a8abae00dd069496dbe8751b16ac1c3d6bd2a7283a938e8228f90d" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "objc-sys" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb91bdd390c7ce1a8607f35f3ca7151b65afc0ff5ff3b34fa350f7d7c7e4310" + +[[package]] +name = "objc2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46a785d4eeff09c14c487497c162e92766fbb3e4059a71840cecc03d9a50b804" +dependencies = [ + "objc-sys", + "objc2-encode", +] + +[[package]] +name = "objc2-app-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" +dependencies = [ + "bitflags 2.9.4", + "block2", + "libc", + "objc2", + "objc2-core-data", + "objc2-core-image", + "objc2-foundation", + "objc2-quartz-core", +] + +[[package]] +name = "objc2-cloud-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009" +dependencies = [ + "bitflags 2.9.4", + "block2", + "objc2", + "objc2-core-location", + "objc2-foundation", +] + +[[package]] +name = "objc2-contacts" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-data" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" +dependencies = [ + "bitflags 2.9.4", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-core-image" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-core-location" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781" +dependencies = [ + "block2", + "objc2", + "objc2-contacts", + "objc2-foundation", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" +dependencies = [ + "bitflags 2.9.4", + "block2", + "dispatch", + "libc", + "objc2", +] + +[[package]] +name = "objc2-link-presentation" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398" +dependencies = [ + "block2", + "objc2", + "objc2-app-kit", + "objc2-foundation", +] + +[[package]] +name = "objc2-metal" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" +dependencies = [ + "bitflags 2.9.4", + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-quartz-core" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" +dependencies = [ + "bitflags 2.9.4", + "block2", + "objc2", + "objc2-foundation", + "objc2-metal", +] + +[[package]] +name = "objc2-symbols" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a684efe3dec1b305badae1a28f6555f6ddd3bb2c2267896782858d5a78404dc" +dependencies = [ + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-ui-kit" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f" +dependencies = [ + "bitflags 2.9.4", + "block2", + "objc2", + "objc2-cloud-kit", + "objc2-core-data", + "objc2-core-image", + "objc2-core-location", + "objc2-foundation", + "objc2-link-presentation", + "objc2-quartz-core", + "objc2-symbols", + "objc2-uniform-type-identifiers", + "objc2-user-notifications", +] + +[[package]] +name = "objc2-uniform-type-identifiers" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe" +dependencies = [ + "block2", + "objc2", + "objc2-foundation", +] + +[[package]] +name = "objc2-user-notifications" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3" +dependencies = [ + "bitflags 2.9.4", + "block2", + "objc2", + "objc2-core-location", + "objc2-foundation", +] + [[package]] name = "object" version = "0.36.7" @@ -2423,6 +2844,26 @@ dependencies = [ "memchr", ] +[[package]] +name = "odbc-api" +version = "19.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b55dec8a12bc5b3a980a71eeab007d3653188e61e2cfd4614a260ef9c41a25" +dependencies = [ + "atoi", + "log", + "odbc-sys", + "thiserror 2.0.16", + "widestring", + "winit", +] + +[[package]] +name = "odbc-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acb069b57ebbad5234fb7197af7ee0c40daceb3946a86fa8d3f7a38393bf2770" + [[package]] name = "once_cell" version = "1.21.3" @@ -2449,7 +2890,7 @@ checksum = "8505734d46c8ab1e19a1dce3aef597ad87dcb4c37e7188231769bd6bd51cebf8" dependencies = [ "bitflags 2.9.4", "cfg-if", - "foreign-types", + "foreign-types 0.3.2", "libc", "once_cell", "openssl-macros", @@ -2501,6 +2942,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "orbclient" +version = "0.3.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba0b26cec2e24f08ed8bb31519a9333140a6599b867dac464bb150bdb796fd43" +dependencies = [ + "libredox", +] + [[package]] name = "os_str_bytes" version = "6.6.1" @@ -2531,7 +2981,7 @@ checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.17", "smallvec", "windows-targets 0.52.6", ] @@ -2934,6 +3384,12 @@ dependencies = [ "rand_core 0.9.3", ] +[[package]] +name = "raw-window-handle" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" + [[package]] name = "rayon" version = "1.11.0" @@ -2963,6 +3419,15 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.17" @@ -3470,6 +3935,15 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smol_str" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd538fb6910ac1099850255cf94a94df6551fbdd602454387d0adb2d1ca6dead" +dependencies = [ + "serde", +] + [[package]] name = "socket2" version = "0.5.10" @@ -3589,6 +4063,7 @@ dependencies = [ "md-5", "memchr", "num-bigint", + "odbc-api", "once_cell", "paste", "percent-encoding", @@ -3734,6 +4209,7 @@ dependencies = [ "anyhow", "async-std", "dotenvy", + "either", "env_logger", "futures", "hex", @@ -4572,6 +5048,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "1.0.2" @@ -4604,6 +5090,12 @@ dependencies = [ "web-sys", ] +[[package]] +name = "widestring" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd7cf3379ca1aac9eea11fba24fd7e315d621f8dfe35c8d7d2be8b793726e07d" + [[package]] name = "winapi" version = "0.3.9" @@ -4700,6 +5192,15 @@ dependencies = [ "windows-link 0.1.3", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -4745,6 +5246,21 @@ dependencies = [ "windows-link 0.2.0", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -4793,6 +5309,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -4811,6 +5333,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -4829,6 +5357,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -4859,6 +5393,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -4877,6 +5417,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -4895,6 +5441,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -4913,6 +5465,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -4931,6 +5489,46 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +[[package]] +name = "winit" +version = "0.30.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66d4b9ed69c4009f6321f762d6e61ad8a2389cd431b97cb1e146812e9e6c732" +dependencies = [ + "android-activity", + "atomic-waker", + "bitflags 2.9.4", + "block2", + "calloop", + "cfg_aliases", + "concurrent-queue", + "core-foundation", + "core-graphics", + "cursor-icon", + "dpi", + "js-sys", + "libc", + "ndk", + "objc2", + "objc2-app-kit", + "objc2-foundation", + "objc2-ui-kit", + "orbclient", + "pin-project", + "raw-window-handle", + "redox_syscall 0.4.1", + "rustix 0.38.44", + "smol_str", + "tracing", + "unicode-segmentation", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "web-time", + "windows-sys 0.52.0", + "xkbcommon-dl", +] + [[package]] name = "winnow" version = "0.7.13" @@ -4961,6 +5559,25 @@ dependencies = [ "tap", ] +[[package]] +name = "xkbcommon-dl" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039de8032a9a8856a6be89cea3e5d12fdd82306ab7c94d74e6deab2460651c5" +dependencies = [ + "bitflags 2.9.4", + "dlib", + "log", + "once_cell", + "xkeysym", +] + +[[package]] +name = "xkeysym" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9cc00251562a284751c9973bace760d86c0276c471b4be569fe6b068ee97a56" + [[package]] name = "yoke" version = "0.8.0" diff --git a/Cargo.toml b/Cargo.toml index 1afcdeefb3..3fa4943e04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ offline = ["sqlx-macros/offline", "sqlx-core/offline"] # intended mainly for CI and docs all = ["tls", "all-databases", "all-types"] -all-databases = ["mysql", "sqlite", "postgres", "mssql", "any"] +all-databases = ["mysql", "sqlite", "postgres", "mssql", "odbc", "any"] all-types = [ "bigdecimal", "decimal", @@ -131,6 +131,7 @@ postgres = ["sqlx-core/postgres", "sqlx-macros/postgres"] mysql = ["sqlx-core/mysql", "sqlx-macros/mysql"] sqlite = ["sqlx-core/sqlite", "sqlx-macros/sqlite"] mssql = ["sqlx-core/mssql", "sqlx-macros/mssql"] +odbc = ["sqlx-core/odbc"] # types bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"] @@ -168,6 +169,7 @@ rand = "0.8" rand_xoshiro = "0.7.0" hex = "0.4.3" tempdir = "0.3.7" +either = "1.6.1" # Needed to test SQLCipher libsqlite3-sys = { version = "0", features = [ "bundled-sqlcipher-vendored-openssl", @@ -191,6 +193,11 @@ name = "any-pool" path = "tests/any/pool.rs" required-features = ["any"] +[[test]] +name = "any-odbc" +path = "tests/any/odbc.rs" +required-features = ["any", "odbc"] + # # Migrations # @@ -326,6 +333,20 @@ name = "mssql" path = "tests/mssql/mssql.rs" required-features = ["mssql"] +# +# ODBC +# + +[[test]] +name = "odbc" +path = "tests/odbc/odbc.rs" +required-features = ["odbc"] + +[[test]] +name = "odbc-types" +path = "tests/odbc/types.rs" +required-features = ["odbc"] + [[test]] name = "mssql-types" path = "tests/mssql/types.rs" diff --git a/contrib/ide/vscode/settings.json b/contrib/ide/vscode/settings.json index 3d1cbfd8a7..977d31f25d 100644 --- a/contrib/ide/vscode/settings.json +++ b/contrib/ide/vscode/settings.json @@ -1,3 +1,16 @@ { - "rust-analyzer.assist.importMergeBehaviour": "last" -} + "rust-analyzer.check.command": "clippy", + "rust-analyzer.cargo.features": [ + "any", + "all-databases", + "macros", + "migrate", + "all-types", + "runtime-actix-rustls" + ], + "rust-analyzer.linkedProjects": [ + "./Cargo.toml", + "./sqlx-core/Cargo.toml", + "./sqlx-macros/Cargo.toml" + ] +} \ No newline at end of file diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 6919a3ebcc..6d03f5fad8 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -20,7 +20,7 @@ default = ["migrate"] migrate = ["sha2", "crc"] # databases -all-databases = ["postgres", "mysql", "sqlite", "mssql", "any"] +all-databases = ["postgres", "mysql", "sqlite", "mssql", "odbc", "any"] postgres = [ "md-5", "sha2", @@ -46,6 +46,7 @@ mysql = [ sqlite = ["libsqlite3-sys", "futures-executor", "flume"] mssql = ["uuid", "encoding_rs", "regex"] any = [] +odbc = ["odbc-api", "futures-executor", "flume"] # types all-types = [ @@ -172,6 +173,7 @@ hkdf = { version = "0.12.0", optional = true } event-listener = "5.4.0" dotenvy = "0.15" +odbc-api = { version = "19.0.1", optional = true } [dev-dependencies] sqlx = { package = "sqlx-oldapi", path = "..", features = ["postgres", "sqlite", "mysql", "runtime-tokio-rustls"] } diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs index 41b0b72946..ab1cecb569 100644 --- a/sqlx-core/src/any/arguments.rs +++ b/sqlx-core/src/any/arguments.rs @@ -46,6 +46,9 @@ pub(crate) enum AnyArgumentBufferKind<'q> { crate::mssql::MssqlArguments, std::marker::PhantomData<&'q ()>, ), + + #[cfg(feature = "odbc")] + Odbc(crate::odbc::OdbcArguments, std::marker::PhantomData<&'q ()>), } // control flow inferred type bounds would be fun @@ -131,3 +134,24 @@ impl<'q> From> for crate::postgres::PgArguments { } } } + +#[cfg(feature = "odbc")] +#[allow(irrefutable_let_patterns)] +impl<'q> From> for crate::odbc::OdbcArguments { + fn from(args: AnyArguments<'q>) -> Self { + let mut buf = AnyArgumentBuffer(AnyArgumentBufferKind::Odbc( + Default::default(), + std::marker::PhantomData, + )); + + for value in args.values { + let _ = value.encode_by_ref(&mut buf); + } + + if let AnyArgumentBufferKind::Odbc(args, _) = buf.0 { + args + } else { + unreachable!() + } + } +} diff --git a/sqlx-core/src/any/column.rs b/sqlx-core/src/any/column.rs index 22049033a8..0b0d5499b7 100644 --- a/sqlx-core/src/any/column.rs +++ b/sqlx-core/src/any/column.rs @@ -13,6 +13,9 @@ use crate::sqlite::{SqliteColumn, SqliteRow, SqliteStatement}; #[cfg(feature = "mssql")] use crate::mssql::{MssqlColumn, MssqlRow, MssqlStatement}; +#[cfg(feature = "odbc")] +use crate::odbc::{OdbcColumn, OdbcRow, OdbcStatement}; + #[derive(Debug, Clone)] pub struct AnyColumn { pub(crate) kind: AnyColumnKind, @@ -34,6 +37,9 @@ pub(crate) enum AnyColumnKind { #[cfg(feature = "mssql")] Mssql(MssqlColumn), + + #[cfg(feature = "odbc")] + Odbc(OdbcColumn), } impl Column for AnyColumn { @@ -52,6 +58,9 @@ impl Column for AnyColumn { #[cfg(feature = "mssql")] AnyColumnKind::Mssql(row) => row.ordinal(), + + #[cfg(feature = "odbc")] + AnyColumnKind::Odbc(row) => row.ordinal(), } } @@ -68,6 +77,9 @@ impl Column for AnyColumn { #[cfg(feature = "mssql")] AnyColumnKind::Mssql(row) => row.name(), + + #[cfg(feature = "odbc")] + AnyColumnKind::Odbc(row) => row.name(), } } @@ -76,368 +88,26 @@ impl Column for AnyColumn { } } -// FIXME: Find a nice way to auto-generate the below or petition Rust to add support for #[cfg] -// to trait bounds - -// all 4 - -#[cfg(all( - feature = "postgres", - feature = "mysql", - feature = "mssql", - feature = "sqlite" -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - feature = "postgres", - feature = "mysql", - feature = "mssql", - feature = "sqlite" -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -// only 3 (4) - -#[cfg(all( - not(feature = "mssql"), - all(feature = "postgres", feature = "mysql", feature = "sqlite") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "mssql"), - all(feature = "postgres", feature = "mysql", feature = "sqlite") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "mysql"), - all(feature = "postgres", feature = "mssql", feature = "sqlite") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "mysql"), - all(feature = "postgres", feature = "mssql", feature = "sqlite") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "sqlite"), - all(feature = "postgres", feature = "mysql", feature = "mssql") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "sqlite"), - all(feature = "postgres", feature = "mysql", feature = "mssql") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "postgres"), - all(feature = "sqlite", feature = "mysql", feature = "mssql") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(feature = "postgres"), - all(feature = "sqlite", feature = "mysql", feature = "mssql") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -// only 2 (6) - -#[cfg(all( - not(any(feature = "mssql", feature = "sqlite")), - all(feature = "postgres", feature = "mysql") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mssql", feature = "sqlite")), - all(feature = "postgres", feature = "mysql") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "sqlite")), - all(feature = "postgres", feature = "mssql") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "sqlite")), - all(feature = "postgres", feature = "mssql") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql")), - all(feature = "postgres", feature = "sqlite") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql")), - all(feature = "postgres", feature = "sqlite") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "sqlite")), - all(feature = "mssql", feature = "mysql") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "sqlite")), - all(feature = "mssql", feature = "mysql") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mysql")), - all(feature = "mssql", feature = "sqlite") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mysql")), - all(feature = "mssql", feature = "sqlite") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql")), - all(feature = "mysql", feature = "sqlite") -))] -pub trait AnyColumnIndex: - ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql")), - all(feature = "mysql", feature = "sqlite") -))] -impl AnyColumnIndex for I where - I: ColumnIndex - + for<'q> ColumnIndex> - + ColumnIndex - + for<'q> ColumnIndex> -{ -} - -// only 1 (4) - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), - feature = "postgres" -))] -pub trait AnyColumnIndex: ColumnIndex + for<'q> ColumnIndex> {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), - feature = "postgres" -))] -impl AnyColumnIndex for I where - I: ColumnIndex + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), - feature = "mysql" -))] -pub trait AnyColumnIndex: ColumnIndex + for<'q> ColumnIndex> {} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), - feature = "mysql" -))] -impl AnyColumnIndex for I where - I: ColumnIndex + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), - feature = "mssql" -))] -pub trait AnyColumnIndex: ColumnIndex + for<'q> ColumnIndex> {} - -#[cfg(all( - not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), - feature = "mssql" -))] -impl AnyColumnIndex for I where - I: ColumnIndex + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "postgres")), - feature = "sqlite" -))] -pub trait AnyColumnIndex: - ColumnIndex + for<'q> ColumnIndex> -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "postgres")), - feature = "sqlite" -))] -impl AnyColumnIndex for I where - I: ColumnIndex + for<'q> ColumnIndex> -{ +// Callback macro that generates the actual trait and impl +macro_rules! impl_any_column_index_for_databases { + ($(($row:ident, $stmt:ident)),+) => { + pub trait AnyColumnIndex: $(ColumnIndex<$row> + for<'q> ColumnIndex<$stmt<'q>> +)+ Sized {} + + impl AnyColumnIndex for I + where + I: $(ColumnIndex<$row> + for<'q> ColumnIndex<$stmt<'q>> +)+ Sized + {} + }; +} + +// Generate all combinations +for_all_feature_combinations! { + entries: [ + ("postgres", (PgRow, PgStatement)), + ("mysql", (MySqlRow, MySqlStatement)), + ("mssql", (MssqlRow, MssqlStatement)), + ("sqlite", (SqliteRow, SqliteStatement)), + ("odbc", (OdbcRow, OdbcStatement)), + ], + callback: impl_any_column_index_for_databases } diff --git a/sqlx-core/src/any/connection/establish.rs b/sqlx-core/src/any/connection/establish.rs index 290a499cdd..a77efcd410 100644 --- a/sqlx-core/src/any/connection/establish.rs +++ b/sqlx-core/src/any/connection/establish.rs @@ -34,6 +34,13 @@ impl AnyConnection { .await .map(AnyConnectionKind::Mssql) } + + #[cfg(feature = "odbc")] + AnyConnectOptionsKind::Odbc(options) => { + crate::odbc::OdbcConnection::connect_with(options) + .await + .map(AnyConnectionKind::Odbc) + } } .map(AnyConnection) } diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index 3eb67c139e..d49d23e543 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -49,6 +49,12 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { .fetch_many((query, arguments.map(Into::into))) .map_ok(|v| v.map_right(Into::into).map_left(Into::into)) .boxed(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn + .fetch_many((query, arguments.map(Into::into))) + .map_ok(|v| v.map_right(Into::into).map_left(Into::into)) + .boxed(), } } @@ -88,6 +94,12 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { .fetch_optional((query, arguments.map(Into::into))) .await? .map(Into::into), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn + .fetch_optional((query, arguments.map(Into::into))) + .await? + .map(Into::into), }) }) } @@ -114,6 +126,9 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(conn) => conn.prepare(sql).await.map(Into::into)?, + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.prepare(sql).await.map(Into::into)?, }) }) } @@ -138,6 +153,9 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(conn) => conn.describe(sql).await.map(map_describe)?, + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.describe(sql).await.map(map_describe)?, }) }) } diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index 33bc7d983f..a0d71378b5 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -15,6 +15,9 @@ use crate::mssql; #[cfg(feature = "mysql")] use crate::mysql; + +#[cfg(feature = "odbc")] +use crate::odbc; use crate::transaction::Transaction; mod establish; @@ -48,6 +51,9 @@ pub enum AnyConnectionKind { #[cfg(feature = "sqlite")] Sqlite(sqlite::SqliteConnection), + + #[cfg(feature = "odbc")] + Odbc(odbc::OdbcConnection), } impl AnyConnectionKind { @@ -64,6 +70,9 @@ impl AnyConnectionKind { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_) => AnyKind::Mssql, + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_) => AnyKind::Odbc, } } } @@ -78,6 +87,34 @@ impl AnyConnection { pub fn private_get_mut(&mut self) -> &mut AnyConnectionKind { &mut self.0 } + + /// Returns the runtime DBMS name for this connection. + /// + /// For most built-in drivers this returns a well-known constant string: + /// - Postgres -> "PostgreSQL" + /// - MySQL -> "MySQL" + /// - SQLite -> "SQLite" + /// - MSSQL -> "Microsoft SQL Server" + /// + /// For ODBC, this queries the driver at runtime via `SQL_DBMS_NAME`. + pub async fn dbms_name(&mut self) -> Result { + match &mut self.0 { + #[cfg(feature = "postgres")] + AnyConnectionKind::Postgres(_) => Ok("PostgreSQL".to_string()), + + #[cfg(feature = "mysql")] + AnyConnectionKind::MySql(_) => Ok("MySQL".to_string()), + + #[cfg(feature = "sqlite")] + AnyConnectionKind::Sqlite(_) => Ok("SQLite".to_string()), + + #[cfg(feature = "mssql")] + AnyConnectionKind::Mssql(_) => Ok("Microsoft SQL Server".to_string()), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.dbms_name().await, + } + } } macro_rules! delegate_to { @@ -94,6 +131,9 @@ macro_rules! delegate_to { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(conn) => conn.$method($($arg),*), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.$method($($arg),*), } }; } @@ -112,6 +152,9 @@ macro_rules! delegate_to_mut { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(conn) => conn.$method($($arg),*), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.$method($($arg),*), } }; } @@ -134,6 +177,9 @@ impl Connection for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(conn) => conn.close(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.close(), } } @@ -150,6 +196,9 @@ impl Connection for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(conn) => conn.close_hard(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => conn.close_hard(), } } @@ -178,6 +227,10 @@ impl Connection for AnyConnection { // no cache #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_) => 0, + + // no cache + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_) => 0, } } @@ -195,6 +248,10 @@ impl Connection for AnyConnection { // no cache #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_) => Box::pin(futures_util::future::ok(())), + + // no cache + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_) => Box::pin(futures_util::future::ok(())), } } @@ -236,3 +293,10 @@ impl From for AnyConnection { AnyConnection(AnyConnectionKind::Sqlite(conn)) } } + +#[cfg(feature = "odbc")] +impl From for AnyConnection { + fn from(conn: odbc::OdbcConnection) -> Self { + AnyConnection(AnyConnectionKind::Odbc(conn)) + } +} diff --git a/sqlx-core/src/any/decode.rs b/sqlx-core/src/any/decode.rs index 28d1872f6e..e90e0e2b95 100644 --- a/sqlx-core/src/any/decode.rs +++ b/sqlx-core/src/any/decode.rs @@ -1,6 +1,9 @@ use crate::decode::Decode; use crate::types::Type; +#[cfg(feature = "odbc")] +use crate::odbc::Odbc; + #[cfg(feature = "postgres")] use crate::postgres::Postgres; @@ -44,320 +47,37 @@ macro_rules! impl_any_decode { crate::any::value::AnyValueRefKind::Postgres(value) => { <$ty as crate::decode::Decode<'r, crate::postgres::Postgres>>::decode(value) } + + #[cfg(feature = "odbc")] + crate::any::value::AnyValueRefKind::Odbc(value) => { + <$ty as crate::decode::Decode<'r, crate::odbc::Odbc>>::decode(value) + } } } } }; } -// FIXME: Find a nice way to auto-generate the below or petition Rust to add support for #[cfg] -// to trait bounds - -// all 4 - -#[cfg(all( - feature = "postgres", - feature = "mysql", - feature = "mssql", - feature = "sqlite" -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Mssql> - + Type - + Decode<'r, Sqlite> - + Type -{ -} - -#[cfg(all( - feature = "postgres", - feature = "mysql", - feature = "mssql", - feature = "sqlite" -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Mssql> - + Type - + Decode<'r, Sqlite> - + Type -{ -} - -// only 3 (4) - -#[cfg(all( - not(feature = "mssql"), - all(feature = "postgres", feature = "mysql", feature = "sqlite") -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "mssql"), - all(feature = "postgres", feature = "mysql", feature = "sqlite") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "mysql"), - all(feature = "postgres", feature = "mssql", feature = "sqlite") -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> - + Type - + Decode<'r, Mssql> - + Type - + Decode<'r, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "mysql"), - all(feature = "postgres", feature = "mssql", feature = "sqlite") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> - + Type - + Decode<'r, Mssql> - + Type - + Decode<'r, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "sqlite"), - all(feature = "postgres", feature = "mysql", feature = "mssql") -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Mssql> - + Type -{ -} - -#[cfg(all( - not(feature = "sqlite"), - all(feature = "postgres", feature = "mysql", feature = "mssql") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Mssql> - + Type -{ -} - -#[cfg(all( - not(feature = "postgres"), - all(feature = "sqlite", feature = "mysql", feature = "mssql") -))] -pub trait AnyDecode<'r>: - Decode<'r, Sqlite> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Mssql> - + Type -{ -} - -#[cfg(all( - not(feature = "postgres"), - all(feature = "sqlite", feature = "mysql", feature = "mssql") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Sqlite> - + Type - + Decode<'r, MySql> - + Type - + Decode<'r, Mssql> - + Type -{ -} - -// only 2 (6) - -#[cfg(all( - not(any(feature = "mssql", feature = "sqlite")), - all(feature = "postgres", feature = "mysql") -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> + Type + Decode<'r, MySql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mssql", feature = "sqlite")), - all(feature = "postgres", feature = "mysql") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> + Type + Decode<'r, MySql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "sqlite")), - all(feature = "postgres", feature = "mssql") -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> + Type + Decode<'r, Mssql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "sqlite")), - all(feature = "postgres", feature = "mssql") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> + Type + Decode<'r, Mssql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql")), - all(feature = "postgres", feature = "sqlite") -))] -pub trait AnyDecode<'r>: - Decode<'r, Postgres> + Type + Decode<'r, Sqlite> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql")), - all(feature = "postgres", feature = "sqlite") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Postgres> + Type + Decode<'r, Sqlite> + Type -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "sqlite")), - all(feature = "mssql", feature = "mysql") -))] -pub trait AnyDecode<'r>: Decode<'r, Mssql> + Type + Decode<'r, MySql> + Type {} - -#[cfg(all( - not(any(feature = "postgres", feature = "sqlite")), - all(feature = "mssql", feature = "mysql") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Mssql> + Type + Decode<'r, MySql> + Type -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mysql")), - all(feature = "mssql", feature = "sqlite") -))] -pub trait AnyDecode<'r>: - Decode<'r, Mssql> + Type + Decode<'r, Sqlite> + Type -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mysql")), - all(feature = "mssql", feature = "sqlite") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, Mssql> + Type + Decode<'r, Sqlite> + Type -{ -} +// Callback macro that generates the actual trait and impl +macro_rules! impl_any_decode_for_databases { + ($($db:ident),+) => { + pub trait AnyDecode<'r>: $(Decode<'r, $db> + Type<$db> +)+ 'r {} -#[cfg(all( - not(any(feature = "postgres", feature = "mssql")), - all(feature = "mysql", feature = "sqlite") -))] -pub trait AnyDecode<'r>: - Decode<'r, MySql> + Type + Decode<'r, Sqlite> + Type -{ + impl<'r, T> AnyDecode<'r> for T + where + T: $(Decode<'r, $db> + Type<$db> +)+ 'r + {} + }; } -#[cfg(all( - not(any(feature = "postgres", feature = "mssql")), - all(feature = "mysql", feature = "sqlite") -))] -impl<'r, T> AnyDecode<'r> for T where - T: Decode<'r, MySql> + Type + Decode<'r, Sqlite> + Type -{ +// Generate all combinations +for_all_feature_combinations! { + entries: [ + ("postgres", Postgres), + ("mysql", MySql), + ("mssql", Mssql), + ("sqlite", Sqlite), + ("odbc", Odbc), + ], + callback: impl_any_decode_for_databases } - -// only 1 (4) - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), - feature = "postgres" -))] -pub trait AnyDecode<'r>: Decode<'r, Postgres> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), - feature = "postgres" -))] -impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, Postgres> + Type {} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), - feature = "mysql" -))] -pub trait AnyDecode<'r>: Decode<'r, MySql> + Type {} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), - feature = "mysql" -))] -impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, MySql> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), - feature = "mssql" -))] -pub trait AnyDecode<'r>: Decode<'r, Mssql> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), - feature = "mssql" -))] -impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, Mssql> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "postgres")), - feature = "sqlite" -))] -pub trait AnyDecode<'r>: Decode<'r, Sqlite> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "postgres")), - feature = "sqlite" -))] -impl<'r, T> AnyDecode<'r> for T where T: Decode<'r, Sqlite> + Type {} diff --git a/sqlx-core/src/any/encode.rs b/sqlx-core/src/any/encode.rs index edde3bcd70..2ddf0d89ab 100644 --- a/sqlx-core/src/any/encode.rs +++ b/sqlx-core/src/any/encode.rs @@ -1,6 +1,9 @@ use crate::encode::Encode; use crate::types::Type; +#[cfg(feature = "odbc")] +use crate::odbc::Odbc; + #[cfg(feature = "postgres")] use crate::postgres::Postgres; @@ -39,6 +42,15 @@ macro_rules! impl_any_encode { #[cfg(feature = "sqlite")] crate::any::arguments::AnyArgumentBufferKind::Sqlite(args) => args.add(self), + + #[cfg(feature = "odbc")] + crate::any::arguments::AnyArgumentBufferKind::Odbc(args, _) => { + let _ = + <$ty as crate::encode::Encode<'q, crate::odbc::Odbc>>::encode_by_ref( + self, + &mut args.values, + ); + } } // unused @@ -48,314 +60,26 @@ macro_rules! impl_any_encode { }; } -// FIXME: Find a nice way to auto-generate the below or petition Rust to add support for #[cfg] -// to trait bounds - -// all 4 - -#[cfg(all( - feature = "postgres", - feature = "mysql", - feature = "mssql", - feature = "sqlite" -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Mssql> - + Type - + Encode<'q, Sqlite> - + Type -{ -} - -#[cfg(all( - feature = "postgres", - feature = "mysql", - feature = "mssql", - feature = "sqlite" -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Mssql> - + Type - + Encode<'q, Sqlite> - + Type -{ -} - -// only 3 (4) - -#[cfg(all( - not(feature = "mssql"), - all(feature = "postgres", feature = "mysql", feature = "sqlite") -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "mssql"), - all(feature = "postgres", feature = "mysql", feature = "sqlite") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "mysql"), - all(feature = "postgres", feature = "mssql", feature = "sqlite") -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> - + Type - + Encode<'q, Mssql> - + Type - + Encode<'q, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "mysql"), - all(feature = "postgres", feature = "mssql", feature = "sqlite") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> - + Type - + Encode<'q, Mssql> - + Type - + Encode<'q, Sqlite> - + Type -{ -} - -#[cfg(all( - not(feature = "sqlite"), - all(feature = "postgres", feature = "mysql", feature = "mssql") -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Mssql> - + Type -{ -} - -#[cfg(all( - not(feature = "sqlite"), - all(feature = "postgres", feature = "mysql", feature = "mssql") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Mssql> - + Type -{ -} - -#[cfg(all( - not(feature = "postgres"), - all(feature = "sqlite", feature = "mysql", feature = "mssql") -))] -pub trait AnyEncode<'q>: - Encode<'q, Sqlite> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Mssql> - + Type -{ -} - -#[cfg(all( - not(feature = "postgres"), - all(feature = "sqlite", feature = "mysql", feature = "mssql") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Sqlite> - + Type - + Encode<'q, MySql> - + Type - + Encode<'q, Mssql> - + Type -{ -} - -// only 2 (6) - -#[cfg(all( - not(any(feature = "mssql", feature = "sqlite")), - all(feature = "postgres", feature = "mysql") -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> + Type + Encode<'q, MySql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mssql", feature = "sqlite")), - all(feature = "postgres", feature = "mysql") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> + Type + Encode<'q, MySql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "sqlite")), - all(feature = "postgres", feature = "mssql") -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> + Type + Encode<'q, Mssql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "sqlite")), - all(feature = "postgres", feature = "mssql") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> + Type + Encode<'q, Mssql> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql")), - all(feature = "postgres", feature = "sqlite") -))] -pub trait AnyEncode<'q>: - Encode<'q, Postgres> + Type + Encode<'q, Sqlite> + Type -{ -} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql")), - all(feature = "postgres", feature = "sqlite") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Postgres> + Type + Encode<'q, Sqlite> + Type -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "sqlite")), - all(feature = "mssql", feature = "mysql") -))] -pub trait AnyEncode<'q>: Encode<'q, Mssql> + Type + Encode<'q, MySql> + Type {} - -#[cfg(all( - not(any(feature = "postgres", feature = "sqlite")), - all(feature = "mssql", feature = "mysql") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Mssql> + Type + Encode<'q, MySql> + Type -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mysql")), - all(feature = "mssql", feature = "sqlite") -))] -pub trait AnyEncode<'q>: - Encode<'q, Mssql> + Type + Encode<'q, Sqlite> + Type -{ -} - -#[cfg(all( - not(any(feature = "postgres", feature = "mysql")), - all(feature = "mssql", feature = "sqlite") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, Mssql> + Type + Encode<'q, Sqlite> + Type -{ -} +// Callback macro that generates the actual trait and impl +macro_rules! impl_any_encode_for_databases { + ($($db:ident),+) => { + pub trait AnyEncode<'q>: $(Encode<'q, $db> + Type<$db> +)+ Send {} -#[cfg(all( - not(any(feature = "postgres", feature = "mssql")), - all(feature = "mysql", feature = "sqlite") -))] -pub trait AnyEncode<'q>: - Encode<'q, MySql> + Type + Encode<'q, Sqlite> + Type -{ + impl<'q, T> AnyEncode<'q> for T + where + T: $(Encode<'q, $db> + Type<$db> +)+ Send + {} + }; } -#[cfg(all( - not(any(feature = "postgres", feature = "mssql")), - all(feature = "mysql", feature = "sqlite") -))] -impl<'q, T> AnyEncode<'q> for T where - T: Encode<'q, MySql> + Type + Encode<'q, Sqlite> + Type -{ +// Generate all combinations +for_all_feature_combinations! { + entries: [ + ("postgres", Postgres), + ("mysql", MySql), + ("mssql", Mssql), + ("sqlite", Sqlite), + ("odbc", Odbc), + ], + callback: impl_any_encode_for_databases } - -// only 1 (4) - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), - feature = "postgres" -))] -pub trait AnyEncode<'q>: Encode<'q, Postgres> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "sqlite")), - feature = "postgres" -))] -impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, Postgres> + Type {} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), - feature = "mysql" -))] -pub trait AnyEncode<'q>: Encode<'q, MySql> + Type {} - -#[cfg(all( - not(any(feature = "postgres", feature = "mssql", feature = "sqlite")), - feature = "mysql" -))] -impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, MySql> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), - feature = "mssql" -))] -pub trait AnyEncode<'q>: Encode<'q, Mssql> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "postgres", feature = "sqlite")), - feature = "mssql" -))] -impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, Mssql> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "postgres")), - feature = "sqlite" -))] -pub trait AnyEncode<'q>: Encode<'q, Sqlite> + Type {} - -#[cfg(all( - not(any(feature = "mysql", feature = "mssql", feature = "postgres")), - feature = "sqlite" -))] -impl<'q, T> AnyEncode<'q> for T where T: Encode<'q, Sqlite> + Type {} diff --git a/sqlx-core/src/any/feature_combinations.rs b/sqlx-core/src/any/feature_combinations.rs new file mode 100644 index 0000000000..0b2730b3e2 --- /dev/null +++ b/sqlx-core/src/any/feature_combinations.rs @@ -0,0 +1,34 @@ +// Shared recursive macro to generate all non-empty combinations of feature flags. +// Pass a list of entries with a feature name and an arbitrary payload which is +// forwarded to the callback when that feature is selected. +// +// Usage: +// for_all_feature_combinations!{ +// entries: [("postgres", Postgres), ("mysql", MySql)], +// callback: my_callback +// } +// will expand to (for the active feature configuration): +// #[cfg(all(feature="postgres"), not(feature="mysql"))] my_callback!(Postgres); +// #[cfg(all(feature="mysql"), not(feature="postgres"))] my_callback!(MySql); +// #[cfg(all(feature="postgres", feature="mysql"))] my_callback!(Postgres, MySql); +// and so on for all non-empty subsets. +#[macro_export] +macro_rules! for_all_feature_combinations { + ( entries: [ $( ( $feat:literal, $payload:tt ) ),* $(,)? ], callback: $callback:ident ) => { + $crate::for_all_feature_combinations!(@recurse [] [] [ $( ( $feat, $payload ) )* ] $callback); + }; + + (@recurse [$($yes:tt)*] [$($no:tt)*] [ ( $feat:literal, $payload:tt ) $($rest:tt)* ] $callback:ident ) => { + $crate::for_all_feature_combinations!(@recurse [ $($yes)* ( $feat, $payload ) ] [ $($no)* ] [ $($rest)* ] $callback); + $crate::for_all_feature_combinations!(@recurse [ $($yes)* ] [ $($no)* $feat ] [ $($rest)* ] $callback); + }; + + // Base case: at least one selected + (@recurse [ $( ( $yfeat:literal, $ypayload:tt ) )+ ] [ $( $nfeat:literal )* ] [] $callback:ident ) => { + #[cfg(all( $( feature = $yfeat ),+ $(, not(feature = $nfeat ))* ))] + $callback!( $( $ypayload ),+ ); + }; + + // Base case: none selected (skip) + (@recurse [] [ $( $nfeat:literal )* ] [] $callback:ident ) => {}; +} diff --git a/sqlx-core/src/any/kind.rs b/sqlx-core/src/any/kind.rs index b8e7b3fb50..2797c9e0ba 100644 --- a/sqlx-core/src/any/kind.rs +++ b/sqlx-core/src/any/kind.rs @@ -14,6 +14,9 @@ pub enum AnyKind { #[cfg(feature = "mssql")] Mssql, + + #[cfg(feature = "odbc")] + Odbc, } impl FromStr for AnyKind { @@ -61,7 +64,27 @@ impl FromStr for AnyKind { Err(Error::Configuration("database URL has the scheme of a MSSQL database but the `mssql` feature is not enabled".into())) } + #[cfg(feature = "odbc")] + _ if url.starts_with("odbc:") || Self::is_odbc_connection_string(url) => { + Ok(AnyKind::Odbc) + } + + #[cfg(not(feature = "odbc"))] + _ if url.starts_with("odbc:") || Self::is_odbc_connection_string(url) => { + Err(Error::Configuration("database URL has the scheme of an ODBC database but the `odbc` feature is not enabled".into())) + } + _ => Err(Error::Configuration(format!("unrecognized database url: {:?}", url).into())) } } } + +impl AnyKind { + fn is_odbc_connection_string(s: &str) -> bool { + let s_upper = s.to_uppercase(); + s_upper.starts_with("DSN=") + || s_upper.starts_with("DRIVER=") + || s_upper.starts_with("FILEDSN=") + || (s_upper.contains("DRIVER=") && s_upper.contains(';')) + } +} diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index 15458d57bf..3de37030db 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -22,6 +22,9 @@ impl MigrateDatabase for Any { #[cfg(feature = "mssql")] AnyKind::Mssql => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyKind::Odbc => unimplemented!(), } }) } @@ -40,6 +43,9 @@ impl MigrateDatabase for Any { #[cfg(feature = "mssql")] AnyKind::Mssql => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyKind::Odbc => unimplemented!(), } }) } @@ -58,6 +64,9 @@ impl MigrateDatabase for Any { #[cfg(feature = "mssql")] AnyKind::Mssql => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyKind::Odbc => unimplemented!(), } }) } @@ -77,6 +86,9 @@ impl Migrate for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => unimplemented!(), } } @@ -94,6 +106,9 @@ impl Migrate for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => unimplemented!(), } } @@ -110,6 +125,9 @@ impl Migrate for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => unimplemented!(), } } @@ -133,6 +151,12 @@ impl Migrate for AnyConnection { let _ = migration; unimplemented!() } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => { + let _ = migration; + unimplemented!() + } } } @@ -149,6 +173,9 @@ impl Migrate for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => unimplemented!(), } } @@ -165,6 +192,9 @@ impl Migrate for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => unimplemented!(), } } @@ -181,6 +211,9 @@ impl Migrate for AnyConnection { #[cfg(feature = "mssql")] AnyConnectionKind::Mssql(_conn) => unimplemented!(), + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => unimplemented!(), } } @@ -203,6 +236,12 @@ impl Migrate for AnyConnection { let _ = migration; unimplemented!() } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => { + let _ = migration; + unimplemented!() + } } } @@ -225,6 +264,12 @@ impl Migrate for AnyConnection { let _ = migration; unimplemented!() } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(_conn) => { + let _ = migration; + unimplemented!() + } } } } diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs index 385c1f9cf1..f51fef7869 100644 --- a/sqlx-core/src/any/mod.rs +++ b/sqlx-core/src/any/mod.rs @@ -2,6 +2,9 @@ use crate::executor::Executor; +#[macro_use] +mod feature_combinations; + #[macro_use] mod decode; @@ -67,7 +70,7 @@ impl_into_maybe_pool!(Any, AnyConnection); // required because some databases have a different handling of NULL impl<'q, T> crate::encode::Encode<'q, Any> for Option where - T: AnyEncode<'q> + 'q, + T: AnyEncode<'q> + 'q + Sync, { fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer<'q>) -> crate::encode::IsNull { match &mut buf.0 { @@ -82,6 +85,14 @@ where #[cfg(feature = "sqlite")] arguments::AnyArgumentBufferKind::Sqlite(args) => args.add(self), + + #[cfg(feature = "odbc")] + arguments::AnyArgumentBufferKind::Odbc(args, _) => { + let _ = as crate::encode::Encode<'q, crate::odbc::Odbc>>::encode_by_ref( + self, + &mut args.values, + ); + } } // unused diff --git a/sqlx-core/src/any/options.rs b/sqlx-core/src/any/options.rs index 3e81198b1b..5ece96c891 100644 --- a/sqlx-core/src/any/options.rs +++ b/sqlx-core/src/any/options.rs @@ -18,6 +18,8 @@ use crate::sqlite::SqliteConnectOptions; use crate::any::kind::AnyKind; #[cfg(feature = "mssql")] use crate::mssql::MssqlConnectOptions; +#[cfg(feature = "odbc")] +use crate::odbc::OdbcConnectOptions; /// Opaque options for connecting to a database. These may only be constructed by parsing from /// a connection url. @@ -43,6 +45,9 @@ impl AnyConnectOptions { #[cfg(feature = "mssql")] AnyConnectOptionsKind::Mssql(_) => AnyKind::Mssql, + + #[cfg(feature = "odbc")] + AnyConnectOptionsKind::Odbc(_) => AnyKind::Odbc, } } } @@ -108,6 +113,9 @@ try_from_any_connect_options_to!( #[cfg(feature = "mssql")] try_from_any_connect_options_to!(MssqlConnectOptions, AnyConnectOptionsKind::Mssql, "mssql"); +#[cfg(feature = "odbc")] +try_from_any_connect_options_to!(OdbcConnectOptions, AnyConnectOptionsKind::Odbc, "odbc"); + #[derive(Debug, Clone)] pub(crate) enum AnyConnectOptionsKind { #[cfg(feature = "postgres")] @@ -121,6 +129,9 @@ pub(crate) enum AnyConnectOptionsKind { #[cfg(feature = "mssql")] Mssql(MssqlConnectOptions), + + #[cfg(feature = "odbc")] + Odbc(OdbcConnectOptions), } #[cfg(feature = "postgres")] @@ -151,6 +162,13 @@ impl From for AnyConnectOptions { } } +#[cfg(feature = "odbc")] +impl From for AnyConnectOptions { + fn from(options: OdbcConnectOptions) -> Self { + Self(AnyConnectOptionsKind::Odbc(options)) + } +} + impl FromStr for AnyConnectOptions { type Err = Error; @@ -171,6 +189,9 @@ impl FromStr for AnyConnectOptions { #[cfg(feature = "mssql")] AnyKind::Mssql => MssqlConnectOptions::from_str(url).map(AnyConnectOptionsKind::Mssql), + + #[cfg(feature = "odbc")] + AnyKind::Odbc => OdbcConnectOptions::from_str(url).map(AnyConnectOptionsKind::Odbc), } .map(AnyConnectOptions) } @@ -205,6 +226,11 @@ impl ConnectOptions for AnyConnectOptions { AnyConnectOptionsKind::Mssql(o) => { o.log_statements(level); } + + #[cfg(feature = "odbc")] + AnyConnectOptionsKind::Odbc(o) => { + o.log_statements(level); + } }; self } @@ -230,6 +256,11 @@ impl ConnectOptions for AnyConnectOptions { AnyConnectOptionsKind::Mssql(o) => { o.log_slow_statements(level, duration); } + + #[cfg(feature = "odbc")] + AnyConnectOptionsKind::Odbc(o) => { + o.log_slow_statements(level, duration); + } }; self } diff --git a/sqlx-core/src/any/row.rs b/sqlx-core/src/any/row.rs index b48f07b585..2a7ba4b2e5 100644 --- a/sqlx-core/src/any/row.rs +++ b/sqlx-core/src/any/row.rs @@ -21,6 +21,9 @@ use crate::sqlite::SqliteRow; #[cfg(feature = "mssql")] use crate::mssql::MssqlRow; +#[cfg(feature = "odbc")] +use crate::odbc::OdbcRow; + pub struct AnyRow { pub(crate) kind: AnyRowKind, pub(crate) columns: Vec, @@ -40,6 +43,9 @@ pub(crate) enum AnyRowKind { #[cfg(feature = "mssql")] Mssql(MssqlRow), + + #[cfg(feature = "odbc")] + Odbc(OdbcRow), } impl Row for AnyRow { @@ -70,6 +76,9 @@ impl Row for AnyRow { #[cfg(feature = "mssql")] AnyRowKind::Mssql(row) => row.try_get_raw(index).map(Into::into), + + #[cfg(feature = "odbc")] + AnyRowKind::Odbc(row) => row.try_get_raw(index).map(Into::into), } } @@ -110,6 +119,9 @@ where #[cfg(feature = "mssql")] AnyRowKind::Mssql(row) => self.index(row), + + #[cfg(feature = "odbc")] + AnyRowKind::Odbc(row) => self.index(row), } } } diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index 248e25847c..b61b679709 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -32,6 +32,11 @@ impl TransactionManager for AnyTransactionManager { AnyConnectionKind::Mssql(conn) => { ::TransactionManager::begin(conn) } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => { + ::TransactionManager::begin(conn) + } } } @@ -56,6 +61,11 @@ impl TransactionManager for AnyTransactionManager { AnyConnectionKind::Mssql(conn) => { ::TransactionManager::commit(conn) } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => { + ::TransactionManager::commit(conn) + } } } @@ -80,6 +90,11 @@ impl TransactionManager for AnyTransactionManager { AnyConnectionKind::Mssql(conn) => { ::TransactionManager::rollback(conn) } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => { + ::TransactionManager::rollback(conn) + } } } @@ -104,6 +119,11 @@ impl TransactionManager for AnyTransactionManager { AnyConnectionKind::Mssql(conn) => { ::TransactionManager::start_rollback(conn) } + + #[cfg(feature = "odbc")] + AnyConnectionKind::Odbc(conn) => { + ::TransactionManager::start_rollback(conn) + } } } } diff --git a/sqlx-core/src/any/type.rs b/sqlx-core/src/any/type.rs index 3df4136b65..1fc4dc53a3 100644 --- a/sqlx-core/src/any/type.rs +++ b/sqlx-core/src/any/type.rs @@ -33,6 +33,11 @@ macro_rules! impl_any_type { crate::any::type_info::AnyTypeInfoKind::Mssql(ty) => { <$ty as crate::types::Type>::compatible(&ty) } + + #[cfg(feature = "odbc")] + crate::any::type_info::AnyTypeInfoKind::Odbc(ty) => { + <$ty as crate::types::Type>::compatible(&ty) + } } } } diff --git a/sqlx-core/src/any/type_info.rs b/sqlx-core/src/any/type_info.rs index 789ad3bb06..60932429f1 100644 --- a/sqlx-core/src/any/type_info.rs +++ b/sqlx-core/src/any/type_info.rs @@ -14,6 +14,9 @@ use crate::sqlite::SqliteTypeInfo; #[cfg(feature = "mssql")] use crate::mssql::MssqlTypeInfo; +#[cfg(feature = "odbc")] +use crate::odbc::OdbcTypeInfo; + #[derive(Debug, Clone, PartialEq)] pub struct AnyTypeInfo(pub AnyTypeInfoKind); @@ -31,6 +34,9 @@ pub enum AnyTypeInfoKind { #[cfg(feature = "mssql")] Mssql(MssqlTypeInfo), + + #[cfg(feature = "odbc")] + Odbc(OdbcTypeInfo), } impl TypeInfo for AnyTypeInfo { @@ -47,6 +53,9 @@ impl TypeInfo for AnyTypeInfo { #[cfg(feature = "mssql")] AnyTypeInfoKind::Mssql(ty) => ty.is_null(), + + #[cfg(feature = "odbc")] + AnyTypeInfoKind::Odbc(ty) => ty.is_null(), } } @@ -63,6 +72,9 @@ impl TypeInfo for AnyTypeInfo { #[cfg(feature = "mssql")] AnyTypeInfoKind::Mssql(ty) => ty.name(), + + #[cfg(feature = "odbc")] + AnyTypeInfoKind::Odbc(ty) => ty.name(), } } } @@ -81,6 +93,9 @@ impl Display for AnyTypeInfo { #[cfg(feature = "mssql")] AnyTypeInfoKind::Mssql(ty) => ty.fmt(f), + + #[cfg(feature = "odbc")] + AnyTypeInfoKind::Odbc(ty) => ty.fmt(f), } } } diff --git a/sqlx-core/src/any/types.rs b/sqlx-core/src/any/types.rs index 6236e83ab0..c78958cbc5 100644 --- a/sqlx-core/src/any/types.rs +++ b/sqlx-core/src/any/types.rs @@ -22,6 +22,7 @@ impl_any_type!(bool); +impl_any_type!(i8); impl_any_type!(i16); impl_any_type!(i32); impl_any_type!(i64); @@ -40,6 +41,7 @@ impl_any_type!(u64); impl_any_encode!(bool); +impl_any_encode!(i8); impl_any_encode!(i16); impl_any_encode!(i32); impl_any_encode!(i64); @@ -58,6 +60,7 @@ impl_any_encode!(u64); impl_any_decode!(bool); +impl_any_decode!(i8); impl_any_decode!(i16); impl_any_decode!(i32); impl_any_decode!(i64); diff --git a/sqlx-core/src/any/value.rs b/sqlx-core/src/any/value.rs index 73dd01fdcf..23a06997c6 100644 --- a/sqlx-core/src/any/value.rs +++ b/sqlx-core/src/any/value.rs @@ -21,6 +21,9 @@ use crate::sqlite::{SqliteValue, SqliteValueRef}; #[cfg(feature = "mssql")] use crate::mssql::{MssqlValue, MssqlValueRef}; +#[cfg(feature = "odbc")] +use crate::odbc::{OdbcValue, OdbcValueRef}; + pub struct AnyValue { pub(crate) kind: AnyValueKind, pub(crate) type_info: AnyTypeInfo, @@ -38,6 +41,9 @@ pub(crate) enum AnyValueKind { #[cfg(feature = "mssql")] Mssql(MssqlValue), + + #[cfg(feature = "odbc")] + Odbc(OdbcValue), } pub struct AnyValueRef<'r> { @@ -57,6 +63,9 @@ pub(crate) enum AnyValueRefKind<'r> { #[cfg(feature = "mssql")] Mssql(MssqlValueRef<'r>), + + #[cfg(feature = "odbc")] + Odbc(OdbcValueRef<'r>), } impl Value for AnyValue { @@ -75,6 +84,9 @@ impl Value for AnyValue { #[cfg(feature = "mssql")] AnyValueKind::Mssql(value) => value.as_ref().into(), + + #[cfg(feature = "odbc")] + AnyValueKind::Odbc(value) => value.as_ref().into(), } } @@ -95,6 +107,9 @@ impl Value for AnyValue { #[cfg(feature = "mssql")] AnyValueKind::Mssql(value) => value.is_null(), + + #[cfg(feature = "odbc")] + AnyValueKind::Odbc(value) => value.is_null(), } } @@ -130,6 +145,9 @@ impl<'r> ValueRef<'r> for AnyValueRef<'r> { #[cfg(feature = "mssql")] AnyValueRefKind::Mssql(value) => ValueRef::to_owned(value).into(), + + #[cfg(feature = "odbc")] + AnyValueRefKind::Odbc(value) => ValueRef::to_owned(value).into(), } } @@ -150,6 +168,9 @@ impl<'r> ValueRef<'r> for AnyValueRef<'r> { #[cfg(feature = "mssql")] AnyValueRefKind::Mssql(value) => value.is_null(), + + #[cfg(feature = "odbc")] + AnyValueRefKind::Odbc(value) => value.is_null(), } } } diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index e670e3b4cd..6ff7de3564 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -55,6 +55,7 @@ impl + ?Sized> ColumnIndex for &'_ I { } } +#[allow(unused_macros)] macro_rules! impl_column_index_for_row { ($R:ident) => { impl crate::column::ColumnIndex<$R> for usize { @@ -71,6 +72,7 @@ macro_rules! impl_column_index_for_row { }; } +#[allow(unused_macros)] macro_rules! impl_column_index_for_statement { ($S:ident) => { impl crate::column::ColumnIndex<$S<'_>> for usize { diff --git a/sqlx-core/src/common/mod.rs b/sqlx-core/src/common/mod.rs index 63ed52815b..59bf8376f7 100644 --- a/sqlx-core/src/common/mod.rs +++ b/sqlx-core/src/common/mod.rs @@ -1,5 +1,6 @@ mod statement_cache; +#[allow(unused_imports)] pub(crate) use statement_cache::StatementCache; use std::fmt::{Debug, Formatter}; use std::ops::{Deref, DerefMut}; diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 8489b1127d..168b1d5779 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -83,7 +83,8 @@ pub mod migrate; feature = "postgres", feature = "mysql", feature = "mssql", - feature = "sqlite" + feature = "sqlite", + feature = "odbc" ), feature = "any" ))] @@ -105,6 +106,10 @@ pub mod mysql; #[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] pub mod mssql; +#[cfg(feature = "odbc")] +#[cfg_attr(docsrs, doc(cfg(feature = "odbc")))] +pub mod odbc; + // Implements test support with automatic DB management. #[cfg(feature = "migrate")] pub mod testing; @@ -112,5 +117,6 @@ pub mod testing; pub use sqlx_rt::test_block_on; /// sqlx uses ahash for increased performance, at the cost of reduced DoS resistance. +#[allow(unused_imports)] use ahash::AHashMap as HashMap; //type HashMap = std::collections::HashMap; diff --git a/sqlx-core/src/odbc/arguments.rs b/sqlx-core/src/odbc/arguments.rs new file mode 100644 index 0000000000..6dc371e7db --- /dev/null +++ b/sqlx-core/src/odbc/arguments.rs @@ -0,0 +1,72 @@ +use crate::arguments::Arguments; +use crate::encode::Encode; +use crate::odbc::Odbc; +use crate::types::Type; + +#[derive(Default, Debug)] +pub struct OdbcArguments { + pub(crate) values: Vec, +} + +#[derive(Debug, Clone)] +pub enum OdbcArgumentValue { + Text(String), + Bytes(Vec), + Int(i64), + Float(f64), + Null, +} + +impl<'q> Arguments<'q> for OdbcArguments { + type Database = Odbc; + + fn reserve(&mut self, additional: usize, _size: usize) { + self.values.reserve(additional); + } + + fn add(&mut self, value: T) + where + T: 'q + Send + Encode<'q, Self::Database> + Type, + { + let _ = value.encode(&mut self.values); + } +} + +// Encode implementations are now in the types module + +impl<'q, T> Encode<'q, Odbc> for Option +where + T: Encode<'q, Odbc> + Type + 'q, +{ + fn produces(&self) -> Option { + if let Some(v) = self { + v.produces() + } else { + T::type_info().into() + } + } + + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + match self { + Some(v) => v.encode(buf), + None => { + buf.push(OdbcArgumentValue::Null); + crate::encode::IsNull::Yes + } + } + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + match self { + Some(v) => v.encode_by_ref(buf), + None => { + buf.push(OdbcArgumentValue::Null); + crate::encode::IsNull::Yes + } + } + } + + fn size_hint(&self) -> usize { + self.as_ref().map_or(0, Encode::size_hint) + } +} diff --git a/sqlx-core/src/odbc/column.rs b/sqlx-core/src/odbc/column.rs new file mode 100644 index 0000000000..e127c16fa8 --- /dev/null +++ b/sqlx-core/src/odbc/column.rs @@ -0,0 +1,39 @@ +use crate::column::Column; +use crate::odbc::{Odbc, OdbcTypeInfo}; + +#[derive(Debug, Clone)] +pub struct OdbcColumn { + pub(crate) name: String, + pub(crate) type_info: OdbcTypeInfo, + pub(crate) ordinal: usize, +} + +impl Column for OdbcColumn { + type Database = Odbc; + + fn ordinal(&self) -> usize { + self.ordinal + } + fn name(&self) -> &str { + &self.name + } + fn type_info(&self) -> &OdbcTypeInfo { + &self.type_info + } +} + +#[cfg(feature = "any")] +impl From for crate::any::AnyColumn { + fn from(col: OdbcColumn) -> Self { + crate::any::AnyColumn { + kind: crate::any::column::AnyColumnKind::Odbc(col.clone()), + type_info: crate::any::AnyTypeInfo::from(col.type_info), + } + } +} + +mod private { + use super::OdbcColumn; + use crate::column::private_column::Sealed; + impl Sealed for OdbcColumn {} +} diff --git a/sqlx-core/src/odbc/connection/executor.rs b/sqlx-core/src/odbc/connection/executor.rs new file mode 100644 index 0000000000..45a32185f8 --- /dev/null +++ b/sqlx-core/src/odbc/connection/executor.rs @@ -0,0 +1,79 @@ +use crate::describe::Describe; +use crate::error::Error; +use crate::executor::{Execute, Executor}; +use crate::odbc::{Odbc, OdbcConnection, OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTypeInfo}; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::TryStreamExt; +use std::borrow::Cow; + +// run method removed; fetch_many implements streaming directly + +impl<'c> Executor<'c> for &'c mut OdbcConnection { + type Database = Odbc; + + fn fetch_many<'e, 'q: 'e, E>( + self, + mut query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database> + 'q, + { + let sql = query.sql().to_string(); + let args = query.take_arguments(); + Box::pin(try_stream! { + let rx = self.worker.execute_stream(&sql, args).await?; + while let Ok(item) = rx.recv_async().await { + r#yield!(item?); + } + Ok(()) + }) + } + + fn fetch_optional<'e, 'q: 'e, E>( + self, + query: E, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database> + 'q, + { + let mut s = self.fetch_many(query); + Box::pin(async move { + while let Some(v) = s.try_next().await? { + if let Either::Right(r) = v { + return Ok(Some(r)); + } + } + Ok(None) + }) + } + + fn prepare_with<'e, 'q: 'e>( + self, + sql: &'q str, + _parameters: &'e [OdbcTypeInfo], + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + Box::pin(async move { + let (_, columns, parameters) = self.worker.prepare(sql).await?; + Ok(OdbcStatement { + sql: Cow::Borrowed(sql), + columns, + parameters, + }) + }) + } + + #[doc(hidden)] + fn describe<'e, 'q: 'e>(self, _sql: &'q str) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + Box::pin(async move { Err(Error::Protocol("ODBC describe not implemented".into())) }) + } +} diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs new file mode 100644 index 0000000000..fc9751bae0 --- /dev/null +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -0,0 +1,77 @@ +use crate::connection::{Connection, LogSettings}; +use crate::error::Error; +use crate::odbc::{Odbc, OdbcConnectOptions}; +use crate::transaction::Transaction; +use futures_core::future::BoxFuture; +use futures_util::future; + +mod executor; +mod worker; + +pub(crate) use worker::ConnectionWorker; + +/// A connection to an ODBC-accessible database. +/// +/// ODBC uses a blocking C API, so we run all calls on a dedicated background thread +/// and communicate over channels to provide async access. +#[derive(Debug)] +pub struct OdbcConnection { + pub(crate) worker: ConnectionWorker, + pub(crate) log_settings: LogSettings, +} + +impl OdbcConnection { + pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result { + let worker = ConnectionWorker::establish(options.clone()).await?; + Ok(Self { + worker, + log_settings: LogSettings::default(), + }) + } + + /// Returns the name of the actual Database Management System (DBMS) this + /// connection is talking to as reported by the ODBC driver. + /// + /// This calls the underlying ODBC API `SQL_DBMS_NAME` via + /// `odbc_api::Connection::database_management_system_name`. + /// + /// See: https://docs.rs/odbc-api/19.0.1/odbc_api/struct.Connection.html#method.database_management_system_name + pub async fn dbms_name(&mut self) -> Result { + self.worker.get_dbms_name().await + } +} + +impl Connection for OdbcConnection { + type Database = Odbc; + + type Options = OdbcConnectOptions; + + fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { self.worker.shutdown().await }) + } + + fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { Ok(()) }) + } + + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(self.worker.ping()) + } + + fn begin(&mut self) -> BoxFuture<'_, Result, Error>> + where + Self: Sized, + { + Transaction::begin(self) + } + + #[doc(hidden)] + fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(future::ok(())) + } + + #[doc(hidden)] + fn should_flush(&self) -> bool { + false + } +} diff --git a/sqlx-core/src/odbc/connection/worker.rs b/sqlx-core/src/odbc/connection/worker.rs new file mode 100644 index 0000000000..b2e7f8b2db --- /dev/null +++ b/sqlx-core/src/odbc/connection/worker.rs @@ -0,0 +1,767 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::thread; + +use flume::{SendError, TrySendError}; +use futures_channel::oneshot; + +use crate::error::Error; +use crate::odbc::{ + OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, + OdbcTypeInfo, +}; +#[allow(unused_imports)] +use crate::row::Row as SqlxRow; +use either::Either; +#[allow(unused_imports)] +use odbc_api::handles::Statement as OdbcStatementTrait; +use odbc_api::handles::StatementImpl; +use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata}; + +// Type aliases for commonly used types +type OdbcConnection = odbc_api::Connection<'static>; +type TransactionResult = Result<(), Error>; +type TransactionSender = oneshot::Sender; +type ExecuteResult = Result, Error>; +type ExecuteSender = flume::Sender; +type PrepareResult = Result<(u64, Vec, usize), Error>; +type PrepareSender = oneshot::Sender; + +#[derive(Debug)] +pub(crate) struct ConnectionWorker { + command_tx: flume::Sender, + join_handle: Option>, +} + +#[derive(Debug)] +enum Command { + Ping { + tx: oneshot::Sender<()>, + }, + Shutdown { + tx: oneshot::Sender<()>, + }, + Begin { + tx: TransactionSender, + }, + Commit { + tx: TransactionSender, + }, + Rollback { + tx: TransactionSender, + }, + Execute { + sql: Box, + args: Option, + tx: ExecuteSender, + }, + Prepare { + sql: Box, + tx: PrepareSender, + }, + GetDbmsName { + tx: oneshot::Sender>, + }, +} + +impl Drop for ConnectionWorker { + fn drop(&mut self) { + self.shutdown_sync(); + } +} + +impl ConnectionWorker { + pub async fn establish(options: OdbcConnectOptions) -> Result { + let (command_tx, command_rx) = flume::bounded(64); + let (conn_tx, conn_rx) = oneshot::channel(); + let thread = thread::Builder::new() + .name("sqlx-odbc-conn".into()) + .spawn(move || worker_thread_main(options, command_rx, conn_tx))?; + + conn_rx.await.map_err(|_| Error::WorkerCrashed)??; + Ok(ConnectionWorker { + command_tx, + join_handle: Some(thread), + }) + } + + pub(crate) async fn ping(&mut self) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + send_command_and_await(&self.command_tx, Command::Ping { tx }, rx).await + } + + pub(crate) async fn shutdown(&mut self) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + send_command_and_await(&self.command_tx, Command::Shutdown { tx }, rx).await + } + + pub(crate) fn shutdown_sync(&mut self) { + // Send shutdown command without waiting for response + // Use try_send to avoid any potential blocking in Drop + + if let Some(join_handle) = self.join_handle.take() { + let (mut tx, _rx) = oneshot::channel(); + while let Err(TrySendError::Full(Command::Shutdown { tx: t })) = + self.command_tx.try_send(Command::Shutdown { tx }) + { + tx = t; + log::warn!("odbc worker thread queue is full, retrying..."); + thread::sleep(std::time::Duration::from_millis(10)); + } + if let Err(e) = join_handle.join() { + let err = e.downcast_ref::(); + log::error!( + "failed to join worker thread while shutting down: {:?}", + err + ); + } + } + } + + pub(crate) async fn begin(&mut self) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + send_transaction_command(&self.command_tx, Command::Begin { tx }, rx).await + } + + pub(crate) async fn commit(&mut self) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + send_transaction_command(&self.command_tx, Command::Commit { tx }, rx).await + } + + pub(crate) async fn rollback(&mut self) -> Result<(), Error> { + let (tx, rx) = oneshot::channel(); + send_transaction_command(&self.command_tx, Command::Rollback { tx }, rx).await + } + + pub(crate) async fn execute_stream( + &mut self, + sql: &str, + args: Option, + ) -> Result, Error>>, Error> { + let (tx, rx) = flume::bounded(64); + self.command_tx + .send_async(Command::Execute { + sql: sql.into(), + args, + tx, + }) + .await + .map_err(|_| Error::WorkerCrashed)?; + Ok(rx) + } + + pub(crate) async fn prepare( + &mut self, + sql: &str, + ) -> Result<(u64, Vec, usize), Error> { + let (tx, rx) = oneshot::channel(); + send_command_and_await( + &self.command_tx, + Command::Prepare { + sql: sql.into(), + tx, + }, + rx, + ) + .await? + } + + pub(crate) async fn get_dbms_name(&mut self) -> Result { + let (tx, rx) = oneshot::channel(); + send_command_and_await(&self.command_tx, Command::GetDbmsName { tx }, rx).await? + } +} + +// Worker thread implementation +fn worker_thread_main( + options: OdbcConnectOptions, + command_rx: flume::Receiver, + conn_tx: oneshot::Sender>, +) { + // Establish connection + let conn = match establish_connection(&options) { + Ok(conn) => { + log::debug!("ODBC connection established successfully"); + let _ = conn_tx.send(Ok(())); + conn + } + Err(e) => { + let _ = conn_tx.send(Err(e)); + return; + } + }; + + let mut stmt_manager = StatementManager::new(&conn); + + // Process commands + while let Ok(cmd) = command_rx.recv() { + log::trace!("Processing command: {:?}", cmd); + match process_command(cmd, &conn, &mut stmt_manager) { + Ok(CommandControlFlow::Continue) => {} + Ok(CommandControlFlow::Shutdown(shutdown_tx)) => { + log::debug!("Shutting down ODBC worker thread"); + drop(stmt_manager); + drop(conn); + send_oneshot(shutdown_tx, (), "shutdown"); + break; + } + Err(()) => { + log::error!("ODBC worker error while processing command"); + } + } + } + // Channel disconnected or shutdown command received, worker thread exits +} + +fn establish_connection(options: &OdbcConnectOptions) -> Result { + // Get or create the shared ODBC environment + // This ensures thread-safe initialization and prevents concurrent environment creation issues + let env = odbc_api::environment().map_err(|e| Error::Configuration(e.to_string().into()))?; + + let conn = env + .connect_with_connection_string(options.connection_string(), Default::default()) + .map_err(|e| Error::Configuration(e.to_string().into()))?; + + Ok(conn) +} + +/// Statement manager to handle preallocated statements +struct StatementManager<'conn> { + conn: &'conn OdbcConnection, + // Reusable statement for direct execution + direct_stmt: Option>>, + // Cache of prepared statements by SQL hash + prepared_cache: HashMap>>, +} + +impl<'conn> StatementManager<'conn> { + fn new(conn: &'conn OdbcConnection) -> Self { + log::debug!("Creating new statement manager"); + Self { + conn, + direct_stmt: None, + prepared_cache: HashMap::new(), + } + } + + fn get_or_create_direct_stmt( + &mut self, + ) -> Result<&mut Preallocated>, Error> { + if self.direct_stmt.is_none() { + log::debug!("Preallocating ODBC direct statement"); + self.direct_stmt = Some(self.conn.preallocate().map_err(Error::from)?); + } + Ok(self.direct_stmt.as_mut().unwrap()) + } + + fn get_or_create_prepared( + &mut self, + sql: &str, + ) -> Result<&mut odbc_api::Prepared>, Error> { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + sql.hash(&mut hasher); + let sql_hash = hasher.finish(); + + match self.prepared_cache.entry(sql_hash) { + Entry::Vacant(e) => { + log::trace!("Preparing statement for SQL: {}", sql); + let prepared = self.conn.prepare(sql)?; + Ok(e.insert(prepared)) + } + Entry::Occupied(e) => { + log::trace!("Using prepared statement for SQL: {}", sql); + Ok(e.into_mut()) + } + } + } +} +// Helper function to send results through oneshot channels with consistent error handling +fn send_oneshot(tx: oneshot::Sender, result: T, operation: &str) { + if tx.send(result).is_err() { + log::warn!("Failed to send {} result: receiver dropped", operation); + } +} + +fn send_stream_result( + tx: &ExecuteSender, + result: ExecuteResult, +) -> Result<(), SendError> { + tx.send(result) +} + +async fn send_command_and_await( + command_tx: &flume::Sender, + cmd: Command, + rx: oneshot::Receiver, +) -> Result { + command_tx + .send_async(cmd) + .await + .map_err(|_| Error::WorkerCrashed)?; + rx.await.map_err(|_| Error::WorkerCrashed) +} + +async fn send_transaction_command( + command_tx: &flume::Sender, + cmd: Command, + rx: oneshot::Receiver, +) -> Result<(), Error> { + send_command_and_await(command_tx, cmd, rx).await??; + Ok(()) +} + +// Utility functions for transaction operations +fn execute_transaction_operation( + conn: &OdbcConnection, + operation: F, + operation_name: &str, +) -> TransactionResult +where + F: FnOnce(&OdbcConnection) -> Result<(), odbc_api::Error>, +{ + log::trace!("{} odbc transaction", operation_name); + operation(conn) + .map_err(|e| Error::Protocol(format!("Failed to {} transaction: {}", operation_name, e))) +} + +#[derive(Debug)] +enum CommandControlFlow { + Shutdown(oneshot::Sender<()>), + Continue, +} + +type CommandResult = Result; + +// Returns a shutdown tx if the command is a shutdown command +fn process_command<'conn>( + cmd: Command, + conn: &'conn OdbcConnection, + stmt_manager: &mut StatementManager<'conn>, +) -> CommandResult { + match cmd { + Command::Ping { tx } => handle_ping(conn, tx), + Command::Begin { tx } => handle_begin(conn, tx), + Command::Commit { tx } => handle_commit(conn, tx), + Command::Rollback { tx } => handle_rollback(conn, tx), + Command::Shutdown { tx } => Ok(CommandControlFlow::Shutdown(tx)), + Command::Execute { sql, args, tx } => handle_execute(stmt_manager, sql, args, tx), + Command::Prepare { sql, tx } => handle_prepare(stmt_manager, sql, tx), + Command::GetDbmsName { tx } => handle_get_dbms_name(conn, tx), + } +} + +// Command handlers +fn handle_ping(conn: &OdbcConnection, tx: oneshot::Sender<()>) -> CommandResult { + match conn.execute("SELECT 1", (), None) { + Ok(_) => send_oneshot(tx, (), "ping"), + Err(e) => log::error!("Ping failed: {}", e), + } + Ok(CommandControlFlow::Continue) +} + +fn handle_begin(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { + let result = execute_transaction_operation(conn, |c| c.set_autocommit(false), "begin"); + send_oneshot(tx, result, "begin transaction"); + Ok(CommandControlFlow::Continue) +} + +fn handle_commit(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { + let result = execute_transaction_operation( + conn, + |c| c.commit().and_then(|_| c.set_autocommit(true)), + "commit", + ); + send_oneshot(tx, result, "commit transaction"); + Ok(CommandControlFlow::Continue) +} + +fn handle_rollback(conn: &OdbcConnection, tx: TransactionSender) -> CommandResult { + let result = execute_transaction_operation( + conn, + |c| c.rollback().and_then(|_| c.set_autocommit(true)), + "rollback", + ); + send_oneshot(tx, result, "rollback transaction"); + Ok(CommandControlFlow::Continue) +} +fn handle_prepare<'conn>( + stmt_manager: &mut StatementManager<'conn>, + sql: Box, + tx: PrepareSender, +) -> CommandResult { + let result = do_prepare(stmt_manager, sql); + send_oneshot(tx, result, "prepare"); + Ok(CommandControlFlow::Continue) +} + +fn do_prepare<'conn>(stmt_manager: &mut StatementManager<'conn>, sql: Box) -> PrepareResult { + log::trace!("Preparing statement: {}", sql); + // Use the statement manager to get or create the prepared statement + let prepared = stmt_manager.get_or_create_prepared(&sql)?; + let columns = collect_columns(prepared); + let params = usize::from(prepared.num_params().unwrap_or(0)); + log::debug!( + "Prepared statement with {} columns and {} parameters", + columns.len(), + params + ); + Ok((0, columns, params)) +} + +fn handle_get_dbms_name( + conn: &OdbcConnection, + tx: oneshot::Sender>, +) -> CommandResult { + log::debug!("Getting DBMS name"); + let result = conn + .database_management_system_name() + .map_err(|e| Error::Protocol(format!("Failed to get DBMS name: {}", e))); + send_oneshot(tx, result, "DBMS name"); + Ok(CommandControlFlow::Continue) +} + +fn handle_execute<'conn>( + stmt_manager: &mut StatementManager<'conn>, + sql: Box, + args: Option, + tx: ExecuteSender, +) -> CommandResult { + log::debug!( + "Executing SQL: {}", + sql.chars().take(100).collect::() + ); + + let result = execute_sql(stmt_manager, &sql, args, &tx); + with_result_send_error(result, &tx, |_| {}); + Ok(CommandControlFlow::Continue) +} + +// SQL execution functions +fn execute_sql<'conn>( + stmt_manager: &mut StatementManager<'conn>, + sql: &str, + args: Option, + tx: &ExecuteSender, +) -> Result<(), Error> { + let params = prepare_parameters(args); + let stmt = stmt_manager.get_or_create_direct_stmt()?; + log::trace!("Starting execution of SQL: {}", sql); + let cursor_result = stmt.execute(sql, ¶ms[..]); + log::trace!("Received cursor result for SQL: {}", sql); + send_exec_result(cursor_result, tx)?; + Ok(()) +} + +// Unified execution logic for both direct and prepared statements +fn send_exec_result( + execution_result: Result, odbc_api::Error>, + tx: &ExecuteSender, +) -> Result<(), Error> +where + C: Cursor + ResultSetMetadata, +{ + match execution_result { + Ok(Some(mut cursor)) => { + handle_cursor(&mut cursor, tx); + Ok(()) + } + Ok(None) => { + let _ = send_done(tx, 0); + Ok(()) + } + Err(e) => Err(Error::from(e)), + } +} + +fn prepare_parameters( + args: Option, +) -> Vec> { + let args = args.map(|a| a.values).unwrap_or_default(); + args.into_iter().map(to_param).collect() +} + +fn to_param(arg: OdbcArgumentValue) -> Box { + match arg { + OdbcArgumentValue::Int(i) => Box::new(i.into_parameter()), + OdbcArgumentValue::Float(f) => Box::new(f.into_parameter()), + OdbcArgumentValue::Text(s) => Box::new(s.into_parameter()), + OdbcArgumentValue::Bytes(b) => Box::new(b.into_parameter()), + OdbcArgumentValue::Null => Box::new(Option::::None.into_parameter()), + } +} + +fn handle_cursor(cursor: &mut C, tx: &ExecuteSender) +where + C: Cursor + ResultSetMetadata, +{ + let columns = collect_columns(cursor); + log::trace!("Processing ODBC result set with {} columns", columns.len()); + + match stream_rows(cursor, &columns, tx) { + Ok(true) => { + log::trace!("Successfully streamed all rows"); + let _ = send_done(tx, 0); + } + Ok(false) => { + log::trace!("Row streaming stopped early (receiver closed)"); + } + Err(e) => { + send_error(tx, e); + } + } +} + +// Unified result sending functions +fn send_done(tx: &ExecuteSender, rows_affected: u64) -> Result<(), SendError> { + send_stream_result(tx, Ok(Either::Left(OdbcQueryResult { rows_affected }))) +} + +fn with_result_send_error( + result: Result, + tx: &ExecuteSender, + handler: impl FnOnce(T), +) { + match result { + Ok(result) => handler(result), + Err(error) => send_error(tx, error), + } +} + +fn send_error(tx: &ExecuteSender, error: Error) { + if let Err(e) = send_stream_result(tx, Err(error)) { + log::error!("Failed to send error from ODBC worker thread: {}", e); + } +} + +fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError> { + send_stream_result(tx, Ok(Either::Right(row))) +} + +// Metadata and row processing +fn collect_columns(cursor: &mut C) -> Vec +where + C: ResultSetMetadata, +{ + let count = cursor.num_result_cols().unwrap_or(0); + + (1..=count) + .map(|i| create_column(cursor, i as u16)) + .collect() +} + +fn create_column(cursor: &mut C, index: u16) -> OdbcColumn +where + C: ResultSetMetadata, +{ + let mut cd = odbc_api::ColumnDescription::default(); + let _ = cursor.describe_col(index, &mut cd); + + OdbcColumn { + name: decode_column_name(cd.name, index), + type_info: OdbcTypeInfo::new(cd.data_type), + ordinal: usize::from(index.checked_sub(1).unwrap()), + } +} + +fn decode_column_name(name_bytes: Vec, index: u16) -> String { + String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1)) +} + +fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result +where + C: Cursor, +{ + let mut receiver_open = true; + let mut row_count = 0; + + while let Some(mut row) = cursor.next_row()? { + let values = collect_row_values(&mut row, columns)?; + let row_data = OdbcRow { + columns: columns.to_vec(), + values: values.into_iter().map(|(_, value)| value).collect(), + }; + + if send_row(tx, row_data).is_err() { + log::debug!("Receiver closed after {} rows", row_count); + receiver_open = false; + break; + } + row_count += 1; + } + + if receiver_open { + log::debug!("Streamed {} rows successfully", row_count); + } + Ok(receiver_open) +} + +fn collect_row_values( + row: &mut CursorRow<'_>, + columns: &[OdbcColumn], +) -> Result, Error> { + columns + .iter() + .enumerate() + .map(|(i, column)| collect_column_value(row, i, column)) + .collect() +} + +fn collect_column_value( + row: &mut CursorRow<'_>, + index: usize, + column: &OdbcColumn, +) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { + use odbc_api::DataType; + + let col_idx = (index + 1) as u16; + let type_info = column.type_info.clone(); + let data_type = type_info.data_type(); + + // Extract value based on data type + let value = match data_type { + // Integer types + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Bit => extract_int(row, col_idx, &type_info)?, + + // Floating point types + DataType::Real => extract_float::(row, col_idx, &type_info)?, + DataType::Float { .. } | DataType::Double => { + extract_float::(row, col_idx, &type_info)? + } + + // String types + DataType::Char { .. } + | DataType::Varchar { .. } + | DataType::LongVarchar { .. } + | DataType::WChar { .. } + | DataType::WVarchar { .. } + | DataType::WLongVarchar { .. } + | DataType::Date + | DataType::Time { .. } + | DataType::Timestamp { .. } + | DataType::Decimal { .. } + | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, + + // Binary types + DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { + extract_binary(row, col_idx, &type_info)? + } + + // Unknown types - try text first, fall back to binary + DataType::Unknown | DataType::Other { .. } => { + match extract_text(row, col_idx, &type_info) { + Ok(v) => v, + Err(_) => extract_binary(row, col_idx, &type_info)?, + } + } + }; + + Ok((type_info, value)) +} + +fn extract_int( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, int) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v)), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int, + float: None, + }) +} + +fn extract_float( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result +where + T: Into + Default, + odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, +{ + let mut nullable = Nullable::::null(); + row.get_data(col_idx, &mut nullable)?; + + let (is_null, float) = match nullable.into_opt() { + None => (true, None), + Some(v) => (false, Some(v.into())), + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob: None, + int: None, + float, + }) +} + +fn extract_text( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_text(col_idx, &mut buf)?; + + let (is_null, text) = if !is_some { + (true, None) + } else { + match String::from_utf8(buf) { + Ok(s) => (false, Some(s)), + Err(e) => return Err(Error::Decode(e.into())), + } + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text, + blob: None, + int: None, + float: None, + }) +} + +fn extract_binary( + row: &mut CursorRow<'_>, + col_idx: u16, + type_info: &OdbcTypeInfo, +) -> Result { + let mut buf = Vec::new(); + let is_some = row.get_binary(col_idx, &mut buf)?; + + let (is_null, blob) = if !is_some { + (true, None) + } else { + (false, Some(buf)) + }; + + Ok(crate::odbc::OdbcValue { + type_info: type_info.clone(), + is_null, + text: None, + blob, + int: None, + float: None, + }) +} diff --git a/sqlx-core/src/odbc/database.rs b/sqlx-core/src/odbc/database.rs new file mode 100644 index 0000000000..b2bf81aca3 --- /dev/null +++ b/sqlx-core/src/odbc/database.rs @@ -0,0 +1,46 @@ +use crate::database::{Database, HasArguments, HasStatement, HasStatementCache, HasValueRef}; +use crate::odbc::{ + OdbcColumn, OdbcConnection, OdbcQueryResult, OdbcRow, OdbcStatement, OdbcTransactionManager, + OdbcTypeInfo, OdbcValue, OdbcValueRef, +}; + +#[derive(Debug)] +pub struct Odbc; + +impl Database for Odbc { + type Connection = OdbcConnection; + + type TransactionManager = OdbcTransactionManager; + + type Row = OdbcRow; + + type QueryResult = OdbcQueryResult; + + type Column = OdbcColumn; + + type TypeInfo = OdbcTypeInfo; + + type Value = OdbcValue; +} + +impl<'r> HasValueRef<'r> for Odbc { + type Database = Odbc; + + type ValueRef = OdbcValueRef<'r>; +} + +impl<'q> HasArguments<'q> for Odbc { + type Database = Odbc; + + type Arguments = crate::odbc::OdbcArguments; + + type ArgumentBuffer = Vec; +} + +impl<'q> HasStatement<'q> for Odbc { + type Database = Odbc; + + type Statement = OdbcStatement<'q>; +} + +impl HasStatementCache for Odbc {} diff --git a/sqlx-core/src/odbc/error.rs b/sqlx-core/src/odbc/error.rs new file mode 100644 index 0000000000..3d8141948a --- /dev/null +++ b/sqlx-core/src/odbc/error.rs @@ -0,0 +1,39 @@ +use crate::error::DatabaseError; +use odbc_api::Error as OdbcApiError; +use std::borrow::Cow; +use std::fmt::{Display, Formatter, Result as FmtResult}; + +#[derive(Debug)] +pub struct OdbcDatabaseError(pub OdbcApiError); + +impl Display for OdbcDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + Display::fmt(&self.0, f) + } +} + +impl std::error::Error for OdbcDatabaseError {} + +impl DatabaseError for OdbcDatabaseError { + fn message(&self) -> &str { + "ODBC error" + } + fn code(&self) -> Option> { + None + } + fn as_error(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { + self + } + fn as_error_mut(&mut self) -> &mut (dyn std::error::Error + Send + Sync + 'static) { + self + } + fn into_error(self: Box) -> Box { + self + } +} + +impl From for crate::error::Error { + fn from(value: OdbcApiError) -> Self { + crate::error::Error::Database(Box::new(OdbcDatabaseError(value))) + } +} diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs new file mode 100644 index 0000000000..da41adb1e9 --- /dev/null +++ b/sqlx-core/src/odbc/mod.rs @@ -0,0 +1,71 @@ +//! ODBC database driver (via `odbc-api`). +//! +//! ## Connection Strings +//! +//! When using the `Any` connection type, SQLx accepts standard ODBC connection strings: +//! +//! ```text +//! // DSN-based connection +//! DSN=MyDataSource;UID=myuser;PWD=mypassword +//! +//! // Driver-based connection +//! Driver={ODBC Driver 17 for SQL Server};Server=localhost;Database=test +//! +//! // File DSN +//! FILEDSN=/path/to/myfile.dsn +//! ``` +//! +//! The `odbc:` URL scheme prefix is optional but still supported for backward compatibility: +//! +//! ```text +//! odbc:DSN=MyDataSource +//! ``` + +use crate::executor::Executor; + +mod arguments; +mod column; +mod connection; +mod database; +mod error; +mod options; +mod query_result; +mod row; +mod statement; +mod transaction; +mod type_info; +pub mod types; +mod value; + +pub use arguments::{OdbcArgumentValue, OdbcArguments}; +pub use column::OdbcColumn; +pub use connection::OdbcConnection; +pub use database::Odbc; +pub use options::OdbcConnectOptions; +pub use query_result::OdbcQueryResult; +pub use row::OdbcRow; +pub use statement::OdbcStatement; +pub use transaction::OdbcTransactionManager; +pub use type_info::{DataTypeExt, OdbcTypeInfo}; +pub use value::{OdbcValue, OdbcValueRef}; + +/// An alias for [`Pool`][crate::pool::Pool], specialized for ODBC. +pub type OdbcPool = crate::pool::Pool; + +/// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for ODBC. +pub type OdbcPoolOptions = crate::pool::PoolOptions; + +/// An alias for [`Executor<'_, Database = Odbc>`][Executor]. +pub trait OdbcExecutor<'c>: Executor<'c, Database = Odbc> {} +impl<'c, T: Executor<'c, Database = Odbc>> OdbcExecutor<'c> for T {} + +// NOTE: required due to the lack of lazy normalization +impl_into_arguments_for_arguments!(crate::odbc::OdbcArguments); +impl_executor_for_pool_connection!(Odbc, OdbcConnection, OdbcRow); +impl_executor_for_transaction!(Odbc, OdbcRow); +impl_column_index_for_row!(OdbcRow); +impl_column_index_for_statement!(OdbcStatement); +impl_acquire!(Odbc, OdbcConnection); +impl_into_maybe_pool!(Odbc, OdbcConnection); + +// custom Option<..> handling implemented in `arguments.rs` diff --git a/sqlx-core/src/odbc/options/mod.rs b/sqlx-core/src/odbc/options/mod.rs new file mode 100644 index 0000000000..19a217bfcc --- /dev/null +++ b/sqlx-core/src/odbc/options/mod.rs @@ -0,0 +1,77 @@ +use crate::connection::{ConnectOptions, LogSettings}; +use crate::error::Error; +use futures_core::future::BoxFuture; +use log::LevelFilter; +use std::fmt::{self, Debug, Formatter}; +use std::str::FromStr; +use std::time::Duration; + +use crate::odbc::OdbcConnection; + +#[derive(Clone)] +pub struct OdbcConnectOptions { + pub(crate) conn_str: String, + pub(crate) log_settings: LogSettings, +} + +impl OdbcConnectOptions { + pub fn connection_string(&self) -> &str { + &self.conn_str + } +} + +impl Debug for OdbcConnectOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("OdbcConnectOptions") + .field("conn_str", &"") + .finish() + } +} + +impl FromStr for OdbcConnectOptions { + type Err = Error; + + fn from_str(s: &str) -> Result { + // Accept forms: + // - "odbc:DSN=Name;..." -> strip scheme + // - "odbc:Name" -> interpret as DSN + // - "DSN=Name;..." or full ODBC connection string + let mut t = s.trim(); + if let Some(rest) = t.strip_prefix("odbc:") { + t = rest; + } + let conn_str = if t.contains('=') { + // Looks like an ODBC key=value connection string + t.to_string() + } else { + // Bare DSN name + format!("DSN={}", t) + }; + + Ok(Self { + conn_str, + log_settings: LogSettings::default(), + }) + } +} + +impl ConnectOptions for OdbcConnectOptions { + type Connection = OdbcConnection; + + fn connect(&self) -> BoxFuture<'_, Result> + where + Self::Connection: Sized, + { + Box::pin(OdbcConnection::establish(self)) + } + + fn log_statements(&mut self, level: LevelFilter) -> &mut Self { + self.log_settings.log_statements(level); + self + } + + fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self { + self.log_settings.log_slow_statements(level, duration); + self + } +} diff --git a/sqlx-core/src/odbc/query_result.rs b/sqlx-core/src/odbc/query_result.rs new file mode 100644 index 0000000000..282e75f6ea --- /dev/null +++ b/sqlx-core/src/odbc/query_result.rs @@ -0,0 +1,28 @@ +#[derive(Debug, Default)] +pub struct OdbcQueryResult { + pub(super) rows_affected: u64, +} + +impl OdbcQueryResult { + pub fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for OdbcQueryResult { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + } + } +} + +#[cfg(feature = "any")] +impl From for crate::any::AnyQueryResult { + fn from(result: OdbcQueryResult) -> Self { + crate::any::AnyQueryResult { + rows_affected: result.rows_affected, + last_insert_id: None, // ODBC doesn't provide last insert ID + } + } +} diff --git a/sqlx-core/src/odbc/row.rs b/sqlx-core/src/odbc/row.rs new file mode 100644 index 0000000000..cf7c823603 --- /dev/null +++ b/sqlx-core/src/odbc/row.rs @@ -0,0 +1,198 @@ +use crate::column::ColumnIndex; +use crate::database::HasValueRef; +use crate::error::Error; +use crate::odbc::{Odbc, OdbcColumn, OdbcValue}; +use crate::row::Row; +use crate::value::Value; + +#[derive(Debug, Clone)] +pub struct OdbcRow { + pub(crate) columns: Vec, + pub(crate) values: Vec, +} + +impl Row for OdbcRow { + type Database = Odbc; + + fn columns(&self) -> &[OdbcColumn] { + &self.columns + } + + fn try_get_raw( + &self, + index: I, + ) -> Result<>::ValueRef, Error> + where + I: ColumnIndex, + { + let idx = index.index(self)?; + let value = &self.values[idx]; + Ok(value.as_ref()) + } +} + +impl ColumnIndex for &str { + fn index(&self, row: &OdbcRow) -> Result { + // Try exact match first (for performance) + if let Some(pos) = row.columns.iter().position(|col| col.name == *self) { + return Ok(pos); + } + + // Fall back to case-insensitive match (for databases like Snowflake) + row.columns + .iter() + .position(|col| col.name.eq_ignore_ascii_case(self)) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + } +} + +mod private { + use super::OdbcRow; + use crate::row::private_row::Sealed; + impl Sealed for OdbcRow {} +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcColumn, OdbcTypeInfo}; + use crate::type_info::TypeInfo; + use odbc_api::DataType; + + fn create_test_row() -> OdbcRow { + use crate::odbc::OdbcValue; + + OdbcRow { + columns: vec![ + OdbcColumn { + name: "lowercase_col".to_string(), + type_info: OdbcTypeInfo::new(DataType::Integer), + ordinal: 0, + }, + OdbcColumn { + name: "UPPERCASE_COL".to_string(), + type_info: OdbcTypeInfo::new(DataType::Varchar { length: None }), + ordinal: 1, + }, + OdbcColumn { + name: "MixedCase_Col".to_string(), + type_info: OdbcTypeInfo::new(DataType::Double), + ordinal: 2, + }, + ], + values: vec![ + OdbcValue { + type_info: OdbcTypeInfo::new(DataType::Integer), + is_null: false, + text: None, + blob: None, + int: Some(42), + float: None, + }, + OdbcValue { + type_info: OdbcTypeInfo::new(DataType::Varchar { length: None }), + is_null: false, + text: Some("test".to_string()), + blob: None, + int: None, + float: None, + }, + OdbcValue { + type_info: OdbcTypeInfo::new(DataType::Double), + is_null: false, + text: None, + blob: None, + int: None, + float: Some(std::f64::consts::PI), + }, + ], + } + } + + #[test] + fn test_exact_column_match() { + let row = create_test_row(); + + // Exact matches should work + assert_eq!("lowercase_col".index(&row).unwrap(), 0); + assert_eq!("UPPERCASE_COL".index(&row).unwrap(), 1); + assert_eq!("MixedCase_Col".index(&row).unwrap(), 2); + } + + #[test] + fn test_case_insensitive_column_match() { + let row = create_test_row(); + + // Case-insensitive matches should work + assert_eq!("LOWERCASE_COL".index(&row).unwrap(), 0); + assert_eq!("lowercase_col".index(&row).unwrap(), 0); + assert_eq!("uppercase_col".index(&row).unwrap(), 1); + assert_eq!("UPPERCASE_COL".index(&row).unwrap(), 1); + assert_eq!("mixedcase_col".index(&row).unwrap(), 2); + assert_eq!("MIXEDCASE_COL".index(&row).unwrap(), 2); + assert_eq!("MixedCase_Col".index(&row).unwrap(), 2); + } + + #[test] + fn test_column_not_found() { + let row = create_test_row(); + + let result = "nonexistent_column".index(&row); + assert!(result.is_err()); + if let Err(Error::ColumnNotFound(name)) = result { + assert_eq!(name, "nonexistent_column"); + } else { + panic!("Expected ColumnNotFound error"); + } + } + + #[test] + fn test_try_get_raw() { + let row = create_test_row(); + + // Test accessing by exact name + let value = row.try_get_raw("lowercase_col").unwrap(); + assert!(!value.is_null); + assert_eq!(value.type_info.name(), "INTEGER"); + + // Test accessing by case-insensitive name + let value = row.try_get_raw("LOWERCASE_COL").unwrap(); + assert!(!value.is_null); + assert_eq!(value.type_info.name(), "INTEGER"); + + // Test accessing uppercase column with lowercase name + let value = row.try_get_raw("uppercase_col").unwrap(); + assert!(!value.is_null); + assert_eq!(value.type_info.name(), "VARCHAR"); + } + + #[test] + fn test_columns_method() { + let row = create_test_row(); + let columns = row.columns(); + + assert_eq!(columns.len(), 3); + assert_eq!(columns[0].name, "lowercase_col"); + assert_eq!(columns[1].name, "UPPERCASE_COL"); + assert_eq!(columns[2].name, "MixedCase_Col"); + } +} + +#[cfg(feature = "any")] +impl From for crate::any::AnyRow { + fn from(row: OdbcRow) -> Self { + let columns = row + .columns + .iter() + .map(|col| crate::any::AnyColumn { + kind: crate::any::column::AnyColumnKind::Odbc(col.clone()), + type_info: crate::any::AnyTypeInfo::from(col.type_info.clone()), + }) + .collect(); + + crate::any::AnyRow { + kind: crate::any::row::AnyRowKind::Odbc(row), + columns, + } + } +} diff --git a/sqlx-core/src/odbc/statement.rs b/sqlx-core/src/odbc/statement.rs new file mode 100644 index 0000000000..beeef9807a --- /dev/null +++ b/sqlx-core/src/odbc/statement.rs @@ -0,0 +1,76 @@ +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::odbc::{Odbc, OdbcColumn, OdbcTypeInfo}; +use crate::statement::Statement; +use either::Either; +use std::borrow::Cow; + +#[derive(Debug, Clone)] +pub struct OdbcStatement<'q> { + pub(crate) sql: Cow<'q, str>, + pub(crate) columns: Vec, + pub(crate) parameters: usize, +} + +impl<'q> Statement<'q> for OdbcStatement<'q> { + type Database = Odbc; + + fn to_owned(&self) -> OdbcStatement<'static> { + OdbcStatement { + sql: Cow::Owned(self.sql.to_string()), + columns: self.columns.clone(), + parameters: self.parameters, + } + } + + fn sql(&self) -> &str { + &self.sql + } + fn parameters(&self) -> Option> { + Some(Either::Right(self.parameters)) + } + fn columns(&self) -> &[OdbcColumn] { + &self.columns + } + + // ODBC arguments placeholder + impl_statement_query!(crate::odbc::OdbcArguments); +} + +impl ColumnIndex> for &'_ str { + fn index(&self, statement: &OdbcStatement<'_>) -> Result { + statement + .columns + .iter() + .position(|c| c.name == *self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + } +} + +#[cfg(feature = "any")] +impl<'q> From> for crate::any::AnyStatement<'q> { + fn from(stmt: OdbcStatement<'q>) -> Self { + let mut column_names = crate::HashMap::::default(); + + // First build the columns and collect names + let columns: Vec<_> = stmt + .columns + .into_iter() + .enumerate() + .map(|(index, col)| { + column_names.insert(crate::ext::ustr::UStr::new(&col.name), index); + crate::any::AnyColumn { + kind: crate::any::column::AnyColumnKind::Odbc(col.clone()), + type_info: crate::any::AnyTypeInfo::from(col.type_info), + } + }) + .collect(); + + crate::any::AnyStatement { + sql: stmt.sql, + parameters: Some(either::Either::Right(stmt.parameters)), + columns, + column_names: std::sync::Arc::new(column_names), + } + } +} diff --git a/sqlx-core/src/odbc/transaction.rs b/sqlx-core/src/odbc/transaction.rs new file mode 100644 index 0000000000..2556c16784 --- /dev/null +++ b/sqlx-core/src/odbc/transaction.rs @@ -0,0 +1,32 @@ +use crate::error::Error; +use crate::odbc::Odbc; +use crate::transaction::TransactionManager; +use futures_core::future::BoxFuture; + +pub struct OdbcTransactionManager; + +impl TransactionManager for OdbcTransactionManager { + type Database = Odbc; + + fn begin( + conn: &mut ::Connection, + ) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { conn.worker.begin().await }) + } + + fn commit( + conn: &mut ::Connection, + ) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { conn.worker.commit().await }) + } + + fn rollback( + conn: &mut ::Connection, + ) -> BoxFuture<'_, Result<(), Error>> { + Box::pin(async move { conn.worker.rollback().await }) + } + + fn start_rollback(_conn: &mut ::Connection) { + // no-op best effort + } +} diff --git a/sqlx-core/src/odbc/type_info.rs b/sqlx-core/src/odbc/type_info.rs new file mode 100644 index 0000000000..93e81be24a --- /dev/null +++ b/sqlx-core/src/odbc/type_info.rs @@ -0,0 +1,192 @@ +use crate::type_info::TypeInfo; +use odbc_api::DataType; +use std::fmt::{Display, Formatter, Result as FmtResult}; + +/// Type information for an ODBC type. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct OdbcTypeInfo { + #[cfg_attr(feature = "offline", serde(skip))] + pub(crate) data_type: DataType, +} + +impl OdbcTypeInfo { + /// Create a new OdbcTypeInfo with the given data type + pub const fn new(data_type: DataType) -> Self { + Self { data_type } + } + + /// Get the underlying data type + pub const fn data_type(&self) -> DataType { + self.data_type + } +} + +/// Extension trait for DataType with helper methods +pub trait DataTypeExt { + /// Get the display name for this data type + fn name(self) -> &'static str; + + /// Check if this is a character/string type + fn accepts_character_data(self) -> bool; + + /// Check if this is a binary type + fn accepts_binary_data(self) -> bool; + + /// Check if this is a numeric type + fn accepts_numeric_data(self) -> bool; + + /// Check if this is a date/time type + fn accepts_datetime_data(self) -> bool; +} + +impl DataTypeExt for DataType { + fn name(self) -> &'static str { + match self { + DataType::BigInt => "BIGINT", + DataType::Binary { .. } => "BINARY", + DataType::Bit => "BIT", + DataType::Char { .. } => "CHAR", + DataType::Date => "DATE", + DataType::Decimal { .. } => "DECIMAL", + DataType::Double => "DOUBLE", + DataType::Float { .. } => "FLOAT", + DataType::Integer => "INTEGER", + DataType::LongVarbinary { .. } => "LONGVARBINARY", + DataType::LongVarchar { .. } => "LONGVARCHAR", + DataType::Numeric { .. } => "NUMERIC", + DataType::Real => "REAL", + DataType::SmallInt => "SMALLINT", + DataType::Time { .. } => "TIME", + DataType::Timestamp { .. } => "TIMESTAMP", + DataType::TinyInt => "TINYINT", + DataType::Varbinary { .. } => "VARBINARY", + DataType::Varchar { .. } => "VARCHAR", + DataType::WChar { .. } => "WCHAR", + DataType::WLongVarchar { .. } => "WLONGVARCHAR", + DataType::WVarchar { .. } => "WVARCHAR", + DataType::Unknown => "UNKNOWN", + DataType::Other { .. } => "OTHER", + } + } + + fn accepts_character_data(self) -> bool { + matches!( + self, + DataType::Char { .. } + | DataType::Varchar { .. } + | DataType::LongVarchar { .. } + | DataType::WChar { .. } + | DataType::WVarchar { .. } + | DataType::WLongVarchar { .. } + ) + } + + fn accepts_binary_data(self) -> bool { + matches!( + self, + DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } + ) + } + + fn accepts_numeric_data(self) -> bool { + matches!( + self, + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Real + | DataType::Float { .. } + | DataType::Double + | DataType::Decimal { .. } + | DataType::Numeric { .. } + ) + } + + fn accepts_datetime_data(self) -> bool { + matches!( + self, + DataType::Date | DataType::Time { .. } | DataType::Timestamp { .. } + ) + } +} + +impl TypeInfo for OdbcTypeInfo { + fn is_null(&self) -> bool { + false + } + + fn name(&self) -> &str { + self.data_type.name() + } + + fn is_void(&self) -> bool { + false + } +} + +impl Display for OdbcTypeInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + f.write_str(self.name()) + } +} + +// Provide some common type constants +impl OdbcTypeInfo { + pub const BIGINT: Self = Self::new(DataType::BigInt); + pub const BIT: Self = Self::new(DataType::Bit); + pub const DATE: Self = Self::new(DataType::Date); + pub const DOUBLE: Self = Self::new(DataType::Double); + pub const INTEGER: Self = Self::new(DataType::Integer); + pub const REAL: Self = Self::new(DataType::Real); + pub const SMALLINT: Self = Self::new(DataType::SmallInt); + pub const TINYINT: Self = Self::new(DataType::TinyInt); + pub const UNKNOWN: Self = Self::new(DataType::Unknown); + pub const TIME: Self = Self::new(DataType::Time { precision: 0 }); + pub const TIMESTAMP: Self = Self::new(DataType::Timestamp { precision: 0 }); + + // For types with parameters, use constructor functions + pub const fn varchar(length: Option) -> Self { + Self::new(DataType::Varchar { length }) + } + + pub const fn varbinary(length: Option) -> Self { + Self::new(DataType::Varbinary { length }) + } + + pub const fn char(length: Option) -> Self { + Self::new(DataType::Char { length }) + } + + pub const fn binary(length: Option) -> Self { + Self::new(DataType::Binary { length }) + } + + pub const fn float(precision: usize) -> Self { + Self::new(DataType::Float { precision }) + } + + pub const fn decimal(precision: usize, scale: i16) -> Self { + Self::new(DataType::Decimal { precision, scale }) + } + + pub const fn numeric(precision: usize, scale: i16) -> Self { + Self::new(DataType::Numeric { precision, scale }) + } + + pub const fn time(precision: i16) -> Self { + Self::new(DataType::Time { precision }) + } + + pub const fn timestamp(precision: i16) -> Self { + Self::new(DataType::Timestamp { precision }) + } +} + +#[cfg(feature = "any")] +impl From for crate::any::AnyTypeInfo { + fn from(info: OdbcTypeInfo) -> Self { + crate::any::AnyTypeInfo(crate::any::type_info::AnyTypeInfoKind::Odbc(info)) + } +} diff --git a/sqlx-core/src/odbc/types/bigdecimal.rs b/sqlx-core/src/odbc/types/bigdecimal.rs new file mode 100644 index 0000000000..7b15f65e15 --- /dev/null +++ b/sqlx-core/src/odbc/types/bigdecimal.rs @@ -0,0 +1,54 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use bigdecimal::{BigDecimal, FromPrimitive}; +use odbc_api::DataType; +use std::str::FromStr; + +impl Type for BigDecimal { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::numeric(28, 4) // Standard precision/scale + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Numeric { .. } + | DataType::Decimal { .. } + | DataType::Double + | DataType::Float { .. } + ) || ty.data_type().accepts_character_data() + } +} + +impl<'q> Encode<'q, Odbc> for BigDecimal { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } +} + +impl<'r> Decode<'r, Odbc> for BigDecimal { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(int) = value.int { + return Ok(BigDecimal::from(int)); + } + if let Some(float) = value.float { + return Ok(BigDecimal::from_f64(float).ok_or(format!("bad float: {}", float))?); + } + if let Some(text) = value.text { + return Ok(BigDecimal::from_str(text).map_err(|e| format!("bad decimal text: {}", e))?); + } + if let Some(bytes) = value.blob { + return Ok(BigDecimal::parse_bytes(bytes, 10) + .ok_or(format!("bad base10 bytes: {:?}", bytes))?); + } + Err(format!("ODBC: cannot decode BigDecimal: {:?}", value).into()) + } +} diff --git a/sqlx-core/src/odbc/types/bool.rs b/sqlx-core/src/odbc/types/bool.rs new file mode 100644 index 0000000000..e574df8f92 --- /dev/null +++ b/sqlx-core/src/odbc/types/bool.rs @@ -0,0 +1,338 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use odbc_api::DataType; + +impl Type for bool { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::BIT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Bit + | DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + | DataType::Real + | DataType::Float { .. } + | DataType::Double + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl<'q> Encode<'q, Odbc> for bool { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(if self { 1 } else { 0 })); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(if *self { 1 } else { 0 })); + crate::encode::IsNull::No + } +} + +impl<'r> Decode<'r, Odbc> for bool { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(i) = value.int { + return Ok(i != 0); + } + + // Handle float values (from DECIMAL/NUMERIC types) + if let Some(f) = value.float { + return Ok(f != 0.0); + } + + if let Some(text) = value.text { + let text = text.trim(); + // Try exact string matches first + return Ok(match text { + "0" | "0.0" | "false" | "FALSE" | "f" | "F" => false, + "1" | "1.0" | "true" | "TRUE" | "t" | "T" => true, + _ => { + // Try parsing as number first + if let Ok(num) = text.parse::() { + num != 0.0 + } else if let Ok(num) = text.parse::() { + num != 0 + } else { + // Fall back to string parsing + text.parse()? + } + } + }); + } + + if let Some(bytes) = value.blob { + let s = std::str::from_utf8(bytes)?; + let s = s.trim(); + return Ok(match s { + "0" | "0.0" | "false" | "FALSE" | "f" | "F" => false, + "1" | "1.0" | "true" | "TRUE" | "t" | "T" => true, + _ => { + // Try parsing as number first + if let Ok(num) = s.parse::() { + num != 0.0 + } else if let Ok(num) = s.parse::() { + num != 0 + } else { + // Fall back to string parsing + s.parse()? + } + } + }); + } + + Err("ODBC: cannot decode bool".into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::type_info::TypeInfo; + use odbc_api::DataType; + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: Some(text), + blob: None, + int: None, + float: None, + } + } + + fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: Some(value), + float: None, + } + } + + fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: None, + float: Some(value), + } + } + + #[test] + fn test_bool_type_compatibility() { + // Standard boolean types + assert!(>::compatible(&OdbcTypeInfo::BIT)); + assert!(>::compatible(&OdbcTypeInfo::TINYINT)); + + // DECIMAL and NUMERIC types (Snowflake compatibility) + assert!(>::compatible(&OdbcTypeInfo::decimal( + 1, 0 + ))); + assert!(>::compatible(&OdbcTypeInfo::numeric( + 1, 0 + ))); + + // Floating point types + assert!(>::compatible(&OdbcTypeInfo::DOUBLE)); + assert!(>::compatible(&OdbcTypeInfo::REAL)); + + // Character types + assert!(>::compatible(&OdbcTypeInfo::varchar( + None + ))); + + // Should not be compatible with binary types + assert!(!>::compatible(&OdbcTypeInfo::varbinary( + None + ))); + } + + #[test] + fn test_bool_decode_from_decimal_text() -> Result<(), BoxDynError> { + let value = create_test_value_text( + "1", + DataType::Decimal { + precision: 1, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + let value = create_test_value_text( + "0", + DataType::Decimal { + precision: 1, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, false); + + // Test with decimal values + let value = create_test_value_text( + "1.0", + DataType::Decimal { + precision: 2, + scale: 1, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + let value = create_test_value_text( + "0.0", + DataType::Decimal { + precision: 2, + scale: 1, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, false); + + Ok(()) + } + + #[test] + fn test_bool_decode_from_float() -> Result<(), BoxDynError> { + let value = create_test_value_float(1.0, DataType::Double); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + let value = create_test_value_float(0.0, DataType::Double); + let decoded = >::decode(value)?; + assert_eq!(decoded, false); + + let value = create_test_value_float(42.5, DataType::Double); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + Ok(()) + } + + #[test] + fn test_bool_decode_from_int() -> Result<(), BoxDynError> { + let value = create_test_value_int(1, DataType::Integer); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + let value = create_test_value_int(0, DataType::Integer); + let decoded = >::decode(value)?; + assert_eq!(decoded, false); + + let value = create_test_value_int(-1, DataType::Integer); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + Ok(()) + } + + #[test] + fn test_bool_decode_string_variants() -> Result<(), BoxDynError> { + // Test various string representations + let test_cases = vec![ + ("true", true), + ("TRUE", true), + ("t", true), + ("T", true), + ("false", false), + ("FALSE", false), + ("f", false), + ("F", false), + ]; + + for (input, expected) in test_cases { + let value = create_test_value_text(input, DataType::Varchar { length: None }); + let decoded = >::decode(value)?; + assert_eq!(decoded, expected, "Failed for input: {}", input); + } + + Ok(()) + } + + #[test] + fn test_bool_decode_with_whitespace() -> Result<(), BoxDynError> { + let value = create_test_value_text( + " 1 ", + DataType::Decimal { + precision: 1, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, true); + + let value = create_test_value_text( + " 0 ", + DataType::Decimal { + precision: 1, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, false); + + Ok(()) + } + + #[test] + fn test_bool_encode() { + let mut buf = Vec::new(); + let result = >::encode(true, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Int(val) = &buf[0] { + assert_eq!(*val, 1); + } else { + panic!("Expected Int argument"); + } + + let mut buf = Vec::new(); + let result = >::encode(false, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Int(val) = &buf[0] { + assert_eq!(*val, 0); + } else { + panic!("Expected Int argument"); + } + } + + #[test] + fn test_bool_type_info() { + let type_info = >::type_info(); + assert_eq!(type_info.name(), "BIT"); + assert!(matches!(type_info.data_type(), DataType::Bit)); + } + + #[test] + fn test_bool_decode_error_handling() { + let value = OdbcValueRef { + type_info: OdbcTypeInfo::BIT, + is_null: false, + text: None, + blob: None, + int: None, + float: None, + }; + + let result = >::decode(value); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "ODBC: cannot decode bool"); + } +} diff --git a/sqlx-core/src/odbc/types/bytes.rs b/sqlx-core/src/odbc/types/bytes.rs new file mode 100644 index 0000000000..6ad56a7554 --- /dev/null +++ b/sqlx-core/src/odbc/types/bytes.rs @@ -0,0 +1,216 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; + +impl Type for Vec { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::varbinary(None) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_binary_data() || ty.data_type().accepts_character_data() + // Allow decoding from character types too + } +} + +impl<'q> Encode<'q, Odbc> for Vec { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Bytes(self)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Bytes(self.clone())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for &'q [u8] { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Bytes(self.to_vec())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Bytes(self.to_vec())); + crate::encode::IsNull::No + } +} + +impl<'r> Decode<'r, Odbc> for Vec { + fn decode(value: OdbcValueRef<'r>) -> Result { + Ok(<&[u8] as Decode<'r, Odbc>>::decode(value)?.to_vec()) + } +} + +impl<'r> Decode<'r, Odbc> for &'r [u8] { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(bytes) = value.blob { + return Ok(bytes); + } + if let Some(text) = value.text { + return Ok(text.as_bytes()); + } + Err(format!("ODBC: cannot decode {:?} as &[u8]", value).into()) + } +} + +impl Type for [u8] { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::varbinary(None) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_binary_data() || ty.data_type().accepts_character_data() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::type_info::TypeInfo; + use odbc_api::DataType; + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: Some(text), + blob: None, + int: None, + float: None, + } + } + + fn create_test_value_blob(data: &'static [u8], data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: Some(data), + int: None, + float: None, + } + } + + #[test] + fn test_vec_u8_type_compatibility() { + // Should be compatible with binary types + assert!( as Type>::compatible( + &OdbcTypeInfo::varbinary(None) + )); + assert!( as Type>::compatible(&OdbcTypeInfo::binary( + None + ))); + + // Should be compatible with character types (for hex decoding) + assert!( as Type>::compatible(&OdbcTypeInfo::varchar( + None + ))); + assert!( as Type>::compatible(&OdbcTypeInfo::char( + None + ))); + + // Should not be compatible with numeric types + assert!(! as Type>::compatible(&OdbcTypeInfo::INTEGER)); + } + + #[test] + fn test_vec_u8_decode_from_blob() -> Result<(), BoxDynError> { + let test_data = b"Hello, ODBC!"; + let value = create_test_value_blob(test_data, DataType::Varbinary { length: None }); + let decoded = as Decode>::decode(value)?; + assert_eq!(decoded, test_data.to_vec()); + + Ok(()) + } + + #[test] + fn test_vec_u8_decode_from_raw_text() -> Result<(), BoxDynError> { + let text = "Hello, World!"; + let value = create_test_value_text(text, DataType::Varchar { length: None }); + let decoded = as Decode>::decode(value)?; + assert_eq!(decoded, text.as_bytes().to_vec()); + + Ok(()) + } + + #[test] + fn test_slice_u8_decode_from_blob() -> Result<(), BoxDynError> { + let test_data = b"Hello, ODBC!"; + let value = create_test_value_blob(test_data, DataType::Varbinary { length: None }); + let decoded = <&[u8] as Decode>::decode(value)?; + assert_eq!(decoded, test_data); + + Ok(()) + } + + #[test] + fn test_slice_u8_decode_from_text() -> Result<(), BoxDynError> { + let text = "Hello"; + let value = create_test_value_text(text, DataType::Varchar { length: None }); + let decoded = <&[u8] as Decode>::decode(value)?; + assert_eq!(decoded, text.as_bytes()); + + Ok(()) + } + + #[test] + fn test_vec_u8_encode() { + let mut buf = Vec::new(); + let data = vec![65, 66, 67, 68, 69]; // "ABCDE" + let result = as Encode>::encode(data, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Bytes(bytes) = &buf[0] { + assert_eq!(*bytes, vec![65, 66, 67, 68, 69]); + } else { + panic!("Expected Bytes argument"); + } + } + + #[test] + fn test_slice_u8_encode() { + let mut buf = Vec::new(); + let data: &[u8] = &[72, 101, 108, 108, 111]; // "Hello" + let result = <&[u8] as Encode>::encode(data, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Bytes(bytes) = &buf[0] { + assert_eq!(*bytes, vec![72, 101, 108, 108, 111]); + } else { + panic!("Expected Bytes argument"); + } + } + + #[test] + fn test_decode_error_handling() { + let value = OdbcValueRef { + type_info: OdbcTypeInfo::varbinary(None), + is_null: false, + text: None, + blob: None, + int: None, + float: None, + }; + assert!( as Decode<'_, Odbc>>::decode(value).is_err()); + } + + #[test] + fn test_type_info() { + let type_info = as Type>::type_info(); + assert_eq!(type_info.name(), "VARBINARY"); + assert!(matches!( + type_info.data_type(), + DataType::Varbinary { length: None } + )); + + let type_info = <[u8] as Type>::type_info(); + assert_eq!(type_info.name(), "VARBINARY"); + assert!(matches!( + type_info.data_type(), + DataType::Varbinary { length: None } + )); + } +} diff --git a/sqlx-core/src/odbc/types/chrono.rs b/sqlx-core/src/odbc/types/chrono.rs new file mode 100644 index 0000000000..178885dacd --- /dev/null +++ b/sqlx-core/src/odbc/types/chrono.rs @@ -0,0 +1,550 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::type_info::TypeInfo; +use crate::types::Type; +use chrono::{DateTime, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use odbc_api::DataType; + +impl Type for NaiveDate { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::DATE + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!(ty.data_type(), DataType::Date) + || ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() + || matches!(ty.data_type(), DataType::Other { .. } | DataType::Unknown) + } +} + +impl Type for NaiveTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TIME + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!(ty.data_type(), DataType::Time { .. }) + || ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() + || matches!(ty.data_type(), DataType::Other { .. } | DataType::Unknown) + } +} + +impl Type for NaiveDateTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TIMESTAMP + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!(ty.data_type(), DataType::Timestamp { .. }) + || ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() + || matches!(ty.data_type(), DataType::Other { .. } | DataType::Unknown) + } +} + +impl Type for DateTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TIMESTAMP + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!(ty.data_type(), DataType::Timestamp { .. }) + || ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() + || matches!(ty.data_type(), DataType::Other { .. } | DataType::Unknown) + } +} + +impl Type for DateTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TIMESTAMP + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!(ty.data_type(), DataType::Timestamp { .. }) + || ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() + || matches!(ty.data_type(), DataType::Other { .. } | DataType::Unknown) + } +} + +impl Type for DateTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TIMESTAMP + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!(ty.data_type(), DataType::Timestamp { .. }) + || ty.data_type().accepts_character_data() + || matches!(ty.data_type(), DataType::Other { .. } | DataType::Unknown) + } +} + +impl<'q> Encode<'q, Odbc> for NaiveDate { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.format("%Y-%m-%d").to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.format("%Y-%m-%d").to_string())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for NaiveTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.format("%H:%M:%S").to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.format("%H:%M:%S").to_string())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for NaiveDateTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for DateTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for DateTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for DateTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text( + self.format("%Y-%m-%d %H:%M:%S").to_string(), + )); + crate::encode::IsNull::No + } +} + +// Helper functions for date parsing +fn parse_yyyymmdd_as_naive_date(val: i64) -> Option { + if (19000101..=30001231).contains(&val) { + let year = (val / 10000) as i32; + let month = ((val % 10000) / 100) as u32; + let day = (val % 100) as u32; + NaiveDate::from_ymd_opt(year, month, day) + } else { + None + } +} + +fn parse_yyyymmdd_text_as_naive_date(s: &str) -> Option { + if s.len() == 8 && s.chars().all(|c| c.is_ascii_digit()) { + if let (Ok(y), Ok(m), Ok(d)) = ( + s[0..4].parse::(), + s[4..6].parse::(), + s[6..8].parse::(), + ) { + return NaiveDate::from_ymd_opt(y, m, d); + } + } + None +} + +fn get_text_from_value(value: &OdbcValueRef<'_>) -> Result, BoxDynError> { + if let Some(text) = value.text { + let trimmed = text.trim_matches('\u{0}').trim(); + return Ok(Some(trimmed.to_string())); + } + if let Some(bytes) = value.blob { + let s = std::str::from_utf8(bytes)?; + let trimmed = s.trim_matches('\u{0}').trim(); + return Ok(Some(trimmed.to_string())); + } + Ok(None) +} + +impl<'r> Decode<'r, Odbc> for NaiveDate { + fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle text values first (most common for dates) + if let Some(text) = get_text_from_value(&value)? { + if let Some(date) = parse_yyyymmdd_text_as_naive_date(&text) { + return Ok(date); + } + if let Ok(date) = text.parse() { + return Ok(date); + } + } + + // Handle numeric YYYYMMDD format (for databases that return as numbers) + if let Some(int_val) = value.int { + if let Some(date) = parse_yyyymmdd_as_naive_date(int_val) { + return Ok(date); + } + return Err(format!( + "ODBC: cannot decode NaiveDate from integer '{}': not in YYYYMMDD range", + int_val + ) + .into()); + } + + // Handle float values similarly + if let Some(float_val) = value.float { + if let Some(date) = parse_yyyymmdd_as_naive_date(float_val as i64) { + return Ok(date); + } + return Err(format!( + "ODBC: cannot decode NaiveDate from float '{}': not in YYYYMMDD range", + float_val + ) + .into()); + } + + Err(format!( + "ODBC: cannot decode NaiveDate from value with type '{}'", + value.type_info.name() + ) + .into()) + } +} + +impl<'r> Decode<'r, Odbc> for NaiveTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + let mut s = >::decode(value)?; + if s.ends_with('\u{0}') { + s = s.trim_end_matches('\u{0}').to_string(); + } + let s_trimmed = s.trim(); + Ok(s_trimmed + .parse() + .map_err(|e| format!("ODBC: cannot decode NaiveTime from '{}': {}", s_trimmed, e))?) + } +} + +impl<'r> Decode<'r, Odbc> for NaiveDateTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + let mut s = >::decode(value)?; + // Some ODBC drivers (e.g. PostgreSQL) may include trailing spaces or NULs + // in textual representations of timestamps. Trim them before parsing. + if s.ends_with('\u{0}') { + s = s.trim_end_matches('\u{0}').to_string(); + } + let s_trimmed = s.trim(); + // Try strict format first, then fall back to Chrono's FromStr + if let Ok(dt) = NaiveDateTime::parse_from_str(s_trimmed, "%Y-%m-%d %H:%M:%S") { + return Ok(dt); + } + Ok(s_trimmed.parse().map_err(|e| { + format!( + "ODBC: cannot decode NaiveDateTime from '{}': {}", + s_trimmed, e + ) + })?) + } +} + +impl<'r> Decode<'r, Odbc> for DateTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + let mut s = >::decode(value)?; + if s.ends_with('\u{0}') { + s = s.trim_end_matches('\u{0}').to_string(); + } + let s_trimmed = s.trim(); + + // First try to parse as a UTC timestamp with timezone + if let Ok(dt) = s_trimmed.parse::>() { + return Ok(dt); + } + + // If that fails, try to parse as a naive datetime and convert to UTC + if let Ok(naive_dt) = NaiveDateTime::parse_from_str(s_trimmed, "%Y-%m-%d %H:%M:%S") { + return Ok(DateTime::::from_naive_utc_and_offset(naive_dt, Utc)); + } + + // Finally, try chrono's default naive datetime parser + if let Ok(naive_dt) = s_trimmed.parse::() { + return Ok(DateTime::::from_naive_utc_and_offset(naive_dt, Utc)); + } + + Err(format!("ODBC: cannot decode DateTime from '{}'", s_trimmed).into()) + } +} + +impl<'r> Decode<'r, Odbc> for DateTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + let mut s = >::decode(value)?; + if s.ends_with('\u{0}') { + s = s.trim_end_matches('\u{0}').to_string(); + } + let s_trimmed = s.trim(); + + // First try to parse as a timestamp with timezone/offset + if let Ok(dt) = s_trimmed.parse::>() { + return Ok(dt); + } + + // If that fails, try to parse as a naive datetime and assume UTC (zero offset) + if let Ok(naive_dt) = NaiveDateTime::parse_from_str(s_trimmed, "%Y-%m-%d %H:%M:%S") { + return Ok(DateTime::::from_naive_utc_and_offset(naive_dt, Utc).fixed_offset()); + } + + // Finally, try chrono's default naive datetime parser + if let Ok(naive_dt) = s_trimmed.parse::() { + return Ok(DateTime::::from_naive_utc_and_offset(naive_dt, Utc).fixed_offset()); + } + + Err(format!( + "ODBC: cannot decode DateTime from '{}'", + s_trimmed + ) + .into()) + } +} + +impl<'r> Decode<'r, Odbc> for DateTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + let mut s = >::decode(value)?; + if s.ends_with('\u{0}') { + s = s.trim_end_matches('\u{0}').to_string(); + } + let s_trimmed = s.trim(); + Ok(s_trimmed + .parse::>() + .map_err(|e| { + format!( + "ODBC: cannot decode DateTime from '{}' as DateTime: {}", + s_trimmed, e + ) + })? + .with_timezone(&Local)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::type_info::TypeInfo; + use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; + use odbc_api::DataType; + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: Some(text), + blob: None, + int: None, + float: None, + } + } + + fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: Some(value), + float: None, + } + } + + #[test] + fn test_naive_date_type_compatibility() { + assert!(>::compatible(&OdbcTypeInfo::DATE)); + assert!(>::compatible( + &OdbcTypeInfo::varchar(None) + )); + assert!(>::compatible( + &OdbcTypeInfo::INTEGER + )); + } + + #[test] + fn test_parse_yyyymmdd_as_naive_date() { + // Valid dates + assert_eq!( + parse_yyyymmdd_as_naive_date(20200102), + Some(NaiveDate::from_ymd_opt(2020, 1, 2).unwrap()) + ); + assert_eq!( + parse_yyyymmdd_as_naive_date(19991231), + Some(NaiveDate::from_ymd_opt(1999, 12, 31).unwrap()) + ); + + // Invalid dates + assert_eq!(parse_yyyymmdd_as_naive_date(20201301), None); // Invalid month + assert_eq!(parse_yyyymmdd_as_naive_date(20200230), None); // Invalid day + assert_eq!(parse_yyyymmdd_as_naive_date(123456), None); // Too short + } + + #[test] + fn test_parse_yyyymmdd_text_as_naive_date() { + // Valid dates + assert_eq!( + parse_yyyymmdd_text_as_naive_date("20200102"), + Some(NaiveDate::from_ymd_opt(2020, 1, 2).unwrap()) + ); + assert_eq!( + parse_yyyymmdd_text_as_naive_date("19991231"), + Some(NaiveDate::from_ymd_opt(1999, 12, 31).unwrap()) + ); + + // Invalid formats + assert_eq!(parse_yyyymmdd_text_as_naive_date("2020-01-02"), None); // Dashes + assert_eq!(parse_yyyymmdd_text_as_naive_date("20201301"), None); // Invalid month + assert_eq!(parse_yyyymmdd_text_as_naive_date("abcd1234"), None); // Non-numeric + } + + #[test] + fn test_naive_date_decode_from_text() -> Result<(), BoxDynError> { + // Standard ISO format + let value = create_test_value_text("2020-01-02", DataType::Date); + let decoded = >::decode(value)?; + assert_eq!(decoded, NaiveDate::from_ymd_opt(2020, 1, 2).unwrap()); + + // YYYYMMDD format + let value = create_test_value_text("20200102", DataType::Date); + let decoded = >::decode(value)?; + assert_eq!(decoded, NaiveDate::from_ymd_opt(2020, 1, 2).unwrap()); + + Ok(()) + } + + #[test] + fn test_naive_date_decode_from_int() -> Result<(), BoxDynError> { + let value = create_test_value_int(20200102, DataType::Date); + let decoded = >::decode(value)?; + assert_eq!(decoded, NaiveDate::from_ymd_opt(2020, 1, 2).unwrap()); + + Ok(()) + } + + #[test] + fn test_naive_datetime_decode() -> Result<(), BoxDynError> { + let value = + create_test_value_text("2020-01-02 15:30:45", DataType::Timestamp { precision: 0 }); + let decoded = >::decode(value)?; + let expected = NaiveDate::from_ymd_opt(2020, 1, 2) + .unwrap() + .and_hms_opt(15, 30, 45) + .unwrap(); + assert_eq!(decoded, expected); + + Ok(()) + } + + #[test] + fn test_datetime_utc_decode() -> Result<(), BoxDynError> { + let value = + create_test_value_text("2020-01-02 15:30:45", DataType::Timestamp { precision: 0 }); + let decoded = as Decode>::decode(value)?; + let expected_naive = NaiveDate::from_ymd_opt(2020, 1, 2) + .unwrap() + .and_hms_opt(15, 30, 45) + .unwrap(); + let expected = DateTime::::from_naive_utc_and_offset(expected_naive, Utc); + assert_eq!(decoded, expected); + + Ok(()) + } + + #[test] + fn test_naive_time_decode() -> Result<(), BoxDynError> { + let value = create_test_value_text("15:30:45", DataType::Time { precision: 0 }); + let decoded = >::decode(value)?; + let expected = NaiveTime::from_hms_opt(15, 30, 45).unwrap(); + assert_eq!(decoded, expected); + + Ok(()) + } + + #[test] + fn test_naive_date_encode() { + let mut buf = Vec::new(); + let date = NaiveDate::from_ymd_opt(2020, 1, 2).unwrap(); + let result = >::encode(date, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Text(text) = &buf[0] { + assert_eq!(text, "2020-01-02"); + } else { + panic!("Expected Text argument"); + } + } + + #[test] + fn test_get_text_from_value() -> Result<(), BoxDynError> { + // From text + let value = create_test_value_text(" test ", DataType::Varchar { length: None }); + assert_eq!(get_text_from_value(&value)?, Some("test".to_string())); + + // From empty + let value = OdbcValueRef { + type_info: OdbcTypeInfo::new(DataType::Date), + is_null: false, + text: None, + blob: None, + int: None, + float: None, + }; + assert_eq!(get_text_from_value(&value)?, None); + + Ok(()) + } + + #[test] + fn test_type_info() { + assert_eq!(>::type_info().name(), "DATE"); + assert_eq!(>::type_info().name(), "TIME"); + assert_eq!( + >::type_info().name(), + "TIMESTAMP" + ); + assert_eq!( + as Type>::type_info().name(), + "TIMESTAMP" + ); + } +} diff --git a/sqlx-core/src/odbc/types/decimal.rs b/sqlx-core/src/odbc/types/decimal.rs new file mode 100644 index 0000000000..ba796e9b9d --- /dev/null +++ b/sqlx-core/src/odbc/types/decimal.rs @@ -0,0 +1,274 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use odbc_api::DataType; +use rust_decimal::Decimal; +use std::str::FromStr; + +impl Type for Decimal { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::numeric(28, 4) // Standard precision/scale + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Numeric { .. } + | DataType::Decimal { .. } + | DataType::Double + | DataType::Float { .. } + ) || ty.data_type().accepts_character_data() + } +} + +impl<'q> Encode<'q, Odbc> for Decimal { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } +} + +// Helper function for getting text from value for decimal parsing +fn get_text_for_decimal_parsing(value: &OdbcValueRef<'_>) -> Result, BoxDynError> { + if let Some(text) = value.text { + return Ok(Some(text.trim().to_string())); + } + if let Some(bytes) = value.blob { + let s = std::str::from_utf8(bytes)?; + return Ok(Some(s.trim().to_string())); + } + Ok(None) +} + +impl<'r> Decode<'r, Odbc> for Decimal { + fn decode(value: OdbcValueRef<'r>) -> Result { + // Try integer conversion first (most precise) + if let Some(int_val) = value.int { + return Ok(Decimal::from(int_val)); + } + + // Try direct float conversion for better precision + if let Some(float_val) = value.float { + if let Ok(decimal) = Decimal::try_from(float_val) { + return Ok(decimal); + } + } + + // Fall back to string parsing + if let Some(text) = get_text_for_decimal_parsing(&value)? { + return Ok(Decimal::from_str(&text)?); + } + + Err("ODBC: cannot decode Decimal".into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::type_info::TypeInfo; + use odbc_api::DataType; + use std::str::FromStr; + + fn create_test_value_text(text: &str, data_type: DataType) -> OdbcValueRef<'_> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: Some(text), + blob: None, + int: None, + float: None, + } + } + + fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: Some(value), + float: None, + } + } + + fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: None, + float: Some(value), + } + } + + #[test] + fn test_decimal_type_compatibility() { + // Should be compatible with decimal/numeric types + assert!(>::compatible(&OdbcTypeInfo::decimal( + 10, 2 + ))); + assert!(>::compatible(&OdbcTypeInfo::numeric( + 15, 4 + ))); + + // Should be compatible with floating point types + assert!(>::compatible(&OdbcTypeInfo::DOUBLE)); + assert!(>::compatible(&OdbcTypeInfo::float( + 24 + ))); + + // Should be compatible with character types + assert!(>::compatible(&OdbcTypeInfo::varchar( + None + ))); + + // Should not be compatible with binary types + assert!(!>::compatible( + &OdbcTypeInfo::varbinary(None) + )); + } + + #[test] + fn test_decimal_decode_from_text() -> Result<(), BoxDynError> { + let value = create_test_value_text( + "123.456", + DataType::Decimal { + precision: 10, + scale: 3, + }, + ); + let decoded = >::decode(value)?; + let expected = Decimal::from_str("123.456")?; + assert_eq!(decoded, expected); + + // Test with whitespace + let value = create_test_value_text( + " 987.654 ", + DataType::Decimal { + precision: 10, + scale: 3, + }, + ); + let decoded = >::decode(value)?; + let expected = Decimal::from_str("987.654")?; + assert_eq!(decoded, expected); + + Ok(()) + } + + #[test] + fn test_decimal_decode_from_int() -> Result<(), BoxDynError> { + let value = create_test_value_int( + 42, + DataType::Decimal { + precision: 10, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + let expected = Decimal::from(42); + assert_eq!(decoded, expected); + + Ok(()) + } + + #[test] + fn test_decimal_decode_from_float() -> Result<(), BoxDynError> { + let value = create_test_value_float( + 123.456, + DataType::Decimal { + precision: 10, + scale: 3, + }, + ); + let decoded = >::decode(value)?; + + // Check that it's approximately correct (floating point precision issues) + let expected_str = "123.456"; + let expected = Decimal::from_str(expected_str)?; + let diff = (decoded - expected).abs(); + assert!(diff < Decimal::from_str("0.001")?); + + Ok(()) + } + + #[test] + fn test_decimal_decode_negative() -> Result<(), BoxDynError> { + let value = create_test_value_text( + "-123.456", + DataType::Decimal { + precision: 10, + scale: 3, + }, + ); + let decoded = >::decode(value)?; + let expected = Decimal::from_str("-123.456")?; + assert_eq!(decoded, expected); + + Ok(()) + } + + #[test] + fn test_decimal_encode() { + let mut buf = Vec::new(); + let decimal = Decimal::from_str("123.456").unwrap(); + let result = >::encode(decimal, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Text(text) = &buf[0] { + assert_eq!(text, "123.456"); + } else { + panic!("Expected Text argument"); + } + } + + #[test] + fn test_decimal_encode_by_ref() { + let mut buf = Vec::new(); + let decimal = Decimal::from_str("987.654").unwrap(); + let result = >::encode_by_ref(&decimal, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Text(text) = &buf[0] { + assert_eq!(text, "987.654"); + } else { + panic!("Expected Text argument"); + } + } + + #[test] + fn test_decimal_type_info() { + let type_info = >::type_info(); + assert_eq!(type_info.name(), "NUMERIC"); + if let DataType::Numeric { precision, scale } = type_info.data_type() { + assert_eq!(precision, 28); + assert_eq!(scale, 4); + } else { + panic!("Expected Numeric data type"); + } + } + + #[test] + fn test_decimal_decode_error_handling() { + let value = OdbcValueRef { + type_info: OdbcTypeInfo::decimal(10, 2), + is_null: false, + text: None, + blob: None, + int: None, + float: None, + }; + + let result = >::decode(value); + assert!(result.is_err()); + } +} diff --git a/sqlx-core/src/odbc/types/float.rs b/sqlx-core/src/odbc/types/float.rs new file mode 100644 index 0000000000..09ed1fcb90 --- /dev/null +++ b/sqlx-core/src/odbc/types/float.rs @@ -0,0 +1,91 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use odbc_api::DataType; + +impl Type for f64 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::DOUBLE + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Double + | DataType::Float { .. } + | DataType::Real + | DataType::Numeric { .. } + | DataType::Decimal { .. } + | DataType::Integer + | DataType::BigInt + | DataType::SmallInt + | DataType::TinyInt + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for f32 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::float(24) // Standard float precision + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Float { .. } + | DataType::Real + | DataType::Double + | DataType::Numeric { .. } + | DataType::Decimal { .. } + | DataType::Integer + | DataType::BigInt + | DataType::SmallInt + | DataType::TinyInt + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl<'q> Encode<'q, Odbc> for f32 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Float(self as f64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Float(*self as f64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for f64 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Float(self)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Float(*self)); + crate::encode::IsNull::No + } +} + +impl<'r> Decode<'r, Odbc> for f64 { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(f) = value.float { + return Ok(f); + } + if let Some(int) = value.int { + return Ok(int as f64); + } + if let Some(s) = value.text { + return Ok(s.trim().parse()?); + } + Err(format!("ODBC: cannot decode f64: {:?}", value).into()) + } +} + +impl<'r> Decode<'r, Odbc> for f32 { + fn decode(value: OdbcValueRef<'r>) -> Result { + Ok(>::decode(value)? as f32) + } +} diff --git a/sqlx-core/src/odbc/types/int.rs b/sqlx-core/src/odbc/types/int.rs new file mode 100644 index 0000000000..485d963194 --- /dev/null +++ b/sqlx-core/src/odbc/types/int.rs @@ -0,0 +1,593 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use odbc_api::DataType; + +impl Type for i32 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::INTEGER + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Integer + | DataType::SmallInt + | DataType::TinyInt + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for i64 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::BIGINT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::BigInt + | DataType::Integer + | DataType::SmallInt + | DataType::TinyInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for i16 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::SMALLINT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::SmallInt + | DataType::TinyInt + | DataType::Integer + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for i8 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TINYINT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for u8 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::TINYINT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::TinyInt + | DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for u16 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::SMALLINT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::SmallInt + | DataType::Integer + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for u32 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::INTEGER + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::Integer + | DataType::BigInt + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl Type for u64 { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::BIGINT + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + matches!( + ty.data_type(), + DataType::BigInt + | DataType::Integer + | DataType::Numeric { .. } + | DataType::Decimal { .. } + ) || ty.data_type().accepts_character_data() // Allow parsing from strings + } +} + +impl<'q> Encode<'q, Odbc> for i32 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self as i64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self as i64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for i64 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for i16 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self as i64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self as i64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for i8 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self as i64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self as i64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for u8 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self as i64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self as i64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for u16 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self as i64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self as i64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for u32 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(self as i64)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Int(*self as i64)); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for u64 { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + match i64::try_from(self) { + Ok(value) => { + buf.push(OdbcArgumentValue::Int(value)); + crate::encode::IsNull::No + } + Err(_) => { + log::warn!("u64 value {} too large for ODBC, encoding as NULL", self); + buf.push(OdbcArgumentValue::Null); + crate::encode::IsNull::Yes + } + } + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + match i64::try_from(*self) { + Ok(value) => { + buf.push(OdbcArgumentValue::Int(value)); + crate::encode::IsNull::No + } + Err(_) => { + log::warn!("u64 value {} too large for ODBC, encoding as NULL", self); + buf.push(OdbcArgumentValue::Null); + crate::encode::IsNull::Yes + } + } + } +} + +// Helper functions for numeric parsing +fn parse_numeric_as_i64(s: &str) -> Option { + let trimmed = s.trim(); + if let Ok(parsed) = trimmed.parse::() { + Some(parsed) + } else if let Ok(parsed) = trimmed.parse::() { + Some(parsed as i64) + } else { + None + } +} + +fn get_text_for_numeric_parsing(value: &OdbcValueRef<'_>) -> Result, BoxDynError> { + if let Some(text) = value.text { + return Ok(Some(text.trim().to_string())); + } + if let Some(bytes) = value.blob { + let s = std::str::from_utf8(bytes)?; + return Ok(Some(s.trim().to_string())); + } + Ok(None) +} + +impl<'r> Decode<'r, Odbc> for i64 { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(i) = value.int { + return Ok(i); + } + if let Some(f) = value.float { + return Ok(f as i64); + } + if let Some(text) = get_text_for_numeric_parsing(&value)? { + if let Some(parsed) = parse_numeric_as_i64(&text) { + return Ok(parsed); + } + } + Err("ODBC: cannot decode i64".into()) + } +} + +impl<'r> Decode<'r, Odbc> for i32 { + fn decode(value: OdbcValueRef<'r>) -> Result { + Ok(>::decode(value)? as i32) + } +} + +impl<'r> Decode<'r, Odbc> for i16 { + fn decode(value: OdbcValueRef<'r>) -> Result { + Ok(>::decode(value)? as i16) + } +} + +impl<'r> Decode<'r, Odbc> for i8 { + fn decode(value: OdbcValueRef<'r>) -> Result { + Ok(>::decode(value)? as i8) + } +} + +impl<'r> Decode<'r, Odbc> for u8 { + fn decode(value: OdbcValueRef<'r>) -> Result { + let i = >::decode(value)?; + Ok(u8::try_from(i)?) + } +} + +impl<'r> Decode<'r, Odbc> for u16 { + fn decode(value: OdbcValueRef<'r>) -> Result { + let i = >::decode(value)?; + Ok(u16::try_from(i)?) + } +} + +impl<'r> Decode<'r, Odbc> for u32 { + fn decode(value: OdbcValueRef<'r>) -> Result { + let i = >::decode(value)?; + Ok(u32::try_from(i)?) + } +} + +impl<'r> Decode<'r, Odbc> for u64 { + fn decode(value: OdbcValueRef<'r>) -> Result { + let i = >::decode(value)?; + Ok(u64::try_from(i)?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use odbc_api::DataType; + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: Some(text), + blob: None, + int: None, + float: None, + } + } + + fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: Some(value), + float: None, + } + } + + fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: None, + blob: None, + int: None, + float: Some(value), + } + } + + #[test] + fn test_i32_type_compatibility() { + // Standard integer types + assert!(>::compatible(&OdbcTypeInfo::INTEGER)); + assert!(>::compatible(&OdbcTypeInfo::SMALLINT)); + assert!(>::compatible(&OdbcTypeInfo::TINYINT)); + assert!(>::compatible(&OdbcTypeInfo::BIGINT)); + + // DECIMAL and NUMERIC types (Snowflake compatibility) + assert!(>::compatible(&OdbcTypeInfo::decimal( + 10, 2 + ))); + assert!(>::compatible(&OdbcTypeInfo::numeric( + 15, 4 + ))); + + // Character types + assert!(>::compatible(&OdbcTypeInfo::varchar( + None + ))); + + // Should not be compatible with binary types + assert!(!>::compatible(&OdbcTypeInfo::varbinary( + None + ))); + } + + #[test] + fn test_i64_decode_from_text() -> Result<(), BoxDynError> { + let value = create_test_value_text( + "42", + DataType::Decimal { + precision: 10, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, 42); + + // Test with decimal value (should truncate) + let value = create_test_value_text( + "42.7", + DataType::Decimal { + precision: 10, + scale: 1, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, 42); + + // Test with whitespace + let value = create_test_value_text( + " 123 ", + DataType::Decimal { + precision: 10, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, 123); + + Ok(()) + } + + #[test] + fn test_i64_decode_from_int() -> Result<(), BoxDynError> { + let value = create_test_value_int(42, DataType::Integer); + let decoded = >::decode(value)?; + assert_eq!(decoded, 42); + + Ok(()) + } + + #[test] + fn test_i64_decode_from_float() -> Result<(), BoxDynError> { + let value = create_test_value_float(42.7, DataType::Double); + let decoded = >::decode(value)?; + assert_eq!(decoded, 42); + + Ok(()) + } + + #[test] + fn test_i32_decode() -> Result<(), BoxDynError> { + let value = create_test_value_text( + "42", + DataType::Decimal { + precision: 10, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, 42); + + // Test negative + let value = create_test_value_text( + "-123", + DataType::Decimal { + precision: 10, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, -123); + + Ok(()) + } + + #[test] + fn test_u32_type_compatibility() { + // Should be compatible with DECIMAL/NUMERIC + assert!(>::compatible(&OdbcTypeInfo::decimal( + 10, 2 + ))); + assert!(>::compatible(&OdbcTypeInfo::numeric( + 15, 4 + ))); + + // Standard integer types + assert!(>::compatible(&OdbcTypeInfo::INTEGER)); + assert!(>::compatible(&OdbcTypeInfo::BIGINT)); + + // Character types + assert!(>::compatible(&OdbcTypeInfo::varchar( + None + ))); + } + + #[test] + fn test_u64_decode() -> Result<(), BoxDynError> { + let value = create_test_value_text( + "42", + DataType::Numeric { + precision: 20, + scale: 0, + }, + ); + let decoded = >::decode(value)?; + assert_eq!(decoded, 42); + + Ok(()) + } + + #[test] + fn test_decode_error_handling() { + let value = OdbcValueRef { + type_info: OdbcTypeInfo::INTEGER, + is_null: false, + text: None, + blob: None, + int: None, + float: None, + }; + + let result = >::decode(value); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "ODBC: cannot decode i64"); + } + + #[test] + fn test_encode_i32() { + let mut buf = Vec::new(); + let result = >::encode(42i32, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Int(val) = &buf[0] { + assert_eq!(*val, 42); + } else { + panic!("Expected Int argument"); + } + } + + #[test] + fn test_encode_u64_overflow() { + let mut buf = Vec::new(); + let large_val = u64::MAX; + let result = >::encode(large_val, &mut buf); + assert!(matches!(result, crate::encode::IsNull::Yes)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Null = &buf[0] { + // Expected + } else { + panic!("Expected Null argument for overflow"); + } + } + + #[test] + fn test_all_integer_types_support_decimal() { + let decimal_type = OdbcTypeInfo::decimal(10, 2); + let numeric_type = OdbcTypeInfo::numeric(15, 4); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + + assert!(>::compatible(&decimal_type)); + assert!(>::compatible(&numeric_type)); + } +} diff --git a/sqlx-core/src/odbc/types/json.rs b/sqlx-core/src/odbc/types/json.rs new file mode 100644 index 0000000000..dcd7db14c4 --- /dev/null +++ b/sqlx-core/src/odbc/types/json.rs @@ -0,0 +1,125 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use serde::de::Error; +use serde_json::Value; + +impl Type for Value { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::varchar(None) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() + || ty.data_type().accepts_binary_data() + || matches!( + ty.data_type(), + odbc_api::DataType::Other { .. } | odbc_api::DataType::Unknown + ) + } +} + +impl<'q> Encode<'q, Odbc> for Value { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } +} + +impl<'r> Decode<'r, Odbc> for Value { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(bytes) = value.blob { + serde_json::from_slice(bytes) + } else if let Some(text) = value.text { + serde_json::from_str(text) + } else if let Some(i) = value.int { + Ok(serde_json::Value::from(i)) + } else if let Some(f) = value.float { + Ok(serde_json::Value::from(f)) + } else { + Err(serde_json::Error::custom("not a valid json type")) + } + .map_err(|e| format!("ODBC: cannot decode JSON from {:?}: {}", value, e).into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::type_info::TypeInfo; + use odbc_api::DataType; + use serde_json::{json, Value}; + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + OdbcValueRef { + type_info: OdbcTypeInfo::new(data_type), + is_null: false, + text: Some(text), + blob: None, + int: None, + float: None, + } + } + + #[test] + fn test_json_type_compatibility() { + // Should be compatible with character types + assert!(>::compatible(&OdbcTypeInfo::varchar( + None + ))); + assert!(>::compatible(&OdbcTypeInfo::char(None))); + } + + #[test] + fn test_json_decode_simple() -> Result<(), BoxDynError> { + let json_str = r#"{"name": "test"}"#; + let value = create_test_value_text(json_str, DataType::Varchar { length: None }); + let decoded = >::decode(value)?; + assert!(decoded.is_object()); + assert_eq!(decoded["name"], "test"); + + Ok(()) + } + + #[test] + fn test_json_decode_invalid() { + let invalid_json = r#"{"invalid": json,}"#; + let value = create_test_value_text(invalid_json, DataType::Varchar { length: None }); + let result = >::decode(value); + assert!(result.is_err(), "{:?} should be an error", result); + } + + #[test] + fn test_json_encode() { + let mut buf = Vec::new(); + let json_val = json!({"name": "test"}); + let result = >::encode(json_val, &mut buf); + assert!(matches!(result, crate::encode::IsNull::No)); + assert_eq!(buf.len(), 1); + if let OdbcArgumentValue::Text(text) = &buf[0] { + // Parse the encoded text back to verify it's valid JSON + let reparsed: Value = serde_json::from_str(text).unwrap(); + assert!(reparsed.is_object()); + } else { + panic!("Expected Text argument"); + } + } + + #[test] + fn test_json_type_info() { + let type_info = >::type_info(); + assert_eq!(type_info.name(), "VARCHAR"); + assert!(matches!( + type_info.data_type(), + DataType::Varchar { length: None } + )); + } +} diff --git a/sqlx-core/src/odbc/types/mod.rs b/sqlx-core/src/odbc/types/mod.rs new file mode 100644 index 0000000000..9708b7108f --- /dev/null +++ b/sqlx-core/src/odbc/types/mod.rs @@ -0,0 +1,23 @@ +pub mod bool; +pub mod bytes; +pub mod float; +pub mod int; +pub mod str; + +#[cfg(feature = "bigdecimal")] +pub mod bigdecimal; + +#[cfg(feature = "chrono")] +pub mod chrono; + +#[cfg(feature = "decimal")] +pub mod decimal; + +#[cfg(feature = "json")] +pub mod json; + +#[cfg(feature = "time")] +pub mod time; + +#[cfg(feature = "uuid")] +pub mod uuid; diff --git a/sqlx-core/src/odbc/types/str.rs b/sqlx-core/src/odbc/types/str.rs new file mode 100644 index 0000000000..32207efb6c --- /dev/null +++ b/sqlx-core/src/odbc/types/str.rs @@ -0,0 +1,71 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; + +impl Type for str { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::varchar(None) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_character_data() + } +} + +impl Type for String { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::varchar(None) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_character_data() + } +} + +impl<'q> Encode<'q, Odbc> for String { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self)); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.clone())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for &'q str { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_owned())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text((*self).to_owned())); + crate::encode::IsNull::No + } +} + +impl<'r> Decode<'r, Odbc> for String { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(text) = value.text { + return Ok(text.to_owned()); + } + if let Some(bytes) = value.blob { + return Ok(std::str::from_utf8(bytes)?.to_owned()); + } + Err("ODBC: cannot decode String".into()) + } +} + +impl<'r> Decode<'r, Odbc> for &'r str { + fn decode(value: OdbcValueRef<'r>) -> Result { + if let Some(text) = value.text { + return Ok(text); + } + if let Some(bytes) = value.blob { + return Ok(std::str::from_utf8(bytes)?); + } + Err("ODBC: cannot decode &str".into()) + } +} diff --git a/sqlx-core/src/odbc/types/time.rs b/sqlx-core/src/odbc/types/time.rs new file mode 100644 index 0000000000..3d0d0d0d44 --- /dev/null +++ b/sqlx-core/src/odbc/types/time.rs @@ -0,0 +1,456 @@ +use crate::decode::Decode; +use crate::encode::Encode; +use crate::error::BoxDynError; +use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; +use crate::types::Type; +use time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; + +impl Type for OffsetDateTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::timestamp(6) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_datetime_data() + || ty.data_type().accepts_character_data() + || ty.data_type().accepts_numeric_data() // For Unix timestamps + } +} + +impl Type for PrimitiveDateTime { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::timestamp(6) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_datetime_data() || ty.data_type().accepts_character_data() + } +} + +impl Type for Date { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::new(odbc_api::DataType::Date) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_datetime_data() || ty.data_type().accepts_character_data() + } +} + +impl Type for Time { + fn type_info() -> OdbcTypeInfo { + OdbcTypeInfo::time(6) + } + fn compatible(ty: &OdbcTypeInfo) -> bool { + ty.data_type().accepts_datetime_data() || ty.data_type().accepts_character_data() + } +} + +impl<'q> Encode<'q, Odbc> for OffsetDateTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + let utc_dt = self.to_offset(time::UtcOffset::UTC); + let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); + buf.push(OdbcArgumentValue::Text(primitive_dt.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + let utc_dt = self.to_offset(time::UtcOffset::UTC); + let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); + buf.push(OdbcArgumentValue::Text(primitive_dt.to_string())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for PrimitiveDateTime { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for Date { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } +} + +impl<'q> Encode<'q, Odbc> for Time { + fn encode(self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } + + fn encode_by_ref(&self, buf: &mut Vec) -> crate::encode::IsNull { + buf.push(OdbcArgumentValue::Text(self.to_string())); + crate::encode::IsNull::No + } +} + +// Helper function for parsing datetime from Unix timestamp +fn parse_unix_timestamp_as_offset_datetime(timestamp: i64) -> Option { + OffsetDateTime::from_unix_timestamp(timestamp).ok() +} + +impl<'r> Decode<'r, Odbc> for OffsetDateTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle numeric timestamps (Unix epoch seconds) first + if let Some(int_val) = value.int { + if let Some(dt) = parse_unix_timestamp_as_offset_datetime(int_val) { + return Ok(dt); + } + } + + if let Some(float_val) = value.float { + if let Some(dt) = parse_unix_timestamp_as_offset_datetime(float_val as i64) { + return Ok(dt); + } + } + + // Handle text values + if let Some(text) = value.text { + let trimmed = text.trim(); + // Try parsing as ISO-8601 timestamp with timezone + if let Ok(dt) = OffsetDateTime::parse( + trimmed, + &time::format_description::well_known::Iso8601::DEFAULT, + ) { + return Ok(dt); + } + // Try parsing as primitive datetime and assume UTC + if let Ok(dt) = PrimitiveDateTime::parse( + trimmed, + &time::format_description::well_known::Iso8601::DEFAULT, + ) { + return Ok(dt.assume_utc()); + } + // Try custom formats that ODBC might return + if let Ok(dt) = time::PrimitiveDateTime::parse( + trimmed, + &time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"), + ) { + return Ok(dt.assume_utc()); + } + } + + Err("ODBC: cannot decode OffsetDateTime".into()) + } +} + +impl<'r> Decode<'r, Odbc> for PrimitiveDateTime { + fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle numeric timestamps (Unix epoch seconds) first + if let Some(int_val) = value.int { + if let Some(offset_dt) = parse_unix_timestamp_as_offset_datetime(int_val) { + let utc_dt = offset_dt.to_offset(time::UtcOffset::UTC); + return Ok(PrimitiveDateTime::new(utc_dt.date(), utc_dt.time())); + } + } + + if let Some(float_val) = value.float { + if let Some(offset_dt) = parse_unix_timestamp_as_offset_datetime(float_val as i64) { + let utc_dt = offset_dt.to_offset(time::UtcOffset::UTC); + return Ok(PrimitiveDateTime::new(utc_dt.date(), utc_dt.time())); + } + } + + // Handle text values + if let Some(text) = value.text { + let trimmed = text.trim(); + // Try parsing as ISO-8601 + if let Ok(dt) = PrimitiveDateTime::parse( + trimmed, + &time::format_description::well_known::Iso8601::DEFAULT, + ) { + return Ok(dt); + } + // Try custom formats that ODBC might return + if let Ok(dt) = PrimitiveDateTime::parse( + trimmed, + &time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"), + ) { + return Ok(dt); + } + if let Ok(dt) = PrimitiveDateTime::parse( + trimmed, + &time::macros::format_description!( + "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" + ), + ) { + return Ok(dt); + } + } + + Err("ODBC: cannot decode PrimitiveDateTime".into()) + } +} + +// Helper functions for time crate date parsing +fn parse_yyyymmdd_as_time_date(val: i64) -> Option { + if (19000101..=30001231).contains(&val) { + let year = (val / 10000) as i32; + let month = ((val % 10000) / 100) as u8; + let day = (val % 100) as u8; + + if let Ok(month_enum) = time::Month::try_from(month) { + Date::from_calendar_date(year, month_enum, day).ok() + } else { + None + } + } else { + None + } +} + +fn parse_yyyymmdd_text_as_time_date(s: &str) -> Option { + if s.len() == 8 && s.chars().all(|c| c.is_ascii_digit()) { + if let (Ok(y), Ok(m), Ok(d)) = ( + s[0..4].parse::(), + s[4..6].parse::(), + s[6..8].parse::(), + ) { + if let Ok(month_enum) = time::Month::try_from(m) { + return Date::from_calendar_date(y, month_enum, d).ok(); + } + } + } + None +} + +impl<'r> Decode<'r, Odbc> for Date { + fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle numeric YYYYMMDD format first + if let Some(int_val) = value.int { + if let Some(date) = parse_yyyymmdd_as_time_date(int_val) { + return Ok(date); + } + + // Fallback: try as days since Unix epoch + if let Ok(days) = i32::try_from(int_val) { + let epoch = Date::from_calendar_date(1970, time::Month::January, 1)?; + if let Some(date) = epoch.checked_add(time::Duration::days(days as i64)) { + return Ok(date); + } + } + } + + // Handle float values + if let Some(float_val) = value.float { + if let Some(date) = parse_yyyymmdd_as_time_date(float_val as i64) { + return Ok(date); + } + } + + // Handle text values + if let Some(text) = value.text { + let trimmed = text.trim(); + if let Some(date) = parse_yyyymmdd_text_as_time_date(trimmed) { + return Ok(date); + } + + if let Ok(date) = Date::parse( + trimmed, + &time::macros::format_description!("[year]-[month]-[day]"), + ) { + return Ok(date); + } + if let Ok(date) = Date::parse( + trimmed, + &time::format_description::well_known::Iso8601::DEFAULT, + ) { + return Ok(date); + } + } + + Err("ODBC: cannot decode Date".into()) + } +} + +// Helper function for time parsing from seconds since midnight +fn parse_seconds_as_time(seconds: i64) -> Option